Spaces:
Sleeping
Sleeping
Commit ·
ddcb837
1
Parent(s): 9c3599b
refactor: align colab notebook with shared llm helpers
Browse files
training/notebooks/fusion_design_lab_training.ipynb
CHANGED
|
@@ -112,24 +112,28 @@
|
|
| 112 |
"import json\n",
|
| 113 |
"from typing import Final\n",
|
| 114 |
"\n",
|
| 115 |
-
"from fusion_lab.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
"from server.contract import RESET_SEEDS\n",
|
| 117 |
"from server.environment import BUDGET, StellaratorEnvironment\n",
|
| 118 |
"\n",
|
| 119 |
-
"
|
| 120 |
" {\"intent\": \"run\", \"parameter\": p, \"direction\": d, \"magnitude\": m}\n",
|
| 121 |
-
" for p in
|
| 122 |
-
" for d in
|
| 123 |
-
" for m in
|
| 124 |
-
"] + [\n",
|
| 125 |
-
" {\"intent\": \"restore_best\"},\n",
|
| 126 |
"]\n",
|
| 127 |
"\n",
|
| 128 |
-
"
|
| 129 |
-
"
|
| 130 |
-
" for a in AVAILABLE_ACTIONS\n",
|
| 131 |
"]\n",
|
| 132 |
-
"\n",
|
| 133 |
"# Quick smoke test\n",
|
| 134 |
"env = StellaratorEnvironment()\n",
|
| 135 |
"obs = env.reset(seed=0)\n",
|
|
@@ -156,135 +160,25 @@
|
|
| 156 |
"metadata": {},
|
| 157 |
"outputs": [],
|
| 158 |
"source": [
|
| 159 |
-
"
|
| 160 |
-
" str\n",
|
| 161 |
-
"] = \"\"\"You are an expert stellarator fusion reactor designer. Your goal is to optimize a stellarator design by adjusting 4 geometric parameters to minimize max elongation while satisfying physics constraints.\n",
|
| 162 |
-
"\n",
|
| 163 |
-
"Constraints:\n",
|
| 164 |
-
"- aspect_ratio <= 4.0\n",
|
| 165 |
-
"- average_triangularity <= -0.5\n",
|
| 166 |
-
"- edge_iota_over_nfp >= 0.3\n",
|
| 167 |
-
"\n",
|
| 168 |
-
"Available parameters: aspect_ratio, elongation, rotational_transform, triangularity_scale\n",
|
| 169 |
-
"Available directions: increase, decrease\n",
|
| 170 |
-
"Available magnitudes: small, medium, large\n",
|
| 171 |
-
"\n",
|
| 172 |
-
"You have a budget of 6 low-fidelity evaluations. Output a short plan of run actions as a JSON array. Each action is an object with keys: intent, parameter, direction, magnitude. Do not output submit.\n",
|
| 173 |
-
"\n",
|
| 174 |
-
"Example:\n",
|
| 175 |
-
"[{\"intent\":\"run\",\"parameter\":\"triangularity_scale\",\"direction\":\"increase\",\"magnitude\":\"small\"},{\"intent\":\"run\",\"parameter\":\"rotational_transform\",\"direction\":\"increase\",\"magnitude\":\"medium\"}]\"\"\"\n",
|
| 176 |
-
"\n",
|
| 177 |
-
"\n",
|
| 178 |
-
"def format_observation(obs: StellaratorObservation) -> str:\n",
|
| 179 |
-
" return (\n",
|
| 180 |
-
" f\"Current stellarator state:\\n\"\n",
|
| 181 |
-
" f\" max_elongation: {obs.max_elongation:.4f}\\n\"\n",
|
| 182 |
-
" f\" aspect_ratio: {obs.aspect_ratio:.4f} (constraint: <= 4.0)\\n\"\n",
|
| 183 |
-
" f\" average_triangularity: {obs.average_triangularity:.6f} (constraint: <= -0.5)\\n\"\n",
|
| 184 |
-
" f\" edge_iota_over_nfp: {obs.edge_iota_over_nfp:.4f} (constraint: >= 0.3)\\n\"\n",
|
| 185 |
-
" f\" p1_score: {obs.p1_score:.4f}\\n\"\n",
|
| 186 |
-
" f\" feasibility: {obs.p1_feasibility:.4f}\\n\"\n",
|
| 187 |
-
" f\" constraints_satisfied: {obs.constraints_satisfied}\\n\"\n",
|
| 188 |
-
" f\" budget_remaining: {obs.budget_remaining}\\n\"\n",
|
| 189 |
-
" f\"\\nGenerate an action plan as a JSON array to optimize this design.\"\n",
|
| 190 |
-
" )\n",
|
| 191 |
-
"\n",
|
| 192 |
-
"\n",
|
| 193 |
-
"def build_prompt(obs: StellaratorObservation) -> str:\n",
|
| 194 |
-
" return (\n",
|
| 195 |
-
" f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
|
| 196 |
-
" f\"<|im_start|>user\\n{format_observation(obs)}<|im_end|>\\n\"\n",
|
| 197 |
-
" f\"<|im_start|>assistant\\n\"\n",
|
| 198 |
-
" )\n",
|
| 199 |
-
"\n",
|
| 200 |
-
"\n",
|
| 201 |
-
"def _extract_json_array(text: str) -> str | None:\n",
|
| 202 |
-
" \"\"\"Return the first balanced [...] substring that parses as a JSON array.\n",
|
| 203 |
-
"\n",
|
| 204 |
-
" Iterates through every [ in text, finds its balanced closing ],\n",
|
| 205 |
-
" and attempts json.loads. Returns the first candidate that decodes as a\n",
|
| 206 |
-
" JSON list, skipping prose fragments like [draft].\n",
|
| 207 |
-
" \"\"\"\n",
|
| 208 |
-
" start = text.find(\"[\")\n",
|
| 209 |
-
" while start != -1:\n",
|
| 210 |
-
" depth = 0\n",
|
| 211 |
-
" in_string = False\n",
|
| 212 |
-
" escape = False\n",
|
| 213 |
-
" matched_end: int | None = None\n",
|
| 214 |
-
" for index in range(start, len(text)):\n",
|
| 215 |
-
" char = text[index]\n",
|
| 216 |
-
" if in_string:\n",
|
| 217 |
-
" if escape:\n",
|
| 218 |
-
" escape = False\n",
|
| 219 |
-
" elif char == \"\\\\\":\n",
|
| 220 |
-
" escape = True\n",
|
| 221 |
-
" elif char == '\"':\n",
|
| 222 |
-
" in_string = False\n",
|
| 223 |
-
" continue\n",
|
| 224 |
-
" if char == '\"':\n",
|
| 225 |
-
" in_string = True\n",
|
| 226 |
-
" elif char == \"[\":\n",
|
| 227 |
-
" depth += 1\n",
|
| 228 |
-
" elif char == \"]\":\n",
|
| 229 |
-
" depth -= 1\n",
|
| 230 |
-
" if depth == 0:\n",
|
| 231 |
-
" matched_end = index\n",
|
| 232 |
-
" break\n",
|
| 233 |
-
" if matched_end is not None:\n",
|
| 234 |
-
" candidate = text[start : matched_end + 1]\n",
|
| 235 |
-
" try:\n",
|
| 236 |
-
" decoded = json.loads(candidate)\n",
|
| 237 |
-
" if isinstance(decoded, list):\n",
|
| 238 |
-
" return candidate\n",
|
| 239 |
-
" except (json.JSONDecodeError, ValueError):\n",
|
| 240 |
-
" pass\n",
|
| 241 |
-
" start = text.find(\"[\", start + 1)\n",
|
| 242 |
-
" return None\n",
|
| 243 |
-
"\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"def parse_action_plan(text: str) -> list[StellaratorAction]:\n",
|
| 246 |
-
" \"\"\"Parse a JSON action plan from model output.\"\"\"\n",
|
| 247 |
-
" array_text = _extract_json_array(text)\n",
|
| 248 |
-
" if array_text is None:\n",
|
| 249 |
-
" return []\n",
|
| 250 |
-
" try:\n",
|
| 251 |
-
" raw = json.loads(array_text)\n",
|
| 252 |
-
" except json.JSONDecodeError:\n",
|
| 253 |
-
" return []\n",
|
| 254 |
-
" actions = []\n",
|
| 255 |
-
" for item in raw:\n",
|
| 256 |
-
" if not isinstance(item, dict) or \"intent\" not in item:\n",
|
| 257 |
-
" continue\n",
|
| 258 |
-
" intent = item[\"intent\"]\n",
|
| 259 |
-
" if intent == \"submit\":\n",
|
| 260 |
-
" continue\n",
|
| 261 |
-
" if intent == \"restore_best\":\n",
|
| 262 |
-
" actions.append(StellaratorAction(intent=\"restore_best\"))\n",
|
| 263 |
-
" continue\n",
|
| 264 |
-
" if intent == \"run\":\n",
|
| 265 |
-
" p = item.get(\"parameter\", \"\")\n",
|
| 266 |
-
" d = item.get(\"direction\", \"\")\n",
|
| 267 |
-
" m = item.get(\"magnitude\", \"small\")\n",
|
| 268 |
-
" if p in (\n",
|
| 269 |
-
" \"aspect_ratio\",\n",
|
| 270 |
-
" \"elongation\",\n",
|
| 271 |
-
" \"rotational_transform\",\n",
|
| 272 |
-
" \"triangularity_scale\",\n",
|
| 273 |
-
" ) and d in (\"increase\", \"decrease\"):\n",
|
| 274 |
-
" if m not in (\"small\", \"medium\", \"large\"):\n",
|
| 275 |
-
" m = \"small\"\n",
|
| 276 |
-
" actions.append(\n",
|
| 277 |
-
" StellaratorAction(intent=\"run\", parameter=p, direction=d, magnitude=m)\n",
|
| 278 |
-
" )\n",
|
| 279 |
-
" return actions\n",
|
| 280 |
-
"\n",
|
| 281 |
-
"\n",
|
| 282 |
-
"# Test prompt\n",
|
| 283 |
"env = StellaratorEnvironment()\n",
|
| 284 |
"obs = env.reset(seed=0)\n",
|
| 285 |
"prompt = build_prompt(obs)\n",
|
| 286 |
"print(prompt[:500])\n",
|
| 287 |
-
"print(\"...\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
]
|
| 289 |
},
|
| 290 |
{
|
|
@@ -308,8 +202,7 @@
|
|
| 308 |
"\n",
|
| 309 |
"prompts = []\n",
|
| 310 |
"for seed_idx in range(len(RESET_SEEDS)):\n",
|
| 311 |
-
"
|
| 312 |
-
" obs = env.reset(seed=seed_idx)\n",
|
| 313 |
" prompt = build_prompt(obs)\n",
|
| 314 |
" # Repeat each seed to create a larger training set\n",
|
| 315 |
" for _ in range(50):\n",
|
|
@@ -327,9 +220,9 @@
|
|
| 327 |
"source": [
|
| 328 |
"## 6. Reward Function\n",
|
| 329 |
"\n",
|
| 330 |
-
"The environment reward executes each generated action plan in the stellarator environment and returns the cumulative V0
|
| 331 |
"\n",
|
| 332 |
-
"
|
| 333 |
]
|
| 334 |
},
|
| 335 |
{
|
|
@@ -360,17 +253,11 @@
|
|
| 360 |
" if len(actions) == 0:\n",
|
| 361 |
" rewards.append(-3.0)\n",
|
| 362 |
" continue\n",
|
| 363 |
-
"
|
| 364 |
-
"
|
| 365 |
-
"
|
| 366 |
-
"
|
| 367 |
-
"
|
| 368 |
-
" continue\n",
|
| 369 |
-
" obs = env.step(action)\n",
|
| 370 |
-
" total_reward += float(obs.reward) if obs.reward is not None else 0.0\n",
|
| 371 |
-
" if obs.done:\n",
|
| 372 |
-
" break\n",
|
| 373 |
-
" rewards.append(total_reward)\n",
|
| 374 |
" except Exception:\n",
|
| 375 |
" traceback.print_exc()\n",
|
| 376 |
" rewards.append(-3.0)\n",
|
|
@@ -543,8 +430,11 @@
|
|
| 543 |
"FastLanguageModel.for_inference(model)\n",
|
| 544 |
"\n",
|
| 545 |
"\n",
|
| 546 |
-
"def reward_term_summary(
|
| 547 |
-
"
|
|
|
|
|
|
|
|
|
|
| 548 |
" terms = []\n",
|
| 549 |
" for key, value in breakdown.items():\n",
|
| 550 |
" if key in {\n",
|
|
@@ -581,35 +471,22 @@
|
|
| 581 |
" outputs[0][inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True\n",
|
| 582 |
" )\n",
|
| 583 |
" actions = parse_action_plan(completion)\n",
|
| 584 |
-
"
|
| 585 |
-
"
|
| 586 |
-
"
|
| 587 |
-
"
|
| 588 |
-
"
|
| 589 |
-
"
|
| 590 |
-
" r = float(obs.reward) if obs.reward is not None else 0.0\n",
|
| 591 |
-
" total_reward += r\n",
|
| 592 |
-
" trace.append(\n",
|
| 593 |
-
" f\" {action.intent} {action.parameter or ''} {action.direction or ''} {action.magnitude or ''} \u2192 reward={r:.3f} score={obs.p1_score:.4f} feasible={obs.constraints_satisfied} terms={reward_term_summary(obs)}\".strip()\n",
|
| 594 |
" )\n",
|
| 595 |
-
"
|
| 596 |
-
"
|
| 597 |
-
" return total_reward, trace\n",
|
| 598 |
"\n",
|
| 599 |
"\n",
|
| 600 |
"def run_random_episode(seed_idx: int) -> float:\n",
|
| 601 |
" \"\"\"Run one episode with random actions for comparison.\"\"\"\n",
|
| 602 |
-
"
|
| 603 |
-
"
|
| 604 |
-
" total_reward = 0.0\n",
|
| 605 |
-
" for step in range(BUDGET):\n",
|
| 606 |
-
" spec = random.choice(AVAILABLE_ACTIONS[:24]) # run actions only\n",
|
| 607 |
-
" action = StellaratorAction(**spec)\n",
|
| 608 |
-
" obs = env.step(action)\n",
|
| 609 |
-
" total_reward += float(obs.reward) if obs.reward is not None else 0.0\n",
|
| 610 |
-
" if obs.done:\n",
|
| 611 |
-
" return total_reward\n",
|
| 612 |
-
" return total_reward\n",
|
| 613 |
"\n",
|
| 614 |
"\n",
|
| 615 |
"# Evaluate\n",
|
|
@@ -690,7 +567,6 @@
|
|
| 690 |
" actions = parse_action_plan(completion)\n",
|
| 691 |
"\n",
|
| 692 |
" print(f\"\\nTrained model generated {len(actions)} actions for remote env:\")\n",
|
| 693 |
-
" done = False\n",
|
| 694 |
" for i, action in enumerate(actions[:BUDGET], start=1):\n",
|
| 695 |
" if action.intent == \"submit\":\n",
|
| 696 |
" continue\n",
|
|
@@ -704,7 +580,6 @@
|
|
| 704 |
" )\n",
|
| 705 |
" if result.done:\n",
|
| 706 |
" print(f\" Episode done. Final score: {step_obs.p1_score:.4f}\")\n",
|
| 707 |
-
" done = True\n",
|
| 708 |
" break\n",
|
| 709 |
"print(\"\\nEnvironment is live and accessible for training and evaluation.\")"
|
| 710 |
]
|
|
|
|
| 112 |
"import json\n",
|
| 113 |
"from typing import Final\n",
|
| 114 |
"\n",
|
| 115 |
+
"from fusion_lab.llm_agent import (\n",
|
| 116 |
+
" RUN_DIRECTIONS,\n",
|
| 117 |
+
" RUN_MAGNITUDES,\n",
|
| 118 |
+
" RUN_PARAMETERS,\n",
|
| 119 |
+
" build_prompt,\n",
|
| 120 |
+
" parse_action_plan,\n",
|
| 121 |
+
" run_episode_with_actions,\n",
|
| 122 |
+
")\n",
|
| 123 |
+
"from fusion_lab.models import StellaratorAction\n",
|
| 124 |
"from server.contract import RESET_SEEDS\n",
|
| 125 |
"from server.environment import BUDGET, StellaratorEnvironment\n",
|
| 126 |
"\n",
|
| 127 |
+
"RUN_ACTION_SPECS: Final[list[dict[str, str]]] = [\n",
|
| 128 |
" {\"intent\": \"run\", \"parameter\": p, \"direction\": d, \"magnitude\": m}\n",
|
| 129 |
+
" for p in RUN_PARAMETERS\n",
|
| 130 |
+
" for d in RUN_DIRECTIONS\n",
|
| 131 |
+
" for m in RUN_MAGNITUDES\n",
|
|
|
|
|
|
|
| 132 |
"]\n",
|
| 133 |
"\n",
|
| 134 |
+
"AVAILABLE_ACTIONS: Final[list[dict[str, str]]] = RUN_ACTION_SPECS + [\n",
|
| 135 |
+
" {\"intent\": \"restore_best\"},\n",
|
|
|
|
| 136 |
"]\n",
|
|
|
|
| 137 |
"# Quick smoke test\n",
|
| 138 |
"env = StellaratorEnvironment()\n",
|
| 139 |
"obs = env.reset(seed=0)\n",
|
|
|
|
| 160 |
"metadata": {},
|
| 161 |
"outputs": [],
|
| 162 |
"source": [
|
| 163 |
+
"# Shared helper smoke test\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
"env = StellaratorEnvironment()\n",
|
| 165 |
"obs = env.reset(seed=0)\n",
|
| 166 |
"prompt = build_prompt(obs)\n",
|
| 167 |
"print(prompt[:500])\n",
|
| 168 |
+
"print(\"...\")\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"sample_completion = json.dumps(\n",
|
| 171 |
+
" [\n",
|
| 172 |
+
" {\n",
|
| 173 |
+
" \"intent\": \"run\",\n",
|
| 174 |
+
" \"parameter\": \"triangularity_scale\",\n",
|
| 175 |
+
" \"direction\": \"increase\",\n",
|
| 176 |
+
" \"magnitude\": \"small\",\n",
|
| 177 |
+
" },\n",
|
| 178 |
+
" {\"intent\": \"submit\"},\n",
|
| 179 |
+
" ]\n",
|
| 180 |
+
")\n",
|
| 181 |
+
"print(parse_action_plan(sample_completion))"
|
| 182 |
]
|
| 183 |
},
|
| 184 |
{
|
|
|
|
| 202 |
"\n",
|
| 203 |
"prompts = []\n",
|
| 204 |
"for seed_idx in range(len(RESET_SEEDS)):\n",
|
| 205 |
+
" obs = StellaratorEnvironment().reset(seed=seed_idx)\n",
|
|
|
|
| 206 |
" prompt = build_prompt(obs)\n",
|
| 207 |
" # Repeat each seed to create a larger training set\n",
|
| 208 |
" for _ in range(50):\n",
|
|
|
|
| 220 |
"source": [
|
| 221 |
"## 6. Reward Function\n",
|
| 222 |
"\n",
|
| 223 |
+
"The environment reward executes each generated action plan in the stellarator environment and returns the cumulative low-fidelity Reward V0 from the live environment. The environment's built-in reward decomposes feasibility (+3/-3 crossing bonuses, feasibility progress), objective (max elongation improvement), step costs, and failure penalties \u2014 see `server/environment.py:_compute_reward_breakdown(...)`.\n",
|
| 224 |
"\n",
|
| 225 |
+
"For the current training workflow, the notebook ignores `submit` and does not auto-submit. GRPO therefore optimizes the low-fidelity `run` path only. The live observation telemetry still exposes `reward_breakdown` and `action_monitor` for debugging reward behavior.\n"
|
| 226 |
]
|
| 227 |
},
|
| 228 |
{
|
|
|
|
| 253 |
" if len(actions) == 0:\n",
|
| 254 |
" rewards.append(-3.0)\n",
|
| 255 |
" continue\n",
|
| 256 |
+
" trace = run_episode_with_actions(\n",
|
| 257 |
+
" actions,\n",
|
| 258 |
+
" seed_idx=int(seeds[i]) % len(RESET_SEEDS),\n",
|
| 259 |
+
" )\n",
|
| 260 |
+
" rewards.append(trace.total_reward)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
" except Exception:\n",
|
| 262 |
" traceback.print_exc()\n",
|
| 263 |
" rewards.append(-3.0)\n",
|
|
|
|
| 430 |
"FastLanguageModel.for_inference(model)\n",
|
| 431 |
"\n",
|
| 432 |
"\n",
|
| 433 |
+
"def reward_term_summary(step_or_obs: object) -> str:\n",
|
| 434 |
+
" breakdown_obj = getattr(step_or_obs, \"reward_breakdown\")\n",
|
| 435 |
+
" breakdown = (\n",
|
| 436 |
+
" breakdown_obj.model_dump() if hasattr(breakdown_obj, \"model_dump\") else breakdown_obj\n",
|
| 437 |
+
" )\n",
|
| 438 |
" terms = []\n",
|
| 439 |
" for key, value in breakdown.items():\n",
|
| 440 |
" if key in {\n",
|
|
|
|
| 471 |
" outputs[0][inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True\n",
|
| 472 |
" )\n",
|
| 473 |
" actions = parse_action_plan(completion)\n",
|
| 474 |
+
" episode = run_episode_with_actions(actions, seed_idx=seed_idx)\n",
|
| 475 |
+
" trace = [\n",
|
| 476 |
+
" (\n",
|
| 477 |
+
" f\"{step.action_label} \u2192 reward={step.reward:.3f} \"\n",
|
| 478 |
+
" f\"score={step.p1_score:.4f} feasible={step.constraints_satisfied} \"\n",
|
| 479 |
+
" f\"terms={reward_term_summary(step)}\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
" )\n",
|
| 481 |
+
" for step in episode.steps\n",
|
| 482 |
+
" ]\n",
|
| 483 |
+
" return episode.total_reward, trace\n",
|
| 484 |
"\n",
|
| 485 |
"\n",
|
| 486 |
"def run_random_episode(seed_idx: int) -> float:\n",
|
| 487 |
" \"\"\"Run one episode with random actions for comparison.\"\"\"\n",
|
| 488 |
+
" actions = [StellaratorAction(**random.choice(RUN_ACTION_SPECS)) for _ in range(BUDGET)]\n",
|
| 489 |
+
" return run_episode_with_actions(actions, seed_idx=seed_idx).total_reward\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
"\n",
|
| 491 |
"\n",
|
| 492 |
"# Evaluate\n",
|
|
|
|
| 567 |
" actions = parse_action_plan(completion)\n",
|
| 568 |
"\n",
|
| 569 |
" print(f\"\\nTrained model generated {len(actions)} actions for remote env:\")\n",
|
|
|
|
| 570 |
" for i, action in enumerate(actions[:BUDGET], start=1):\n",
|
| 571 |
" if action.intent == \"submit\":\n",
|
| 572 |
" continue\n",
|
|
|
|
| 580 |
" )\n",
|
| 581 |
" if result.done:\n",
|
| 582 |
" print(f\" Episode done. Final score: {step_obs.p1_score:.4f}\")\n",
|
|
|
|
| 583 |
" break\n",
|
| 584 |
"print(\"\\nEnvironment is live and accessible for training and evaluation.\")"
|
| 585 |
]
|