| import gradio as gr |
| import json, os, re, traceback, contextlib, math, random |
| from typing import Any, List, Dict, Optional, Tuple |
|
|
| import spaces |
| import torch |
| from PIL import Image, ImageDraw |
| import requests |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
|
|
| |
| MODEL_ID = "Hcompany/Holo1-3B" |
|
|
| |
|
|
| def pick_device() -> str: |
| """ |
| On HF Spaces (ZeroGPU), CUDA is only available inside @spaces.GPU calls. |
| We still honor FORCE_DEVICE for local testing. |
| """ |
| forced = os.getenv("FORCE_DEVICE", "").lower().strip() |
| if forced in {"cpu", "cuda", "mps"}: |
| return forced |
| if torch.cuda.is_available(): |
| return "cuda" |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
| def pick_dtype(device: str) -> torch.dtype: |
| if device == "cuda": |
| major, _ = torch.cuda.get_device_capability() |
| return torch.bfloat16 if major >= 8 else torch.float16 |
| if device == "mps": |
| return torch.float16 |
| return torch.float32 |
|
|
| def move_to_device(batch, device: str): |
| if isinstance(batch, dict): |
| return {k: (v.to(device, non_blocking=True) if hasattr(v, "to") else v) for k, v in batch.items()} |
| if hasattr(batch, "to"): |
| return batch.to(device, non_blocking=True) |
| return batch |
|
|
| |
| def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str: |
| tok = getattr(processor, "tokenizer", None) |
| if hasattr(processor, "apply_chat_template"): |
| return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| if tok is not None and hasattr(tok, "apply_chat_template"): |
| return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| texts = [] |
| for m in messages: |
| for c in m.get("content", []): |
| if isinstance(c, dict) and c.get("type") == "text": |
| texts.append(c.get("text", "")) |
| return "\n".join(texts) |
|
|
| def batch_decode_compat(processor, token_id_batches, **kw): |
| tok = getattr(processor, "tokenizer", None) |
| if tok is not None and hasattr(tok, "batch_decode"): |
| return tok.batch_decode(token_id_batches, **kw) |
| if hasattr(processor, "batch_decode"): |
| return processor.batch_decode(token_id_batches, **kw) |
| raise AttributeError("No batch_decode available on processor or tokenizer.") |
|
|
| def get_image_proc_params(processor) -> Dict[str, int]: |
| ip = getattr(processor, "image_processor", None) |
| return { |
| "patch_size": getattr(ip, "patch_size", 14), |
| "merge_size": getattr(ip, "merge_size", 1), |
| "min_pixels": getattr(ip, "min_pixels", 256 * 256), |
| "max_pixels": getattr(ip, "max_pixels", 1280 * 1280), |
| } |
|
|
| def trim_generated(generated_ids, inputs): |
| in_ids = getattr(inputs, "input_ids", None) |
| if in_ids is None and isinstance(inputs, dict): |
| in_ids = inputs.get("input_ids", None) |
| if in_ids is None: |
| return [out_ids for out_ids in generated_ids] |
| return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)] |
|
|
| |
| print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...") |
| model = None |
| processor = None |
| model_loaded = False |
| load_error_message = "" |
|
|
| try: |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| trust_remote_code=True, |
| ) |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
| model.eval() |
| model_loaded = True |
| print("Model and processor loaded on CPU.") |
| except Exception as e: |
| load_error_message = ( |
| f"Error loading model/processor: {e}\n" |
| "This might be due to network/model ID/library versions.\n" |
| "Check the full traceback in the logs." |
| ) |
| print(load_error_message) |
| traceback.print_exc() |
|
|
| |
| def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[dict]: |
| guidelines: str = ( |
| "Localize an element on the GUI image according to my instructions and " |
| "output a click position as Click(x, y) with x num pixels from the left edge " |
| "and y num pixels from the top edge." |
| ) |
| return [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": pil_image}, |
| {"type": "text", "text": f"{guidelines}\n{instruction}"} |
| ], |
| } |
| ] |
|
|
| |
| @torch.inference_mode() |
| def run_inference_localization( |
| messages_for_template: List[dict[str, Any]], |
| pil_image_for_processing: Image.Image, |
| device: str, |
| dtype: torch.dtype, |
| do_sample: bool = False, |
| temperature: float = 0.6, |
| top_p: float = 0.9, |
| max_new_tokens: int = 128, |
| ) -> str: |
| text_prompt = apply_chat_template_compat(processor, messages_for_template) |
|
|
| inputs = processor( |
| text=[text_prompt], |
| images=[pil_image_for_processing], |
| padding=True, |
| return_tensors="pt", |
| ) |
| inputs = move_to_device(inputs, device) |
|
|
| |
| if device == "cuda": |
| amp_ctx = torch.autocast(device_type="cuda", dtype=dtype) |
| elif device == "mps": |
| amp_ctx = torch.autocast(device_type="mps", dtype=torch.float16) |
| else: |
| amp_ctx = contextlib.nullcontext() |
|
|
| gen_kwargs = dict( |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| temperature=temperature, |
| top_p=top_p, |
| ) |
|
|
| with amp_ctx: |
| generated_ids = model.generate(**inputs, **gen_kwargs) |
|
|
| generated_ids_trimmed = trim_generated(generated_ids, inputs) |
| decoded_output = batch_decode_compat( |
| processor, |
| generated_ids_trimmed, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| ) |
| return decoded_output[0] if decoded_output else "" |
|
|
| |
| CLICK_RE = re.compile(r"Click\((\d+),\s*(\d+)\)") |
|
|
| def parse_click(s: str) -> Optional[Tuple[int, int]]: |
| m = CLICK_RE.search(s) |
| if not m: |
| return None |
| try: |
| return int(m.group(1)), int(m.group(2)) |
| except Exception: |
| return None |
|
|
| @torch.inference_mode() |
| def sample_clicks( |
| messages: List[dict], |
| img: Image.Image, |
| device: str, |
| dtype: torch.dtype, |
| n_samples: int = 7, |
| temperature: float = 0.6, |
| top_p: float = 0.9, |
| seed: Optional[int] = None, |
| ) -> List[Optional[Tuple[int, int]]]: |
| """ |
| Run multiple stochastic decodes to estimate self-consistency. |
| Returns a list of (x,y) or None (if parsing failed) for each sample. |
| """ |
| clicks: List[Optional[Tuple[int, int]]] = [] |
| |
| if seed is not None: |
| torch.manual_seed(seed) |
| random.seed(seed) |
| for i in range(n_samples): |
| |
| if seed is not None: |
| torch.manual_seed(seed + i + 1) |
| random.seed((seed + i + 1) & 0xFFFFFFFF) |
| out = run_inference_localization( |
| messages, img, device, dtype, |
| do_sample=True, temperature=temperature, top_p=top_p |
| ) |
| clicks.append(parse_click(out)) |
| return clicks |
|
|
| def cluster_and_confidence( |
| clicks: List[Optional[Tuple[int,int]]], |
| img_w: int, |
| img_h: int, |
| ) -> Dict[str, Any]: |
| """ |
| Simple robust consensus: |
| - Keep only valid points |
| - Compute median point (x_med, y_med) |
| - Compute distances to median |
| - Inlier threshold = max(8 px, 2% of min(img_w, img_h)) |
| - Confidence = (#inliers / #total_samples) * clamp(1 - (rms_inlier_dist / thr), 0, 1) |
| Returns dict with consensus point, confidence, dispersion, and counts. |
| """ |
| valid = [xy for xy in clicks if xy is not None] |
| total = len(clicks) |
| if total == 0: |
| return dict(ok=False, reason="no_samples") |
|
|
| if not valid: |
| return dict(ok=False, reason="no_valid_points", total=total) |
|
|
| xs = sorted([x for x, _ in valid]) |
| ys = sorted([y for _, y in valid]) |
| mid = len(valid) // 2 |
| if len(valid) % 2 == 1: |
| x_med = xs[mid] |
| y_med = ys[mid] |
| else: |
| x_med = (xs[mid - 1] + xs[mid]) // 2 |
| y_med = (ys[mid - 1] + ys[mid]) // 2 |
|
|
| thr = max(8.0, 0.02 * min(img_w, img_h)) |
| dists = [math.hypot(x - x_med, y - y_med) for (x, y) in valid] |
| inliers = [(xy, d) for xy, d in zip(valid, dists) if d <= thr] |
| outliers = [(xy, d) for xy, d in zip(valid, dists) if d > thr] |
| inlier_count = len(inliers) |
|
|
| |
| if inliers: |
| rms = math.sqrt(sum(d*d for _, d in inliers) / len(inliers)) |
| else: |
| rms = float("inf") |
|
|
| |
| if inliers: |
| sharp = max(0.0, min(1.0, 1.0 - (rms / thr))) |
| else: |
| sharp = 0.0 |
| confidence = (inlier_count / total) * sharp |
|
|
| return dict( |
| ok=True, |
| x=x_med, y=y_med, |
| confidence=confidence, |
| total_samples=total, |
| valid_samples=len(valid), |
| inliers=inlier_count, |
| outliers=len(outliers), |
| sigma_px=rms if math.isfinite(rms) else None, |
| inlier_threshold_px=thr, |
| all_points=valid, |
| inlier_points=[xy for xy,_ in inliers], |
| outlier_points=[xy for xy,_ in outliers], |
| ) |
|
|
| def draw_samples( |
| base_img: Image.Image, |
| consensus_xy: Optional[Tuple[int,int]], |
| inliers: List[Tuple[int,int]], |
| outliers: List[Tuple[int,int]], |
| ring_color: str = "red", |
| ) -> Image.Image: |
| """ |
| Overlay all sampled points: green=inliers, red=outliers, plus a ring for consensus. |
| """ |
| img = base_img.copy().convert("RGB") |
| draw = ImageDraw.Draw(img) |
| w, h = img.size |
| |
| r = max(3, min(w, h) // 200) |
|
|
| |
| for (x, y) in inliers: |
| draw.ellipse((x - r, y - r, x + r, y + r), fill="green", outline=None) |
|
|
| |
| for (x, y) in outliers: |
| draw.ellipse((x - r, y - r, x + r, y + r), fill="red", outline=None) |
|
|
| |
| if consensus_xy is not None: |
| cx, cy = consensus_xy |
| ring_r = max(5, min(w, h) // 100, r * 3) |
| draw.ellipse((cx - ring_r, cy - ring_r, cx + ring_r, cy + ring_r), outline=ring_color, width=max(2, ring_r // 4)) |
| return img |
|
|
| |
| |
| @spaces.GPU(duration=120) |
| def predict_click_location( |
| input_pil_image: Image.Image, |
| instruction: str, |
| estimate_confidence: bool = True, |
| num_samples: int = 7, |
| temperature: float = 0.6, |
| top_p: float = 0.9, |
| seed: Optional[int] = None, |
| ): |
| if not model_loaded or not processor or not model: |
| return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a" |
| if not input_pil_image: |
| return "No image provided. Please upload an image.", None, "device: n/a | dtype: n/a" |
| if not instruction or instruction.strip() == "": |
| return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB"), "device: n/a | dtype: n/a" |
|
|
| |
| device = pick_device() |
| dtype = pick_dtype(device) |
|
|
| |
| if device == "cuda": |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
|
|
| |
| try: |
| p = next(model.parameters()) |
| cur_dev = p.device.type |
| cur_dtype = p.dtype |
| except StopIteration: |
| cur_dev, cur_dtype = "cpu", torch.float32 |
|
|
| if cur_dev != device or cur_dtype != dtype: |
| model.to(device=device, dtype=dtype) |
| model.eval() |
|
|
| |
| try: |
| ip = get_image_proc_params(processor) |
| resized_height, resized_width = smart_resize( |
| input_pil_image.height, |
| input_pil_image.width, |
| factor=ip["patch_size"] * ip["merge_size"], |
| min_pixels=ip["min_pixels"], |
| max_pixels=ip["max_pixels"], |
| ) |
| resized_image = input_pil_image.resize( |
| size=(resized_width, resized_height), |
| resample=Image.Resampling.LANCZOS |
| ) |
| except Exception as e: |
| traceback.print_exc() |
| return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
|
|
| |
| messages = get_localization_prompt(resized_image, instruction) |
|
|
| |
| try: |
| if estimate_confidence and num_samples >= 3: |
| |
| clicks = sample_clicks( |
| messages, resized_image, device, dtype, |
| n_samples=int(num_samples), |
| temperature=float(temperature), |
| top_p=float(top_p), |
| seed=seed |
| ) |
| summary = cluster_and_confidence(clicks, resized_image.width, resized_image.height) |
|
|
| if not summary.get("ok", False): |
| |
| coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False) |
| out_img = resized_image.copy().convert("RGB") |
| match = CLICK_RE.search(coord_str or "") |
| if match: |
| x, y = int(match.group(1)), int(match.group(2)) |
| out_img = draw_samples(out_img, (x, y), [], []) |
| coords_text = f"{coord_str} | confidence=0.00 (fallback)" |
| return coords_text, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
|
|
| |
| x, y = int(summary["x"]), int(summary["y"]) |
| conf = summary["confidence"] |
| inliers = summary["inlier_points"] |
| outliers = summary["outlier_points"] |
| sigma = summary["sigma_px"] |
| thr = summary["inlier_threshold_px"] |
| total = summary["total_samples"] |
| valid = summary["valid_samples"] |
|
|
| |
| coord_str = f"Click({x}, {y})" |
| diag = ( |
| f"confidence={conf:.2f} | samples(valid/total)={valid}/{total} | " |
| f"inliers={len(inliers)} | σ={sigma:.1f}px | thr={thr:.1f}px | " |
| f"T={temperature:.2f}, p={top_p:.2f}" |
| ) |
|
|
| |
| out_img = draw_samples(resized_image, (x, y), inliers, outliers) |
| return f"{coord_str} | {diag}", out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
|
|
| else: |
| |
| coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False) |
| out_img = resized_image.copy().convert("RGB") |
| match = CLICK_RE.search(coord_str or "") |
| if match: |
| x = int(match.group(1)) |
| y = int(match.group(2)) |
| |
| out_img = draw_samples(out_img, (x, y), [], []) |
| else: |
| print(f"Could not parse 'Click(x, y)' from model output: {coord_str}") |
| return coord_str, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
|
|
| except Exception as e: |
| traceback.print_exc() |
| return f"Error during model inference: {e}", resized_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
|
|
| |
| example_image = None |
| example_instruction = "Enter the server address readyforquantum.com to check its security" |
| try: |
| example_image_url = "https://readyforquantum.com/img/screentest.jpg" |
| example_image = Image.open(requests.get(example_image_url, stream=True).raw) |
| except Exception as e: |
| print(f"Could not load example image from URL: {e}") |
| traceback.print_exc() |
| try: |
| example_image = Image.new("RGB", (200, 150), color="lightgray") |
| draw = ImageDraw.Draw(example_image) |
| draw.text((10, 10), "Example image\nfailed to load", fill="black") |
| except Exception: |
| pass |
|
|
| |
| title = "Holo1-3B: Holo1 Localization Demo (ZeroGPU-ready)" |
| article = f""" |
| <p style='text-align: center'> |
| Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany | |
| Paper: <a href='https://cdn.prod.website-files.com/67e2dbd9acff0c50d4c8a80c/683ec8095b353e8b38317f80_h_tech_report_v1.pdf' target='_blank'>HCompany Tech Report</a> | |
| Blog: <a href='https://www.hcompany.ai/surfer-h' target='_blank'>Surfer-H Blog Post</a><br/> |
| <small>GPU (if available) is requested only during inference via @spaces.GPU.</small> |
| </p> |
| """ |
|
|
| if not model_loaded: |
| with gr.Blocks() as demo: |
| gr.Markdown(f"# <center>⚠️ Error: Model Failed to Load ⚠️</center>") |
| gr.Markdown(f"<center>{load_error_message}</center>") |
| gr.Markdown("<center>See logs for the full traceback.</center>") |
| else: |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") |
| gr.Markdown(article) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image_component = gr.Image(type="pil", label="Input UI Image", height=400) |
| instruction_component = gr.Textbox( |
| label="Instruction", |
| placeholder="e.g., Click the 'Login' button", |
| info="Type the action you want the model to localize on the image." |
| ) |
| estimate_conf = gr.Checkbox(value=True, label="Estimate confidence (slower)") |
| num_samples_slider = gr.Slider(3, 15, value=7, step=1, label="Samples (for confidence)") |
| temperature_slider = gr.Slider(0.2, 1.2, value=0.6, step=0.05, label="Temperature") |
| top_p_slider = gr.Slider(0.5, 0.99, value=0.9, step=0.01, label="Top-p") |
| seed_box = gr.Number(value=None, precision=0, label="Seed (optional, for reproducibility)") |
| submit_button = gr.Button("Localize Click", variant="primary") |
|
|
| with gr.Column(scale=1): |
| output_coords_component = gr.Textbox( |
| label="Predicted Coordinates + Confidence", |
| interactive=False |
| ) |
| output_image_component = gr.Image( |
| type="pil", |
| label="Image with Samples (green=inliers, red=outliers) and Final Ring", |
| height=400, |
| interactive=False |
| ) |
| runtime_info = gr.Textbox( |
| label="Runtime Info", |
| value="device: n/a | dtype: n/a", |
| interactive=False |
| ) |
|
|
| if example_image: |
| gr.Examples( |
| examples=[[example_image, example_instruction, True, 7, 0.6, 0.9, None]], |
| inputs=[ |
| input_image_component, |
| instruction_component, |
| estimate_conf, |
| num_samples_slider, |
| temperature_slider, |
| top_p_slider, |
| seed_box, |
| ], |
| outputs=[output_coords_component, output_image_component, runtime_info], |
| fn=predict_click_location, |
| cache_examples="lazy", |
| ) |
|
|
| submit_button.click( |
| fn=predict_click_location, |
| inputs=[ |
| input_image_component, |
| instruction_component, |
| estimate_conf, |
| num_samples_slider, |
| temperature_slider, |
| top_p_slider, |
| seed_box, |
| ], |
| outputs=[output_coords_component, output_image_component, runtime_info] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(debug=True) |
|
|