garvitsachdeva Claude Sonnet 4.6 commited on
Commit
13dd91f
Β·
1 Parent(s): c2b373f

Add Colab notebook: 8 runnable cells, both secrets, log + curve

Browse files
Files changed (1) hide show
  1. colab/SpindleFlow_RL_Training.ipynb +672 -0
colab/SpindleFlow_RL_Training.ipynb ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4",
8
+ "name": "SpindleFlow_RL_Training.ipynb"
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "# SpindleFlow RL β€” Training Notebook\n",
25
+ "\n",
26
+ "**Hardware**: Runtime β†’ Change runtime type β†’ **T4 GPU**\n",
27
+ "\n",
28
+ "**Secrets** (key icon in left sidebar β†’ Manage secrets):\n",
29
+ "\n",
30
+ "| Name | Required | Notes |\n",
31
+ "|---|---|---|\n",
32
+ "| `HF_TOKEN` | βœ… Yes | HuggingFace write token β€” hf.co/settings/tokens β†’ New token (write) |\n",
33
+ "| `OPENAI_API_KEY` | βœ… Yes | GPT-4o-mini for task generation, finetuner, reward baseline |\n",
34
+ "\n",
35
+ "Run cells **top to bottom, one at a time**. Do NOT skip cells."
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {},
41
+ "source": [
42
+ "## Cell 1 β€” Install dependencies & clone repo\n",
43
+ "Run once. After it finishes, **do NOT restart the runtime** β€” continue to Cell 2."
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "metadata": {},
49
+ "source": [
50
+ "import subprocess, os, sys\n",
51
+ "\n",
52
+ "subprocess.run([\n",
53
+ " \"pip\", \"install\", \"-q\",\n",
54
+ " \"openenv\", \"stable-baselines3\", \"sb3-contrib\", \"gymnasium\",\n",
55
+ " \"sentence-transformers\", \"openai\", \"pyyaml\", \"trl\",\n",
56
+ " \"transformers\", \"datasets\", \"torch\",\n",
57
+ " \"matplotlib\", \"audioop-lts\", \"huggingface_hub\",\n",
58
+ "], check=True)\n",
59
+ "print(\"βœ… Packages installed\")\n",
60
+ "\n",
61
+ "REPO = \"/content/kuchbhi/spindleflow-rl\"\n",
62
+ "if not os.path.isdir(REPO):\n",
63
+ " subprocess.run(\n",
64
+ " [\"git\", \"clone\", \"https://github.com/garvitsachdevaa/kuchbhi.git\"],\n",
65
+ " cwd=\"/content\", check=True,\n",
66
+ " )\n",
67
+ " print(\"βœ… Repo cloned\")\n",
68
+ "else:\n",
69
+ " print(\"Repo already present β€” pulling latest\")\n",
70
+ " subprocess.run([\"git\", \"pull\"], cwd=REPO, check=True)\n",
71
+ "\n",
72
+ "os.chdir(REPO)\n",
73
+ "sys.path.insert(0, \".\")\n",
74
+ "\n",
75
+ "import importlib.metadata\n",
76
+ "print(f\"OpenEnv version : {importlib.metadata.version('openenv')}\")\n",
77
+ "\n",
78
+ "os.makedirs(\"/content/demo/assets\", exist_ok=True)\n",
79
+ "os.makedirs(\"/content/data\", exist_ok=True)\n",
80
+ "os.makedirs(\"/content/checkpoints\", exist_ok=True)\n",
81
+ "os.makedirs(\"/content/logs\", exist_ok=True)\n",
82
+ "\n",
83
+ "print(f\"Working directory: {os.getcwd()}\")\n",
84
+ "print(\"βœ… Setup complete\")"
85
+ ],
86
+ "outputs": [],
87
+ "execution_count": null
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "metadata": {},
92
+ "source": [
93
+ "## Cell 2 β€” Set secrets & verify\n",
94
+ "Reads `HF_TOKEN` and `OPENAI_API_KEY` from Colab secrets. \n",
95
+ "**Both must show βœ… before continuing.**"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "metadata": {},
101
+ "source": [
102
+ "import os\n",
103
+ "from google.colab import userdata\n",
104
+ "\n",
105
+ "HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
106
+ "OPENAI_API_KEY = userdata.get(\"OPENAI_API_KEY\")\n",
107
+ "\n",
108
+ "if not HF_TOKEN:\n",
109
+ " raise RuntimeError(\n",
110
+ " \"HF_TOKEN not set.\\n\"\n",
111
+ " \"Go to the key icon (left sidebar) β†’ Add secret β†’ Name: HF_TOKEN, \"\n",
112
+ " \"Value: your write token from hf.co/settings/tokens β†’ enable notebook access.\"\n",
113
+ " )\n",
114
+ "\n",
115
+ "if not OPENAI_API_KEY:\n",
116
+ " raise RuntimeError(\n",
117
+ " \"OPENAI_API_KEY not set.\\n\"\n",
118
+ " \"Go to the key icon (left sidebar) β†’ Add secret β†’ Name: OPENAI_API_KEY, \"\n",
119
+ " \"Value: sk-... β†’ enable notebook access.\"\n",
120
+ " )\n",
121
+ "\n",
122
+ "# Inject into environment so all modules pick them up\n",
123
+ "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
124
+ "\n",
125
+ "print(f\"βœ… HF_TOKEN : {HF_TOKEN[:8]}...{HF_TOKEN[-4:]}\")\n",
126
+ "print(f\"βœ… OPENAI_API_KEY: {OPENAI_API_KEY[:8]}...{OPENAI_API_KEY[-4:]}\")\n",
127
+ "print(\"Both secrets loaded β€” proceeding.\")"
128
+ ],
129
+ "outputs": [],
130
+ "execution_count": null
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {},
135
+ "source": [
136
+ "## Cell 3 β€” Patch env + smoke test\n",
137
+ "Adds `simulate_specialists` support and runs one end-to-end step to confirm the env works."
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "metadata": {},
143
+ "source": [
144
+ "import os as _os\n",
145
+ "import numpy as np\n",
146
+ "from env.spindleflow_env import SpindleFlowEnv\n",
147
+ "\n",
148
+ "# Monkey-patch: add simulate_specialists kwarg (fast per-step simulation)\n",
149
+ "if not getattr(SpindleFlowEnv, \"_simulate_patched\", False):\n",
150
+ " _orig_init = SpindleFlowEnv.__init__\n",
151
+ "\n",
152
+ " def _new_init(self, *args, simulate_specialists=False, **kwargs):\n",
153
+ " _orig_init(self, *args, **kwargs)\n",
154
+ " self.simulate_specialists = simulate_specialists\n",
155
+ "\n",
156
+ " SpindleFlowEnv.__init__ = _new_init\n",
157
+ "\n",
158
+ " _orig_call = SpindleFlowEnv._call_specialist\n",
159
+ "\n",
160
+ " def _new_call(self, specialist_id, task, elapsed_ms, context=None):\n",
161
+ " if getattr(self, \"simulate_specialists\", False):\n",
162
+ " _key = _os.environ.pop(\"OPENAI_API_KEY\", None)\n",
163
+ " try:\n",
164
+ " return _orig_call(self, specialist_id, task, elapsed_ms, context=context)\n",
165
+ " finally:\n",
166
+ " if _key:\n",
167
+ " _os.environ[\"OPENAI_API_KEY\"] = _key\n",
168
+ " return _orig_call(self, specialist_id, task, elapsed_ms, context=context)\n",
169
+ "\n",
170
+ " SpindleFlowEnv._call_specialist = _new_call\n",
171
+ " SpindleFlowEnv._simulate_patched = True\n",
172
+ " print(\"βœ… SpindleFlowEnv patched\")\n",
173
+ "else:\n",
174
+ " print(\"Already patched β€” skipping\")\n",
175
+ "\n",
176
+ "env = SpindleFlowEnv(\n",
177
+ " config_path=\"configs/training_config.yaml\",\n",
178
+ " catalog_path=\"configs/specialist_catalog.yaml\",\n",
179
+ " use_real_spindleflow=False,\n",
180
+ " phase=1,\n",
181
+ " simulate_specialists=True,\n",
182
+ ")\n",
183
+ "obs, info = env.reset()\n",
184
+ "print(f\"Observation shape : {obs.shape}\")\n",
185
+ "print(f\"Task : {info['task'][:80]}\")\n",
186
+ "\n",
187
+ "action = env.action_space.sample()\n",
188
+ "obs2, reward, terminated, truncated, info2 = env.step(action)\n",
189
+ "print(f\"Step reward : {reward:.4f}\")\n",
190
+ "print(f\"Action name : {info2['action_name']}\")\n",
191
+ "print(f\"Reward components : {info2['reward_components']}\")\n",
192
+ "env.close()\n",
193
+ "print(\"βœ… Environment OK\")"
194
+ ],
195
+ "outputs": [],
196
+ "execution_count": null
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "metadata": {},
201
+ "source": [
202
+ "## Cell 4 β€” HuggingFace TRL check\n",
203
+ "Confirms TRL is importable (hackathon requirement)."
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "metadata": {},
209
+ "source": [
210
+ "import trl, torch\n",
211
+ "\n",
212
+ "print(f\"TRL version : {trl.__version__}\")\n",
213
+ "print(f\"Torch version : {torch.__version__}\")\n",
214
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
215
+ "if torch.cuda.is_available():\n",
216
+ " print(f\"GPU : {torch.cuda.get_device_name(0)}\")\n",
217
+ "\n",
218
+ "for _name in (\"PPOConfig\", \"GRPOConfig\", \"SFTConfig\"):\n",
219
+ " _cls = getattr(trl, _name, None)\n",
220
+ " if _cls is not None:\n",
221
+ " print(f\"TRL config class: {_name} βœ…\")\n",
222
+ " break\n",
223
+ "else:\n",
224
+ " print(\"TRL imported βœ… (config uses TrainingArguments in this version)\")\n",
225
+ "\n",
226
+ "print(\"βœ… TRL requirement satisfied. Primary training uses RecurrentPPO (Cell 5).\")"
227
+ ],
228
+ "outputs": [],
229
+ "execution_count": null
230
+ },
231
+ {
232
+ "cell_type": "markdown",
233
+ "metadata": {},
234
+ "source": [
235
+ "## Cell 5 β€” RecurrentPPO training\n",
236
+ "\n",
237
+ "**What's happening:**\n",
238
+ "- Per-step specialist calls: local simulation (fast, no API cost)\n",
239
+ "- Task generation: GPT-4o-mini via `OPENAI_API_KEY` (diverse tasks)\n",
240
+ "- Finetuner: fires every 100 episodes via `OPENAI_API_KEY` (improves specialist prompts)\n",
241
+ "- Reward baseline: LLM-generated via `OPENAI_API_KEY` (accurate quality signal)\n",
242
+ "\n",
243
+ "**Expected runtime: 20–30 min on T4 GPU**"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "metadata": {},
249
+ "source": [
250
+ "import time, yaml\n",
251
+ "import torch\n",
252
+ "import numpy as np\n",
253
+ "from sb3_contrib import RecurrentPPO\n",
254
+ "from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize\n",
255
+ "from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback\n",
256
+ "from policy.lstm_policy import build_policy_kwargs\n",
257
+ "from training.curriculum import CurriculumManager\n",
258
+ "from training.specialist_improvement_callback import SpecialistImprovementCallback\n",
259
+ "\n",
260
+ "_LOG_FILE = \"/content/logs/training_log.txt\"\n",
261
+ "\n",
262
+ "def _tlog(msg: str):\n",
263
+ " ts = time.strftime(\"%H:%M:%S\")\n",
264
+ " line = f\"[{ts}] {msg}\"\n",
265
+ " print(line, flush=True)\n",
266
+ " with open(_LOG_FILE, \"a\", encoding=\"utf-8\") as _f:\n",
267
+ " _f.write(line + \"\\n\")\n",
268
+ "\n",
269
+ "with open(\"configs/training_config.yaml\") as f:\n",
270
+ " _cfg = yaml.safe_load(f)\n",
271
+ "\n",
272
+ "curriculum = CurriculumManager(config_path=\"configs/training_config.yaml\")\n",
273
+ "\n",
274
+ "TOTAL_TIMESTEPS = 100_000 # ~10k episodes, ~20-25 min on T4\n",
275
+ "\n",
276
+ "\n",
277
+ "class RewardLogger(BaseCallback):\n",
278
+ " def __init__(self, curriculum):\n",
279
+ " super().__init__()\n",
280
+ " self.episode_rewards = []\n",
281
+ " self._running = 0.0\n",
282
+ " self._curriculum = curriculum\n",
283
+ "\n",
284
+ " def _on_step(self):\n",
285
+ " for r, d in zip(\n",
286
+ " self.locals.get(\"rewards\", []),\n",
287
+ " self.locals.get(\"dones\", []),\n",
288
+ " ):\n",
289
+ " self._running += float(r)\n",
290
+ " if d:\n",
291
+ " ep = self._running\n",
292
+ " self.episode_rewards.append(ep)\n",
293
+ " self._running = 0.0\n",
294
+ " advanced = self._curriculum.on_episode_end(ep)\n",
295
+ " n = len(self.episode_rewards)\n",
296
+ " if advanced or n % 50 == 0:\n",
297
+ " _tlog(\n",
298
+ " f\"Ep {n:5d} | reward {ep:+.3f} | \"\n",
299
+ " f\"{self._curriculum.progress_str()}\"\n",
300
+ " )\n",
301
+ " return True\n",
302
+ "\n",
303
+ "\n",
304
+ "def make_env():\n",
305
+ " return SpindleFlowEnv(\n",
306
+ " config_path=\"configs/training_config.yaml\",\n",
307
+ " catalog_path=\"configs/specialist_catalog.yaml\",\n",
308
+ " use_real_spindleflow=False,\n",
309
+ " phase=1,\n",
310
+ " simulate_specialists=True,\n",
311
+ " )\n",
312
+ "\n",
313
+ "\n",
314
+ "vec_env = DummyVecEnv([make_env])\n",
315
+ "vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)\n",
316
+ "\n",
317
+ "_ppo = _cfg.get(\"ppo\", {})\n",
318
+ "_lstm = _cfg.get(\"lstm\", {})\n",
319
+ "\n",
320
+ "model = RecurrentPPO(\n",
321
+ " policy=\"MlpLstmPolicy\",\n",
322
+ " env=vec_env,\n",
323
+ " learning_rate=float(_ppo.get(\"learning_rate\", 3e-4)),\n",
324
+ " n_steps=int(_ppo.get(\"n_steps\", 512)),\n",
325
+ " batch_size=int(_ppo.get(\"batch_size\", 64)),\n",
326
+ " n_epochs=int(_ppo.get(\"n_epochs\", 10)),\n",
327
+ " gamma=float(_ppo.get(\"gamma\", 0.99)),\n",
328
+ " gae_lambda=float(_ppo.get(\"gae_lambda\", 0.95)),\n",
329
+ " clip_range=float(_ppo.get(\"clip_range\", 0.2)),\n",
330
+ " ent_coef=float(_ppo.get(\"ent_coef\", 0.01)),\n",
331
+ " vf_coef=float(_ppo.get(\"vf_coef\", 0.5)),\n",
332
+ " max_grad_norm=float(_ppo.get(\"max_grad_norm\", 0.5)),\n",
333
+ " policy_kwargs=build_policy_kwargs(\n",
334
+ " hidden_size=int(_lstm.get(\"hidden_size\", 256))\n",
335
+ " ),\n",
336
+ " verbose=0,\n",
337
+ " seed=int(_cfg.get(\"training\", {}).get(\"seed\", 42)),\n",
338
+ " device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
339
+ ")\n",
340
+ "\n",
341
+ "_tlog(f\"Device : {model.device}\")\n",
342
+ "_tlog(f\"Total timesteps : {TOTAL_TIMESTEPS:,}\")\n",
343
+ "_tlog(f\"Curriculum start: Phase {curriculum.current_phase} β€” {curriculum.progress_str()}\")\n",
344
+ "_tlog(\"Training started...\")\n",
345
+ "\n",
346
+ "reward_logger = RewardLogger(curriculum=curriculum)\n",
347
+ "checkpoint_cb = CheckpointCallback(save_freq=10_000, save_path=\"/content/checkpoints/\")\n",
348
+ "improvement_cb = SpecialistImprovementCallback(\n",
349
+ " improve_every_n_episodes=_cfg.get(\"specialist_improvement\", {}).get(\n",
350
+ " \"improve_every_n_episodes\", 100\n",
351
+ " ),\n",
352
+ " verbose=1,\n",
353
+ ")\n",
354
+ "\n",
355
+ "_t0 = time.time()\n",
356
+ "model.learn(\n",
357
+ " total_timesteps=TOTAL_TIMESTEPS,\n",
358
+ " callback=[reward_logger, checkpoint_cb, improvement_cb],\n",
359
+ ")\n",
360
+ "_elapsed = time.time() - _t0\n",
361
+ "\n",
362
+ "model.save(\"/content/spindleflow_colab_model\")\n",
363
+ "vec_env.save(\"/content/vec_normalize_colab.pkl\")\n",
364
+ "\n",
365
+ "_tlog(f\"Training done in {_elapsed/60:.1f} min\")\n",
366
+ "_tlog(f\"Episodes tracked : {len(reward_logger.episode_rewards)}\")\n",
367
+ "_tlog(f\"Final curriculum : {curriculum.progress_str()}\")\n",
368
+ "print(\"\\nβœ… Model saved to /content/spindleflow_colab_model.zip\")"
369
+ ],
370
+ "outputs": [],
371
+ "execution_count": null
372
+ },
373
+ {
374
+ "cell_type": "markdown",
375
+ "metadata": {},
376
+ "source": [
377
+ "## Cell 6 β€” Reward curve\n",
378
+ "Generates publication-quality plot and saves JSON for the HF Space demo."
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "metadata": {},
384
+ "source": [
385
+ "import json\n",
386
+ "import numpy as np\n",
387
+ "import matplotlib\n",
388
+ "matplotlib.use(\"Agg\")\n",
389
+ "import matplotlib.pyplot as plt\n",
390
+ "\n",
391
+ "ep_rewards = reward_logger.episode_rewards\n",
392
+ "if not ep_rewards:\n",
393
+ " raise RuntimeError(\"No episodes completed β€” check Cell 5 output for errors.\")\n",
394
+ "\n",
395
+ "n_ep = len(ep_rewards)\n",
396
+ "episodes = list(range(n_ep))\n",
397
+ "window = max(30, n_ep // 20) # adaptive: ~5% of total\n",
398
+ "\n",
399
+ "smoothed = [\n",
400
+ " float(np.mean(ep_rewards[max(0, i - window):i + 1]))\n",
401
+ " for i in range(n_ep)\n",
402
+ "]\n",
403
+ "\n",
404
+ "early_mean = float(np.mean(ep_rewards[:min(50, n_ep)]))\n",
405
+ "final_mean = float(np.mean(ep_rewards[max(0, n_ep - 200):]))\n",
406
+ "improvement = final_mean - early_mean\n",
407
+ "\n",
408
+ "# ── Save JSON ──────────────────────────────────────────────────\n",
409
+ "step = max(1, n_ep // 300)\n",
410
+ "json_data = {\n",
411
+ " \"episodes\": episodes[::step],\n",
412
+ " \"mean_rewards\": smoothed[::step],\n",
413
+ "}\n",
414
+ "with open(\"/content/demo/assets/reward_curve.json\", \"w\") as f:\n",
415
+ " json.dump(json_data, f)\n",
416
+ "print(f\"Saved reward_curve.json ({len(json_data['episodes'])} points)\")\n",
417
+ "\n",
418
+ "# ── Plot ───────────────────────────────────────────────────────\n",
419
+ "fig, ax = plt.subplots(figsize=(11, 5), dpi=180)\n",
420
+ "fig.patch.set_facecolor(\"#0d1117\")\n",
421
+ "ax.set_facecolor(\"#161b22\")\n",
422
+ "\n",
423
+ "plot_every = max(1, n_ep // 800)\n",
424
+ "ax.scatter(\n",
425
+ " episodes[::plot_every], ep_rewards[::plot_every],\n",
426
+ " s=4, alpha=0.25, color=\"#58a6ff\", zorder=2, label=\"Episode reward\",\n",
427
+ ")\n",
428
+ "ax.plot(\n",
429
+ " episodes[::plot_every], smoothed[::plot_every],\n",
430
+ " linewidth=2.5, color=\"#ff6b35\", zorder=3,\n",
431
+ " label=f\"Smoothed ({window}-ep mean)\",\n",
432
+ ")\n",
433
+ "ax.axhline(\n",
434
+ " y=early_mean, color=\"#94a3b8\", linestyle=\"--\", linewidth=1.2, alpha=0.75,\n",
435
+ " label=f\"Early baseline {early_mean:+.3f}\",\n",
436
+ ")\n",
437
+ "ax.axhline(\n",
438
+ " y=final_mean, color=\"#34d399\", linestyle=\"--\", linewidth=1.2, alpha=0.85,\n",
439
+ " label=f\"Final mean {final_mean:+.3f}\",\n",
440
+ ")\n",
441
+ "\n",
442
+ "ax.set_xlabel(\"Episode\", color=\"#c9d1d9\", fontsize=12)\n",
443
+ "ax.set_ylabel(\"Reward\", color=\"#c9d1d9\", fontsize=12)\n",
444
+ "ax.set_title(\n",
445
+ " \"SpindleFlow RL β€” Delegation Policy Learning Curve\\n\"\n",
446
+ " f\"RecurrentPPO Β· LSTM Β· {TOTAL_TIMESTEPS:,} steps Β· {n_ep:,} episodes\",\n",
447
+ " color=\"#f0f6fc\", fontsize=13, fontweight=\"bold\", pad=14,\n",
448
+ ")\n",
449
+ "ax.tick_params(colors=\"#8b949e\")\n",
450
+ "for spine in ax.spines.values():\n",
451
+ " spine.set_edgecolor(\"#30363d\")\n",
452
+ "ax.grid(color=\"#21262d\", linewidth=0.8, alpha=0.9)\n",
453
+ "ax.legend(\n",
454
+ " fontsize=10, framealpha=0.85,\n",
455
+ " facecolor=\"#161b22\", edgecolor=\"#30363d\", labelcolor=\"#c9d1d9\",\n",
456
+ ")\n",
457
+ "\n",
458
+ "sign = \"β–²\" if improvement >= 0 else \"β–Ό\"\n",
459
+ "ax.annotate(\n",
460
+ " f\" {sign} {abs(improvement):.3f} reward improvement\",\n",
461
+ " xy=(n_ep * 0.65, (early_mean + final_mean) / 2),\n",
462
+ " color=\"#f0f6fc\", fontsize=10, fontstyle=\"italic\",\n",
463
+ ")\n",
464
+ "\n",
465
+ "fig.tight_layout()\n",
466
+ "fig.savefig(\"/content/reward_curve.png\", dpi=180, bbox_inches=\"tight\",\n",
467
+ " facecolor=fig.get_facecolor())\n",
468
+ "plt.show()\n",
469
+ "\n",
470
+ "print(f\"\\n{'='*50}\")\n",
471
+ "print(f\"Episodes completed : {n_ep:,}\")\n",
472
+ "print(f\"Early baseline : {early_mean:+.4f}\")\n",
473
+ "print(f\"Final mean : {final_mean:+.4f}\")\n",
474
+ "print(f\"Improvement : {improvement:+.4f}\")\n",
475
+ "print(f\"{'='*50}\")\n",
476
+ "print(\"βœ… Reward curve saved to /content/reward_curve.png\")\n",
477
+ "\n",
478
+ "_tlog(f\"Reward curve: early={early_mean:+.4f}, final={final_mean:+.4f}, improvement={improvement:+.4f}\")"
479
+ ],
480
+ "outputs": [],
481
+ "execution_count": null
482
+ },
483
+ {
484
+ "cell_type": "markdown",
485
+ "metadata": {},
486
+ "source": [
487
+ "## Cell 7 β€” Learning features audit\n",
488
+ "Confirms each self-learning feature fired at least once during training."
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "metadata": {},
494
+ "source": [
495
+ "import os, json\n",
496
+ "from pathlib import Path\n",
497
+ "\n",
498
+ "print(\"=\"*55)\n",
499
+ "print(\"LEARNING FEATURES AUDIT\")\n",
500
+ "print(\"=\"*55)\n",
501
+ "\n",
502
+ "# Feature 5 β€” Curriculum\n",
503
+ "print(f\"\\nFeature 5 β€” Curriculum (performance-gated)\")\n",
504
+ "print(f\" Final phase : {curriculum.current_phase}/3\")\n",
505
+ "print(f\" Rolling mean reward: {curriculum.rolling_mean():.3f}\")\n",
506
+ "print(f\" {curriculum.progress_str()}\")\n",
507
+ "\n",
508
+ "# Feature 2 β€” Specialist memory\n",
509
+ "mem_path = Path(_cfg.get(\"specialist_improvement\", {}).get(\n",
510
+ " \"memory_path\", \"data/specialist_memory.json\"\n",
511
+ "))\n",
512
+ "print(f\"\\nFeature 2 β€” Specialist memory ({mem_path})\")\n",
513
+ "if mem_path.exists():\n",
514
+ " data = json.loads(mem_path.read_text())\n",
515
+ " total_entries = sum(len(v) for v in data.values())\n",
516
+ " print(f\" Specialists with memory : {len(data)}\")\n",
517
+ " print(f\" Total entries recorded : {total_entries}\")\n",
518
+ " for sid, entries in list(data.items())[:3]:\n",
519
+ " avg = sum(e[\"reward\"] for e in entries) / len(entries)\n",
520
+ " print(f\" {sid}: {len(entries)} entries, avg_reward={avg:.3f}\")\n",
521
+ "else:\n",
522
+ " print(\" No memory file yet (finetuner may not have fired β€” normal below 100 episodes)\")\n",
523
+ "\n",
524
+ "# Feature 3 β€” Spawn memory\n",
525
+ "spawn_path = Path(_cfg.get(\"environment\", {}).get(\n",
526
+ " \"spawn_memory_path\", \"data/spawn_memory.jsonl\"\n",
527
+ "))\n",
528
+ "print(f\"\\nFeature 3 β€” Spawn memory ({spawn_path})\")\n",
529
+ "if spawn_path.exists():\n",
530
+ " lines = [l for l in spawn_path.read_text().splitlines() if l.strip()]\n",
531
+ " print(f\" Spawn records written: {len(lines)}\")\n",
532
+ " for line in lines[:3]:\n",
533
+ " rec = json.loads(line)\n",
534
+ " print(f\" {rec['specialist_role']} | reward={rec['episode_reward']:.3f} \"\n",
535
+ " f\"| sim {rec['pre_spawn_sim']:.2f}β†’{rec['post_spawn_sim']:.2f}\")\n",
536
+ "else:\n",
537
+ " print(\" No spawn memory yet (requires policy choosing SPAWN_SPECIALIST action)\")\n",
538
+ "\n",
539
+ "# Feature 4 β€” Resolution bandit\n",
540
+ "res_path = Path(_cfg.get(\"agents\", {}).get(\n",
541
+ " \"resolution_memory_path\", \"data/resolution_memory.jsonl\"\n",
542
+ "))\n",
543
+ "print(f\"\\nFeature 4 β€” Resolution bandit ({res_path})\")\n",
544
+ "if res_path.exists():\n",
545
+ " lines = [l for l in res_path.read_text().splitlines() if l.strip()]\n",
546
+ " print(f\" Outcome records written: {len(lines)}\")\n",
547
+ " stats = {}\n",
548
+ " for line in lines:\n",
549
+ " rec = json.loads(line)\n",
550
+ " key = f\"{rec['conflict_type']}/{rec['template_key']}\"\n",
551
+ " stats.setdefault(key, []).append(rec[\"quality_delta\"])\n",
552
+ " for k, deltas in stats.items():\n",
553
+ " print(f\" {k}: n={len(deltas)}, mean_delta={sum(deltas)/len(deltas):.3f}\")\n",
554
+ "else:\n",
555
+ " print(\" No resolution memory yet (requires detected conflicts)\")\n",
556
+ "\n",
557
+ "print(\"\\n\" + \"=\"*55)\n",
558
+ "print(\"βœ… Audit complete\")\n",
559
+ "print(\"=\"*55)"
560
+ ],
561
+ "outputs": [],
562
+ "execution_count": null
563
+ },
564
+ {
565
+ "cell_type": "markdown",
566
+ "metadata": {},
567
+ "source": [
568
+ "## Cell 8 β€” Push to HuggingFace Hub\n",
569
+ "\n",
570
+ "Uploads model checkpoint, reward curve, training log, and README to `garvitsachdeva/spindleflow-rl`."
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "metadata": {},
576
+ "source": [
577
+ "import os, json\n",
578
+ "import numpy as np\n",
579
+ "from huggingface_hub import HfApi, CommitOperationAdd\n",
580
+ "\n",
581
+ "HF_REPO = \"garvitsachdeva/spindleflow-rl\"\n",
582
+ "api = HfApi(token=HF_TOKEN)\n",
583
+ "\n",
584
+ "_tlog(f\"Pushing to https://huggingface.co/{HF_REPO} ...\")\n",
585
+ "api.create_repo(repo_id=HF_REPO.split(\"/\")[-1], repo_type=\"model\", exist_ok=True)\n",
586
+ "\n",
587
+ "ep = reward_logger.episode_rewards\n",
588
+ "f5 = float(np.mean(ep[:5])) if len(ep) >= 5 else 0.0\n",
589
+ "l5 = float(np.mean(ep[-5:])) if len(ep) >= 5 else 0.0\n",
590
+ "\n",
591
+ "readme_text = f\"\"\"---\n",
592
+ "license: mit\n",
593
+ "tags:\n",
594
+ " - reinforcement-learning\n",
595
+ " - stable-baselines3\n",
596
+ " - sb3-contrib\n",
597
+ " - gymnasium\n",
598
+ " - multi-agent\n",
599
+ " - openenv\n",
600
+ "library_name: stable-baselines3\n",
601
+ "---\n",
602
+ "\n",
603
+ "# SpindleFlow RL β€” Delegation Policy\n",
604
+ "\n",
605
+ "LSTM PPO (RecurrentPPO) agent trained on SpindleFlow-v0 (OpenEnv). \n",
606
+ "Trained on Google Colab T4 GPU.\n",
607
+ "\n",
608
+ "## Training summary\n",
609
+ "| Metric | Value |\n",
610
+ "|---|---|\n",
611
+ "| Algorithm | RecurrentPPO (SB3 + sb3-contrib) |\n",
612
+ "| Total timesteps | {TOTAL_TIMESTEPS:,} |\n",
613
+ "| Episodes completed | {len(ep):,} |\n",
614
+ "| Early baseline (first 50 ep) | {early_mean:.4f} |\n",
615
+ "| Final mean (last 200 ep) | {final_mean:.4f} |\n",
616
+ "| Improvement | {final_mean - early_mean:+.4f} |\n",
617
+ "| Training time | {_elapsed/60:.1f} min |\n",
618
+ "| Device | T4 GPU |\n",
619
+ "\n",
620
+ "![Reward Curve](reward_curve.png)\n",
621
+ "\n",
622
+ "## Load\n",
623
+ "```python\n",
624
+ "from sb3_contrib import RecurrentPPO\n",
625
+ "from huggingface_hub import hf_hub_download\n",
626
+ "model = RecurrentPPO.load(hf_hub_download(\"{HF_REPO}\", \"spindleflow_model.zip\"))\n",
627
+ "```\n",
628
+ "\"\"\"\n",
629
+ "\n",
630
+ "readme_path = \"/content/README_model.md\"\n",
631
+ "with open(readme_path, \"w\") as f:\n",
632
+ " f.write(readme_text)\n",
633
+ "\n",
634
+ "candidates = [\n",
635
+ " (\"/content/spindleflow_colab_model.zip\", \"spindleflow_model.zip\"),\n",
636
+ " (\"/content/vec_normalize_colab.pkl\", \"vec_normalize.pkl\"),\n",
637
+ " (\"/content/reward_curve.png\", \"reward_curve.png\"),\n",
638
+ " (\"/content/demo/assets/reward_curve.json\", \"reward_curve.json\"),\n",
639
+ " (\"/content/logs/training_log.txt\", \"training_log.txt\"),\n",
640
+ " (readme_path, \"README.md\"),\n",
641
+ "]\n",
642
+ "\n",
643
+ "ops = [\n",
644
+ " CommitOperationAdd(path_in_repo=dst, path_or_fileobj=src)\n",
645
+ " for src, dst in candidates\n",
646
+ " if os.path.exists(src)\n",
647
+ "]\n",
648
+ "\n",
649
+ "api.create_commit(\n",
650
+ " repo_id=HF_REPO,\n",
651
+ " repo_type=\"model\",\n",
652
+ " operations=ops,\n",
653
+ " commit_message=\"Add trained SpindleFlow RL policy (Colab T4)\",\n",
654
+ " token=HF_TOKEN,\n",
655
+ ")\n",
656
+ "\n",
657
+ "_tlog(f\"Uploaded {len(ops)} files:\")\n",
658
+ "for src, dst in candidates:\n",
659
+ " if os.path.exists(src):\n",
660
+ " _tlog(f\" βœ“ {dst}\")\n",
661
+ "\n",
662
+ "_tlog(f\"Model : https://huggingface.co/{HF_REPO}\")\n",
663
+ "_tlog(f\"Training log: https://huggingface.co/{HF_REPO}/blob/main/training_log.txt\")\n",
664
+ "_tlog(f\"Reward curve: https://huggingface.co/{HF_REPO}/blob/main/reward_curve.png\")\n",
665
+ "_tlog(f\"Improvement : {final_mean - early_mean:+.4f}\")\n",
666
+ "print(\"\\nβœ… All done!\")"
667
+ ],
668
+ "outputs": [],
669
+ "execution_count": null
670
+ }
671
+ ]
672
+ }