Eshit commited on
Commit
423d538
·
1 Parent(s): df17371

Delete deprecated GRPO notebook and publish GRPO v2 colab.

Browse files
training/grpo_colab.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
training/grpo_v2_colab.ipynb CHANGED
@@ -15,7 +15,7 @@
15
  "4. GRPO loop too slow - consequence of fix 3\n",
16
  "5. parse_action(text, None) crash - standalone check_json_format() for format reward\n",
17
  "\n",
18
- "**Hardware:** A10G Large 24GB (HuggingFace Space JupyterLab)\n",
19
  "\n",
20
  "**Before running:** In a terminal, authenticate:\n",
21
  "```\n",
@@ -37,8 +37,11 @@
37
  "metadata": {},
38
  "source": [
39
  "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
40
- "!pip install trl==0.15.2 datasets==3.4.1 wandb\n",
41
- "!pip install torchvision --extra-index-url https://download.pytorch.org/whl/cu121"
 
 
 
42
  ],
43
  "execution_count": null,
44
  "outputs": []
@@ -56,6 +59,72 @@
56
  "execution_count": null,
57
  "outputs": []
58
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  {
60
  "cell_type": "markdown",
61
  "metadata": {},
@@ -104,7 +173,11 @@
104
  "import os, random, json, sys\n",
105
  "import torch\n",
106
  "\n",
107
- "REPO_ROOT = \".\" # Adjust to repo root in Colab\n",
 
 
 
 
108
  "if REPO_ROOT not in sys.path:\n",
109
  " sys.path.insert(0, REPO_ROOT)\n",
110
  "\n",
@@ -219,59 +292,65 @@
219
  "cell_type": "code",
220
  "metadata": {},
221
  "source": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  "def reward_fn_outcome(completions, prompts, tier=None, seed=None, **kwargs):\n",
223
  " \"\"\"\n",
224
  " Score each GRPO completion by:\n",
225
  " 1. Resetting the env to the EXACT (tier, seed) that generated the prompt (Issue 1 fix).\n",
226
  " 2. Applying the sampled completion as the single first action (MODEL_STEPS=1, Issue 3/4 fix).\n",
227
  " 3. Running HeuristicAgent until episode completion (Issue 2 fix - captures terminal reward).\n",
228
- "\n",
229
  " tier and seed are dataset columns forwarded by GRPOTrainer.\n",
230
  " \"\"\"\n",
231
  " global _reward_call_count\n",
232
  " _reward_call_count += 1\n",
233
- " rewards = []\n",
234
- "\n",
235
- " for i, completion in enumerate(completions):\n",
236
- " ep_tier = tier[i] if tier is not None else controller.get_tier()\n",
237
- " ep_seed = seed[i] if seed is not None else random.choice(SEED_POOL)\n",
238
  "\n",
239
- " env = WildfireEnv()\n",
240
- " obs = env.reset(task_id=ep_tier, seed=ep_seed)\n",
241
- " total_reward = 0.0\n",
242
- "\n",
243
- " # Apply the sampled completion as step 0\n",
244
- " text = completion if isinstance(completion, str) else completion[0]['content']\n",
245
- " action, _ = parse_action(text, obs)\n",
246
- " result = env.step(action)\n",
247
- " total_reward += result.reward\n",
248
- " obs = result.observation\n",
249
- "\n",
250
- " # Heuristic drives everything after (full episode to capture terminal reward)\n",
251
- " heuristic = HeuristicAgent()\n",
252
- " while not env.done:\n",
253
- " action = heuristic.act(obs)\n",
254
- " result = env.step(action)\n",
255
- " total_reward += result.reward\n",
256
- " obs = result.observation\n",
257
  "\n",
258
- " rewards.append(total_reward)\n",
 
259
  "\n",
260
- " # Update curriculum (once per batch, not per completion)\n",
261
  " mean_r = sum(rewards) / len(rewards)\n",
262
  " promoted = controller.after_episode(mean_r)\n",
263
  " if promoted:\n",
264
  " print(f' *** Curriculum promoted to: {promoted} (mean batch reward={mean_r:.2f}) ***')\n",
265
  "\n",
266
- " # Sample completions to disk for inspection\n",
267
  " if _reward_call_count % 10 == 0:\n",
 
268
  " sample_path = f'training/samples/call_{_reward_call_count}.txt'\n",
269
  " with open(sample_path, 'w') as f:\n",
270
  " f.write(f'call={_reward_call_count} tier={tier[0] if tier else \"?\"} reward={rewards[0]:.3f}\\n')\n",
271
  " f.write('---\\n')\n",
272
  " c = completions[0]\n",
273
  " f.write(c if isinstance(c, str) else c[0]['content'])\n",
274
- " f.write('\\n')\n",
275
  "\n",
276
  " return rewards\n",
277
  "\n",
@@ -295,7 +374,7 @@
295
  " return rewards\n",
296
  "\n",
297
  "\n",
298
- "print('Reward functions defined.')"
299
  ],
300
  "execution_count": null,
301
  "outputs": []
@@ -400,8 +479,8 @@
400
  " output_dir='./grpo_checkpoints',\n",
401
  " num_generations=8,\n",
402
  " learning_rate=3e-6,\n",
403
- " max_steps=400,\n",
404
- " save_steps=20,\n",
405
  " per_device_train_batch_size=1,\n",
406
  " gradient_accumulation_steps=4,\n",
407
  " max_completion_length=192,\n",
@@ -453,7 +532,11 @@
453
  "stats = [{'step': ep, 'tier': t, 'mean_reward': r} for ep, t, r in history]\n",
454
  "with open('./training_stats.json', 'w') as f:\n",
455
  " json.dump(stats, f, indent=2)\n",
456
- "print('Stats saved -> training_stats.json')"
 
 
 
 
457
  ],
458
  "execution_count": null,
459
  "outputs": []
@@ -526,7 +609,9 @@
526
  "source": [
527
  "import numpy as np\n",
528
  "\n",
529
- "with open('scripts/results.json', 'r') as f:\n",
 
 
530
  " baselines = json.load(f)\n",
531
  "\n",
532
  "FastLanguageModel.for_inference(model)\n",
@@ -612,6 +697,15 @@
612
  "print('Pop saved rate: ', end='')\n",
613
  "print(' '.join(f'{t}={results[t][\"pop_saved_pct\"]:.0f}%' for t in TIERS))\n",
614
  "\n",
 
 
 
 
 
 
 
 
 
615
  "assert any_tier_close, (\n",
616
  " 'Trained model did not come within 1.0 of heuristic on any tier. '\n",
617
  " 'Check training logs and sample completions.'\n",
 
15
  "4. GRPO loop too slow - consequence of fix 3\n",
16
  "5. parse_action(text, None) crash - standalone check_json_format() for format reward\n",
17
  "\n",
18
+ "**Hardware:** A100 Large 40GB (HuggingFace Space JupyterLab) — ~75 min wall-clock for 150 steps\n",
19
  "\n",
20
  "**Before running:** In a terminal, authenticate:\n",
21
  "```\n",
 
37
  "metadata": {},
38
  "source": [
39
  "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
40
+ "!pip install \"trl==0.20.0\" datasets==3.4.1 wandb\n",
41
+ "# torchvision: choose the index matching your CUDA version\n",
42
+ "# HF Space A100/A10G (CUDA 12.8): use cu128\n",
43
+ "# Standard Colab (CUDA 12.1): replace cu128 with cu121\n",
44
+ "!pip install torchvision --index-url https://download.pytorch.org/whl/cu128"
45
  ],
46
  "execution_count": null,
47
  "outputs": []
 
59
  "execution_count": null,
60
  "outputs": []
61
  },
62
+ {
63
+ "cell_type": "code",
64
+ "metadata": {},
65
+ "source": [
66
+ "import sys\n",
67
+ "from enum import Enum\n",
68
+ "import importlib.machinery\n",
69
+ "from unittest.mock import MagicMock\n",
70
+ "\n",
71
+ "# torchvision C extension is ABI-incompatible with torch 2.10.0+cu128.\n",
72
+ "# Stub it out — text-only GRPO never calls vision ops.\n",
73
+ "# If your torchvision imports correctly, this cell is harmless (setdefault won't overwrite).\n",
74
+ "for _key in list(sys.modules.keys()):\n",
75
+ " if 'torchvision' in _key:\n",
76
+ " del sys.modules[_key]\n",
77
+ "\n",
78
+ "class _InterpolationMode(Enum):\n",
79
+ " NEAREST = \"nearest\"\n",
80
+ " NEAREST_EXACT = \"nearest_exact\"\n",
81
+ " BOX = \"box\"\n",
82
+ " BILINEAR = \"bilinear\"\n",
83
+ " BICUBIC = \"bicubic\"\n",
84
+ " HAMMING = \"hamming\"\n",
85
+ " LANCZOS = \"lanczos\"\n",
86
+ "\n",
87
+ "class _StubModule(type(sys)): \n",
88
+ " def __getattr__(self, name):\n",
89
+ " if name.startswith('__'):\n",
90
+ " raise AttributeError(name)\n",
91
+ " mock = MagicMock()\n",
92
+ " setattr(self, name, mock)\n",
93
+ " return mock\n",
94
+ "\n",
95
+ "def _make(name):\n",
96
+ " m = _StubModule(name)\n",
97
+ " m.__spec__ = importlib.machinery.ModuleSpec(name, None)\n",
98
+ " m.__path__ = []\n",
99
+ " m.__package__ = name\n",
100
+ " return m\n",
101
+ "\n",
102
+ "_tv = _make(\"torchvision\")\n",
103
+ "_tv.__version__ = \"0.20.0\"\n",
104
+ "_tr = _make(\"torchvision.transforms\")\n",
105
+ "_tr.InterpolationMode = _InterpolationMode\n",
106
+ "_tr_v2 = _make(\"torchvision.transforms.v2\")\n",
107
+ "_tvF = _make(\"torchvision.transforms.v2.functional\")\n",
108
+ "_ops = _make(\"torchvision.ops\")\n",
109
+ "_models = _make(\"torchvision.models\")\n",
110
+ "_io = _make(\"torchvision.io\")\n",
111
+ "_utils = _make(\"torchvision.utils\")\n",
112
+ "_datasets= _make(\"torchvision.datasets\")\n",
113
+ "_tv.transforms = _tr\n",
114
+ "_tr.v2 = _tr_v2\n",
115
+ "_tr_v2.functional = _tvF\n",
116
+ "_tv.ops = _ops; _tv.models = _models; _tv.io = _io\n",
117
+ "_tv.utils = _utils; _tv.datasets = _datasets\n",
118
+ "\n",
119
+ "for _mod in [_tv, _tr, _tr_v2, _tvF, _ops, _models, _io, _utils, _datasets]:\n",
120
+ " sys.modules[_mod.__name__] = _mod\n",
121
+ "\n",
122
+ "print(\"torchvision stubbed OK (safe for text-only training)\")"
123
+ ],
124
+ "execution_count": null,
125
+ "outputs": [],
126
+ "id": "c9ae1850"
127
+ },
128
  {
129
  "cell_type": "markdown",
130
  "metadata": {},
 
173
  "import os, random, json, sys\n",
174
  "import torch\n",
175
  "\n",
176
+ "# Clone the simulator repo first (run once in a terminal or notebook cell):\n",
177
+ "# !git clone https://github.com/Abrodolph/Wildfire-Containment-Simulator /home/user/app/Wildfire-Containment-Simulator\n",
178
+ "# !pip install -e /home/user/app/Wildfire-Containment-Simulator --quiet\n",
179
+ "REPO_ROOT = \"/home/user/app/Wildfire-Containment-Simulator\" # HF JupyterLab path\n",
180
+ "# On standard Colab: REPO_ROOT = \"/content/Wildfire-Containment-Simulator\"\n",
181
  "if REPO_ROOT not in sys.path:\n",
182
  " sys.path.insert(0, REPO_ROOT)\n",
183
  "\n",
 
292
  "cell_type": "code",
293
  "metadata": {},
294
  "source": [
295
+ "from concurrent.futures import ThreadPoolExecutor\n",
296
+ "\n",
297
+ "def _run_episode(args):\n",
298
+ " \"\"\"Run one full wildfire episode for a single GRPO completion (parallelizable).\"\"\"\n",
299
+ " completion, ep_tier, ep_seed = args\n",
300
+ " env = WildfireEnv()\n",
301
+ " obs = env.reset(task_id=ep_tier, seed=ep_seed)\n",
302
+ " total_reward = 0.0\n",
303
+ " text = completion if isinstance(completion, str) else completion[0]['content']\n",
304
+ " action, _ = parse_action(text, obs)\n",
305
+ " result = env.step(action)\n",
306
+ " total_reward += result.reward\n",
307
+ " obs = result.observation\n",
308
+ " heuristic = HeuristicAgent()\n",
309
+ " while not env.done:\n",
310
+ " action = heuristic.act(obs)\n",
311
+ " result = env.step(action)\n",
312
+ " total_reward += result.reward\n",
313
+ " obs = result.observation\n",
314
+ " return total_reward\n",
315
+ "\n",
316
+ "\n",
317
  "def reward_fn_outcome(completions, prompts, tier=None, seed=None, **kwargs):\n",
318
  " \"\"\"\n",
319
  " Score each GRPO completion by:\n",
320
  " 1. Resetting the env to the EXACT (tier, seed) that generated the prompt (Issue 1 fix).\n",
321
  " 2. Applying the sampled completion as the single first action (MODEL_STEPS=1, Issue 3/4 fix).\n",
322
  " 3. Running HeuristicAgent until episode completion (Issue 2 fix - captures terminal reward).\n",
323
+ " Episodes are run in parallel threads to reduce wall-clock time.\n",
324
  " tier and seed are dataset columns forwarded by GRPOTrainer.\n",
325
  " \"\"\"\n",
326
  " global _reward_call_count\n",
327
  " _reward_call_count += 1\n",
 
 
 
 
 
328
  "\n",
329
+ " args_list = [\n",
330
+ " (\n",
331
+ " completions[i],\n",
332
+ " tier[i] if tier is not None else controller.get_tier(),\n",
333
+ " seed[i] if seed is not None else random.choice(SEED_POOL),\n",
334
+ " )\n",
335
+ " for i in range(len(completions))\n",
336
+ " ]\n",
 
 
 
 
 
 
 
 
 
 
337
  "\n",
338
+ " with ThreadPoolExecutor(max_workers=len(completions)) as executor:\n",
339
+ " rewards = list(executor.map(_run_episode, args_list))\n",
340
  "\n",
 
341
  " mean_r = sum(rewards) / len(rewards)\n",
342
  " promoted = controller.after_episode(mean_r)\n",
343
  " if promoted:\n",
344
  " print(f' *** Curriculum promoted to: {promoted} (mean batch reward={mean_r:.2f}) ***')\n",
345
  "\n",
 
346
  " if _reward_call_count % 10 == 0:\n",
347
+ " os.makedirs('training/samples', exist_ok=True)\n",
348
  " sample_path = f'training/samples/call_{_reward_call_count}.txt'\n",
349
  " with open(sample_path, 'w') as f:\n",
350
  " f.write(f'call={_reward_call_count} tier={tier[0] if tier else \"?\"} reward={rewards[0]:.3f}\\n')\n",
351
  " f.write('---\\n')\n",
352
  " c = completions[0]\n",
353
  " f.write(c if isinstance(c, str) else c[0]['content'])\n",
 
354
  "\n",
355
  " return rewards\n",
356
  "\n",
 
374
  " return rewards\n",
375
  "\n",
376
  "\n",
377
+ "print('Reward functions defined (parallelized).')"
378
  ],
379
  "execution_count": null,
380
  "outputs": []
 
479
  " output_dir='./grpo_checkpoints',\n",
480
  " num_generations=8,\n",
481
  " learning_rate=3e-6,\n",
482
+ " max_steps=150, # 150 steps ~ 75 min on A100; increase to 400 if time allows\n",
483
+ " save_steps=10,\n",
484
  " per_device_train_batch_size=1,\n",
485
  " gradient_accumulation_steps=4,\n",
486
  " max_completion_length=192,\n",
 
532
  "stats = [{'step': ep, 'tier': t, 'mean_reward': r} for ep, t, r in history]\n",
533
  "with open('./training_stats.json', 'w') as f:\n",
534
  " json.dump(stats, f, indent=2)\n",
535
+ "print('Stats saved -> training_stats.json')\n",
536
+ "\n",
537
+ "# To resume training for more steps later:\n",
538
+ "# grpo_config.max_steps = 300 # new total\n",
539
+ "# trainer.train(resume_from_checkpoint='./grpo_checkpoints')"
540
  ],
541
  "execution_count": null,
542
  "outputs": []
 
609
  "source": [
610
  "import numpy as np\n",
611
  "\n",
612
+ "# Adjust path to repo root if needed\n",
613
+ "BASELINES_PATH = f'{REPO_ROOT}/scripts/results.json'\n",
614
+ "with open(BASELINES_PATH, 'r') as f:\n",
615
  " baselines = json.load(f)\n",
616
  "\n",
617
  "FastLanguageModel.for_inference(model)\n",
 
697
  "print('Pop saved rate: ', end='')\n",
698
  "print(' '.join(f'{t}={results[t][\"pop_saved_pct\"]:.0f}%' for t in TIERS))\n",
699
  "\n",
700
+ "with open('./grpo_eval_results.json', 'w') as f:\n",
701
+ " json.dump({\n",
702
+ " 'trained': results,\n",
703
+ " 'baselines': baselines,\n",
704
+ " 'eval_seeds': EVAL_SEEDS,\n",
705
+ " 'model': 'Eshit/wildfire-grpo-7b',\n",
706
+ " }, f, indent=2)\n",
707
+ "print('Eval results saved -> grpo_eval_results.json')\n",
708
+ "\n",
709
  "assert any_tier_close, (\n",
710
  " 'Trained model did not come within 1.0 of heuristic on any tier. '\n",
711
  " 'Check training logs and sample completions.'\n",