CreativeEngineer commited on
Commit
ddcb837
·
1 Parent(s): 9c3599b

refactor: align colab notebook with shared llm helpers

Browse files
training/notebooks/fusion_design_lab_training.ipynb CHANGED
@@ -112,24 +112,28 @@
112
  "import json\n",
113
  "from typing import Final\n",
114
  "\n",
115
- "from fusion_lab.models import StellaratorAction, StellaratorObservation\n",
 
 
 
 
 
 
 
 
116
  "from server.contract import RESET_SEEDS\n",
117
  "from server.environment import BUDGET, StellaratorEnvironment\n",
118
  "\n",
119
- "AVAILABLE_ACTIONS: Final[list[dict[str, str]]] = [\n",
120
  " {\"intent\": \"run\", \"parameter\": p, \"direction\": d, \"magnitude\": m}\n",
121
- " for p in [\"aspect_ratio\", \"elongation\", \"rotational_transform\", \"triangularity_scale\"]\n",
122
- " for d in [\"increase\", \"decrease\"]\n",
123
- " for m in [\"small\", \"medium\", \"large\"]\n",
124
- "] + [\n",
125
- " {\"intent\": \"restore_best\"},\n",
126
  "]\n",
127
  "\n",
128
- "ACTION_LABELS: Final[list[str]] = [\n",
129
- " f\"{a['intent']} {a.get('parameter', '')} {a.get('direction', '')} {a.get('magnitude', '')}\".strip()\n",
130
- " for a in AVAILABLE_ACTIONS\n",
131
  "]\n",
132
- "\n",
133
  "# Quick smoke test\n",
134
  "env = StellaratorEnvironment()\n",
135
  "obs = env.reset(seed=0)\n",
@@ -156,135 +160,25 @@
156
  "metadata": {},
157
  "outputs": [],
158
  "source": [
159
- "SYSTEM_PROMPT: Final[\n",
160
- " str\n",
161
- "] = \"\"\"You are an expert stellarator fusion reactor designer. Your goal is to optimize a stellarator design by adjusting 4 geometric parameters to minimize max elongation while satisfying physics constraints.\n",
162
- "\n",
163
- "Constraints:\n",
164
- "- aspect_ratio <= 4.0\n",
165
- "- average_triangularity <= -0.5\n",
166
- "- edge_iota_over_nfp >= 0.3\n",
167
- "\n",
168
- "Available parameters: aspect_ratio, elongation, rotational_transform, triangularity_scale\n",
169
- "Available directions: increase, decrease\n",
170
- "Available magnitudes: small, medium, large\n",
171
- "\n",
172
- "You have a budget of 6 low-fidelity evaluations. Output a short plan of run actions as a JSON array. Each action is an object with keys: intent, parameter, direction, magnitude. Do not output submit.\n",
173
- "\n",
174
- "Example:\n",
175
- "[{\"intent\":\"run\",\"parameter\":\"triangularity_scale\",\"direction\":\"increase\",\"magnitude\":\"small\"},{\"intent\":\"run\",\"parameter\":\"rotational_transform\",\"direction\":\"increase\",\"magnitude\":\"medium\"}]\"\"\"\n",
176
- "\n",
177
- "\n",
178
- "def format_observation(obs: StellaratorObservation) -> str:\n",
179
- " return (\n",
180
- " f\"Current stellarator state:\\n\"\n",
181
- " f\" max_elongation: {obs.max_elongation:.4f}\\n\"\n",
182
- " f\" aspect_ratio: {obs.aspect_ratio:.4f} (constraint: <= 4.0)\\n\"\n",
183
- " f\" average_triangularity: {obs.average_triangularity:.6f} (constraint: <= -0.5)\\n\"\n",
184
- " f\" edge_iota_over_nfp: {obs.edge_iota_over_nfp:.4f} (constraint: >= 0.3)\\n\"\n",
185
- " f\" p1_score: {obs.p1_score:.4f}\\n\"\n",
186
- " f\" feasibility: {obs.p1_feasibility:.4f}\\n\"\n",
187
- " f\" constraints_satisfied: {obs.constraints_satisfied}\\n\"\n",
188
- " f\" budget_remaining: {obs.budget_remaining}\\n\"\n",
189
- " f\"\\nGenerate an action plan as a JSON array to optimize this design.\"\n",
190
- " )\n",
191
- "\n",
192
- "\n",
193
- "def build_prompt(obs: StellaratorObservation) -> str:\n",
194
- " return (\n",
195
- " f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
196
- " f\"<|im_start|>user\\n{format_observation(obs)}<|im_end|>\\n\"\n",
197
- " f\"<|im_start|>assistant\\n\"\n",
198
- " )\n",
199
- "\n",
200
- "\n",
201
- "def _extract_json_array(text: str) -> str | None:\n",
202
- " \"\"\"Return the first balanced [...] substring that parses as a JSON array.\n",
203
- "\n",
204
- " Iterates through every [ in text, finds its balanced closing ],\n",
205
- " and attempts json.loads. Returns the first candidate that decodes as a\n",
206
- " JSON list, skipping prose fragments like [draft].\n",
207
- " \"\"\"\n",
208
- " start = text.find(\"[\")\n",
209
- " while start != -1:\n",
210
- " depth = 0\n",
211
- " in_string = False\n",
212
- " escape = False\n",
213
- " matched_end: int | None = None\n",
214
- " for index in range(start, len(text)):\n",
215
- " char = text[index]\n",
216
- " if in_string:\n",
217
- " if escape:\n",
218
- " escape = False\n",
219
- " elif char == \"\\\\\":\n",
220
- " escape = True\n",
221
- " elif char == '\"':\n",
222
- " in_string = False\n",
223
- " continue\n",
224
- " if char == '\"':\n",
225
- " in_string = True\n",
226
- " elif char == \"[\":\n",
227
- " depth += 1\n",
228
- " elif char == \"]\":\n",
229
- " depth -= 1\n",
230
- " if depth == 0:\n",
231
- " matched_end = index\n",
232
- " break\n",
233
- " if matched_end is not None:\n",
234
- " candidate = text[start : matched_end + 1]\n",
235
- " try:\n",
236
- " decoded = json.loads(candidate)\n",
237
- " if isinstance(decoded, list):\n",
238
- " return candidate\n",
239
- " except (json.JSONDecodeError, ValueError):\n",
240
- " pass\n",
241
- " start = text.find(\"[\", start + 1)\n",
242
- " return None\n",
243
- "\n",
244
- "\n",
245
- "def parse_action_plan(text: str) -> list[StellaratorAction]:\n",
246
- " \"\"\"Parse a JSON action plan from model output.\"\"\"\n",
247
- " array_text = _extract_json_array(text)\n",
248
- " if array_text is None:\n",
249
- " return []\n",
250
- " try:\n",
251
- " raw = json.loads(array_text)\n",
252
- " except json.JSONDecodeError:\n",
253
- " return []\n",
254
- " actions = []\n",
255
- " for item in raw:\n",
256
- " if not isinstance(item, dict) or \"intent\" not in item:\n",
257
- " continue\n",
258
- " intent = item[\"intent\"]\n",
259
- " if intent == \"submit\":\n",
260
- " continue\n",
261
- " if intent == \"restore_best\":\n",
262
- " actions.append(StellaratorAction(intent=\"restore_best\"))\n",
263
- " continue\n",
264
- " if intent == \"run\":\n",
265
- " p = item.get(\"parameter\", \"\")\n",
266
- " d = item.get(\"direction\", \"\")\n",
267
- " m = item.get(\"magnitude\", \"small\")\n",
268
- " if p in (\n",
269
- " \"aspect_ratio\",\n",
270
- " \"elongation\",\n",
271
- " \"rotational_transform\",\n",
272
- " \"triangularity_scale\",\n",
273
- " ) and d in (\"increase\", \"decrease\"):\n",
274
- " if m not in (\"small\", \"medium\", \"large\"):\n",
275
- " m = \"small\"\n",
276
- " actions.append(\n",
277
- " StellaratorAction(intent=\"run\", parameter=p, direction=d, magnitude=m)\n",
278
- " )\n",
279
- " return actions\n",
280
- "\n",
281
- "\n",
282
- "# Test prompt\n",
283
  "env = StellaratorEnvironment()\n",
284
  "obs = env.reset(seed=0)\n",
285
  "prompt = build_prompt(obs)\n",
286
  "print(prompt[:500])\n",
287
- "print(\"...\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  ]
289
  },
290
  {
@@ -308,8 +202,7 @@
308
  "\n",
309
  "prompts = []\n",
310
  "for seed_idx in range(len(RESET_SEEDS)):\n",
311
- " env = StellaratorEnvironment()\n",
312
- " obs = env.reset(seed=seed_idx)\n",
313
  " prompt = build_prompt(obs)\n",
314
  " # Repeat each seed to create a larger training set\n",
315
  " for _ in range(50):\n",
@@ -327,9 +220,9 @@
327
  "source": [
328
  "## 6. Reward Function\n",
329
  "\n",
330
- "The environment reward executes each generated action plan in the stellarator environment and returns the cumulative V0 reward. The environment's built-in reward decomposes feasibility (+3/-3 crossing bonuses, feasibility progress), objective (max elongation improvement), step costs, submit bonuses, and failure penalties \u2014 see `server/environment.py:_compute_reward_breakdown(...)`.\n",
331
  "\n",
332
- "If the model's plan ends before the episode terminates (no submit, budget not exhausted), the reward function auto-submits so terminal reward terms always fire. This ensures GRPO optimizes the full episode return, not truncated partial returns. The live observation telemetry also exposes `reward_breakdown` and `action_monitor` for debugging reward behavior.\n"
333
  ]
334
  },
335
  {
@@ -360,17 +253,11 @@
360
  " if len(actions) == 0:\n",
361
  " rewards.append(-3.0)\n",
362
  " continue\n",
363
- " env = StellaratorEnvironment()\n",
364
- " env.reset(seed=int(seeds[i]) % len(RESET_SEEDS))\n",
365
- " total_reward = 0.0\n",
366
- " for action in actions[:BUDGET]:\n",
367
- " if action.intent == \"submit\":\n",
368
- " continue\n",
369
- " obs = env.step(action)\n",
370
- " total_reward += float(obs.reward) if obs.reward is not None else 0.0\n",
371
- " if obs.done:\n",
372
- " break\n",
373
- " rewards.append(total_reward)\n",
374
  " except Exception:\n",
375
  " traceback.print_exc()\n",
376
  " rewards.append(-3.0)\n",
@@ -543,8 +430,11 @@
543
  "FastLanguageModel.for_inference(model)\n",
544
  "\n",
545
  "\n",
546
- "def reward_term_summary(obs: StellaratorObservation) -> str:\n",
547
- " breakdown = obs.reward_breakdown.model_dump()\n",
 
 
 
548
  " terms = []\n",
549
  " for key, value in breakdown.items():\n",
550
  " if key in {\n",
@@ -581,35 +471,22 @@
581
  " outputs[0][inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True\n",
582
  " )\n",
583
  " actions = parse_action_plan(completion)\n",
584
- " trace = []\n",
585
- " total_reward = 0.0\n",
586
- " for action in actions[:BUDGET]:\n",
587
- " if action.intent == \"submit\":\n",
588
- " continue\n",
589
- " obs = env.step(action)\n",
590
- " r = float(obs.reward) if obs.reward is not None else 0.0\n",
591
- " total_reward += r\n",
592
- " trace.append(\n",
593
- " f\" {action.intent} {action.parameter or ''} {action.direction or ''} {action.magnitude or ''} \u2192 reward={r:.3f} score={obs.p1_score:.4f} feasible={obs.constraints_satisfied} terms={reward_term_summary(obs)}\".strip()\n",
594
  " )\n",
595
- " if obs.done:\n",
596
- " break\n",
597
- " return total_reward, trace\n",
598
  "\n",
599
  "\n",
600
  "def run_random_episode(seed_idx: int) -> float:\n",
601
  " \"\"\"Run one episode with random actions for comparison.\"\"\"\n",
602
- " env = StellaratorEnvironment()\n",
603
- " env.reset(seed=seed_idx)\n",
604
- " total_reward = 0.0\n",
605
- " for step in range(BUDGET):\n",
606
- " spec = random.choice(AVAILABLE_ACTIONS[:24]) # run actions only\n",
607
- " action = StellaratorAction(**spec)\n",
608
- " obs = env.step(action)\n",
609
- " total_reward += float(obs.reward) if obs.reward is not None else 0.0\n",
610
- " if obs.done:\n",
611
- " return total_reward\n",
612
- " return total_reward\n",
613
  "\n",
614
  "\n",
615
  "# Evaluate\n",
@@ -690,7 +567,6 @@
690
  " actions = parse_action_plan(completion)\n",
691
  "\n",
692
  " print(f\"\\nTrained model generated {len(actions)} actions for remote env:\")\n",
693
- " done = False\n",
694
  " for i, action in enumerate(actions[:BUDGET], start=1):\n",
695
  " if action.intent == \"submit\":\n",
696
  " continue\n",
@@ -704,7 +580,6 @@
704
  " )\n",
705
  " if result.done:\n",
706
  " print(f\" Episode done. Final score: {step_obs.p1_score:.4f}\")\n",
707
- " done = True\n",
708
  " break\n",
709
  "print(\"\\nEnvironment is live and accessible for training and evaluation.\")"
710
  ]
 
112
  "import json\n",
113
  "from typing import Final\n",
114
  "\n",
115
+ "from fusion_lab.llm_agent import (\n",
116
+ " RUN_DIRECTIONS,\n",
117
+ " RUN_MAGNITUDES,\n",
118
+ " RUN_PARAMETERS,\n",
119
+ " build_prompt,\n",
120
+ " parse_action_plan,\n",
121
+ " run_episode_with_actions,\n",
122
+ ")\n",
123
+ "from fusion_lab.models import StellaratorAction\n",
124
  "from server.contract import RESET_SEEDS\n",
125
  "from server.environment import BUDGET, StellaratorEnvironment\n",
126
  "\n",
127
+ "RUN_ACTION_SPECS: Final[list[dict[str, str]]] = [\n",
128
  " {\"intent\": \"run\", \"parameter\": p, \"direction\": d, \"magnitude\": m}\n",
129
+ " for p in RUN_PARAMETERS\n",
130
+ " for d in RUN_DIRECTIONS\n",
131
+ " for m in RUN_MAGNITUDES\n",
 
 
132
  "]\n",
133
  "\n",
134
+ "AVAILABLE_ACTIONS: Final[list[dict[str, str]]] = RUN_ACTION_SPECS + [\n",
135
+ " {\"intent\": \"restore_best\"},\n",
 
136
  "]\n",
 
137
  "# Quick smoke test\n",
138
  "env = StellaratorEnvironment()\n",
139
  "obs = env.reset(seed=0)\n",
 
160
  "metadata": {},
161
  "outputs": [],
162
  "source": [
163
+ "# Shared helper smoke test\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  "env = StellaratorEnvironment()\n",
165
  "obs = env.reset(seed=0)\n",
166
  "prompt = build_prompt(obs)\n",
167
  "print(prompt[:500])\n",
168
+ "print(\"...\")\n",
169
+ "\n",
170
+ "sample_completion = json.dumps(\n",
171
+ " [\n",
172
+ " {\n",
173
+ " \"intent\": \"run\",\n",
174
+ " \"parameter\": \"triangularity_scale\",\n",
175
+ " \"direction\": \"increase\",\n",
176
+ " \"magnitude\": \"small\",\n",
177
+ " },\n",
178
+ " {\"intent\": \"submit\"},\n",
179
+ " ]\n",
180
+ ")\n",
181
+ "print(parse_action_plan(sample_completion))"
182
  ]
183
  },
184
  {
 
202
  "\n",
203
  "prompts = []\n",
204
  "for seed_idx in range(len(RESET_SEEDS)):\n",
205
+ " obs = StellaratorEnvironment().reset(seed=seed_idx)\n",
 
206
  " prompt = build_prompt(obs)\n",
207
  " # Repeat each seed to create a larger training set\n",
208
  " for _ in range(50):\n",
 
220
  "source": [
221
  "## 6. Reward Function\n",
222
  "\n",
223
+ "The environment reward executes each generated action plan in the stellarator environment and returns the cumulative low-fidelity Reward V0 from the live environment. The environment's built-in reward decomposes feasibility (+3/-3 crossing bonuses, feasibility progress), objective (max elongation improvement), step costs, and failure penalties \u2014 see `server/environment.py:_compute_reward_breakdown(...)`.\n",
224
  "\n",
225
+ "For the current training workflow, the notebook ignores `submit` and does not auto-submit. GRPO therefore optimizes the low-fidelity `run` path only. The live observation telemetry still exposes `reward_breakdown` and `action_monitor` for debugging reward behavior.\n"
226
  ]
227
  },
228
  {
 
253
  " if len(actions) == 0:\n",
254
  " rewards.append(-3.0)\n",
255
  " continue\n",
256
+ " trace = run_episode_with_actions(\n",
257
+ " actions,\n",
258
+ " seed_idx=int(seeds[i]) % len(RESET_SEEDS),\n",
259
+ " )\n",
260
+ " rewards.append(trace.total_reward)\n",
 
 
 
 
 
 
261
  " except Exception:\n",
262
  " traceback.print_exc()\n",
263
  " rewards.append(-3.0)\n",
 
430
  "FastLanguageModel.for_inference(model)\n",
431
  "\n",
432
  "\n",
433
+ "def reward_term_summary(step_or_obs: object) -> str:\n",
434
+ " breakdown_obj = getattr(step_or_obs, \"reward_breakdown\")\n",
435
+ " breakdown = (\n",
436
+ " breakdown_obj.model_dump() if hasattr(breakdown_obj, \"model_dump\") else breakdown_obj\n",
437
+ " )\n",
438
  " terms = []\n",
439
  " for key, value in breakdown.items():\n",
440
  " if key in {\n",
 
471
  " outputs[0][inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True\n",
472
  " )\n",
473
  " actions = parse_action_plan(completion)\n",
474
+ " episode = run_episode_with_actions(actions, seed_idx=seed_idx)\n",
475
+ " trace = [\n",
476
+ " (\n",
477
+ " f\"{step.action_label} \u2192 reward={step.reward:.3f} \"\n",
478
+ " f\"score={step.p1_score:.4f} feasible={step.constraints_satisfied} \"\n",
479
+ " f\"terms={reward_term_summary(step)}\"\n",
 
 
 
 
480
  " )\n",
481
+ " for step in episode.steps\n",
482
+ " ]\n",
483
+ " return episode.total_reward, trace\n",
484
  "\n",
485
  "\n",
486
  "def run_random_episode(seed_idx: int) -> float:\n",
487
  " \"\"\"Run one episode with random actions for comparison.\"\"\"\n",
488
+ " actions = [StellaratorAction(**random.choice(RUN_ACTION_SPECS)) for _ in range(BUDGET)]\n",
489
+ " return run_episode_with_actions(actions, seed_idx=seed_idx).total_reward\n",
 
 
 
 
 
 
 
 
 
490
  "\n",
491
  "\n",
492
  "# Evaluate\n",
 
567
  " actions = parse_action_plan(completion)\n",
568
  "\n",
569
  " print(f\"\\nTrained model generated {len(actions)} actions for remote env:\")\n",
 
570
  " for i, action in enumerate(actions[:BUDGET], start=1):\n",
571
  " if action.intent == \"submit\":\n",
572
  " continue\n",
 
580
  " )\n",
581
  " if result.done:\n",
582
  " print(f\" Episode done. Final score: {step_obs.p1_score:.4f}\")\n",
 
583
  " break\n",
584
  "print(\"\\nEnvironment is live and accessible for training and evaluation.\")"
585
  ]