Spaces:
Sleeping
Sleeping
| """Labrats replay UI \u2014 a Gradio Space that walks through a recorded | |
| DiscoveryWorld episode and surfaces the memory layer's behaviour. | |
| Two stories: | |
| 1. Single-trace replay: step through any run, see action / success / | |
| scorecard, the memory hits the agent retrieved on that step, and | |
| the cumulative notebook so far. | |
| 2. Side-by-side baseline-vs-stage2 view: action histograms, score | |
| curves, and step-count summary to make the "memory fixes the | |
| loop" delta visible at a glance. | |
| The app loads pre-recorded traces from ./runs and does NOT run any | |
| model. That keeps the Space deployable on the free CPU tier with no | |
| GPU / quota dependencies. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| # Make `src/` importable when the app is launched from the repo root. | |
| ROOT = Path(__file__).resolve().parent | |
| SRC = ROOT / "src" | |
| if SRC.exists() and str(SRC) not in sys.path: | |
| sys.path.insert(0, str(SRC)) | |
| from labrats.live import ( # noqa: E402 | |
| LiveConfig, | |
| LiveRunner, | |
| backend_available, | |
| has_hf_token, | |
| has_local_model, | |
| ) | |
| from labrats.replay import ( # noqa: E402 | |
| Episode, | |
| MemoryHitSnap, | |
| StepRow, | |
| author_color_map, | |
| discover_episodes, | |
| discover_frame_groups, | |
| ) | |
| RUNS_DIR = ROOT / "runs" | |
| VIDEO_DIR = ROOT / "video" | |
| # Serve frames directly so swapping <img src=...> doesn't go | |
| # through Gradio's lazy file pipeline (which causes the fade). | |
| STATIC_FRAME_PATHS = [p for p in (RUNS_DIR, VIDEO_DIR) if p.exists()] | |
| if STATIC_FRAME_PATHS: | |
| gr.set_static_paths(paths=[str(p.resolve()) for p in STATIC_FRAME_PATHS]) | |
| def load_all_episodes() -> dict[str, Episode]: | |
| eps = discover_episodes([RUNS_DIR]) | |
| return {ep.label: ep for ep in eps} | |
| EPISODES = load_all_episodes() | |
| DEFAULT_PICK = ( | |
| next((k for k in EPISODES if "stage2" in k.lower()), None) | |
| or next(iter(EPISODES), None) | |
| ) | |
| # ---- rendering helpers -------------------------------------------- | |
| def _ep(label: str | None) -> Episode | None: | |
| if not label: | |
| return None | |
| return EPISODES.get(label) | |
| def _format_action(a: dict) -> str: | |
| name = a.get("action", "?") | |
| bits = [f"**{name}**"] | |
| for k in ("arg1", "arg2"): | |
| v = a.get(k) | |
| if v is not None: | |
| bits.append(f"{k}={v}") | |
| return " ".join(bits) | |
| def _format_hits(hits: list[MemoryHitSnap]) -> str: | |
| if not hits: | |
| return "_(none)_" | |
| lines = [] | |
| for h in hits: | |
| lines.append( | |
| f"- **[{h.type} @step {h.timestamp_step}]** " | |
| f"_score={h.score:.2f} rec={h.recency:.2f} " | |
| f"imp={h.importance:.2f} rel={h.relevance:.2f}_ \n" | |
| f" {h.content}" | |
| ) | |
| return "\n".join(lines) | |
| def _format_notebook(items: list[str]) -> str: | |
| if not items: | |
| return "_(empty)_" | |
| return "\n".join(f"- {x}" for x in items) | |
| def _html_escape(s: str) -> str: | |
| return ( | |
| s.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| ) | |
| def _format_chat( | |
| ep: Episode | None, | |
| *, | |
| current_step: int, | |
| visible_senders: list[str] | None = None, | |
| ) -> str: | |
| """Render the chat log as colored HTML; highlight the current step's row.""" | |
| if ep is None or not ep.chat_log: | |
| if ep is not None and not ep.dialogue_enabled: | |
| return "_Dialogue was disabled for this run._" | |
| return "_(no chat in this run)_" | |
| palette = author_color_map(ep.agent_names) | |
| allowed = set(visible_senders) if visible_senders else set(ep.agent_names) | |
| rows: list[str] = [] | |
| for line in ep.chat_log: | |
| if line.sender not in allowed: | |
| continue | |
| color = palette.get(line.sender, "#444") | |
| is_current = line.step == current_step | |
| bg = "background:#fff8c5;" if is_current else "" | |
| rows.append( | |
| f'<div style="padding:2px 6px;border-left:3px solid {color};' | |
| f'margin:2px 0;{bg}font-family:ui-monospace,Menlo,monospace;font-size:12px;">' | |
| f'<span style="color:#888">t{line.tick:02d} s{line.step:03d}</span> ' | |
| f'<b style="color:{color}">{_html_escape(line.sender)}</b>' | |
| f' → <b>{_html_escape(line.to)}</b>: ' | |
| f'{_html_escape(line.text)}' | |
| f"</div>" | |
| ) | |
| if not rows: | |
| return "_(no chat matches the current filter)_" | |
| return ( | |
| '<div style="max-height:360px;overflow-y:auto;border:1px solid #ddd;' | |
| 'border-radius:6px;padding:4px;">' | |
| + "".join(rows) | |
| + "</div>" | |
| ) | |
| def _action_hist_df(ep: Episode) -> pd.DataFrame: | |
| hist = ep.action_histogram() | |
| return pd.DataFrame({"action": list(hist.keys()), "count": list(hist.values())}) | |
| def _score_curve_df(ep: Episode) -> pd.DataFrame: | |
| rows = [] | |
| for r in ep.steps: | |
| sc = r.scorecard or {} | |
| rows.append( | |
| { | |
| "step": r.step, | |
| "scoreNormalized": float(sc.get("scoreNormalized") or 0.0), | |
| } | |
| ) | |
| return pd.DataFrame(rows) if rows else pd.DataFrame({"step": [], "scoreNormalized": []}) | |
| def _episode_header(ep: Episode) -> str: | |
| n_steps = len(ep.steps) | |
| last_sc = ep.steps[-1].scorecard if ep.steps else {} | |
| badges = ( | |
| f"`agents={ep.n_agents}` " | |
| f"`notebook={ep.notebook_mode}` " | |
| f"`dialogue={'on' if ep.dialogue_enabled else 'off'}` " | |
| f"`inbox_k={ep.inbox_k}` " | |
| f"`chat_lines={len(ep.chat_log)}`" | |
| ) | |
| return ( | |
| f"### {ep.path.parent.name}\n" | |
| f"- **scenario:** {ep.scenario} ({ep.difficulty}, seed={ep.seed})\n" | |
| f"- **agents:** {', '.join(ep.agent_names) or '?'}\n" | |
| f"- **steps used:** {n_steps} / {ep.max_steps}\n" | |
| f"- **final:** {ep.score_summary}\n" | |
| f"- **last scorecard:** {json.dumps(last_sc)}\n" | |
| f"- {badges}" | |
| ) | |
| def _step_panel(ep: Episode | None, step_idx: int): | |
| """Returns header_md, action_md, scorecard_md, notebook_hits_md, | |
| private_hits_md, notebook_cum_md, private_cum_md.""" | |
| if ep is None or not ep.steps: | |
| empty = "_(no episode loaded)_" | |
| return empty, empty, empty, empty, empty, empty, empty | |
| step_idx = max(0, min(step_idx, len(ep.steps) - 1)) | |
| row: StepRow = ep.steps[step_idx] | |
| status = "✅ success" if row.success else "❌ failure" | |
| header = ( | |
| f"### Step {row.step} (tick {row.tick}, agent `{row.agent_name}`)\n\n" | |
| f"{_format_action(row.action)} \u2192 " | |
| f"{status}" | |
| + (f" \nretries: {row.retries}" if row.retries else "") | |
| ) | |
| if row.errors: | |
| header += f"\n\n**Errors:**\n```\n{chr(10).join(row.errors)}\n```" | |
| if row.thought: | |
| header += f"\n\n**Thought:** {row.thought}" | |
| if row.utterance: | |
| header += ( | |
| f"\n\n**\U0001f4ac Said to `{row.utterance.get('to', '?')}`:** " | |
| f"{row.utterance.get('text', '')}" | |
| ) | |
| sc_md = "**Scorecard progress:** " + (json.dumps(row.scorecard) if row.scorecard else "_(none)_") | |
| notebook_hits = _format_hits(row.memory_hits.get("notebook", [])) | |
| private_hits = _format_hits(row.memory_hits.get("private", [])) | |
| notebook_cum = _format_notebook(row.notebook_so_far) | |
| private_cum = _format_notebook(row.private_so_far) | |
| return header, _episode_header(ep), sc_md, notebook_hits, private_hits, notebook_cum, private_cum | |
| def _comparison_md(left: Episode | None, right: Episode | None) -> str: | |
| """Side-by-side summary table.""" | |
| if left is None or right is None: | |
| return "_Select two episodes to compare._" | |
| def _row(label: str, l_val, r_val) -> str: | |
| return f"| **{label}** | {l_val} | {r_val} |" | |
| lh, rh = left.action_histogram(), right.action_histogram() | |
| actions = sorted(set(lh) | set(rh)) | |
| head = "| metric | " + left.path.parent.name + " | " + right.path.parent.name + " |" | |
| sep = "| --- | --- | --- |" | |
| rows = [ | |
| head, | |
| sep, | |
| _row("score", left.score_summary, right.score_summary), | |
| _row("steps used", f"{len(left.steps)} / {left.max_steps}", f"{len(right.steps)} / {right.max_steps}"), | |
| _row("ticks_run", left.ticks_run, right.ticks_run), | |
| _row("tasks_complete", left.tasks_complete, right.tasks_complete), | |
| ] | |
| for a in actions: | |
| rows.append(_row(a, lh.get(a, 0), rh.get(a, 0))) | |
| return "\n".join(rows) | |
| # ---- callbacks ---------------------------------------------------- | |
| def on_episode_change(label: str): | |
| ep = _ep(label) | |
| if ep is None or not ep.steps: | |
| return ( | |
| gr.update(maximum=0, value=0, label="Step"), | |
| "_(no episode)_", | |
| "_(no episode)_", | |
| "_(no episode)_", | |
| "_(no episode)_", | |
| "_(no episode)_", | |
| "_(no episode)_", | |
| "_(no episode)_", | |
| pd.DataFrame(), | |
| pd.DataFrame(), | |
| gr.update(choices=[], value=[]), | |
| "_(no episode)_", | |
| ) | |
| n = len(ep.steps) | |
| header, ep_md, sc_md, nb_hits, pr_hits, nb_cum, pr_cum = _step_panel(ep, 0) | |
| chat_md = _format_chat(ep, current_step=ep.steps[0].step, visible_senders=list(ep.agent_names)) | |
| return ( | |
| gr.update(minimum=0, maximum=n - 1, value=0, step=1, label=f"Step (0 .. {n - 1})"), | |
| header, | |
| ep_md, | |
| sc_md, | |
| nb_hits, | |
| pr_hits, | |
| nb_cum, | |
| pr_cum, | |
| _action_hist_df(ep), | |
| _score_curve_df(ep), | |
| gr.update(choices=list(ep.agent_names), value=list(ep.agent_names)), | |
| chat_md, | |
| ) | |
| def on_step_change(label: str, step_idx: float, senders: list[str] | None): | |
| ep = _ep(label) | |
| header, ep_md, sc_md, nb_hits, pr_hits, nb_cum, pr_cum = _step_panel(ep, int(step_idx)) | |
| current_step = ep.steps[int(step_idx)].step if (ep and ep.steps) else -1 | |
| chat_md = _format_chat(ep, current_step=current_step, visible_senders=senders or None) | |
| return header, sc_md, nb_hits, pr_hits, nb_cum, pr_cum, chat_md | |
| def on_chat_filter_change(label: str, step_idx: float, senders: list[str] | None): | |
| ep = _ep(label) | |
| current_step = ep.steps[int(step_idx)].step if (ep and ep.steps) else -1 | |
| return _format_chat(ep, current_step=current_step, visible_senders=senders or None) | |
| def on_step_nudge(label: str, step_idx: float, delta: int): | |
| ep = _ep(label) | |
| if ep is None or not ep.steps: | |
| return 0 | |
| new_idx = max(0, min(int(step_idx) + delta, len(ep.steps) - 1)) | |
| return new_idx | |
| def on_compare(left_label: str, right_label: str): | |
| return ( | |
| _comparison_md(_ep(left_label), _ep(right_label)), | |
| _action_hist_df(_ep(left_label)) if _ep(left_label) else pd.DataFrame(), | |
| _action_hist_df(_ep(right_label)) if _ep(right_label) else pd.DataFrame(), | |
| _score_curve_df(_ep(left_label)) if _ep(left_label) else pd.DataFrame(), | |
| _score_curve_df(_ep(right_label)) if _ep(right_label) else pd.DataFrame(), | |
| ) | |
| def on_refresh(): | |
| global EPISODES | |
| EPISODES = load_all_episodes() | |
| keys = list(EPISODES) | |
| default = ( | |
| next((k for k in keys if "stage2" in k.lower()), None) | |
| or (keys[0] if keys else None) | |
| ) | |
| return ( | |
| gr.update(choices=keys, value=default), | |
| gr.update(choices=keys, value=(next((k for k in keys if "baseline" in k.lower()), None) or default)), | |
| gr.update(choices=keys, value=default), | |
| ) | |
| # ---- live inference ---------------------------------------------- | |
| # A single background runner shared by the Space. Live episodes are written | |
| # into RUNS_DIR so they become replayable like any recorded run. Only one | |
| # episode runs at a time (the shared HF token pays for every visitor). | |
| LIVE_RUNNER = LiveRunner(RUNS_DIR) | |
| # Curated scenarios for the dropdown. "Archaeology Dating" is the validated | |
| # multi-agent map; the Small Skills tests are cheap single-agent smoke runs. | |
| # The dropdown allows custom values so any DiscoveryWorld scenario name works. | |
| LIVE_SCENARIOS = [ | |
| "Archaeology Dating", | |
| "Proteomics", | |
| "Plant Nutrients", | |
| "Reactor Lab", | |
| "Space Sick", | |
| "Small Skills: Instrument Measurement Test", | |
| "Small Skills: Dialog Test", | |
| ] | |
| def _live_status_md() -> str: | |
| st = LIVE_RUNNER.status | |
| if st.state == "idle": | |
| return "_Idle — configure a run and press **Start live episode**._" | |
| icon = {"running": "\u23f3", "done": "\u2705", "error": "\u274c"}.get(st.state, "\u2022") | |
| lines = [f"{icon} **{st.state}** — {st.message}"] | |
| if st.run_label: | |
| lines.append(f"\nRun: `{st.run_label}`") | |
| if st.error: | |
| lines.append("\n```\n" + st.error[-1500:] + "\n```") | |
| return "\n".join(lines) | |
| def on_live_start(scenario, difficulty, seed, n_agents, max_steps, memory, dialogue, backend): | |
| backend = str(backend) | |
| if not backend_available(backend): | |
| if backend == "local": | |
| msg = ( | |
| "\u274c No local model configured. Set `LLAMA_MODEL_PATH` (or " | |
| "`LLAMA_REPO_ID` + `LLAMA_FILENAME`) and restart, then try again." | |
| ) | |
| else: | |
| msg = ( | |
| "\u274c No `HF_TOKEN` is configured, so the HF backend is " | |
| "unavailable. Use the **Episode replay** tab to view recorded runs." | |
| ) | |
| return msg, gr.update(active=False) | |
| cfg = LiveConfig( | |
| scenario=str(scenario), | |
| difficulty=str(difficulty), | |
| seed=int(seed), | |
| n_agents=int(n_agents), | |
| max_steps=int(max_steps), | |
| memory=bool(memory), | |
| dialogue=bool(dialogue), | |
| backend=backend, | |
| ) | |
| LIVE_RUNNER.start(cfg) | |
| return _live_status_md(), gr.update(active=True) | |
| def _live_thread_for(run_label: str, thread_keys: list[str]) -> str | None: | |
| """Find the frame-group thread belonging to the current live run.""" | |
| if run_label: | |
| match = next((t for t in thread_keys if run_label in t), None) | |
| if match is not None: | |
| return match | |
| return thread_keys[0] if thread_keys else None | |
| def on_live_poll(): | |
| """Poll the background runner. | |
| While the episode is running we re-scan the run's frame directory and | |
| show the newest all-agents frame in the live viewport. When it finishes | |
| we surface the completed run in the replay dropdown and stop the timer. | |
| Note: this deliberately drives only the dedicated live viewport (an HTML | |
| image + caption), never the replay tab's frame slider. Feeding a slider a | |
| shrinking maximum while it holds a larger cached value makes Gradio raise | |
| "Value N is greater than maximum value M" on the next preprocess. | |
| """ | |
| global EPISODES, THREADS, DEFAULT_THREAD | |
| st = LIVE_RUNNER.status | |
| finished = st.state in ("done", "error") | |
| # Pick up any frames written so far (and, when done, the full trace). | |
| THREADS = discover_threads() | |
| thread_keys = list(THREADS) | |
| new_thread = _live_thread_for(st.run_label, thread_keys) | |
| DEFAULT_THREAD = new_thread | |
| agents = _thread_agents(new_thread) | |
| max_idx = _max_frame_idx(new_thread) | |
| cur_idx = max_idx # auto-advance to the latest frame as it streams in | |
| # The Run-live tab gets an always-all-agents viewport of the latest frame. | |
| live_html = _frame_html_multi(new_thread, cur_idx) | |
| if new_thread and max_idx >= 0 and agents: | |
| live_caption = f"**{new_thread}** \u00b7 frame **{max_idx + 1}** \u00b7 {len(agents)} agent(s)" | |
| else: | |
| live_caption = "_(waiting for first frame…)_" | |
| if finished: | |
| EPISODES = load_all_episodes() | |
| keys = list(EPISODES) | |
| new_label = next((k for k in keys if st.run_label and st.run_label in k), None) | |
| target = new_label or (keys[0] if keys else None) | |
| return ( | |
| _live_status_md(), | |
| gr.Timer(active=False), | |
| gr.update(choices=keys, value=target), | |
| live_html, | |
| live_caption, | |
| ) | |
| # Still running: keep the timer ticking and leave the replay dropdown alone. | |
| return ( | |
| _live_status_md(), | |
| gr.update(), | |
| gr.update(), | |
| live_html, | |
| live_caption, | |
| ) | |
| # ---- frame playback ---------------------------------------------- | |
| def discover_threads() -> dict[str, dict[str, list[Path]]]: | |
| """Map { frame_group: { agent_id: [frame_path sorted by frame index] } }.""" | |
| return discover_frame_groups(RUNS_DIR, legacy_video_dir=VIDEO_DIR) | |
| THREADS = discover_threads() | |
| DEFAULT_THREAD = next(iter(THREADS), None) | |
| def _thread_agents(thread: str | None) -> list[str]: | |
| if not thread or thread not in THREADS: | |
| return [] | |
| return list(THREADS[thread].keys()) | |
| def _max_frame_idx(thread: str | None) -> int: | |
| if not thread or thread not in THREADS: | |
| return 0 | |
| return max((len(v) for v in THREADS[thread].values()), default=1) - 1 | |
| def _frame_for(thread: str | None, agent: str | None, idx: int) -> str | None: | |
| if not thread or not agent or thread not in THREADS: | |
| return None | |
| frames = THREADS[thread].get(agent, []) | |
| if not frames: | |
| return None | |
| i = max(0, min(int(idx), len(frames) - 1)) | |
| return str(frames[i]) | |
| def _frame_url(p: Path | str) -> str: | |
| """URL for a path served via gr.set_static_paths.""" | |
| return "/gradio_api/file=" + str(Path(p).resolve()).replace("\\", "/") | |
| _IMG_STYLE_SINGLE = ( | |
| "max-width:100%;max-height:560px;object-fit:contain;display:block;" | |
| "margin:0 auto;image-rendering:pixelated;" | |
| ) | |
| _IMG_STYLE_MULTI = ( | |
| "width:100%;max-height:520px;object-fit:contain;display:block;" | |
| "image-rendering:pixelated;" | |
| ) | |
| def _frame_html_single(thread: str | None, agent: str | None, idx: int) -> str: | |
| p = _frame_for(thread, agent, idx) | |
| if not p: | |
| return "<div style='padding:1em;color:#888'>(no frame)</div>" | |
| return ( | |
| f'<img src="{_frame_url(p)}" style="{_IMG_STYLE_SINGLE}" ' | |
| f'alt="agent {agent} frame {int(idx) + 1}">' | |
| ) | |
| def _frame_html_multi(thread: str | None, idx: int) -> str: | |
| if not thread or thread not in THREADS: | |
| return "<div style='padding:1em;color:#888'>(no thread)</div>" | |
| cells: list[str] = [] | |
| for agent, frames in THREADS[thread].items(): | |
| if not frames: | |
| continue | |
| i = max(0, min(int(idx), len(frames) - 1)) | |
| url = _frame_url(frames[i]) | |
| cells.append( | |
| "<div style='flex:1 1 0;min-width:0;text-align:center;'>" | |
| f"<div style='font:12px ui-monospace,Menlo,monospace;color:#666;" | |
| f"padding:2px 0;'>agent {agent} \u00b7 frame {i + 1}/{len(frames)}</div>" | |
| f"<img src='{url}' style='{_IMG_STYLE_MULTI}' alt='agent {agent}'>" | |
| "</div>" | |
| ) | |
| if not cells: | |
| return "<div style='padding:1em;color:#888'>(no frames)</div>" | |
| return ( | |
| "<div style='display:flex;flex-direction:row;gap:8px;'>" | |
| + "".join(cells) | |
| + "</div>" | |
| ) | |
| def _frame_caption(thread: str | None, agent: str | None, idx: int) -> str: | |
| if not thread or not agent or thread not in THREADS: | |
| return "_(no frames)_" | |
| frames = THREADS[thread].get(agent, []) | |
| if not frames: | |
| return "_(no frames)_" | |
| i = max(0, min(int(idx), len(frames) - 1)) | |
| return f"**{thread}** \u00b7 agent `{agent}` \u00b7 frame **{i + 1} / {len(frames)}**" | |
| def _frame_html_for_mode(thread: str | None, agent: str | None, idx: int, mode: str) -> str: | |
| if mode == "All agents": | |
| return _frame_html_multi(thread, idx) | |
| return _frame_html_single(thread, agent, idx) | |
| def on_thread_change(thread: str, agent: str | None, idx: float, mode: str): | |
| agents = _thread_agents(thread) | |
| new_agent = agent if agent in agents else (agents[0] if agents else None) | |
| max_idx = _max_frame_idx(thread) | |
| new_idx = max(0, min(int(idx), max_idx)) | |
| return ( | |
| gr.update(choices=agents, value=new_agent), | |
| gr.update(minimum=0, maximum=max(max_idx, 0), value=new_idx, | |
| label=f"Frame (0 .. {max_idx})"), | |
| _frame_html_for_mode(thread, new_agent, new_idx, mode), | |
| _frame_caption(thread, new_agent, new_idx), | |
| ) | |
| def on_playback_change(thread: str, agent: str, idx: float, mode: str): | |
| return ( | |
| _frame_html_for_mode(thread, agent, int(idx), mode), | |
| _frame_caption(thread, agent, int(idx)), | |
| ) | |
| def on_frame_nudge(thread: str, idx: float, delta: int): | |
| max_idx = _max_frame_idx(thread) | |
| if max_idx <= 0: | |
| return 0 | |
| new = int(idx) + delta | |
| if new < 0: | |
| new = max_idx | |
| elif new > max_idx: | |
| new = 0 | |
| return new | |
| def on_tick(thread: str, idx: float, playing: bool): | |
| if not playing: | |
| return gr.update() | |
| return on_frame_nudge(thread, idx, +1) | |
| def on_refresh_video(): | |
| global THREADS, DEFAULT_THREAD | |
| THREADS = discover_threads() | |
| keys = list(THREADS) | |
| DEFAULT_THREAD = keys[0] if keys else None | |
| agents = _thread_agents(DEFAULT_THREAD) | |
| max_idx = _max_frame_idx(DEFAULT_THREAD) | |
| first_agent = agents[0] if agents else None | |
| return ( | |
| gr.update(choices=keys, value=DEFAULT_THREAD), | |
| gr.update(choices=agents, value=first_agent), | |
| gr.update(minimum=0, maximum=max(max_idx, 0), value=0, | |
| label=f"Frame (0 .. {max_idx})"), | |
| _frame_html_single(DEFAULT_THREAD, first_agent, 0), | |
| _frame_caption(DEFAULT_THREAD, first_agent, 0), | |
| ) | |
| # ---- UI ----------------------------------------------------------- | |
| with gr.Blocks( | |
| title="Labrats — Tiny lab agents in DiscoveryWorld", | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # Labrats — Tiny lab agents in DiscoveryWorld | |
| A small (\u226432B) LLM agent inside the DiscoveryWorld simulator, with a | |
| two-tier memory layer (private episodic + shared lab notebook). | |
| This Space replays recorded episodes. **Phase B** is a vanilla ReAct | |
| loop; **Phase C** adds the memory layer. The comparison tab shows the | |
| behavioural delta from adding memory on the same scenario and seed. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Episode replay"): | |
| with gr.Row(): | |
| ep_picker = gr.Dropdown( | |
| choices=list(EPISODES), | |
| value=DEFAULT_PICK, | |
| label="Episode", | |
| scale=4, | |
| ) | |
| refresh_btn = gr.Button("\u21bb Refresh", scale=1) | |
| ep_header = gr.Markdown() | |
| with gr.Accordion("\U0001f3ac Frame playback", open=True): | |
| with gr.Row(): | |
| thread_pick = gr.Dropdown( | |
| choices=list(THREADS), | |
| value=DEFAULT_THREAD, | |
| label="Thread", | |
| scale=3, | |
| ) | |
| agent_pick = gr.Dropdown( | |
| choices=_thread_agents(DEFAULT_THREAD), | |
| value=(_thread_agents(DEFAULT_THREAD)[0] if _thread_agents(DEFAULT_THREAD) else None), | |
| label="Agent (single view)", | |
| scale=2, | |
| ) | |
| view_mode = gr.Radio( | |
| choices=["Single agent", "All agents"], | |
| value="Single agent", | |
| label="View", | |
| scale=2, | |
| ) | |
| refresh_video_btn = gr.Button("\u21bb Refresh", scale=1) | |
| frame_caption = gr.Markdown( | |
| _frame_caption( | |
| DEFAULT_THREAD, | |
| (_thread_agents(DEFAULT_THREAD)[0] if _thread_agents(DEFAULT_THREAD) else None), | |
| 0, | |
| ) | |
| ) | |
| _initial_html = _frame_html_single( | |
| DEFAULT_THREAD, | |
| (_thread_agents(DEFAULT_THREAD)[0] if _thread_agents(DEFAULT_THREAD) else None), | |
| 0, | |
| ) | |
| frame_view = gr.HTML(value=_initial_html) | |
| with gr.Row(): | |
| frame_prev = gr.Button("\u25c0 prev") | |
| _max_idx0 = _max_frame_idx(DEFAULT_THREAD) | |
| frame_slider = gr.Slider( | |
| minimum=0, | |
| maximum=max(_max_idx0, 1), | |
| value=0, | |
| step=1, | |
| label=f"Frame (0 .. {_max_idx0})", | |
| scale=4, | |
| ) | |
| frame_next = gr.Button("next \u25b6") | |
| play_btn = gr.Button("\u25b6 play") | |
| with gr.Row(): | |
| speed = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=0.5, step=0.1, | |
| label="Frame interval (seconds)", scale=3, | |
| ) | |
| playing_state = gr.State(False) | |
| playback_timer = gr.Timer(value=0.5, active=False) | |
| with gr.Row(): | |
| prev_btn = gr.Button("\u25c0 prev") | |
| # Slider needs max>min at construction in Gradio 6; the | |
| # demo.load callback resizes it once an episode is picked. | |
| step_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Step", scale=4) | |
| next_btn = gr.Button("next \u25b6") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| step_header = gr.Markdown() | |
| sc_md = gr.Markdown() | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Memory hits surfaced this step") | |
| gr.Markdown("**Notebook (shared):**") | |
| notebook_hits_md = gr.Markdown() | |
| gr.Markdown("**Private:**") | |
| private_hits_md = gr.Markdown() | |
| with gr.Accordion("Cumulative memory so far", open=False), gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### \U0001f4d3 Lab notebook (shared)") | |
| notebook_cum_md = gr.Markdown() | |
| with gr.Column(): | |
| gr.Markdown("### \U0001f5d2\ufe0f Private notes") | |
| private_cum_md = gr.Markdown() | |
| with gr.Accordion("\U0001f4ac Team chat", open=True): | |
| gr.Markdown( | |
| "Out-of-band peer-to-peer messages between agents. The " | |
| "current step's utterance is highlighted in yellow." | |
| ) | |
| chat_filter = gr.CheckboxGroup( | |
| choices=[], value=[], label="Show senders", | |
| ) | |
| chat_md = gr.HTML("_(no episode loaded)_") | |
| with gr.Accordion("Episode summary plots", open=False), gr.Row(): | |
| hist_plot = gr.BarPlot( | |
| value=pd.DataFrame(), | |
| x="action", | |
| y="count", | |
| title="Action histogram", | |
| height=240, | |
| ) | |
| score_plot = gr.LinePlot( | |
| value=pd.DataFrame(), | |
| x="step", | |
| y="scoreNormalized", | |
| title="Score over time", | |
| height=240, | |
| y_lim=[0, 1], | |
| ) | |
| with gr.Tab("Compare two runs"): | |
| gr.Markdown( | |
| "Pick a baseline (no memory) and a stage-2 run (memory on) " | |
| "of the same scenario + seed to see the behavioural delta." | |
| ) | |
| with gr.Row(): | |
| left_pick = gr.Dropdown( | |
| choices=list(EPISODES), | |
| value=next((k for k in EPISODES if "baseline" in k.lower()), DEFAULT_PICK), | |
| label="Left episode", | |
| ) | |
| right_pick = gr.Dropdown( | |
| choices=list(EPISODES), | |
| value=DEFAULT_PICK, | |
| label="Right episode", | |
| ) | |
| compare_btn = gr.Button("Compare") | |
| cmp_md = gr.Markdown() | |
| with gr.Row(): | |
| cmp_left_hist = gr.BarPlot( | |
| value=pd.DataFrame(), x="action", y="count", | |
| title="Left action histogram", height=260, | |
| ) | |
| cmp_right_hist = gr.BarPlot( | |
| value=pd.DataFrame(), x="action", y="count", | |
| title="Right action histogram", height=260, | |
| ) | |
| with gr.Row(): | |
| cmp_left_score = gr.LinePlot( | |
| value=pd.DataFrame(), x="step", y="scoreNormalized", | |
| title="Left score over time", height=260, y_lim=[0, 1], | |
| ) | |
| cmp_right_score = gr.LinePlot( | |
| value=pd.DataFrame(), x="step", y="scoreNormalized", | |
| title="Right score over time", height=260, y_lim=[0, 1], | |
| ) | |
| with gr.Tab("Run live"): | |
| _token_ok = has_hf_token() | |
| _local_ok = has_local_model() | |
| _default_backend = "local" if _local_ok else "hf" | |
| _any_backend = _token_ok or _local_ok | |
| _status_bits = [] | |
| _status_bits.append( | |
| "`HF_TOKEN` detected" if _token_ok else "no `HF_TOKEN`" | |
| ) | |
| _status_bits.append( | |
| "local model configured" if _local_ok else "no local model" | |
| ) | |
| gr.Markdown( | |
| "Run a **live episode** — spin up N ReAct agents on a fresh " | |
| "DiscoveryWorld scenario. Choose the **HF** backend (Hugging " | |
| "Face Inference Providers, needs an `HF_TOKEN`) or the **local** " | |
| "backend (in-process llama-cpp, needs the `LLAMA_*` env vars). " | |
| "The run streams into a new trace that appears in the " | |
| "**Episode replay** tab when it finishes.\n\n" | |
| + ( | |
| "**Status:** " + ", ".join(_status_bits) + "." | |
| if _any_backend | |
| else "**Status:** no backend is configured (no `HF_TOKEN` and " | |
| "no local model), so live runs are disabled. Recorded " | |
| "episodes still work in the other tabs." | |
| ) | |
| ) | |
| with gr.Row(): | |
| live_backend = gr.Radio( | |
| choices=["local", "hf"], | |
| value=_default_backend, | |
| label="Backend", | |
| scale=2, | |
| ) | |
| live_scenario = gr.Dropdown( | |
| choices=LIVE_SCENARIOS, | |
| value="Archaeology Dating", | |
| label="Scenario", | |
| allow_custom_value=True, | |
| scale=3, | |
| ) | |
| live_difficulty = gr.Dropdown( | |
| choices=["Easy", "Normal", "Challenge"], | |
| value="Normal", | |
| label="Difficulty", | |
| scale=2, | |
| ) | |
| live_seed = gr.Number(value=0, precision=0, label="Seed", scale=1) | |
| with gr.Row(): | |
| live_agents = gr.Slider( | |
| minimum=1, maximum=LiveRunner.MAX_AGENTS_CAP, value=2, step=1, | |
| label="Agents", scale=2, | |
| ) | |
| live_steps = gr.Slider( | |
| minimum=1, maximum=LiveRunner.MAX_STEPS_CAP, value=15, step=1, | |
| label="Max steps", scale=3, | |
| ) | |
| live_memory = gr.Checkbox(value=True, label="Memory", scale=1) | |
| live_dialogue = gr.Checkbox(value=True, label="Dialogue", scale=1) | |
| live_start_btn = gr.Button( | |
| "\u25b6 Start live episode", variant="primary", interactive=_any_backend | |
| ) | |
| live_status_md = gr.Markdown(_live_status_md()) | |
| # Live viewport: the latest rendered frame, streamed in as the | |
| # episode runs. Mirrors the playback view on the replay tab. | |
| live_frame_caption = gr.Markdown("_(frames appear here once a run starts)_") | |
| live_frame_view = gr.HTML( | |
| "<div style='padding:1em;color:#888'>(no frames yet)</div>" | |
| ) | |
| # Polls the background runner while an episode is in flight. | |
| live_timer = gr.Timer(value=1.0, active=False) | |
| # ---- wiring -------------------------------------------------- | |
| ep_picker.change( | |
| on_episode_change, | |
| inputs=[ep_picker], | |
| outputs=[ | |
| step_slider, step_header, ep_header, sc_md, | |
| notebook_hits_md, private_hits_md, | |
| notebook_cum_md, private_cum_md, | |
| hist_plot, score_plot, | |
| chat_filter, chat_md, | |
| ], | |
| ) | |
| step_slider.change( | |
| on_step_change, | |
| inputs=[ep_picker, step_slider, chat_filter], | |
| outputs=[step_header, sc_md, notebook_hits_md, private_hits_md, | |
| notebook_cum_md, private_cum_md, chat_md], | |
| ) | |
| chat_filter.change( | |
| on_chat_filter_change, | |
| inputs=[ep_picker, step_slider, chat_filter], | |
| outputs=[chat_md], | |
| ) | |
| prev_btn.click( | |
| lambda lbl, s: on_step_nudge(lbl, s, -1), | |
| inputs=[ep_picker, step_slider], | |
| outputs=[step_slider], | |
| ) | |
| next_btn.click( | |
| lambda lbl, s: on_step_nudge(lbl, s, +1), | |
| inputs=[ep_picker, step_slider], | |
| outputs=[step_slider], | |
| ) | |
| refresh_btn.click( | |
| on_refresh, | |
| outputs=[ep_picker, left_pick, right_pick], | |
| ) | |
| compare_btn.click( | |
| on_compare, | |
| inputs=[left_pick, right_pick], | |
| outputs=[cmp_md, cmp_left_hist, cmp_right_hist, cmp_left_score, cmp_right_score], | |
| ) | |
| # ---- live inference wiring ---------------------------------- | |
| live_start_btn.click( | |
| on_live_start, | |
| inputs=[ | |
| live_scenario, live_difficulty, live_seed, | |
| live_agents, live_steps, live_memory, live_dialogue, live_backend, | |
| ], | |
| outputs=[live_status_md, live_timer], | |
| ) | |
| live_timer.tick( | |
| on_live_poll, | |
| outputs=[ | |
| live_status_md, live_timer, ep_picker, | |
| live_frame_view, live_frame_caption, | |
| ], | |
| show_progress="hidden", | |
| ) | |
| # ---- frame playback wiring ---------------------------------- | |
| thread_pick.change( | |
| on_thread_change, | |
| inputs=[thread_pick, agent_pick, frame_slider, view_mode], | |
| outputs=[agent_pick, frame_slider, frame_view, frame_caption], | |
| show_progress="hidden", | |
| ) | |
| agent_pick.change( | |
| on_playback_change, | |
| inputs=[thread_pick, agent_pick, frame_slider, view_mode], | |
| outputs=[frame_view, frame_caption], | |
| show_progress="hidden", | |
| ) | |
| frame_slider.change( | |
| on_playback_change, | |
| inputs=[thread_pick, agent_pick, frame_slider, view_mode], | |
| outputs=[frame_view, frame_caption], | |
| show_progress="hidden", | |
| ) | |
| view_mode.change( | |
| on_playback_change, | |
| inputs=[thread_pick, agent_pick, frame_slider, view_mode], | |
| outputs=[frame_view, frame_caption], | |
| show_progress="hidden", | |
| ) | |
| frame_prev.click( | |
| lambda th, idx: on_frame_nudge(th, idx, -1), | |
| inputs=[thread_pick, frame_slider], | |
| outputs=[frame_slider], | |
| show_progress="hidden", | |
| ) | |
| frame_next.click( | |
| lambda th, idx: on_frame_nudge(th, idx, +1), | |
| inputs=[thread_pick, frame_slider], | |
| outputs=[frame_slider], | |
| show_progress="hidden", | |
| ) | |
| def _toggle_play(playing: bool): | |
| new = not bool(playing) | |
| return ( | |
| new, | |
| gr.update(value=("\u23f8 pause" if new else "\u25b6 play")), | |
| gr.Timer(active=new), | |
| ) | |
| play_btn.click( | |
| _toggle_play, | |
| inputs=[playing_state], | |
| outputs=[playing_state, play_btn, playback_timer], | |
| show_progress="hidden", | |
| ) | |
| speed.change( | |
| lambda s, active: gr.Timer(value=float(s), active=bool(active)), | |
| inputs=[speed, playing_state], | |
| outputs=[playback_timer], | |
| show_progress="hidden", | |
| ) | |
| playback_timer.tick( | |
| on_tick, | |
| inputs=[thread_pick, frame_slider, playing_state], | |
| outputs=[frame_slider], | |
| show_progress="hidden", | |
| ) | |
| refresh_video_btn.click( | |
| on_refresh_video, | |
| outputs=[thread_pick, agent_pick, frame_slider, | |
| frame_view, frame_caption], | |
| show_progress="hidden", | |
| ) | |
| # Initial population so the first-load view isn't blank. | |
| demo.load( | |
| on_episode_change, | |
| inputs=[ep_picker], | |
| outputs=[ | |
| step_slider, step_header, ep_header, sc_md, | |
| notebook_hits_md, private_hits_md, | |
| notebook_cum_md, private_cum_md, | |
| hist_plot, score_plot, | |
| chat_filter, chat_md, | |
| ], | |
| ) | |
| demo.load( | |
| on_compare, | |
| inputs=[left_pick, right_pick], | |
| outputs=[cmp_md, cmp_left_hist, cmp_right_hist, cmp_left_score, cmp_right_score], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| allowed_paths=[str(VIDEO_DIR.resolve())] if VIDEO_DIR.exists() else None, | |
| ) | |