| """Gradio webapp to visualize Gemini Robotics-ER 1.6 outputs.""" |
| from __future__ import annotations |
|
|
| import io |
| import json |
| import os |
| import re |
| from pathlib import Path |
|
|
| import gradio as gr |
| from dotenv import load_dotenv |
| from google import genai |
| from google.genai import errors as genai_errors |
| from google.genai import types |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| load_dotenv() |
|
|
| MODEL = "gemini-robotics-er-1.6-preview" |
| EXAMPLES_DIR = Path(__file__).parent / "examples" |
|
|
| PALETTE = [ |
| "#ef4444", "#22c55e", "#3b82f6", "#eab308", "#a855f7", |
| "#ec4899", "#06b6d4", "#f97316", "#84cc16", "#14b8a6", |
| ] |
|
|
| TASK_TEMPLATES = { |
| "Points (object detection)": ( |
| "Point to the most prominent objects in this image (up to 10). " |
| "Output ONLY a JSON list in the form " |
| '[{"point": [y, x], "label": <concise object name>}] ' |
| "with coordinates normalized to 0-1000." |
| ), |
| "Bounding boxes": ( |
| "Detect the prominent objects and return BOUNDING BOXES (rectangles), " |
| "not points. Output ONLY a JSON array where every item has the key " |
| '"box_2d" with four integers: ' |
| '[{"box_2d": [ymin, xmin, ymax, xmax], "label": <concise object name>}]. ' |
| "Coordinates normalized to 0-1000. Limit to 25 objects. " |
| 'Do NOT use "point" — every item must have "box_2d" with exactly 4 values.' |
| ), |
| "Trajectory (path planning)": ( |
| "Pick one salient object in the image and plan a trajectory for a " |
| "robot gripper to pick it up and move it to a nearby free area. " |
| 'Return ONLY a JSON list of ordered waypoints [{"point": [y, x], ' |
| '"label": "<step_index>"}] with coordinates normalized to 0-1000 and ' |
| 'labels as stringified integers starting from "0". Use ~10-15 waypoints.' |
| ), |
| "Grasp points": ( |
| "Identify up to 3 good grasp points on objects visible in this scene. " |
| "Output ONLY JSON as " |
| '[{"point": [y, x], "label": <object + brief reason for this grasp>}] ' |
| "with coordinates normalized to 0-1000." |
| ), |
| "Instrument reading": ( |
| "If any gauge, meter, clock, dial, or digital display is visible, read it. " |
| 'Output ONLY a JSON object like {"reading": <value>, "units": <str>, ' |
| '"explanation": <short>}. If no instrument is visible, return ' |
| '{"reading": null, "explanation": "no instrument visible"}.' |
| ), |
| "Free-form spatial reasoning": ( |
| "Describe the spatial layout of this scene and suggest a reasonable " |
| "next action for a household or warehouse robot. Be concise." |
| ), |
| } |
|
|
|
|
| def get_client(api_key: str | None = None) -> genai.Client: |
| key = (api_key or "").strip() or os.getenv("GEMINI_API_KEY") |
| if not key: |
| raise gr.Error( |
| "No API key. Paste one in the 'Gemini API key' field above, " |
| "or set GEMINI_API_KEY in .env." |
| ) |
| return genai.Client(api_key=key) |
|
|
|
|
| def extract_json(text: str): |
| """Strip markdown fences and parse JSON; return None on failure.""" |
| if not text: |
| return None |
| stripped = text.strip() |
| fenced = re.search(r"```(?:json)?\s*(.+?)```", stripped, re.DOTALL) |
| if fenced: |
| stripped = fenced.group(1).strip() |
| try: |
| return json.loads(stripped) |
| except json.JSONDecodeError: |
| |
| for pattern in (r"\[.*\]", r"\{.*\}"): |
| m = re.search(pattern, stripped, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group(0)) |
| except json.JSONDecodeError: |
| continue |
| return None |
|
|
|
|
| def load_font(size: int) -> ImageFont.ImageFont: |
| for candidate in [ |
| "/System/Library/Fonts/Helvetica.ttc", |
| "/System/Library/Fonts/Supplemental/Arial.ttf", |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", |
| ]: |
| if os.path.exists(candidate): |
| try: |
| return ImageFont.truetype(candidate, size) |
| except OSError: |
| continue |
| return ImageFont.load_default() |
|
|
|
|
| def draw_label(draw: ImageDraw.ImageDraw, xy, text: str, color: str, font): |
| x, y = xy |
| bbox = draw.textbbox((x, y), text, font=font) |
| pad = 3 |
| draw.rectangle( |
| (bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad), |
| fill=color, |
| ) |
| draw.text((x, y), text, fill="white", font=font) |
|
|
|
|
| def render_points(img: Image.Image, items: list) -> Image.Image: |
| out = img.copy().convert("RGB") |
| draw = ImageDraw.Draw(out) |
| W, H = out.size |
| font = load_font(max(14, W // 70)) |
| r = max(6, W // 180) |
| for i, item in enumerate(items): |
| pt = item.get("point") or item.get("xy") |
| if not pt or len(pt) != 2: |
| continue |
| y, x = pt |
| cx, cy = x * W / 1000, y * H / 1000 |
| color = PALETTE[i % len(PALETTE)] |
| draw.ellipse((cx - r, cy - r, cx + r, cy + r), fill=color, outline="white", width=2) |
| label = str(item.get("label", "")) |
| if label: |
| draw_label(draw, (cx + r + 4, cy - r), label, color, font) |
| return out |
|
|
|
|
| def render_boxes(img: Image.Image, items: list) -> Image.Image: |
| out = img.copy().convert("RGB") |
| draw = ImageDraw.Draw(out) |
| W, H = out.size |
| font = load_font(max(14, W // 70)) |
| for i, item in enumerate(items): |
| box = item.get("box_2d") or item.get("bbox") |
| if not box or len(box) != 4: |
| continue |
| ymin, xmin, ymax, xmax = box |
| x0, y0 = xmin * W / 1000, ymin * H / 1000 |
| x1, y1 = xmax * W / 1000, ymax * H / 1000 |
| color = PALETTE[i % len(PALETTE)] |
| draw.rectangle((x0, y0, x1, y1), outline=color, width=3) |
| label = str(item.get("label", "")) |
| if label: |
| draw_label(draw, (x0 + 4, y0 + 4), label, color, font) |
| return out |
|
|
|
|
| def render_trajectory(img: Image.Image, items: list) -> Image.Image: |
| out = img.copy().convert("RGB") |
| draw = ImageDraw.Draw(out) |
| W, H = out.size |
| font = load_font(max(12, W // 80)) |
|
|
| def key(item): |
| try: |
| return int(str(item.get("label", "0"))) |
| except ValueError: |
| return 0 |
|
|
| pts = [] |
| for item in sorted(items, key=key): |
| p = item.get("point") |
| if p and len(p) == 2: |
| y, x = p |
| pts.append((x * W / 1000, y * H / 1000, str(item.get("label", "")))) |
|
|
| if len(pts) >= 2: |
| for (x0, y0, _), (x1, y1, _) in zip(pts, pts[1:]): |
| draw.line((x0, y0, x1, y1), fill="#3b82f6", width=4) |
|
|
| r = max(6, W // 200) |
| for i, (x, y, label) in enumerate(pts): |
| color = PALETTE[0] if i == 0 else PALETTE[1] if i == len(pts) - 1 else "#3b82f6" |
| draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="white", width=2) |
| draw_label(draw, (x + r + 3, y - r - 2), label, color, font) |
| return out |
|
|
|
|
| def _as_items(payload) -> list: |
| if isinstance(payload, list): |
| return payload |
| if isinstance(payload, dict): |
| for key in ("points", "boxes", "waypoints", "items", "detections"): |
| v = payload.get(key) |
| if isinstance(v, list): |
| return v |
| return [] |
|
|
|
|
| def visualize(img: Image.Image, task: str, payload) -> Image.Image | None: |
| """Render overlays. Shape-aware: falls back to whatever the payload actually |
| contains so the visualization still works when the model returns a different |
| shape than the task template requested.""" |
| items = _as_items(payload) |
| if not items: |
| return None |
|
|
| has_box = any(isinstance(i, dict) and ("box_2d" in i or "bbox" in i) for i in items) |
| has_point = any(isinstance(i, dict) and ("point" in i or "xy" in i) for i in items) |
|
|
| if task.startswith("Trajectory") and has_point: |
| return render_trajectory(img, items) |
| if task.startswith("Bounding"): |
| if has_box: |
| return render_boxes(img, items) |
| if has_point: |
| return render_points(img, items) |
| if has_box: |
| return render_boxes(img, items) |
| if has_point: |
| return render_points(img, items) |
| return None |
|
|
|
|
| def _format_gemini_error(e: genai_errors.APIError) -> str: |
| """Extract the useful parts of a Gemini API error into a single-line message.""" |
| status = getattr(e, "code", None) or getattr(e, "status_code", "?") |
| message = getattr(e, "message", None) or str(e) |
| |
| details_blob = getattr(e, "details", None) or {} |
| retry_hint = "" |
| try: |
| err = details_blob.get("error", details_blob) if isinstance(details_blob, dict) else {} |
| for d in err.get("details", []) or []: |
| if d.get("@type", "").endswith("RetryInfo"): |
| delay = d.get("retryDelay", "") |
| if delay: |
| retry_hint = f" (retry in {delay})" |
| break |
| except Exception: |
| pass |
| |
| short = message.split("\n", 1)[0].strip() |
| if len(short) > 400: |
| short = short[:397] + "..." |
| return f"Gemini API error {status}: {short}{retry_hint}" |
|
|
|
|
| def run_model(image: Image.Image | None, task: str, prompt: str, thinking_budget: int, api_key: str): |
| if image is None: |
| raise gr.Error("Please pick or upload an image first.") |
| if not prompt.strip(): |
| raise gr.Error("Prompt is empty.") |
|
|
| pil = image if isinstance(image, Image.Image) else Image.open(image) |
| pil = pil.convert("RGB") |
| buf = io.BytesIO() |
| pil.save(buf, format="PNG") |
| img_bytes = buf.getvalue() |
|
|
| try: |
| client = get_client(api_key) |
| config = types.GenerateContentConfig( |
| temperature=1.0, |
| thinking_config=types.ThinkingConfig(thinking_budget=int(thinking_budget)), |
| ) |
| response = client.models.generate_content( |
| model=MODEL, |
| contents=[ |
| types.Part.from_bytes(data=img_bytes, mime_type="image/png"), |
| prompt, |
| ], |
| config=config, |
| ) |
| except genai_errors.APIError as e: |
| raise gr.Error(_format_gemini_error(e)) |
| except Exception as e: |
| raise gr.Error(f"Request failed: {type(e).__name__}: {e}") |
|
|
| text = response.text or "" |
| if not text.strip(): |
| |
| finish = "" |
| try: |
| finish = str(response.candidates[0].finish_reason) if response.candidates else "" |
| except Exception: |
| pass |
| raise gr.Error( |
| "Model returned an empty response" |
| + (f" (finish_reason={finish})" if finish else "") |
| + ". This can happen if the image was blocked by safety filters or " |
| "if the output hit the token limit." |
| ) |
|
|
| parsed = extract_json(text) |
| overlay = visualize(pil, task, parsed) if parsed is not None else None |
| pretty = json.dumps(parsed, indent=2) if parsed is not None else "(no JSON parsed)" |
| return overlay if overlay is not None else pil, pretty, text |
|
|
|
|
| def on_task_change(task: str) -> str: |
| return TASK_TEMPLATES.get(task, "") |
|
|
|
|
| REMOTE_EXAMPLES = [ |
| |
| "https://www.shutterstock.com/image-photo/adiyaman-turkey-aug-21-2024-260nw-2514334691.jpg", |
| "https://c8.alamy.com/comp/2HXT4ME/full-containers-next-to-sets-of-lego-brand-bricks-with-figurines-on-a-carpet-in-a-playroom-2HXT4ME.jpg", |
| "https://media.licdn.com/dms/image/v2/D5612AQF1km1N_jy66Q/article-cover_image-shrink_720_1280/B56Zh8QE__HUAI-/0/1754431246854?e=2147483647&v=beta&t=iEGQih1pUnJNFBbas0K9rCpez-VgJOxOKJ2TGQM1-tY", |
| ] |
|
|
|
|
| def build_ui() -> gr.Blocks: |
| local_examples = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg")) |
| examples = local_examples + REMOTE_EXAMPLES |
|
|
| with gr.Blocks(title="Gemini Robotics-ER 1.6 Visualizer") as demo: |
| gr.Markdown( |
| "# Gemini Robotics-ER 1.6 Visualizer\n" |
| f"Model: `{MODEL}` — pick an image, choose a task, edit the prompt, run." |
| ) |
| api_key = gr.Textbox( |
| label="Gemini API key", |
| value="", |
| type="password", |
| placeholder="Paste your key here (or leave empty if the server provides one)", |
| info="Sent only to this server to call Gemini; never logged or persisted.", |
| ) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image = gr.Image(type="pil", label="Input image", height=380) |
| if examples: |
| gr.Examples(examples=examples, inputs=image, label="Example images") |
| task = gr.Dropdown( |
| choices=list(TASK_TEMPLATES.keys()), |
| value="Points (object detection)", |
| label="Task", |
| ) |
| prompt = gr.Textbox( |
| label="Prompt (editable)", |
| value=TASK_TEMPLATES["Points (object detection)"], |
| lines=6, |
| ) |
| thinking = gr.Slider( |
| 0, 8192, value=0, step=256, |
| label="Thinking budget (0 = fast, higher = more reasoning)", |
| ) |
| run = gr.Button("Run", variant="primary") |
| with gr.Column(scale=1): |
| overlay = gr.Image(type="pil", label="Result", height=480) |
| with gr.Accordion("Raw JSON", open=True): |
| parsed_out = gr.Code(label="Parsed JSON", language="json") |
| with gr.Accordion("Model answer (raw text)", open=True): |
| raw_out = gr.Textbox(label="response.text", lines=8) |
|
|
| task.change(on_task_change, inputs=task, outputs=prompt) |
| run.click( |
| run_model, |
| inputs=[image, task, prompt, thinking, api_key], |
| outputs=[overlay, parsed_out, raw_out], |
| ) |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| build_ui().launch(theme=gr.themes.Soft()) |
|
|