"""Gradio webapp to visualize Gemini Robotics-ER 1.6 outputs.""" from __future__ import annotations import io import json import os import re from pathlib import Path import gradio as gr from dotenv import load_dotenv from google import genai from google.genai import errors as genai_errors from google.genai import types from PIL import Image, ImageDraw, ImageFont load_dotenv() MODEL = "gemini-robotics-er-1.6-preview" EXAMPLES_DIR = Path(__file__).parent / "examples" PALETTE = [ "#ef4444", "#22c55e", "#3b82f6", "#eab308", "#a855f7", "#ec4899", "#06b6d4", "#f97316", "#84cc16", "#14b8a6", ] TASK_TEMPLATES = { "Points (object detection)": ( "Point to the most prominent objects in this image (up to 10). " "Output ONLY a JSON list in the form " '[{"point": [y, x], "label": }] ' "with coordinates normalized to 0-1000." ), "Bounding boxes": ( "Detect the prominent objects and return BOUNDING BOXES (rectangles), " "not points. Output ONLY a JSON array where every item has the key " '"box_2d" with four integers: ' '[{"box_2d": [ymin, xmin, ymax, xmax], "label": }]. ' "Coordinates normalized to 0-1000. Limit to 25 objects. " 'Do NOT use "point" — every item must have "box_2d" with exactly 4 values.' ), "Trajectory (path planning)": ( "Pick one salient object in the image and plan a trajectory for a " "robot gripper to pick it up and move it to a nearby free area. " 'Return ONLY a JSON list of ordered waypoints [{"point": [y, x], ' '"label": ""}] with coordinates normalized to 0-1000 and ' 'labels as stringified integers starting from "0". Use ~10-15 waypoints.' ), "Grasp points": ( "Identify up to 3 good grasp points on objects visible in this scene. " "Output ONLY JSON as " '[{"point": [y, x], "label": }] ' "with coordinates normalized to 0-1000." ), "Instrument reading": ( "If any gauge, meter, clock, dial, or digital display is visible, read it. " 'Output ONLY a JSON object like {"reading": , "units": , ' '"explanation": }. If no instrument is visible, return ' '{"reading": null, "explanation": "no instrument visible"}.' ), "Free-form spatial reasoning": ( "Describe the spatial layout of this scene and suggest a reasonable " "next action for a household or warehouse robot. Be concise." ), } def get_client(api_key: str | None = None) -> genai.Client: key = (api_key or "").strip() or os.getenv("GEMINI_API_KEY") if not key: raise gr.Error( "No API key. Paste one in the 'Gemini API key' field above, " "or set GEMINI_API_KEY in .env." ) return genai.Client(api_key=key) def extract_json(text: str): """Strip markdown fences and parse JSON; return None on failure.""" if not text: return None stripped = text.strip() fenced = re.search(r"```(?:json)?\s*(.+?)```", stripped, re.DOTALL) if fenced: stripped = fenced.group(1).strip() try: return json.loads(stripped) except json.JSONDecodeError: # try to find the first [...] or {...} blob for pattern in (r"\[.*\]", r"\{.*\}"): m = re.search(pattern, stripped, re.DOTALL) if m: try: return json.loads(m.group(0)) except json.JSONDecodeError: continue return None def load_font(size: int) -> ImageFont.ImageFont: for candidate in [ "/System/Library/Fonts/Helvetica.ttc", "/System/Library/Fonts/Supplemental/Arial.ttf", "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", ]: if os.path.exists(candidate): try: return ImageFont.truetype(candidate, size) except OSError: continue return ImageFont.load_default() def draw_label(draw: ImageDraw.ImageDraw, xy, text: str, color: str, font): x, y = xy bbox = draw.textbbox((x, y), text, font=font) pad = 3 draw.rectangle( (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad), fill=color, ) draw.text((x, y), text, fill="white", font=font) def render_points(img: Image.Image, items: list) -> Image.Image: out = img.copy().convert("RGB") draw = ImageDraw.Draw(out) W, H = out.size font = load_font(max(14, W // 70)) r = max(6, W // 180) for i, item in enumerate(items): pt = item.get("point") or item.get("xy") if not pt or len(pt) != 2: continue y, x = pt cx, cy = x * W / 1000, y * H / 1000 color = PALETTE[i % len(PALETTE)] draw.ellipse((cx - r, cy - r, cx + r, cy + r), fill=color, outline="white", width=2) label = str(item.get("label", "")) if label: draw_label(draw, (cx + r + 4, cy - r), label, color, font) return out def render_boxes(img: Image.Image, items: list) -> Image.Image: out = img.copy().convert("RGB") draw = ImageDraw.Draw(out) W, H = out.size font = load_font(max(14, W // 70)) for i, item in enumerate(items): box = item.get("box_2d") or item.get("bbox") if not box or len(box) != 4: continue ymin, xmin, ymax, xmax = box x0, y0 = xmin * W / 1000, ymin * H / 1000 x1, y1 = xmax * W / 1000, ymax * H / 1000 color = PALETTE[i % len(PALETTE)] draw.rectangle((x0, y0, x1, y1), outline=color, width=3) label = str(item.get("label", "")) if label: draw_label(draw, (x0 + 4, y0 + 4), label, color, font) return out def render_trajectory(img: Image.Image, items: list) -> Image.Image: out = img.copy().convert("RGB") draw = ImageDraw.Draw(out) W, H = out.size font = load_font(max(12, W // 80)) def key(item): try: return int(str(item.get("label", "0"))) except ValueError: return 0 pts = [] for item in sorted(items, key=key): p = item.get("point") if p and len(p) == 2: y, x = p pts.append((x * W / 1000, y * H / 1000, str(item.get("label", "")))) if len(pts) >= 2: for (x0, y0, _), (x1, y1, _) in zip(pts, pts[1:]): draw.line((x0, y0, x1, y1), fill="#3b82f6", width=4) r = max(6, W // 200) for i, (x, y, label) in enumerate(pts): color = PALETTE[0] if i == 0 else PALETTE[1] if i == len(pts) - 1 else "#3b82f6" draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="white", width=2) draw_label(draw, (x + r + 3, y - r - 2), label, color, font) return out def _as_items(payload) -> list: if isinstance(payload, list): return payload if isinstance(payload, dict): for key in ("points", "boxes", "waypoints", "items", "detections"): v = payload.get(key) if isinstance(v, list): return v return [] def visualize(img: Image.Image, task: str, payload) -> Image.Image | None: """Render overlays. Shape-aware: falls back to whatever the payload actually contains so the visualization still works when the model returns a different shape than the task template requested.""" items = _as_items(payload) if not items: return None has_box = any(isinstance(i, dict) and ("box_2d" in i or "bbox" in i) for i in items) has_point = any(isinstance(i, dict) and ("point" in i or "xy" in i) for i in items) if task.startswith("Trajectory") and has_point: return render_trajectory(img, items) if task.startswith("Bounding"): if has_box: return render_boxes(img, items) if has_point: return render_points(img, items) # fallback: model returned points if has_box: return render_boxes(img, items) if has_point: return render_points(img, items) return None def _format_gemini_error(e: genai_errors.APIError) -> str: """Extract the useful parts of a Gemini API error into a single-line message.""" status = getattr(e, "code", None) or getattr(e, "status_code", "?") message = getattr(e, "message", None) or str(e) # The SDK stores the raw response JSON on the exception; try to pull retry info. details_blob = getattr(e, "details", None) or {} retry_hint = "" try: err = details_blob.get("error", details_blob) if isinstance(details_blob, dict) else {} for d in err.get("details", []) or []: if d.get("@type", "").endswith("RetryInfo"): delay = d.get("retryDelay", "") if delay: retry_hint = f" (retry in {delay})" break except Exception: pass # Trim super-long messages for the toast. short = message.split("\n", 1)[0].strip() if len(short) > 400: short = short[:397] + "..." return f"Gemini API error {status}: {short}{retry_hint}" def run_model(image: Image.Image | None, task: str, prompt: str, thinking_budget: int, api_key: str): if image is None: raise gr.Error("Please pick or upload an image first.") if not prompt.strip(): raise gr.Error("Prompt is empty.") pil = image if isinstance(image, Image.Image) else Image.open(image) pil = pil.convert("RGB") buf = io.BytesIO() pil.save(buf, format="PNG") img_bytes = buf.getvalue() try: client = get_client(api_key) config = types.GenerateContentConfig( temperature=1.0, thinking_config=types.ThinkingConfig(thinking_budget=int(thinking_budget)), ) response = client.models.generate_content( model=MODEL, contents=[ types.Part.from_bytes(data=img_bytes, mime_type="image/png"), prompt, ], config=config, ) except genai_errors.APIError as e: raise gr.Error(_format_gemini_error(e)) except Exception as e: raise gr.Error(f"Request failed: {type(e).__name__}: {e}") text = response.text or "" if not text.strip(): # The SDK gave us an empty response — surface whatever we can. finish = "" try: finish = str(response.candidates[0].finish_reason) if response.candidates else "" except Exception: pass raise gr.Error( "Model returned an empty response" + (f" (finish_reason={finish})" if finish else "") + ". This can happen if the image was blocked by safety filters or " "if the output hit the token limit." ) parsed = extract_json(text) overlay = visualize(pil, task, parsed) if parsed is not None else None pretty = json.dumps(parsed, indent=2) if parsed is not None else "(no JSON parsed)" return overlay if overlay is not None else pil, pretty, text def on_task_change(task: str) -> str: return TASK_TEMPLATES.get(task, "") REMOTE_EXAMPLES = [ # Licensed/watermarked images — referenced by URL so we don't redistribute them. "https://www.shutterstock.com/image-photo/adiyaman-turkey-aug-21-2024-260nw-2514334691.jpg", "https://c8.alamy.com/comp/2HXT4ME/full-containers-next-to-sets-of-lego-brand-bricks-with-figurines-on-a-carpet-in-a-playroom-2HXT4ME.jpg", "https://media.licdn.com/dms/image/v2/D5612AQF1km1N_jy66Q/article-cover_image-shrink_720_1280/B56Zh8QE__HUAI-/0/1754431246854?e=2147483647&v=beta&t=iEGQih1pUnJNFBbas0K9rCpez-VgJOxOKJ2TGQM1-tY", ] def build_ui() -> gr.Blocks: local_examples = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg")) examples = local_examples + REMOTE_EXAMPLES with gr.Blocks(title="Gemini Robotics-ER 1.6 Visualizer") as demo: gr.Markdown( "# Gemini Robotics-ER 1.6 Visualizer\n" f"Model: `{MODEL}` — pick an image, choose a task, edit the prompt, run." ) api_key = gr.Textbox( label="Gemini API key", value="", type="password", placeholder="Paste your key here (or leave empty if the server provides one)", info="Sent only to this server to call Gemini; never logged or persisted.", ) with gr.Row(): with gr.Column(scale=1): image = gr.Image(type="pil", label="Input image", height=380) if examples: gr.Examples(examples=examples, inputs=image, label="Example images") task = gr.Dropdown( choices=list(TASK_TEMPLATES.keys()), value="Points (object detection)", label="Task", ) prompt = gr.Textbox( label="Prompt (editable)", value=TASK_TEMPLATES["Points (object detection)"], lines=6, ) thinking = gr.Slider( 0, 8192, value=0, step=256, label="Thinking budget (0 = fast, higher = more reasoning)", ) run = gr.Button("Run", variant="primary") with gr.Column(scale=1): overlay = gr.Image(type="pil", label="Result", height=480) with gr.Accordion("Raw JSON", open=True): parsed_out = gr.Code(label="Parsed JSON", language="json") with gr.Accordion("Model answer (raw text)", open=True): raw_out = gr.Textbox(label="response.text", lines=8) task.change(on_task_change, inputs=task, outputs=prompt) run.click( run_model, inputs=[image, task, prompt, thinking, api_key], outputs=[overlay, parsed_out, raw_out], ) return demo if __name__ == "__main__": build_ui().launch(theme=gr.themes.Soft())