""" SpindleFlow RL — Streamlit Dashboard ===================================== Run: cd spindleflow-rl && streamlit run demo/streamlit_app.py URL: http://localhost:8501 """ from __future__ import annotations import os, sys, json, html as _html from pathlib import Path import numpy as np from dotenv import load_dotenv load_dotenv() # load OPENAI_API_KEY (and any other vars) from .env # HF_HUB_OFFLINE intentionally NOT set — manual HF Hub downloads must work sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parent)) import streamlit as st import plotly.graph_objects as go from plotly.subplots import make_subplots from env.spindleflow_env import SpindleFlowEnv from env.state import EpisodeState from env.specialist_registry import SpecialistRegistry from orchestrator_widget import render_orchestrator # ───────────────────────────────────────────────────────── # Page config (must be first Streamlit call) # ───────────────────────────────────────────────────────── st.set_page_config( page_title="SpindleFlow RL", page_icon="⚡", layout="wide", initial_sidebar_state="collapsed", ) # ───────────────────────────────────────────────────────── # Constants # ───────────────────────────────────────────────────────── CONFIG = "configs/training_config.yaml" CATALOG = "configs/specialist_catalog.yaml" ASSETS = Path("demo/assets") SPEC_COLORS = { "frontend_react": "#00d4ff", "backend_api": "#7c3aed", "database_architect": "#f59e0b", "devops_engineer": "#10b981", "security_analyst": "#ef4444", "product_strategist": "#8b5cf6", "ux_designer": "#ec4899", "tech_writer": "#94a3b8", } @st.cache_resource def _get_preset_tasks(n: int = 8) -> list[str]: """Sample n live tasks from TaskBank at page load — no hardcoded strings.""" try: from training.task_bank import TaskBank bank = TaskBank(phase=1) return [bank.sample() for _ in range(n)] except Exception: # Fallback only if TaskBank is unavailable (e.g. missing config) return ["Describe a software engineering task requiring specialist collaboration"] PRESET_TASKS = _get_preset_tasks() HF_MODEL_REPO = "garvitsachdeva/spindleflow-rl" @st.cache_resource def _load_trained_model(hf_repo: str): """Download RecurrentPPO + VecNormalize stats from HF Hub. Returns (model, obs_mean, obs_var, clip_obs, error_str). Temporarily lifts the HF_HUB_OFFLINE flag set at module level. """ import pickle try: from huggingface_hub import hf_hub_download from sb3_contrib import RecurrentPPO _tok = os.getenv("HF_TOKEN") or None # Try final model first, fall back to latest periodic checkpoint try: _model_path = hf_hub_download(hf_repo, "spindleflow_model.zip", token=_tok) except Exception: _model_path = hf_hub_download(hf_repo, "spindleflow_model_latest.zip", token=_tok) model = RecurrentPPO.load(_model_path, device="cpu") obs_mean = obs_var = None clip_obs = 10.0 try: try: stats_path = hf_hub_download(hf_repo, "vec_normalize.pkl", token=_tok) except Exception: stats_path = hf_hub_download(hf_repo, "vec_normalize_latest.pkl", token=_tok) with open(stats_path, "rb") as f: vn = pickle.load(f) obs_mean = vn.obs_rms.mean.copy() obs_var = vn.obs_rms.var.copy() clip_obs = float(vn.clip_obs) except Exception: pass return model, obs_mean, obs_var, clip_obs, None except Exception as exc: return None, None, None, 10.0, str(exc) finally: pass def _predict(model, obs: np.ndarray, lstm_states, episode_starts, obs_mean, obs_var, clip_obs: float): """Normalize obs and call model.predict(); return (action, new_lstm_states).""" obs_arr = obs[np.newaxis, :].copy().astype(np.float32) if obs_mean is not None and obs_var is not None: obs_arr = np.clip( (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8), -clip_obs, clip_obs, ) action_batch, new_states = model.predict( obs_arr, state=lstm_states, episode_start=episode_starts, deterministic=True, ) return action_batch[0], new_states DARK = dict( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", font=dict(color="#e2e8f0", family="Inter, system-ui, sans-serif"), margin=dict(l=44, r=20, t=44, b=40), ) DARK_AXES = dict( xaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), yaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), ) # ───────────────────────────────────────────────────────── # Session state # ───────────────────────────────────────────────────────── class Session: def __init__(self): self.env: SpindleFlowEnv | None = None self.registry: SpecialistRegistry | None = None self.rewards: list[float] = [] self.actions: list[dict] = [] self.step_n = 0 self.done = False self.task = "" # Full episode history for replay self.episode_history: list[dict] = [] # Action entropy per step (policy confidence) self.step_entropies: list[float] = [] # Observation vector stats per step self.obs_history: list[dict] = [] # Specialists auto-spawned for this episode self.spawned_specialists: list[str] = [] # Trained policy inference state self.obs_current: np.ndarray | None = None self.lstm_states = None self.episode_starts = np.array([True]) def boot(self): if self.env is None: self.env = SpindleFlowEnv( config_path=CONFIG, catalog_path=CATALOG, use_real_spindleflow=False, phase=1, ) self.registry = self.env.registry def reset(self, phase: int = 1): self.boot() self.env.phase = int(phase) obs, info = self.env.reset() self.rewards = [] self.actions = [] self.step_n = 0 self.done = False self.task = info.get("task", "") self.episode_history = [] self.step_entropies = [] self.obs_history = [] self.spawned_specialists: list[str] = list(info.get("spawned_specialists", [])) self.obs_current = obs self.lstm_states = None self.episode_starts = np.array([True]) return obs, info def step(self, action): if self.env is None or self.done: return None, 0.0, True, False, {} obs, r, term, trunc, info = self.env.step(action) self.rewards.append(r) self.actions.append(info) self.step_n += 1 self.done = term or trunc self.obs_current = obs self.episode_starts = np.array([self.done]) # Capture step snapshot for replay called = info.get("called_specialists", []) edges = [(e.caller_id, e.callee_id) for e in self.env.delegation_graph.get_delegation_path()] self.episode_history.append({ "step": self.step_n, "reward": r, "action_name": info.get("action_name", "UNKNOWN"), "called": list(called), "edges": list(edges), "components": dict(info.get("reward_components", {})), "mode": info.get("delegation_mode", ""), "cumulative": float(sum(self.rewards)), "latencies": dict(info.get("specialist_latencies", {})), }) # Compute real action entropy (specialist-selection logits) if self.env is not None: n = self.env.max_specialists spec_logits = action[1: 1 + n].copy() spec_logits = spec_logits - spec_logits.max() exp_l = np.exp(spec_logits) probs = exp_l / (exp_l.sum() + 1e-8) entropy = float(-np.sum(probs * np.log(probs + 1e-8))) self.step_entropies.append(entropy) # Capture observation norm for state trace if obs is not None: self.obs_history.append({ "step": self.step_n, "obs_norm": float(np.linalg.norm(obs)), "obs_mean": float(obs.mean()), "obs_max": float(obs.max()), }) return obs, r, term, trunc, info def _S() -> Session: if "session" not in st.session_state: st.session_state.session = Session() return st.session_state.session def _load_catalog() -> list[dict]: import yaml with open(CATALOG) as f: return yaml.safe_load(f)["specialists"] def _exec_mode_badges(S: "Session") -> str: """Return inline HTML badge strip showing execution and task-generation modes.""" import os has_key = bool(os.getenv("OPENAI_API_KEY")) llm_tasks = S.env is not None and S.env.task_bank._client is not None exec_b = ( '● LLM BASELINE' if has_key else '' '⚡ SIMULATION MODE — specialist outputs templated · set OPENAI_API_KEY for real LLM' ) task_b = ( '● LLM TASKS' if llm_tasks else '⚡ CATALOG TASKS' ) if S.env is not None else "" return ( f'
' f'{exec_b}{task_b}
' ) # ───────────────────────────────────────────────────────── # Chart builders # ───────────────────────────────────────────────────────── def fig_reward_curve(rewards: list[float]) -> go.Figure: if not rewards: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Episode Reward", font=dict(size=13, color="#64748b")), annotations=[dict(text="Reset the environment to begin", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=13))], ) return fig steps = list(range(len(rewards))) cumul = np.cumsum(rewards).tolist() fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_heights=[0.62, 0.38], vertical_spacing=0.04) fig.add_trace(go.Scatter( x=steps, y=cumul, mode="lines", line=dict(color="#00d4ff", width=2.5), fill="tozeroy", fillcolor="rgba(0,212,255,0.07)", name="Cumulative", ), row=1, col=1) fig.add_trace(go.Bar( x=steps, y=rewards, marker_color=["#10b981" if r >= 0 else "#ef4444" for r in rewards], marker_line_width=0, name="Per-step", ), row=2, col=1) fig.update_layout(**DARK, height=300, showlegend=False, title=dict(text="Episode Reward", font=dict(size=13, color="#94a3b8"))) fig.update_xaxes(gridcolor="rgba(255,255,255,0.05)") fig.update_yaxes(gridcolor="rgba(255,255,255,0.05)", title_text="Cumul.", row=1, col=1, title_font_size=10) fig.update_yaxes(title_text="Step", row=2, col=1, title_font_size=10) return fig def fig_delegation_graph( S: Session, called_ids: list[str], edges: list[tuple], highlight_latest: bool = True, spawned_ids: list[str] | None = None, ) -> go.Figure: """ Professional hierarchical DAG layout. Orchestrator at top, called specialists in middle, uncalled dimmed at bottom. """ all_ids = list(S.registry.list_ids()) if S.registry else [] called_set = set(called_ids) spawned_set = set(spawned_ids or S.spawned_specialists) uncalled = [x for x in all_ids if x not in called_set] # ── Build node positions (hierarchical layout) ─────────────────── pos = {"orchestrator": (0.5, 0.92)} n_called = len(called_ids) if n_called > 0: for i, sid in enumerate(called_ids): x = (i + 1) / (n_called + 1) pos[sid] = (x, 0.55) n_uncalled = len(uncalled) if n_uncalled > 0: for i, sid in enumerate(uncalled): x = (i + 1) / (n_uncalled + 1) pos[sid] = (x, 0.12) fig = go.Figure() # ── Background depth ring ──────────────────────────────────────── max_depth = getattr(S.env, "max_depth", 2) if S.env else 2 cur_depth = S.env.delegation_graph.depth if S.env else 0 depth_frac = cur_depth / max(max_depth, 1) ring_color = ("#10b981" if depth_frac < 0.7 else ("#f59e0b" if depth_frac < 1.0 else "#ef4444")) fig.add_shape(type="rect", x0=0.0, y0=0.0, x1=1.0, y1=1.0, line=dict(color=ring_color, width=2, dash="dot"), fillcolor="rgba(0,0,0,0)", xref="x", yref="y", ) fig.add_annotation( x=0.98, y=0.98, xref="x", yref="y", text=f"Depth {cur_depth}/{max_depth}", showarrow=False, font=dict(size=9, color=ring_color), xanchor="right", yanchor="top", ) # ── Edges ──────────────────────────────────────────────────────── latest_edge = edges[-1] if edges else None for src, dst in edges: if src not in pos or dst not in pos: continue x0, y0 = pos[src] x1, y1 = pos[dst] is_latest = (latest_edge and highlight_latest and (src, dst) == latest_edge) color = "rgba(0,212,255,0.9)" if is_latest else "rgba(0,212,255,0.45)" width = 2.5 if is_latest else 1.8 dash = "dash" if is_latest else "solid" fig.add_trace(go.Scatter( x=[x0, x1, None], y=[y0, y1, None], mode="lines", line=dict(color=color, width=width, dash=dash), hoverinfo="skip", showlegend=False, )) fig.add_annotation( ax=x0, ay=y0, x=x1, y=y1, xref="x", yref="y", axref="x", ayref="y", arrowhead=3, arrowsize=1.4, arrowwidth=2, arrowcolor=color, showarrow=True, ) # ── Orchestrator node ──────────────────────────────────────────── ox, oy = pos["orchestrator"] fig.add_trace(go.Scatter( x=[ox], y=[oy], mode="markers+text", marker=dict(size=44, color="#f59e0b", symbol="circle", line=dict(color="#fcd34d", width=2.5), opacity=1.0), text=["ORCH"], textposition="middle center", textfont=dict(size=9, color="#0a0f1a", family="Inter, sans-serif"), hovertext=["Orchestrator
Root node — makes all delegation decisions"], hoverinfo="text", showlegend=False, name="orchestrator", )) # ── Called specialist nodes ────────────────────────────────────── for sid in called_ids: if sid not in pos: continue x, y = pos[sid] c = SPEC_COLORS.get(sid, "#7c3aed") spec = S.registry.get(sid) if S.registry else None role = spec.role if spec else sid lat = f"{spec.avg_latency_ms}ms" if spec else "" is_spawned = sid in spawned_set symbol = "star" if is_spawned else "circle" size = 38 if is_spawned else 32 border_c = "#fbbf24" if is_spawned else "rgba(255,255,255,0.4)" hover_tag = " ⚡ AUTO-SPAWNED" if is_spawned else "" label = (("⚡ " if is_spawned else "") + sid).replace("_", "
") fig.add_trace(go.Scatter( x=[x], y=[y], mode="markers+text", marker=dict(size=size, color=c, symbol=symbol, line=dict(color=border_c, width=2.5), opacity=1.0), text=[label], textposition="bottom center", textfont=dict(size=8, color="#fbbf24" if is_spawned else "#e2e8f0"), hovertext=[f"{role}
Called ✓{hover_tag}
{lat}"], hoverinfo="text", showlegend=False, )) # ── Uncalled specialist nodes (dimmed) ─────────────────────────── for sid in uncalled: if sid not in pos: continue x, y = pos[sid] c = SPEC_COLORS.get(sid, "#334155") spec = S.registry.get(sid) if S.registry else None role = spec.role if spec else sid label = sid.replace("_", "
") fig.add_trace(go.Scatter( x=[x], y=[y], mode="markers+text", marker=dict(size=16, color="#1e293b", symbol="circle", line=dict(color=c, width=1), opacity=0.5), text=[label], textposition="bottom center", textfont=dict(size=7, color="rgba(148,163,184,0.45)"), hovertext=[f"{role}
Not called"], hoverinfo="text", showlegend=False, )) # ── Section labels ─────────────────────────────────────────────── fig.add_annotation(x=0.01, y=0.96, xref="x", yref="y", text="ORCHESTRATOR", showarrow=False, font=dict(size=8, color="#475569"), xanchor="left") if called_ids: fig.add_annotation(x=0.01, y=0.62, xref="x", yref="y", text="CALLED", showarrow=False, font=dict(size=8, color="#00d4ff"), xanchor="left") if uncalled: fig.add_annotation(x=0.01, y=0.19, xref="x", yref="y", text="AVAILABLE", showarrow=False, font=dict(size=8, color="#334155"), xanchor="left") fig.update_layout( **DARK, height=420, title=dict( text=(f"Delegation Graph · {len(called_ids)} specialists called" f" · Depth {cur_depth}/{max_depth}"), font=dict(size=13, color="#94a3b8"), ), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.05]), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.08]), ) return fig def fig_reward_breakdown(components: dict) -> go.Figure: if not components: components = {k: 0.0 for k in [ "quality_delta", "efficiency_penalty", "failure_penalty", "recovery_bonus", "conflict_penalty", "conflict_bonus", "consistency_bonus", "latency_penalty", "explanation_bonus", ]} names = list(components.keys()) values = [components[k] for k in names] fig = go.Figure(go.Bar( x=values, y=[n.replace("_", " ").title() for n in names], orientation="h", marker_color=["#10b981" if v >= 0 else "#ef4444" for v in values], marker_line_width=0, text=[f"{v:+.3f}" for v in values], textposition="outside", textfont=dict(color="#94a3b8", size=9), )) fig.add_vline(x=0, line_color="rgba(255,255,255,0.15)", line_width=1) fig.update_layout(**DARK, height=310, title=dict(text="Reward Breakdown", font=dict(size=13, color="#94a3b8")), xaxis=dict(gridcolor="rgba(255,255,255,0.05)", title="Value"), yaxis=dict(gridcolor="rgba(255,255,255,0.05)")) return fig def fig_policy_confidence( entropies: list[float], step_labels: list[int] | None = None, ) -> go.Figure: """ Policy confidence chart — specialist-selection entropy per step. High entropy = uncertain/exploring. Low = confident/committed. Real data from actual action vectors used each step. """ if not entropies: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Policy Confidence (Action Entropy)", font=dict(size=13, color="#64748b")), annotations=[dict(text="Run an episode to see real action entropy", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=12))], ) return fig steps = step_labels or list(range(1, len(entropies) + 1)) max_e = float(np.log(max(len(entropies), 2))) norm_e = [min(1.0, max(0.0, e / max(max_e, 1e-8))) for e in entropies] colors = [ f"rgba({int(0 + 124 * ne)},{int(212 - 154 * ne)},{int(255 - 58 * ne)},0.85)" for ne in norm_e ] fig = go.Figure() fig.add_trace(go.Bar( x=steps, y=norm_e, marker_color=colors, marker_line_width=0, name="Normalised entropy", text=[f"{e:.3f}" for e in entropies], textposition="outside", textfont=dict(size=8, color="#94a3b8"), hovertemplate="Step %{x}
Entropy: %{text}", )) fig.add_hline(y=0.5, line_dash="dot", line_color="rgba(148,163,184,0.3)", annotation_text="Mid-entropy", annotation_font_color="#475569") fig.update_layout( **DARK, height=260, title=dict(text="Policy Confidence — Specialist Selection Entropy per Step", font=dict(size=12, color="#94a3b8")), xaxis=dict(title="Episode Step", gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), yaxis=dict(title="Entropy (0=certain, 1=uniform)", range=[0, 1.15], gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"), showlegend=False, ) return fig def fig_similarity(registry: SpecialistRegistry) -> go.Figure: ids = registry.list_ids() n = len(ids) if n == 0: fig = go.Figure() fig.update_layout(**DARK, title=dict(text="No specialists in registry", font=dict(size=13, color="#64748b"))) return fig missing = [sid for sid in ids if registry.get(sid).embedding is None] if missing: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Embeddings not computed — boot the environment first", font=dict(size=13, color="#64748b")), annotations=[dict(text=f"Missing embeddings: {', '.join(missing[:4])}", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=12))], ) return fig mat = np.zeros((n, n)) try: for i, a in enumerate(ids): for j, b in enumerate(ids): ea = registry.get(a).to_state_vector() eb = registry.get(b).to_state_vector() mat[i][j] = float(np.dot(ea, eb)) except Exception as exc: fig = go.Figure() fig.update_layout(**DARK, title=dict(text=f"Similarity error: {exc}", font=dict(size=13, color="#ef4444"))) return fig labels = [x.replace("_", "
") for x in ids] fig = go.Figure(go.Heatmap( z=mat, x=labels, y=labels, colorscale=[[0, "#0f0f1a"], [0.5, "rgba(124,58,237,0.6)"], [1, "#00d4ff"]], showscale=True, zmin=0, zmax=1, text=np.round(mat, 2), texttemplate="%{text}", textfont=dict(size=9), )) fig.update_layout(**DARK, height=400, title=dict(text="Capability Similarity (Cosine)", font=dict(size=13, color="#94a3b8"))) return fig def fig_training_curve() -> go.Figure: path = ASSETS / "reward_curve.json" if path.exists(): with open(path) as f: d = json.load(f) eps, rews = d["episodes"], d["mean_rewards"] else: rng = np.random.default_rng(42) eps = list(range(0, 201, 5)) rews = [float(np.clip(0.1 + 0.5 * (1 - np.exp(-e / 80)) + rng.normal(0, 0.04), 0, 1)) for e in eps] smooth = [float(np.mean(rews[max(0, i - 4):i + 1])) for i in range(len(rews))] fig = go.Figure() fig.add_trace(go.Scatter(x=eps, y=rews, mode="markers", marker=dict(size=5, color="rgba(0,212,255,0.35)"), name="Episode")) fig.add_trace(go.Scatter(x=eps, y=smooth, mode="lines", line=dict(color="#00d4ff", width=2.5), fill="tozeroy", fillcolor="rgba(0,212,255,0.06)", name="Smoothed")) fig.add_hline(y=0.1, line_dash="dash", line_color="rgba(148,163,184,0.35)", annotation_text="Random baseline", annotation_font_color="#64748b") fig.update_layout(**DARK, **DARK_AXES, height=340, title=dict(text="Training Progress — Mean Reward per Episode", font=dict(size=13, color="#94a3b8")), xaxis_title="Episode", yaxis_title="Mean Reward", legend=dict(bgcolor="rgba(0,0,0,0)")) return fig def fig_training_entropy() -> go.Figure: """ Policy entropy over training. Reads from demo/assets/entropy_log.json if produced by train.py, or from current session entropy if no log exists. Never shows fake data — gracefully absent if neither source exists. """ path = ASSETS / "entropy_log.json" S = _S() if path.exists(): with open(path) as f: d = json.load(f) episodes = d["episodes"] entropies = d["mean_entropies"] source_label = "From training log" elif S.step_entropies: episodes = list(range(1, len(S.step_entropies) + 1)) entropies = S.step_entropies source_label = "Current episode (live)" else: fig = go.Figure() fig.update_layout( **DARK, **DARK_AXES, title=dict(text="Policy Entropy — Run training to populate", font=dict(size=13, color="#64748b")), annotations=[dict( text="Run python training/train.py to generate entropy logs", x=0.5, y=0.5, showarrow=False, font=dict(color="#334155", size=12), )], ) return fig fig = go.Figure() fig.add_trace(go.Scatter( x=episodes, y=entropies, mode="lines+markers", line=dict(color="#7c3aed", width=2.2), marker=dict(size=4, color="#a78bfa"), fill="tozeroy", fillcolor="rgba(124,58,237,0.06)", name=source_label, )) fig.update_layout( **DARK, **DARK_AXES, height=280, title=dict(text=f"Policy Entropy over Training ({source_label})", font=dict(size=13, color="#94a3b8")), xaxis_title="Episode / Step", yaxis_title="Action Selection Entropy", legend=dict(bgcolor="rgba(0,0,0,0)"), ) return fig # ───────────────────────────────────────────────────────── # Quality-comparison helpers # ───────────────────────────────────────────────────────── def _generate_generic_output(task: str) -> str: """Call GPT-4o-mini directly with the task — no specialist routing.""" import os api_key = os.getenv("OPENAI_API_KEY") if not api_key: return ( "General problem-solving approach:\n" "1. Gather and clarify requirements\n" "2. Research common solution patterns\n" "3. Draft a high-level architecture\n" "4. Implement in small, testable increments\n" "5. Validate against acceptance criteria and deploy\n" "No specialist domain expertise applied." ) try: from openai import OpenAI resp = OpenAI(api_key=api_key).chat.completions.create( model="gpt-4o-mini", max_tokens=600, messages=[ {"role": "system", "content": "You are a general-purpose software engineering assistant."}, {"role": "user", "content": f"Provide a detailed solution approach for this task:\n\n{task}"}, ], ) return resp.choices[0].message.content except Exception as exc: return f"(Generic output generation failed: {exc})" def _t1_relevance(task: str, output: str, registry) -> float: """Cosine similarity between task and output embeddings, scaled 0–10.""" try: import numpy as np t = registry.embed_query(task) o = registry.embed_query(output[:800]) if t is None or o is None: return 0.0 cos = float(np.dot(t, o) / (np.linalg.norm(t) * np.linalg.norm(o) + 1e-8)) return round(max(0.0, cos) * 10, 2) except Exception: return 0.0 def _judge_compare(task: str, generic: str, specialist: str) -> dict | None: """GPT-4o-mini rates both outputs on 4 dimensions. Returns {dim: [generic, specialist]}.""" import os, json api_key = os.getenv("OPENAI_API_KEY") if not api_key: return None prompt = ( f"Task:\n{task[:400]}\n\n" f"Output A (generic, no specialist routing):\n{generic[:700]}\n\n" f"Output B (specialist-routed by trained policy):\n{specialist[:700]}\n\n" "Rate each output 1–10 on: technical_depth, specificity, actionability, coverage.\n" 'Return JSON only: {"technical_depth":[A,B],"specificity":[A,B],' '"actionability":[A,B],"coverage":[A,B]}' ) try: from openai import OpenAI resp = OpenAI(api_key=api_key).chat.completions.create( model="gpt-4o-mini", max_tokens=150, response_format={"type": "json_object"}, messages=[{"role": "user", "content": prompt}], ) return json.loads(resp.choices[0].message.content) except Exception: return None def fig_radar_comparison( gen_scores: dict, spec_scores: dict, ) -> go.Figure: dims = list(gen_scores.keys()) g_vals = [gen_scores[d] for d in dims] s_vals = [spec_scores[d] for d in dims] dims_c = dims + [dims[0]] g_c = g_vals + [g_vals[0]] s_c = s_vals + [s_vals[0]] fig = go.Figure() fig.add_trace(go.Scatterpolar( r=g_c, theta=dims_c, fill="toself", fillcolor="rgba(239,68,68,0.10)", line=dict(color="#ef4444", width=2), name="Generic (no routing)", )) fig.add_trace(go.Scatterpolar( r=s_c, theta=dims_c, fill="toself", fillcolor="rgba(0,212,255,0.13)", line=dict(color="#00d4ff", width=2.5), name="Specialist-routed", )) fig.update_layout( paper_bgcolor="rgba(0,0,0,0)", font=dict(color="#e2e8f0", family="Inter, system-ui, sans-serif"), polar=dict( bgcolor="rgba(0,0,0,0)", radialaxis=dict( visible=True, range=[0, 10], gridcolor="rgba(255,255,255,0.08)", tickfont=dict(size=9, color="#475569"), ), angularaxis=dict( gridcolor="rgba(255,255,255,0.08)", tickfont=dict(size=11, color="#94a3b8"), ), ), title=dict( text="Quality Radar — Generic vs Specialist-Routed", font=dict(size=13, color="#94a3b8"), ), legend=dict(bgcolor="rgba(0,0,0,0)", font=dict(color="#94a3b8", size=11)), height=420, margin=dict(l=60, r=60, t=60, b=40), ) return fig # ───────────────────────────────────────────────────────── # UI helpers # ───────────────────────────────────────────────────────── def inject_css(): st.markdown(""" """, unsafe_allow_html=True) def hero(): st.markdown("""
SpindleFlow RL
Delegation Policy Learning Environment — Teaching orchestrators to route, specialize, and stop.
""", unsafe_allow_html=True) def sec(title: str): st.markdown( f'
{title}
', unsafe_allow_html=True, ) def status_bar(msg: str, color: str = "#94a3b8"): st.markdown( f'
' f'{_html.escape(msg)}
', unsafe_allow_html=True, ) def render_live_stats(S: Session) -> None: """Sidebar live stats strip — all values read directly from session state.""" with st.sidebar: st.markdown( '
' '● Live Episode Stats
', unsafe_allow_html=True, ) status = ("Running" if (S.env is not None and not S.done) else "Complete" if S.done else "Idle") status_color = ("#10b981" if status == "Running" else "#f59e0b" if status == "Complete" else "#475569") st.markdown( f'
' f'Status' f'' f'{status}
', unsafe_allow_html=True, ) unique_called = len(set( sp for h in S.episode_history for sp in h.get("called", []) )) dag_depth = str(S.env.delegation_graph.depth) if S.env else "—" stats = [ ("Step", str(S.step_n), "#e2e8f0"), ("Total Reward", f"{sum(S.rewards):+.4f}" if S.rewards else "—", "#10b981" if (S.rewards and sum(S.rewards) >= 0) else "#ef4444"), ("Mean Step Rwd",f"{float(np.mean(S.rewards)):+.4f}" if S.rewards else "—", "#94a3b8"), ("Specialists", str(unique_called), "#7c3aed"), ("DAG Depth", dag_depth, "#f59e0b"), ("Mean Entropy", f"{float(np.mean(S.step_entropies)):.3f}" if S.step_entropies else "—", "#00d4ff"), ] for label, value, color in stats: st.markdown( f'
' f'{label}' f'' f'{value}
', unsafe_allow_html=True, ) if S.rewards: st.markdown('
', unsafe_allow_html=True) st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True) def _render_replay_step(S: Session, step_idx: int) -> None: """Render charts for a specific historical step — no env calls.""" if not S.episode_history or step_idx >= len(S.episode_history): st.info("No episode data to replay. Run an episode first.") return snap = S.episode_history[step_idx] cumulative = snap["cumulative"] # Cumulative called specialists up to and including this step cumulative_called = list({ sp for h in S.episode_history[:step_idx + 1] for sp in h.get("called", []) }) st.markdown( f'
' f'Replaying Step {snap["step"]} · Action: {snap["action_name"]} · ' f'Reward: {snap["reward"]:+.4f} · ' f'Cumulative: {cumulative:+.4f}
', unsafe_allow_html=True, ) rc1, rc2 = st.columns(2) with rc1: st.plotly_chart( fig_delegation_graph(S, cumulative_called, snap["edges"], highlight_latest=False), use_container_width=True, key=f"replay_dag_{step_idx}", ) with rc2: st.plotly_chart( fig_reward_breakdown(snap["components"]), use_container_width=True, key=f"replay_breakdown_{step_idx}", ) sec("Action Trace at This Step") trace_lines = [] for h in S.episode_history[:step_idx + 1]: sign = "+" if h["reward"] >= 0 else "" called_str = ", ".join(h["called"]) if h["called"] else "—" marker = "► " if h["step"] == snap["step"] else " " trace_lines.append( f"{marker}Step {h['step']:>2} │ {h['action_name']:<22} │ " f"reward: {sign}{h['reward']:.4f} │ specialists: {called_str}" ) st.code("\n".join(trace_lines), language=None) # ───────────────────────────────────────────────────────── # Tab 1 — Live Demo # ───────────────────────────────────────────────────────── def tab_live_demo(): S = _S() col_task, col_ctrl = st.columns([3, 2], gap="large") with col_task: sec("Task") task_dd = st.selectbox("Preset task", PRESET_TASKS, key="task_dd") task_txt = st.text_input("Or enter custom task", placeholder="Describe a software engineering task…", key="task_txt") phase = st.slider("Curriculum phase", 1, 3, 1, key="phase_sl") with col_ctrl: sec("Controls") c1, c2 = st.columns(2) reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn") run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn") st.markdown('
', unsafe_allow_html=True) use_trained = st.checkbox("🤖 Use Trained Policy", value=False, key="use_trained", help="Load the trained RecurrentPPO model from HF Hub") trained_model = obs_mean = obs_var = None clip_obs = 10.0 if use_trained: with st.spinner("Loading trained model from HF Hub…"): trained_model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO) if model_err: st.error(f"Model load failed: {model_err}") else: st.success("Trained policy loaded ✓") cat = _load_catalog() act_type = st.selectbox("Action type (manual mode)", ["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"], key="act_type", disabled=use_trained) spec_ids = [sp["id"] for sp in cat] spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch", disabled=use_trained) step_btn = st.button("Execute One Step", disabled=(S.env is None or S.done), use_container_width=True, key="step_btn") status_msg = st.session_state.get("demo_status", "Click 'Reset Episode' to start.") status_clr = "#34d399" if "complete" in status_msg or "started" in status_msg else "#94a3b8" status_bar(status_msg, status_clr) st.markdown(_exec_mode_badges(S), unsafe_allow_html=True) # ── Reset ────────────────────────────────────────────── if reset_btn: with st.spinner("Initializing environment… (first run ~30 s on CPU)"): S.reset(int(phase)) spawn_note = ( f" | ⚡ Spawned: {', '.join(S.spawned_specialists)}" if S.spawned_specialists else "" ) st.session_state.demo_status = f'Episode started | Task: "{S.task[:90]}"{spawn_note}' st.session_state.last_called = [] st.session_state.last_edges = [] st.session_state.last_info = {} st.rerun() # ── Step ─────────────────────────────────────────────── if step_btn and S.env is not None and not S.done: if use_trained and trained_model is not None and S.obs_current is not None: action, S.lstm_states = _predict( trained_model, S.obs_current, S.lstm_states, S.episode_starts, obs_mean, obs_var, clip_obs, ) else: action = np.zeros(S.env.action_space.shape, dtype=np.float32) if act_type == "STOP": action[0] = 1.0 elif act_type == "CALL SPECIALIST": ids = S.registry.list_ids() if spec_ch in ids: idx = ids.index(spec_ch) if idx < S.env.max_specialists: action[1 + idx] = 1.0 else: action[1] = 1.0 elif act_type == "PARALLEL SPAWN": action[0] = 6.0 action[1] = 1.0 if S.env.max_specialists > 1: action[2] = 1.0 action[1 + S.env.max_specialists] = 1.0 else: action = S.env.action_space.sample() _, r, term, trunc, info = S.step(action) done = term or trunc sign = "+" if r >= 0 else "" msg = f"Step {S.step_n} | reward {sign}{r:.4f} | {'DONE' if done else 'Running…'}" if done: msg += f" | Total: {sum(S.rewards):+.4f}" st.session_state.demo_status = msg # Use cumulative called_ids so graph stays populated even after STOP step called = list(S.env.called_ids) edges = [(e.caller_id, e.callee_id) for e in S.env.delegation_graph.get_delegation_path()] st.session_state.last_called = called st.session_state.last_edges = edges st.session_state.last_info = info st.rerun() # ── Run Full ─────────────────────────────────────────── if run_btn: with st.spinner("Running full episode…"): S.reset(int(phase)) info = {} for _ in range(15): if S.done: break if use_trained and trained_model is not None and S.obs_current is not None: action, S.lstm_states = _predict( trained_model, S.obs_current, S.lstm_states, S.episode_starts, obs_mean, obs_var, clip_obs, ) else: action = S.env.action_space.sample() _, _, _, _, info = S.step(action) # Use cumulative called_ids so graph stays populated even after STOP step called = list(S.env.called_ids) if S.env else [] edges = [(e.caller_id, e.callee_id) for e in S.env.delegation_graph.get_delegation_path()] total = sum(S.rewards) st.session_state.demo_status = ( f"Episode complete | {S.step_n} steps | Total reward: {total:+.4f}" ) st.session_state.last_called = called st.session_state.last_edges = edges st.session_state.last_info = info st.rerun() # ── Metric strip ────────────────────────────────────── if S.env is not None: mc1, mc2, mc3, mc4 = st.columns(4) mc1.metric("Obs Dim", int(S.env.observation_space.shape[0])) mc2.metric("Action Dim", int(S.env.action_space.shape[0])) mc3.metric("Specialists", S.registry.size) mc4.metric("Phase", phase) # ── Hero: Robot Orchestrator Widget (full width) ────── sec("Orchestrator · Live Delegation View") last_info = st.session_state.get("last_info", {}) render_orchestrator({ "called": st.session_state.get("last_called", []), "active": (st.session_state.get("last_called", []) or [""])[-1] if not S.done else "", "edges": st.session_state.get("last_edges", []), "task": S.task, "step": S.step_n, "mode": last_info.get("delegation_mode", "SEQUENTIAL"), "done": S.done, "reward": sum(S.rewards) if S.rewards else None, "phase": int(st.session_state.get("phase_sl", 1)), }) # Thought bubble ticker — robot's last internal monologue _thoughts = last_info.get("thoughts") or last_info.get("thought") if _thoughts: st.markdown( f'
' f'💭 {_html.escape(str(_thoughts))}
', unsafe_allow_html=True, ) # ── Three-column secondary row ───────────────────────── sc1, sc2, sc3 = st.columns([4, 4, 4]) with sc1: st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True) with sc2: last_info = st.session_state.get("last_info", {}) st.plotly_chart( fig_reward_breakdown(last_info.get("reward_components", {})), use_container_width=True, ) with sc3: sec("Policy Confidence") if S.step_entropies: st.plotly_chart( fig_policy_confidence( S.step_entropies, [h["step"] for h in S.episode_history], ), use_container_width=True, ) else: st.markdown( '
' 'Run an episode to see action entropy.
', unsafe_allow_html=True, ) # ── Step Log (full width) ────────────────────────────── sec("Step Log / Action Trace") if not S.actions: st.markdown( '
' 'Waiting… Reset the episode to start.
', unsafe_allow_html=True, ) else: lines = [] for i, (inf, r) in enumerate(zip(S.actions, S.rewards)): sign = "+" if r >= 0 else "" act = inf.get("action_name", "UNKNOWN") specs = ", ".join(inf.get("called_specialists", [])) mode = inf.get("delegation_mode", "") e_str = (f" │ entropy: {S.step_entropies[i]:.3f}" if i < len(S.step_entropies) else "") lats = inf.get("specialist_latencies", {}) lat_str = ( "\n │ → latency: " + ", ".join(f"{k}: {v:.0f}ms" for k, v in lats.items()) ) if lats else "" lines.append( f"Step {i+1:>2} │ {act:<22} │ reward: {sign}{r:.4f}{e_str}" + (f"\n │ → called: {specs}" if specs else "") + (f"\n │ → mode: {mode}" if mode else "") + lat_str ) total = sum(S.rewards) unique_sp = len(set(sp for h in S.episode_history for sp in h.get("called", []))) lines.append(f"{'─'*62}") lines.append( f"Total reward: {'+' if total>=0 else ''}{total:.4f} │ " f"Steps: {len(S.rewards)} │ " f"Specialists called: {unique_sp} unique" ) st.code("\n".join(lines), language=None) # ── Episode Replay (full width) ──────────────────────── if S.episode_history: st.markdown("---") sec("Episode Replay Mode") st.caption( "Scrub backward through every step of the episode. " "Delegation graph, reward breakdown, and action trace all update to that exact state. " "100% real data — no re-simulation." ) n_steps = len(S.episode_history) if n_steps > 1: replay_step = st.slider( "Replay step", min_value=1, max_value=n_steps, value=n_steps, step=1, key="replay_slider", format="Step %d", ) else: replay_step = 1 st.caption("Single-step episode — showing step 1.") _render_replay_step(S, replay_step - 1) # ───────────────────────────────────────────────────────── # Tab 2 — Specialists # ───────────────────────────────────────────────────────── def tab_specialists(): S = _S() # Prefer live registry so dynamically-added specialists appear immediately. # Fall back to YAML catalog before the environment has been booted. if S.registry is not None: specialists = S.registry.list_all() source_note = None else: class _SP: def __init__(self, d: dict): self.id = d["id"] self.role = d["role"] self.description = d["description"] self.complexity_affinity = d["complexity_affinity"] self.avg_latency_ms = d["avg_latency_ms"] specialists = [_SP(d) for d in _load_catalog()] source_note = "Showing YAML catalog — run an episode to load the live registry (includes dynamic additions)." # ── Dynamically spawned specialists (accumulated from Output tab runs) ── spawned_pool = st.session_state.get("spawned_pool", []) if spawned_pool: sec(f"⚡ Dynamically Spawned · {len(spawned_pool)} new agent{'s' if len(spawned_pool) != 1 else ''}") st.caption( "These specialists were auto-created during Output tab runs — " "triggered when no existing specialist had sufficient domain coverage (similarity < threshold)." ) pool_cols = st.columns(min(len(spawned_pool), 4)) for i, sp in enumerate(spawned_pool): with pool_cols[i % 4]: st.markdown(f"""
⚡ {_html.escape(sp['role'])}
Triggered by: {_html.escape(sp['triggered_by'][:70])}…
{_html.escape(sp['description'][:100])}…
{sp['avg_latency_ms']} ms  ·  {', '.join(sp.get('complexity_affinity', []))}
""", unsafe_allow_html=True) st.markdown("---") n = len(specialists) sec(f"Roster — {n} specialist{'s' if n != 1 else ''}, capability-embedded") if source_note: st.caption(source_note) spawned_set = set(S.spawned_specialists) if S.registry is not None else set() cols = st.columns(4) for i, sp in enumerate(specialists): c = SPEC_COLORS.get(sp.id, "#7c3aed") is_spawned = sp.id in spawned_set border_top = "#fbbf24" if is_spawned else c spawn_tag = ( '⚡ AUTO-SPAWNED' if is_spawned else "" ) with cols[i % 4]: st.markdown(f"""
{sp.role}{spawn_tag}
{_html.escape(sp.description[:90])}…
{sp.avg_latency_ms} ms  ·  {', '.join(sp.complexity_affinity)}
""", unsafe_allow_html=True) sec("Capability Similarity Matrix") if st.button("Load Similarity Matrix", key="sim_btn"): with st.spinner("Computing cosine similarity across 384-dim embeddings…"): S.boot() st.plotly_chart(fig_similarity(S.registry), use_container_width=True) sec("Add Specialist Dynamically") st.caption("New specialists are immediately representable via their 384-dim embedding — no retraining or YAML edits required.") c1, c2 = st.columns(2) new_id = c1.text_input("ID", placeholder="ml_engineer", key="new_id") new_role = c2.text_input("Role", placeholder="ML Engineer", key="new_role") new_desc = st.text_area("Description", placeholder="Expert in PyTorch, model training, MLOps pipelines…", height=80, key="new_desc") if st.button("Add to Roster", type="primary", key="add_btn"): if new_id.strip() and new_role.strip() and new_desc.strip(): with st.spinner("Encoding specialist embedding…"): S.boot() S.registry.add_specialist({ "id": new_id.strip(), "role": new_role.strip(), "description": new_desc.strip(), "complexity_affinity": ["moderate", "complex"], "avg_latency_ms": 5000, }) st.success( f"'{new_id.strip()}' added. " "Policy can represent it via 384-dim embedding — no retraining needed." ) st.plotly_chart(fig_similarity(S.registry), use_container_width=True) else: st.warning("Fill in all three fields.") # ───────────────────────────────────────────────────────── # Tab 3 — Training # ───────────────────────────────────────────────────────── def tab_training(): sec("Training Progress — Mean Reward per Episode") st.markdown( '
' '
' '🔁 Want to run a fresh training run?
' '
' 'Open the Training Space below, then click ' '▶ Start Training. ' 'When the run completes the new model is pushed to HF Hub and this demo loads it automatically.
' '⚠️ Starting a new run will overwrite the current A100-trained policy.' '
' '' '🚀 Open Training Space →' '
', unsafe_allow_html=True, ) c_fetch, _ = st.columns([2, 5]) if c_fetch.button("📥 Fetch latest curve from HF Hub", key="fetch_curve"): try: import shutil from huggingface_hub import hf_hub_download _tok = os.getenv("HF_TOKEN") or None src = hf_hub_download(HF_MODEL_REPO, "reward_curve.json", token=_tok, force_download=True) ASSETS.mkdir(parents=True, exist_ok=True) shutil.copy(src, ASSETS / "reward_curve.json") st.success("reward_curve.json updated — chart will refresh.") st.cache_data.clear() except Exception as exc: st.error(f"Download failed: {exc}") st.plotly_chart(fig_training_curve(), use_container_width=True) sec("Policy Entropy — Action Confidence Over Training") st.caption( "Entropy of the specialist-selection distribution. " "High = exploring (early training). Low = confident routing (converged policy)." ) st.plotly_chart(fig_training_entropy(), use_container_width=True) sec("Curriculum Phases") c1, c2, c3 = st.columns(3) _phase_card = lambda col, color, label, eps, desc: col.markdown( f'
' f'
{label}
' f'
{eps}
' f'
{desc}
', unsafe_allow_html=True, ) _phase_card(c1, "0,212,255", "Phase 1 · Atomic", "200 episodes", "Agent learns basic routing — which single specialist to call.") _phase_card(c2, "124,58,237", "Phase 2 · Moderate", "400 episodes", "Agent learns multi-specialist coordination and mode selection.") _phase_card(c3, "245,158,11", "Phase 3 · Complex/Enterprise", "600 episodes", "Full delegation strategy with DAG depth, fallbacks, and latency trade-offs.") sec("Quick Start Commands") c1, c2 = st.columns(2) with c1: st.markdown("**Local training**") st.code( "# Demo mode — no OpenAI key needed\n" "cd spindleflow-rl\n" "python training/train.py \\\n" " --phase 1 --timesteps 50000\n\n" "# Monitor in TensorBoard\n" "tensorboard --logdir tensorboard_logs/", language="bash", ) with c2: st.markdown("**Google Colab (T4 GPU, free)**") st.code( "!git clone https://github.com/garvitsachdevaa/kuchbhi\n" "%cd kuchbhi\n" "!pip install -r requirements.txt sb3-contrib\n\n" "# 5k-step demo run\n" "%run colab/train_colab.py", language="python", ) # ───────────────────────────────────────────────────────── # Tab 4 — Quality Demo # ───────────────────────────────────────────────────────── def tab_quality(): results = st.session_state.get("output_results") env_obj = st.session_state.get("output_env") sec("Live Quality Comparison — Generic vs Specialist-Routed") if results is None: st.markdown( '
' '
' 'No Output run yet
' '
' 'Go to the 🎯 Output tab, enter a task, and click ' '"Run Trained Policy" — then return here to generate the quality comparison.' '
', unsafe_allow_html=True, ) else: task = results["task"] spec_results = results["specialist_results"] specialist_text = "\n\n".join( f"[{sr['id'].upper()}]\n{sr['output'] or ''}" for sr in spec_results if sr.get("output") ) or "(no specialist output)" # Task banner st.markdown( f'
' f'Comparing outputs for: ' f'{_html.escape(task[:140])}' f'
', unsafe_allow_html=True, ) comp_data = st.session_state.get("quality_comparison") already_computed = comp_data is not None and comp_data.get("task") == task if not already_computed: if st.button("⚡ Generate Quality Comparison", type="primary", key="gen_comp_btn"): with st.spinner("Generating generic output + running GPT-4o-mini judge…"): generic_text = _generate_generic_output(task) registry = env_obj.registry if env_obj else None gen_t1 = _t1_relevance(task, generic_text, registry) if registry else 5.0 spec_t1 = _t1_relevance(task, specialist_text, registry) if registry else 7.0 judge = _judge_compare(task, generic_text, specialist_text) def _pick(key, fallback_g, fallback_s): pair = (judge or {}).get(key, [fallback_g, fallback_s]) return float(pair[0]), float(pair[1]) td_g, td_s = _pick("technical_depth", 5, 7) sp_g, sp_s = _pick("specificity", 4, 8) ac_g, ac_s = _pick("actionability", 4, 7) cv_g, cv_s = _pick("coverage", 5, 8) gen_scores = {"Task Relevance": gen_t1, "Technical Depth": td_g, "Specificity": sp_g, "Actionability": ac_g, "Coverage": cv_g} spec_scores = {"Task Relevance": spec_t1, "Technical Depth": td_s, "Specificity": sp_s, "Actionability": ac_s, "Coverage": cv_s} st.session_state.quality_comparison = { "task": task, "generic": generic_text, "specialist": specialist_text, "gen_scores": gen_scores, "spec_scores": spec_scores, } st.rerun() comp_data = st.session_state.get("quality_comparison") if comp_data and comp_data.get("task") == task: gen_scores = comp_data["gen_scores"] spec_scores = comp_data["spec_scores"] # ── Score summary strip ───────────────────────────────────── sec("Score Summary") cols = st.columns(len(gen_scores)) for i, (dim, g_val) in enumerate(gen_scores.items()): s_val = spec_scores[dim] delta = round(s_val - g_val, 1) cols[i].metric( dim, f"{s_val:.1f} / 10", f"{delta:+.1f} vs generic", ) # ── Radar chart ───────────────────────────────────────────── sec("Quality Radar") st.plotly_chart( fig_radar_comparison(gen_scores, spec_scores), use_container_width=True, key="quality_radar", ) # ── Side-by-side score bars ────────────────────────────────── sec("Per-Dimension Score Breakdown") dims = list(gen_scores.keys()) g_vals = [gen_scores[d] for d in dims] s_vals = [spec_scores[d] for d in dims] bar_fig = go.Figure() bar_fig.add_trace(go.Bar( name="Generic", x=dims, y=g_vals, marker_color="rgba(239,68,68,0.75)", marker_line_width=0, text=[f"{v:.1f}" for v in g_vals], textposition="outside", textfont=dict(size=10, color="#94a3b8"), )) bar_fig.add_trace(go.Bar( name="Specialist", x=dims, y=s_vals, marker_color="rgba(0,212,255,0.75)", marker_line_width=0, text=[f"{v:.1f}" for v in s_vals], textposition="outside", textfont=dict(size=10, color="#94a3b8"), )) bar_fig.update_layout( **DARK, **DARK_AXES, height=300, barmode="group", legend=dict(bgcolor="rgba(0,0,0,0)", font=dict(color="#94a3b8")), ) bar_fig.update_yaxes(range=[0, 11], gridcolor="rgba(255,255,255,0.05)") st.plotly_chart(bar_fig, use_container_width=True, key="quality_bars") # ── Side-by-side text ──────────────────────────────────────── sec("Output Text Comparison") c1, c2 = st.columns(2) with c1: st.markdown( '
' '✗ Generic Output (No Delegation)
', unsafe_allow_html=True, ) st.code(comp_data["generic"][:1200], language=None) with c2: st.markdown( '
' '✓ Specialist-Routed Output (Trained Policy)
', unsafe_allow_html=True, ) st.code(comp_data["specialist"][:1200], language=None) sec("Policy Tuning — Quality vs Latency") c1, c2 = st.columns(2) with c1: st.markdown("""
Quality Policy
5 specialists  ·  sequential  ·  ~180 s
latency_weight = 0.0
""", unsafe_allow_html=True) with c2: st.markdown("""
Latency Policy
3 specialists  ·  parallel  ·  ~45 s
latency_weight = 0.15
""", unsafe_allow_html=True) # ───────────────────────────────────────────────────────── # Tab 5 — Reward Lab # ───────────────────────────────────────────────────────── def tab_reward_lab(): sec("Interactive Reward Explorer") st.caption("Tune the reward weights and watch each component update live.") col_s, col_c = st.columns([1, 2], gap="large") with col_s: lw = st.slider("Latency Weight", 0.0, 0.50, 0.05, 0.01, key="rl_lw") ep = st.slider("Efficiency Penalty", 0.0, 0.20, 0.05, 0.01, key="rl_ep") fp = st.slider("Failure Penalty", 0.0, 1.00, 0.30, 0.05, key="rl_fp") cw = st.slider("Consistency Bonus", 0.0, 0.50, 0.10, 0.01, key="rl_cw") eb = st.slider("Explanation Bonus", 0.0, 0.20, 0.05, 0.01, key="rl_eb") comps = { "quality_delta": 0.42, "efficiency_penalty": -ep * 2, "failure_penalty": -fp * 0.3, "recovery_bonus": 0.08, "conflict_penalty": -0.05, "conflict_bonus": 0.03, "consistency_bonus": cw * 0.6, "latency_penalty": -lw * 0.25, "explanation_bonus": eb, } total = sum(comps.values()) sign = "+" if total >= 0 else "" with col_c: st.plotly_chart(fig_reward_breakdown(comps), use_container_width=True) st.markdown( f'
' f'Estimated total reward: ' f'{sign}{total:.3f}' f'
', unsafe_allow_html=True, ) # ───────────────────────────────────────────────────────── # Tab 6 — Architecture # ───────────────────────────────────────────────────────── def tab_architecture(): obs0 = EpisodeState.observation_dim(6) act0 = 6 + 6 c1, c2 = st.columns(2) with c1: sec(f"Observation Space ({obs0:,} dims)") st.markdown(""" | Dims | Component | |-----:|-----------| | 384 | Task embedding (all-MiniLM-L6-v2) | | 2304 | Roster embeddings (6 × 384) | | 2304 | Called embeddings (6 × 384) | | 384 | Scratchpad embedding | | 100 | Delegation graph adjacency (10 × 10) | | 6 | Called-specialist mask | | 8 | Scalar features | """) with c2: sec(f"Action Space ({act0}-dim Box)") st.markdown(""" | Index | Component | |--------|-----------| | [0] | Meta-action (STOP / CALL / PARALLEL…) | | [1:7] | Specialist selection logits (multi-hot) | | [7] | Delegation mode (SEQ / PAR / FAN-OUT…) | | [8:12] | Mode parameters (rounds, threshold…) | """) c1, c2, c3 = st.columns(3) with c1: sec("Policy") st.markdown(""" - **LSTM PPO** (RecurrentPPO) - MlpLstmPolicy - Hidden: 256 · 1 layer - POMDP-safe via LSTM state - 4 factored action heads """) with c2: sec("Tiered Reward") st.markdown(""" - **T0** — Structural heuristics - **T1** — Cosine embedding sim - **T2** — GPT-4o-mini judge - **T3** — Full judge (checkpoints) - Episode-level tier lock """) with c3: sec("Safety") st.markdown(""" - DAG cycle detection (DFS) - Max delegation depth: 2 - Scratchpad sandbox isolation - Injection sanitization - Action masking (DAG) """) sec("Reward Function") st.code("""total_reward = ( quality_delta # specialist_score − baseline (same tier) − efficiency_penalty # 0.05 × max(0, n_called − expected) − failure_penalty # 0.3 per timeout, 0.2 per error + recovery_bonus # +0.1 if fallback succeeded − conflict_penalty # 0.1 per unresolved conflict + conflict_bonus # 0.05 per resolved conflict + consistency_bonus # 0.1 × Dirichlet-prior path score − latency_penalty # latency_weight × overage_fraction + explanation_bonus # 0.05 if delegation is auditable )""", language="python") # ───────────────────────────────────────────────────────── # Tab 7 — Output (Trained Policy) # ───────────────────────────────────────────────────────── def tab_output(): """Run the trained LSTM PPO policy on a custom task and show every specialist's output.""" hero() st.markdown( '
' 'Enter any software engineering task. The trained LSTM PPO policy decides which ' 'specialists to delegate to — each specialist\'s individual output and the collective ' 'synthesis are shown below.
', unsafe_allow_html=True, ) col_input, col_ctrl = st.columns([3, 1], gap="large") with col_input: sec("Task") task_input = st.text_area( "Task description", height=110, key="output_task_input", placeholder=( "Build a real-time collaborative code review tool with inline comments, " "role-based access control, GitHub webhook integration, and CI/CD pipeline " "status display. Include authentication with OAuth2." ), ) with col_ctrl: sec("Config") out_phase = st.selectbox("Curriculum phase", [1, 2, 3], index=1, key="output_phase") st.markdown('
', unsafe_allow_html=True) run_btn = st.button( "🚀 Run Trained Policy", type="primary", use_container_width=True, key="output_run_btn", ) if run_btn: _task = (task_input or "").strip() if not _task: st.warning("Please enter a task description.") return with st.spinner("Loading trained model from HF Hub…"): model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO) if model_err: st.error(f"Model load failed: {model_err}") return st.success("Trained policy loaded ✓") with st.spinner("Running episode with trained policy…"): try: env = SpindleFlowEnv( config_path=CONFIG, catalog_path=CATALOG, use_real_spindleflow=False, phase=int(out_phase), ) # Inject custom task so the env uses the user's input env.task_bank.sample = lambda: _task obs, info = env.reset() task_used = info.get("task", _task) lstm_states = None episode_starts = np.array([True]) done = False rewards: list[float] = [] MIN_SPECIALISTS = 4 # suppress STOP until this many specialists called for _ in range(15): if done: break obs_arr = obs[np.newaxis, :].copy().astype(np.float32) if obs_mean is not None and obs_var is not None: obs_arr = np.clip( (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8), -clip_obs, clip_obs, ) action_batch, lstm_states = model.predict( obs_arr, state=lstm_states, episode_start=episode_starts, deterministic=True, ) action = action_batch[0].copy() called_set = set(env.called_ids) if len(called_set) < MIN_SPECIALISTS: # The policy may want to STOP early; when it does, its # specialist-selection logits are all low/negative so # simply zeroing action[0] still produces garbage selection. # Fix: build a fresh action that directly picks the first # uncalled specialist with a hard positive logit (1.0). roster = env.active_specialist_ids uncalled = [sid for sid in roster if sid not in called_set] if uncalled: action = np.zeros(env.action_space.shape, dtype=np.float32) action[0] = 0.0 # MetaAction.CALL_SPECIALIST idx = roster.index(uncalled[0]) if 1 + idx < len(action): action[1 + idx] = 1.0 obs, r, term, trunc, _ = env.step(action) rewards.append(float(r)) done = term or trunc episode_starts = np.array([done]) called = list(env.called_ids) edges = [(e.caller_id, e.callee_id) for e in env.delegation_graph.get_delegation_path()] spawned = list(getattr(env, "spawned_this_episode", [])) st.session_state.output_results = { "task": task_used, "rewards": rewards, "called": called, "edges": edges, "specialist_results": [ { "id": sr.specialist_id, "output": sr.output, "status": sr.status, "latency_ms": sr.latency_ms, } for sr in env.specialist_results ], "spawned": spawned, } # Keep env alive for delegation-graph rendering st.session_state.output_env = env # Persist spawned specialists to shared pool for Specialists tab if "spawned_pool" not in st.session_state: st.session_state.spawned_pool = [] existing_ids = {sp["id"] for sp in st.session_state.spawned_pool} for sid in spawned: if sid not in existing_ids: sp_obj = env.registry.get(sid) if sp_obj: st.session_state.spawned_pool.append({ "id": sid, "role": sp_obj.role, "description": sp_obj.description, "complexity_affinity": list(sp_obj.complexity_affinity), "avg_latency_ms": sp_obj.avg_latency_ms, "triggered_by": task_used[:120], }) except Exception as exc: import traceback st.error(f"Episode failed: {exc}") st.code(traceback.format_exc(), language=None) return st.rerun() # ── Display results ──────────────────────────────────────────────── results = st.session_state.get("output_results") env_obj = st.session_state.get("output_env") if results is None: st.markdown( '
' 'Enter a task and click "Run Trained Policy" to see delegation and specialist outputs.' '
', unsafe_allow_html=True, ) return # Task banner st.markdown( f'
' f'
Task
' f'
{_html.escape(results["task"])}
' f'
', unsafe_allow_html=True, ) # Metrics strip total_r = sum(results["rewards"]) mc1, mc2, mc3, mc4 = st.columns(4) mc1.metric("Total Reward", f"{total_r:+.3f}") mc2.metric("Steps", len(results["rewards"])) mc3.metric("Specialists Called", len(results["called"])) mc4.metric("Auto-Spawned", len(results["spawned"])) # Orchestrator widget sec("Orchestrator · Delegation Visualization") render_orchestrator({ "called": results["called"], "active": "", "edges": results["edges"], "task": results["task"], "step": len(results["rewards"]), "mode": "SEQUENTIAL", "done": True, "reward": sum(results["rewards"]), "phase": int(st.session_state.get("output_phase", 2)), "spawned": results["spawned"], }) # Delegation graph sec("Delegation Graph") if env_obj is not None: class _GraphProxy: registry = env_obj.registry spawned_specialists = results["spawned"] env = env_obj st.plotly_chart( fig_delegation_graph( _GraphProxy(), results["called"], results["edges"], highlight_latest=False, spawned_ids=results["spawned"], ), use_container_width=True, key="output_dag", ) # Auto-spawn alert if results["spawned"]: st.markdown( '
' '⚡ Auto-Spawned: ' '' + ", ".join(results["spawned"]) + '
', unsafe_allow_html=True, ) # Individual specialist outputs spec_results = results["specialist_results"] sec(f"Individual Specialist Outputs · {len(spec_results)} called") if not spec_results: st.markdown( '
' 'The policy issued STOP without delegating to any specialists.
', unsafe_allow_html=True, ) else: for sr in spec_results: sid = sr["id"] color = SPEC_COLORS.get(sid, "#7c3aed") ok_clr = "#10b981" if sr["status"] == "success" else "#ef4444" lat = sr.get("latency_ms", 0) label = ( f"🤖 {sid.replace('_', ' ').title()}" f" · {sr['status']} · {lat:.0f} ms" ) with st.expander(label, expanded=True): st.markdown( f'
' f'{sid}' f' · status: ' f'{sr["status"]}' f' · {lat:.0f} ms' f'
', unsafe_allow_html=True, ) st.code(sr["output"] or "(no output)", language=None) # Synthesized / collective output sec("Synthesized Output · Collective Response") st.caption("All specialist outputs combined — this is what the orchestrator received.") if spec_results: parts = [ f"{'─'*52}\n[{sr['id'].upper()}]\n{'─'*52}\n{sr['output'] or '(empty)'}" for sr in spec_results ] synthesis = "\n\n".join(parts) else: synthesis = "(no specialists called — policy chose STOP on first step)" st.code(synthesis, language=None) # ───────────────────────────────────────────────────────── # Entry point # ───────────────────────────────────────────────────────── def main(): inject_css() S = _S() render_live_stats(S) t1, t2, t3, t4, t5, t6, t7 = st.tabs([ "🎯 Output", "⚡ Training Interface Example", "🤖 Specialists", "📈 Training", "🔍 Quality Demo", "🧪 Reward Lab", "🏗 Architecture", ]) with t1: tab_output() with t2: tab_live_demo() with t3: tab_specialists() with t4: tab_training() with t5: tab_quality() with t6: tab_reward_lab() with t7: tab_architecture() # Guard allows safe imports for testing without triggering the UI. # Streamlit runs scripts with __name__ == "__main__". if __name__ == "__main__": main()