labrats / app.py
Smestern's picture
Deploy Labrats live Space (free-tier, remote embeddings)
f049226 verified
Raw
History Blame Contribute Delete
42.3 kB
"""Labrats replay UI \u2014 a Gradio Space that walks through a recorded
DiscoveryWorld episode and surfaces the memory layer's behaviour.
Two stories:
1. Single-trace replay: step through any run, see action / success /
scorecard, the memory hits the agent retrieved on that step, and
the cumulative notebook so far.
2. Side-by-side baseline-vs-stage2 view: action histograms, score
curves, and step-count summary to make the "memory fixes the
loop" delta visible at a glance.
The app loads pre-recorded traces from ./runs and does NOT run any
model. That keeps the Space deployable on the free CPU tier with no
GPU / quota dependencies.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import gradio as gr
import pandas as pd
# Make `src/` importable when the app is launched from the repo root.
ROOT = Path(__file__).resolve().parent
SRC = ROOT / "src"
if SRC.exists() and str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
from labrats.live import ( # noqa: E402
LiveConfig,
LiveRunner,
backend_available,
has_hf_token,
has_local_model,
)
from labrats.replay import ( # noqa: E402
Episode,
MemoryHitSnap,
StepRow,
author_color_map,
discover_episodes,
discover_frame_groups,
load_episode,
)
RUNS_DIR = ROOT / "runs"
VIDEO_DIR = ROOT / "video"
# Serve frames directly so swapping <img src=...> doesn't go
# through Gradio's lazy file pipeline (which causes the fade).
STATIC_FRAME_PATHS = [p for p in (RUNS_DIR, VIDEO_DIR) if p.exists()]
if STATIC_FRAME_PATHS:
gr.set_static_paths(paths=[str(p.resolve()) for p in STATIC_FRAME_PATHS])
def load_all_episodes() -> dict[str, Episode]:
eps = discover_episodes([RUNS_DIR])
return {ep.label: ep for ep in eps}
EPISODES = load_all_episodes()
DEFAULT_PICK = (
next((k for k in EPISODES if "stage2" in k.lower()), None)
or next(iter(EPISODES), None)
)
# ---- rendering helpers --------------------------------------------
def _ep(label: str | None) -> Episode | None:
if not label:
return None
return EPISODES.get(label)
def _format_action(a: dict) -> str:
name = a.get("action", "?")
bits = [f"**{name}**"]
for k in ("arg1", "arg2"):
v = a.get(k)
if v is not None:
bits.append(f"{k}={v}")
return " &nbsp; ".join(bits)
def _format_hits(hits: list[MemoryHitSnap]) -> str:
if not hits:
return "_(none)_"
lines = []
for h in hits:
lines.append(
f"- **[{h.type} @step {h.timestamp_step}]** "
f"_score={h.score:.2f} rec={h.recency:.2f} "
f"imp={h.importance:.2f} rel={h.relevance:.2f}_ \n"
f" {h.content}"
)
return "\n".join(lines)
def _format_notebook(items: list[str]) -> str:
if not items:
return "_(empty)_"
return "\n".join(f"- {x}" for x in items)
def _html_escape(s: str) -> str:
return (
s.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
def _format_chat(
ep: Episode | None,
*,
current_step: int,
visible_senders: list[str] | None = None,
) -> str:
"""Render the chat log as colored HTML; highlight the current step's row."""
if ep is None or not ep.chat_log:
if ep is not None and not ep.dialogue_enabled:
return "_Dialogue was disabled for this run._"
return "_(no chat in this run)_"
palette = author_color_map(ep.agent_names)
allowed = set(visible_senders) if visible_senders else set(ep.agent_names)
rows: list[str] = []
for line in ep.chat_log:
if line.sender not in allowed:
continue
color = palette.get(line.sender, "#444")
is_current = line.step == current_step
bg = "background:#080121;" if is_current else ""
rows.append(
f'<div style="padding:2px 6px;border-left:3px solid {color};'
f'margin:2px 0;{bg}font-family:ui-monospace,Menlo,monospace;font-size:12px;">'
f'<span style="color:#888">t{line.tick:02d} s{line.step:03d}</span> '
f'<b style="color:{color}">{_html_escape(line.sender)}</b>'
f' &rarr; <b>{_html_escape(line.to)}</b>: '
f'{_html_escape(line.text)}'
f"</div>"
)
if not rows:
return "_(no chat matches the current filter)_"
return (
'<div style="max-height:360px;overflow-y:auto;border:1px solid #ddd;'
'border-radius:6px;padding:4px;">'
+ "".join(rows)
+ "</div>"
)
def _action_hist_df(ep: Episode) -> pd.DataFrame:
hist = ep.action_histogram()
return pd.DataFrame({"action": list(hist.keys()), "count": list(hist.values())})
def _score_curve_df(ep: Episode) -> pd.DataFrame:
rows = []
for r in ep.steps:
sc = r.scorecard or {}
rows.append(
{
"step": r.step,
"scoreNormalized": float(sc.get("scoreNormalized") or 0.0),
}
)
return pd.DataFrame(rows) if rows else pd.DataFrame({"step": [], "scoreNormalized": []})
def _episode_header(ep: Episode) -> str:
n_steps = len(ep.steps)
last_sc = ep.steps[-1].scorecard if ep.steps else {}
badges = (
f"`agents={ep.n_agents}` &nbsp; "
f"`notebook={ep.notebook_mode}` &nbsp; "
f"`dialogue={'on' if ep.dialogue_enabled else 'off'}` &nbsp; "
f"`inbox_k={ep.inbox_k}` &nbsp; "
f"`chat_lines={len(ep.chat_log)}`"
)
return (
f"### {ep.path.parent.name}\n"
f"- **scenario:** {ep.scenario} ({ep.difficulty}, seed={ep.seed})\n"
f"- **agents:** {', '.join(ep.agent_names) or '?'}\n"
f"- **steps used:** {n_steps} / {ep.max_steps}\n"
f"- **final:** {ep.score_summary}\n"
f"- **last scorecard:** {json.dumps(last_sc)}\n"
f"- {badges}"
)
def _step_panel(ep: Episode | None, step_idx: int):
"""Returns header_md, action_md, scorecard_md, notebook_hits_md,
private_hits_md, notebook_cum_md, private_cum_md."""
if ep is None or not ep.steps:
empty = "_(no episode loaded)_"
return empty, empty, empty, empty, empty, empty, empty
step_idx = max(0, min(step_idx, len(ep.steps) - 1))
row: StepRow = ep.steps[step_idx]
status = "✅ success" if row.success else "❌ failure"
header = (
f"### Step {row.step} &nbsp;(tick {row.tick}, agent `{row.agent_name}`)\n\n"
f"{_format_action(row.action)} &nbsp;\u2192 "
f"{status}"
+ (f" \nretries: {row.retries}" if row.retries else "")
)
if row.errors:
header += f"\n\n**Errors:**\n```\n{chr(10).join(row.errors)}\n```"
if row.thought:
header += f"\n\n**Thought:** {row.thought}"
if row.utterance:
header += (
f"\n\n**\U0001f4ac Said to `{row.utterance.get('to', '?')}`:** "
f"{row.utterance.get('text', '')}"
)
sc_md = "**Scorecard progress:** " + (json.dumps(row.scorecard) if row.scorecard else "_(none)_")
notebook_hits = _format_hits(row.memory_hits.get("notebook", []))
private_hits = _format_hits(row.memory_hits.get("private", []))
notebook_cum = _format_notebook(row.notebook_so_far)
private_cum = _format_notebook(row.private_so_far)
return header, _episode_header(ep), sc_md, notebook_hits, private_hits, notebook_cum, private_cum
def _comparison_md(left: Episode | None, right: Episode | None) -> str:
"""Side-by-side summary table."""
if left is None or right is None:
return "_Select two episodes to compare._"
def _row(label: str, l_val, r_val) -> str:
return f"| **{label}** | {l_val} | {r_val} |"
lh, rh = left.action_histogram(), right.action_histogram()
actions = sorted(set(lh) | set(rh))
head = "| metric | " + left.path.parent.name + " | " + right.path.parent.name + " |"
sep = "| --- | --- | --- |"
rows = [
head,
sep,
_row("score", left.score_summary, right.score_summary),
_row("steps used", f"{len(left.steps)} / {left.max_steps}", f"{len(right.steps)} / {right.max_steps}"),
_row("ticks_run", left.ticks_run, right.ticks_run),
_row("tasks_complete", left.tasks_complete, right.tasks_complete),
]
for a in actions:
rows.append(_row(a, lh.get(a, 0), rh.get(a, 0)))
return "\n".join(rows)
# ---- callbacks ----------------------------------------------------
def on_episode_change(label: str):
ep = _ep(label)
if ep is None or not ep.steps:
return (
gr.update(maximum=0, value=0, label="Step"),
"_(no episode)_",
"_(no episode)_",
"_(no episode)_",
"_(no episode)_",
"_(no episode)_",
"_(no episode)_",
"_(no episode)_",
pd.DataFrame(),
pd.DataFrame(),
gr.update(choices=[], value=[]),
"_(no episode)_",
)
n = len(ep.steps)
header, ep_md, sc_md, nb_hits, pr_hits, nb_cum, pr_cum = _step_panel(ep, 0)
chat_md = _format_chat(ep, current_step=ep.steps[0].step, visible_senders=list(ep.agent_names))
return (
gr.update(minimum=0, maximum=n - 1, value=0, step=1, label=f"Step (0 .. {n - 1})"),
header,
ep_md,
sc_md,
nb_hits,
pr_hits,
nb_cum,
pr_cum,
_action_hist_df(ep),
_score_curve_df(ep),
gr.update(choices=list(ep.agent_names), value=list(ep.agent_names)),
chat_md,
)
def on_step_change(label: str, step_idx: float, senders: list[str] | None):
ep = _ep(label)
header, ep_md, sc_md, nb_hits, pr_hits, nb_cum, pr_cum = _step_panel(ep, int(step_idx))
current_step = ep.steps[int(step_idx)].step if (ep and ep.steps) else -1
chat_md = _format_chat(ep, current_step=current_step, visible_senders=senders or None)
return header, sc_md, nb_hits, pr_hits, nb_cum, pr_cum, chat_md
def on_chat_filter_change(label: str, step_idx: float, senders: list[str] | None):
ep = _ep(label)
current_step = ep.steps[int(step_idx)].step if (ep and ep.steps) else -1
return _format_chat(ep, current_step=current_step, visible_senders=senders or None)
def on_step_nudge(label: str, step_idx: float, delta: int):
ep = _ep(label)
if ep is None or not ep.steps:
return 0
new_idx = max(0, min(int(step_idx) + delta, len(ep.steps) - 1))
return new_idx
def on_compare(left_label: str, right_label: str):
return (
_comparison_md(_ep(left_label), _ep(right_label)),
_action_hist_df(_ep(left_label)) if _ep(left_label) else pd.DataFrame(),
_action_hist_df(_ep(right_label)) if _ep(right_label) else pd.DataFrame(),
_score_curve_df(_ep(left_label)) if _ep(left_label) else pd.DataFrame(),
_score_curve_df(_ep(right_label)) if _ep(right_label) else pd.DataFrame(),
)
def on_refresh():
global EPISODES
EPISODES = load_all_episodes()
keys = list(EPISODES)
default = (
next((k for k in keys if "stage2" in k.lower()), None)
or (keys[0] if keys else None)
)
return (
gr.update(choices=keys, value=default),
gr.update(choices=keys, value=(next((k for k in keys if "baseline" in k.lower()), None) or default)),
gr.update(choices=keys, value=default),
)
# ---- live inference ----------------------------------------------
# A single background runner shared by the Space. Live episodes are written
# into RUNS_DIR so they become replayable like any recorded run. Only one
# episode runs at a time (the shared HF token pays for every visitor).
LIVE_RUNNER = LiveRunner(RUNS_DIR)
# Curated scenarios for the dropdown. "Archaeology Dating" is the validated
# multi-agent map; the Small Skills tests are cheap single-agent smoke runs.
# The dropdown allows custom values so any DiscoveryWorld scenario name works.
LIVE_SCENARIOS = [
"Archaeology Dating",
"Proteomics",
"Plant Nutrients",
"Reactor Lab",
"Space Sick",
"Small Skills: Instrument Measurement Test",
"Small Skills: Dialog Test",
]
def _live_status_md() -> str:
st = LIVE_RUNNER.status
if st.state == "idle":
return "_Idle — configure a run and press **Start live episode**._"
icon = {"running": "\u23f3", "done": "\u2705", "error": "\u274c"}.get(st.state, "\u2022")
lines = [f"{icon} **{st.state}** — {st.message}"]
if st.run_label:
lines.append(f"\nRun: `{st.run_label}`")
if st.error:
lines.append("\n```\n" + st.error[-1500:] + "\n```")
return "\n".join(lines)
def on_live_start(
scenario, difficulty, seed, n_agents, max_steps, memory, dialogue, backend,
model, max_tokens,
):
backend = str(backend)
if not backend_available(backend):
if backend == "local":
msg = (
"\u274c No local model configured. Set `LLAMA_MODEL_PATH` (or "
"`LLAMA_REPO_ID` + `LLAMA_FILENAME`) and restart, then try again."
)
else:
msg = (
"\u274c No `HF_TOKEN` is configured, so the HF backend is "
"unavailable. Use the **Episode replay** tab to view recorded runs."
)
return msg, gr.update(active=False)
cfg = LiveConfig(
scenario=str(scenario),
difficulty=str(difficulty),
seed=int(seed),
n_agents=int(n_agents),
max_steps=int(max_steps),
memory=bool(memory),
dialogue=bool(dialogue),
backend=backend,
model=(str(model).strip() or None),
max_tokens=int(max_tokens),
)
LIVE_RUNNER.start(cfg)
return _live_status_md(), gr.update(active=True)
def _live_thread_for(run_label: str, thread_keys: list[str]) -> str | None:
"""Find the frame-group thread belonging to the current live run."""
if run_label:
match = next((t for t in thread_keys if run_label in t), None)
if match is not None:
return match
return thread_keys[0] if thread_keys else None
def _live_memory_views() -> tuple[str, str, str, str]:
"""Render (chat, notebook_hits, private_hits, notebook_cum, private_cum)
from the live run's partial trace.
Returns chat, notebook-cumulative, private-cumulative, and the latest
step's memory hits combined. The trace is read tolerantly because the
final line may be half-written when we catch it mid-flush.
"""
st = LIVE_RUNNER.status
trace_path = Path(st.trace_path) if st.trace_path else None
empty = "_(no chat yet)_"
none_yet = "_(none yet)_"
if trace_path is None or not trace_path.exists():
return empty, none_yet, none_yet, "_(empty)_", "_(empty)_"
try:
ep = load_episode(trace_path, tolerant=True)
except Exception: # noqa: BLE001 — never let a bad read break the poll loop
return empty, none_yet, none_yet, "_(empty)_", "_(empty)_"
if not ep.steps:
return empty, none_yet, none_yet, "_(empty)_", "_(empty)_"
last = ep.steps[-1]
chat = _format_chat(ep, current_step=last.step, visible_senders=list(ep.agent_names))
nb_hits = _format_hits(last.memory_hits.get("notebook", []))
pr_hits = _format_hits(last.memory_hits.get("private", []))
nb_cum = _format_notebook(last.notebook_so_far)
pr_cum = _format_notebook(last.private_so_far)
return chat, nb_hits, pr_hits, nb_cum, pr_cum
def on_live_poll():
"""Poll the background runner.
While the episode is running we re-scan the run's frame directory and
show the newest all-agents frame in the live viewport. When it finishes
we surface the completed run in the replay dropdown and stop the timer.
Note: this deliberately drives only the dedicated live viewport (an HTML
image + caption), never the replay tab's frame slider. Feeding a slider a
shrinking maximum while it holds a larger cached value makes Gradio raise
"Value N is greater than maximum value M" on the next preprocess.
"""
global EPISODES, THREADS, DEFAULT_THREAD
st = LIVE_RUNNER.status
finished = st.state in ("done", "error")
# Pick up any frames written so far (and, when done, the full trace).
THREADS = discover_threads()
thread_keys = list(THREADS)
new_thread = _live_thread_for(st.run_label, thread_keys)
DEFAULT_THREAD = new_thread
agents = _thread_agents(new_thread)
max_idx = _max_frame_idx(new_thread)
cur_idx = max_idx # auto-advance to the latest frame as it streams in
# The Run-live tab gets an always-all-agents viewport of the latest frame.
live_html = _frame_html_multi(new_thread, cur_idx)
if new_thread and max_idx >= 0 and agents:
live_caption = f"**{new_thread}** \u00b7 frame **{max_idx + 1}** \u00b7 {len(agents)} agent(s)"
else:
live_caption = "_(waiting for first frame…)_"
# Stream the memory layer (chat / notebook / private / hits) from the
# partial trace so the live tab mirrors the replay tab as the run unfolds.
live_chat, live_nb_hits, live_pr_hits, live_nb_cum, live_pr_cum = _live_memory_views()
if finished:
EPISODES = load_all_episodes()
keys = list(EPISODES)
new_label = next((k for k in keys if st.run_label and st.run_label in k), None)
target = new_label or (keys[0] if keys else None)
return (
_live_status_md(),
gr.Timer(active=False),
gr.update(choices=keys, value=target),
live_html,
live_caption,
live_chat,
live_nb_hits,
live_pr_hits,
live_nb_cum,
live_pr_cum,
)
# Still running: keep the timer ticking and leave the replay dropdown alone.
return (
_live_status_md(),
gr.update(),
gr.update(),
live_html,
live_caption,
live_chat,
live_nb_hits,
live_pr_hits,
live_nb_cum,
live_pr_cum,
)
# ---- frame playback ----------------------------------------------
def discover_threads() -> dict[str, dict[str, list[Path]]]:
"""Map { frame_group: { agent_id: [frame_path sorted by frame index] } }."""
return discover_frame_groups(RUNS_DIR, legacy_video_dir=VIDEO_DIR)
THREADS = discover_threads()
DEFAULT_THREAD = next(iter(THREADS), None)
def _thread_agents(thread: str | None) -> list[str]:
if not thread or thread not in THREADS:
return []
return list(THREADS[thread].keys())
def _max_frame_idx(thread: str | None) -> int:
if not thread or thread not in THREADS:
return 0
return max((len(v) for v in THREADS[thread].values()), default=1) - 1
def _frame_for(thread: str | None, agent: str | None, idx: int) -> str | None:
if not thread or not agent or thread not in THREADS:
return None
frames = THREADS[thread].get(agent, [])
if not frames:
return None
i = max(0, min(int(idx), len(frames) - 1))
return str(frames[i])
def _frame_url(p: Path | str) -> str:
"""URL for a path served via gr.set_static_paths."""
return "/gradio_api/file=" + str(Path(p).resolve()).replace("\\", "/")
_IMG_STYLE_SINGLE = (
"max-width:100%;max-height:560px;object-fit:contain;display:block;"
"margin:0 auto;image-rendering:pixelated;"
)
_IMG_STYLE_MULTI = (
"width:100%;max-height:520px;object-fit:contain;display:block;"
"image-rendering:pixelated;"
)
def _frame_html_single(thread: str | None, agent: str | None, idx: int) -> str:
p = _frame_for(thread, agent, idx)
if not p:
return "<div style='padding:1em;color:#888'>(no frame)</div>"
return (
f'<img src="{_frame_url(p)}" style="{_IMG_STYLE_SINGLE}" '
f'alt="agent {agent} frame {int(idx) + 1}">'
)
def _frame_html_multi(thread: str | None, idx: int) -> str:
if not thread or thread not in THREADS:
return "<div style='padding:1em;color:#888'>(no thread)</div>"
cells: list[str] = []
for agent, frames in THREADS[thread].items():
if not frames:
continue
i = max(0, min(int(idx), len(frames) - 1))
url = _frame_url(frames[i])
cells.append(
"<div style='flex:1 1 0;min-width:0;text-align:center;'>"
f"<div style='font:12px ui-monospace,Menlo,monospace;color:#666;"
f"padding:2px 0;'>agent {agent} \u00b7 frame {i + 1}/{len(frames)}</div>"
f"<img src='{url}' style='{_IMG_STYLE_MULTI}' alt='agent {agent}'>"
"</div>"
)
if not cells:
return "<div style='padding:1em;color:#888'>(no frames)</div>"
return (
"<div style='display:flex;flex-direction:row;gap:8px;'>"
+ "".join(cells)
+ "</div>"
)
def _frame_caption(thread: str | None, agent: str | None, idx: int) -> str:
if not thread or not agent or thread not in THREADS:
return "_(no frames)_"
frames = THREADS[thread].get(agent, [])
if not frames:
return "_(no frames)_"
i = max(0, min(int(idx), len(frames) - 1))
return f"**{thread}** \u00b7 agent `{agent}` \u00b7 frame **{i + 1} / {len(frames)}**"
def _frame_html_for_mode(thread: str | None, agent: str | None, idx: int, mode: str) -> str:
if mode == "All agents":
return _frame_html_multi(thread, idx)
return _frame_html_single(thread, agent, idx)
def on_thread_change(thread: str, agent: str | None, idx: float, mode: str):
agents = _thread_agents(thread)
new_agent = agent if agent in agents else (agents[0] if agents else None)
max_idx = _max_frame_idx(thread)
new_idx = max(0, min(int(idx), max_idx))
return (
gr.update(choices=agents, value=new_agent),
gr.update(minimum=0, maximum=max(max_idx, 0), value=new_idx,
label=f"Frame (0 .. {max_idx})"),
_frame_html_for_mode(thread, new_agent, new_idx, mode),
_frame_caption(thread, new_agent, new_idx),
)
def on_playback_change(thread: str, agent: str, idx: float, mode: str):
return (
_frame_html_for_mode(thread, agent, int(idx), mode),
_frame_caption(thread, agent, int(idx)),
)
def on_frame_nudge(thread: str, idx: float, delta: int):
max_idx = _max_frame_idx(thread)
if max_idx <= 0:
return 0
new = int(idx) + delta
if new < 0:
new = max_idx
elif new > max_idx:
new = 0
return new
def on_tick(thread: str, idx: float, playing: bool):
if not playing:
return gr.update()
return on_frame_nudge(thread, idx, +1)
def on_refresh_video():
global THREADS, DEFAULT_THREAD
THREADS = discover_threads()
keys = list(THREADS)
DEFAULT_THREAD = keys[0] if keys else None
agents = _thread_agents(DEFAULT_THREAD)
max_idx = _max_frame_idx(DEFAULT_THREAD)
first_agent = agents[0] if agents else None
return (
gr.update(choices=keys, value=DEFAULT_THREAD),
gr.update(choices=agents, value=first_agent),
gr.update(minimum=0, maximum=max(max_idx, 0), value=0,
label=f"Frame (0 .. {max_idx})"),
_frame_html_single(DEFAULT_THREAD, first_agent, 0),
_frame_caption(DEFAULT_THREAD, first_agent, 0),
)
# ---- UI -----------------------------------------------------------
with gr.Blocks(
title="Labrats — Tiny lab agents in DiscoveryWorld",
) as demo:
gr.Markdown(
"""
# Labrats — Tiny lab agents in DiscoveryWorld
A small (\u226432B) LLM agent inside the DiscoveryWorld simulator, with a
two-tier memory layer (private episodic + shared lab notebook).
"""
)
with gr.Tabs():
with gr.Tab("Run live"):
_token_ok = has_hf_token()
_local_ok = has_local_model()
_default_backend = "local" if _local_ok else "hf"
_any_backend = _token_ok or _local_ok
_status_bits = []
_status_bits.append(
"`HF_TOKEN` detected" if _token_ok else "no `HF_TOKEN`"
)
_status_bits.append(
"local model configured" if _local_ok else "no local model"
)
gr.Markdown(
"Run a **live episode** — spin up N ReAct agents on a fresh "
"DiscoveryWorld scenario. Choose the **HF** backend (Hugging "
"Face Inference Providers, needs an `HF_TOKEN`) or the **local** "
"backend (in-process llama-cpp, needs the `LLAMA_*` env vars). "
"The run streams into a new trace that appears in the "
"**Episode replay** tab when it finishes.\n\n"
+ (
"**Status:** " + ", ".join(_status_bits) + "."
if _any_backend
else "**Status:** no backend is configured (no `HF_TOKEN` and "
"no local model), so live runs are disabled. Recorded "
"episodes still work in the other tabs."
)
)
with gr.Row():
live_backend = gr.Radio(
choices=["local", "hf"],
value=_default_backend,
label="Backend",
scale=2,
)
live_scenario = gr.Dropdown(
choices=LIVE_SCENARIOS,
value="Archaeology Dating",
label="Scenario",
allow_custom_value=True,
scale=3,
)
live_difficulty = gr.Dropdown(
choices=["Easy", "Normal", "Challenge"],
value="Normal",
label="Difficulty",
scale=2,
)
live_seed = gr.Number(value=0, precision=0, label="Seed", scale=1)
with gr.Row():
live_agents = gr.Slider(
minimum=1, maximum=LiveRunner.MAX_AGENTS_CAP, value=2, step=1,
label="Agents", scale=2,
)
live_steps = gr.Slider(
minimum=1, maximum=LiveRunner.MAX_STEPS_CAP, value=15, step=1,
label="Max steps", scale=3,
)
live_memory = gr.Checkbox(value=True, label="Memory", scale=1)
live_dialogue = gr.Checkbox(value=True, label="Dialogue", scale=1)
with gr.Row():
live_model = gr.Textbox(
value="",
label="Model (HF backend)",
placeholder="blank = HF_MODEL / server default, e.g. "
"unsloth/gemma-3-12b-it",
scale=3,
)
live_max_tokens = gr.Slider(
minimum=256, maximum=4096, value=1024, step=128,
label="Max tokens / step",
info="Reasoning models need room to think + emit the action.",
scale=2,
)
live_start_btn = gr.Button(
"\u25b6 Start live episode", variant="primary", interactive=_any_backend
)
live_status_md = gr.Markdown(_live_status_md())
# Live viewport (left) and out-of-band team chat (right), streamed
# in as the episode runs. The frame mirrors the playback view on
# the replay tab; the chat reuses its render helper so the live
# view matches the recorded view exactly.
with gr.Row():
with gr.Column(scale=3):
live_frame_caption = gr.Markdown(
"_(frames appear here once a run starts)_"
)
live_frame_view = gr.HTML(
"<div style='padding:1em;color:#888'>(no frames yet)</div>"
)
with gr.Column(scale=2), gr.Accordion(
"\U0001f4ac Team chat", open=True
):
gr.Markdown(
"Out-of-band peer-to-peer messages between agents, "
"newest at the bottom. The latest step is "
"highlighted in dark blue."
)
live_chat = gr.HTML("_(no chat yet)_")
with gr.Accordion("Memory hits surfaced this step", open=False), gr.Row():
with gr.Column():
gr.Markdown("**Notebook (shared):**")
live_nb_hits_md = gr.Markdown("_(none yet)_")
with gr.Column():
gr.Markdown("**Private:**")
live_pr_hits_md = gr.Markdown("_(none yet)_")
with gr.Accordion("Cumulative memory so far", open=True), gr.Row():
with gr.Column():
gr.Markdown("### \U0001f4d3 Lab notebook (shared)")
live_notebook_md = gr.Markdown("_(empty)_")
with gr.Column():
gr.Markdown("### \U0001f5d2\ufe0f Private notes")
live_private_md = gr.Markdown("_(empty)_")
# Polls the background runner while an episode is in flight.
live_timer = gr.Timer(value=1.0, active=False)
with gr.Tab("Episode replay"):
with gr.Row():
ep_picker = gr.Dropdown(
choices=list(EPISODES),
value=DEFAULT_PICK,
label="Episode",
scale=4,
)
refresh_btn = gr.Button("\u21bb Refresh", scale=1)
ep_header = gr.Markdown()
with gr.Accordion("\U0001f3ac Frame playback", open=True):
with gr.Row():
thread_pick = gr.Dropdown(
choices=list(THREADS),
value=DEFAULT_THREAD,
label="Thread",
scale=3,
)
agent_pick = gr.Dropdown(
choices=_thread_agents(DEFAULT_THREAD),
value=(_thread_agents(DEFAULT_THREAD)[0] if _thread_agents(DEFAULT_THREAD) else None),
label="Agent (single view)",
scale=2,
)
view_mode = gr.Radio(
choices=["Single agent", "All agents"],
value="Single agent",
label="View",
scale=2,
)
refresh_video_btn = gr.Button("\u21bb Refresh", scale=1)
frame_caption = gr.Markdown(
_frame_caption(
DEFAULT_THREAD,
(_thread_agents(DEFAULT_THREAD)[0] if _thread_agents(DEFAULT_THREAD) else None),
0,
)
)
_initial_html = _frame_html_single(
DEFAULT_THREAD,
(_thread_agents(DEFAULT_THREAD)[0] if _thread_agents(DEFAULT_THREAD) else None),
0,
)
frame_view = gr.HTML(value=_initial_html)
with gr.Row():
frame_prev = gr.Button("\u25c0 prev")
_max_idx0 = _max_frame_idx(DEFAULT_THREAD)
frame_slider = gr.Slider(
minimum=0,
maximum=max(_max_idx0, 1),
value=0,
step=1,
label=f"Frame (0 .. {_max_idx0})",
scale=4,
)
frame_next = gr.Button("next \u25b6")
play_btn = gr.Button("\u25b6 play")
with gr.Row():
speed = gr.Slider(
minimum=0.1, maximum=2.0, value=0.5, step=0.1,
label="Frame interval (seconds)", scale=3,
)
playing_state = gr.State(False)
playback_timer = gr.Timer(value=0.5, active=False)
with gr.Row():
prev_btn = gr.Button("\u25c0 prev")
# Slider needs max>min at construction in Gradio 6; the
# demo.load callback resizes it once an episode is picked.
step_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Step", scale=4)
next_btn = gr.Button("next \u25b6")
with gr.Row():
with gr.Column(scale=3):
step_header = gr.Markdown()
sc_md = gr.Markdown()
with gr.Column(scale=2):
gr.Markdown("### Memory hits surfaced this step")
gr.Markdown("**Notebook (shared):**")
notebook_hits_md = gr.Markdown()
gr.Markdown("**Private:**")
private_hits_md = gr.Markdown()
with gr.Accordion("Cumulative memory so far", open=False), gr.Row():
with gr.Column():
gr.Markdown("### \U0001f4d3 Lab notebook (shared)")
notebook_cum_md = gr.Markdown()
with gr.Column():
gr.Markdown("### \U0001f5d2\ufe0f Private notes")
private_cum_md = gr.Markdown()
with gr.Accordion("\U0001f4ac Team chat", open=True):
gr.Markdown(
"Out-of-band peer-to-peer messages between agents. The "
"current step's utterance is highlighted in yellow."
)
chat_filter = gr.CheckboxGroup(
choices=[], value=[], label="Show senders",
)
chat_md = gr.HTML("_(no episode loaded)_")
with gr.Accordion("Episode summary plots", open=False), gr.Row():
hist_plot = gr.BarPlot(
value=pd.DataFrame(),
x="action",
y="count",
title="Action histogram",
height=240,
)
score_plot = gr.LinePlot(
value=pd.DataFrame(),
x="step",
y="scoreNormalized",
title="Score over time",
height=240,
y_lim=[0, 1],
)
with gr.Tab("Compare two runs"):
gr.Markdown(
"Pick a baseline (no memory) and a stage-2 run (memory on) "
"of the same scenario + seed to see the behavioural delta."
)
with gr.Row():
left_pick = gr.Dropdown(
choices=list(EPISODES),
value=next((k for k in EPISODES if "baseline" in k.lower()), DEFAULT_PICK),
label="Left episode",
)
right_pick = gr.Dropdown(
choices=list(EPISODES),
value=DEFAULT_PICK,
label="Right episode",
)
compare_btn = gr.Button("Compare")
cmp_md = gr.Markdown()
with gr.Row():
cmp_left_hist = gr.BarPlot(
value=pd.DataFrame(), x="action", y="count",
title="Left action histogram", height=260,
)
cmp_right_hist = gr.BarPlot(
value=pd.DataFrame(), x="action", y="count",
title="Right action histogram", height=260,
)
with gr.Row():
cmp_left_score = gr.LinePlot(
value=pd.DataFrame(), x="step", y="scoreNormalized",
title="Left score over time", height=260, y_lim=[0, 1],
)
cmp_right_score = gr.LinePlot(
value=pd.DataFrame(), x="step", y="scoreNormalized",
title="Right score over time", height=260, y_lim=[0, 1],
)
# ---- wiring --------------------------------------------------
ep_picker.change(
on_episode_change,
inputs=[ep_picker],
outputs=[
step_slider, step_header, ep_header, sc_md,
notebook_hits_md, private_hits_md,
notebook_cum_md, private_cum_md,
hist_plot, score_plot,
chat_filter, chat_md,
],
)
step_slider.change(
on_step_change,
inputs=[ep_picker, step_slider, chat_filter],
outputs=[step_header, sc_md, notebook_hits_md, private_hits_md,
notebook_cum_md, private_cum_md, chat_md],
)
chat_filter.change(
on_chat_filter_change,
inputs=[ep_picker, step_slider, chat_filter],
outputs=[chat_md],
)
prev_btn.click(
lambda lbl, s: on_step_nudge(lbl, s, -1),
inputs=[ep_picker, step_slider],
outputs=[step_slider],
)
next_btn.click(
lambda lbl, s: on_step_nudge(lbl, s, +1),
inputs=[ep_picker, step_slider],
outputs=[step_slider],
)
refresh_btn.click(
on_refresh,
outputs=[ep_picker, left_pick, right_pick],
)
compare_btn.click(
on_compare,
inputs=[left_pick, right_pick],
outputs=[cmp_md, cmp_left_hist, cmp_right_hist, cmp_left_score, cmp_right_score],
)
# ---- live inference wiring ----------------------------------
live_start_btn.click(
on_live_start,
inputs=[
live_scenario, live_difficulty, live_seed,
live_agents, live_steps, live_memory, live_dialogue, live_backend,
live_model, live_max_tokens,
],
outputs=[live_status_md, live_timer],
)
live_timer.tick(
on_live_poll,
outputs=[
live_status_md, live_timer, ep_picker,
live_frame_view, live_frame_caption,
live_chat, live_nb_hits_md, live_pr_hits_md,
live_notebook_md, live_private_md,
],
show_progress="hidden",
)
# ---- frame playback wiring ----------------------------------
thread_pick.change(
on_thread_change,
inputs=[thread_pick, agent_pick, frame_slider, view_mode],
outputs=[agent_pick, frame_slider, frame_view, frame_caption],
show_progress="hidden",
)
agent_pick.change(
on_playback_change,
inputs=[thread_pick, agent_pick, frame_slider, view_mode],
outputs=[frame_view, frame_caption],
show_progress="hidden",
)
frame_slider.change(
on_playback_change,
inputs=[thread_pick, agent_pick, frame_slider, view_mode],
outputs=[frame_view, frame_caption],
show_progress="hidden",
)
view_mode.change(
on_playback_change,
inputs=[thread_pick, agent_pick, frame_slider, view_mode],
outputs=[frame_view, frame_caption],
show_progress="hidden",
)
frame_prev.click(
lambda th, idx: on_frame_nudge(th, idx, -1),
inputs=[thread_pick, frame_slider],
outputs=[frame_slider],
show_progress="hidden",
)
frame_next.click(
lambda th, idx: on_frame_nudge(th, idx, +1),
inputs=[thread_pick, frame_slider],
outputs=[frame_slider],
show_progress="hidden",
)
def _toggle_play(playing: bool):
new = not bool(playing)
return (
new,
gr.update(value=("\u23f8 pause" if new else "\u25b6 play")),
gr.Timer(active=new),
)
play_btn.click(
_toggle_play,
inputs=[playing_state],
outputs=[playing_state, play_btn, playback_timer],
show_progress="hidden",
)
speed.change(
lambda s, active: gr.Timer(value=float(s), active=bool(active)),
inputs=[speed, playing_state],
outputs=[playback_timer],
show_progress="hidden",
)
playback_timer.tick(
on_tick,
inputs=[thread_pick, frame_slider, playing_state],
outputs=[frame_slider],
show_progress="hidden",
)
refresh_video_btn.click(
on_refresh_video,
outputs=[thread_pick, agent_pick, frame_slider,
frame_view, frame_caption],
show_progress="hidden",
)
# Initial population so the first-load view isn't blank.
demo.load(
on_episode_change,
inputs=[ep_picker],
outputs=[
step_slider, step_header, ep_header, sc_md,
notebook_hits_md, private_hits_md,
notebook_cum_md, private_cum_md,
hist_plot, score_plot,
chat_filter, chat_md,
],
)
demo.load(
on_compare,
inputs=[left_pick, right_pick],
outputs=[cmp_md, cmp_left_hist, cmp_right_hist, cmp_left_score, cmp_right_score],
)
if __name__ == "__main__":
demo.launch(
theme=gr.themes.Soft(),
allowed_paths=[str(VIDEO_DIR.resolve())] if VIDEO_DIR.exists() else None,
)