| | import gradio as gr |
| | import json, os, re, traceback, contextlib |
| | from typing import Any, List, Dict |
| |
|
| | import spaces |
| | import torch |
| | from PIL import Image, ImageDraw |
| | import requests |
| | from transformers import AutoModelForImageTextToText, AutoProcessor |
| | from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
| |
|
| | |
| | MODEL_ID = "Hcompany/Holo1-3B" |
| |
|
| | |
| |
|
| | def pick_device() -> str: |
| | """ |
| | On HF Spaces (ZeroGPU), CUDA is only available inside @spaces.GPU calls. |
| | We still honor FORCE_DEVICE for local testing. |
| | """ |
| | forced = os.getenv("FORCE_DEVICE", "").lower().strip() |
| | if forced in {"cpu", "cuda", "mps"}: |
| | return forced |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
| | return "mps" |
| | return "cpu" |
| |
|
| | def pick_dtype(device: str) -> torch.dtype: |
| | if device == "cuda": |
| | major, _ = torch.cuda.get_device_capability() |
| | return torch.bfloat16 if major >= 8 else torch.float16 |
| | if device == "mps": |
| | return torch.float16 |
| | return torch.float32 |
| |
|
| | def move_to_device(batch, device: str): |
| | if isinstance(batch, dict): |
| | return {k: (v.to(device, non_blocking=True) if hasattr(v, "to") else v) for k, v in batch.items()} |
| | if hasattr(batch, "to"): |
| | return batch.to(device, non_blocking=True) |
| | return batch |
| |
|
| | |
| | def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str: |
| | tok = getattr(processor, "tokenizer", None) |
| | if hasattr(processor, "apply_chat_template"): |
| | return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | if tok is not None and hasattr(tok, "apply_chat_template"): |
| | return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | texts = [] |
| | for m in messages: |
| | for c in m.get("content", []): |
| | if isinstance(c, dict) and c.get("type") == "text": |
| | texts.append(c.get("text", "")) |
| | return "\n".join(texts) |
| |
|
| | def batch_decode_compat(processor, token_id_batches, **kw): |
| | tok = getattr(processor, "tokenizer", None) |
| | if tok is not None and hasattr(tok, "batch_decode"): |
| | return tok.batch_decode(token_id_batches, **kw) |
| | if hasattr(processor, "batch_decode"): |
| | return processor.batch_decode(token_id_batches, **kw) |
| | raise AttributeError("No batch_decode available on processor or tokenizer.") |
| |
|
| | def get_image_proc_params(processor) -> Dict[str, int]: |
| | ip = getattr(processor, "image_processor", None) |
| | return { |
| | "patch_size": getattr(ip, "patch_size", 14), |
| | "merge_size": getattr(ip, "merge_size", 1), |
| | "min_pixels": getattr(ip, "min_pixels", 256 * 256), |
| | "max_pixels": getattr(ip, "max_pixels", 1280 * 1280), |
| | } |
| |
|
| | def trim_generated(generated_ids, inputs): |
| | in_ids = getattr(inputs, "input_ids", None) |
| | if in_ids is None and isinstance(inputs, dict): |
| | in_ids = inputs.get("input_ids", None) |
| | if in_ids is None: |
| | return [out_ids for out_ids in generated_ids] |
| | return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)] |
| |
|
| | |
| | print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...") |
| | model = None |
| | processor = None |
| | model_loaded = False |
| | load_error_message = "" |
| |
|
| | try: |
| | model = AutoModelForImageTextToText.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float32, |
| | trust_remote_code=True, |
| | ) |
| | processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
| | model.eval() |
| | model_loaded = True |
| | print("Model and processor loaded on CPU.") |
| | except Exception as e: |
| | load_error_message = ( |
| | f"Error loading model/processor: {e}\n" |
| | "This might be due to network/model ID/library versions.\n" |
| | "Check the full traceback in the logs." |
| | ) |
| | print(load_error_message) |
| | traceback.print_exc() |
| |
|
| | |
| | def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[dict]: |
| | guidelines: str = ( |
| | "Localize an element on the GUI image according to my instructions and " |
| | "output a click position as Click(x, y) with x num pixels from the left edge " |
| | "and y num pixels from the top edge." |
| | ) |
| | return [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": pil_image}, |
| | {"type": "text", "text": f"{guidelines}\n{instruction}"} |
| | ], |
| | } |
| | ] |
| |
|
| | |
| | @torch.inference_mode() |
| | def run_inference_localization( |
| | messages_for_template: List[dict[str, Any]], |
| | pil_image_for_processing: Image.Image, |
| | device: str, |
| | dtype: torch.dtype, |
| | ) -> str: |
| | text_prompt = apply_chat_template_compat(processor, messages_for_template) |
| |
|
| | inputs = processor( |
| | text=[text_prompt], |
| | images=[pil_image_for_processing], |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| | inputs = move_to_device(inputs, device) |
| |
|
| | |
| | if device == "cuda": |
| | amp_ctx = torch.autocast(device_type="cuda", dtype=dtype) |
| | elif device == "mps": |
| | amp_ctx = torch.autocast(device_type="mps", dtype=torch.float16) |
| | else: |
| | amp_ctx = contextlib.nullcontext() |
| |
|
| | with amp_ctx: |
| | generated_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=128, |
| | do_sample=False, |
| | ) |
| |
|
| | generated_ids_trimmed = trim_generated(generated_ids, inputs) |
| | decoded_output = batch_decode_compat( |
| | processor, |
| | generated_ids_trimmed, |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | ) |
| | return decoded_output[0] if decoded_output else "" |
| |
|
| | |
| | |
| | @spaces.GPU(duration=120) |
| | def predict_click_location(input_pil_image: Image.Image, instruction: str): |
| | if not model_loaded or not processor or not model: |
| | return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a" |
| | if not input_pil_image: |
| | return "No image provided. Please upload an image.", None, "device: n/a | dtype: n/a" |
| | if not instruction or instruction.strip() == "": |
| | return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB"), "device: n/a | dtype: n/a" |
| |
|
| | |
| | device = pick_device() |
| | dtype = pick_dtype(device) |
| |
|
| | |
| | if device == "cuda": |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.set_float32_matmul_precision("high") |
| |
|
| | |
| | try: |
| | p = next(model.parameters()) |
| | cur_dev = p.device.type |
| | cur_dtype = p.dtype |
| | except StopIteration: |
| | cur_dev, cur_dtype = "cpu", torch.float32 |
| |
|
| | if cur_dev != device or cur_dtype != dtype: |
| | model.to(device=device, dtype=dtype) |
| | model.eval() |
| |
|
| | |
| | try: |
| | ip = get_image_proc_params(processor) |
| | resized_height, resized_width = smart_resize( |
| | input_pil_image.height, |
| | input_pil_image.width, |
| | factor=ip["patch_size"] * ip["merge_size"], |
| | min_pixels=ip["min_pixels"], |
| | max_pixels=ip["max_pixels"], |
| | ) |
| | resized_image = input_pil_image.resize( |
| | size=(resized_width, resized_height), |
| | resample=Image.Resampling.LANCZOS |
| | ) |
| | except Exception as e: |
| | traceback.print_exc() |
| | return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
| |
|
| | |
| | messages = get_localization_prompt(resized_image, instruction) |
| |
|
| | |
| | try: |
| | coordinates_str = run_inference_localization(messages, resized_image, device, dtype) |
| | except Exception as e: |
| | traceback.print_exc() |
| | return f"Error during model inference: {e}", resized_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
| |
|
| | |
| | output_image_with_click = resized_image.copy().convert("RGB") |
| | match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str) |
| | if match: |
| | try: |
| | x = int(match.group(1)) |
| | y = int(match.group(2)) |
| | draw = ImageDraw.Draw(output_image_with_click) |
| | radius = max(5, min(resized_width // 100, resized_height // 100, 15)) |
| | bbox = (x - radius, y - radius, x + radius, y + radius) |
| | draw.ellipse(bbox, outline="red", width=max(2, radius // 4)) |
| | print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})") |
| | except Exception as e: |
| | print(f"Error drawing on image: {e}") |
| | traceback.print_exc() |
| | else: |
| | print(f"Could not parse 'Click(x, y)' from model output: {coordinates_str}") |
| |
|
| | return coordinates_str, output_image_with_click, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
| |
|
| | |
| | example_image = None |
| | example_instruction = "Enter the server address readyforquantum.com to check its security" |
| | try: |
| | example_image_url = "https://readyforquantum.com/img/screentest.jpg" |
| | example_image = Image.open(requests.get(example_image_url, stream=True).raw) |
| | except Exception as e: |
| | print(f"Could not load example image from URL: {e}") |
| | traceback.print_exc() |
| | try: |
| | example_image = Image.new("RGB", (200, 150), color="lightgray") |
| | draw = ImageDraw.Draw(example_image) |
| | draw.text((10, 10), "Example image\nfailed to load", fill="black") |
| | except Exception: |
| | pass |
| |
|
| | |
| | title = "Holo1-3B: Holo1 Localization Demo (ZeroGPU-ready)" |
| | article = f""" |
| | <p style='text-align: center'> |
| | Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany | |
| | Paper: <a href='https://cdn.prod.website-files.com/67e2dbd9acff0c50d4c8a80c/683ec8095b353e8b38317f80_h_tech_report_v1.pdf' target='_blank'>HCompany Tech Report</a> | |
| | Blog: <a href='https://www.hcompany.ai/surfer-h' target='_blank'>Surfer-H Blog Post</a><br/> |
| | <small>GPU (if available) is requested only during inference via @spaces.GPU.</small> |
| | </p> |
| | """ |
| |
|
| | if not model_loaded: |
| | with gr.Blocks() as demo: |
| | gr.Markdown(f"# <center>⚠️ Error: Model Failed to Load ⚠️</center>") |
| | gr.Markdown(f"<center>{load_error_message}</center>") |
| | gr.Markdown("<center>See logs for the full traceback.</center>") |
| | else: |
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") |
| | gr.Markdown(article) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | input_image_component = gr.Image(type="pil", label="Input UI Image", height=400) |
| | instruction_component = gr.Textbox( |
| | label="Instruction", |
| | placeholder="e.g., Click the 'Login' button", |
| | info="Type the action you want the model to localize on the image." |
| | ) |
| | submit_button = gr.Button("Localize Click", variant="primary") |
| |
|
| | with gr.Column(scale=1): |
| | output_coords_component = gr.Textbox( |
| | label="Predicted Coordinates (Format: Click(x, y))", |
| | interactive=False |
| | ) |
| | output_image_component = gr.Image( |
| | type="pil", |
| | label="Image with Predicted Click Point", |
| | height=400, |
| | interactive=False |
| | ) |
| | runtime_info = gr.Textbox( |
| | label="Runtime Info", |
| | value="device: n/a | dtype: n/a", |
| | interactive=False |
| | ) |
| |
|
| | if example_image: |
| | gr.Examples( |
| | examples=[[example_image, example_instruction]], |
| | inputs=[input_image_component, instruction_component], |
| | outputs=[output_coords_component, output_image_component, runtime_info], |
| | fn=predict_click_location, |
| | cache_examples="lazy", |
| | ) |
| |
|
| | submit_button.click( |
| | fn=predict_click_location, |
| | inputs=[input_image_component, instruction_component], |
| | outputs=[output_coords_component, output_image_component, runtime_info] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | |
| | demo.launch(debug=True) |
| |
|
| |
|