EJ-L commited on
Commit
a4fdcb3
·
1 Parent(s): 2287a37
Files changed (1) hide show
  1. app.py +126 -54
app.py CHANGED
@@ -1,6 +1,22 @@
 
1
  import gradio as gr
2
- import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from qwen_vl_utils import process_vision_info
5
  import torch
6
  import base64
@@ -8,52 +24,60 @@ from PIL import Image, ImageDraw
8
  from io import BytesIO
9
  import re
10
 
 
 
 
 
 
 
11
 
 
 
12
  models = {
13
- "OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained("OS-Copilot/OS-Atlas-Base-7B", torch_dtype="auto", device_map="auto"),
 
 
 
 
14
  }
15
 
16
  processors = {
17
  "OS-Copilot/OS-Atlas-Base-7B": AutoProcessor.from_pretrained("OS-Copilot/OS-Atlas-Base-7B")
18
  }
19
 
20
-
21
- def image_to_base64(image):
22
  buffered = BytesIO()
23
  image.save(buffered, format="PNG")
24
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
25
- return img_str
26
 
27
-
28
- def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
29
  draw = ImageDraw.Draw(image)
30
- for box in bounding_boxes:
31
  xmin, ymin, xmax, ymax = box
32
  draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
33
  return image
34
 
35
-
36
  def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
 
 
37
  x_scale = original_width / scaled_width
38
  y_scale = original_height / scaled_height
39
- rescaled_boxes = []
40
- for box in bounding_boxes:
41
- xmin, ymin, xmax, ymax = box
42
- rescaled_box = [
43
- xmin * x_scale,
44
- ymin * y_scale,
45
- xmax * x_scale,
46
- ymax * y_scale
47
- ]
48
- rescaled_boxes.append(rescaled_box)
49
- return rescaled_boxes
50
-
51
 
52
- @spaces.GPU
53
  def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
 
 
 
 
54
  model = models[model_id].eval()
55
  processor = processors[model_id]
56
- prompt = f"In this UI screenshot, what is the position of the element corresponding to the command \"{text_input}\" (with bbox)?"
 
57
  messages = [
58
  {
59
  "role": "user",
@@ -64,9 +88,8 @@ def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
64
  }
65
  ]
66
 
67
- text = processor.apply_chat_template(
68
- messages, tokenize=False, add_generation_prompt=True
69
- )
70
  image_inputs, video_inputs = process_vision_info(messages)
71
  inputs = processor(
72
  text=[text],
@@ -75,43 +98,70 @@ def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
75
  padding=True,
76
  return_tensors="pt",
77
  )
78
- inputs = inputs.to("cuda")
79
 
80
- generated_ids = model.generate(**inputs, max_new_tokens=128)
81
- generated_ids_trimmed = [
82
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
83
- ]
84
- output_text = processor.batch_decode(
 
 
 
 
 
85
  generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
86
  )
87
- print(output_text)
88
- text = output_text[0]
89
 
 
90
  object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
91
  box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>"
92
 
93
- object_ref = re.search(object_ref_pattern, text).group(1)
94
- box_content = re.search(box_pattern, text).group(1)
95
 
96
- boxes = [tuple(map(int, pair.strip("()").split(','))) for pair in box_content.split("),(")]
97
- boxes = [[boxes[0][0], boxes[0][1], boxes[1][0], boxes[1][1]]]
98
 
99
- scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height)
100
- return object_ref, text, draw_bounding_boxes(image, scaled_boxes)
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
102
  css = """
103
  #output {
104
- height: 500px;
105
- overflow: auto;
106
- border: 1px solid #ccc;
107
  }
108
  """
109
- with gr.Blocks() as demo:
 
 
 
110
  with gr.Row():
111
- gr.HTML(f"<style>{css}</style>")
112
  with gr.Column():
113
  input_img = gr.Image(label="Input Image", type="pil")
114
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="OS-Copilot/OS-Atlas-Base-7B")
 
 
 
 
115
  text_input = gr.Textbox(label="User Prompt")
116
  submit_btn = gr.Button(value="Submit")
117
  with gr.Column():
@@ -125,12 +175,34 @@ with gr.Blocks() as demo:
125
  ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "switch to discussions"],
126
  ],
127
  inputs=[input_img, text_input],
128
- outputs=[model_output_text, model_output_box, annotated_image],
129
- fn=run_example,
130
- cache_examples=False,
131
- label="Try examples"
 
 
 
132
  )
133
 
134
- submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, model_output_box, annotated_image])
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+
4
+ # --- Patch gradio_client boolean-schema bug ---
5
+ import gradio_client.utils as gcu
6
+
7
+ orig_json_schema_to_python_type = gcu._json_schema_to_python_type
8
+
9
+ def _safe_json_schema_to_python_type(schema, defs):
10
+ # Fix: handle boolean schema values for additionalProperties
11
+ if isinstance(schema, bool):
12
+ # True → any type allowed; False → never allowed
13
+ return "Any" if schema else "Never"
14
+ return orig_json_schema_to_python_type(schema, defs)
15
+
16
+ gcu._json_schema_to_python_type = _safe_json_schema_to_python_type
17
+ # ------------------------------------------------
18
+
19
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
20
  from qwen_vl_utils import process_vision_info
21
  import torch
22
  import base64
 
24
  from io import BytesIO
25
  import re
26
 
27
+ # -------- Runtime / device --------
28
+ # Force CPU usage
29
+ device = "cpu"
30
+
31
+ # Hugging Face Spaces port
32
+ PORT = int(os.getenv("PORT", "7860"))
33
 
34
+ # -------- Model / Processor --------
35
+ # NOTE: device_map=None + .to(device) keeps everything on CPU
36
  models = {
37
+ "OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained(
38
+ "OS-Copilot/OS-Atlas-Base-7B",
39
+ dtype="auto", # use 'dtype' (new) rather than deprecated 'torch_dtype'
40
+ device_map=None
41
+ ).to(device)
42
  }
43
 
44
  processors = {
45
  "OS-Copilot/OS-Atlas-Base-7B": AutoProcessor.from_pretrained("OS-Copilot/OS-Atlas-Base-7B")
46
  }
47
 
48
+ # -------- Helpers --------
49
+ def image_to_base64(image: Image.Image) -> str:
50
  buffered = BytesIO()
51
  image.save(buffered, format="PNG")
52
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
53
 
54
+ def draw_bounding_boxes(image: Image.Image, bounding_boxes, outline_color="red", line_width=2):
 
55
  draw = ImageDraw.Draw(image)
56
+ for box in bounding_boxes or []:
57
  xmin, ymin, xmax, ymax = box
58
  draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
59
  return image
60
 
 
61
  def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
62
+ if not bounding_boxes:
63
+ return []
64
  x_scale = original_width / scaled_width
65
  y_scale = original_height / scaled_height
66
+ return [
67
+ [xmin * x_scale, ymin * y_scale, xmax * x_scale, ymax * y_scale]
68
+ for (xmin, ymin, xmax, ymax) in bounding_boxes
69
+ ]
 
 
 
 
 
 
 
 
70
 
71
+ # -------- Inference --------
72
  def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
73
+ # Basic validation so the Space doesn't 500
74
+ if image is None or (text_input is None or str(text_input).strip() == ""):
75
+ return "", [], image
76
+
77
  model = models[model_id].eval()
78
  processor = processors[model_id]
79
+
80
+ prompt = f'In this UI screenshot, what is the position of the element corresponding to the command "{text_input}" (with bbox)?'
81
  messages = [
82
  {
83
  "role": "user",
 
88
  }
89
  ]
90
 
91
+ # Build inputs
92
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
93
  image_inputs, video_inputs = process_vision_info(messages)
94
  inputs = processor(
95
  text=[text],
 
98
  padding=True,
99
  return_tensors="pt",
100
  )
 
101
 
102
+ # Move tensors to CPU explicitly
103
+ inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
104
+
105
+ # Generate
106
+ with torch.no_grad():
107
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
108
+
109
+ # Post-process
110
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
111
+ output_texts = processor.batch_decode(
112
  generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
113
  )
114
+ text = output_texts[0] if output_texts else ""
 
115
 
116
+ # Parse object_ref and bbox defensively
117
  object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
118
  box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>"
119
 
120
+ object_match = re.search(object_ref_pattern, text or "")
121
+ box_match = re.search(box_pattern, text or "")
122
 
123
+ object_ref = object_match.group(1).strip() if object_match else ""
124
+ box_content = box_match.group(1).strip() if box_match else ""
125
 
126
+ boxes = []
127
+ if box_content:
128
+ try:
129
+ # Expecting "(x1,y1),(x2,y2)" -> convert to [xmin, ymin, xmax, ymax]
130
+ parts = [p.strip() for p in box_content.split("),(")]
131
+ parts[0] = parts[0].lstrip("(")
132
+ parts[-1] = parts[-1].rstrip(")")
133
+ coords = [tuple(map(int, p.split(","))) for p in parts]
134
+ if len(coords) >= 2:
135
+ (x1, y1), (x2, y2) = coords[0], coords[1]
136
+ boxes = [[x1, y1, x2, y2]]
137
+ except Exception:
138
+ boxes = []
139
 
140
+ scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height) if boxes else []
141
+ annotated = draw_bounding_boxes(image.copy(), scaled_boxes) if scaled_boxes else image
142
+
143
+ return object_ref, scaled_boxes, annotated
144
+
145
+ # -------- UI --------
146
  css = """
147
  #output {
148
+ height: 500px;
149
+ overflow: auto;
150
+ border: 1px solid #ccc;
151
  }
152
  """
153
+
154
+ with gr.Blocks(css=css) as demo:
155
+ gr.Markdown("# Demo for OS-ATLAS: A Foundation Action Model For Generalist GUI Agents")
156
+
157
  with gr.Row():
 
158
  with gr.Column():
159
  input_img = gr.Image(label="Input Image", type="pil")
160
+ model_selector = gr.Dropdown(
161
+ choices=list(models.keys()),
162
+ label="Model",
163
+ value="OS-Copilot/OS-Atlas-Base-7B"
164
+ )
165
  text_input = gr.Textbox(label="User Prompt")
166
  submit_btn = gr.Button(value="Submit")
167
  with gr.Column():
 
175
  ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "switch to discussions"],
176
  ],
177
  inputs=[input_img, text_input],
178
+ # remove fn/outputs so examples only prefill inputs
179
+ )
180
+
181
+ submit_btn.click(
182
+ run_example,
183
+ [input_img, text_input, model_selector],
184
+ [model_output_text, model_output_box, annotated_image],
185
  )
186
 
187
+ # ---- Make Gradio/Starlette error responses small & safe (no Content-Length drama) ----
188
+ from fastapi import Request
189
+ from starlette.responses import PlainTextResponse
190
+
191
+ app = demo.app # FastAPI app behind Gradio Blocks
192
+
193
+ @app.exception_handler(Exception)
194
+ async def _catch_all_exceptions(request: Request, exc: Exception):
195
+ # Return a very small body so Starlette/Uvicorn never miscounts bytes
196
+ return PlainTextResponse("Internal Server Error", status_code=500)
197
+ # --------------------------------------------------------------------------------------
198
+
199
 
200
+ # -------- Launch (Spaces-friendly) --------
201
+ demo.queue().launch(
202
+ server_name="0.0.0.0",
203
+ server_port=PORT,
204
+ show_error=False, # avoid large HTML error bodies
205
+ debug=False, # avoid big pretty tracebacks (and Content-Length mismatch)
206
+ show_api=False # <— key: disables /api/info schema generation
207
+ # api_open=False # if your Gradio version expects the old name, use this instead of show_api
208
+ )