""" SpindleFlow RL — Streamlit Dashboard ===================================== Run: cd spindleflow-rl && streamlit run demo/streamlit_app.py URL: http://localhost:8501 """ from __future__ import annotations import os, sys, json, html as _html from pathlib import Path import numpy as np from dotenv import load_dotenv load_dotenv() # load OPENAI_API_KEY (and any other vars) from .env os.environ.setdefault("HF_HUB_OFFLINE", "1") os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parent)) import streamlit as st import plotly.graph_objects as go from plotly.subplots import make_subplots from env.spindleflow_env import SpindleFlowEnv from env.state import EpisodeState from env.specialist_registry import SpecialistRegistry from orchestrator_widget import render_orchestrator # ───────────────────────────────────────────────────────── # Page config (must be first Streamlit call) # ───────────────────────────────────────────────────────── st.set_page_config( page_title="SpindleFlow RL", page_icon="⚡", layout="wide", initial_sidebar_state="collapsed", ) # ───────────────────────────────────────────────────────── # Constants # ───────────────────────────────────────────────────────── CONFIG = "configs/training_config.yaml" CATALOG = "configs/specialist_catalog.yaml" ASSETS = Path("demo/assets") SPEC_COLORS = { "frontend_react": "#00d4ff", "backend_api": "#7c3aed", "database_architect": "#f59e0b", "devops_engineer": "#10b981", "security_analyst": "#ef4444", "product_strategist": "#8b5cf6", "ux_designer": "#ec4899", "tech_writer": "#94a3b8", } @st.cache_resource def _get_preset_tasks(n: int = 8) -> list[str]: """Sample n live tasks from TaskBank at page load — no hardcoded strings.""" try: from training.task_bank import TaskBank bank = TaskBank(phase=1) return [bank.sample() for _ in range(n)] except Exception: # Fallback only if TaskBank is unavailable (e.g. missing config) return ["Describe a software engineering task requiring specialist collaboration"] PRESET_TASKS = _get_preset_tasks() HF_MODEL_REPO = "garvitsachdeva/spindleflow-rl" @st.cache_resource def _load_trained_model(hf_repo: str): """Download RecurrentPPO + VecNormalize stats from HF Hub. Returns (model, obs_mean, obs_var, clip_obs, error_str). Temporarily lifts the HF_HUB_OFFLINE flag set at module level. """ import pickle _old_hf = os.environ.pop("HF_HUB_OFFLINE", None) _old_tf = os.environ.pop("TRANSFORMERS_OFFLINE", None) try: from huggingface_hub import hf_hub_download from sb3_contrib import RecurrentPPO model = RecurrentPPO.load( hf_hub_download(hf_repo, "spindleflow_model.zip"), device="cpu" ) obs_mean = obs_var = None clip_obs = 10.0 try: stats_path = hf_hub_download(hf_repo, "vec_normalize.pkl") with open(stats_path, "rb") as f: vn = pickle.load(f) obs_mean = vn.obs_rms.mean.copy() obs_var = vn.obs_rms.var.copy() clip_obs = float(vn.clip_obs) except Exception: pass return model, obs_mean, obs_var, clip_obs, None except Exception as exc: return None, None, None, 10.0, str(exc) finally: if _old_hf is not None: os.environ["HF_HUB_OFFLINE"] = _old_hf if _old_tf is not None: os.environ["TRANSFORMERS_OFFLINE"] = _old_tf def _predict(model, obs: np.ndarray, lstm_states, episode_starts, obs_mean, obs_var, clip_obs: float): """Normalize obs and call model.predict(); return (action, new_lstm_states).""" obs_arr = obs[np.newaxis, :].copy().astype(np.float32) if obs_mean is not None and obs_var is not None: obs_arr = np.clip( (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8), -clip_obs, clip_obs, ) action_batch, new_states = model.predict( obs_arr, state=lstm_states, episode_start=episode_starts, deterministic=True, ) return action_batch[0], new_states DARK = dict( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", font=dict(color="#e2e8f0", family="Inter, system-ui, sans-serif"), margin=dict(l=44, r=20, t=44, b=40), ) DARK_AXES = dict( xaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), yaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), ) # ───────────────────────────────────────────────────────── # Session state # ───────────────────────────────────────────────────────── class Session: def __init__(self): self.env: SpindleFlowEnv | None = None self.registry: SpecialistRegistry | None = None self.rewards: list[float] = [] self.actions: list[dict] = [] self.step_n = 0 self.done = False self.task = "" # Full episode history for replay self.episode_history: list[dict] = [] # Action entropy per step (policy confidence) self.step_entropies: list[float] = [] # Observation vector stats per step self.obs_history: list[dict] = [] # Specialists auto-spawned for this episode self.spawned_specialists: list[str] = [] # Trained policy inference state self.obs_current: np.ndarray | None = None self.lstm_states = None self.episode_starts = np.array([True]) def boot(self): if self.env is None: self.env = SpindleFlowEnv( config_path=CONFIG, catalog_path=CATALOG, use_real_spindleflow=False, phase=1, ) self.registry = self.env.registry def reset(self, phase: int = 1): self.boot() self.env.phase = int(phase) obs, info = self.env.reset() self.rewards = [] self.actions = [] self.step_n = 0 self.done = False self.task = info.get("task", "") self.episode_history = [] self.step_entropies = [] self.obs_history = [] self.spawned_specialists: list[str] = list(info.get("spawned_specialists", [])) self.obs_current = obs self.lstm_states = None self.episode_starts = np.array([True]) return obs, info def step(self, action): if self.env is None or self.done: return None, 0.0, True, False, {} obs, r, term, trunc, info = self.env.step(action) self.rewards.append(r) self.actions.append(info) self.step_n += 1 self.done = term or trunc self.obs_current = obs self.episode_starts = np.array([self.done]) # Capture step snapshot for replay called = info.get("called_specialists", []) edges = [(e.caller_id, e.callee_id) for e in self.env.delegation_graph.get_delegation_path()] self.episode_history.append({ "step": self.step_n, "reward": r, "action_name": info.get("action_name", "UNKNOWN"), "called": list(called), "edges": list(edges), "components": dict(info.get("reward_components", {})), "mode": info.get("delegation_mode", ""), "cumulative": float(sum(self.rewards)), "latencies": dict(info.get("specialist_latencies", {})), }) # Compute real action entropy (specialist-selection logits) if self.env is not None: n = self.env.max_specialists spec_logits = action[1: 1 + n].copy() spec_logits = spec_logits - spec_logits.max() exp_l = np.exp(spec_logits) probs = exp_l / (exp_l.sum() + 1e-8) entropy = float(-np.sum(probs * np.log(probs + 1e-8))) self.step_entropies.append(entropy) # Capture observation norm for state trace if obs is not None: self.obs_history.append({ "step": self.step_n, "obs_norm": float(np.linalg.norm(obs)), "obs_mean": float(obs.mean()), "obs_max": float(obs.max()), }) return obs, r, term, trunc, info def _S() -> Session: if "session" not in st.session_state: st.session_state.session = Session() return st.session_state.session def _load_catalog() -> list[dict]: import yaml with open(CATALOG) as f: return yaml.safe_load(f)["specialists"] def _exec_mode_badges(S: "Session") -> str: """Return inline HTML badge strip showing execution and task-generation modes.""" import os has_key = bool(os.getenv("OPENAI_API_KEY")) llm_tasks = S.env is not None and S.env.task_bank._client is not None exec_b = ( '● LLM BASELINE' if has_key else '' '⚡ SIMULATION MODE — specialist outputs templated · set OPENAI_API_KEY for real LLM' ) task_b = ( '● LLM TASKS' if llm_tasks else '⚡ CATALOG TASKS' ) if S.env is not None else "" return ( f'
' f'{exec_b}{task_b}
' ) # ───────────────────────────────────────────────────────── # Chart builders # ───────────────────────────────────────────────────────── def fig_reward_curve(rewards: list[float]) -> go.Figure: if not rewards: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Episode Reward", font=dict(size=13, color="#64748b")), annotations=[dict(text="Reset the environment to begin", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=13))], ) return fig steps = list(range(len(rewards))) cumul = np.cumsum(rewards).tolist() fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_heights=[0.62, 0.38], vertical_spacing=0.04) fig.add_trace(go.Scatter( x=steps, y=cumul, mode="lines", line=dict(color="#00d4ff", width=2.5), fill="tozeroy", fillcolor="rgba(0,212,255,0.07)", name="Cumulative", ), row=1, col=1) fig.add_trace(go.Bar( x=steps, y=rewards, marker_color=["#10b981" if r >= 0 else "#ef4444" for r in rewards], marker_line_width=0, name="Per-step", ), row=2, col=1) fig.update_layout(**DARK, height=300, showlegend=False, title=dict(text="Episode Reward", font=dict(size=13, color="#94a3b8"))) fig.update_xaxes(gridcolor="rgba(255,255,255,0.05)") fig.update_yaxes(gridcolor="rgba(255,255,255,0.05)", title_text="Cumul.", row=1, col=1, title_font_size=10) fig.update_yaxes(title_text="Step", row=2, col=1, title_font_size=10) return fig def fig_delegation_graph( S: Session, called_ids: list[str], edges: list[tuple], highlight_latest: bool = True, spawned_ids: list[str] | None = None, ) -> go.Figure: """ Professional hierarchical DAG layout. Orchestrator at top, called specialists in middle, uncalled dimmed at bottom. """ all_ids = list(S.registry.list_ids()) if S.registry else [] called_set = set(called_ids) spawned_set = set(spawned_ids or S.spawned_specialists) uncalled = [x for x in all_ids if x not in called_set] # ── Build node positions (hierarchical layout) ─────────────────── pos = {"orchestrator": (0.5, 0.92)} n_called = len(called_ids) if n_called > 0: for i, sid in enumerate(called_ids): x = (i + 1) / (n_called + 1) pos[sid] = (x, 0.55) n_uncalled = len(uncalled) if n_uncalled > 0: for i, sid in enumerate(uncalled): x = (i + 1) / (n_uncalled + 1) pos[sid] = (x, 0.12) fig = go.Figure() # ── Background depth ring ──────────────────────────────────────── max_depth = getattr(S.env, "max_depth", 2) if S.env else 2 cur_depth = S.env.delegation_graph.depth if S.env else 0 depth_frac = cur_depth / max(max_depth, 1) ring_color = ("#10b981" if depth_frac < 0.7 else ("#f59e0b" if depth_frac < 1.0 else "#ef4444")) fig.add_shape(type="rect", x0=0.0, y0=0.0, x1=1.0, y1=1.0, line=dict(color=ring_color, width=2, dash="dot"), fillcolor="rgba(0,0,0,0)", xref="x", yref="y", ) fig.add_annotation( x=0.98, y=0.98, xref="x", yref="y", text=f"Depth {cur_depth}/{max_depth}", showarrow=False, font=dict(size=9, color=ring_color), xanchor="right", yanchor="top", ) # ── Edges ──────────────────────────────────────────────────────── latest_edge = edges[-1] if edges else None for src, dst in edges: if src not in pos or dst not in pos: continue x0, y0 = pos[src] x1, y1 = pos[dst] is_latest = (latest_edge and highlight_latest and (src, dst) == latest_edge) color = "rgba(0,212,255,0.9)" if is_latest else "rgba(0,212,255,0.45)" width = 2.5 if is_latest else 1.8 dash = "dash" if is_latest else "solid" fig.add_trace(go.Scatter( x=[x0, x1, None], y=[y0, y1, None], mode="lines", line=dict(color=color, width=width, dash=dash), hoverinfo="skip", showlegend=False, )) fig.add_annotation( ax=x0, ay=y0, x=x1, y=y1, xref="x", yref="y", axref="x", ayref="y", arrowhead=3, arrowsize=1.4, arrowwidth=2, arrowcolor=color, showarrow=True, ) # ── Orchestrator node ──────────────────────────────────────────── ox, oy = pos["orchestrator"] fig.add_trace(go.Scatter( x=[ox], y=[oy], mode="markers+text", marker=dict(size=44, color="#f59e0b", symbol="circle", line=dict(color="#fcd34d", width=2.5), opacity=1.0), text=["ORCH"], textposition="middle center", textfont=dict(size=9, color="#0a0f1a", family="Inter, sans-serif"), hovertext=["Orchestrator
Root node — makes all delegation decisions"], hoverinfo="text", showlegend=False, name="orchestrator", )) # ── Called specialist nodes ────────────────────────────────────── for sid in called_ids: if sid not in pos: continue x, y = pos[sid] c = SPEC_COLORS.get(sid, "#7c3aed") spec = S.registry.get(sid) if S.registry else None role = spec.role if spec else sid lat = f"{spec.avg_latency_ms}ms" if spec else "" is_spawned = sid in spawned_set symbol = "star" if is_spawned else "circle" size = 38 if is_spawned else 32 border_c = "#fbbf24" if is_spawned else "rgba(255,255,255,0.4)" hover_tag = " ⚡ AUTO-SPAWNED" if is_spawned else "" label = (("⚡ " if is_spawned else "") + sid).replace("_", "
") fig.add_trace(go.Scatter( x=[x], y=[y], mode="markers+text", marker=dict(size=size, color=c, symbol=symbol, line=dict(color=border_c, width=2.5), opacity=1.0), text=[label], textposition="bottom center", textfont=dict(size=8, color="#fbbf24" if is_spawned else "#e2e8f0"), hovertext=[f"{role}
Called ✓{hover_tag}
{lat}"], hoverinfo="text", showlegend=False, )) # ── Uncalled specialist nodes (dimmed) ─────────────────────────── for sid in uncalled: if sid not in pos: continue x, y = pos[sid] c = SPEC_COLORS.get(sid, "#334155") spec = S.registry.get(sid) if S.registry else None role = spec.role if spec else sid label = sid.replace("_", "
") fig.add_trace(go.Scatter( x=[x], y=[y], mode="markers+text", marker=dict(size=16, color="#1e293b", symbol="circle", line=dict(color=c, width=1), opacity=0.5), text=[label], textposition="bottom center", textfont=dict(size=7, color="rgba(148,163,184,0.45)"), hovertext=[f"{role}
Not called"], hoverinfo="text", showlegend=False, )) # ── Section labels ─────────────────────────────────────────────── fig.add_annotation(x=0.01, y=0.96, xref="x", yref="y", text="ORCHESTRATOR", showarrow=False, font=dict(size=8, color="#475569"), xanchor="left") if called_ids: fig.add_annotation(x=0.01, y=0.62, xref="x", yref="y", text="CALLED", showarrow=False, font=dict(size=8, color="#00d4ff"), xanchor="left") if uncalled: fig.add_annotation(x=0.01, y=0.19, xref="x", yref="y", text="AVAILABLE", showarrow=False, font=dict(size=8, color="#334155"), xanchor="left") fig.update_layout( **DARK, height=420, title=dict( text=(f"Delegation Graph · {len(called_ids)} specialists called" f" · Depth {cur_depth}/{max_depth}"), font=dict(size=13, color="#94a3b8"), ), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.05]), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.08]), ) return fig def fig_reward_breakdown(components: dict) -> go.Figure: if not components: components = {k: 0.0 for k in [ "quality_delta", "efficiency_penalty", "failure_penalty", "recovery_bonus", "conflict_penalty", "conflict_bonus", "consistency_bonus", "latency_penalty", "explanation_bonus", ]} names = list(components.keys()) values = [components[k] for k in names] fig = go.Figure(go.Bar( x=values, y=[n.replace("_", " ").title() for n in names], orientation="h", marker_color=["#10b981" if v >= 0 else "#ef4444" for v in values], marker_line_width=0, text=[f"{v:+.3f}" for v in values], textposition="outside", textfont=dict(color="#94a3b8", size=9), )) fig.add_vline(x=0, line_color="rgba(255,255,255,0.15)", line_width=1) fig.update_layout(**DARK, height=310, title=dict(text="Reward Breakdown", font=dict(size=13, color="#94a3b8")), xaxis=dict(gridcolor="rgba(255,255,255,0.05)", title="Value"), yaxis=dict(gridcolor="rgba(255,255,255,0.05)")) return fig def fig_policy_confidence( entropies: list[float], step_labels: list[int] | None = None, ) -> go.Figure: """ Policy confidence chart — specialist-selection entropy per step. High entropy = uncertain/exploring. Low = confident/committed. Real data from actual action vectors used each step. """ if not entropies: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Policy Confidence (Action Entropy)", font=dict(size=13, color="#64748b")), annotations=[dict(text="Run an episode to see real action entropy", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=12))], ) return fig steps = step_labels or list(range(1, len(entropies) + 1)) max_e = float(np.log(max(len(entropies), 2))) norm_e = [min(1.0, max(0.0, e / max(max_e, 1e-8))) for e in entropies] colors = [ f"rgba({int(0 + 124 * ne)},{int(212 - 154 * ne)},{int(255 - 58 * ne)},0.85)" for ne in norm_e ] fig = go.Figure() fig.add_trace(go.Bar( x=steps, y=norm_e, marker_color=colors, marker_line_width=0, name="Normalised entropy", text=[f"{e:.3f}" for e in entropies], textposition="outside", textfont=dict(size=8, color="#94a3b8"), hovertemplate="Step %{x}
Entropy: %{text}", )) fig.add_hline(y=0.5, line_dash="dot", line_color="rgba(148,163,184,0.3)", annotation_text="Mid-entropy", annotation_font_color="#475569") fig.update_layout( **DARK, height=260, title=dict(text="Policy Confidence — Specialist Selection Entropy per Step", font=dict(size=12, color="#94a3b8")), xaxis=dict(title="Episode Step", gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), yaxis=dict(title="Entropy (0=certain, 1=uniform)", range=[0, 1.15], gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), showlegend=False, ) return fig def fig_similarity(registry: SpecialistRegistry) -> go.Figure: ids = registry.list_ids() n = len(ids) if n == 0: fig = go.Figure() fig.update_layout(**DARK, title=dict(text="No specialists in registry", font=dict(size=13, color="#64748b"))) return fig missing = [sid for sid in ids if registry.get(sid).embedding is None] if missing: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Embeddings not computed — boot the environment first", font=dict(size=13, color="#64748b")), annotations=[dict(text=f"Missing embeddings: {', '.join(missing[:4])}", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=12))], ) return fig mat = np.zeros((n, n)) try: for i, a in enumerate(ids): for j, b in enumerate(ids): ea = registry.get(a).to_state_vector() eb = registry.get(b).to_state_vector() mat[i][j] = float(np.dot(ea, eb)) except Exception as exc: fig = go.Figure() fig.update_layout(**DARK, title=dict(text=f"Similarity error: {exc}", font=dict(size=13, color="#ef4444"))) return fig labels = [x.replace("_", "
") for x in ids] fig = go.Figure(go.Heatmap( z=mat, x=labels, y=labels, colorscale=[[0, "#0f0f1a"], [0.5, "rgba(124,58,237,0.6)"], [1, "#00d4ff"]], showscale=True, zmin=0, zmax=1, text=np.round(mat, 2), texttemplate="%{text}", textfont=dict(size=9), )) fig.update_layout(**DARK, height=400, title=dict(text="Capability Similarity (Cosine)", font=dict(size=13, color="#94a3b8"))) return fig def fig_training_curve() -> go.Figure: path = ASSETS / "reward_curve.json" if path.exists(): with open(path) as f: d = json.load(f) eps, rews = d["episodes"], d["mean_rewards"] else: rng = np.random.default_rng(42) eps = list(range(0, 201, 5)) rews = [float(np.clip(0.1 + 0.5 * (1 - np.exp(-e / 80)) + rng.normal(0, 0.04), 0, 1)) for e in eps] smooth = [float(np.mean(rews[max(0, i - 4):i + 1])) for i in range(len(rews))] fig = go.Figure() fig.add_trace(go.Scatter(x=eps, y=rews, mode="markers", marker=dict(size=5, color="rgba(0,212,255,0.35)"), name="Episode")) fig.add_trace(go.Scatter(x=eps, y=smooth, mode="lines", line=dict(color="#00d4ff", width=2.5), fill="tozeroy", fillcolor="rgba(0,212,255,0.06)", name="Smoothed")) fig.add_hline(y=0.1, line_dash="dash", line_color="rgba(148,163,184,0.35)", annotation_text="Random baseline", annotation_font_color="#64748b") fig.update_layout(**DARK, **DARK_AXES, height=340, title=dict(text="Training Progress — Mean Reward per Episode", font=dict(size=13, color="#94a3b8")), xaxis_title="Episode", yaxis_title="Mean Reward", legend=dict(bgcolor="rgba(0,0,0,0)")) return fig def fig_training_entropy() -> go.Figure: """ Policy entropy over training. Reads from demo/assets/entropy_log.json if produced by train.py, or from current session entropy if no log exists. Never shows fake data — gracefully absent if neither source exists. """ path = ASSETS / "entropy_log.json" S = _S() if path.exists(): with open(path) as f: d = json.load(f) episodes = d["episodes"] entropies = d["mean_entropies"] source_label = "From training log" elif S.step_entropies: episodes = list(range(1, len(S.step_entropies) + 1)) entropies = S.step_entropies source_label = "Current episode (live)" else: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Policy Entropy — Run training to populate", font=dict(size=13, color="#64748b")), annotations=[dict( text="Run python training/train.py to generate entropy logs", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=12), )], ) return fig fig = go.Figure() fig.add_trace(go.Scatter( x=episodes, y=entropies, mode="lines+markers", line=dict(color="#7c3aed", width=2.2), marker=dict(size=4, color="#a78bfa"), fill="tozeroy", fillcolor="rgba(124,58,237,0.06)", name=source_label, )) fig.update_layout( **DARK, **DARK_AXES, height=280, title=dict(text=f"Policy Entropy over Training ({source_label})", font=dict(size=13, color="#94a3b8")), xaxis_title="Episode / Step", yaxis_title="Action Selection Entropy", legend=dict(bgcolor="rgba(0,0,0,0)"), ) return fig # ───────────────────────────────────────────────────────── # UI helpers # ───────────────────────────────────────────────────────── def inject_css(): st.markdown(""" """, unsafe_allow_html=True) def hero(): st.markdown("""
SpindleFlow RL
Delegation Policy Learning Environment — Teaching orchestrators to route, specialize, and stop.
OPENENV v0 LSTM PPO 22/22 TESTS HACKATHON 2026 GENERIC MULTI-SECTOR
""", unsafe_allow_html=True) def sec(title: str): st.markdown( f'
{title}
', unsafe_allow_html=True, ) def status_bar(msg: str, color: str = "#94a3b8"): st.markdown( f'
' f'{_html.escape(msg)}
', unsafe_allow_html=True, ) def render_live_stats(S: Session) -> None: """Sidebar live stats strip — all values read directly from session state.""" with st.sidebar: st.markdown( '
' '● Live Episode Stats
', unsafe_allow_html=True, ) status = ("Running" if (S.env is not None and not S.done) else "Complete" if S.done else "Idle") status_color = ("#10b981" if status == "Running" else "#f59e0b" if status == "Complete" else "#475569") st.markdown( f'
' f'Status' f'' f'{status}
', unsafe_allow_html=True, ) unique_called = len(set( sp for h in S.episode_history for sp in h.get("called", []) )) dag_depth = str(S.env.delegation_graph.depth) if S.env else "—" stats = [ ("Step", str(S.step_n), "#e2e8f0"), ("Total Reward", f"{sum(S.rewards):+.4f}" if S.rewards else "—", "#10b981" if (S.rewards and sum(S.rewards) >= 0) else "#ef4444"), ("Mean Step Rwd",f"{float(np.mean(S.rewards)):+.4f}" if S.rewards else "—", "#94a3b8"), ("Specialists", str(unique_called), "#7c3aed"), ("DAG Depth", dag_depth, "#f59e0b"), ("Mean Entropy", f"{float(np.mean(S.step_entropies)):.3f}" if S.step_entropies else "—", "#00d4ff"), ] for label, value, color in stats: st.markdown( f'
' f'{label}' f'' f'{value}
', unsafe_allow_html=True, ) if S.rewards: st.markdown('
', unsafe_allow_html=True) st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True) def _render_replay_step(S: Session, step_idx: int) -> None: """Render charts for a specific historical step — no env calls.""" if not S.episode_history or step_idx >= len(S.episode_history): st.info("No episode data to replay. Run an episode first.") return snap = S.episode_history[step_idx] cumulative = snap["cumulative"] # Cumulative called specialists up to and including this step cumulative_called = list({ sp for h in S.episode_history[:step_idx + 1] for sp in h.get("called", []) }) st.markdown( f'
' f'Replaying Step {snap["step"]} · Action: {snap["action_name"]} · ' f'Reward: {snap["reward"]:+.4f} · ' f'Cumulative: {cumulative:+.4f}
', unsafe_allow_html=True, ) rc1, rc2 = st.columns(2) with rc1: st.plotly_chart( fig_delegation_graph(S, cumulative_called, snap["edges"], highlight_latest=False), use_container_width=True, key=f"replay_dag_{step_idx}", ) with rc2: st.plotly_chart( fig_reward_breakdown(snap["components"]), use_container_width=True, key=f"replay_breakdown_{step_idx}", ) sec("Action Trace at This Step") trace_lines = [] for h in S.episode_history[:step_idx + 1]: sign = "+" if h["reward"] >= 0 else "" called_str = ", ".join(h["called"]) if h["called"] else "—" marker = "► " if h["step"] == snap["step"] else " " trace_lines.append( f"{marker}Step {h['step']:>2} │ {h['action_name']:<22} │ " f"reward: {sign}{h['reward']:.4f} │ specialists: {called_str}" ) st.code("\n".join(trace_lines), language=None) # ───────────────────────────────────────────────────────── # Tab 1 — Live Demo # ───────────────────────────────────────────────────────── def tab_live_demo(): S = _S() col_task, col_ctrl = st.columns([3, 2], gap="large") with col_task: sec("Task") task_dd = st.selectbox("Preset task", PRESET_TASKS, key="task_dd") task_txt = st.text_input("Or enter custom task", placeholder="Describe a software engineering task…", key="task_txt") phase = st.slider("Curriculum phase", 1, 3, 1, key="phase_sl") with col_ctrl: sec("Controls") c1, c2 = st.columns(2) reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn") run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn") st.markdown('
', unsafe_allow_html=True) use_trained = st.checkbox("🤖 Use Trained Policy", value=False, key="use_trained", help="Load the trained RecurrentPPO model from HF Hub") trained_model = obs_mean = obs_var = None clip_obs = 10.0 if use_trained: with st.spinner("Loading trained model from HF Hub…"): trained_model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO) if model_err: st.error(f"Model load failed: {model_err}") else: st.success("Trained policy loaded ✓") cat = _load_catalog() act_type = st.selectbox("Action type (manual mode)", ["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"], key="act_type", disabled=use_trained) spec_ids = [sp["id"] for sp in cat] spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch", disabled=use_trained) step_btn = st.button("Execute One Step", disabled=(S.env is None or S.done), use_container_width=True, key="step_btn") status_msg = st.session_state.get("demo_status", "Click 'Reset Episode' to start.") status_clr = "#34d399" if "complete" in status_msg or "started" in status_msg else "#94a3b8" status_bar(status_msg, status_clr) st.markdown(_exec_mode_badges(S), unsafe_allow_html=True) # ── Reset ────────────────────────────────────────────── if reset_btn: with st.spinner("Initializing environment… (first run ~30 s on CPU)"): S.reset(int(phase)) spawn_note = ( f" | ⚡ Spawned: {', '.join(S.spawned_specialists)}" if S.spawned_specialists else "" ) st.session_state.demo_status = f'Episode started | Task: "{S.task[:90]}"{spawn_note}' st.session_state.last_called = [] st.session_state.last_edges = [] st.session_state.last_info = {} st.rerun() # ── Step ─────────────────────────────────────────────── if step_btn and S.env is not None and not S.done: if use_trained and trained_model is not None and S.obs_current is not None: action, S.lstm_states = _predict( trained_model, S.obs_current, S.lstm_states, S.episode_starts, obs_mean, obs_var, clip_obs, ) else: action = np.zeros(S.env.action_space.shape, dtype=np.float32) if act_type == "STOP": action[0] = 1.0 elif act_type == "CALL SPECIALIST": ids = S.registry.list_ids() if spec_ch in ids: idx = ids.index(spec_ch) if idx < S.env.max_specialists: action[1 + idx] = 1.0 else: action[1] = 1.0 elif act_type == "PARALLEL SPAWN": action[0] = 6.0 action[1] = 1.0 if S.env.max_specialists > 1: action[2] = 1.0 action[1 + S.env.max_specialists] = 1.0 else: action = S.env.action_space.sample() _, r, term, trunc, info = S.step(action) done = term or trunc sign = "+" if r >= 0 else "" msg = f"Step {S.step_n} | reward {sign}{r:.4f} | {'DONE' if done else 'Running…'}" if done: msg += f" | Total: {sum(S.rewards):+.4f}" st.session_state.demo_status = msg # Use cumulative called_ids so graph stays populated even after STOP step called = list(S.env.called_ids) edges = [(e.caller_id, e.callee_id) for e in S.env.delegation_graph.get_delegation_path()] st.session_state.last_called = called st.session_state.last_edges = edges st.session_state.last_info = info st.rerun() # ── Run Full ─────────────────────────────────────────── if run_btn: with st.spinner("Running full episode…"): S.reset(int(phase)) info = {} for _ in range(15): if S.done: break if use_trained and trained_model is not None and S.obs_current is not None: action, S.lstm_states = _predict( trained_model, S.obs_current, S.lstm_states, S.episode_starts, obs_mean, obs_var, clip_obs, ) else: action = S.env.action_space.sample() _, _, _, _, info = S.step(action) # Use cumulative called_ids so graph stays populated even after STOP step called = list(S.env.called_ids) if S.env else [] edges = [(e.caller_id, e.callee_id) for e in S.env.delegation_graph.get_delegation_path()] total = sum(S.rewards) st.session_state.demo_status = ( f"Episode complete | {S.step_n} steps | Total reward: {total:+.4f}" ) st.session_state.last_called = called st.session_state.last_edges = edges st.session_state.last_info = info st.rerun() # ── Metric strip ────────────────────────────────────── if S.env is not None: mc1, mc2, mc3, mc4 = st.columns(4) mc1.metric("Obs Dim", int(S.env.observation_space.shape[0])) mc2.metric("Action Dim", int(S.env.action_space.shape[0])) mc3.metric("Specialists", S.registry.size) mc4.metric("Phase", phase) # ── Hero: Robot Orchestrator Widget (full width) ────── sec("Orchestrator · Live Delegation View") last_info = st.session_state.get("last_info", {}) render_orchestrator({ "called": st.session_state.get("last_called", []), "active": (st.session_state.get("last_called", []) or [""])[-1] if not S.done else "", "edges": st.session_state.get("last_edges", []), "task": S.task, "step": S.step_n, "mode": last_info.get("delegation_mode", "SEQUENTIAL"), "done": S.done, "reward": sum(S.rewards) if S.rewards else None, "phase": int(st.session_state.get("phase_sl", 1)), }) # Thought bubble ticker — robot's last internal monologue _thoughts = last_info.get("thoughts") or last_info.get("thought") if _thoughts: st.markdown( f'
' f'💭 {_html.escape(str(_thoughts))}
', unsafe_allow_html=True, ) # ── Three-column secondary row ───────────────────────── sc1, sc2, sc3 = st.columns([4, 4, 4]) with sc1: st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True) with sc2: last_info = st.session_state.get("last_info", {}) st.plotly_chart( fig_reward_breakdown(last_info.get("reward_components", {})), use_container_width=True, ) with sc3: sec("Policy Confidence") if S.step_entropies: st.plotly_chart( fig_policy_confidence( S.step_entropies, [h["step"] for h in S.episode_history], ), use_container_width=True, ) else: st.markdown( '
' 'Run an episode to see action entropy.
', unsafe_allow_html=True, ) # ── Step Log (full width) ────────────────────────────── sec("Step Log / Action Trace") if not S.actions: st.markdown( '
' 'Waiting… Reset the episode to start.
', unsafe_allow_html=True, ) else: lines = [] for i, (inf, r) in enumerate(zip(S.actions, S.rewards)): sign = "+" if r >= 0 else "" act = inf.get("action_name", "UNKNOWN") specs = ", ".join(inf.get("called_specialists", [])) mode = inf.get("delegation_mode", "") e_str = (f" │ entropy: {S.step_entropies[i]:.3f}" if i < len(S.step_entropies) else "") lats = inf.get("specialist_latencies", {}) lat_str = ( "\n │ → latency: " + ", ".join(f"{k}: {v:.0f}ms" for k, v in lats.items()) ) if lats else "" lines.append( f"Step {i+1:>2} │ {act:<22} │ reward: {sign}{r:.4f}{e_str}" + (f"\n │ → called: {specs}" if specs else "") + (f"\n │ → mode: {mode}" if mode else "") + lat_str ) total = sum(S.rewards) unique_sp = len(set(sp for h in S.episode_history for sp in h.get("called", []))) lines.append(f"{'─'*62}") lines.append( f"Total reward: {'+' if total>=0 else ''}{total:.4f} │ " f"Steps: {len(S.rewards)} │ " f"Specialists called: {unique_sp} unique" ) st.code("\n".join(lines), language=None) # ── Episode Replay (full width) ──────────────────────── if S.episode_history: st.markdown("---") sec("Episode Replay Mode") st.caption( "Scrub backward through every step of the episode. " "Delegation graph, reward breakdown, and action trace all update to that exact state. " "100% real data — no re-simulation." ) n_steps = len(S.episode_history) if n_steps > 1: replay_step = st.slider( "Replay step", min_value=1, max_value=n_steps, value=n_steps, step=1, key="replay_slider", format="Step %d", ) else: replay_step = 1 st.caption("Single-step episode — showing step 1.") _render_replay_step(S, replay_step - 1) # ───────────────────────────────────────────────────────── # Tab 2 — Specialists # ───────────────────────────────────────────────────────── def tab_specialists(): S = _S() # Prefer live registry so dynamically-added specialists appear immediately. # Fall back to YAML catalog before the environment has been booted. if S.registry is not None: specialists = S.registry.list_all() source_note = None else: class _SP: def __init__(self, d: dict): self.id = d["id"] self.role = d["role"] self.description = d["description"] self.complexity_affinity = d["complexity_affinity"] self.avg_latency_ms = d["avg_latency_ms"] specialists = [_SP(d) for d in _load_catalog()] source_note = "Showing YAML catalog — run an episode to load the live registry (includes dynamic additions)." n = len(specialists) sec(f"Roster — {n} specialist{'s' if n != 1 else ''}, capability-embedded") if source_note: st.caption(source_note) spawned_set = set(S.spawned_specialists) if S.registry is not None else set() cols = st.columns(4) for i, sp in enumerate(specialists): c = SPEC_COLORS.get(sp.id, "#7c3aed") is_spawned = sp.id in spawned_set border_top = "#fbbf24" if is_spawned else c spawn_tag = ( '⚡ AUTO-SPAWNED' if is_spawned else "" ) with cols[i % 4]: st.markdown(f"""
{sp.role}{spawn_tag}
{_html.escape(sp.description[:90])}…
{sp.avg_latency_ms} ms  ·  {', '.join(sp.complexity_affinity)}
""", unsafe_allow_html=True) sec("Capability Similarity Matrix") if st.button("Load Similarity Matrix", key="sim_btn"): with st.spinner("Computing cosine similarity across 384-dim embeddings…"): S.boot() st.plotly_chart(fig_similarity(S.registry), use_container_width=True) sec("Add Specialist Dynamically") st.caption("New specialists are immediately representable via their 384-dim embedding — no retraining or YAML edits required.") c1, c2 = st.columns(2) new_id = c1.text_input("ID", placeholder="ml_engineer", key="new_id") new_role = c2.text_input("Role", placeholder="ML Engineer", key="new_role") new_desc = st.text_area("Description", placeholder="Expert in PyTorch, model training, MLOps pipelines…", height=80, key="new_desc") if st.button("Add to Roster", type="primary", key="add_btn"): if new_id.strip() and new_role.strip() and new_desc.strip(): with st.spinner("Encoding specialist embedding…"): S.boot() S.registry.add_specialist({ "id": new_id.strip(), "role": new_role.strip(), "description": new_desc.strip(), "complexity_affinity": ["moderate", "complex"], "avg_latency_ms": 5000, }) st.success( f"'{new_id.strip()}' added. " "Policy can represent it via 384-dim embedding — no retraining needed." ) st.plotly_chart(fig_similarity(S.registry), use_container_width=True) else: st.warning("Fill in all three fields.") # ───────────────────────────────────────────────────────── # Tab 3 — Training # ───────────────────────────────────────────────────────── def tab_training(): sec("Training Progress — Mean Reward per Episode") c_fetch, _ = st.columns([2, 5]) if c_fetch.button("📥 Fetch latest curve from HF Hub", key="fetch_curve"): _old_hf = os.environ.pop("HF_HUB_OFFLINE", None) _old_tf = os.environ.pop("TRANSFORMERS_OFFLINE", None) try: import shutil from huggingface_hub import hf_hub_download src = hf_hub_download(HF_MODEL_REPO, "reward_curve.json") ASSETS.mkdir(parents=True, exist_ok=True) shutil.copy(src, ASSETS / "reward_curve.json") st.success("reward_curve.json updated — chart will refresh.") st.cache_data.clear() except Exception as exc: st.error(f"Download failed: {exc}") finally: if _old_hf is not None: os.environ["HF_HUB_OFFLINE"] = _old_hf if _old_tf is not None: os.environ["TRANSFORMERS_OFFLINE"] = _old_tf st.plotly_chart(fig_training_curve(), use_container_width=True) sec("Policy Entropy — Action Confidence Over Training") st.caption( "Entropy of the specialist-selection distribution. " "High = exploring (early training). Low = confident routing (converged policy)." ) st.plotly_chart(fig_training_entropy(), use_container_width=True) sec("Curriculum Phases") c1, c2, c3 = st.columns(3) _phase_card = lambda col, color, label, eps, desc: col.markdown( f'
' f'
{label}
' f'
{eps}
' f'
{desc}
', unsafe_allow_html=True, ) _phase_card(c1, "0,212,255", "Phase 1 · Atomic", "200 episodes", "Agent learns basic routing — which single specialist to call.") _phase_card(c2, "124,58,237", "Phase 2 · Moderate", "400 episodes", "Agent learns multi-specialist coordination and mode selection.") _phase_card(c3, "245,158,11", "Phase 3 · Complex/Enterprise", "600 episodes", "Full delegation strategy with DAG depth, fallbacks, and latency trade-offs.") sec("Quick Start Commands") c1, c2 = st.columns(2) with c1: st.markdown("**Local training**") st.code( "# Demo mode — no OpenAI key needed\n" "cd spindleflow-rl\n" "python training/train.py \\\n" " --phase 1 --timesteps 50000\n\n" "# Monitor in TensorBoard\n" "tensorboard --logdir tensorboard_logs/", language="bash", ) with c2: st.markdown("**Google Colab (T4 GPU, free)**") st.code( "!git clone https://github.com/garvitsachdevaa/kuchbhi\n" "%cd kuchbhi\n" "!pip install -r requirements.txt sb3-contrib\n\n" "# 5k-step demo run\n" "%run colab/train_colab.py", language="python", ) # ───────────────────────────────────────────────────────── # Tab 4 — Quality Demo # ───────────────────────────────────────────────────────── def tab_quality(): sec("Before vs After Delegation Learning") if st.button("Load Demo Comparison", type="primary", key="load_demo"): p = ASSETS / "demo_moment_1.json" if not p.exists(): st.error("Run `python demo/precompute_demo.py` first to generate demo assets.") else: with open(p) as f: d = json.load(f) c1, c2 = st.columns(2) with c1: st.markdown( '
' 'Generalist Output (No Delegation)
', unsafe_allow_html=True, ) st.code(d["generalist_output"][:700], language=None) with c2: st.markdown( '
' 'Specialist-Routed Output (Learned Policy)
', unsafe_allow_html=True, ) st.code(d["specialist_output"][:700], language=None) sec("Policy Tuning — Quality vs Latency") c1, c2 = st.columns(2) with c1: st.markdown("""
Quality Policy
5 specialists  ·  sequential  ·  ~180 s
latency_weight = 0.0
""", unsafe_allow_html=True) with c2: st.markdown("""
Latency Policy
3 specialists  ·  parallel  ·  ~45 s
latency_weight = 0.15
""", unsafe_allow_html=True) # ───────────────────────────────────────────────────────── # Tab 5 — Reward Lab # ───────────────────────────────────────────────────────── def tab_reward_lab(): sec("Interactive Reward Explorer") st.caption("Tune the reward weights and watch each component update live.") col_s, col_c = st.columns([1, 2], gap="large") with col_s: lw = st.slider("Latency Weight", 0.0, 0.50, 0.05, 0.01, key="rl_lw") ep = st.slider("Efficiency Penalty", 0.0, 0.20, 0.05, 0.01, key="rl_ep") fp = st.slider("Failure Penalty", 0.0, 1.00, 0.30, 0.05, key="rl_fp") cw = st.slider("Consistency Bonus", 0.0, 0.50, 0.10, 0.01, key="rl_cw") eb = st.slider("Explanation Bonus", 0.0, 0.20, 0.05, 0.01, key="rl_eb") comps = { "quality_delta": 0.42, "efficiency_penalty": -ep * 2, "failure_penalty": -fp * 0.3, "recovery_bonus": 0.08, "conflict_penalty": -0.05, "conflict_bonus": 0.03, "consistency_bonus": cw * 0.6, "latency_penalty": -lw * 0.25, "explanation_bonus": eb, } total = sum(comps.values()) sign = "+" if total >= 0 else "" with col_c: st.plotly_chart(fig_reward_breakdown(comps), use_container_width=True) st.markdown( f'
' f'Estimated total reward: ' f'{sign}{total:.3f}' f'
', unsafe_allow_html=True, ) # ───────────────────────────────────────────────────────── # Tab 6 — Architecture # ───────────────────────────────────────────────────────── def tab_architecture(): obs0 = EpisodeState.observation_dim(6) act0 = 6 + 6 c1, c2 = st.columns(2) with c1: sec(f"Observation Space ({obs0:,} dims)") st.markdown(""" | Dims | Component | |-----:|-----------| | 384 | Task embedding (all-MiniLM-L6-v2) | | 2304 | Roster embeddings (6 × 384) | | 2304 | Called embeddings (6 × 384) | | 384 | Scratchpad embedding | | 100 | Delegation graph adjacency (10 × 10) | | 6 | Called-specialist mask | | 8 | Scalar features | """) with c2: sec(f"Action Space ({act0}-dim Box)") st.markdown(""" | Index | Component | |--------|-----------| | [0] | Meta-action (STOP / CALL / PARALLEL…) | | [1:7] | Specialist selection logits (multi-hot) | | [7] | Delegation mode (SEQ / PAR / FAN-OUT…) | | [8:12] | Mode parameters (rounds, threshold…) | """) c1, c2, c3 = st.columns(3) with c1: sec("Policy") st.markdown(""" - **LSTM PPO** (RecurrentPPO) - MlpLstmPolicy - Hidden: 256 · 1 layer - POMDP-safe via LSTM state - 4 factored action heads """) with c2: sec("Tiered Reward") st.markdown(""" - **T0** — Structural heuristics - **T1** — Cosine embedding sim - **T2** — GPT-4o-mini judge - **T3** — Full judge (checkpoints) - Episode-level tier lock """) with c3: sec("Safety") st.markdown(""" - DAG cycle detection (DFS) - Max delegation depth: 2 - Scratchpad sandbox isolation - Injection sanitization - Action masking (DAG) """) sec("Reward Function") st.code("""total_reward = ( quality_delta # specialist_score − baseline (same tier) − efficiency_penalty # 0.05 × max(0, n_called − expected) − failure_penalty # 0.3 per timeout, 0.2 per error + recovery_bonus # +0.1 if fallback succeeded − conflict_penalty # 0.1 per unresolved conflict + conflict_bonus # 0.05 per resolved conflict + consistency_bonus # 0.1 × Dirichlet-prior path score − latency_penalty # latency_weight × overage_fraction + explanation_bonus # 0.05 if delegation is auditable )""", language="python") # ───────────────────────────────────────────────────────── # Entry point # ───────────────────────────────────────────────────────── def main(): inject_css() hero() S = _S() render_live_stats(S) t1, t2, t3, t4, t5, t6 = st.tabs([ "⚡ Live Demo", "🤖 Specialists", "📈 Training", "🔍 Quality Demo", "🧪 Reward Lab", "🏗 Architecture", ]) with t1: tab_live_demo() with t2: tab_specialists() with t3: tab_training() with t4: tab_quality() with t5: tab_reward_lab() with t6: tab_architecture() # Guard allows safe imports for testing without triggering the UI. # Streamlit runs scripts with __name__ == "__main__". if __name__ == "__main__": main()