garvitsachdeva commited on
Commit
61e8a52
·
1 Parent(s): 51db558

fix: importlib wrapper with correct __file__; remove garbage files

Browse files
Files changed (3) hide show
  1. =4.40.0 +0 -0
  2. =5.22.0 +0 -0
  3. streamlit_app.py +18 -2128
=4.40.0 DELETED
File without changes
=5.22.0 DELETED
File without changes
streamlit_app.py CHANGED
@@ -1,2132 +1,22 @@
1
- """
2
- SpindleFlow RL Streamlit Dashboard
3
- =====================================
4
- Run: cd spindleflow-rl && streamlit run demo/streamlit_app.py
5
- URL: http://localhost:8501
6
- """
7
-
8
- from __future__ import annotations
9
- import os, sys, json, html as _html
10
  from pathlib import Path
11
- import numpy as np
12
- from dotenv import load_dotenv
13
-
14
- load_dotenv() # load OPENAI_API_KEY (and any other vars) from .env
15
-
16
- # HF_HUB_OFFLINE intentionally NOT set — manual HF Hub downloads must work
17
 
18
- sys.path.insert(0, str(Path(__file__).resolve().parent))
19
- sys.path.insert(0, str(Path(__file__).resolve().parent / "demo"))
20
-
21
- import streamlit as st
22
- import plotly.graph_objects as go
23
- from plotly.subplots import make_subplots
24
 
25
  try:
26
- from env.spindleflow_env import SpindleFlowEnv
27
- from env.state import EpisodeState
28
- from env.specialist_registry import SpecialistRegistry
29
- from orchestrator_widget import render_orchestrator
30
- except Exception as _import_err:
31
- import traceback as _tb
32
- st.error(f"Import failed: {_import_err}")
33
- st.code(_tb.format_exc())
34
- st.stop()
35
-
36
- # ─────────────────────────────────────────────────────────
37
- # Page config (must be first Streamlit call)
38
- # ─────────────────────────────────────────────────────────
39
- st.set_page_config(
40
- page_title="SpindleFlow RL",
41
- page_icon="⚡",
42
- layout="wide",
43
- initial_sidebar_state="collapsed",
44
- )
45
-
46
- # ─────────────────────────────────────────────────────────
47
- # Constants
48
- # ─────────────────────────────────────────────────────────
49
- CONFIG = "configs/training_config.yaml"
50
- CATALOG = "configs/specialist_catalog.yaml"
51
- ASSETS = Path("demo/assets")
52
-
53
- SPEC_COLORS = {
54
- "frontend_react": "#00d4ff",
55
- "backend_api": "#7c3aed",
56
- "database_architect": "#f59e0b",
57
- "devops_engineer": "#10b981",
58
- "security_analyst": "#ef4444",
59
- "product_strategist": "#8b5cf6",
60
- "ux_designer": "#ec4899",
61
- "tech_writer": "#94a3b8",
62
- }
63
-
64
- @st.cache_resource
65
- def _get_preset_tasks(n: int = 8) -> list[str]:
66
- """Sample n live tasks from TaskBank at page load — no hardcoded strings."""
67
- try:
68
- from training.task_bank import TaskBank
69
- bank = TaskBank(phase=1)
70
- return [bank.sample() for _ in range(n)]
71
- except Exception:
72
- # Fallback only if TaskBank is unavailable (e.g. missing config)
73
- return ["Describe a software engineering task requiring specialist collaboration"]
74
-
75
-
76
- PRESET_TASKS = _get_preset_tasks()
77
-
78
- HF_MODEL_REPO = "garvitsachdeva/spindleflow-rl"
79
-
80
-
81
- @st.cache_resource
82
- def _load_trained_model(hf_repo: str):
83
- """Download RecurrentPPO + VecNormalize stats from HF Hub.
84
-
85
- Returns (model, obs_mean, obs_var, clip_obs, error_str).
86
- Temporarily lifts the HF_HUB_OFFLINE flag set at module level.
87
- """
88
- import pickle
89
- try:
90
- from huggingface_hub import hf_hub_download
91
- from sb3_contrib import RecurrentPPO
92
-
93
- _tok = os.getenv("HF_TOKEN") or None
94
- # Try final model first, fall back to latest periodic checkpoint
95
- try:
96
- _model_path = hf_hub_download(hf_repo, "spindleflow_model.zip", token=_tok)
97
- except Exception:
98
- _model_path = hf_hub_download(hf_repo, "spindleflow_model_latest.zip", token=_tok)
99
- model = RecurrentPPO.load(_model_path, device="cpu")
100
- obs_mean = obs_var = None
101
- clip_obs = 10.0
102
- try:
103
- try:
104
- stats_path = hf_hub_download(hf_repo, "vec_normalize.pkl", token=_tok)
105
- except Exception:
106
- stats_path = hf_hub_download(hf_repo, "vec_normalize_latest.pkl", token=_tok)
107
- with open(stats_path, "rb") as f:
108
- vn = pickle.load(f)
109
- obs_mean = vn.obs_rms.mean.copy()
110
- obs_var = vn.obs_rms.var.copy()
111
- clip_obs = float(vn.clip_obs)
112
- except Exception:
113
- pass
114
- return model, obs_mean, obs_var, clip_obs, None
115
- except Exception as exc:
116
- return None, None, None, 10.0, str(exc)
117
- finally:
118
- pass
119
-
120
-
121
- def _predict(model, obs: np.ndarray, lstm_states, episode_starts,
122
- obs_mean, obs_var, clip_obs: float):
123
- """Normalize obs and call model.predict(); return (action, new_lstm_states)."""
124
- obs_arr = obs[np.newaxis, :].copy().astype(np.float32)
125
- if obs_mean is not None and obs_var is not None:
126
- obs_arr = np.clip(
127
- (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8),
128
- -clip_obs, clip_obs,
129
- )
130
- action_batch, new_states = model.predict(
131
- obs_arr,
132
- state=lstm_states,
133
- episode_start=episode_starts,
134
- deterministic=True,
135
- )
136
- return action_batch[0], new_states
137
-
138
-
139
- DARK = dict(
140
- paper_bgcolor="rgba(0,0,0,0)",
141
- plot_bgcolor="rgba(0,0,0,0)",
142
- font=dict(color="#e2e8f0", family="Inter, system-ui, sans-serif"),
143
- margin=dict(l=44, r=20, t=44, b=40),
144
- )
145
- DARK_AXES = dict(
146
- xaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"),
147
- yaxis=dict(gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"),
148
- )
149
-
150
- # ─────────────────────────────────────────────────────────
151
- # Session state
152
- # ─────────────────────────────────────────────────────────
153
- class Session:
154
- def __init__(self):
155
- self.env: SpindleFlowEnv | None = None
156
- self.registry: SpecialistRegistry | None = None
157
- self.rewards: list[float] = []
158
- self.actions: list[dict] = []
159
- self.step_n = 0
160
- self.done = False
161
- self.task = ""
162
- # Full episode history for replay
163
- self.episode_history: list[dict] = []
164
- # Action entropy per step (policy confidence)
165
- self.step_entropies: list[float] = []
166
- # Observation vector stats per step
167
- self.obs_history: list[dict] = []
168
- # Specialists auto-spawned for this episode
169
- self.spawned_specialists: list[str] = []
170
- # Trained policy inference state
171
- self.obs_current: np.ndarray | None = None
172
- self.lstm_states = None
173
- self.episode_starts = np.array([True])
174
-
175
- def boot(self):
176
- if self.env is None:
177
- self.env = SpindleFlowEnv(
178
- config_path=CONFIG, catalog_path=CATALOG,
179
- use_real_spindleflow=False, phase=1,
180
- )
181
- self.registry = self.env.registry
182
-
183
- def reset(self, phase: int = 1):
184
- self.boot()
185
- self.env.phase = int(phase)
186
- obs, info = self.env.reset()
187
- self.rewards = []
188
- self.actions = []
189
- self.step_n = 0
190
- self.done = False
191
- self.task = info.get("task", "")
192
- self.episode_history = []
193
- self.step_entropies = []
194
- self.obs_history = []
195
- self.spawned_specialists: list[str] = list(info.get("spawned_specialists", []))
196
- self.obs_current = obs
197
- self.lstm_states = None
198
- self.episode_starts = np.array([True])
199
- return obs, info
200
-
201
- def step(self, action):
202
- if self.env is None or self.done:
203
- return None, 0.0, True, False, {}
204
- obs, r, term, trunc, info = self.env.step(action)
205
- self.rewards.append(r)
206
- self.actions.append(info)
207
- self.step_n += 1
208
- self.done = term or trunc
209
- self.obs_current = obs
210
- self.episode_starts = np.array([self.done])
211
-
212
- # Capture step snapshot for replay
213
- called = info.get("called_specialists", [])
214
- edges = [(e.caller_id, e.callee_id)
215
- for e in self.env.delegation_graph.get_delegation_path()]
216
- self.episode_history.append({
217
- "step": self.step_n,
218
- "reward": r,
219
- "action_name": info.get("action_name", "UNKNOWN"),
220
- "called": list(called),
221
- "edges": list(edges),
222
- "components": dict(info.get("reward_components", {})),
223
- "mode": info.get("delegation_mode", ""),
224
- "cumulative": float(sum(self.rewards)),
225
- "latencies": dict(info.get("specialist_latencies", {})),
226
- })
227
-
228
- # Compute real action entropy (specialist-selection logits)
229
- if self.env is not None:
230
- n = self.env.max_specialists
231
- spec_logits = action[1: 1 + n].copy()
232
- spec_logits = spec_logits - spec_logits.max()
233
- exp_l = np.exp(spec_logits)
234
- probs = exp_l / (exp_l.sum() + 1e-8)
235
- entropy = float(-np.sum(probs * np.log(probs + 1e-8)))
236
- self.step_entropies.append(entropy)
237
-
238
- # Capture observation norm for state trace
239
- if obs is not None:
240
- self.obs_history.append({
241
- "step": self.step_n,
242
- "obs_norm": float(np.linalg.norm(obs)),
243
- "obs_mean": float(obs.mean()),
244
- "obs_max": float(obs.max()),
245
- })
246
-
247
- return obs, r, term, trunc, info
248
-
249
-
250
- def _S() -> Session:
251
- if "session" not in st.session_state:
252
- st.session_state.session = Session()
253
- return st.session_state.session
254
-
255
-
256
- def _load_catalog() -> list[dict]:
257
- import yaml
258
- with open(CATALOG) as f:
259
- return yaml.safe_load(f)["specialists"]
260
-
261
-
262
- def _exec_mode_badges(S: "Session") -> str:
263
- """Return inline HTML badge strip showing execution and task-generation modes."""
264
- import os
265
- has_key = bool(os.getenv("OPENAI_API_KEY"))
266
- llm_tasks = S.env is not None and S.env.task_bank._client is not None
267
-
268
- exec_b = (
269
- '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;'
270
- 'background:rgba(16,185,129,0.1);color:#34d399;'
271
- 'border:1px solid rgba(16,185,129,0.22);">● LLM BASELINE</span>'
272
- if has_key else
273
- '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;'
274
- 'background:rgba(245,158,11,0.1);color:#fbbf24;'
275
- 'border:1px solid rgba(245,158,11,0.22);">'
276
- '⚡ SIMULATION MODE — specialist outputs templated · set OPENAI_API_KEY for real LLM</span>'
277
- )
278
- task_b = (
279
- '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;'
280
- 'background:rgba(16,185,129,0.1);color:#34d399;'
281
- 'border:1px solid rgba(16,185,129,0.22);">● LLM TASKS</span>'
282
- if llm_tasks else
283
- '<span style="padding:3px 10px;border-radius:999px;font-size:10px;font-weight:700;'
284
- 'background:rgba(148,163,184,0.08);color:#64748b;'
285
- 'border:1px solid rgba(148,163,184,0.18);">⚡ CATALOG TASKS</span>'
286
- ) if S.env is not None else ""
287
-
288
- return (
289
- f'<div style="display:flex;gap:8px;flex-wrap:wrap;margin:4px 0 12px;">'
290
- f'{exec_b}{task_b}</div>'
291
- )
292
-
293
- # ─────────────────────────────────────────────────────────
294
- # Chart builders
295
- # ─────────────────────────────────────────────────────────
296
- def fig_reward_curve(rewards: list[float]) -> go.Figure:
297
- if not rewards:
298
- fig = go.Figure()
299
- fig.update_layout(
300
- **DARK, **DARK_AXES,
301
- title=dict(text="Episode Reward", font=dict(size=13, color="#64748b")),
302
- annotations=[dict(text="Reset the environment to begin",
303
- x=0.5, y=0.5, showarrow=False,
304
- font=dict(color="#334155", size=13))],
305
- )
306
- return fig
307
-
308
- steps = list(range(len(rewards)))
309
- cumul = np.cumsum(rewards).tolist()
310
- fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
311
- row_heights=[0.62, 0.38], vertical_spacing=0.04)
312
- fig.add_trace(go.Scatter(
313
- x=steps, y=cumul, mode="lines",
314
- line=dict(color="#00d4ff", width=2.5),
315
- fill="tozeroy", fillcolor="rgba(0,212,255,0.07)",
316
- name="Cumulative",
317
- ), row=1, col=1)
318
- fig.add_trace(go.Bar(
319
- x=steps, y=rewards,
320
- marker_color=["#10b981" if r >= 0 else "#ef4444" for r in rewards],
321
- marker_line_width=0, name="Per-step",
322
- ), row=2, col=1)
323
- fig.update_layout(**DARK, height=300, showlegend=False,
324
- title=dict(text="Episode Reward", font=dict(size=13, color="#94a3b8")))
325
- fig.update_xaxes(gridcolor="rgba(255,255,255,0.05)")
326
- fig.update_yaxes(gridcolor="rgba(255,255,255,0.05)",
327
- title_text="Cumul.", row=1, col=1, title_font_size=10)
328
- fig.update_yaxes(title_text="Step", row=2, col=1, title_font_size=10)
329
- return fig
330
-
331
-
332
- def fig_delegation_graph(
333
- S: Session,
334
- called_ids: list[str],
335
- edges: list[tuple],
336
- highlight_latest: bool = True,
337
- spawned_ids: list[str] | None = None,
338
- ) -> go.Figure:
339
- """
340
- Professional hierarchical DAG layout.
341
- Orchestrator at top, called specialists in middle, uncalled dimmed at bottom.
342
- """
343
- all_ids = list(S.registry.list_ids()) if S.registry else []
344
- called_set = set(called_ids)
345
- spawned_set = set(spawned_ids or S.spawned_specialists)
346
- uncalled = [x for x in all_ids if x not in called_set]
347
-
348
- # ── Build node positions (hierarchical layout) ───────────────────
349
- pos = {"orchestrator": (0.5, 0.92)}
350
-
351
- n_called = len(called_ids)
352
- if n_called > 0:
353
- for i, sid in enumerate(called_ids):
354
- x = (i + 1) / (n_called + 1)
355
- pos[sid] = (x, 0.55)
356
-
357
- n_uncalled = len(uncalled)
358
- if n_uncalled > 0:
359
- for i, sid in enumerate(uncalled):
360
- x = (i + 1) / (n_uncalled + 1)
361
- pos[sid] = (x, 0.12)
362
-
363
- fig = go.Figure()
364
-
365
- # ── Background depth ring ────────────────────────────────────────
366
- max_depth = getattr(S.env, "max_depth", 2) if S.env else 2
367
- cur_depth = S.env.delegation_graph.depth if S.env else 0
368
- depth_frac = cur_depth / max(max_depth, 1)
369
- ring_color = ("#10b981" if depth_frac < 0.7
370
- else ("#f59e0b" if depth_frac < 1.0 else "#ef4444"))
371
-
372
- fig.add_shape(type="rect",
373
- x0=0.0, y0=0.0, x1=1.0, y1=1.0,
374
- line=dict(color=ring_color, width=2, dash="dot"),
375
- fillcolor="rgba(0,0,0,0)", xref="x", yref="y",
376
- )
377
- fig.add_annotation(
378
- x=0.98, y=0.98, xref="x", yref="y",
379
- text=f"Depth {cur_depth}/{max_depth}", showarrow=False,
380
- font=dict(size=9, color=ring_color), xanchor="right", yanchor="top",
381
- )
382
-
383
- # ── Edges ────────────────────────────────────────────────────────
384
- latest_edge = edges[-1] if edges else None
385
- for src, dst in edges:
386
- if src not in pos or dst not in pos:
387
- continue
388
- x0, y0 = pos[src]
389
- x1, y1 = pos[dst]
390
- is_latest = (latest_edge and highlight_latest and (src, dst) == latest_edge)
391
- color = "rgba(0,212,255,0.9)" if is_latest else "rgba(0,212,255,0.45)"
392
- width = 2.5 if is_latest else 1.8
393
- dash = "dash" if is_latest else "solid"
394
-
395
- fig.add_trace(go.Scatter(
396
- x=[x0, x1, None], y=[y0, y1, None], mode="lines",
397
- line=dict(color=color, width=width, dash=dash),
398
- hoverinfo="skip", showlegend=False,
399
- ))
400
- fig.add_annotation(
401
- ax=x0, ay=y0, x=x1, y=y1,
402
- xref="x", yref="y", axref="x", ayref="y",
403
- arrowhead=3, arrowsize=1.4, arrowwidth=2,
404
- arrowcolor=color, showarrow=True,
405
- )
406
-
407
- # ── Orchestrator node ────────────────────────────────────────────
408
- ox, oy = pos["orchestrator"]
409
- fig.add_trace(go.Scatter(
410
- x=[ox], y=[oy], mode="markers+text",
411
- marker=dict(size=44, color="#f59e0b", symbol="circle",
412
- line=dict(color="#fcd34d", width=2.5), opacity=1.0),
413
- text=["<b>ORCH</b>"], textposition="middle center",
414
- textfont=dict(size=9, color="#0a0f1a", family="Inter, sans-serif"),
415
- hovertext=["<b>Orchestrator</b><br>Root node — makes all delegation decisions"],
416
- hoverinfo="text", showlegend=False, name="orchestrator",
417
- ))
418
-
419
- # ── Called specialist nodes ──────────────────────────────────────
420
- for sid in called_ids:
421
- if sid not in pos:
422
- continue
423
- x, y = pos[sid]
424
- c = SPEC_COLORS.get(sid, "#7c3aed")
425
- spec = S.registry.get(sid) if S.registry else None
426
- role = spec.role if spec else sid
427
- lat = f"{spec.avg_latency_ms}ms" if spec else ""
428
- is_spawned = sid in spawned_set
429
- symbol = "star" if is_spawned else "circle"
430
- size = 38 if is_spawned else 32
431
- border_c = "#fbbf24" if is_spawned else "rgba(255,255,255,0.4)"
432
- hover_tag = " ⚡ AUTO-SPAWNED" if is_spawned else ""
433
- label = (("⚡ " if is_spawned else "") + sid).replace("_", "<br>")
434
- fig.add_trace(go.Scatter(
435
- x=[x], y=[y], mode="markers+text",
436
- marker=dict(size=size, color=c, symbol=symbol,
437
- line=dict(color=border_c, width=2.5), opacity=1.0),
438
- text=[label], textposition="bottom center",
439
- textfont=dict(size=8, color="#fbbf24" if is_spawned else "#e2e8f0"),
440
- hovertext=[f"<b>{role}</b><br>Called ✓{hover_tag}<br>{lat}"],
441
- hoverinfo="text", showlegend=False,
442
- ))
443
-
444
- # ── Uncalled specialist nodes (dimmed) ───────────────────────────
445
- for sid in uncalled:
446
- if sid not in pos:
447
- continue
448
- x, y = pos[sid]
449
- c = SPEC_COLORS.get(sid, "#334155")
450
- spec = S.registry.get(sid) if S.registry else None
451
- role = spec.role if spec else sid
452
- label = sid.replace("_", "<br>")
453
- fig.add_trace(go.Scatter(
454
- x=[x], y=[y], mode="markers+text",
455
- marker=dict(size=16, color="#1e293b", symbol="circle",
456
- line=dict(color=c, width=1), opacity=0.5),
457
- text=[label], textposition="bottom center",
458
- textfont=dict(size=7, color="rgba(148,163,184,0.45)"),
459
- hovertext=[f"<b>{role}</b><br>Not called"],
460
- hoverinfo="text", showlegend=False,
461
- ))
462
-
463
- # ── Section labels ───────────────────────────────────────────────
464
- fig.add_annotation(x=0.01, y=0.96, xref="x", yref="y",
465
- text="ORCHESTRATOR", showarrow=False,
466
- font=dict(size=8, color="#475569"), xanchor="left")
467
- if called_ids:
468
- fig.add_annotation(x=0.01, y=0.62, xref="x", yref="y",
469
- text="CALLED", showarrow=False,
470
- font=dict(size=8, color="#00d4ff"), xanchor="left")
471
- if uncalled:
472
- fig.add_annotation(x=0.01, y=0.19, xref="x", yref="y",
473
- text="AVAILABLE", showarrow=False,
474
- font=dict(size=8, color="#334155"), xanchor="left")
475
-
476
- fig.update_layout(
477
- **DARK, height=420,
478
- title=dict(
479
- text=(f"Delegation Graph · {len(called_ids)} specialists called"
480
- f" · Depth {cur_depth}/{max_depth}"),
481
- font=dict(size=13, color="#94a3b8"),
482
- ),
483
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.05]),
484
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.05, 1.08]),
485
- )
486
- return fig
487
-
488
-
489
- def fig_reward_breakdown(components: dict) -> go.Figure:
490
- if not components:
491
- components = {k: 0.0 for k in [
492
- "quality_delta", "efficiency_penalty", "failure_penalty",
493
- "recovery_bonus", "conflict_penalty", "conflict_bonus",
494
- "consistency_bonus", "latency_penalty", "explanation_bonus",
495
- ]}
496
- names = list(components.keys())
497
- values = [components[k] for k in names]
498
- fig = go.Figure(go.Bar(
499
- x=values,
500
- y=[n.replace("_", " ").title() for n in names],
501
- orientation="h",
502
- marker_color=["#10b981" if v >= 0 else "#ef4444" for v in values],
503
- marker_line_width=0,
504
- text=[f"{v:+.3f}" for v in values],
505
- textposition="outside",
506
- textfont=dict(color="#94a3b8", size=9),
507
- ))
508
- fig.add_vline(x=0, line_color="rgba(255,255,255,0.15)", line_width=1)
509
- fig.update_layout(**DARK, height=310,
510
- title=dict(text="Reward Breakdown", font=dict(size=13, color="#94a3b8")),
511
- xaxis=dict(gridcolor="rgba(255,255,255,0.05)", title="Value"),
512
- yaxis=dict(gridcolor="rgba(255,255,255,0.05)"))
513
- return fig
514
-
515
-
516
- def fig_policy_confidence(
517
- entropies: list[float],
518
- step_labels: list[int] | None = None,
519
- ) -> go.Figure:
520
- """
521
- Policy confidence chart — specialist-selection entropy per step.
522
- High entropy = uncertain/exploring. Low = confident/committed.
523
- Real data from actual action vectors used each step.
524
- """
525
- if not entropies:
526
- fig = go.Figure()
527
- fig.update_layout(
528
- **DARK, **DARK_AXES,
529
- title=dict(text="Policy Confidence (Action Entropy)",
530
- font=dict(size=13, color="#64748b")),
531
- annotations=[dict(text="Run an episode to see real action entropy",
532
- x=0.5, y=0.5, showarrow=False,
533
- font=dict(color="#334155", size=12))],
534
- )
535
- return fig
536
-
537
- steps = step_labels or list(range(1, len(entropies) + 1))
538
- max_e = float(np.log(max(len(entropies), 2)))
539
- norm_e = [min(1.0, max(0.0, e / max(max_e, 1e-8))) for e in entropies]
540
- colors = [
541
- f"rgba({int(0 + 124 * ne)},{int(212 - 154 * ne)},{int(255 - 58 * ne)},0.85)"
542
- for ne in norm_e
543
- ]
544
-
545
- fig = go.Figure()
546
- fig.add_trace(go.Bar(
547
- x=steps, y=norm_e,
548
- marker_color=colors, marker_line_width=0,
549
- name="Normalised entropy",
550
- text=[f"{e:.3f}" for e in entropies],
551
- textposition="outside",
552
- textfont=dict(size=8, color="#94a3b8"),
553
- hovertemplate="Step %{x}<br>Entropy: %{text}<extra></extra>",
554
- ))
555
- fig.add_hline(y=0.5, line_dash="dot", line_color="rgba(148,163,184,0.3)",
556
- annotation_text="Mid-entropy", annotation_font_color="#475569")
557
- fig.update_layout(
558
- **DARK, height=260,
559
- title=dict(text="Policy Confidence — Specialist Selection Entropy per Step",
560
- font=dict(size=12, color="#94a3b8")),
561
- xaxis=dict(title="Episode Step", gridcolor="rgba(255,255,255,0.05)",
562
- zerolinecolor="rgba(255,255,255,0.08)"),
563
- yaxis=dict(title="Entropy (0=certain, 1=uniform)", range=[0, 1.15],
564
- gridcolor="rgba(255,255,255,0.05)", zerolinecolor="rgba(255,255,255,0.08)"),
565
- showlegend=False,
566
- )
567
- return fig
568
-
569
-
570
- def fig_similarity(registry: SpecialistRegistry) -> go.Figure:
571
- ids = registry.list_ids()
572
- n = len(ids)
573
-
574
- if n == 0:
575
- fig = go.Figure()
576
- fig.update_layout(**DARK, title=dict(text="No specialists in registry",
577
- font=dict(size=13, color="#64748b")))
578
- return fig
579
-
580
- missing = [sid for sid in ids if registry.get(sid).embedding is None]
581
- if missing:
582
- fig = go.Figure()
583
- fig.update_layout(
584
- **DARK, **DARK_AXES,
585
- title=dict(text="Embeddings not computed — boot the environment first",
586
- font=dict(size=13, color="#64748b")),
587
- annotations=[dict(text=f"Missing embeddings: {', '.join(missing[:4])}",
588
- x=0.5, y=0.5, showarrow=False,
589
- font=dict(color="#334155", size=12))],
590
- )
591
- return fig
592
-
593
- mat = np.zeros((n, n))
594
- try:
595
- for i, a in enumerate(ids):
596
- for j, b in enumerate(ids):
597
- ea = registry.get(a).to_state_vector()
598
- eb = registry.get(b).to_state_vector()
599
- mat[i][j] = float(np.dot(ea, eb))
600
- except Exception as exc:
601
- fig = go.Figure()
602
- fig.update_layout(**DARK, title=dict(text=f"Similarity error: {exc}",
603
- font=dict(size=13, color="#ef4444")))
604
- return fig
605
- labels = [x.replace("_", "<br>") for x in ids]
606
- fig = go.Figure(go.Heatmap(
607
- z=mat, x=labels, y=labels,
608
- colorscale=[[0, "#0f0f1a"], [0.5, "rgba(124,58,237,0.6)"], [1, "#00d4ff"]],
609
- showscale=True, zmin=0, zmax=1,
610
- text=np.round(mat, 2), texttemplate="%{text}", textfont=dict(size=9),
611
- ))
612
- fig.update_layout(**DARK, height=400,
613
- title=dict(text="Capability Similarity (Cosine)", font=dict(size=13, color="#94a3b8")))
614
- return fig
615
-
616
-
617
- def fig_training_curve() -> go.Figure:
618
- path = ASSETS / "reward_curve.json"
619
- if path.exists():
620
- with open(path) as f:
621
- d = json.load(f)
622
- eps, rews = d["episodes"], d["mean_rewards"]
623
- else:
624
- rng = np.random.default_rng(42)
625
- eps = list(range(0, 201, 5))
626
- rews = [float(np.clip(0.1 + 0.5 * (1 - np.exp(-e / 80)) + rng.normal(0, 0.04), 0, 1))
627
- for e in eps]
628
- smooth = [float(np.mean(rews[max(0, i - 4):i + 1])) for i in range(len(rews))]
629
- fig = go.Figure()
630
- fig.add_trace(go.Scatter(x=eps, y=rews, mode="markers",
631
- marker=dict(size=5, color="rgba(0,212,255,0.35)"),
632
- name="Episode"))
633
- fig.add_trace(go.Scatter(x=eps, y=smooth, mode="lines",
634
- line=dict(color="#00d4ff", width=2.5),
635
- fill="tozeroy", fillcolor="rgba(0,212,255,0.06)",
636
- name="Smoothed"))
637
- fig.add_hline(y=0.1, line_dash="dash", line_color="rgba(148,163,184,0.35)",
638
- annotation_text="Random baseline", annotation_font_color="#64748b")
639
- fig.update_layout(**DARK, **DARK_AXES, height=340,
640
- title=dict(text="Training Progress — Mean Reward per Episode",
641
- font=dict(size=13, color="#94a3b8")),
642
- xaxis_title="Episode", yaxis_title="Mean Reward",
643
- legend=dict(bgcolor="rgba(0,0,0,0)"))
644
- return fig
645
-
646
-
647
- def fig_training_entropy() -> go.Figure:
648
- """
649
- Policy entropy over training.
650
- Reads from demo/assets/entropy_log.json if produced by train.py,
651
- or from current session entropy if no log exists.
652
- Never shows fake data — gracefully absent if neither source exists.
653
- """
654
- path = ASSETS / "entropy_log.json"
655
- S = _S()
656
-
657
- if path.exists():
658
- with open(path) as f:
659
- d = json.load(f)
660
- episodes = d["episodes"]
661
- entropies = d["mean_entropies"]
662
- source_label = "From training log"
663
- elif S.step_entropies:
664
- episodes = list(range(1, len(S.step_entropies) + 1))
665
- entropies = S.step_entropies
666
- source_label = "Current episode (live)"
667
- else:
668
- fig = go.Figure()
669
- fig.update_layout(
670
- **DARK, **DARK_AXES,
671
- title=dict(text="Policy Entropy — Run training to populate",
672
- font=dict(size=13, color="#64748b")),
673
- annotations=[dict(
674
- text="Run python training/train.py to generate entropy logs",
675
- x=0.5, y=0.5, showarrow=False,
676
- font=dict(color="#334155", size=12),
677
- )],
678
- )
679
- return fig
680
-
681
- fig = go.Figure()
682
- fig.add_trace(go.Scatter(
683
- x=episodes, y=entropies, mode="lines+markers",
684
- line=dict(color="#7c3aed", width=2.2),
685
- marker=dict(size=4, color="#a78bfa"),
686
- fill="tozeroy", fillcolor="rgba(124,58,237,0.06)",
687
- name=source_label,
688
- ))
689
- fig.update_layout(
690
- **DARK, **DARK_AXES, height=280,
691
- title=dict(text=f"Policy Entropy over Training ({source_label})",
692
- font=dict(size=13, color="#94a3b8")),
693
- xaxis_title="Episode / Step",
694
- yaxis_title="Action Selection Entropy",
695
- legend=dict(bgcolor="rgba(0,0,0,0)"),
696
- )
697
- return fig
698
-
699
-
700
- # ─────────────────────────────────────────────────────────
701
- # Quality-comparison helpers
702
- # ─────────────────────────────────────────────────────────
703
- def _generate_generic_output(task: str) -> str:
704
- """Call GPT-4o-mini directly with the task — no specialist routing."""
705
- import os
706
- api_key = os.getenv("OPENAI_API_KEY")
707
- if not api_key:
708
- return (
709
- "General problem-solving approach:\n"
710
- "1. Gather and clarify requirements\n"
711
- "2. Research common solution patterns\n"
712
- "3. Draft a high-level architecture\n"
713
- "4. Implement in small, testable increments\n"
714
- "5. Validate against acceptance criteria and deploy\n"
715
- "No specialist domain expertise applied."
716
- )
717
- try:
718
- from openai import OpenAI
719
- resp = OpenAI(api_key=api_key).chat.completions.create(
720
- model="gpt-4o-mini",
721
- max_tokens=600,
722
- messages=[
723
- {"role": "system",
724
- "content": "You are a general-purpose software engineering assistant."},
725
- {"role": "user",
726
- "content": f"Provide a detailed solution approach for this task:\n\n{task}"},
727
- ],
728
- )
729
- return resp.choices[0].message.content
730
- except Exception as exc:
731
- return f"(Generic output generation failed: {exc})"
732
-
733
-
734
- def _t1_relevance(task: str, output: str, registry) -> float:
735
- """Cosine similarity between task and output embeddings, scaled 0–10."""
736
- try:
737
- import numpy as np
738
- t = registry.embed_query(task)
739
- o = registry.embed_query(output[:800])
740
- if t is None or o is None:
741
- return 0.0
742
- cos = float(np.dot(t, o) / (np.linalg.norm(t) * np.linalg.norm(o) + 1e-8))
743
- return round(max(0.0, cos) * 10, 2)
744
- except Exception:
745
- return 0.0
746
-
747
-
748
- def _judge_compare(task: str, generic: str, specialist: str) -> dict | None:
749
- """GPT-4o-mini rates both outputs on 4 dimensions. Returns {dim: [generic, specialist]}."""
750
- import os, json
751
- api_key = os.getenv("OPENAI_API_KEY")
752
- if not api_key:
753
- return None
754
- prompt = (
755
- f"Task:\n{task[:400]}\n\n"
756
- f"Output A (generic, no specialist routing):\n{generic[:700]}\n\n"
757
- f"Output B (specialist-routed by trained policy):\n{specialist[:700]}\n\n"
758
- "Rate each output 1–10 on: technical_depth, specificity, actionability, coverage.\n"
759
- 'Return JSON only: {"technical_depth":[A,B],"specificity":[A,B],'
760
- '"actionability":[A,B],"coverage":[A,B]}'
761
- )
762
- try:
763
- from openai import OpenAI
764
- resp = OpenAI(api_key=api_key).chat.completions.create(
765
- model="gpt-4o-mini",
766
- max_tokens=150,
767
- response_format={"type": "json_object"},
768
- messages=[{"role": "user", "content": prompt}],
769
- )
770
- return json.loads(resp.choices[0].message.content)
771
- except Exception:
772
- return None
773
-
774
-
775
- def fig_radar_comparison(
776
- gen_scores: dict,
777
- spec_scores: dict,
778
- ) -> go.Figure:
779
- dims = list(gen_scores.keys())
780
- g_vals = [gen_scores[d] for d in dims]
781
- s_vals = [spec_scores[d] for d in dims]
782
- dims_c = dims + [dims[0]]
783
- g_c = g_vals + [g_vals[0]]
784
- s_c = s_vals + [s_vals[0]]
785
-
786
- fig = go.Figure()
787
- fig.add_trace(go.Scatterpolar(
788
- r=g_c, theta=dims_c, fill="toself",
789
- fillcolor="rgba(239,68,68,0.10)",
790
- line=dict(color="#ef4444", width=2),
791
- name="Generic (no routing)",
792
- ))
793
- fig.add_trace(go.Scatterpolar(
794
- r=s_c, theta=dims_c, fill="toself",
795
- fillcolor="rgba(0,212,255,0.13)",
796
- line=dict(color="#00d4ff", width=2.5),
797
- name="Specialist-routed",
798
- ))
799
- fig.update_layout(
800
- paper_bgcolor="rgba(0,0,0,0)",
801
- font=dict(color="#e2e8f0", family="Inter, system-ui, sans-serif"),
802
- polar=dict(
803
- bgcolor="rgba(0,0,0,0)",
804
- radialaxis=dict(
805
- visible=True, range=[0, 10],
806
- gridcolor="rgba(255,255,255,0.08)",
807
- tickfont=dict(size=9, color="#475569"),
808
- ),
809
- angularaxis=dict(
810
- gridcolor="rgba(255,255,255,0.08)",
811
- tickfont=dict(size=11, color="#94a3b8"),
812
- ),
813
- ),
814
- title=dict(
815
- text="Quality Radar — Generic vs Specialist-Routed",
816
- font=dict(size=13, color="#94a3b8"),
817
- ),
818
- legend=dict(bgcolor="rgba(0,0,0,0)", font=dict(color="#94a3b8", size=11)),
819
- height=420,
820
- margin=dict(l=60, r=60, t=60, b=40),
821
- )
822
- return fig
823
-
824
-
825
- # ─────────────────────────────────────────────────────────
826
- # UI helpers
827
- # ─────────────────────────────────────────────────────────
828
- def inject_css():
829
- st.markdown("""
830
- <style>
831
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700;800&display=swap');
832
-
833
- html, body, [data-testid="stAppViewContainer"] {
834
- background: #0f0f1a !important;
835
- font-family: 'Inter', system-ui, sans-serif !important;
836
- }
837
- [data-testid="stHeader"] { background: transparent !important; }
838
- [data-testid="stToolbar"] { display: none !important; }
839
-
840
- [data-testid="stTabs"] > div:first-child button {
841
- color: #475569 !important; font-weight: 600 !important; font-size: 13px !important;
842
- }
843
- [data-testid="stTabs"] > div:first-child button[aria-selected="true"] {
844
- color: #00d4ff !important; border-bottom-color: #00d4ff !important;
845
- }
846
-
847
- .stButton > button {
848
- border-radius: 8px !important; font-weight: 600 !important;
849
- font-size: 13px !important; transition: all .18s !important;
850
- border: 1px solid rgba(255,255,255,0.18) !important;
851
- background: rgba(255,255,255,0.10) !important; color: #e2e8f0 !important;
852
- }
853
- .stButton > button:hover {
854
- background: rgba(255,255,255,0.18) !important;
855
- border-color: rgba(0,212,255,0.45) !important;
856
- color: #ffffff !important;
857
- }
858
- .stButton > button[kind="primary"] {
859
- background: linear-gradient(135deg,#00d4ff,#0092bb) !important;
860
- border: none !important; color: #0a0f1a !important;
861
- }
862
- .stButton > button[kind="primary"]:hover {
863
- box-shadow: 0 4px 18px rgba(0,212,255,0.35) !important;
864
- }
865
-
866
- [data-testid="stTextInput"] input,
867
- [data-testid="stTextArea"] textarea {
868
- background: rgba(0,0,0,0.3) !important;
869
- border: 1px solid rgba(255,255,255,0.09) !important;
870
- color: #e2e8f0 !important; border-radius: 8px !important;
871
- }
872
-
873
- [data-testid="stSelectbox"] > div > div {
874
- background: rgba(0,0,0,0.35) !important;
875
- border: 1px solid rgba(255,255,255,0.09) !important;
876
- border-radius: 8px !important; color: #e2e8f0 !important;
877
- }
878
-
879
- [data-testid="stSlider"] [data-testid="stTickBar"] { color: #475569 !important; }
880
-
881
- [data-testid="metric-container"] {
882
- background: rgba(255,255,255,0.03) !important;
883
- border: 1px solid rgba(255,255,255,0.07) !important;
884
- border-radius: 12px !important; padding: 16px !important;
885
- }
886
- [data-testid="stMetric"] label { color: #475569 !important; font-size: 11px !important; }
887
- [data-testid="stMetricValue"] { color: #00d4ff !important; font-weight: 700 !important; }
888
-
889
- [data-testid="stCode"], .stCodeBlock {
890
- background: rgba(0,0,0,0.4) !important;
891
- border: 1px solid rgba(255,255,255,0.07) !important;
892
- border-radius: 10px !important;
893
- }
894
-
895
- hr { border-color: rgba(255,255,255,0.07) !important; }
896
-
897
- ::-webkit-scrollbar { width: 4px; height: 4px; }
898
- ::-webkit-scrollbar-thumb { background: rgba(255,255,255,0.1); border-radius: 4px; }
899
- ::-webkit-scrollbar-track { background: transparent; }
900
- </style>
901
- """, unsafe_allow_html=True)
902
-
903
-
904
- def hero():
905
- st.markdown("""
906
- <div style="background:linear-gradient(135deg,#0f0f1a,#130a22,#091422);
907
- border:1px solid rgba(0,212,255,0.14);border-radius:16px;
908
- padding:28px 36px;margin-bottom:4px;position:relative;overflow:hidden;">
909
- <div style="position:absolute;top:-60px;right:-40px;width:360px;height:360px;
910
- background:radial-gradient(circle,rgba(124,58,237,0.11) 0%,transparent 70%);
911
- pointer-events:none;"></div>
912
- <div style="position:absolute;bottom:-60px;left:15%;width:280px;height:280px;
913
- background:radial-gradient(circle,rgba(0,212,255,0.07) 0%,transparent 70%);
914
- pointer-events:none;"></div>
915
- <div style="font-size:28px;font-weight:800;
916
- background:linear-gradient(90deg,#00d4ff,#7c3aed,#00d4ff);
917
- background-size:200% auto;-webkit-background-clip:text;
918
- -webkit-text-fill-color:transparent;background-clip:text;
919
- margin:0 0 8px;">SpindleFlow RL</div>
920
- <div style="color:#64748b;font-size:13px;margin:0;">
921
- Delegation Policy Learning Environment &mdash;
922
- Teaching orchestrators to route, specialize, and stop.
923
- </div>
924
- </div>
925
- """, unsafe_allow_html=True)
926
-
927
-
928
- def sec(title: str):
929
- st.markdown(
930
- f'<div style="font-size:11px;font-weight:700;color:#475569;text-transform:uppercase;'
931
- f'letter-spacing:1px;padding-bottom:10px;border-bottom:1px solid rgba(255,255,255,0.07);'
932
- f'margin:18px 0 14px;">{title}</div>',
933
- unsafe_allow_html=True,
934
- )
935
-
936
-
937
- def status_bar(msg: str, color: str = "#94a3b8"):
938
- st.markdown(
939
- f'<div style="background:rgba(0,0,0,0.3);border:1px solid rgba(255,255,255,0.07);'
940
- f'border-radius:8px;padding:10px 16px;font-size:12px;color:{color};margin:6px 0 10px;">'
941
- f'{_html.escape(msg)}</div>',
942
- unsafe_allow_html=True,
943
- )
944
-
945
-
946
- def render_live_stats(S: Session) -> None:
947
- """Sidebar live stats strip — all values read directly from session state."""
948
- with st.sidebar:
949
- st.markdown(
950
- '<div style="font-size:10px;font-weight:700;color:#00d4ff;'
951
- 'text-transform:uppercase;letter-spacing:1px;margin-bottom:12px;">'
952
- '● Live Episode Stats</div>',
953
- unsafe_allow_html=True,
954
- )
955
-
956
- status = ("Running" if (S.env is not None and not S.done) else
957
- "Complete" if S.done else "Idle")
958
- status_color = ("#10b981" if status == "Running" else
959
- "#f59e0b" if status == "Complete" else "#475569")
960
- st.markdown(
961
- f'<div style="display:flex;justify-content:space-between;'
962
- f'padding:6px 0;border-bottom:1px solid rgba(255,255,255,0.05);">'
963
- f'<span style="font-size:11px;color:#475569;">Status</span>'
964
- f'<span style="font-size:11px;font-weight:700;color:{status_color};">'
965
- f'{status}</span></div>',
966
- unsafe_allow_html=True,
967
- )
968
-
969
- unique_called = len(set(
970
- sp for h in S.episode_history for sp in h.get("called", [])
971
- ))
972
- dag_depth = str(S.env.delegation_graph.depth) if S.env else "—"
973
-
974
- stats = [
975
- ("Step", str(S.step_n), "#e2e8f0"),
976
- ("Total Reward", f"{sum(S.rewards):+.4f}" if S.rewards else "—",
977
- "#10b981" if (S.rewards and sum(S.rewards) >= 0) else "#ef4444"),
978
- ("Mean Step Rwd",f"{float(np.mean(S.rewards)):+.4f}" if S.rewards else "—", "#94a3b8"),
979
- ("Specialists", str(unique_called), "#7c3aed"),
980
- ("DAG Depth", dag_depth, "#f59e0b"),
981
- ("Mean Entropy", f"{float(np.mean(S.step_entropies)):.3f}"
982
- if S.step_entropies else "—", "#00d4ff"),
983
- ]
984
-
985
- for label, value, color in stats:
986
- st.markdown(
987
- f'<div style="display:flex;justify-content:space-between;'
988
- f'padding:5px 0;border-bottom:1px solid rgba(255,255,255,0.04);">'
989
- f'<span style="font-size:11px;color:#475569;">{label}</span>'
990
- f'<span style="font-size:11px;font-weight:600;color:{color};">'
991
- f'{value}</span></div>',
992
- unsafe_allow_html=True,
993
- )
994
-
995
- if S.rewards:
996
- st.markdown('<div style="margin-top:12px;"></div>', unsafe_allow_html=True)
997
- st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True)
998
-
999
-
1000
- def _render_replay_step(S: Session, step_idx: int) -> None:
1001
- """Render charts for a specific historical step — no env calls."""
1002
- if not S.episode_history or step_idx >= len(S.episode_history):
1003
- st.info("No episode data to replay. Run an episode first.")
1004
- return
1005
-
1006
- snap = S.episode_history[step_idx]
1007
- cumulative = snap["cumulative"]
1008
-
1009
- # Cumulative called specialists up to and including this step
1010
- cumulative_called = list({
1011
- sp
1012
- for h in S.episode_history[:step_idx + 1]
1013
- for sp in h.get("called", [])
1014
- })
1015
-
1016
- st.markdown(
1017
- f'<div style="background:rgba(124,58,237,0.07);border:1px solid rgba(124,58,237,0.2);'
1018
- f'border-radius:10px;padding:12px 18px;font-size:12px;color:#a78bfa;margin-bottom:12px;">'
1019
- f'Replaying Step {snap["step"]} · Action: <b>{snap["action_name"]}</b> · '
1020
- f'Reward: <b>{snap["reward"]:+.4f}</b> · '
1021
- f'Cumulative: <b>{cumulative:+.4f}</b></div>',
1022
- unsafe_allow_html=True,
1023
- )
1024
-
1025
- rc1, rc2 = st.columns(2)
1026
- with rc1:
1027
- st.plotly_chart(
1028
- fig_delegation_graph(S, cumulative_called, snap["edges"], highlight_latest=False),
1029
- use_container_width=True,
1030
- key=f"replay_dag_{step_idx}",
1031
- )
1032
- with rc2:
1033
- st.plotly_chart(
1034
- fig_reward_breakdown(snap["components"]),
1035
- use_container_width=True,
1036
- key=f"replay_breakdown_{step_idx}",
1037
- )
1038
-
1039
- sec("Action Trace at This Step")
1040
- trace_lines = []
1041
- for h in S.episode_history[:step_idx + 1]:
1042
- sign = "+" if h["reward"] >= 0 else ""
1043
- called_str = ", ".join(h["called"]) if h["called"] else "—"
1044
- marker = "► " if h["step"] == snap["step"] else " "
1045
- trace_lines.append(
1046
- f"{marker}Step {h['step']:>2} │ {h['action_name']:<22} │ "
1047
- f"reward: {sign}{h['reward']:.4f} │ specialists: {called_str}"
1048
- )
1049
- st.code("\n".join(trace_lines), language=None)
1050
-
1051
-
1052
- # ─────────────────────────────────────────────────────────
1053
- # Tab 1 — Live Demo
1054
- # ─────────────────────────────────────────────────────────
1055
- def tab_live_demo():
1056
- S = _S()
1057
-
1058
- col_task, col_ctrl = st.columns([3, 2], gap="large")
1059
-
1060
- with col_task:
1061
- sec("Task")
1062
- task_dd = st.selectbox("Preset task", PRESET_TASKS, key="task_dd")
1063
- task_txt = st.text_input("Or enter custom task",
1064
- placeholder="Describe a software engineering task…",
1065
- key="task_txt")
1066
- phase = st.slider("Curriculum phase", 1, 3, 1, key="phase_sl")
1067
-
1068
- with col_ctrl:
1069
- sec("Controls")
1070
- c1, c2 = st.columns(2)
1071
- reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn")
1072
- run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn")
1073
- st.markdown('<div style="height:6px"></div>', unsafe_allow_html=True)
1074
-
1075
- use_trained = st.checkbox("🤖 Use Trained Policy", value=False, key="use_trained",
1076
- help="Load the trained RecurrentPPO model from HF Hub")
1077
- trained_model = obs_mean = obs_var = None
1078
- clip_obs = 10.0
1079
- if use_trained:
1080
- with st.spinner("Loading trained model from HF Hub…"):
1081
- trained_model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO)
1082
- if model_err:
1083
- st.error(f"Model load failed: {model_err}")
1084
- else:
1085
- st.success("Trained policy loaded ✓")
1086
-
1087
- cat = _load_catalog()
1088
- act_type = st.selectbox("Action type (manual mode)",
1089
- ["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"],
1090
- key="act_type",
1091
- disabled=use_trained)
1092
- spec_ids = [sp["id"] for sp in cat]
1093
- spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch",
1094
- disabled=use_trained)
1095
- step_btn = st.button("Execute One Step",
1096
- disabled=(S.env is None or S.done),
1097
- use_container_width=True, key="step_btn")
1098
-
1099
- status_msg = st.session_state.get("demo_status", "Click 'Reset Episode' to start.")
1100
- status_clr = "#34d399" if "complete" in status_msg or "started" in status_msg else "#94a3b8"
1101
- status_bar(status_msg, status_clr)
1102
- st.markdown(_exec_mode_badges(S), unsafe_allow_html=True)
1103
-
1104
- # ── Reset ──────────────────────────────────────────────
1105
- if reset_btn:
1106
- with st.spinner("Initializing environment… (first run ~30 s on CPU)"):
1107
- S.reset(int(phase))
1108
- spawn_note = (
1109
- f" | ⚡ Spawned: {', '.join(S.spawned_specialists)}"
1110
- if S.spawned_specialists else ""
1111
- )
1112
- st.session_state.demo_status = f'Episode started | Task: "{S.task[:90]}"{spawn_note}'
1113
- st.session_state.last_called = []
1114
- st.session_state.last_edges = []
1115
- st.session_state.last_info = {}
1116
- st.rerun()
1117
-
1118
- # ── Step ───────────────────────────────────────────────
1119
- if step_btn and S.env is not None and not S.done:
1120
- if use_trained and trained_model is not None and S.obs_current is not None:
1121
- action, S.lstm_states = _predict(
1122
- trained_model, S.obs_current, S.lstm_states,
1123
- S.episode_starts, obs_mean, obs_var, clip_obs,
1124
- )
1125
- else:
1126
- action = np.zeros(S.env.action_space.shape, dtype=np.float32)
1127
- if act_type == "STOP":
1128
- action[0] = 1.0
1129
- elif act_type == "CALL SPECIALIST":
1130
- ids = S.registry.list_ids()
1131
- if spec_ch in ids:
1132
- idx = ids.index(spec_ch)
1133
- if idx < S.env.max_specialists:
1134
- action[1 + idx] = 1.0
1135
- else:
1136
- action[1] = 1.0
1137
- elif act_type == "PARALLEL SPAWN":
1138
- action[0] = 6.0
1139
- action[1] = 1.0
1140
- if S.env.max_specialists > 1:
1141
- action[2] = 1.0
1142
- action[1 + S.env.max_specialists] = 1.0
1143
- else:
1144
- action = S.env.action_space.sample()
1145
-
1146
- _, r, term, trunc, info = S.step(action)
1147
- done = term or trunc
1148
- sign = "+" if r >= 0 else ""
1149
- msg = f"Step {S.step_n} | reward {sign}{r:.4f} | {'DONE' if done else 'Running…'}"
1150
- if done:
1151
- msg += f" | Total: {sum(S.rewards):+.4f}"
1152
- st.session_state.demo_status = msg
1153
- # Use cumulative called_ids so graph stays populated even after STOP step
1154
- called = list(S.env.called_ids)
1155
- edges = [(e.caller_id, e.callee_id)
1156
- for e in S.env.delegation_graph.get_delegation_path()]
1157
- st.session_state.last_called = called
1158
- st.session_state.last_edges = edges
1159
- st.session_state.last_info = info
1160
- st.rerun()
1161
-
1162
- # ── Run Full ───────────────────────────────────────────
1163
- if run_btn:
1164
- with st.spinner("Running full episode…"):
1165
- S.reset(int(phase))
1166
- info = {}
1167
- for _ in range(15):
1168
- if S.done:
1169
- break
1170
- if use_trained and trained_model is not None and S.obs_current is not None:
1171
- action, S.lstm_states = _predict(
1172
- trained_model, S.obs_current, S.lstm_states,
1173
- S.episode_starts, obs_mean, obs_var, clip_obs,
1174
- )
1175
- else:
1176
- action = S.env.action_space.sample()
1177
- _, _, _, _, info = S.step(action)
1178
- # Use cumulative called_ids so graph stays populated even after STOP step
1179
- called = list(S.env.called_ids) if S.env else []
1180
- edges = [(e.caller_id, e.callee_id)
1181
- for e in S.env.delegation_graph.get_delegation_path()]
1182
- total = sum(S.rewards)
1183
- st.session_state.demo_status = (
1184
- f"Episode complete | {S.step_n} steps | Total reward: {total:+.4f}"
1185
- )
1186
- st.session_state.last_called = called
1187
- st.session_state.last_edges = edges
1188
- st.session_state.last_info = info
1189
- st.rerun()
1190
-
1191
- # ── Metric strip ──────────────────────────────────────
1192
- if S.env is not None:
1193
- mc1, mc2, mc3, mc4 = st.columns(4)
1194
- mc1.metric("Obs Dim", int(S.env.observation_space.shape[0]))
1195
- mc2.metric("Action Dim", int(S.env.action_space.shape[0]))
1196
- mc3.metric("Specialists", S.registry.size)
1197
- mc4.metric("Phase", phase)
1198
-
1199
- # ── Hero: Robot Orchestrator Widget (full width) ──────
1200
- sec("Orchestrator · Live Delegation View")
1201
- last_info = st.session_state.get("last_info", {})
1202
- render_orchestrator({
1203
- "called": st.session_state.get("last_called", []),
1204
- "active": (st.session_state.get("last_called", []) or [""])[-1]
1205
- if not S.done else "",
1206
- "edges": st.session_state.get("last_edges", []),
1207
- "task": S.task,
1208
- "step": S.step_n,
1209
- "mode": last_info.get("delegation_mode", "SEQUENTIAL"),
1210
- "done": S.done,
1211
- "reward": sum(S.rewards) if S.rewards else None,
1212
- "phase": int(st.session_state.get("phase_sl", 1)),
1213
- })
1214
- # Thought bubble ticker — robot's last internal monologue
1215
- _thoughts = last_info.get("thoughts") or last_info.get("thought")
1216
- if _thoughts:
1217
- st.markdown(
1218
- f'<div style="font-size:11px;color:#64748b;margin-top:-8px;padding:4px 8px;">'
1219
- f'💭 {_html.escape(str(_thoughts))}</div>',
1220
- unsafe_allow_html=True,
1221
- )
1222
-
1223
- # ── Three-column secondary row ─────────────────────────
1224
- sc1, sc2, sc3 = st.columns([4, 4, 4])
1225
- with sc1:
1226
- st.plotly_chart(fig_reward_curve(S.rewards), use_container_width=True)
1227
- with sc2:
1228
- last_info = st.session_state.get("last_info", {})
1229
- st.plotly_chart(
1230
- fig_reward_breakdown(last_info.get("reward_components", {})),
1231
- use_container_width=True,
1232
- )
1233
- with sc3:
1234
- sec("Policy Confidence")
1235
- if S.step_entropies:
1236
- st.plotly_chart(
1237
- fig_policy_confidence(
1238
- S.step_entropies,
1239
- [h["step"] for h in S.episode_history],
1240
- ),
1241
- use_container_width=True,
1242
- )
1243
- else:
1244
- st.markdown(
1245
- '<div style="color:#334155;font-size:11px;padding:24px;text-align:center;">'
1246
- 'Run an episode to see action entropy.</div>',
1247
- unsafe_allow_html=True,
1248
- )
1249
-
1250
- # ── Step Log (full width) ──────────────────────────────
1251
- sec("Step Log / Action Trace")
1252
- if not S.actions:
1253
- st.markdown(
1254
- '<div style="color:#334155;font-size:12px;padding:16px;text-align:center;">'
1255
- 'Waiting… Reset the episode to start.</div>',
1256
- unsafe_allow_html=True,
1257
- )
1258
- else:
1259
- lines = []
1260
- for i, (inf, r) in enumerate(zip(S.actions, S.rewards)):
1261
- sign = "+" if r >= 0 else ""
1262
- act = inf.get("action_name", "UNKNOWN")
1263
- specs = ", ".join(inf.get("called_specialists", []))
1264
- mode = inf.get("delegation_mode", "")
1265
- e_str = (f" │ entropy: {S.step_entropies[i]:.3f}"
1266
- if i < len(S.step_entropies) else "")
1267
- lats = inf.get("specialist_latencies", {})
1268
- lat_str = (
1269
- "\n │ → latency: "
1270
- + ", ".join(f"{k}: {v:.0f}ms" for k, v in lats.items())
1271
- ) if lats else ""
1272
- lines.append(
1273
- f"Step {i+1:>2} │ {act:<22} │ reward: {sign}{r:.4f}{e_str}"
1274
- + (f"\n │ → called: {specs}" if specs else "")
1275
- + (f"\n │ → mode: {mode}" if mode else "")
1276
- + lat_str
1277
- )
1278
- total = sum(S.rewards)
1279
- unique_sp = len(set(sp for h in S.episode_history for sp in h.get("called", [])))
1280
- lines.append(f"{'─'*62}")
1281
- lines.append(
1282
- f"Total reward: {'+' if total>=0 else ''}{total:.4f} │ "
1283
- f"Steps: {len(S.rewards)} │ "
1284
- f"Specialists called: {unique_sp} unique"
1285
- )
1286
- st.code("\n".join(lines), language=None)
1287
-
1288
- # ── Episode Replay (full width) ────────────────────────
1289
- if S.episode_history:
1290
- st.markdown("---")
1291
- sec("Episode Replay Mode")
1292
- st.caption(
1293
- "Scrub backward through every step of the episode. "
1294
- "Delegation graph, reward breakdown, and action trace all update to that exact state. "
1295
- "100% real data — no re-simulation."
1296
- )
1297
- n_steps = len(S.episode_history)
1298
- if n_steps > 1:
1299
- replay_step = st.slider(
1300
- "Replay step",
1301
- min_value=1,
1302
- max_value=n_steps,
1303
- value=n_steps,
1304
- step=1,
1305
- key="replay_slider",
1306
- format="Step %d",
1307
- )
1308
- else:
1309
- replay_step = 1
1310
- st.caption("Single-step episode — showing step 1.")
1311
- _render_replay_step(S, replay_step - 1)
1312
-
1313
-
1314
- # ─────────────────────────────────────────────────────────
1315
- # Tab 2 — Specialists
1316
- # ─────────────────────────────────────────────────────────
1317
- def tab_specialists():
1318
- S = _S()
1319
-
1320
- # Prefer live registry so dynamically-added specialists appear immediately.
1321
- # Fall back to YAML catalog before the environment has been booted.
1322
- if S.registry is not None:
1323
- specialists = S.registry.list_all()
1324
- source_note = None
1325
- else:
1326
- class _SP:
1327
- def __init__(self, d: dict):
1328
- self.id = d["id"]
1329
- self.role = d["role"]
1330
- self.description = d["description"]
1331
- self.complexity_affinity = d["complexity_affinity"]
1332
- self.avg_latency_ms = d["avg_latency_ms"]
1333
- specialists = [_SP(d) for d in _load_catalog()]
1334
- source_note = "Showing YAML catalog — run an episode to load the live registry (includes dynamic additions)."
1335
-
1336
- # ── Dynamically spawned specialists (accumulated from Output tab runs) ──
1337
- spawned_pool = st.session_state.get("spawned_pool", [])
1338
- if spawned_pool:
1339
- sec(f"⚡ Dynamically Spawned · {len(spawned_pool)} new agent{'s' if len(spawned_pool) != 1 else ''}")
1340
- st.caption(
1341
- "These specialists were auto-created during Output tab runs — "
1342
- "triggered when no existing specialist had sufficient domain coverage (similarity < threshold)."
1343
- )
1344
- pool_cols = st.columns(min(len(spawned_pool), 4))
1345
- for i, sp in enumerate(spawned_pool):
1346
- with pool_cols[i % 4]:
1347
- st.markdown(f"""
1348
- <div style="background:rgba(251,191,36,0.06);border:1px solid rgba(251,191,36,0.28);
1349
- border-left:3px solid #fbbf24;border-radius:12px;
1350
- padding:14px;margin-bottom:10px;">
1351
- <div style="font-size:11px;font-weight:700;color:#fbbf24;margin-bottom:5px;">
1352
- ⚡ {_html.escape(sp['role'])}
1353
- </div>
1354
- <div style="font-size:10px;color:#475569;margin-bottom:6px;font-style:italic;">
1355
- Triggered by: {_html.escape(sp['triggered_by'][:70])}…
1356
- </div>
1357
- <div style="font-size:11px;color:#64748b;line-height:1.5;">
1358
- {_html.escape(sp['description'][:100])}…
1359
- </div>
1360
- <div style="font-size:10px;color:#334155;margin-top:8px;padding-top:8px;
1361
- border-top:1px solid rgba(255,255,255,0.05);">
1362
- {sp['avg_latency_ms']} ms &nbsp;·&nbsp; {', '.join(sp.get('complexity_affinity', []))}
1363
- </div>
1364
- </div>""", unsafe_allow_html=True)
1365
- st.markdown("---")
1366
-
1367
- n = len(specialists)
1368
- sec(f"Roster — {n} specialist{'s' if n != 1 else ''}, capability-embedded")
1369
- if source_note:
1370
- st.caption(source_note)
1371
-
1372
- spawned_set = set(S.spawned_specialists) if S.registry is not None else set()
1373
-
1374
- cols = st.columns(4)
1375
- for i, sp in enumerate(specialists):
1376
- c = SPEC_COLORS.get(sp.id, "#7c3aed")
1377
- is_spawned = sp.id in spawned_set
1378
- border_top = "#fbbf24" if is_spawned else c
1379
- spawn_tag = (
1380
- '<span style="font-size:9px;font-weight:700;color:#fbbf24;'
1381
- 'background:rgba(251,191,36,0.1);border:1px solid rgba(251,191,36,0.25);'
1382
- 'border-radius:999px;padding:1px 7px;margin-left:6px;">⚡ AUTO-SPAWNED</span>'
1383
- if is_spawned else ""
1384
- )
1385
- with cols[i % 4]:
1386
- st.markdown(f"""
1387
- <div style="background:rgba(255,255,255,0.025);border:1px solid {c}22;
1388
- border-left:3px solid {border_top};border-radius:12px;
1389
- padding:14px;margin-bottom:10px;">
1390
- <div style="font-size:11px;font-weight:700;color:{c};margin-bottom:6px;">
1391
- {sp.role}{spawn_tag}
1392
- </div>
1393
- <div style="font-size:11px;color:#64748b;line-height:1.5;">
1394
- {_html.escape(sp.description[:90])}…
1395
- </div>
1396
- <div style="font-size:10px;color:#334155;margin-top:8px;padding-top:8px;
1397
- border-top:1px solid rgba(255,255,255,0.05);">
1398
- {sp.avg_latency_ms} ms &nbsp;·&nbsp; {', '.join(sp.complexity_affinity)}
1399
- </div>
1400
- </div>""", unsafe_allow_html=True)
1401
-
1402
- sec("Capability Similarity Matrix")
1403
- if st.button("Load Similarity Matrix", key="sim_btn"):
1404
- with st.spinner("Computing cosine similarity across 384-dim embeddings…"):
1405
- S.boot()
1406
- st.plotly_chart(fig_similarity(S.registry), use_container_width=True)
1407
-
1408
- sec("Add Specialist Dynamically")
1409
- st.caption("New specialists are immediately representable via their 384-dim embedding — no retraining or YAML edits required.")
1410
- c1, c2 = st.columns(2)
1411
- new_id = c1.text_input("ID", placeholder="ml_engineer", key="new_id")
1412
- new_role = c2.text_input("Role", placeholder="ML Engineer", key="new_role")
1413
- new_desc = st.text_area("Description",
1414
- placeholder="Expert in PyTorch, model training, MLOps pipelines…",
1415
- height=80, key="new_desc")
1416
- if st.button("Add to Roster", type="primary", key="add_btn"):
1417
- if new_id.strip() and new_role.strip() and new_desc.strip():
1418
- with st.spinner("Encoding specialist embedding…"):
1419
- S.boot()
1420
- S.registry.add_specialist({
1421
- "id": new_id.strip(), "role": new_role.strip(),
1422
- "description": new_desc.strip(),
1423
- "complexity_affinity": ["moderate", "complex"],
1424
- "avg_latency_ms": 5000,
1425
- })
1426
- st.success(
1427
- f"'{new_id.strip()}' added. "
1428
- "Policy can represent it via 384-dim embedding — no retraining needed."
1429
- )
1430
- st.plotly_chart(fig_similarity(S.registry), use_container_width=True)
1431
- else:
1432
- st.warning("Fill in all three fields.")
1433
-
1434
-
1435
- # ─────────────────────────────────────────────────────────
1436
- # Tab 3 — Training
1437
- # ─────────────────────────────────────────────────────────
1438
- def tab_training():
1439
- sec("Training Progress — Mean Reward per Episode")
1440
-
1441
- st.markdown(
1442
- '<div style="background:rgba(0,212,255,0.06);border:1px solid rgba(0,212,255,0.20);'
1443
- 'border-radius:12px;padding:16px 20px;margin-bottom:18px;">'
1444
- '<div style="font-size:13px;font-weight:700;color:#00d4ff;margin-bottom:6px;">'
1445
- '🔁 Want to run a fresh training run?</div>'
1446
- '<div style="font-size:12px;color:#94a3b8;margin-bottom:10px;">'
1447
- 'Open the <strong style="color:#e2e8f0;">Training Space</strong> below, then click '
1448
- '<strong style="color:#e2e8f0;">▶ Start Training</strong>. '
1449
- 'When the run completes the new model is pushed to HF Hub and this demo loads it automatically.<br>'
1450
- '<span style="color:#fb923c;font-size:11px;">⚠️ Starting a new run will overwrite the current A100-trained policy.</span>'
1451
- '</div>'
1452
- '<a href="https://huggingface.co/spaces/garvitsachdeva/finalRLEnv" target="_blank" '
1453
- 'style="display:inline-block;background:rgba(0,212,255,0.12);border:1px solid rgba(0,212,255,0.35);'
1454
- 'color:#00d4ff;padding:7px 18px;border-radius:8px;text-decoration:none;font-size:13px;font-weight:600;">'
1455
- '🚀 Open Training Space →</a>'
1456
- '</div>',
1457
- unsafe_allow_html=True,
1458
- )
1459
-
1460
- c_fetch, _ = st.columns([2, 5])
1461
- if c_fetch.button("📥 Fetch latest curve from HF Hub", key="fetch_curve"):
1462
- try:
1463
- import shutil
1464
- from huggingface_hub import hf_hub_download
1465
- _tok = os.getenv("HF_TOKEN") or None
1466
- src = hf_hub_download(HF_MODEL_REPO, "reward_curve.json",
1467
- token=_tok, force_download=True)
1468
- ASSETS.mkdir(parents=True, exist_ok=True)
1469
- shutil.copy(src, ASSETS / "reward_curve.json")
1470
- st.success("reward_curve.json updated — chart will refresh.")
1471
- st.cache_data.clear()
1472
- except Exception as exc:
1473
- st.error(f"Download failed: {exc}")
1474
-
1475
- st.plotly_chart(fig_training_curve(), use_container_width=True)
1476
-
1477
- sec("Policy Entropy — Action Confidence Over Training")
1478
- st.caption(
1479
- "Entropy of the specialist-selection distribution. "
1480
- "High = exploring (early training). Low = confident routing (converged policy)."
1481
- )
1482
- st.plotly_chart(fig_training_entropy(), use_container_width=True)
1483
-
1484
- sec("Curriculum Phases")
1485
- c1, c2, c3 = st.columns(3)
1486
- _phase_card = lambda col, color, label, eps, desc: col.markdown(
1487
- f'<div style="background:rgba({color},0.04);border:1px solid rgba({color},0.18);'
1488
- f'border-radius:12px;padding:18px;">'
1489
- f'<div style="font-size:10px;font-weight:700;color:rgb({color});text-transform:uppercase;'
1490
- f'letter-spacing:1px;margin-bottom:8px;">{label}</div>'
1491
- f'<div style="font-size:22px;font-weight:700;color:#e2e8f0;margin-bottom:5px;">{eps}</div>'
1492
- f'<div style="font-size:11px;color:#475569;">{desc}</div></div>',
1493
- unsafe_allow_html=True,
1494
- )
1495
- _phase_card(c1, "0,212,255", "Phase 1 · Atomic", "200 episodes",
1496
- "Agent learns basic routing — which single specialist to call.")
1497
- _phase_card(c2, "124,58,237", "Phase 2 · Moderate", "400 episodes",
1498
- "Agent learns multi-specialist coordination and mode selection.")
1499
- _phase_card(c3, "245,158,11", "Phase 3 · Complex/Enterprise", "600 episodes",
1500
- "Full delegation strategy with DAG depth, fallbacks, and latency trade-offs.")
1501
-
1502
- sec("Quick Start Commands")
1503
- c1, c2 = st.columns(2)
1504
- with c1:
1505
- st.markdown("**Local training**")
1506
- st.code(
1507
- "# Demo mode — no OpenAI key needed\n"
1508
- "cd spindleflow-rl\n"
1509
- "python training/train.py \\\n"
1510
- " --phase 1 --timesteps 50000\n\n"
1511
- "# Monitor in TensorBoard\n"
1512
- "tensorboard --logdir tensorboard_logs/",
1513
- language="bash",
1514
- )
1515
- with c2:
1516
- st.markdown("**Google Colab (T4 GPU, free)**")
1517
- st.code(
1518
- "!git clone https://github.com/garvitsachdevaa/kuchbhi\n"
1519
- "%cd kuchbhi\n"
1520
- "!pip install -r requirements.txt sb3-contrib\n\n"
1521
- "# 5k-step demo run\n"
1522
- "%run colab/train_colab.py",
1523
- language="python",
1524
- )
1525
-
1526
-
1527
- # ─────────────────────────────────────────────────────────
1528
- # Tab 4 — Quality Demo
1529
- # ─────────────────────────────────────────────────────────
1530
- def tab_quality():
1531
- results = st.session_state.get("output_results")
1532
- env_obj = st.session_state.get("output_env")
1533
-
1534
- sec("Live Quality Comparison — Generic vs Specialist-Routed")
1535
-
1536
- if results is None:
1537
- st.markdown(
1538
- '<div style="background:rgba(245,158,11,0.05);border:1px solid rgba(245,158,11,0.2);'
1539
- 'border-radius:12px;padding:28px;text-align:center;">'
1540
- '<div style="font-size:13px;color:#fbbf24;font-weight:600;margin-bottom:8px;">'
1541
- 'No Output run yet</div>'
1542
- '<div style="font-size:12px;color:#64748b;">'
1543
- 'Go to the <b>🎯 Output</b> tab, enter a task, and click '
1544
- '"Run Trained Policy" — then return here to generate the quality comparison.'
1545
- '</div></div>',
1546
- unsafe_allow_html=True,
1547
- )
1548
- else:
1549
- task = results["task"]
1550
- spec_results = results["specialist_results"]
1551
- specialist_text = "\n\n".join(
1552
- f"[{sr['id'].upper()}]\n{sr['output'] or ''}"
1553
- for sr in spec_results if sr.get("output")
1554
- ) or "(no specialist output)"
1555
-
1556
- # Task banner
1557
- st.markdown(
1558
- f'<div style="background:rgba(0,212,255,0.04);border:1px solid rgba(0,212,255,0.18);'
1559
- f'border-radius:10px;padding:12px 18px;margin-bottom:16px;">'
1560
- f'<span style="font-size:9px;font-weight:700;color:#475569;text-transform:uppercase;'
1561
- f'letter-spacing:1px;">Comparing outputs for: </span>'
1562
- f'<span style="font-size:12px;color:#e2e8f0;">{_html.escape(task[:140])}</span>'
1563
- f'</div>',
1564
- unsafe_allow_html=True,
1565
- )
1566
-
1567
- comp_data = st.session_state.get("quality_comparison")
1568
- already_computed = comp_data is not None and comp_data.get("task") == task
1569
-
1570
- if not already_computed:
1571
- if st.button("⚡ Generate Quality Comparison", type="primary", key="gen_comp_btn"):
1572
- with st.spinner("Generating generic output + running GPT-4o-mini judge…"):
1573
- generic_text = _generate_generic_output(task)
1574
- registry = env_obj.registry if env_obj else None
1575
-
1576
- gen_t1 = _t1_relevance(task, generic_text, registry) if registry else 5.0
1577
- spec_t1 = _t1_relevance(task, specialist_text, registry) if registry else 7.0
1578
-
1579
- judge = _judge_compare(task, generic_text, specialist_text)
1580
-
1581
- def _pick(key, fallback_g, fallback_s):
1582
- pair = (judge or {}).get(key, [fallback_g, fallback_s])
1583
- return float(pair[0]), float(pair[1])
1584
-
1585
- td_g, td_s = _pick("technical_depth", 5, 7)
1586
- sp_g, sp_s = _pick("specificity", 4, 8)
1587
- ac_g, ac_s = _pick("actionability", 4, 7)
1588
- cv_g, cv_s = _pick("coverage", 5, 8)
1589
-
1590
- gen_scores = {"Task Relevance": gen_t1, "Technical Depth": td_g,
1591
- "Specificity": sp_g, "Actionability": ac_g, "Coverage": cv_g}
1592
- spec_scores = {"Task Relevance": spec_t1, "Technical Depth": td_s,
1593
- "Specificity": sp_s, "Actionability": ac_s, "Coverage": cv_s}
1594
-
1595
- st.session_state.quality_comparison = {
1596
- "task": task,
1597
- "generic": generic_text,
1598
- "specialist": specialist_text,
1599
- "gen_scores": gen_scores,
1600
- "spec_scores": spec_scores,
1601
- }
1602
- st.rerun()
1603
-
1604
- comp_data = st.session_state.get("quality_comparison")
1605
- if comp_data and comp_data.get("task") == task:
1606
- gen_scores = comp_data["gen_scores"]
1607
- spec_scores = comp_data["spec_scores"]
1608
-
1609
- # ── Score summary strip ─────────────────────────────────────
1610
- sec("Score Summary")
1611
- cols = st.columns(len(gen_scores))
1612
- for i, (dim, g_val) in enumerate(gen_scores.items()):
1613
- s_val = spec_scores[dim]
1614
- delta = round(s_val - g_val, 1)
1615
- cols[i].metric(
1616
- dim,
1617
- f"{s_val:.1f} / 10",
1618
- f"{delta:+.1f} vs generic",
1619
- )
1620
-
1621
- # ── Radar chart ─────────────────────────────────────────────
1622
- sec("Quality Radar")
1623
- st.plotly_chart(
1624
- fig_radar_comparison(gen_scores, spec_scores),
1625
- use_container_width=True,
1626
- key="quality_radar",
1627
- )
1628
-
1629
- # ── Side-by-side score bars ──────────────────────────────────
1630
- sec("Per-Dimension Score Breakdown")
1631
- dims = list(gen_scores.keys())
1632
- g_vals = [gen_scores[d] for d in dims]
1633
- s_vals = [spec_scores[d] for d in dims]
1634
- bar_fig = go.Figure()
1635
- bar_fig.add_trace(go.Bar(
1636
- name="Generic", x=dims, y=g_vals,
1637
- marker_color="rgba(239,68,68,0.75)", marker_line_width=0,
1638
- text=[f"{v:.1f}" for v in g_vals], textposition="outside",
1639
- textfont=dict(size=10, color="#94a3b8"),
1640
- ))
1641
- bar_fig.add_trace(go.Bar(
1642
- name="Specialist", x=dims, y=s_vals,
1643
- marker_color="rgba(0,212,255,0.75)", marker_line_width=0,
1644
- text=[f"{v:.1f}" for v in s_vals], textposition="outside",
1645
- textfont=dict(size=10, color="#94a3b8"),
1646
- ))
1647
- bar_fig.update_layout(
1648
- **DARK, **DARK_AXES, height=300, barmode="group",
1649
- legend=dict(bgcolor="rgba(0,0,0,0)", font=dict(color="#94a3b8")),
1650
- )
1651
- bar_fig.update_yaxes(range=[0, 11], gridcolor="rgba(255,255,255,0.05)")
1652
- st.plotly_chart(bar_fig, use_container_width=True, key="quality_bars")
1653
-
1654
- # ── Side-by-side text ────────────────────────────────────────
1655
- sec("Output Text Comparison")
1656
- c1, c2 = st.columns(2)
1657
- with c1:
1658
- st.markdown(
1659
- '<div style="font-size:10px;font-weight:700;color:#ef4444;'
1660
- 'text-transform:uppercase;letter-spacing:1px;margin-bottom:8px;">'
1661
- '✗ Generic Output (No Delegation)</div>',
1662
- unsafe_allow_html=True,
1663
- )
1664
- st.code(comp_data["generic"][:1200], language=None)
1665
- with c2:
1666
- st.markdown(
1667
- '<div style="font-size:10px;font-weight:700;color:#10b981;'
1668
- 'text-transform:uppercase;letter-spacing:1px;margin-bottom:8px;">'
1669
- '✓ Specialist-Routed Output (Trained Policy)</div>',
1670
- unsafe_allow_html=True,
1671
- )
1672
- st.code(comp_data["specialist"][:1200], language=None)
1673
-
1674
- sec("Policy Tuning — Quality vs Latency")
1675
- c1, c2 = st.columns(2)
1676
- with c1:
1677
- st.markdown("""
1678
- <div style="background:rgba(124,58,237,0.05);border:1px solid rgba(124,58,237,0.2);
1679
- border-radius:12px;padding:16px;">
1680
- <div style="font-size:10px;font-weight:700;color:#a78bfa;text-transform:uppercase;
1681
- letter-spacing:1px;margin-bottom:8px;">Quality Policy</div>
1682
- <div style="font-size:12px;color:#64748b;line-height:1.8;">
1683
- 5 specialists &nbsp;·&nbsp; sequential &nbsp;·&nbsp; ~180 s<br>
1684
- <code style="color:#a78bfa;background:rgba(124,58,237,0.12);
1685
- padding:2px 6px;border-radius:4px;">latency_weight = 0.0</code>
1686
- </div>
1687
- </div>""", unsafe_allow_html=True)
1688
- with c2:
1689
- st.markdown("""
1690
- <div style="background:rgba(0,212,255,0.05);border:1px solid rgba(0,212,255,0.2);
1691
- border-radius:12px;padding:16px;">
1692
- <div style="font-size:10px;font-weight:700;color:#00d4ff;text-transform:uppercase;
1693
- letter-spacing:1px;margin-bottom:8px;">Latency Policy</div>
1694
- <div style="font-size:12px;color:#64748b;line-height:1.8;">
1695
- 3 specialists &nbsp;·&nbsp; parallel &nbsp;·&nbsp; ~45 s<br>
1696
- <code style="color:#00d4ff;background:rgba(0,212,255,0.1);
1697
- padding:2px 6px;border-radius:4px;">latency_weight = 0.15</code>
1698
- </div>
1699
- </div>""", unsafe_allow_html=True)
1700
-
1701
-
1702
- # ─────────────────────────────────────────────────────────
1703
- # Tab 5 — Reward Lab
1704
- # ─────────────────────────────────────────────────────────
1705
- def tab_reward_lab():
1706
- sec("Interactive Reward Explorer")
1707
- st.caption("Tune the reward weights and watch each component update live.")
1708
-
1709
- col_s, col_c = st.columns([1, 2], gap="large")
1710
- with col_s:
1711
- lw = st.slider("Latency Weight", 0.0, 0.50, 0.05, 0.01, key="rl_lw")
1712
- ep = st.slider("Efficiency Penalty", 0.0, 0.20, 0.05, 0.01, key="rl_ep")
1713
- fp = st.slider("Failure Penalty", 0.0, 1.00, 0.30, 0.05, key="rl_fp")
1714
- cw = st.slider("Consistency Bonus", 0.0, 0.50, 0.10, 0.01, key="rl_cw")
1715
- eb = st.slider("Explanation Bonus", 0.0, 0.20, 0.05, 0.01, key="rl_eb")
1716
-
1717
- comps = {
1718
- "quality_delta": 0.42,
1719
- "efficiency_penalty": -ep * 2,
1720
- "failure_penalty": -fp * 0.3,
1721
- "recovery_bonus": 0.08,
1722
- "conflict_penalty": -0.05,
1723
- "conflict_bonus": 0.03,
1724
- "consistency_bonus": cw * 0.6,
1725
- "latency_penalty": -lw * 0.25,
1726
- "explanation_bonus": eb,
1727
- }
1728
- total = sum(comps.values())
1729
- sign = "+" if total >= 0 else ""
1730
- with col_c:
1731
- st.plotly_chart(fig_reward_breakdown(comps), use_container_width=True)
1732
- st.markdown(
1733
- f'<div style="background:rgba(0,212,255,0.05);border:1px solid rgba(0,212,255,0.18);'
1734
- f'border-radius:10px;padding:14px 18px;font-size:13px;color:#94a3b8;">'
1735
- f'Estimated total reward: '
1736
- f'<span style="color:#00d4ff;font-weight:700;font-size:20px;">{sign}{total:.3f}</span>'
1737
- f'</div>',
1738
- unsafe_allow_html=True,
1739
- )
1740
-
1741
-
1742
- # ─────────────────────────────────────────────────────────
1743
- # Tab 6 — Architecture
1744
- # ─────────────────────────────────────────────────────────
1745
- def tab_architecture():
1746
- obs0 = EpisodeState.observation_dim(6)
1747
- act0 = 6 + 6
1748
-
1749
- c1, c2 = st.columns(2)
1750
- with c1:
1751
- sec(f"Observation Space ({obs0:,} dims)")
1752
- st.markdown("""
1753
- | Dims | Component |
1754
- |-----:|-----------|
1755
- | 384 | Task embedding (all-MiniLM-L6-v2) |
1756
- | 2304 | Roster embeddings (6 × 384) |
1757
- | 2304 | Called embeddings (6 × 384) |
1758
- | 384 | Scratchpad embedding |
1759
- | 100 | Delegation graph adjacency (10 × 10) |
1760
- | 6 | Called-specialist mask |
1761
- | 8 | Scalar features |
1762
- """)
1763
- with c2:
1764
- sec(f"Action Space ({act0}-dim Box)")
1765
- st.markdown("""
1766
- | Index | Component |
1767
- |--------|-----------|
1768
- | [0] | Meta-action (STOP / CALL / PARALLEL…) |
1769
- | [1:7] | Specialist selection logits (multi-hot) |
1770
- | [7] | Delegation mode (SEQ / PAR / FAN-OUT…) |
1771
- | [8:12] | Mode parameters (rounds, threshold…) |
1772
- """)
1773
-
1774
- c1, c2, c3 = st.columns(3)
1775
- with c1:
1776
- sec("Policy")
1777
- st.markdown("""
1778
- - **LSTM PPO** (RecurrentPPO)
1779
- - MlpLstmPolicy
1780
- - Hidden: 256 · 1 layer
1781
- - POMDP-safe via LSTM state
1782
- - 4 factored action heads
1783
- """)
1784
- with c2:
1785
- sec("Tiered Reward")
1786
- st.markdown("""
1787
- - **T0** — Structural heuristics
1788
- - **T1** — Cosine embedding sim
1789
- - **T2** — GPT-4o-mini judge
1790
- - **T3** — Full judge (checkpoints)
1791
- - Episode-level tier lock
1792
- """)
1793
- with c3:
1794
- sec("Safety")
1795
- st.markdown("""
1796
- - DAG cycle detection (DFS)
1797
- - Max delegation depth: 2
1798
- - Scratchpad sandbox isolation
1799
- - Injection sanitization
1800
- - Action masking (DAG)
1801
- """)
1802
-
1803
- sec("Reward Function")
1804
- st.code("""total_reward = (
1805
- quality_delta # specialist_score − baseline (same tier)
1806
- − efficiency_penalty # 0.05 × max(0, n_called − expected)
1807
- − failure_penalty # 0.3 per timeout, 0.2 per error
1808
- + recovery_bonus # +0.1 if fallback succeeded
1809
- − conflict_penalty # 0.1 per unresolved conflict
1810
- + conflict_bonus # 0.05 per resolved conflict
1811
- + consistency_bonus # 0.1 × Dirichlet-prior path score
1812
- − latency_penalty # latency_weight × overage_fraction
1813
- + explanation_bonus # 0.05 if delegation is auditable
1814
- )""", language="python")
1815
-
1816
-
1817
- # ────────────────────────────────────────────────���────────
1818
- # Tab 7 — Output (Trained Policy)
1819
- # ─────────────────────────────────────────────────────────
1820
- def tab_output():
1821
- """Run the trained LSTM PPO policy on a custom task and show every specialist's output."""
1822
- hero()
1823
- st.markdown(
1824
- '<div style="font-size:12px;color:#64748b;margin-bottom:16px;">'
1825
- 'Enter any software engineering task. The trained LSTM PPO policy decides which '
1826
- 'specialists to delegate to — each specialist\'s individual output and the collective '
1827
- 'synthesis are shown below.</div>',
1828
- unsafe_allow_html=True,
1829
- )
1830
-
1831
- col_input, col_ctrl = st.columns([3, 1], gap="large")
1832
- with col_input:
1833
- sec("Task")
1834
- task_input = st.text_area(
1835
- "Task description",
1836
- height=110,
1837
- key="output_task_input",
1838
- placeholder=(
1839
- "Build a real-time collaborative code review tool with inline comments, "
1840
- "role-based access control, GitHub webhook integration, and CI/CD pipeline "
1841
- "status display. Include authentication with OAuth2."
1842
- ),
1843
- )
1844
- with col_ctrl:
1845
- sec("Config")
1846
- out_phase = st.selectbox("Curriculum phase", [1, 2, 3], index=1, key="output_phase")
1847
- st.markdown('<div style="height:8px"></div>', unsafe_allow_html=True)
1848
- run_btn = st.button(
1849
- "🚀 Run Trained Policy",
1850
- type="primary",
1851
- use_container_width=True,
1852
- key="output_run_btn",
1853
- )
1854
-
1855
- if run_btn:
1856
- _task = (task_input or "").strip()
1857
- if not _task:
1858
- st.warning("Please enter a task description.")
1859
- return
1860
-
1861
- with st.spinner("Loading trained model from HF Hub…"):
1862
- model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO)
1863
- if model_err:
1864
- st.error(f"Model load failed: {model_err}")
1865
- return
1866
-
1867
- st.success("Trained policy loaded ✓")
1868
-
1869
- with st.spinner("Running episode with trained policy…"):
1870
- try:
1871
- env = SpindleFlowEnv(
1872
- config_path=CONFIG, catalog_path=CATALOG,
1873
- use_real_spindleflow=False, phase=int(out_phase),
1874
- )
1875
- # Inject custom task so the env uses the user's input
1876
- env.task_bank.sample = lambda: _task
1877
-
1878
- obs, info = env.reset()
1879
- task_used = info.get("task", _task)
1880
-
1881
- lstm_states = None
1882
- episode_starts = np.array([True])
1883
- done = False
1884
- rewards: list[float] = []
1885
-
1886
- MIN_SPECIALISTS = 4 # suppress STOP until this many specialists called
1887
-
1888
- for _ in range(15):
1889
- if done:
1890
- break
1891
- obs_arr = obs[np.newaxis, :].copy().astype(np.float32)
1892
- if obs_mean is not None and obs_var is not None:
1893
- obs_arr = np.clip(
1894
- (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8),
1895
- -clip_obs, clip_obs,
1896
- )
1897
- action_batch, lstm_states = model.predict(
1898
- obs_arr,
1899
- state=lstm_states,
1900
- episode_start=episode_starts,
1901
- deterministic=True,
1902
- )
1903
- action = action_batch[0].copy()
1904
- called_set = set(env.called_ids)
1905
- if len(called_set) < MIN_SPECIALISTS:
1906
- # The policy may want to STOP early; when it does, its
1907
- # specialist-selection logits are all low/negative so
1908
- # simply zeroing action[0] still produces garbage selection.
1909
- # Fix: build a fresh action that directly picks the first
1910
- # uncalled specialist with a hard positive logit (1.0).
1911
- roster = env.active_specialist_ids
1912
- uncalled = [sid for sid in roster if sid not in called_set]
1913
- if uncalled:
1914
- action = np.zeros(env.action_space.shape, dtype=np.float32)
1915
- action[0] = 0.0 # MetaAction.CALL_SPECIALIST
1916
- idx = roster.index(uncalled[0])
1917
- if 1 + idx < len(action):
1918
- action[1 + idx] = 1.0
1919
- obs, r, term, trunc, _ = env.step(action)
1920
- rewards.append(float(r))
1921
- done = term or trunc
1922
- episode_starts = np.array([done])
1923
-
1924
- called = list(env.called_ids)
1925
- edges = [(e.caller_id, e.callee_id)
1926
- for e in env.delegation_graph.get_delegation_path()]
1927
- spawned = list(getattr(env, "spawned_this_episode", []))
1928
-
1929
- st.session_state.output_results = {
1930
- "task": task_used,
1931
- "rewards": rewards,
1932
- "called": called,
1933
- "edges": edges,
1934
- "specialist_results": [
1935
- {
1936
- "id": sr.specialist_id,
1937
- "output": sr.output,
1938
- "status": sr.status,
1939
- "latency_ms": sr.latency_ms,
1940
- }
1941
- for sr in env.specialist_results
1942
- ],
1943
- "spawned": spawned,
1944
- }
1945
- # Keep env alive for delegation-graph rendering
1946
- st.session_state.output_env = env
1947
-
1948
- # Persist spawned specialists to shared pool for Specialists tab
1949
- if "spawned_pool" not in st.session_state:
1950
- st.session_state.spawned_pool = []
1951
- existing_ids = {sp["id"] for sp in st.session_state.spawned_pool}
1952
- for sid in spawned:
1953
- if sid not in existing_ids:
1954
- sp_obj = env.registry.get(sid)
1955
- if sp_obj:
1956
- st.session_state.spawned_pool.append({
1957
- "id": sid,
1958
- "role": sp_obj.role,
1959
- "description": sp_obj.description,
1960
- "complexity_affinity": list(sp_obj.complexity_affinity),
1961
- "avg_latency_ms": sp_obj.avg_latency_ms,
1962
- "triggered_by": task_used[:120],
1963
- })
1964
-
1965
- except Exception as exc:
1966
- import traceback
1967
- st.error(f"Episode failed: {exc}")
1968
- st.code(traceback.format_exc(), language=None)
1969
- return
1970
-
1971
- st.rerun()
1972
-
1973
- # ── Display results ────────────────────────────────────────────────
1974
- results = st.session_state.get("output_results")
1975
- env_obj = st.session_state.get("output_env")
1976
-
1977
- if results is None:
1978
- st.markdown(
1979
- '<div style="color:#334155;font-size:12px;padding:40px;text-align:center;">'
1980
- 'Enter a task and click "Run Trained Policy" to see delegation and specialist outputs.'
1981
- '</div>',
1982
- unsafe_allow_html=True,
1983
- )
1984
- return
1985
-
1986
- # Task banner
1987
- st.markdown(
1988
- f'<div style="background:rgba(0,212,255,0.04);'
1989
- f'border:1px solid rgba(0,212,255,0.18);border-radius:10px;'
1990
- f'padding:14px 18px;margin:10px 0 16px;">'
1991
- f'<div style="font-size:9px;font-weight:700;color:#475569;'
1992
- f'text-transform:uppercase;letter-spacing:1px;margin-bottom:5px;">Task</div>'
1993
- f'<div style="font-size:13px;color:#e2e8f0;">{_html.escape(results["task"])}</div>'
1994
- f'</div>',
1995
- unsafe_allow_html=True,
1996
- )
1997
-
1998
- # Metrics strip
1999
- total_r = sum(results["rewards"])
2000
- mc1, mc2, mc3, mc4 = st.columns(4)
2001
- mc1.metric("Total Reward", f"{total_r:+.3f}")
2002
- mc2.metric("Steps", len(results["rewards"]))
2003
- mc3.metric("Specialists Called", len(results["called"]))
2004
- mc4.metric("Auto-Spawned", len(results["spawned"]))
2005
-
2006
- # Orchestrator widget
2007
- sec("Orchestrator · Delegation Visualization")
2008
- render_orchestrator({
2009
- "called": results["called"],
2010
- "active": "",
2011
- "edges": results["edges"],
2012
- "task": results["task"],
2013
- "step": len(results["rewards"]),
2014
- "mode": "SEQUENTIAL",
2015
- "done": True,
2016
- "reward": sum(results["rewards"]),
2017
- "phase": int(st.session_state.get("output_phase", 2)),
2018
- "spawned": results["spawned"],
2019
- })
2020
-
2021
- # Delegation graph
2022
- sec("Delegation Graph")
2023
- if env_obj is not None:
2024
- class _GraphProxy:
2025
- registry = env_obj.registry
2026
- spawned_specialists = results["spawned"]
2027
- env = env_obj
2028
-
2029
- st.plotly_chart(
2030
- fig_delegation_graph(
2031
- _GraphProxy(),
2032
- results["called"],
2033
- results["edges"],
2034
- highlight_latest=False,
2035
- spawned_ids=results["spawned"],
2036
- ),
2037
- use_container_width=True,
2038
- key="output_dag",
2039
- )
2040
-
2041
- # Auto-spawn alert
2042
- if results["spawned"]:
2043
- st.markdown(
2044
- '<div style="background:rgba(251,191,36,0.06);'
2045
- 'border:1px solid rgba(251,191,36,0.22);border-radius:10px;'
2046
- 'padding:10px 16px;margin:8px 0;">'
2047
- '<span style="font-size:10px;font-weight:700;color:#fbbf24;'
2048
- 'text-transform:uppercase;letter-spacing:1px;">⚡ Auto-Spawned: </span>'
2049
- '<span style="font-size:12px;color:#e2e8f0;">'
2050
- + ", ".join(results["spawned"])
2051
- + '</span></div>',
2052
- unsafe_allow_html=True,
2053
- )
2054
-
2055
- # Individual specialist outputs
2056
- spec_results = results["specialist_results"]
2057
- sec(f"Individual Specialist Outputs · {len(spec_results)} called")
2058
-
2059
- if not spec_results:
2060
- st.markdown(
2061
- '<div style="color:#475569;font-size:12px;padding:16px;'
2062
- 'background:rgba(0,0,0,0.2);border-radius:8px;">'
2063
- 'The policy issued STOP without delegating to any specialists.</div>',
2064
- unsafe_allow_html=True,
2065
- )
2066
- else:
2067
- for sr in spec_results:
2068
- sid = sr["id"]
2069
- color = SPEC_COLORS.get(sid, "#7c3aed")
2070
- ok_clr = "#10b981" if sr["status"] == "success" else "#ef4444"
2071
- lat = sr.get("latency_ms", 0)
2072
- label = (
2073
- f"🤖 {sid.replace('_', ' ').title()}"
2074
- f" · {sr['status']} · {lat:.0f} ms"
2075
- )
2076
- with st.expander(label, expanded=True):
2077
- st.markdown(
2078
- f'<div style="border-left:3px solid {color};'
2079
- f'padding:4px 0 4px 12px;margin-bottom:8px;">'
2080
- f'<span style="font-size:10px;color:{color};font-weight:700;">{sid}</span>'
2081
- f'<span style="font-size:10px;color:#475569;"> · status: </span>'
2082
- f'<span style="font-size:10px;color:{ok_clr};">{sr["status"]}</span>'
2083
- f'<span style="font-size:10px;color:#475569;"> · {lat:.0f} ms</span>'
2084
- f'</div>',
2085
- unsafe_allow_html=True,
2086
- )
2087
- st.code(sr["output"] or "(no output)", language=None)
2088
-
2089
- # Synthesized / collective output
2090
- sec("Synthesized Output · Collective Response")
2091
- st.caption("All specialist outputs combined — this is what the orchestrator received.")
2092
- if spec_results:
2093
- parts = [
2094
- f"{'─'*52}\n[{sr['id'].upper()}]\n{'─'*52}\n{sr['output'] or '(empty)'}"
2095
- for sr in spec_results
2096
- ]
2097
- synthesis = "\n\n".join(parts)
2098
- else:
2099
- synthesis = "(no specialists called — policy chose STOP on first step)"
2100
- st.code(synthesis, language=None)
2101
-
2102
-
2103
- # ─────────────────────────────────────────────────────────
2104
- # Entry point
2105
- # ─────────────────────────────────────────────────────────
2106
- def main():
2107
- inject_css()
2108
- S = _S()
2109
- render_live_stats(S)
2110
-
2111
- t1, t2, t3, t4, t5, t6, t7 = st.tabs([
2112
- "🎯 Output",
2113
- "⚡ Training Interface Example",
2114
- "🤖 Specialists",
2115
- "📈 Training",
2116
- "🔍 Quality Demo",
2117
- "🧪 Reward Lab",
2118
- "🏗 Architecture",
2119
- ])
2120
- with t1: tab_output()
2121
- with t2: tab_live_demo()
2122
- with t3: tab_specialists()
2123
- with t4: tab_training()
2124
- with t5: tab_quality()
2125
- with t6: tab_reward_lab()
2126
- with t7: tab_architecture()
2127
-
2128
-
2129
- # Guard allows safe imports for testing without triggering the UI.
2130
- # Streamlit runs scripts with __name__ == "__main__".
2131
- if __name__ == "__main__":
2132
- main()
 
1
+ import streamlit as st
2
+ import sys, os, traceback
 
 
 
 
 
 
 
3
  from pathlib import Path
4
+ import importlib.util
 
 
 
 
 
5
 
6
+ root = Path(__file__).resolve().parent
7
+ sys.path.insert(0, str(root))
8
+ os.chdir(str(root))
 
 
 
9
 
10
  try:
11
+ demo_file = root / "demo" / "streamlit_app.py"
12
+ spec = importlib.util.spec_from_file_location("spindleflow_demo", str(demo_file))
13
+ mod = importlib.util.module_from_spec(spec)
14
+ mod.__file__ = str(demo_file) # demo's own sys.path logic resolves correctly
15
+ sys.modules["spindleflow_demo"] = mod
16
+ spec.loader.exec_module(mod) # runs demo/streamlit_app.py in its own context
17
+ mod.main()
18
+ except SystemExit:
19
+ pass
20
+ except BaseException as e:
21
+ st.error(f"SpindleFlow failed to load: {e}")
22
+ st.code(traceback.format_exc())