Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio app for Trace Model inference visualization. | |
| Takes an image, runs the trace model to predict trajectory points, | |
| overlays the trace on the image, and displays the predicted coordinates. | |
| Model: https://huggingface.co/mihirgrao/trace-model | |
| """ | |
| import base64 | |
| import os | |
| import tempfile | |
| import logging | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| import requests | |
| from trace_inference import ( | |
| DEFAULT_MODEL_ID, | |
| build_prompt, | |
| preprocess_image_for_trace, | |
| run_inference, | |
| ) | |
| from trajectory_viz import visualize_trajectory_on_image | |
| logger = logging.getLogger(__name__) | |
| # Global server state (eval server mode) | |
| _server_state = {"server_url": None, "base_url": "http://localhost"} | |
| def discover_available_models( | |
| base_url: str = "http://localhost", | |
| port_range: Tuple[int, int] = (8000, 8010), | |
| ) -> List[Tuple[str, str]]: | |
| """Discover trace eval servers by pinging /health. Returns [(server_url, model_name), ...]. | |
| For ngrok or https URLs, uses the URL as-is. For localhost, scans ports.""" | |
| base_url = base_url.strip().rstrip("/") | |
| urls_to_check: List[Tuple[str, str]] = [] | |
| # Single URL mode: ngrok, https, or URL that already has a port | |
| if "ngrok" in base_url or base_url.startswith("https://"): | |
| urls_to_check = [(base_url, "Trace (ngrok/external)")] | |
| elif ":" in base_url.split("//")[-1].split("/")[0]: | |
| # Already has port (e.g. http://localhost:8000) | |
| urls_to_check = [(base_url, "Trace")] | |
| else: | |
| # Scan ports for localhost | |
| start_port, end_port = port_range | |
| for port in range(start_port, end_port + 1): | |
| urls_to_check.append((f"{base_url}:{port}", f"Trace @ port {port}")) | |
| available = [] | |
| headers = {} | |
| if "ngrok" in base_url: | |
| headers["ngrok-skip-browser-warning"] = "true" | |
| for server_url, label in urls_to_check: | |
| try: | |
| r = requests.get(f"{server_url}/health", timeout=5.0, headers=headers) | |
| if r.status_code == 200: | |
| try: | |
| info = requests.get( | |
| f"{server_url}/model_info", timeout=5.0, headers=headers | |
| ).json() | |
| name = info.get("model_id", label) | |
| except Exception: | |
| name = label | |
| available.append((server_url, name)) | |
| except requests.exceptions.RequestException as e: | |
| logger.debug(f"Could not reach {server_url}/health: {e}") | |
| continue | |
| return available | |
| def get_model_info_for_url(server_url: str) -> Optional[str]: | |
| """Get formatted model info for a trace eval server.""" | |
| if not server_url: | |
| return None | |
| headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {} | |
| try: | |
| r = requests.get(f"{server_url.rstrip('/')}/model_info", timeout=5.0, headers=headers) | |
| if r.status_code == 200: | |
| return format_trace_model_info(r.json()) | |
| except Exception as e: | |
| logger.warning(f"Could not fetch model info: {e}") | |
| return None | |
| def format_trace_model_info(info: dict) -> str: | |
| """Format trace model info as markdown.""" | |
| lines = ["## Model Information\n"] | |
| lines.append(f"**Model ID:** `{info.get('model_id', 'Unknown')}`\n") | |
| if "model_class" in info: | |
| lines.append(f"**Model Class:** `{info.get('model_class')}`\n") | |
| if "total_parameters" in info: | |
| lines.append(f"**Parameters:** {info.get('total_parameters', 0):,}\n") | |
| if "error" in info: | |
| lines.append(f"**Error:** {info['error']}\n") | |
| return "".join(lines) | |
| def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]: | |
| """Check trace eval server health. Returns (status_msg, health_data, model_info_text).""" | |
| if not server_url: | |
| return "Please provide a server URL.", None, None | |
| headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {} | |
| try: | |
| r = requests.get(f"{server_url.rstrip('/')}/health", timeout=5.0, headers=headers) | |
| r.raise_for_status() | |
| data = r.json() | |
| info = get_model_info_for_url(server_url) | |
| _server_state["server_url"] = server_url | |
| return f"Server connected: {data.get('status', 'ok')}", data, info | |
| except requests.exceptions.RequestException as e: | |
| return f"Error connecting to server: {str(e)}", None, None | |
| def run_inference_via_server( | |
| image_path: str, | |
| instruction: str, | |
| server_url: str, | |
| is_oxe: bool = False, | |
| ) -> Tuple[str, Optional[str]]: | |
| """Run inference via trace eval server. Returns (prediction, overlay_path).""" | |
| with open(image_path, "rb") as f: | |
| image_b64 = base64.b64encode(f.read()).decode("utf-8") | |
| headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {} | |
| r = requests.post( | |
| f"{server_url.rstrip('/')}/predict", | |
| json={ | |
| "image_base64": image_b64, | |
| "instruction": instruction, | |
| "is_oxe": is_oxe, | |
| }, | |
| timeout=120.0, | |
| headers=headers, | |
| ) | |
| r.raise_for_status() | |
| data = r.json() | |
| if "error" in data: | |
| return data["error"], None | |
| prediction = data.get("prediction", "") | |
| trajectory = data.get("trajectory", []) | |
| overlay_path = None | |
| if trajectory and len(trajectory) >= 2: | |
| _, preprocessed_path = preprocess_image_for_trace(image_path) | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f: | |
| overlay_path = f.name | |
| img_arr = visualize_trajectory_on_image( | |
| trajectory=trajectory, | |
| image_path=preprocessed_path, | |
| output_path=overlay_path, | |
| normalized=True, | |
| ) | |
| if img_arr is None: | |
| visualize_trajectory_on_image( | |
| trajectory=trajectory, | |
| image_path=preprocessed_path, | |
| output_path=overlay_path, | |
| normalized=False, | |
| ) | |
| finally: | |
| if os.path.exists(preprocessed_path): | |
| try: | |
| os.unlink(preprocessed_path) | |
| except Exception: | |
| pass | |
| return prediction, overlay_path | |
| # --- Gradio UI --- | |
| try: | |
| demo = gr.Blocks(title="Trace Model Visualizer") | |
| except TypeError: | |
| demo = gr.Blocks(title="Trace Model Visualizer") | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # Trace Model Visualizer | |
| Upload an image and provide a natural language task instruction to predict the trajectory/trace using [mihirgrao/trace-model](https://huggingface.co/mihirgrao/trace-model). | |
| The model predicts coordinate points from your instruction; they are overlaid on the image (green → red gradient) and listed below. | |
| """ | |
| ) | |
| server_url_state = gr.State(value=None) | |
| model_url_mapping_state = gr.State(value={}) | |
| def discover_and_select_models(base_url: str): | |
| if not base_url: | |
| return ( | |
| gr.update(choices=[], value=None), | |
| gr.update(value="Please provide a base URL", visible=True), | |
| gr.update(value="", visible=True), | |
| None, | |
| {}, | |
| ) | |
| _server_state["base_url"] = base_url | |
| models = discover_available_models(base_url, port_range=(8000, 8010)) | |
| if not models: | |
| return ( | |
| gr.update(choices=[], value=None), | |
| gr.update( | |
| value="❌ No trace eval servers found on ports 8000-8010.", | |
| visible=True, | |
| ), | |
| gr.update(value="", visible=True), | |
| None, | |
| {}, | |
| ) | |
| choices = [] | |
| url_map = {} | |
| for url, name in models: | |
| choices.append(name) | |
| url_map[name] = url | |
| selected = choices[0] if choices else None | |
| selected_url = url_map.get(selected) if selected else None | |
| model_info_text = get_model_info_for_url(selected_url) if selected_url else "" | |
| status = f"✅ Found {len(models)} server(s). Auto-selected first." | |
| _server_state["server_url"] = selected_url | |
| return ( | |
| gr.update(choices=choices, value=selected), | |
| gr.update(value=status, visible=True), | |
| gr.update(value=model_info_text, visible=True), | |
| selected_url, | |
| url_map, | |
| ) | |
| def on_model_selected(model_choice: str, url_mapping: dict): | |
| if not model_choice: | |
| return gr.update(value="No model selected", visible=True), gr.update(value="", visible=True), None | |
| server_url = url_mapping.get(model_choice) if url_mapping else None | |
| if not server_url: | |
| return ( | |
| gr.update(value="Could not find server URL. Please rediscover.", visible=True), | |
| gr.update(value="", visible=True), | |
| None, | |
| ) | |
| model_info_text = get_model_info_for_url(server_url) or "" | |
| status, _, _ = check_server_health(server_url) | |
| _server_state["server_url"] = server_url | |
| return gr.update(value=status, visible=True), gr.update(value=model_info_text, visible=True), server_url | |
| with gr.Sidebar(): | |
| gr.Markdown("### 🔧 Model Configuration") | |
| base_url_input = gr.Textbox( | |
| label="Base Server URL", | |
| placeholder="http://localhost", | |
| value="http://localhost", | |
| interactive=True, | |
| ) | |
| discover_btn = gr.Button("🔍 Discover Eval Servers", variant="primary", size="lg") | |
| model_dropdown = gr.Dropdown( | |
| label="Select Eval Server", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| info="Discover trace eval servers on ports 8000-8010", | |
| ) | |
| server_status = gr.Markdown("Select an eval server below (auto-connects on selection)") | |
| gr.Markdown("---") | |
| gr.Markdown("### 📋 Model Information") | |
| model_info_display = gr.Markdown("") | |
| discover_btn.click( | |
| fn=discover_and_select_models, | |
| inputs=[base_url_input], | |
| outputs=[ | |
| model_dropdown, | |
| server_status, | |
| model_info_display, | |
| server_url_state, | |
| model_url_mapping_state, | |
| ], | |
| ) | |
| model_dropdown.change( | |
| fn=on_model_selected, | |
| inputs=[model_dropdown, model_url_mapping_state], | |
| outputs=[server_status, model_info_display, server_url_state], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="Upload Image", | |
| type="filepath", | |
| height=400, | |
| ) | |
| instruction_input = gr.Textbox( | |
| label="Natural language instruction", | |
| placeholder="e.g. Pick up the red block and place it on the table. Stack the cube on top of the block.", | |
| value="", | |
| lines=4, | |
| info="Enter a task description in natural language. The model predicts the trace for this instruction.", | |
| ) | |
| prompt_format = gr.Radio( | |
| choices=["LIBERO", "OXE"], | |
| value="LIBERO", | |
| label="Prompt Format", | |
| info="Switch between LIBERO and OXE training formats.", | |
| ) | |
| gr.Markdown("### Local model (if no eval server selected)") | |
| model_id_input = gr.Textbox( | |
| label="Model ID", | |
| value=DEFAULT_MODEL_ID, | |
| info="Hugging Face model ID (auto-loads on first inference if no eval server selected)", | |
| ) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| with gr.Column(scale=1): | |
| prompt_display = gr.Markdown( | |
| f"**Prompt sent to model:**\n\n```\n{build_prompt('')}\n```", | |
| label="Model prompt", | |
| ) | |
| overlay_output = gr.Image( | |
| label="Image with Trace Overlay", | |
| height=400, | |
| ) | |
| prediction_output = gr.Textbox( | |
| label="Model Prediction (raw)", | |
| lines=6, | |
| ) | |
| status_md = gr.Markdown( | |
| "Select an eval server from the sidebar (auto-connects), or run inference with local model." | |
| ) | |
| def on_run_inference(image_path, instruction, model_id, server_url, prompt_mode): | |
| if image_path is None: | |
| return ( | |
| "", | |
| "Please upload an image first.", | |
| None, | |
| "**Status:** Please upload an image.", | |
| ) | |
| is_oxe = (prompt_mode == "OXE") | |
| if server_url: | |
| prompt = build_prompt(instruction, is_oxe=is_oxe) | |
| prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```" | |
| pred, overlay_path = run_inference_via_server( | |
| image_path, instruction, server_url, is_oxe=is_oxe | |
| ) | |
| else: | |
| prompt = build_prompt(instruction, is_oxe=is_oxe) | |
| prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```" | |
| pred, overlay_path, _ = run_inference(image_path, prompt, model_id) | |
| status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}" | |
| return prompt_md, pred, overlay_path, status | |
| def update_prompt_display(instruction: str, prompt_mode: str): | |
| is_oxe = (prompt_mode == "OXE") | |
| prompt = build_prompt(instruction, is_oxe=is_oxe) | |
| return f"**Prompt sent to model:**\n\n```\n{prompt}\n```" | |
| instruction_input.change( | |
| fn=update_prompt_display, | |
| inputs=[instruction_input, prompt_format], | |
| outputs=[prompt_display], | |
| ) | |
| prompt_format.change( | |
| fn=update_prompt_display, | |
| inputs=[instruction_input, prompt_format], | |
| outputs=[prompt_display], | |
| ) | |
| run_btn.click( | |
| fn=on_run_inference, | |
| inputs=[ | |
| image_input, | |
| instruction_input, | |
| model_id_input, | |
| server_url_state, | |
| prompt_format, | |
| ], | |
| outputs=[ | |
| prompt_display, | |
| prediction_output, | |
| overlay_output, | |
| status_md, | |
| ], | |
| api_name="run_inference", | |
| ) | |
| def main(): | |
| """Launch the Gradio app.""" | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| main() | |