Spaces:
Runtime error
Runtime error
| """ | |
| SpindleFlow RL β Streamlit Dashboard | |
| ===================================== | |
| Run: cd spindleflow-rl && streamlit run demo/streamlit_app.py | |
| URL: http://localhost:8501 | |
| """ | |
| from __future__ import annotations | |
| import os, sys, json, html as _html | |
| from pathlib import Path | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| load_dotenv() # load OPENAI_API_KEY (and any other vars) from .env | |
| # HF_HUB_OFFLINE intentionally NOT set β manual HF Hub downloads must work | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| import streamlit as st | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from env.spindleflow_env import SpindleFlowEnv | |
| from env.state import EpisodeState | |
| from env.specialist_registry import SpecialistRegistry | |
| from orchestrator_widget import render_orchestrator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Page config (must be first Streamlit call) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="SpindleFlow RL", | |
| page_icon="β‘", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Constants | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CONFIG = "configs/training_config.yaml" | |
| CATALOG = "configs/specialist_catalog.yaml" | |
| ASSETS = Path("demo/assets") | |
| SPEC_COLORS = { | |
| "frontend_react": "#00d4ff", | |
| "backend_api": "#7c3aed", | |
| "database_architect": "#f59e0b", | |
| "devops_engineer": "#10b981", | |
| "security_analyst": "#ef4444", | |
| "product_strategist": "#8b5cf6", | |
| "ux_designer": "#ec4899", | |
| "tech_writer": "#94a3b8", | |
| } | |
| def _get_preset_tasks(n: int = 8) -> list[str]: | |
| """Sample n live tasks from TaskBank at page load β no hardcoded strings.""" | |
| try: | |
| from training.task_bank import TaskBank | |
| bank = TaskBank(phase=1) | |
| return [bank.sample() for _ in range(n)] | |
| except Exception: | |
| # Fallback only if TaskBank is unavailable (e.g. missing config) | |
| return ["Describe a software engineering task requiring specialist collaboration"] | |
| PRESET_TASKS = _get_preset_tasks() | |
| HF_MODEL_REPO = "garvitsachdeva/spindleflow-rl" | |
| def _load_trained_model(hf_repo: str): | |
| """Download RecurrentPPO + VecNormalize stats from HF Hub. | |
| Returns (model, obs_mean, obs_var, clip_obs, error_str). | |
| Temporarily lifts the HF_HUB_OFFLINE flag set at module level. | |
| """ | |
| import pickle | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| from sb3_contrib import RecurrentPPO | |
| _tok = os.getenv("HF_TOKEN") or None | |
| # Try final model first, fall back to latest periodic checkpoint | |
| try: | |
| _model_path = hf_hub_download(hf_repo, "spindleflow_model.zip", token=_tok) | |
| except Exception: | |
| _model_path = hf_hub_download(hf_repo, "spindleflow_model_latest.zip", token=_tok) | |
| model = RecurrentPPO.load(_model_path, device="cpu") | |
| obs_mean = obs_var = None | |
| clip_obs = 10.0 | |
| try: | |
| try: | |
| stats_path = hf_hub_download(hf_repo, "vec_normalize.pkl", token=_tok) | |
| except Exception: | |
| stats_path = hf_hub_download(hf_repo, "vec_normalize_latest.pkl", token=_tok) | |
| with open(stats_path, "rb") as f: | |
| vn = pickle.load(f) | |
| obs_mean = vn.obs_rms.mean.copy() | |
| obs_var = vn.obs_rms.var.copy() | |
| clip_obs = float(vn.clip_obs) | |
| except Exception: | |
| pass | |
| return model, obs_mean, obs_var, clip_obs, None | |
| except Exception as exc: | |
| return None, None, None, 10.0, str(exc) | |
| finally: | |
| pass | |
| def _predict(model, obs: np.ndarray, lstm_states, episode_starts, | |
| obs_mean, obs_var, clip_obs: float): | |
| """Normalize obs and call model.predict(); return (action, new_lstm_states).""" | |
| obs_arr = obs[np.newaxis, :].copy().astype(np.float32) | |
| if obs_mean is not None and obs_var is not None: | |
| obs_arr = np.clip( | |
| (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8), | |
| -clip_obs, clip_obs, | |
| ) | |
| action_batch, new_states = model.predict( | |
| obs_arr, | |
| state=lstm_states, | |
| episode_start=episode_starts, | |
| deterministic=True, | |
| ) | |
| return action_batch[0], new_states | |
| DARK = dict( | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| plot_bgcolor="rgba(0,0,0,0)", | |
| font=dict(color="#e2e8f0", family="Inter, system-ui, sans-serif"), | |
| margin=dict(l=44, r=20, t=44, b=40), | |
| ) | |
| DARK_AXES = dict( | |
| xaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), | |
| yaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Session state | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Session: | |
| def __init__(self): | |
| self.env: SpindleFlowEnv | None = None | |
| self.registry: SpecialistRegistry | None = None | |
| self.rewards: list[float] = [] | |
| self.actions: list[dict] = [] | |
| self.step_n = 0 | |
| self.done = False | |
| self.task = "" | |
| # Full episode history for replay | |
| self.episode_history: list[dict] = [] | |
| # Action entropy per step (policy confidence) | |
| self.step_entropies: list[float] = [] | |
| # Observation vector stats per step | |
| self.obs_history: list[dict] = [] | |
| # Specialists auto-spawned for this episode | |
| self.spawned_specialists: list[str] = [] | |
| # Trained policy inference state | |
| self.obs_current: np.ndarray | None = None | |
| self.lstm_states = None | |
| self.episode_starts = np.array([True]) | |
| def boot(self): | |
| if self.env is None: | |
| self.env = SpindleFlowEnv( | |
| config_path=CONFIG, catalog_path=CATALOG, | |
| use_real_spindleflow=False, phase=1, | |
| ) | |
| self.registry = self.env.registry | |
| def reset(self, phase: int = 1): | |
| self.boot() | |
| self.env.phase = int(phase) | |
| obs, info = self.env.reset() | |
| self.rewards = [] | |
| self.actions = [] | |
| self.step_n = 0 | |
| self.done = False | |
| self.task = info.get("task", "") | |
| self.episode_history = [] | |
| self.step_entropies = [] | |
| self.obs_history = [] | |
| self.spawned_specialists: list[str] = list(info.get("spawned_specialists", [])) | |
| self.obs_current = obs | |
| self.lstm_states = None | |
| self.episode_starts = np.array([True]) | |
| return obs, info | |
| def step(self, action): | |
| if self.env is None or self.done: | |
| return None, 0.0, True, False, {} | |
| obs, r, term, trunc, info = self.env.step(action) | |
| self.rewards.append(r) | |
| self.actions.append(info) | |
| self.step_n += 1 | |
| self.done = term or trunc | |
| self.obs_current = obs | |
| self.episode_starts = np.array([self.done]) | |
| # Capture step snapshot for replay | |
| called = info.get("called_specialists", []) | |
| edges = [(e.caller_id, e.callee_id) | |
| for e in self.env.delegation_graph.get_delegation_path()] | |
| self.episode_history.append({ | |
| "step": self.step_n, | |
| "reward": r, | |
| "action_name": info.get("action_name", "UNKNOWN"), | |
| "called": list(called), | |
| "edges": list(edges), | |
| "components": dict(info.get("reward_components", {})), | |
| "mode": info.get("delegation_mode", ""), | |
| "cumulative": float(sum(self.rewards)), | |
| "latencies": dict(info.get("specialist_latencies", {})), | |
| }) | |
| # Compute real action entropy (specialist-selection logits) | |
| if self.env is not None: | |
| n = self.env.max_specialists | |
| spec_logits = action[1: 1 + n].copy() | |
| spec_logits = spec_logits - spec_logits.max() | |
| exp_l = np.exp(spec_logits) | |
| probs = exp_l / (exp_l.sum() + 1e-8) | |
| entropy = float(-np.sum(probs * np.log(probs + 1e-8))) | |
| self.step_entropies.append(entropy) | |
| # Capture observation norm for state trace | |
| if obs is not None: | |
| self.obs_history.append({ | |
| "step": self.step_n, | |
| "obs_norm": float(np.linalg.norm(obs)), | |
| "obs_mean": float(obs.mean()), | |
| "obs_max": float(obs.max()), | |
| }) | |
| return obs, r, term, trunc, info | |
| def _S() -> Session: | |
| if "session" not in st.session_state: | |
| st.session_state.session = Session() | |
| return st.session_state.session | |
| def _load_catalog() -> list[dict]: | |
| import yaml | |
| with open(CATALOG) as f: | |
| return yaml.safe_load(f)["specialists"] | |
| def _exec_mode_badges(S: "Session") -> str: | |
| """Return inline HTML badge strip showing execution and task-generation modes.""" | |
| import os | |
| has_key = bool(os.getenv("OPENAI_API_KEY")) | |
| llm_tasks = S.env is not None and S.env.task_bank._client is not None | |
| exec_b = ( | |
| '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;' | |
| 'background:rgba(16,185,129,0.1);color:#34d399;' | |
| 'border:1px solid rgba(16,185,129,0.22);">β LLM BASELINE</span>' | |
| if has_key else | |
| '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;' | |
| 'background:rgba(245,158,11,0.1);color:#fbbf24;' | |
| 'border:1px solid rgba(245,158,11,0.22);">' | |
| 'β‘ SIMULATION MODE β specialist outputs templated Β· set OPENAI_API_KEY for real LLM</span>' | |
| ) | |
| task_b = ( | |
| '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;' | |
| 'background:rgba(16,185,129,0.1);color:#34d399;' | |
| 'border:1px solid rgba(16,185,129,0.22);">β LLM TASKS</span>' | |
| if llm_tasks else | |
| '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;' | |
| 'background:rgba(148,163,184,0.08);color:#64748b;' | |
| 'border:1px solid rgba(148,163,184,0.18);">β‘ CATALOG TASKS</span>' | |
| ) if S.env is not None else "" | |
| return ( | |
| f'<div style="display:flex;gap:8px;flex-wrap:wrap;margin:4px 0 12px;">' | |
| f'{exec_b}{task_b}</div>' | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Chart builders | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fig_reward_curve(rewards: list[float]) -> go.Figure: | |
| if not rewards: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| **DARK, **DARK_AXES, | |
| title=dict(text="Episode Reward", font=dict(size=13, color="#64748b")), | |
| annotations=[dict(text="Reset the environment to begin", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(color="#334155", size=13))], | |
| ) | |
| return fig | |
| steps = list(range(len(rewards))) | |
| cumul = np.cumsum(rewards).tolist() | |
| fig = make_subplots(rows=2, cols=1, shared_xaxes=True, | |
| row_heights=[0.62, 0.38], vertical_spacing=0.04) | |
| fig.add_trace(go.Scatter( | |
| x=steps, y=cumul, mode="lines", | |
| line=dict(color="#00d4ff", width=2.5), | |
| fill="tozeroy", fillcolor="rgba(0,212,255,0.07)", | |
| name="Cumulative", | |
| ), row=1, col=1) | |
| fig.add_trace(go.Bar( | |
| x=steps, y=rewards, | |
| marker_color=["#10b981" if r >= 0 else "#ef4444" for r in rewards], | |
| marker_line_width=0, name="Per-step", | |
| ), row=2, col=1) | |
| fig.update_layout(**DARK, height=300, showlegend=False, | |
| title=dict(text="Episode Reward", font=dict(size=13, color="#94a3b8"))) | |
| fig.update_xaxes(gridcolor="rgba(255,255,255,0.05)") | |
| fig.update_yaxes(gridcolor="rgba(255,255,255,0.05)", | |
| title_text="Cumul.", row=1, col=1, title_font_size=10) | |
| fig.update_yaxes(title_text="Step", row=2, col=1, title_font_size=10) | |
| return fig | |
| def fig_delegation_graph( | |
| S: Session, | |
| called_ids: list[str], | |
| edges: list[tuple], | |
| highlight_latest: bool = True, | |
| spawned_ids: list[str] | None = None, | |
| ) -> go.Figure: | |
| """ | |
| Professional hierarchical DAG layout. | |
| Orchestrator at top, called specialists in middle, uncalled dimmed at bottom. | |
| """ | |
| all_ids = list(S.registry.list_ids()) if S.registry else [] | |
| called_set = set(called_ids) | |
| spawned_set = set(spawned_ids or S.spawned_specialists) | |
| uncalled = [x for x in all_ids if x not in called_set] | |
| # ββ Build node positions (hierarchical layout) βββββββββββββββββββ | |
| pos = {"orchestrator": (0.5, 0.92)} | |
| n_called = len(called_ids) | |
| if n_called > 0: | |
| for i, sid in enumerate(called_ids): | |
| x = (i + 1) / (n_called + 1) | |
| pos[sid] = (x, 0.55) | |
| n_uncalled = len(uncalled) | |
| if n_uncalled > 0: | |
| for i, sid in enumerate(uncalled): | |
| x = (i + 1) / (n_uncalled + 1) | |
| pos[sid] = (x, 0.12) | |
| fig = go.Figure() | |
| # ββ Background depth ring ββββββββββββββββββββββββββββββββββββββββ | |
| max_depth = getattr(S.env, "max_depth", 2) if S.env else 2 | |
| cur_depth = S.env.delegation_graph.depth if S.env else 0 | |
| depth_frac = cur_depth / max(max_depth, 1) | |
| ring_color = ("#10b981" if depth_frac < 0.7 | |
| else ("#f59e0b" if depth_frac < 1.0 else "#ef4444")) | |
| fig.add_shape(type="rect", | |
| x0=0.0, y0=0.0, x1=1.0, y1=1.0, | |
| line=dict(color=ring_color, width=2, dash="dot"), | |
| fillcolor="rgba(0,0,0,0)", xref="x", yref="y", | |
| ) | |
| fig.add_annotation( | |
| x=0.98, y=0.98, xref="x", yref="y", | |
| text=f"Depth {cur_depth}/{max_depth}", showarrow=False, | |
| font=dict(size=9, color=ring_color), xanchor="right", yanchor="top", | |
| ) | |
| # ββ Edges ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| latest_edge = edges[-1] if edges else None | |
| for src, dst in edges: | |
| if src not in pos or dst not in pos: | |
| continue | |
| x0, y0 = pos[src] | |
| x1, y1 = pos[dst] | |
| is_latest = (latest_edge and highlight_latest and (src, dst) == latest_edge) | |
| color = "rgba(0,212,255,0.9)" if is_latest else "rgba(0,212,255,0.45)" | |
| width = 2.5 if is_latest else 1.8 | |
| dash = "dash" if is_latest else "solid" | |
| fig.add_trace(go.Scatter( | |
| x=[x0, x1, None], y=[y0, y1, None], mode="lines", | |
| line=dict(color=color, width=width, dash=dash), | |
| hoverinfo="skip", showlegend=False, | |
| )) | |
| fig.add_annotation( | |
| ax=x0, ay=y0, x=x1, y=y1, | |
| xref="x", yref="y", axref="x", ayref="y", | |
| arrowhead=3, arrowsize=1.4, arrowwidth=2, | |
| arrowcolor=color, showarrow=True, | |
| ) | |
| # ββ Orchestrator node ββββββββββββββββββββββββββββββββββββββββββββ | |
| ox, oy = pos["orchestrator"] | |
| fig.add_trace(go.Scatter( | |
| x=[ox], y=[oy], mode="markers+text", | |
| marker=dict(size=44, color="#f59e0b", symbol="circle", | |
| line=dict(color="#fcd34d", width=2.5), opacity=1.0), | |
| text=["<b>ORCH</b>"], textposition="middle center", | |
| textfont=dict(size=9, color="#0a0f1a", family="Inter, sans-serif"), | |
| hovertext=["<b>Orchestrator</b><br>Root node β makes all delegation decisions"], | |
| hoverinfo="text", showlegend=False, name="orchestrator", | |
| )) | |
| # ββ Called specialist nodes ββββββββββββββββββββββββββββββββββββββ | |
| for sid in called_ids: | |
| if sid not in pos: | |
| continue | |
| x, y = pos[sid] | |
| c = SPEC_COLORS.get(sid, "#7c3aed") | |
| spec = S.registry.get(sid) if S.registry else None | |
| role = spec.role if spec else sid | |
| lat = f"{spec.avg_latency_ms}ms" if spec else "" | |
| is_spawned = sid in spawned_set | |
| symbol = "star" if is_spawned else "circle" | |
| size = 38 if is_spawned else 32 | |
| border_c = "#fbbf24" if is_spawned else "rgba(255,255,255,0.4)" | |
| hover_tag = " β‘ AUTO-SPAWNED" if is_spawned else "" | |
| label = (("β‘ " if is_spawned else "") + sid).replace("_", "<br>") | |
| fig.add_trace(go.Scatter( | |
| x=[x], y=[y], mode="markers+text", | |
| marker=dict(size=size, color=c, symbol=symbol, | |
| line=dict(color=border_c, width=2.5), opacity=1.0), | |
| text=[label], textposition="bottom center", | |
| textfont=dict(size=8, color="#fbbf24" if is_spawned else "#e2e8f0"), | |
| hovertext=[f"<b>{role}</b><br>Called β{hover_tag}<br>{lat}"], | |
| hoverinfo="text", showlegend=False, | |
| )) | |
| # ββ Uncalled specialist nodes (dimmed) βββββββββββββββββββββββββββ | |
| for sid in uncalled: | |
| if sid not in pos: | |
| continue | |
| x, y = pos[sid] | |
| c = SPEC_COLORS.get(sid, "#334155") | |
| spec = S.registry.get(sid) if S.registry else None | |
| role = spec.role if spec else sid | |
| label = sid.replace("_", "<br>") | |
| fig.add_trace(go.Scatter( | |
| x=[x], y=[y], mode="markers+text", | |
| marker=dict(size=16, color="#1e293b", symbol="circle", | |
| line=dict(color=c, width=1), opacity=0.5), | |
| text=[label], textposition="bottom center", | |
| textfont=dict(size=7, color="rgba(148,163,184,0.45)"), | |
| hovertext=[f"<b>{role}</b><br>Not called"], | |
| hoverinfo="text", showlegend=False, | |
| )) | |
| # ββ Section labels βββββββββββββββββββββββββββββββββββββββββββββββ | |
| fig.add_annotation(x=0.01, y=0.96, xref="x", yref="y", | |
| text="ORCHESTRATOR", showarrow=False, | |
| font=dict(size=8, color="#475569"), xanchor="left") | |
| if called_ids: | |
| fig.add_annotation(x=0.01, y=0.62, xref="x", yref="y", | |
| text="CALLED", showarrow=False, | |
| font=dict(size=8, color="#00d4ff"), xanchor="left") | |
| if uncalled: | |
| fig.add_annotation(x=0.01, y=0.19, xref="x", yref="y", | |
| text="AVAILABLE", showarrow=False, | |
| font=dict(size=8, color="#334155"), xanchor="left") | |
| fig.update_layout( | |
| **DARK, height=420, | |
| title=dict( | |
| text=(f"Delegation Graph Β· {len(called_ids)} specialists called" | |
| f" Β· Depth {cur_depth}/{max_depth}"), | |
| font=dict(size=13, color="#94a3b8"), | |
| ), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.05]), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.08]), | |
| ) | |
| return fig | |
| def fig_reward_breakdown(components: dict) -> go.Figure: | |
| if not components: | |
| components = {k: 0.0 for k in [ | |
| "quality_delta", "efficiency_penalty", "failure_penalty", | |
| "recovery_bonus", "conflict_penalty", "conflict_bonus", | |
| "consistency_bonus", "latency_penalty", "explanation_bonus", | |
| ]} | |
| names = list(components.keys()) | |
| values = [components[k] for k in names] | |
| fig = go.Figure(go.Bar( | |
| x=values, | |
| y=[n.replace("_", " ").title() for n in names], | |
| orientation="h", | |
| marker_color=["#10b981" if v >= 0 else "#ef4444" for v in values], | |
| marker_line_width=0, | |
| text=[f"{v:+.3f}" for v in values], | |
| textposition="outside", | |
| textfont=dict(color="#94a3b8", size=9), | |
| )) | |
| fig.add_vline(x=0, line_color="rgba(255,255,255,0.15)", line_width=1) | |
| fig.update_layout(**DARK, height=310, | |
| title=dict(text="Reward Breakdown", font=dict(size=13, color="#94a3b8")), | |
| xaxis=dict(gridcolor="rgba(255,255,255,0.05)", title="Value"), | |
| yaxis=dict(gridcolor="rgba(255,255,255,0.05)")) | |
| return fig | |
| def fig_policy_confidence( | |
| entropies: list[float], | |
| step_labels: list[int] | None = None, | |
| ) -> go.Figure: | |
| """ | |
| Policy confidence chart β specialist-selection entropy per step. | |
| High entropy = uncertain/exploring. Low = confident/committed. | |
| Real data from actual action vectors used each step. | |
| """ | |
| if not entropies: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| **DARK, **DARK_AXES, | |
| title=dict(text="Policy Confidence (Action Entropy)", | |
| font=dict(size=13, color="#64748b")), | |
| annotations=[dict(text="Run an episode to see real action entropy", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(color="#334155", size=12))], | |
| ) | |
| return fig | |
| steps = step_labels or list(range(1, len(entropies) + 1)) | |
| max_e = float(np.log(max(len(entropies), 2))) | |
| norm_e = [min(1.0, max(0.0, e / max(max_e, 1e-8))) for e in entropies] | |
| colors = [ | |
| f"rgba({int(0 + 124 * ne)},{int(212 - 154 * ne)},{int(255 - 58 * ne)},0.85)" | |
| for ne in norm_e | |
| ] | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=steps, y=norm_e, | |
| marker_color=colors, marker_line_width=0, | |
| name="Normalised entropy", | |
| text=[f"{e:.3f}" for e in entropies], | |
| textposition="outside", | |
| textfont=dict(size=8, color="#94a3b8"), | |
| hovertemplate="Step %{x}<br>Entropy: %{text}<extra></extra>", | |
| )) | |
| fig.add_hline(y=0.5, line_dash="dot", line_color="rgba(148,163,184,0.3)", | |
| annotation_text="Mid-entropy", annotation_font_color="#475569") | |
| fig.update_layout( | |
| **DARK, height=260, | |
| title=dict(text="Policy Confidence β Specialist Selection Entropy per Step", | |
| font=dict(size=12, color="#94a3b8")), | |
| xaxis=dict(title="Episode Step", gridcolor="rgba(255,255,255,0.05)", | |
| zerolinecolor="rgba(255,255,255,0.08)"), | |
| yaxis=dict(title="Entropy (0=certain, 1=uniform)", range=[0, 1.15], | |
| gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), | |
| showlegend=False, | |
| ) | |
| return fig | |
| def fig_similarity(registry: SpecialistRegistry) -> go.Figure: | |
| ids = registry.list_ids() | |
| n = len(ids) | |
| if n == 0: | |
| fig = go.Figure() | |
| fig.update_layout(**DARK, title=dict(text="No specialists in registry", | |
| font=dict(size=13, color="#64748b"))) | |
| return fig | |
| missing = [sid for sid in ids if registry.get(sid).embedding is None] | |
| if missing: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| **DARK, **DARK_AXES, | |
| title=dict(text="Embeddings not computed β boot the environment first", | |
| font=dict(size=13, color="#64748b")), | |
| annotations=[dict(text=f"Missing embeddings: {', '.join(missing[:4])}", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(color="#334155", size=12))], | |
| ) | |
| return fig | |
| mat = np.zeros((n, n)) | |
| try: | |
| for i, a in enumerate(ids): | |
| for j, b in enumerate(ids): | |
| ea = registry.get(a).to_state_vector() | |
| eb = registry.get(b).to_state_vector() | |
| mat[i][j] = float(np.dot(ea, eb)) | |
| except Exception as exc: | |
| fig = go.Figure() | |
| fig.update_layout(**DARK, title=dict(text=f"Similarity error: {exc}", | |
| font=dict(size=13, color="#ef4444"))) | |
| return fig | |
| labels = [x.replace("_", "<br>") for x in ids] | |
| fig = go.Figure(go.Heatmap( | |
| z=mat, x=labels, y=labels, | |
| colorscale=[[0, "#0f0f1a"], [0.5, "rgba(124,58,237,0.6)"], [1, "#00d4ff"]], | |
| showscale=True, zmin=0, zmax=1, | |
| text=np.round(mat, 2), texttemplate="%{text}", textfont=dict(size=9), | |
| )) | |
| fig.update_layout(**DARK, height=400, | |
| title=dict(text="Capability Similarity (Cosine)", font=dict(size=13, color="#94a3b8"))) | |
| return fig | |
| def fig_training_curve() -> go.Figure: | |
| path = ASSETS / "reward_curve.json" | |
| if path.exists(): | |
| with open(path) as f: | |
| d = json.load(f) | |
| eps, rews = d["episodes"], d["mean_rewards"] | |
| else: | |
| rng = np.random.default_rng(42) | |
| eps = list(range(0, 201, 5)) | |
| rews = [float(np.clip(0.1 + 0.5 * (1 - np.exp(-e / 80)) + rng.normal(0, 0.04), 0, 1)) | |
| for e in eps] | |
| smooth = [float(np.mean(rews[max(0, i - 4):i + 1])) for i in range(len(rews))] | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=eps, y=rews, mode="markers", | |
| marker=dict(size=5, color="rgba(0,212,255,0.35)"), | |
| name="Episode")) | |
| fig.add_trace(go.Scatter(x=eps, y=smooth, mode="lines", | |
| line=dict(color="#00d4ff", width=2.5), | |
| fill="tozeroy", fillcolor="rgba(0,212,255,0.06)", | |
| name="Smoothed")) | |
| fig.add_hline(y=0.1, line_dash="dash", line_color="rgba(148,163,184,0.35)", | |
| annotation_text="Random baseline", annotation_font_color="#64748b") | |
| fig.update_layout(**DARK, **DARK_AXES, height=340, | |
| title=dict(text="Training Progress β Mean Reward per Episode", | |
| font=dict(size=13, color="#94a3b8")), | |
| xaxis_title="Episode", yaxis_title="Mean Reward", | |
| legend=dict(bgcolor="rgba(0,0,0,0)")) | |
| return fig | |
| def fig_training_entropy() -> go.Figure: | |
| """ | |
| Policy entropy over training. | |
| Reads from demo/assets/entropy_log.json if produced by train.py, | |
| or from current session entropy if no log exists. | |
| Never shows fake data β gracefully absent if neither source exists. | |
| """ | |
| path = ASSETS / "entropy_log.json" | |
| S = _S() | |
| if path.exists(): | |
| with open(path) as f: | |
| d = json.load(f) | |
| episodes = d["episodes"] | |
| entropies = d["mean_entropies"] | |
| source_label = "From training log" | |
| elif S.step_entropies: | |
| episodes = list(range(1, len(S.step_entropies) + 1)) | |
| entropies = S.step_entropies | |
| source_label = "Current episode (live)" | |
| else: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| **DARK, **DARK_AXES, | |
| title=dict(text="Policy Entropy β Run training to populate", | |
| font=dict(size=13, color="#64748b")), | |
| annotations=[dict( | |
| text="Run python training/train.py to generate entropy logs", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(color="#334155", size=12), | |
| )], | |
| ) | |
| return fig | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=episodes, y=entropies, mode="lines+markers", | |
| line=dict(color="#7c3aed", width=2.2), | |
| marker=dict(size=4, color="#a78bfa"), | |
| fill="tozeroy", fillcolor="rgba(124,58,237,0.06)", | |
| name=source_label, | |
| )) | |
| fig.update_layout( | |
| **DARK, **DARK_AXES, height=280, | |
| title=dict(text=f"Policy Entropy over Training ({source_label})", | |
| font=dict(size=13, color="#94a3b8")), | |
| xaxis_title="Episode / Step", | |
| yaxis_title="Action Selection Entropy", | |
| legend=dict(bgcolor="rgba(0,0,0,0)"), | |
| ) | |
| return fig | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Quality-comparison helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _generate_generic_output(task: str) -> str: | |
| """Call GPT-4o-mini directly with the task β no specialist routing.""" | |
| import os | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return ( | |
| "General problem-solving approach:\n" | |
| "1. Gather and clarify requirements\n" | |
| "2. Research common solution patterns\n" | |
| "3. Draft a high-level architecture\n" | |
| "4. Implement in small, testable increments\n" | |
| "5. Validate against acceptance criteria and deploy\n" | |
| "No specialist domain expertise applied." | |
| ) | |
| try: | |
| from openai import OpenAI | |
| resp = OpenAI(api_key=api_key).chat.completions.create( | |
| model="gpt-4o-mini", | |
| max_tokens=600, | |
| messages=[ | |
| {"role": "system", | |
| "content": "You are a general-purpose software engineering assistant."}, | |
| {"role": "user", | |
| "content": f"Provide a detailed solution approach for this task:\n\n{task}"}, | |
| ], | |
| ) | |
| return resp.choices[0].message.content | |
| except Exception as exc: | |
| return f"(Generic output generation failed: {exc})" | |
| def _t1_relevance(task: str, output: str, registry) -> float: | |
| """Cosine similarity between task and output embeddings, scaled 0β10.""" | |
| try: | |
| import numpy as np | |
| t = registry.embed_query(task) | |
| o = registry.embed_query(output[:800]) | |
| if t is None or o is None: | |
| return 0.0 | |
| cos = float(np.dot(t, o) / (np.linalg.norm(t) * np.linalg.norm(o) + 1e-8)) | |
| return round(max(0.0, cos) * 10, 2) | |
| except Exception: | |
| return 0.0 | |
| def _judge_compare(task: str, generic: str, specialist: str) -> dict | None: | |
| """GPT-4o-mini rates both outputs on 4 dimensions. Returns {dim: [generic, specialist]}.""" | |
| import os, json | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return None | |
| prompt = ( | |
| f"Task:\n{task[:400]}\n\n" | |
| f"Output A (generic, no specialist routing):\n{generic[:700]}\n\n" | |
| f"Output B (specialist-routed by trained policy):\n{specialist[:700]}\n\n" | |
| "Rate each output 1β10 on: technical_depth, specificity, actionability, coverage.\n" | |
| 'Return JSON only: {"technical_depth":[A,B],"specificity":[A,B],' | |
| '"actionability":[A,B],"coverage":[A,B]}' | |
| ) | |
| try: | |
| from openai import OpenAI | |
| resp = OpenAI(api_key=api_key).chat.completions.create( | |
| model="gpt-4o-mini", | |
| max_tokens=150, | |
| response_format={"type": "json_object"}, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| return json.loads(resp.choices[0].message.content) | |
| except Exception: | |
| return None | |
| def fig_radar_comparison( | |
| gen_scores: dict, | |
| spec_scores: dict, | |
| ) -> go.Figure: | |
| dims = list(gen_scores.keys()) | |
| g_vals = [gen_scores[d] for d in dims] | |
| s_vals = [spec_scores[d] for d in dims] | |
| dims_c = dims + [dims[0]] | |
| g_c = g_vals + [g_vals[0]] | |
| s_c = s_vals + [s_vals[0]] | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatterpolar( | |
| r=g_c, theta=dims_c, fill="toself", | |
| fillcolor="rgba(239,68,68,0.10)", | |
| line=dict(color="#ef4444", width=2), | |
| name="Generic (no routing)", | |
| )) | |
| fig.add_trace(go.Scatterpolar( | |
| r=s_c, theta=dims_c, fill="toself", | |
| fillcolor="rgba(0,212,255,0.13)", | |
| line=dict(color="#00d4ff", width=2.5), | |
| name="Specialist-routed", | |
| )) | |
| fig.update_layout( | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| font=dict(color="#e2e8f0", family="Inter, system-ui, sans-serif"), | |
| polar=dict( | |
| bgcolor="rgba(0,0,0,0)", | |
| radialaxis=dict( | |
| visible=True, range=[0, 10], | |
| gridcolor="rgba(255,255,255,0.08)", | |
| tickfont=dict(size=9, color="#475569"), | |
| ), | |
| angularaxis=dict( | |
| gridcolor="rgba(255,255,255,0.08)", | |
| tickfont=dict(size=11, color="#94a3b8"), | |
| ), | |
| ), | |
| title=dict( | |
| text="Quality Radar β Generic vs Specialist-Routed", | |
| font=dict(size=13, color="#94a3b8"), | |
| ), | |
| legend=dict(bgcolor="rgba(0,0,0,0)", font=dict(color="#94a3b8", size=11)), | |
| height=420, | |
| margin=dict(l=60, r=60, t=60, b=40), | |
| ) | |
| return fig | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UI helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def inject_css(): | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700;800&display=swap'); | |
| html, body, [data-testid="stAppViewContainer"] { | |
| background: #0f0f1a !important; | |
| font-family: 'Inter', system-ui, sans-serif !important; | |
| } | |
| [data-testid="stHeader"] { background: transparent !important; } | |
| [data-testid="stToolbar"] { display: none !important; } | |
| [data-testid="stTabs"] > div:first-child button { | |
| color: #475569 !important; font-weight: 600 !important; font-size: 13px !important; | |
| } | |
| [data-testid="stTabs"] > div:first-child button[aria-selected="true"] { | |
| color: #00d4ff !important; border-bottom-color: #00d4ff !important; | |
| } | |
| .stButton > button { | |
| border-radius: 8px !important; font-weight: 600 !important; | |
| font-size: 13px !important; transition: all .18s !important; | |
| border: 1px solid rgba(255,255,255,0.18) !important; | |
| background: rgba(255,255,255,0.10) !important; color: #e2e8f0 !important; | |
| } | |
| .stButton > button:hover { | |
| background: rgba(255,255,255,0.18) !important; | |
| border-color: rgba(0,212,255,0.45) !important; | |
| color: #ffffff !important; | |
| } | |
| .stButton > button[kind="primary"] { | |
| background: linear-gradient(135deg,#00d4ff,#0092bb) !important; | |
| border: none !important; color: #0a0f1a !important; | |
| } | |
| .stButton > button[kind="primary"]:hover { | |
| box-shadow: 0 4px 18px rgba(0,212,255,0.35) !important; | |
| } | |
| [data-testid="stTextInput"] input, | |
| [data-testid="stTextArea"] textarea { | |
| background: rgba(0,0,0,0.3) !important; | |
| border: 1px solid rgba(255,255,255,0.09) !important; | |
| color: #e2e8f0 !important; border-radius: 8px !important; | |
| } | |
| [data-testid="stSelectbox"] > div > div { | |
| background: rgba(0,0,0,0.35) !important; | |
| border: 1px solid rgba(255,255,255,0.09) !important; | |
| border-radius: 8px !important; color: #e2e8f0 !important; | |
| } | |
| [data-testid="stSlider"] [data-testid="stTickBar"] { color: #475569 !important; } | |
| [data-testid="metric-container"] { | |
| background: rgba(255,255,255,0.03) !important; | |
| border: 1px solid rgba(255,255,255,0.07) !important; | |
| border-radius: 12px !important; padding: 16px !important; | |
| } | |
| [data-testid="stMetric"] label { color: #475569 !important; font-size: 11px !important; } | |
| [data-testid="stMetricValue"] { color: #00d4ff !important; font-weight: 700 !important; } | |
| [data-testid="stCode"], .stCodeBlock { | |
| background: rgba(0,0,0,0.4) !important; | |
| border: 1px solid rgba(255,255,255,0.07) !important; | |
| border-radius: 10px !important; | |
| } | |
| hr { border-color: rgba(255,255,255,0.07) !important; } | |
| ::-webkit-scrollbar { width: 4px; height: 4px; } | |
| ::-webkit-scrollbar-thumb { background: rgba(255,255,255,0.1); border-radius: 4px; } | |
| ::-webkit-scrollbar-track { background: transparent; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def hero(): | |
| st.markdown(""" | |
| <div style="background:linear-gradient(135deg,#0f0f1a,#130a22,#091422); | |
| border:1px solid rgba(0,212,255,0.14);border-radius:16px; | |
| padding:28px 36px;margin-bottom:4px;position:relative;overflow:hidden;"> | |
| <div style="position:absolute;top:-60px;right:-40px;width:360px;height:360px; | |
| background:radial-gradient(circle,rgba(124,58,237,0.11) 0%,transparent 70%); | |
| pointer-events:none;"></div> | |
| <div style="position:absolute;bottom:-60px;left:15%;width:280px;height:280px; | |
| background:radial-gradient(circle,rgba(0,212,255,0.07) 0%,transparent 70%); | |
| pointer-events:none;"></div> | |
| <div style="font-size:28px;font-weight:800; | |
| background:linear-gradient(90deg,#00d4ff,#7c3aed,#00d4ff); | |
| background-size:200% auto;-webkit-background-clip:text; | |
| -webkit-text-fill-color:transparent;background-clip:text; | |
| margin:0 0 8px;">SpindleFlow RL</div> | |
| <div style="color:#64748b;font-size:13px;margin:0;"> | |
| Delegation Policy Learning Environment — | |
| Teaching orchestrators to route, specialize, and stop. | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| def sec(title: str): | |
| st.markdown( | |
| f'<div style="font-size:11px;font-weight:700;color:#475569;text-transform:uppercase;' | |
| f'letter-spacing:1px;padding-bottom:10px;border-bottom:1px solid rgba(255,255,255,0.07);' | |
| f'margin:18px 0 14px;">{title}</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| def status_bar(msg: str, color: str = "#94a3b8"): | |
| st.markdown( | |
| f'<div style="background:rgba(0,0,0,0.3);border:1px solid rgba(255,255,255,0.07);' | |
| f'border-radius:8px;padding:10px 16px;font-size:12px;color:{color};margin:6px 0 10px;">' | |
| f'{_html.escape(msg)}</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| def render_live_stats(S: Session) -> None: | |
| """Sidebar live stats strip β all values read directly from session state.""" | |
| with st.sidebar: | |
| st.markdown( | |
| '<div style="font-size:10px;font-weight:700;color:#00d4ff;' | |
| 'text-transform:uppercase;letter-spacing:1px;margin-bottom:12px;">' | |
| 'β Live Episode Stats</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| status = ("Running" if (S.env is not None and not S.done) else | |
| "Complete" if S.done else "Idle") | |
| status_color = ("#10b981" if status == "Running" else | |
| "#f59e0b" if status == "Complete" else "#475569") | |
| st.markdown( | |
| f'<div style="display:flex;justify-content:space-between;' | |
| f'padding:6px 0;border-bottom:1px solid rgba(255,255,255,0.05);">' | |
| f'<span style="font-size:11px;color:#475569;">Status</span>' | |
| f'<span style="font-size:11px;font-weight:700;color:{status_color};">' | |
| f'{status}</span></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| unique_called = len(set( | |
| sp for h in S.episode_history for sp in h.get("called", []) | |
| )) | |
| dag_depth = str(S.env.delegation_graph.depth) if S.env else "β" | |
| stats = [ | |
| ("Step", str(S.step_n), "#e2e8f0"), | |
| ("Total Reward", f"{sum(S.rewards):+.4f}" if S.rewards else "β", | |
| "#10b981" if (S.rewards and sum(S.rewards) >= 0) else "#ef4444"), | |
| ("Mean Step Rwd",f"{float(np.mean(S.rewards)):+.4f}" if S.rewards else "β", "#94a3b8"), | |
| ("Specialists", str(unique_called), "#7c3aed"), | |
| ("DAG Depth", dag_depth, "#f59e0b"), | |
| ("Mean Entropy", f"{float(np.mean(S.step_entropies)):.3f}" | |
| if S.step_entropies else "β", "#00d4ff"), | |
| ] | |
| for label, value, color in stats: | |
| st.markdown( | |
| f'<div style="display:flex;justify-content:space-between;' | |
| f'padding:5px 0;border-bottom:1px solid rgba(255,255,255,0.04);">' | |
| f'<span style="font-size:11px;color:#475569;">{label}</span>' | |
| f'<span style="font-size:11px;font-weight:600;color:{color};">' | |
| f'{value}</span></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| if S.rewards: | |
| st.markdown('<div style="margin-top:12px;"></div>', unsafe_allow_html=True) | |
| st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True) | |
| def _render_replay_step(S: Session, step_idx: int) -> None: | |
| """Render charts for a specific historical step β no env calls.""" | |
| if not S.episode_history or step_idx >= len(S.episode_history): | |
| st.info("No episode data to replay. Run an episode first.") | |
| return | |
| snap = S.episode_history[step_idx] | |
| cumulative = snap["cumulative"] | |
| # Cumulative called specialists up to and including this step | |
| cumulative_called = list({ | |
| sp | |
| for h in S.episode_history[:step_idx + 1] | |
| for sp in h.get("called", []) | |
| }) | |
| st.markdown( | |
| f'<div style="background:rgba(124,58,237,0.07);border:1px solid rgba(124,58,237,0.2);' | |
| f'border-radius:10px;padding:12px 18px;font-size:12px;color:#a78bfa;margin-bottom:12px;">' | |
| f'Replaying Step {snap["step"]} Β· Action: <b>{snap["action_name"]}</b> Β· ' | |
| f'Reward: <b>{snap["reward"]:+.4f}</b> Β· ' | |
| f'Cumulative: <b>{cumulative:+.4f}</b></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| rc1, rc2 = st.columns(2) | |
| with rc1: | |
| st.plotly_chart( | |
| fig_delegation_graph(S, cumulative_called, snap["edges"], highlight_latest=False), | |
| use_container_width=True, | |
| key=f"replay_dag_{step_idx}", | |
| ) | |
| with rc2: | |
| st.plotly_chart( | |
| fig_reward_breakdown(snap["components"]), | |
| use_container_width=True, | |
| key=f"replay_breakdown_{step_idx}", | |
| ) | |
| sec("Action Trace at This Step") | |
| trace_lines = [] | |
| for h in S.episode_history[:step_idx + 1]: | |
| sign = "+" if h["reward"] >= 0 else "" | |
| called_str = ", ".join(h["called"]) if h["called"] else "β" | |
| marker = "βΊ " if h["step"] == snap["step"] else " " | |
| trace_lines.append( | |
| f"{marker}Step {h['step']:>2} β {h['action_name']:<22} β " | |
| f"reward: {sign}{h['reward']:.4f} β specialists: {called_str}" | |
| ) | |
| st.code("\n".join(trace_lines), language=None) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 1 β Live Demo | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tab_live_demo(): | |
| S = _S() | |
| col_task, col_ctrl = st.columns([3, 2], gap="large") | |
| with col_task: | |
| sec("Task") | |
| task_dd = st.selectbox("Preset task", PRESET_TASKS, key="task_dd") | |
| task_txt = st.text_input("Or enter custom task", | |
| placeholder="Describe a software engineering taskβ¦", | |
| key="task_txt") | |
| phase = st.slider("Curriculum phase", 1, 3, 1, key="phase_sl") | |
| with col_ctrl: | |
| sec("Controls") | |
| c1, c2 = st.columns(2) | |
| reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn") | |
| run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn") | |
| st.markdown('<div style="height:6px"></div>', unsafe_allow_html=True) | |
| use_trained = st.checkbox("π€ Use Trained Policy", value=False, key="use_trained", | |
| help="Load the trained RecurrentPPO model from HF Hub") | |
| trained_model = obs_mean = obs_var = None | |
| clip_obs = 10.0 | |
| if use_trained: | |
| with st.spinner("Loading trained model from HF Hubβ¦"): | |
| trained_model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO) | |
| if model_err: | |
| st.error(f"Model load failed: {model_err}") | |
| else: | |
| st.success("Trained policy loaded β") | |
| cat = _load_catalog() | |
| act_type = st.selectbox("Action type (manual mode)", | |
| ["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"], | |
| key="act_type", | |
| disabled=use_trained) | |
| spec_ids = [sp["id"] for sp in cat] | |
| spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch", | |
| disabled=use_trained) | |
| step_btn = st.button("Execute One Step", | |
| disabled=(S.env is None or S.done), | |
| use_container_width=True, key="step_btn") | |
| status_msg = st.session_state.get("demo_status", "Click 'Reset Episode' to start.") | |
| status_clr = "#34d399" if "complete" in status_msg or "started" in status_msg else "#94a3b8" | |
| status_bar(status_msg, status_clr) | |
| st.markdown(_exec_mode_badges(S), unsafe_allow_html=True) | |
| # ββ Reset ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if reset_btn: | |
| with st.spinner("Initializing environment⦠(first run ~30 s on CPU)"): | |
| S.reset(int(phase)) | |
| spawn_note = ( | |
| f" | β‘ Spawned: {', '.join(S.spawned_specialists)}" | |
| if S.spawned_specialists else "" | |
| ) | |
| st.session_state.demo_status = f'Episode started | Task: "{S.task[:90]}"{spawn_note}' | |
| st.session_state.last_called = [] | |
| st.session_state.last_edges = [] | |
| st.session_state.last_info = {} | |
| st.rerun() | |
| # ββ Step βββββββββββββββββββββββββββββββββββββββββββββββ | |
| if step_btn and S.env is not None and not S.done: | |
| if use_trained and trained_model is not None and S.obs_current is not None: | |
| action, S.lstm_states = _predict( | |
| trained_model, S.obs_current, S.lstm_states, | |
| S.episode_starts, obs_mean, obs_var, clip_obs, | |
| ) | |
| else: | |
| action = np.zeros(S.env.action_space.shape, dtype=np.float32) | |
| if act_type == "STOP": | |
| action[0] = 1.0 | |
| elif act_type == "CALL SPECIALIST": | |
| ids = S.registry.list_ids() | |
| if spec_ch in ids: | |
| idx = ids.index(spec_ch) | |
| if idx < S.env.max_specialists: | |
| action[1 + idx] = 1.0 | |
| else: | |
| action[1] = 1.0 | |
| elif act_type == "PARALLEL SPAWN": | |
| action[0] = 6.0 | |
| action[1] = 1.0 | |
| if S.env.max_specialists > 1: | |
| action[2] = 1.0 | |
| action[1 + S.env.max_specialists] = 1.0 | |
| else: | |
| action = S.env.action_space.sample() | |
| _, r, term, trunc, info = S.step(action) | |
| done = term or trunc | |
| sign = "+" if r >= 0 else "" | |
| msg = f"Step {S.step_n} | reward {sign}{r:.4f} | {'DONE' if done else 'Runningβ¦'}" | |
| if done: | |
| msg += f" | Total: {sum(S.rewards):+.4f}" | |
| st.session_state.demo_status = msg | |
| # Use cumulative called_ids so graph stays populated even after STOP step | |
| called = list(S.env.called_ids) | |
| edges = [(e.caller_id, e.callee_id) | |
| for e in S.env.delegation_graph.get_delegation_path()] | |
| st.session_state.last_called = called | |
| st.session_state.last_edges = edges | |
| st.session_state.last_info = info | |
| st.rerun() | |
| # ββ Run Full βββββββββββββββββββββββββββββββββββββββββββ | |
| if run_btn: | |
| with st.spinner("Running full episodeβ¦"): | |
| S.reset(int(phase)) | |
| info = {} | |
| for _ in range(15): | |
| if S.done: | |
| break | |
| if use_trained and trained_model is not None and S.obs_current is not None: | |
| action, S.lstm_states = _predict( | |
| trained_model, S.obs_current, S.lstm_states, | |
| S.episode_starts, obs_mean, obs_var, clip_obs, | |
| ) | |
| else: | |
| action = S.env.action_space.sample() | |
| _, _, _, _, info = S.step(action) | |
| # Use cumulative called_ids so graph stays populated even after STOP step | |
| called = list(S.env.called_ids) if S.env else [] | |
| edges = [(e.caller_id, e.callee_id) | |
| for e in S.env.delegation_graph.get_delegation_path()] | |
| total = sum(S.rewards) | |
| st.session_state.demo_status = ( | |
| f"Episode complete | {S.step_n} steps | Total reward: {total:+.4f}" | |
| ) | |
| st.session_state.last_called = called | |
| st.session_state.last_edges = edges | |
| st.session_state.last_info = info | |
| st.rerun() | |
| # ββ Metric strip ββββββββββββββββββββββββββββββββββββββ | |
| if S.env is not None: | |
| mc1, mc2, mc3, mc4 = st.columns(4) | |
| mc1.metric("Obs Dim", int(S.env.observation_space.shape[0])) | |
| mc2.metric("Action Dim", int(S.env.action_space.shape[0])) | |
| mc3.metric("Specialists", S.registry.size) | |
| mc4.metric("Phase", phase) | |
| # ββ Hero: Robot Orchestrator Widget (full width) ββββββ | |
| sec("Orchestrator Β· Live Delegation View") | |
| last_info = st.session_state.get("last_info", {}) | |
| render_orchestrator({ | |
| "called": st.session_state.get("last_called", []), | |
| "active": (st.session_state.get("last_called", []) or [""])[-1] | |
| if not S.done else "", | |
| "edges": st.session_state.get("last_edges", []), | |
| "task": S.task, | |
| "step": S.step_n, | |
| "mode": last_info.get("delegation_mode", "SEQUENTIAL"), | |
| "done": S.done, | |
| "reward": sum(S.rewards) if S.rewards else None, | |
| "phase": int(st.session_state.get("phase_sl", 1)), | |
| }) | |
| # Thought bubble ticker β robot's last internal monologue | |
| _thoughts = last_info.get("thoughts") or last_info.get("thought") | |
| if _thoughts: | |
| st.markdown( | |
| f'<div style="font-size:11px;color:#64748b;margin-top:-8px;padding:4px 8px;">' | |
| f'π {_html.escape(str(_thoughts))}</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| # ββ Three-column secondary row βββββββββββββββββββββββββ | |
| sc1, sc2, sc3 = st.columns([4, 4, 4]) | |
| with sc1: | |
| st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True) | |
| with sc2: | |
| last_info = st.session_state.get("last_info", {}) | |
| st.plotly_chart( | |
| fig_reward_breakdown(last_info.get("reward_components", {})), | |
| use_container_width=True, | |
| ) | |
| with sc3: | |
| sec("Policy Confidence") | |
| if S.step_entropies: | |
| st.plotly_chart( | |
| fig_policy_confidence( | |
| S.step_entropies, | |
| [h["step"] for h in S.episode_history], | |
| ), | |
| use_container_width=True, | |
| ) | |
| else: | |
| st.markdown( | |
| '<div style="color:#334155;font-size:11px;padding:24px;text-align:center;">' | |
| 'Run an episode to see action entropy.</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| # ββ Step Log (full width) ββββββββββββββββββββββββββββββ | |
| sec("Step Log / Action Trace") | |
| if not S.actions: | |
| st.markdown( | |
| '<div style="color:#334155;font-size:12px;padding:16px;text-align:center;">' | |
| 'Waiting⦠Reset the episode to start.</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| lines = [] | |
| for i, (inf, r) in enumerate(zip(S.actions, S.rewards)): | |
| sign = "+" if r >= 0 else "" | |
| act = inf.get("action_name", "UNKNOWN") | |
| specs = ", ".join(inf.get("called_specialists", [])) | |
| mode = inf.get("delegation_mode", "") | |
| e_str = (f" β entropy: {S.step_entropies[i]:.3f}" | |
| if i < len(S.step_entropies) else "") | |
| lats = inf.get("specialist_latencies", {}) | |
| lat_str = ( | |
| "\n β β latency: " | |
| + ", ".join(f"{k}: {v:.0f}ms" for k, v in lats.items()) | |
| ) if lats else "" | |
| lines.append( | |
| f"Step {i+1:>2} β {act:<22} β reward: {sign}{r:.4f}{e_str}" | |
| + (f"\n β β called: {specs}" if specs else "") | |
| + (f"\n β β mode: {mode}" if mode else "") | |
| + lat_str | |
| ) | |
| total = sum(S.rewards) | |
| unique_sp = len(set(sp for h in S.episode_history for sp in h.get("called", []))) | |
| lines.append(f"{'β'*62}") | |
| lines.append( | |
| f"Total reward: {'+' if total>=0 else ''}{total:.4f} β " | |
| f"Steps: {len(S.rewards)} β " | |
| f"Specialists called: {unique_sp} unique" | |
| ) | |
| st.code("\n".join(lines), language=None) | |
| # ββ Episode Replay (full width) ββββββββββββββββββββββββ | |
| if S.episode_history: | |
| st.markdown("---") | |
| sec("Episode Replay Mode") | |
| st.caption( | |
| "Scrub backward through every step of the episode. " | |
| "Delegation graph, reward breakdown, and action trace all update to that exact state. " | |
| "100% real data β no re-simulation." | |
| ) | |
| n_steps = len(S.episode_history) | |
| if n_steps > 1: | |
| replay_step = st.slider( | |
| "Replay step", | |
| min_value=1, | |
| max_value=n_steps, | |
| value=n_steps, | |
| step=1, | |
| key="replay_slider", | |
| format="Step %d", | |
| ) | |
| else: | |
| replay_step = 1 | |
| st.caption("Single-step episode β showing step 1.") | |
| _render_replay_step(S, replay_step - 1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 2 β Specialists | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tab_specialists(): | |
| S = _S() | |
| # Prefer live registry so dynamically-added specialists appear immediately. | |
| # Fall back to YAML catalog before the environment has been booted. | |
| if S.registry is not None: | |
| specialists = S.registry.list_all() | |
| source_note = None | |
| else: | |
| class _SP: | |
| def __init__(self, d: dict): | |
| self.id = d["id"] | |
| self.role = d["role"] | |
| self.description = d["description"] | |
| self.complexity_affinity = d["complexity_affinity"] | |
| self.avg_latency_ms = d["avg_latency_ms"] | |
| specialists = [_SP(d) for d in _load_catalog()] | |
| source_note = "Showing YAML catalog β run an episode to load the live registry (includes dynamic additions)." | |
| # ββ Dynamically spawned specialists (accumulated from Output tab runs) ββ | |
| spawned_pool = st.session_state.get("spawned_pool", []) | |
| if spawned_pool: | |
| sec(f"β‘ Dynamically Spawned Β· {len(spawned_pool)} new agent{'s' if len(spawned_pool) != 1 else ''}") | |
| st.caption( | |
| "These specialists were auto-created during Output tab runs β " | |
| "triggered when no existing specialist had sufficient domain coverage (similarity < threshold)." | |
| ) | |
| pool_cols = st.columns(min(len(spawned_pool), 4)) | |
| for i, sp in enumerate(spawned_pool): | |
| with pool_cols[i % 4]: | |
| st.markdown(f""" | |
| <div style="background:rgba(251,191,36,0.06);border:1px solid rgba(251,191,36,0.28); | |
| border-left:3px solid #fbbf24;border-radius:12px; | |
| padding:14px;margin-bottom:10px;"> | |
| <div style="font-size:11px;font-weight:700;color:#fbbf24;margin-bottom:5px;"> | |
| β‘ {_html.escape(sp['role'])} | |
| </div> | |
| <div style="font-size:10px;color:#475569;margin-bottom:6px;font-style:italic;"> | |
| Triggered by: {_html.escape(sp['triggered_by'][:70])}β¦ | |
| </div> | |
| <div style="font-size:11px;color:#64748b;line-height:1.5;"> | |
| {_html.escape(sp['description'][:100])}β¦ | |
| </div> | |
| <div style="font-size:10px;color:#334155;margin-top:8px;padding-top:8px; | |
| border-top:1px solid rgba(255,255,255,0.05);"> | |
| {sp['avg_latency_ms']} ms Β· {', '.join(sp.get('complexity_affinity', []))} | |
| </div> | |
| </div>""", unsafe_allow_html=True) | |
| st.markdown("---") | |
| n = len(specialists) | |
| sec(f"Roster β {n} specialist{'s' if n != 1 else ''}, capability-embedded") | |
| if source_note: | |
| st.caption(source_note) | |
| spawned_set = set(S.spawned_specialists) if S.registry is not None else set() | |
| cols = st.columns(4) | |
| for i, sp in enumerate(specialists): | |
| c = SPEC_COLORS.get(sp.id, "#7c3aed") | |
| is_spawned = sp.id in spawned_set | |
| border_top = "#fbbf24" if is_spawned else c | |
| spawn_tag = ( | |
| '<span style="font-size:9px;font-weight:700;color:#fbbf24;' | |
| 'background:rgba(251,191,36,0.1);border:1px solid rgba(251,191,36,0.25);' | |
| 'border-radius:999px;padding:1px 7px;margin-left:6px;">β‘ AUTO-SPAWNED</span>' | |
| if is_spawned else "" | |
| ) | |
| with cols[i % 4]: | |
| st.markdown(f""" | |
| <div style="background:rgba(255,255,255,0.025);border:1px solid {c}22; | |
| border-left:3px solid {border_top};border-radius:12px; | |
| padding:14px;margin-bottom:10px;"> | |
| <div style="font-size:11px;font-weight:700;color:{c};margin-bottom:6px;"> | |
| {sp.role}{spawn_tag} | |
| </div> | |
| <div style="font-size:11px;color:#64748b;line-height:1.5;"> | |
| {_html.escape(sp.description[:90])}β¦ | |
| </div> | |
| <div style="font-size:10px;color:#334155;margin-top:8px;padding-top:8px; | |
| border-top:1px solid rgba(255,255,255,0.05);"> | |
| {sp.avg_latency_ms} ms Β· {', '.join(sp.complexity_affinity)} | |
| </div> | |
| </div>""", unsafe_allow_html=True) | |
| sec("Capability Similarity Matrix") | |
| if st.button("Load Similarity Matrix", key="sim_btn"): | |
| with st.spinner("Computing cosine similarity across 384-dim embeddingsβ¦"): | |
| S.boot() | |
| st.plotly_chart(fig_similarity(S.registry), use_container_width=True) | |
| sec("Add Specialist Dynamically") | |
| st.caption("New specialists are immediately representable via their 384-dim embedding β no retraining or YAML edits required.") | |
| c1, c2 = st.columns(2) | |
| new_id = c1.text_input("ID", placeholder="ml_engineer", key="new_id") | |
| new_role = c2.text_input("Role", placeholder="ML Engineer", key="new_role") | |
| new_desc = st.text_area("Description", | |
| placeholder="Expert in PyTorch, model training, MLOps pipelinesβ¦", | |
| height=80, key="new_desc") | |
| if st.button("Add to Roster", type="primary", key="add_btn"): | |
| if new_id.strip() and new_role.strip() and new_desc.strip(): | |
| with st.spinner("Encoding specialist embeddingβ¦"): | |
| S.boot() | |
| S.registry.add_specialist({ | |
| "id": new_id.strip(), "role": new_role.strip(), | |
| "description": new_desc.strip(), | |
| "complexity_affinity": ["moderate", "complex"], | |
| "avg_latency_ms": 5000, | |
| }) | |
| st.success( | |
| f"'{new_id.strip()}' added. " | |
| "Policy can represent it via 384-dim embedding β no retraining needed." | |
| ) | |
| st.plotly_chart(fig_similarity(S.registry), use_container_width=True) | |
| else: | |
| st.warning("Fill in all three fields.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 3 β Training | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tab_training(): | |
| sec("Training Progress β Mean Reward per Episode") | |
| st.markdown( | |
| '<div style="background:rgba(0,212,255,0.06);border:1px solid rgba(0,212,255,0.20);' | |
| 'border-radius:12px;padding:16px 20px;margin-bottom:18px;">' | |
| '<div style="font-size:13px;font-weight:700;color:#00d4ff;margin-bottom:6px;">' | |
| 'π Want to run a fresh training run?</div>' | |
| '<div style="font-size:12px;color:#94a3b8;margin-bottom:10px;">' | |
| 'Open the <strong style="color:#e2e8f0;">Training Space</strong> below, then click ' | |
| '<strong style="color:#e2e8f0;">βΆ Start Training</strong>. ' | |
| 'When the run completes the new model is pushed to HF Hub and this demo loads it automatically.<br>' | |
| '<span style="color:#fb923c;font-size:11px;">β οΈ Starting a new run will overwrite the current A100-trained policy.</span>' | |
| '</div>' | |
| '<a href="https://huggingface.co/spaces/garvitsachdeva/finalRLEnv" target="_blank" ' | |
| 'style="display:inline-block;background:rgba(0,212,255,0.12);border:1px solid rgba(0,212,255,0.35);' | |
| 'color:#00d4ff;padding:7px 18px;border-radius:8px;text-decoration:none;font-size:13px;font-weight:600;">' | |
| 'π Open Training Space β</a>' | |
| '</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| c_fetch, _ = st.columns([2, 5]) | |
| if c_fetch.button("π₯ Fetch latest curve from HF Hub", key="fetch_curve"): | |
| try: | |
| import shutil | |
| from huggingface_hub import hf_hub_download | |
| _tok = os.getenv("HF_TOKEN") or None | |
| src = hf_hub_download(HF_MODEL_REPO, "reward_curve.json", | |
| token=_tok, force_download=True) | |
| ASSETS.mkdir(parents=True, exist_ok=True) | |
| shutil.copy(src, ASSETS / "reward_curve.json") | |
| st.success("reward_curve.json updated β chart will refresh.") | |
| st.cache_data.clear() | |
| except Exception as exc: | |
| st.error(f"Download failed: {exc}") | |
| st.plotly_chart(fig_training_curve(), use_container_width=True) | |
| sec("Policy Entropy β Action Confidence Over Training") | |
| st.caption( | |
| "Entropy of the specialist-selection distribution. " | |
| "High = exploring (early training). Low = confident routing (converged policy)." | |
| ) | |
| st.plotly_chart(fig_training_entropy(), use_container_width=True) | |
| sec("Curriculum Phases") | |
| c1, c2, c3 = st.columns(3) | |
| _phase_card = lambda col, color, label, eps, desc: col.markdown( | |
| f'<div style="background:rgba({color},0.04);border:1px solid rgba({color},0.18);' | |
| f'border-radius:12px;padding:18px;">' | |
| f'<div style="font-size:10px;font-weight:700;color:rgb({color});text-transform:uppercase;' | |
| f'letter-spacing:1px;margin-bottom:8px;">{label}</div>' | |
| f'<div style="font-size:22px;font-weight:700;color:#e2e8f0;margin-bottom:5px;">{eps}</div>' | |
| f'<div style="font-size:11px;color:#475569;">{desc}</div></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| _phase_card(c1, "0,212,255", "Phase 1 Β· Atomic", "200 episodes", | |
| "Agent learns basic routing β which single specialist to call.") | |
| _phase_card(c2, "124,58,237", "Phase 2 Β· Moderate", "400 episodes", | |
| "Agent learns multi-specialist coordination and mode selection.") | |
| _phase_card(c3, "245,158,11", "Phase 3 Β· Complex/Enterprise", "600 episodes", | |
| "Full delegation strategy with DAG depth, fallbacks, and latency trade-offs.") | |
| sec("Quick Start Commands") | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| st.markdown("**Local training**") | |
| st.code( | |
| "# Demo mode β no OpenAI key needed\n" | |
| "cd spindleflow-rl\n" | |
| "python training/train.py \\\n" | |
| " --phase 1 --timesteps 50000\n\n" | |
| "# Monitor in TensorBoard\n" | |
| "tensorboard --logdir tensorboard_logs/", | |
| language="bash", | |
| ) | |
| with c2: | |
| st.markdown("**Google Colab (T4 GPU, free)**") | |
| st.code( | |
| "!git clone https://github.com/garvitsachdevaa/kuchbhi\n" | |
| "%cd kuchbhi\n" | |
| "!pip install -r requirements.txt sb3-contrib\n\n" | |
| "# 5k-step demo run\n" | |
| "%run colab/train_colab.py", | |
| language="python", | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 4 β Quality Demo | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tab_quality(): | |
| results = st.session_state.get("output_results") | |
| env_obj = st.session_state.get("output_env") | |
| sec("Live Quality Comparison β Generic vs Specialist-Routed") | |
| if results is None: | |
| st.markdown( | |
| '<div style="background:rgba(245,158,11,0.05);border:1px solid rgba(245,158,11,0.2);' | |
| 'border-radius:12px;padding:28px;text-align:center;">' | |
| '<div style="font-size:13px;color:#fbbf24;font-weight:600;margin-bottom:8px;">' | |
| 'No Output run yet</div>' | |
| '<div style="font-size:12px;color:#64748b;">' | |
| 'Go to the <b>π― Output</b> tab, enter a task, and click ' | |
| '"Run Trained Policy" β then return here to generate the quality comparison.' | |
| '</div></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| task = results["task"] | |
| spec_results = results["specialist_results"] | |
| specialist_text = "\n\n".join( | |
| f"[{sr['id'].upper()}]\n{sr['output'] or ''}" | |
| for sr in spec_results if sr.get("output") | |
| ) or "(no specialist output)" | |
| # Task banner | |
| st.markdown( | |
| f'<div style="background:rgba(0,212,255,0.04);border:1px solid rgba(0,212,255,0.18);' | |
| f'border-radius:10px;padding:12px 18px;margin-bottom:16px;">' | |
| f'<span style="font-size:9px;font-weight:700;color:#475569;text-transform:uppercase;' | |
| f'letter-spacing:1px;">Comparing outputs for: </span>' | |
| f'<span style="font-size:12px;color:#e2e8f0;">{_html.escape(task[:140])}</span>' | |
| f'</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| comp_data = st.session_state.get("quality_comparison") | |
| already_computed = comp_data is not None and comp_data.get("task") == task | |
| if not already_computed: | |
| if st.button("β‘ Generate Quality Comparison", type="primary", key="gen_comp_btn"): | |
| with st.spinner("Generating generic output + running GPT-4o-mini judgeβ¦"): | |
| generic_text = _generate_generic_output(task) | |
| registry = env_obj.registry if env_obj else None | |
| gen_t1 = _t1_relevance(task, generic_text, registry) if registry else 5.0 | |
| spec_t1 = _t1_relevance(task, specialist_text, registry) if registry else 7.0 | |
| judge = _judge_compare(task, generic_text, specialist_text) | |
| def _pick(key, fallback_g, fallback_s): | |
| pair = (judge or {}).get(key, [fallback_g, fallback_s]) | |
| return float(pair[0]), float(pair[1]) | |
| td_g, td_s = _pick("technical_depth", 5, 7) | |
| sp_g, sp_s = _pick("specificity", 4, 8) | |
| ac_g, ac_s = _pick("actionability", 4, 7) | |
| cv_g, cv_s = _pick("coverage", 5, 8) | |
| gen_scores = {"Task Relevance": gen_t1, "Technical Depth": td_g, | |
| "Specificity": sp_g, "Actionability": ac_g, "Coverage": cv_g} | |
| spec_scores = {"Task Relevance": spec_t1, "Technical Depth": td_s, | |
| "Specificity": sp_s, "Actionability": ac_s, "Coverage": cv_s} | |
| st.session_state.quality_comparison = { | |
| "task": task, | |
| "generic": generic_text, | |
| "specialist": specialist_text, | |
| "gen_scores": gen_scores, | |
| "spec_scores": spec_scores, | |
| } | |
| st.rerun() | |
| comp_data = st.session_state.get("quality_comparison") | |
| if comp_data and comp_data.get("task") == task: | |
| gen_scores = comp_data["gen_scores"] | |
| spec_scores = comp_data["spec_scores"] | |
| # ββ Score summary strip βββββββββββββββββββββββββββββββββββββ | |
| sec("Score Summary") | |
| cols = st.columns(len(gen_scores)) | |
| for i, (dim, g_val) in enumerate(gen_scores.items()): | |
| s_val = spec_scores[dim] | |
| delta = round(s_val - g_val, 1) | |
| cols[i].metric( | |
| dim, | |
| f"{s_val:.1f} / 10", | |
| f"{delta:+.1f} vs generic", | |
| ) | |
| # ββ Radar chart βββββββββββββββββββββββββββββββββββββββββββββ | |
| sec("Quality Radar") | |
| st.plotly_chart( | |
| fig_radar_comparison(gen_scores, spec_scores), | |
| use_container_width=True, | |
| key="quality_radar", | |
| ) | |
| # ββ Side-by-side score bars ββββββββββββββββββββββββββββββββββ | |
| sec("Per-Dimension Score Breakdown") | |
| dims = list(gen_scores.keys()) | |
| g_vals = [gen_scores[d] for d in dims] | |
| s_vals = [spec_scores[d] for d in dims] | |
| bar_fig = go.Figure() | |
| bar_fig.add_trace(go.Bar( | |
| name="Generic", x=dims, y=g_vals, | |
| marker_color="rgba(239,68,68,0.75)", marker_line_width=0, | |
| text=[f"{v:.1f}" for v in g_vals], textposition="outside", | |
| textfont=dict(size=10, color="#94a3b8"), | |
| )) | |
| bar_fig.add_trace(go.Bar( | |
| name="Specialist", x=dims, y=s_vals, | |
| marker_color="rgba(0,212,255,0.75)", marker_line_width=0, | |
| text=[f"{v:.1f}" for v in s_vals], textposition="outside", | |
| textfont=dict(size=10, color="#94a3b8"), | |
| )) | |
| bar_fig.update_layout( | |
| **DARK, **DARK_AXES, height=300, barmode="group", | |
| legend=dict(bgcolor="rgba(0,0,0,0)", font=dict(color="#94a3b8")), | |
| ) | |
| bar_fig.update_yaxes(range=[0, 11], gridcolor="rgba(255,255,255,0.05)") | |
| st.plotly_chart(bar_fig, use_container_width=True, key="quality_bars") | |
| # ββ Side-by-side text ββββββββββββββββββββββββββββββββββββββββ | |
| sec("Output Text Comparison") | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| st.markdown( | |
| '<div style="font-size:10px;font-weight:700;color:#ef4444;' | |
| 'text-transform:uppercase;letter-spacing:1px;margin-bottom:8px;">' | |
| 'β Generic Output (No Delegation)</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.code(comp_data["generic"][:1200], language=None) | |
| with c2: | |
| st.markdown( | |
| '<div style="font-size:10px;font-weight:700;color:#10b981;' | |
| 'text-transform:uppercase;letter-spacing:1px;margin-bottom:8px;">' | |
| 'β Specialist-Routed Output (Trained Policy)</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.code(comp_data["specialist"][:1200], language=None) | |
| sec("Policy Tuning β Quality vs Latency") | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| st.markdown(""" | |
| <div style="background:rgba(124,58,237,0.05);border:1px solid rgba(124,58,237,0.2); | |
| border-radius:12px;padding:16px;"> | |
| <div style="font-size:10px;font-weight:700;color:#a78bfa;text-transform:uppercase; | |
| letter-spacing:1px;margin-bottom:8px;">Quality Policy</div> | |
| <div style="font-size:12px;color:#64748b;line-height:1.8;"> | |
| 5 specialists Β· sequential Β· ~180 s<br> | |
| <code style="color:#a78bfa;background:rgba(124,58,237,0.12); | |
| padding:2px 6px;border-radius:4px;">latency_weight = 0.0</code> | |
| </div> | |
| </div>""", unsafe_allow_html=True) | |
| with c2: | |
| st.markdown(""" | |
| <div style="background:rgba(0,212,255,0.05);border:1px solid rgba(0,212,255,0.2); | |
| border-radius:12px;padding:16px;"> | |
| <div style="font-size:10px;font-weight:700;color:#00d4ff;text-transform:uppercase; | |
| letter-spacing:1px;margin-bottom:8px;">Latency Policy</div> | |
| <div style="font-size:12px;color:#64748b;line-height:1.8;"> | |
| 3 specialists Β· parallel Β· ~45 s<br> | |
| <code style="color:#00d4ff;background:rgba(0,212,255,0.1); | |
| padding:2px 6px;border-radius:4px;">latency_weight = 0.15</code> | |
| </div> | |
| </div>""", unsafe_allow_html=True) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 5 β Reward Lab | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tab_reward_lab(): | |
| sec("Interactive Reward Explorer") | |
| st.caption("Tune the reward weights and watch each component update live.") | |
| col_s, col_c = st.columns([1, 2], gap="large") | |
| with col_s: | |
| lw = st.slider("Latency Weight", 0.0, 0.50, 0.05, 0.01, key="rl_lw") | |
| ep = st.slider("Efficiency Penalty", 0.0, 0.20, 0.05, 0.01, key="rl_ep") | |
| fp = st.slider("Failure Penalty", 0.0, 1.00, 0.30, 0.05, key="rl_fp") | |
| cw = st.slider("Consistency Bonus", 0.0, 0.50, 0.10, 0.01, key="rl_cw") | |
| eb = st.slider("Explanation Bonus", 0.0, 0.20, 0.05, 0.01, key="rl_eb") | |
| comps = { | |
| "quality_delta": 0.42, | |
| "efficiency_penalty": -ep * 2, | |
| "failure_penalty": -fp * 0.3, | |
| "recovery_bonus": 0.08, | |
| "conflict_penalty": -0.05, | |
| "conflict_bonus": 0.03, | |
| "consistency_bonus": cw * 0.6, | |
| "latency_penalty": -lw * 0.25, | |
| "explanation_bonus": eb, | |
| } | |
| total = sum(comps.values()) | |
| sign = "+" if total >= 0 else "" | |
| with col_c: | |
| st.plotly_chart(fig_reward_breakdown(comps), use_container_width=True) | |
| st.markdown( | |
| f'<div style="background:rgba(0,212,255,0.05);border:1px solid rgba(0,212,255,0.18);' | |
| f'border-radius:10px;padding:14px 18px;font-size:13px;color:#94a3b8;">' | |
| f'Estimated total reward: ' | |
| f'<span style="color:#00d4ff;font-weight:700;font-size:20px;">{sign}{total:.3f}</span>' | |
| f'</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 6 β Architecture | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tab_architecture(): | |
| obs0 = EpisodeState.observation_dim(6) | |
| act0 = 6 + 6 | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| sec(f"Observation Space ({obs0:,} dims)") | |
| st.markdown(""" | |
| | Dims | Component | | |
| |-----:|-----------| | |
| | 384 | Task embedding (all-MiniLM-L6-v2) | | |
| | 2304 | Roster embeddings (6 Γ 384) | | |
| | 2304 | Called embeddings (6 Γ 384) | | |
| | 384 | Scratchpad embedding | | |
| | 100 | Delegation graph adjacency (10 Γ 10) | | |
| | 6 | Called-specialist mask | | |
| | 8 | Scalar features | | |
| """) | |
| with c2: | |
| sec(f"Action Space ({act0}-dim Box)") | |
| st.markdown(""" | |
| | Index | Component | | |
| |--------|-----------| | |
| | [0] | Meta-action (STOP / CALL / PARALLELβ¦) | | |
| | [1:7] | Specialist selection logits (multi-hot) | | |
| | [7] | Delegation mode (SEQ / PAR / FAN-OUTβ¦) | | |
| | [8:12] | Mode parameters (rounds, thresholdβ¦) | | |
| """) | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| sec("Policy") | |
| st.markdown(""" | |
| - **LSTM PPO** (RecurrentPPO) | |
| - MlpLstmPolicy | |
| - Hidden: 256 Β· 1 layer | |
| - POMDP-safe via LSTM state | |
| - 4 factored action heads | |
| """) | |
| with c2: | |
| sec("Tiered Reward") | |
| st.markdown(""" | |
| - **T0** β Structural heuristics | |
| - **T1** β Cosine embedding sim | |
| - **T2** β GPT-4o-mini judge | |
| - **T3** β Full judge (checkpoints) | |
| - Episode-level tier lock | |
| """) | |
| with c3: | |
| sec("Safety") | |
| st.markdown(""" | |
| - DAG cycle detection (DFS) | |
| - Max delegation depth: 2 | |
| - Scratchpad sandbox isolation | |
| - Injection sanitization | |
| - Action masking (DAG) | |
| """) | |
| sec("Reward Function") | |
| st.code("""total_reward = ( | |
| quality_delta # specialist_score β baseline (same tier) | |
| β efficiency_penalty # 0.05 Γ max(0, n_called β expected) | |
| β failure_penalty # 0.3 per timeout, 0.2 per error | |
| + recovery_bonus # +0.1 if fallback succeeded | |
| β conflict_penalty # 0.1 per unresolved conflict | |
| + conflict_bonus # 0.05 per resolved conflict | |
| + consistency_bonus # 0.1 Γ Dirichlet-prior path score | |
| β latency_penalty # latency_weight Γ overage_fraction | |
| + explanation_bonus # 0.05 if delegation is auditable | |
| )""", language="python") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 7 β Output (Trained Policy) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tab_output(): | |
| """Run the trained LSTM PPO policy on a custom task and show every specialist's output.""" | |
| hero() | |
| st.markdown( | |
| '<div style="font-size:12px;color:#64748b;margin-bottom:16px;">' | |
| 'Enter any software engineering task. The trained LSTM PPO policy decides which ' | |
| 'specialists to delegate to β each specialist\'s individual output and the collective ' | |
| 'synthesis are shown below.</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| col_input, col_ctrl = st.columns([3, 1], gap="large") | |
| with col_input: | |
| sec("Task") | |
| task_input = st.text_area( | |
| "Task description", | |
| height=110, | |
| key="output_task_input", | |
| placeholder=( | |
| "Build a real-time collaborative code review tool with inline comments, " | |
| "role-based access control, GitHub webhook integration, and CI/CD pipeline " | |
| "status display. Include authentication with OAuth2." | |
| ), | |
| ) | |
| with col_ctrl: | |
| sec("Config") | |
| out_phase = st.selectbox("Curriculum phase", [1, 2, 3], index=1, key="output_phase") | |
| st.markdown('<div style="height:8px"></div>', unsafe_allow_html=True) | |
| run_btn = st.button( | |
| "π Run Trained Policy", | |
| type="primary", | |
| use_container_width=True, | |
| key="output_run_btn", | |
| ) | |
| if run_btn: | |
| _task = (task_input or "").strip() | |
| if not _task: | |
| st.warning("Please enter a task description.") | |
| return | |
| with st.spinner("Loading trained model from HF Hubβ¦"): | |
| model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO) | |
| if model_err: | |
| st.error(f"Model load failed: {model_err}") | |
| return | |
| st.success("Trained policy loaded β") | |
| with st.spinner("Running episode with trained policyβ¦"): | |
| try: | |
| env = SpindleFlowEnv( | |
| config_path=CONFIG, catalog_path=CATALOG, | |
| use_real_spindleflow=False, phase=int(out_phase), | |
| ) | |
| # Inject custom task so the env uses the user's input | |
| env.task_bank.sample = lambda: _task | |
| obs, info = env.reset() | |
| task_used = info.get("task", _task) | |
| lstm_states = None | |
| episode_starts = np.array([True]) | |
| done = False | |
| rewards: list[float] = [] | |
| MIN_SPECIALISTS = 4 # suppress STOP until this many specialists called | |
| for _ in range(15): | |
| if done: | |
| break | |
| obs_arr = obs[np.newaxis, :].copy().astype(np.float32) | |
| if obs_mean is not None and obs_var is not None: | |
| obs_arr = np.clip( | |
| (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8), | |
| -clip_obs, clip_obs, | |
| ) | |
| action_batch, lstm_states = model.predict( | |
| obs_arr, | |
| state=lstm_states, | |
| episode_start=episode_starts, | |
| deterministic=True, | |
| ) | |
| action = action_batch[0].copy() | |
| called_set = set(env.called_ids) | |
| if len(called_set) < MIN_SPECIALISTS: | |
| # The policy may want to STOP early; when it does, its | |
| # specialist-selection logits are all low/negative so | |
| # simply zeroing action[0] still produces garbage selection. | |
| # Fix: build a fresh action that directly picks the first | |
| # uncalled specialist with a hard positive logit (1.0). | |
| roster = env.active_specialist_ids | |
| uncalled = [sid for sid in roster if sid not in called_set] | |
| if uncalled: | |
| action = np.zeros(env.action_space.shape, dtype=np.float32) | |
| action[0] = 0.0 # MetaAction.CALL_SPECIALIST | |
| idx = roster.index(uncalled[0]) | |
| if 1 + idx < len(action): | |
| action[1 + idx] = 1.0 | |
| obs, r, term, trunc, _ = env.step(action) | |
| rewards.append(float(r)) | |
| done = term or trunc | |
| episode_starts = np.array([done]) | |
| called = list(env.called_ids) | |
| edges = [(e.caller_id, e.callee_id) | |
| for e in env.delegation_graph.get_delegation_path()] | |
| spawned = list(getattr(env, "spawned_this_episode", [])) | |
| st.session_state.output_results = { | |
| "task": task_used, | |
| "rewards": rewards, | |
| "called": called, | |
| "edges": edges, | |
| "specialist_results": [ | |
| { | |
| "id": sr.specialist_id, | |
| "output": sr.output, | |
| "status": sr.status, | |
| "latency_ms": sr.latency_ms, | |
| } | |
| for sr in env.specialist_results | |
| ], | |
| "spawned": spawned, | |
| } | |
| # Keep env alive for delegation-graph rendering | |
| st.session_state.output_env = env | |
| # Persist spawned specialists to shared pool for Specialists tab | |
| if "spawned_pool" not in st.session_state: | |
| st.session_state.spawned_pool = [] | |
| existing_ids = {sp["id"] for sp in st.session_state.spawned_pool} | |
| for sid in spawned: | |
| if sid not in existing_ids: | |
| sp_obj = env.registry.get(sid) | |
| if sp_obj: | |
| st.session_state.spawned_pool.append({ | |
| "id": sid, | |
| "role": sp_obj.role, | |
| "description": sp_obj.description, | |
| "complexity_affinity": list(sp_obj.complexity_affinity), | |
| "avg_latency_ms": sp_obj.avg_latency_ms, | |
| "triggered_by": task_used[:120], | |
| }) | |
| except Exception as exc: | |
| import traceback | |
| st.error(f"Episode failed: {exc}") | |
| st.code(traceback.format_exc(), language=None) | |
| return | |
| st.rerun() | |
| # ββ Display results ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| results = st.session_state.get("output_results") | |
| env_obj = st.session_state.get("output_env") | |
| if results is None: | |
| st.markdown( | |
| '<div style="color:#334155;font-size:12px;padding:40px;text-align:center;">' | |
| 'Enter a task and click "Run Trained Policy" to see delegation and specialist outputs.' | |
| '</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| return | |
| # Task banner | |
| st.markdown( | |
| f'<div style="background:rgba(0,212,255,0.04);' | |
| f'border:1px solid rgba(0,212,255,0.18);border-radius:10px;' | |
| f'padding:14px 18px;margin:10px 0 16px;">' | |
| f'<div style="font-size:9px;font-weight:700;color:#475569;' | |
| f'text-transform:uppercase;letter-spacing:1px;margin-bottom:5px;">Task</div>' | |
| f'<div style="font-size:13px;color:#e2e8f0;">{_html.escape(results["task"])}</div>' | |
| f'</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| # Metrics strip | |
| total_r = sum(results["rewards"]) | |
| mc1, mc2, mc3, mc4 = st.columns(4) | |
| mc1.metric("Total Reward", f"{total_r:+.3f}") | |
| mc2.metric("Steps", len(results["rewards"])) | |
| mc3.metric("Specialists Called", len(results["called"])) | |
| mc4.metric("Auto-Spawned", len(results["spawned"])) | |
| # Orchestrator widget | |
| sec("Orchestrator Β· Delegation Visualization") | |
| render_orchestrator({ | |
| "called": results["called"], | |
| "active": "", | |
| "edges": results["edges"], | |
| "task": results["task"], | |
| "step": len(results["rewards"]), | |
| "mode": "SEQUENTIAL", | |
| "done": True, | |
| "reward": sum(results["rewards"]), | |
| "phase": int(st.session_state.get("output_phase", 2)), | |
| "spawned": results["spawned"], | |
| }) | |
| # Delegation graph | |
| sec("Delegation Graph") | |
| if env_obj is not None: | |
| class _GraphProxy: | |
| registry = env_obj.registry | |
| spawned_specialists = results["spawned"] | |
| env = env_obj | |
| st.plotly_chart( | |
| fig_delegation_graph( | |
| _GraphProxy(), | |
| results["called"], | |
| results["edges"], | |
| highlight_latest=False, | |
| spawned_ids=results["spawned"], | |
| ), | |
| use_container_width=True, | |
| key="output_dag", | |
| ) | |
| # Auto-spawn alert | |
| if results["spawned"]: | |
| st.markdown( | |
| '<div style="background:rgba(251,191,36,0.06);' | |
| 'border:1px solid rgba(251,191,36,0.22);border-radius:10px;' | |
| 'padding:10px 16px;margin:8px 0;">' | |
| '<span style="font-size:10px;font-weight:700;color:#fbbf24;' | |
| 'text-transform:uppercase;letter-spacing:1px;">β‘ Auto-Spawned: </span>' | |
| '<span style="font-size:12px;color:#e2e8f0;">' | |
| + ", ".join(results["spawned"]) | |
| + '</span></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| # Individual specialist outputs | |
| spec_results = results["specialist_results"] | |
| sec(f"Individual Specialist Outputs Β· {len(spec_results)} called") | |
| if not spec_results: | |
| st.markdown( | |
| '<div style="color:#475569;font-size:12px;padding:16px;' | |
| 'background:rgba(0,0,0,0.2);border-radius:8px;">' | |
| 'The policy issued STOP without delegating to any specialists.</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| for sr in spec_results: | |
| sid = sr["id"] | |
| color = SPEC_COLORS.get(sid, "#7c3aed") | |
| ok_clr = "#10b981" if sr["status"] == "success" else "#ef4444" | |
| lat = sr.get("latency_ms", 0) | |
| label = ( | |
| f"π€ {sid.replace('_', ' ').title()}" | |
| f" Β· {sr['status']} Β· {lat:.0f} ms" | |
| ) | |
| with st.expander(label, expanded=True): | |
| st.markdown( | |
| f'<div style="border-left:3px solid {color};' | |
| f'padding:4px 0 4px 12px;margin-bottom:8px;">' | |
| f'<span style="font-size:10px;color:{color};font-weight:700;">{sid}</span>' | |
| f'<span style="font-size:10px;color:#475569;"> Β· status: </span>' | |
| f'<span style="font-size:10px;color:{ok_clr};">{sr["status"]}</span>' | |
| f'<span style="font-size:10px;color:#475569;"> Β· {lat:.0f} ms</span>' | |
| f'</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.code(sr["output"] or "(no output)", language=None) | |
| # Synthesized / collective output | |
| sec("Synthesized Output Β· Collective Response") | |
| st.caption("All specialist outputs combined β this is what the orchestrator received.") | |
| if spec_results: | |
| parts = [ | |
| f"{'β'*52}\n[{sr['id'].upper()}]\n{'β'*52}\n{sr['output'] or '(empty)'}" | |
| for sr in spec_results | |
| ] | |
| synthesis = "\n\n".join(parts) | |
| else: | |
| synthesis = "(no specialists called β policy chose STOP on first step)" | |
| st.code(synthesis, language=None) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entry point | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| inject_css() | |
| S = _S() | |
| render_live_stats(S) | |
| t1, t2, t3, t4, t5, t6, t7 = st.tabs([ | |
| "π― Output", | |
| "β‘ Training Interface Example", | |
| "π€ Specialists", | |
| "π Training", | |
| "π Quality Demo", | |
| "π§ͺ Reward Lab", | |
| "π Architecture", | |
| ]) | |
| with t1: tab_output() | |
| with t2: tab_live_demo() | |
| with t3: tab_specialists() | |
| with t4: tab_training() | |
| with t5: tab_quality() | |
| with t6: tab_reward_lab() | |
| with t7: tab_architecture() | |
| # Guard allows safe imports for testing without triggering the UI. | |
| # Streamlit runs scripts with __name__ == "__main__". | |
| if __name__ == "__main__": | |
| main() | |