prithivMLmods commited on
Commit
9c44f17
·
verified ·
1 Parent(s): 6cd9363

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -156
app.py CHANGED
@@ -5,7 +5,8 @@ import time
5
  import unicodedata
6
  import gc
7
  from io import BytesIO
8
- from typing import Iterable, Tuple, Optional, List, Dict, Any
 
9
 
10
  import gradio as gr
11
  import numpy as np
@@ -115,6 +116,7 @@ except Exception as e:
115
  print("🔄 Loading UI-TARS-1.5-7B...")
116
  MODEL_ID_X = "ByteDance-Seed/UI-TARS-1.5-7B"
117
  try:
 
118
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True, use_fast=False)
119
  model_x = AutoModelForImageTextToText.from_pretrained(
120
  MODEL_ID_X,
@@ -126,80 +128,16 @@ except Exception as e:
126
  model_x = None
127
  processor_x = None
128
 
129
- # --- Load Holo2-8B ---
130
- print("🔄 Loading Holo2-8B...")
131
- MODEL_ID_H = "Hcompany/Holo2-8B"
132
- try:
133
- processor_h = AutoProcessor.from_pretrained(MODEL_ID_H, trust_remote_code=True)
134
- model_h = AutoModelForImageTextToText.from_pretrained(
135
- MODEL_ID_H,
136
- trust_remote_code=True,
137
- torch_dtype=torch.float16
138
- ).to(device).eval()
139
- except Exception as e:
140
- print(f"Failed to load Holo2: {e}")
141
- model_h = None
142
- processor_h = None
143
-
144
  print("✅ Models loading sequence complete.")
145
 
146
  # -----------------------------------------------------------------------------
147
- # 3. UTILS & HELPERS
148
  # -----------------------------------------------------------------------------
149
 
150
  def array_to_image(image_array: np.ndarray) -> Image.Image:
151
  if image_array is None: raise ValueError("No image provided.")
152
  return Image.fromarray(np.uint8(image_array))
153
 
154
- # --- Compatibility Helpers ---
155
- def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str:
156
- """Helper to handle chat template application across different processors"""
157
- tok = getattr(processor, "tokenizer", None)
158
- if hasattr(processor, "apply_chat_template"):
159
- return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
160
- if tok is not None and hasattr(tok, "apply_chat_template"):
161
- return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
162
-
163
- # Fallback if no template method found
164
- texts = []
165
- for m in messages:
166
- for c in m.get("content", []):
167
- if isinstance(c, dict) and c.get("type") == "text":
168
- texts.append(c.get("text", ""))
169
- return "\n".join(texts)
170
-
171
- def batch_decode_compat(processor, token_id_batches, **kw):
172
- """Helper to handle batch decoding"""
173
- tok = getattr(processor, "tokenizer", None)
174
- if tok is not None and hasattr(tok, "batch_decode"):
175
- return tok.batch_decode(token_id_batches, **kw)
176
- if hasattr(processor, "batch_decode"):
177
- return processor.batch_decode(token_id_batches, **kw)
178
- raise AttributeError("No batch_decode available on processor or tokenizer.")
179
-
180
- def trim_generated(generated_ids, inputs):
181
- """Removes input tokens from output if necessary"""
182
- in_ids = getattr(inputs, "input_ids", None)
183
- if in_ids is None and isinstance(inputs, dict):
184
- in_ids = inputs.get("input_ids", None)
185
- if in_ids is None:
186
- return [out_ids for out_ids in generated_ids]
187
- return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
188
-
189
- def get_image_proc_params(processor) -> Dict[str, int]:
190
- """Extracts resizing parameters from the processor configuration"""
191
- ip = getattr(processor, "image_processor", None)
192
- return {
193
- "patch_size": getattr(ip, "patch_size", 14),
194
- "merge_size": getattr(ip, "merge_size", 2), # Default to 2, Holo2 might differ
195
- "min_pixels": getattr(ip, "min_pixels", 256 * 256),
196
- "max_pixels": getattr(ip, "max_pixels", 1280 * 1280),
197
- }
198
-
199
- # -----------------------------------------------------------------------------
200
- # 4. PROMPTS
201
- # -----------------------------------------------------------------------------
202
-
203
  # --- Fara Prompt ---
204
  def get_fara_prompt(task, image):
205
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
@@ -217,6 +155,7 @@ def get_fara_prompt(task, image):
217
 
218
  # --- UI-TARS Prompt ---
219
  def get_uitars_prompt(task, image):
 
220
  guidelines = (
221
  "Localize an element on the GUI image according to my instructions and "
222
  "output a click position as Click(x, y) with x num pixels from the left edge "
@@ -232,38 +171,29 @@ def get_uitars_prompt(task, image):
232
  }
233
  ]
234
 
235
- # --- Holo2 Prompt ---
236
- def get_holo_prompt(pil_image: Image.Image, instruction: str) -> List[dict]:
237
- guidelines: str = (
238
- "Localize an element on the GUI image according to my instructions and "
239
- "output a click position as Click(x, y) with x num pixels from the left edge "
240
- "and y num pixels from the top edge."
241
- )
242
- return [
243
- {
244
- "role": "user",
245
- "content": [
246
- {"type": "image", "image": pil_image},
247
- {"type": "text", "text": f"{guidelines}\n{instruction}"}
248
- ],
249
- }
250
- ]
251
 
252
  # -----------------------------------------------------------------------------
253
- # 5. PARSING LOGIC
254
  # -----------------------------------------------------------------------------
255
 
256
- def parse_coordinate_response(text: str) -> List[Dict]:
257
- """
258
- Parses UI-TARS and Holo2 output formats.
259
- Targets formats like: Click(x, y), point=[x, y], etc.
260
- """
261
  actions = []
262
  text = text.strip()
263
 
264
- print(f"Parsing Coordinate output: {text}")
 
265
 
266
- # Regex 1: Click(x, y) - Standard prompt output for UI-TARS & Holo2
 
267
  matches_click = re.findall(r"Click\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", text, re.IGNORECASE)
268
  for m in matches_click:
269
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
@@ -278,7 +208,7 @@ def parse_coordinate_response(text: str) -> List[Dict]:
278
  for m in matches_box:
279
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
280
 
281
- # Remove duplicates
282
  unique_actions = []
283
  seen = set()
284
  for a in actions:
@@ -311,13 +241,23 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
311
  if not actions: return None
312
  img_copy = original_image.copy()
313
  draw = ImageDraw.Draw(img_copy)
 
314
 
315
  try: font = ImageFont.load_default()
316
  except: font = None
317
 
318
  for act in actions:
319
- pixel_x = int(act['x'])
320
- pixel_y = int(act['y'])
 
 
 
 
 
 
 
 
 
321
 
322
  color = 'red' if 'click' in act['type'].lower() else 'blue'
323
 
@@ -343,7 +283,7 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
343
  return img_copy
344
 
345
  # -----------------------------------------------------------------------------
346
- # 6. CORE LOGIC
347
  # -----------------------------------------------------------------------------
348
 
349
  @spaces.GPU(duration=120)
@@ -352,18 +292,14 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
352
 
353
  input_pil_image = array_to_image(input_numpy_image)
354
  orig_w, orig_h = input_pil_image.size
355
-
356
- actions = []
357
- raw_response = ""
358
 
359
- # -----------------------
360
- # MODEL: UI-TARS-1.5-7B
361
- # -----------------------
362
  if model_choice == "UI-TARS-1.5-7B":
363
  if model_x is None: return "Error: UI-TARS model failed to load on startup.", None
364
  print("Using UI-TARS Pipeline...")
365
 
366
- # 1. Smart Resize
 
367
  ip_params = get_image_proc_params(processor_x)
368
  resized_h, resized_w = smart_resize(
369
  input_pil_image.height, input_pil_image.width,
@@ -372,78 +308,36 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
372
  )
373
  proc_image = input_pil_image.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
374
 
375
- # 2. Prompt & Inputs
376
  messages = get_uitars_prompt(task, proc_image)
377
  text_prompt = processor_x.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
378
  inputs = processor_x(text=[text_prompt], images=[proc_image], padding=True, return_tensors="pt")
379
  inputs = {k: v.to(device) for k, v in inputs.items()}
380
 
381
- # 3. Generate
382
  with torch.no_grad():
383
  generated_ids = model_x.generate(**inputs, max_new_tokens=128)
384
 
 
385
  generated_ids = [out_ids[len(in_seq):] for in_seq, out_ids in zip(inputs.get("input_ids"), generated_ids)]
386
  raw_response = processor_x.batch_decode(generated_ids, skip_special_tokens=True)[0]
387
 
388
- # 4. Parse & Rescale
389
- actions = parse_coordinate_response(raw_response)
390
 
391
- # Map coordinates from resized space back to original space
 
 
392
  scale_x = orig_w / resized_w
393
  scale_y = orig_h / resized_h
394
- for a in actions:
395
- a['x'] = int(a['x'] * scale_x)
396
- a['y'] = int(a['y'] * scale_y)
397
-
398
- # -----------------------
399
- # MODEL: Holo2-8B
400
- # -----------------------
401
- elif model_choice == "Holo2-8B":
402
- if model_h is None: return "Error: Holo2 model failed to load on startup.", None
403
- print("Using Holo2 Pipeline...")
404
-
405
- # 1. Smart Resize (Holo2 typically uses merge_size=1 or similar logic)
406
- ip_params = get_image_proc_params(processor_h)
407
- # Force merge_size to 1 if not detected (as per common practice for this model architecture variant)
408
- ms = ip_params.get("merge_size", 1)
409
 
410
- resized_h, resized_w = smart_resize(
411
- input_pil_image.height, input_pil_image.width,
412
- factor=ip_params["patch_size"] * ms,
413
- min_pixels=ip_params["min_pixels"], max_pixels=ip_params["max_pixels"]
414
- )
415
- proc_image = input_pil_image.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
416
-
417
- # 2. Prompt & Inputs
418
- messages = get_holo_prompt(proc_image, task)
419
- text_prompt = apply_chat_template_compat(processor_h, messages)
420
-
421
- # Holo2 / Qwen2-VL based inputs
422
- inputs = processor_h(text=[text_prompt], images=[proc_image], padding=True, return_tensors="pt")
423
- inputs = {k: v.to(device) for k, v in inputs.items()}
424
-
425
- # 3. Generate
426
- with torch.no_grad():
427
- generated_ids = model_h.generate(**inputs, max_new_tokens=128)
428
-
429
- # Trim input tokens
430
- generated_ids_trimmed = trim_generated(generated_ids, inputs)
431
- raw_response = batch_decode_compat(processor_h, generated_ids_trimmed, skip_special_tokens=True)[0]
432
-
433
- # 4. Parse & Rescale
434
- # Holo2 prompt asks for Click(x,y) similar to UI-TARS
435
- actions = parse_coordinate_response(raw_response)
436
-
437
- # Map coordinates from resized space back to original space
438
- scale_x = orig_w / resized_w
439
- scale_y = orig_h / resized_h
440
  for a in actions:
441
  a['x'] = int(a['x'] * scale_x)
442
  a['y'] = int(a['y'] * scale_y)
443
 
444
- # -----------------------
445
- # MODEL: Fara-7B
446
- # -----------------------
447
  else:
448
  if model_v is None: return "Error: Fara model failed to load on startup.", None
449
  print("Using Fara Pipeline...")
@@ -481,7 +375,7 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
481
  return raw_response, output_image
482
 
483
  # -----------------------------------------------------------------------------
484
- # 7. UI SETUP
485
  # -----------------------------------------------------------------------------
486
 
487
  with gr.Blocks(theme=steel_blue_theme, css=css) as demo:
@@ -494,7 +388,7 @@ with gr.Blocks(theme=steel_blue_theme, css=css) as demo:
494
 
495
  with gr.Row():
496
  model_choice = gr.Radio(
497
- choices=["Fara-7B", "UI-TARS-1.5-7B", "Holo2-8B"],
498
  label="Select Model",
499
  value="Fara-7B",
500
  interactive=True
 
5
  import unicodedata
6
  import gc
7
  from io import BytesIO
8
+ from typing import Iterable
9
+ from typing import Tuple, Optional, List, Dict, Any
10
 
11
  import gradio as gr
12
  import numpy as np
 
116
  print("🔄 Loading UI-TARS-1.5-7B...")
117
  MODEL_ID_X = "ByteDance-Seed/UI-TARS-1.5-7B"
118
  try:
119
+ # Important: use_fast=False is often required for custom tokenizers
120
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True, use_fast=False)
121
  model_x = AutoModelForImageTextToText.from_pretrained(
122
  MODEL_ID_X,
 
128
  model_x = None
129
  processor_x = None
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  print("✅ Models loading sequence complete.")
132
 
133
  # -----------------------------------------------------------------------------
134
+ # 3. UTILS & PROMPTS
135
  # -----------------------------------------------------------------------------
136
 
137
  def array_to_image(image_array: np.ndarray) -> Image.Image:
138
  if image_array is None: raise ValueError("No image provided.")
139
  return Image.fromarray(np.uint8(image_array))
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # --- Fara Prompt ---
142
  def get_fara_prompt(task, image):
143
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
 
155
 
156
  # --- UI-TARS Prompt ---
157
  def get_uitars_prompt(task, image):
158
+ # UI-TARS generally responds better to a simpler instruction when finetuned
159
  guidelines = (
160
  "Localize an element on the GUI image according to my instructions and "
161
  "output a click position as Click(x, y) with x num pixels from the left edge "
 
171
  }
172
  ]
173
 
174
+ def get_image_proc_params(processor) -> Dict[str, int]:
175
+ ip = getattr(processor, "image_processor", None)
176
+ return {
177
+ "patch_size": getattr(ip, "patch_size", 14),
178
+ "merge_size": getattr(ip, "merge_size", 2),
179
+ "min_pixels": getattr(ip, "min_pixels", 256 * 256),
180
+ "max_pixels": getattr(ip, "max_pixels", 1280 * 1280),
181
+ }
 
 
 
 
 
 
 
 
182
 
183
  # -----------------------------------------------------------------------------
184
+ # 4. PARSING LOGIC
185
  # -----------------------------------------------------------------------------
186
 
187
+ def parse_uitars_response(text: str) -> List[Dict]:
188
+ """Parse various UI-TARS output formats"""
 
 
 
189
  actions = []
190
  text = text.strip()
191
 
192
+ # Debug print
193
+ print(f"Parsing UI-TARS output: {text}")
194
 
195
+ # Regex 1: Click(x, y) - Standard prompt output
196
+ # Matches: Click(123, 456) or Click(123,456)
197
  matches_click = re.findall(r"Click\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", text, re.IGNORECASE)
198
  for m in matches_click:
199
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
 
208
  for m in matches_box:
209
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
210
 
211
+ # Remove duplicates if any logic matched multiple times
212
  unique_actions = []
213
  seen = set()
214
  for a in actions:
 
241
  if not actions: return None
242
  img_copy = original_image.copy()
243
  draw = ImageDraw.Draw(img_copy)
244
+ width, height = img_copy.size
245
 
246
  try: font = ImageFont.load_default()
247
  except: font = None
248
 
249
  for act in actions:
250
+ x = act['x']
251
+ y = act['y']
252
+
253
+ # Determine if we need to scale normalized coords (0-1) or use absolute
254
+ # UI-TARS usually outputs absolute pixels relative to the image size it saw.
255
+ # But we already scaled them in the main loop.
256
+ # Double check sanity:
257
+ if x < 1.0 and y < 1.0:
258
+ pixel_x, pixel_y = int(x * width), int(y * height)
259
+ else:
260
+ pixel_x, pixel_y = int(x), int(y)
261
 
262
  color = 'red' if 'click' in act['type'].lower() else 'blue'
263
 
 
283
  return img_copy
284
 
285
  # -----------------------------------------------------------------------------
286
+ # 5. CORE LOGIC
287
  # -----------------------------------------------------------------------------
288
 
289
  @spaces.GPU(duration=120)
 
292
 
293
  input_pil_image = array_to_image(input_numpy_image)
294
  orig_w, orig_h = input_pil_image.size
 
 
 
295
 
296
+ # --- UI-TARS Logic ---
 
 
297
  if model_choice == "UI-TARS-1.5-7B":
298
  if model_x is None: return "Error: UI-TARS model failed to load on startup.", None
299
  print("Using UI-TARS Pipeline...")
300
 
301
+ # 1. Smart Resize (Crucial for UI-TARS accuracy)
302
+ # We must resize the image to the resolution the model expects/handles best
303
  ip_params = get_image_proc_params(processor_x)
304
  resized_h, resized_w = smart_resize(
305
  input_pil_image.height, input_pil_image.width,
 
308
  )
309
  proc_image = input_pil_image.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
310
 
311
+ # 2. Prompting
312
  messages = get_uitars_prompt(task, proc_image)
313
  text_prompt = processor_x.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
314
+
315
+ # 3. Inputs
316
  inputs = processor_x(text=[text_prompt], images=[proc_image], padding=True, return_tensors="pt")
317
  inputs = {k: v.to(device) for k, v in inputs.items()}
318
 
319
+ # 4. Generate
320
  with torch.no_grad():
321
  generated_ids = model_x.generate(**inputs, max_new_tokens=128)
322
 
323
+ # Decode
324
  generated_ids = [out_ids[len(in_seq):] for in_seq, out_ids in zip(inputs.get("input_ids"), generated_ids)]
325
  raw_response = processor_x.batch_decode(generated_ids, skip_special_tokens=True)[0]
326
 
327
+ # 5. Parse
328
+ actions = parse_uitars_response(raw_response)
329
 
330
+ # 6. Rescale Coordinates back to Original Image Size
331
+ # The model saw 'resized_w' x 'resized_h', so coordinates are in that space.
332
+ # We need to map them back to 'orig_w' x 'orig_h' for the visualizer.
333
  scale_x = orig_w / resized_w
334
  scale_y = orig_h / resized_h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  for a in actions:
337
  a['x'] = int(a['x'] * scale_x)
338
  a['y'] = int(a['y'] * scale_y)
339
 
340
+ # --- Fara Logic ---
 
 
341
  else:
342
  if model_v is None: return "Error: Fara model failed to load on startup.", None
343
  print("Using Fara Pipeline...")
 
375
  return raw_response, output_image
376
 
377
  # -----------------------------------------------------------------------------
378
+ # 6. UI SETUP
379
  # -----------------------------------------------------------------------------
380
 
381
  with gr.Blocks(theme=steel_blue_theme, css=css) as demo:
 
388
 
389
  with gr.Row():
390
  model_choice = gr.Radio(
391
+ choices=["Fara-7B", "UI-TARS-1.5-7B"],
392
  label="Select Model",
393
  value="Fara-7B",
394
  interactive=True