prithivMLmods commited on
Commit
83e0b46
·
verified ·
1 Parent(s): 7696916

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -110
app.py CHANGED
@@ -5,8 +5,7 @@ import time
5
  import unicodedata
6
  import gc
7
  from io import BytesIO
8
- from typing import Iterable
9
- from typing import Tuple, Optional, List, Dict, Any
10
 
11
  import gradio as gr
12
  import numpy as np
@@ -94,7 +93,6 @@ class OrangeRedTheme(Soft):
94
 
95
  orange_red_theme = OrangeRedTheme()
96
 
97
- # --- Device Setup ---
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
  print(f"Running on device: {device}")
100
 
@@ -129,16 +127,19 @@ except Exception as e:
129
  processor_x = None
130
 
131
  print("🔄 Loading Holo2-4B...")
132
- MODEL_ID_H = "Hcompany/Holo2-4B"
133
  try:
134
  processor_h = AutoProcessor.from_pretrained(MODEL_ID_H, trust_remote_code=True)
135
  model_h = AutoModelForImageTextToText.from_pretrained(
136
  MODEL_ID_H,
137
  trust_remote_code=True,
138
  torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
139
- ).to(device).eval()
 
 
 
140
  except Exception as e:
141
- print(f"Failed to load Holo2: {e}")
142
  model_h = None
143
  processor_h = None
144
 
@@ -177,17 +178,7 @@ def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str
177
  return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
178
  if tok is not None and hasattr(tok, "apply_chat_template"):
179
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
180
-
181
- texts = []
182
- for m in messages:
183
- content = m.get("content", "")
184
- if isinstance(content, list):
185
- for c in content:
186
- if isinstance(c, dict) and c.get("type") == "text":
187
- texts.append(c.get("text", ""))
188
- elif isinstance(content, str):
189
- texts.append(content)
190
- return "\n".join(texts)
191
 
192
  def batch_decode_compat(processor, token_id_batches, **kw):
193
  tok = getattr(processor, "tokenizer", None)
@@ -205,7 +196,7 @@ def trim_generated(generated_ids, inputs):
205
  return generated_ids
206
  return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
207
 
208
- # --- Prompts ---
209
 
210
  def get_fara_prompt(task, image):
211
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
@@ -237,15 +228,23 @@ def get_localization_prompt(task, image):
237
  }
238
  ]
239
 
240
- # --- Parsing Logic ---
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  def parse_click_response(text: str) -> List[Dict]:
243
- """Parses standard (x,y) text responses from TARS/General VLMs"""
244
  actions = []
245
  text = text.strip()
246
 
247
- print(f"Parsing click-style output: {text}")
248
-
249
  matches_click = re.findall(r"Click\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", text, re.IGNORECASE)
250
  for m in matches_click:
251
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
@@ -262,6 +261,7 @@ def parse_click_response(text: str) -> List[Dict]:
262
  for m in matches_tuple:
263
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
264
 
 
265
  unique_actions = []
266
  seen = set()
267
  for a in actions:
@@ -273,7 +273,6 @@ def parse_click_response(text: str) -> List[Dict]:
273
  return unique_actions
274
 
275
  def parse_fara_response(response: str) -> List[Dict]:
276
- """Parses Fara's specific tool_call JSON format"""
277
  actions = []
278
  matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
279
  for match in matches:
@@ -292,35 +291,44 @@ def parse_fara_response(response: str) -> List[Dict]:
292
  pass
293
  return actions
294
 
295
- def parse_holo_reasoning(generated_ids: torch.Tensor, processor) -> tuple[str, str]:
296
- """Parses Holo2 content separating thought process from JSON output"""
297
  all_ids = generated_ids[0].tolist()
298
 
299
- # 151667 = <think>, 151668 = </think> for this tokenizer
300
  try:
301
  think_start_index = all_ids.index(151667)
302
  except ValueError:
303
  think_start_index = -1
304
-
305
  try:
306
  think_end_index = all_ids.index(151668)
307
  except ValueError:
308
- think_end_index = len(all_ids)
309
-
310
- thinking_content = ""
311
- if think_start_index != -1:
312
- thinking_content = processor.decode(
313
- all_ids[think_start_index+1 : think_end_index],
314
- skip_special_tokens=True
315
- ).strip("\n")
316
 
317
- # Content comes after the thinking block
318
- start_content = think_end_index + 1 if think_end_index < len(all_ids) else 0
319
- content = processor.decode(all_ids[start_content:], skip_special_tokens=True).strip("\n")
320
-
 
 
 
 
321
  return content, thinking_content
322
 
323
- # --- Visualization ---
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  def create_localized_image(original_image: Image.Image, actions: list[dict]) -> Optional[Image.Image]:
326
  if not actions: return None
@@ -337,15 +345,18 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
337
  y = act['y']
338
 
339
  pixel_x, pixel_y = int(x), int(y)
340
-
341
  color = 'red' if 'click' in act['type'].lower() else 'blue'
342
 
343
- r = 20
344
- line_width = 5
 
 
 
 
345
 
346
- # Draw target circles
347
- draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], outline=color, width=line_width)
348
- draw.ellipse([pixel_x - 4, pixel_y - 4, pixel_x + 4, pixel_y + 4], fill=color)
349
 
350
  label = f"{act['type'].capitalize()}"
351
  if act.get('text'): label += f": \"{act['text']}\""
@@ -355,10 +366,10 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
355
  try:
356
  bbox = draw.textbbox(text_pos, label, font=font)
357
  padded_bbox = (bbox[0]-4, bbox[1]-2, bbox[2]+4, bbox[3]+2)
358
- draw.rectangle(padded_bbox, fill="black", outline=color)
359
- draw.text(text_pos, label, fill="white", font=font)
360
- except Exception as e:
361
- draw.text(text_pos, label, fill="white")
362
 
363
  return img_copy
364
 
@@ -372,9 +383,9 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
372
  input_pil_image = array_to_image(input_numpy_image)
373
  orig_w, orig_h = input_pil_image.size
374
  actions = []
375
- final_text_response = ""
 
376
 
377
- # --- Fara-7B ---
378
  if model_choice == "Fara-7B":
379
  if model_v is None: return "Error: Fara model failed to load on startup.", None
380
  print("Using Fara Pipeline...")
@@ -397,76 +408,63 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
397
 
398
  generated_ids = trim_generated(generated_ids, inputs)
399
  raw_response = processor_v.batch_decode(generated_ids, skip_special_tokens=True)[0]
400
- final_text_response = raw_response
401
  actions = parse_fara_response(raw_response)
402
 
403
- # --- Holo2-4B ---
404
  elif model_choice == "Holo2-4B":
405
- if model_h is None: return "Error: Holo2 model failed to load.", None
406
  print("Using Holo2-4B Pipeline...")
407
 
408
- # Holo2 standard chat format
409
- messages = [
410
- {"role": "user", "content": [
411
- {"type": "image", "image": input_pil_image},
412
- {"type": "text", "text": task}
413
- ]}
414
- ]
 
 
 
415
 
416
- # Prepare inputs
417
- text_prompt = processor_h.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
418
- image_inputs, video_inputs = process_vision_info(messages)
 
 
 
 
 
 
419
 
420
  inputs = processor_h(
421
- text=[text_prompt],
422
- images=image_inputs,
423
- padding=True,
424
  return_tensors="pt"
425
- )
426
- inputs = {k: v.to(device) for k, v in inputs.items()}
427
-
428
  with torch.no_grad():
429
- # Adjust max_new_tokens to accommodate thinking process + json
430
- generated_ids = model_h.generate(**inputs, max_new_tokens=1024)
431
-
432
- # Parse Thinking vs Content
433
- generated_ids_trimmed = trim_generated(generated_ids, inputs)
434
- content, thinking = parse_holo_reasoning(generated_ids_trimmed, processor_h)
435
 
436
- final_text_response = f"💭 **Reasoning:**\n{thinking}\n\n📍 **Action Output:**\n{content}"
 
 
 
437
 
438
- # Parse JSON Coordinate {"x": int, "y": int} (0-1000 scale)
439
- try:
440
- # Clean content just in case of markdown blocks
441
- clean_content = content.replace("```json", "").replace("```", "").strip()
442
- data = json.loads(clean_content)
443
-
444
- norm_x = data.get("x", 0)
445
- norm_y = data.get("y", 0)
446
-
447
- # Convert 0-1000 scale to original image pixels
448
- pixel_x = (norm_x / 1000) * orig_w
449
- pixel_y = (norm_y / 1000) * orig_h
450
-
451
- actions.append({
452
- "type": "click",
453
- "x": int(pixel_x),
454
- "y": int(pixel_y),
455
- "text": "Target"
456
- })
457
-
458
- except json.JSONDecodeError:
459
- print(f"Failed to parse Holo2 JSON: {content}")
460
- except Exception as e:
461
- print(f"Error processing Holo2 output: {e}")
462
 
463
- # --- UI-TARS-1.5-7B ---
464
  elif model_choice == "UI-TARS-1.5-7B":
465
  if model_x is None: return "Error: UI-TARS model failed to load.", None
466
  print("Using UI-TARS Pipeline...")
467
 
468
  ip_params = get_image_proc_params(processor_x)
469
-
470
  resized_h, resized_w = smart_resize(
471
  input_pil_image.height, input_pil_image.width,
472
  factor=ip_params["patch_size"] * ip_params["merge_size"],
@@ -486,11 +484,10 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
486
 
487
  generated_ids = trim_generated(generated_ids, inputs)
488
  raw_response = batch_decode_compat(processor_x, generated_ids, skip_special_tokens=True)[0]
489
- final_text_response = raw_response
490
 
491
  actions = parse_click_response(raw_response)
492
 
493
- # Scale back from resized dims to original
494
  if resized_w > 0 and resized_h > 0:
495
  scale_x = orig_w / resized_w
496
  scale_y = orig_h / resized_h
@@ -501,17 +498,17 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
501
  else:
502
  return f"Error: Unknown model '{model_choice}'", None
503
 
 
504
  print(f"Parsed Actions: {actions}")
505
 
506
- # Generate visual output
507
  output_image = input_pil_image
508
  if actions:
509
  vis = create_localized_image(input_pil_image, actions)
510
  if vis: output_image = vis
511
 
512
- return final_text_response, output_image
513
 
514
- # --- Gradio UI ---
515
 
516
  css="""
517
  #col-container {
@@ -545,7 +542,7 @@ with gr.Blocks() as demo:
545
 
546
  with gr.Column(scale=3):
547
  output_image = gr.Image(label="Visualized Action Points", elem_id="out_img", height=500)
548
- output_text = gr.Textbox(label="Agent Model Response (Thinking & Action)", lines=12)
549
 
550
  submit_btn.click(
551
  fn=process_screenshot,
 
5
  import unicodedata
6
  import gc
7
  from io import BytesIO
8
+ from typing import Iterable, Tuple, Optional, List, Dict, Any
 
9
 
10
  import gradio as gr
11
  import numpy as np
 
93
 
94
  orange_red_theme = OrangeRedTheme()
95
 
 
96
  device = "cuda" if torch.cuda.is_available() else "cpu"
97
  print(f"Running on device: {device}")
98
 
 
127
  processor_x = None
128
 
129
  print("🔄 Loading Holo2-4B...")
130
+ MODEL_ID_H = "Hcompany/Holo2-4B"
131
  try:
132
  processor_h = AutoProcessor.from_pretrained(MODEL_ID_H, trust_remote_code=True)
133
  model_h = AutoModelForImageTextToText.from_pretrained(
134
  MODEL_ID_H,
135
  trust_remote_code=True,
136
  torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
137
+ device_map="auto" if device == "cuda" else None
138
+ ).eval()
139
+ if device == "cpu":
140
+ model_h = model_h.to(device)
141
  except Exception as e:
142
+ print(f"Failed to load Holo2-4B: {e}")
143
  model_h = None
144
  processor_h = None
145
 
 
178
  return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
179
  if tok is not None and hasattr(tok, "apply_chat_template"):
180
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
+ return ""
 
 
 
 
 
 
 
 
 
 
182
 
183
  def batch_decode_compat(processor, token_id_batches, **kw):
184
  tok = getattr(processor, "tokenizer", None)
 
196
  return generated_ids
197
  return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
198
 
199
+ # --- Prompt Builders ---
200
 
201
  def get_fara_prompt(task, image):
202
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
 
228
  }
229
  ]
230
 
231
+ def get_holo2_messages(task, image):
232
+ return [
233
+ {
234
+ "role": "user",
235
+ "content": [
236
+ {"type": "image", "image": image},
237
+ {"type": "text", "text": task}
238
+ ]
239
+ }
240
+ ]
241
+
242
+ # --- Response Parsers ---
243
 
244
  def parse_click_response(text: str) -> List[Dict]:
 
245
  actions = []
246
  text = text.strip()
247
 
 
 
248
  matches_click = re.findall(r"Click\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", text, re.IGNORECASE)
249
  for m in matches_click:
250
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
 
261
  for m in matches_tuple:
262
  actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
263
 
264
+ # Deduplicate
265
  unique_actions = []
266
  seen = set()
267
  for a in actions:
 
273
  return unique_actions
274
 
275
  def parse_fara_response(response: str) -> List[Dict]:
 
276
  actions = []
277
  matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
278
  for match in matches:
 
291
  pass
292
  return actions
293
 
294
+ def parse_holo2_reasoning(processor, generated_ids) -> tuple[str, str]:
295
+ """Parse content from generated_ids specifically for Holo2"""
296
  all_ids = generated_ids[0].tolist()
297
 
298
+ # Try to find thinking block indices
299
  try:
300
  think_start_index = all_ids.index(151667)
301
  except ValueError:
302
  think_start_index = -1
303
+
304
  try:
305
  think_end_index = all_ids.index(151668)
306
  except ValueError:
307
+ think_end_index = -1
 
 
 
 
 
 
 
308
 
309
+ if think_start_index != -1 and think_end_index != -1:
310
+ thinking_content = processor.decode(all_ids[think_start_index+1:think_end_index], skip_special_tokens=True).strip("\n")
311
+ content = processor.decode(all_ids[think_end_index+1:], skip_special_tokens=True).strip("\n")
312
+ else:
313
+ # If no thinking tags or incomplete, decode everything
314
+ thinking_content = ""
315
+ content = processor.decode(all_ids, skip_special_tokens=True).strip("\n")
316
+
317
  return content, thinking_content
318
 
319
+ def parse_holo2_json(content: str) -> List[Dict]:
320
+ actions = []
321
+ try:
322
+ # Clean potential markdown
323
+ cleaned = content.replace("```json", "").replace("```", "").strip()
324
+ data = json.loads(cleaned)
325
+ if "x" in data and "y" in data:
326
+ actions.append({"type": "click", "x": data["x"], "y": data["y"], "text": ""})
327
+ except json.JSONDecodeError:
328
+ print(f"Failed to parse Holo2 JSON: {content}")
329
+ return actions
330
+
331
+ # --- Visualizer ---
332
 
333
  def create_localized_image(original_image: Image.Image, actions: list[dict]) -> Optional[Image.Image]:
334
  if not actions: return None
 
345
  y = act['y']
346
 
347
  pixel_x, pixel_y = int(x), int(y)
 
348
  color = 'red' if 'click' in act['type'].lower() else 'blue'
349
 
350
+ # Draw Cross and Circle style (as requested by user preference)
351
+ cross_size = 20
352
+ # Horizontal line
353
+ draw.line([pixel_x - cross_size, pixel_y, pixel_x + cross_size, pixel_y], fill=color, width=4)
354
+ # Vertical line
355
+ draw.line([pixel_x, pixel_y - cross_size, pixel_x, pixel_y + cross_size], fill=color, width=4)
356
 
357
+ # Circle
358
+ r = 15
359
+ draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], outline=color, width=3)
360
 
361
  label = f"{act['type'].capitalize()}"
362
  if act.get('text'): label += f": \"{act['text']}\""
 
366
  try:
367
  bbox = draw.textbbox(text_pos, label, font=font)
368
  padded_bbox = (bbox[0]-4, bbox[1]-2, bbox[2]+4, bbox[3]+2)
369
+ draw.rectangle(padded_bbox, fill="yellow", outline=color)
370
+ draw.text(text_pos, label, fill="black", font=font)
371
+ except Exception:
372
+ draw.text(text_pos, label, fill="black")
373
 
374
  return img_copy
375
 
 
383
  input_pil_image = array_to_image(input_numpy_image)
384
  orig_w, orig_h = input_pil_image.size
385
  actions = []
386
+ raw_response = ""
387
+ thinking_output = ""
388
 
 
389
  if model_choice == "Fara-7B":
390
  if model_v is None: return "Error: Fara model failed to load on startup.", None
391
  print("Using Fara Pipeline...")
 
408
 
409
  generated_ids = trim_generated(generated_ids, inputs)
410
  raw_response = processor_v.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
411
  actions = parse_fara_response(raw_response)
412
 
 
413
  elif model_choice == "Holo2-4B":
414
+ if model_h is None: return "Error: Holo2-4B model failed to load.", None
415
  print("Using Holo2-4B Pipeline...")
416
 
417
+ # Specific Holo2 resizing logic
418
+ ip_config = processor_h.image_processor
419
+ resized_h, resized_w = smart_resize(
420
+ input_pil_image.height,
421
+ input_pil_image.width,
422
+ factor=ip_config.patch_size * ip_config.merge_size,
423
+ min_pixels=ip_config.size.get("shortest_edge", 256*256),
424
+ max_pixels=ip_config.size.get("longest_edge", 1280*1280),
425
+ )
426
+ processed_image = input_pil_image.resize((resized_w, resized_h), Image.Resampling.LANCZOS)
427
 
428
+ messages = get_holo2_messages(task, processed_image)
429
+
430
+ # Apply template with thinking=False for localization as per documentation/snippet
431
+ text_prompt = processor_h.apply_chat_template(
432
+ messages,
433
+ tokenize=False,
434
+ add_generation_prompt=True,
435
+ thinking=False
436
+ )
437
 
438
  inputs = processor_h(
439
+ text=[text_prompt],
440
+ images=[processed_image],
441
+ padding=True,
442
  return_tensors="pt"
443
+ ).to(model_h.device)
444
+
 
445
  with torch.no_grad():
446
+ generated_ids = model_h.generate(**inputs, max_new_tokens=128)
447
+
448
+ # Parse reasoning/content
449
+ content, thinking_output = parse_holo2_reasoning(processor_h, trim_generated(generated_ids, inputs))
450
+ raw_response = content
 
451
 
452
+ if thinking_output:
453
+ raw_response = f"[Thinking Process]:\n{thinking_output}\n\n[Action]:\n{content}"
454
+
455
+ actions = parse_holo2_json(content)
456
 
457
+ # Handle Holo2 coordinate normalization (0-1000) relative to image
458
+ # Math: (x_norm / 1000) * orig_w
459
+ for a in actions:
460
+ a['x'] = (a['x'] / 1000) * orig_w
461
+ a['y'] = (a['y'] / 1000) * orig_h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
 
463
  elif model_choice == "UI-TARS-1.5-7B":
464
  if model_x is None: return "Error: UI-TARS model failed to load.", None
465
  print("Using UI-TARS Pipeline...")
466
 
467
  ip_params = get_image_proc_params(processor_x)
 
468
  resized_h, resized_w = smart_resize(
469
  input_pil_image.height, input_pil_image.width,
470
  factor=ip_params["patch_size"] * ip_params["merge_size"],
 
484
 
485
  generated_ids = trim_generated(generated_ids, inputs)
486
  raw_response = batch_decode_compat(processor_x, generated_ids, skip_special_tokens=True)[0]
 
487
 
488
  actions = parse_click_response(raw_response)
489
 
490
+ # UI-TARS returns coordinates relative to resized image size
491
  if resized_w > 0 and resized_h > 0:
492
  scale_x = orig_w / resized_w
493
  scale_y = orig_h / resized_h
 
498
  else:
499
  return f"Error: Unknown model '{model_choice}'", None
500
 
501
+ print(f"Raw Output: {raw_response}")
502
  print(f"Parsed Actions: {actions}")
503
 
 
504
  output_image = input_pil_image
505
  if actions:
506
  vis = create_localized_image(input_pil_image, actions)
507
  if vis: output_image = vis
508
 
509
+ return raw_response, output_image
510
 
511
+ # --- UI Setup ---
512
 
513
  css="""
514
  #col-container {
 
542
 
543
  with gr.Column(scale=3):
544
  output_image = gr.Image(label="Visualized Action Points", elem_id="out_img", height=500)
545
+ output_text = gr.Textbox(label="Agent Model Response", lines=10)
546
 
547
  submit_btn.click(
548
  fn=process_screenshot,