tugot17's picture
Add Gradio app for visualizing Gemini Robotics-ER 1.6 outputs
26ee000
"""Gradio webapp to visualize Gemini Robotics-ER 1.6 outputs."""
from __future__ import annotations
import io
import json
import os
import re
from pathlib import Path
import gradio as gr
from dotenv import load_dotenv
from google import genai
from google.genai import errors as genai_errors
from google.genai import types
from PIL import Image, ImageDraw, ImageFont
load_dotenv()
MODEL = "gemini-robotics-er-1.6-preview"
EXAMPLES_DIR = Path(__file__).parent / "examples"
PALETTE = [
"#ef4444", "#22c55e", "#3b82f6", "#eab308", "#a855f7",
"#ec4899", "#06b6d4", "#f97316", "#84cc16", "#14b8a6",
]
TASK_TEMPLATES = {
"Points (object detection)": (
"Point to the most prominent objects in this image (up to 10). "
"Output ONLY a JSON list in the form "
'[{"point": [y, x], "label": <concise object name>}] '
"with coordinates normalized to 0-1000."
),
"Bounding boxes": (
"Detect the prominent objects and return BOUNDING BOXES (rectangles), "
"not points. Output ONLY a JSON array where every item has the key "
'"box_2d" with four integers: '
'[{"box_2d": [ymin, xmin, ymax, xmax], "label": <concise object name>}]. '
"Coordinates normalized to 0-1000. Limit to 25 objects. "
'Do NOT use "point" — every item must have "box_2d" with exactly 4 values.'
),
"Trajectory (path planning)": (
"Pick one salient object in the image and plan a trajectory for a "
"robot gripper to pick it up and move it to a nearby free area. "
'Return ONLY a JSON list of ordered waypoints [{"point": [y, x], '
'"label": "<step_index>"}] with coordinates normalized to 0-1000 and '
'labels as stringified integers starting from "0". Use ~10-15 waypoints.'
),
"Grasp points": (
"Identify up to 3 good grasp points on objects visible in this scene. "
"Output ONLY JSON as "
'[{"point": [y, x], "label": <object + brief reason for this grasp>}] '
"with coordinates normalized to 0-1000."
),
"Instrument reading": (
"If any gauge, meter, clock, dial, or digital display is visible, read it. "
'Output ONLY a JSON object like {"reading": <value>, "units": <str>, '
'"explanation": <short>}. If no instrument is visible, return '
'{"reading": null, "explanation": "no instrument visible"}.'
),
"Free-form spatial reasoning": (
"Describe the spatial layout of this scene and suggest a reasonable "
"next action for a household or warehouse robot. Be concise."
),
}
def get_client(api_key: str | None = None) -> genai.Client:
key = (api_key or "").strip() or os.getenv("GEMINI_API_KEY")
if not key:
raise gr.Error(
"No API key. Paste one in the 'Gemini API key' field above, "
"or set GEMINI_API_KEY in .env."
)
return genai.Client(api_key=key)
def extract_json(text: str):
"""Strip markdown fences and parse JSON; return None on failure."""
if not text:
return None
stripped = text.strip()
fenced = re.search(r"```(?:json)?\s*(.+?)```", stripped, re.DOTALL)
if fenced:
stripped = fenced.group(1).strip()
try:
return json.loads(stripped)
except json.JSONDecodeError:
# try to find the first [...] or {...} blob
for pattern in (r"\[.*\]", r"\{.*\}"):
m = re.search(pattern, stripped, re.DOTALL)
if m:
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
continue
return None
def load_font(size: int) -> ImageFont.ImageFont:
for candidate in [
"/System/Library/Fonts/Helvetica.ttc",
"/System/Library/Fonts/Supplemental/Arial.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
]:
if os.path.exists(candidate):
try:
return ImageFont.truetype(candidate, size)
except OSError:
continue
return ImageFont.load_default()
def draw_label(draw: ImageDraw.ImageDraw, xy, text: str, color: str, font):
x, y = xy
bbox = draw.textbbox((x, y), text, font=font)
pad = 3
draw.rectangle(
(bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad),
fill=color,
)
draw.text((x, y), text, fill="white", font=font)
def render_points(img: Image.Image, items: list) -> Image.Image:
out = img.copy().convert("RGB")
draw = ImageDraw.Draw(out)
W, H = out.size
font = load_font(max(14, W // 70))
r = max(6, W // 180)
for i, item in enumerate(items):
pt = item.get("point") or item.get("xy")
if not pt or len(pt) != 2:
continue
y, x = pt
cx, cy = x * W / 1000, y * H / 1000
color = PALETTE[i % len(PALETTE)]
draw.ellipse((cx - r, cy - r, cx + r, cy + r), fill=color, outline="white", width=2)
label = str(item.get("label", ""))
if label:
draw_label(draw, (cx + r + 4, cy - r), label, color, font)
return out
def render_boxes(img: Image.Image, items: list) -> Image.Image:
out = img.copy().convert("RGB")
draw = ImageDraw.Draw(out)
W, H = out.size
font = load_font(max(14, W // 70))
for i, item in enumerate(items):
box = item.get("box_2d") or item.get("bbox")
if not box or len(box) != 4:
continue
ymin, xmin, ymax, xmax = box
x0, y0 = xmin * W / 1000, ymin * H / 1000
x1, y1 = xmax * W / 1000, ymax * H / 1000
color = PALETTE[i % len(PALETTE)]
draw.rectangle((x0, y0, x1, y1), outline=color, width=3)
label = str(item.get("label", ""))
if label:
draw_label(draw, (x0 + 4, y0 + 4), label, color, font)
return out
def render_trajectory(img: Image.Image, items: list) -> Image.Image:
out = img.copy().convert("RGB")
draw = ImageDraw.Draw(out)
W, H = out.size
font = load_font(max(12, W // 80))
def key(item):
try:
return int(str(item.get("label", "0")))
except ValueError:
return 0
pts = []
for item in sorted(items, key=key):
p = item.get("point")
if p and len(p) == 2:
y, x = p
pts.append((x * W / 1000, y * H / 1000, str(item.get("label", ""))))
if len(pts) >= 2:
for (x0, y0, _), (x1, y1, _) in zip(pts, pts[1:]):
draw.line((x0, y0, x1, y1), fill="#3b82f6", width=4)
r = max(6, W // 200)
for i, (x, y, label) in enumerate(pts):
color = PALETTE[0] if i == 0 else PALETTE[1] if i == len(pts) - 1 else "#3b82f6"
draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="white", width=2)
draw_label(draw, (x + r + 3, y - r - 2), label, color, font)
return out
def _as_items(payload) -> list:
if isinstance(payload, list):
return payload
if isinstance(payload, dict):
for key in ("points", "boxes", "waypoints", "items", "detections"):
v = payload.get(key)
if isinstance(v, list):
return v
return []
def visualize(img: Image.Image, task: str, payload) -> Image.Image | None:
"""Render overlays. Shape-aware: falls back to whatever the payload actually
contains so the visualization still works when the model returns a different
shape than the task template requested."""
items = _as_items(payload)
if not items:
return None
has_box = any(isinstance(i, dict) and ("box_2d" in i or "bbox" in i) for i in items)
has_point = any(isinstance(i, dict) and ("point" in i or "xy" in i) for i in items)
if task.startswith("Trajectory") and has_point:
return render_trajectory(img, items)
if task.startswith("Bounding"):
if has_box:
return render_boxes(img, items)
if has_point:
return render_points(img, items) # fallback: model returned points
if has_box:
return render_boxes(img, items)
if has_point:
return render_points(img, items)
return None
def _format_gemini_error(e: genai_errors.APIError) -> str:
"""Extract the useful parts of a Gemini API error into a single-line message."""
status = getattr(e, "code", None) or getattr(e, "status_code", "?")
message = getattr(e, "message", None) or str(e)
# The SDK stores the raw response JSON on the exception; try to pull retry info.
details_blob = getattr(e, "details", None) or {}
retry_hint = ""
try:
err = details_blob.get("error", details_blob) if isinstance(details_blob, dict) else {}
for d in err.get("details", []) or []:
if d.get("@type", "").endswith("RetryInfo"):
delay = d.get("retryDelay", "")
if delay:
retry_hint = f" (retry in {delay})"
break
except Exception:
pass
# Trim super-long messages for the toast.
short = message.split("\n", 1)[0].strip()
if len(short) > 400:
short = short[:397] + "..."
return f"Gemini API error {status}: {short}{retry_hint}"
def run_model(image: Image.Image | None, task: str, prompt: str, thinking_budget: int, api_key: str):
if image is None:
raise gr.Error("Please pick or upload an image first.")
if not prompt.strip():
raise gr.Error("Prompt is empty.")
pil = image if isinstance(image, Image.Image) else Image.open(image)
pil = pil.convert("RGB")
buf = io.BytesIO()
pil.save(buf, format="PNG")
img_bytes = buf.getvalue()
try:
client = get_client(api_key)
config = types.GenerateContentConfig(
temperature=1.0,
thinking_config=types.ThinkingConfig(thinking_budget=int(thinking_budget)),
)
response = client.models.generate_content(
model=MODEL,
contents=[
types.Part.from_bytes(data=img_bytes, mime_type="image/png"),
prompt,
],
config=config,
)
except genai_errors.APIError as e:
raise gr.Error(_format_gemini_error(e))
except Exception as e:
raise gr.Error(f"Request failed: {type(e).__name__}: {e}")
text = response.text or ""
if not text.strip():
# The SDK gave us an empty response — surface whatever we can.
finish = ""
try:
finish = str(response.candidates[0].finish_reason) if response.candidates else ""
except Exception:
pass
raise gr.Error(
"Model returned an empty response"
+ (f" (finish_reason={finish})" if finish else "")
+ ". This can happen if the image was blocked by safety filters or "
"if the output hit the token limit."
)
parsed = extract_json(text)
overlay = visualize(pil, task, parsed) if parsed is not None else None
pretty = json.dumps(parsed, indent=2) if parsed is not None else "(no JSON parsed)"
return overlay if overlay is not None else pil, pretty, text
def on_task_change(task: str) -> str:
return TASK_TEMPLATES.get(task, "")
REMOTE_EXAMPLES = [
# Licensed/watermarked images — referenced by URL so we don't redistribute them.
"https://www.shutterstock.com/image-photo/adiyaman-turkey-aug-21-2024-260nw-2514334691.jpg",
"https://c8.alamy.com/comp/2HXT4ME/full-containers-next-to-sets-of-lego-brand-bricks-with-figurines-on-a-carpet-in-a-playroom-2HXT4ME.jpg",
"https://media.licdn.com/dms/image/v2/D5612AQF1km1N_jy66Q/article-cover_image-shrink_720_1280/B56Zh8QE__HUAI-/0/1754431246854?e=2147483647&v=beta&t=iEGQih1pUnJNFBbas0K9rCpez-VgJOxOKJ2TGQM1-tY",
]
def build_ui() -> gr.Blocks:
local_examples = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg"))
examples = local_examples + REMOTE_EXAMPLES
with gr.Blocks(title="Gemini Robotics-ER 1.6 Visualizer") as demo:
gr.Markdown(
"# Gemini Robotics-ER 1.6 Visualizer\n"
f"Model: `{MODEL}` — pick an image, choose a task, edit the prompt, run."
)
api_key = gr.Textbox(
label="Gemini API key",
value="",
type="password",
placeholder="Paste your key here (or leave empty if the server provides one)",
info="Sent only to this server to call Gemini; never logged or persisted.",
)
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil", label="Input image", height=380)
if examples:
gr.Examples(examples=examples, inputs=image, label="Example images")
task = gr.Dropdown(
choices=list(TASK_TEMPLATES.keys()),
value="Points (object detection)",
label="Task",
)
prompt = gr.Textbox(
label="Prompt (editable)",
value=TASK_TEMPLATES["Points (object detection)"],
lines=6,
)
thinking = gr.Slider(
0, 8192, value=0, step=256,
label="Thinking budget (0 = fast, higher = more reasoning)",
)
run = gr.Button("Run", variant="primary")
with gr.Column(scale=1):
overlay = gr.Image(type="pil", label="Result", height=480)
with gr.Accordion("Raw JSON", open=True):
parsed_out = gr.Code(label="Parsed JSON", language="json")
with gr.Accordion("Model answer (raw text)", open=True):
raw_out = gr.Textbox(label="response.text", lines=8)
task.change(on_task_change, inputs=task, outputs=prompt)
run.click(
run_model,
inputs=[image, task, prompt, thinking, api_key],
outputs=[overlay, parsed_out, raw_out],
)
return demo
if __name__ == "__main__":
build_ui().launch(theme=gr.themes.Soft())