CreativeEngineer Claude Opus 4.6 commited on
Commit
3bfd80a
·
1 Parent(s): 2dba2cf

feat: add HF Space deployment + GRPO training notebook

Browse files

- Add root-level re-export files (__init__.py, client.py, models.py)
for OpenEnv packaging convention
- Switch Dockerfile base from openenv-base to python:3.12-slim for
reliable HF Space builds
- Add Colab-ready GRPO training notebook using Unsloth + TRL
with environment reward functions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Fusion Design Lab — OpenEnv P1 stellarator environment."""
client.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Root-level re-export for OpenEnv packaging convention."""
2
+
3
+ from fusion_lab.client import FusionLabClient
4
+
5
+ __all__ = ["FusionLabClient"]
models.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Root-level re-export for OpenEnv packaging convention."""
2
+
3
+ from fusion_lab.models import (
4
+ ActionIntent,
5
+ DirectionName,
6
+ EvaluationFidelityName,
7
+ LowDimBoundaryParams,
8
+ MagnitudeName,
9
+ ParameterName,
10
+ StellaratorAction,
11
+ StellaratorObservation,
12
+ StellaratorState,
13
+ default_low_dim_boundary_params,
14
+ )
15
+
16
+ __all__ = [
17
+ "ActionIntent",
18
+ "DirectionName",
19
+ "EvaluationFidelityName",
20
+ "LowDimBoundaryParams",
21
+ "MagnitudeName",
22
+ "ParameterName",
23
+ "StellaratorAction",
24
+ "StellaratorObservation",
25
+ "StellaratorState",
26
+ "default_low_dim_boundary_params",
27
+ ]
server/Dockerfile CHANGED
@@ -1,43 +1,33 @@
1
- ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
2
- FROM ${BASE_IMAGE} AS builder
3
 
4
  WORKDIR /app
5
 
6
  RUN apt-get update && \
7
- apt-get install -y --no-install-recommends git && \
8
  rm -rf /var/lib/apt/lists/*
9
 
10
- ARG BUILD_MODE=standalone
11
- ARG ENV_NAME=fusion_design_lab
12
-
13
  COPY . /app/env
14
 
15
  WORKDIR /app/env
16
 
17
- RUN if ! command -v uv >/dev/null 2>&1; then \
18
- curl -LsSf https://astral.sh/uv/install.sh | sh && \
19
- mv /root/.local/bin/uv /usr/local/bin/uv && \
20
- mv /root/.local/bin/uvx /usr/local/bin/uvx; \
21
- fi
22
 
23
  RUN --mount=type=cache,target=/root/.cache/uv \
24
- if [ -f uv.lock ]; then \
25
- uv sync --frozen --no-install-project --no-editable; \
26
- else \
27
- uv sync --no-install-project --no-editable; \
28
- fi
29
 
30
  RUN --mount=type=cache,target=/root/.cache/uv \
31
- if [ -f uv.lock ]; then \
32
- uv sync --frozen --no-editable; \
33
- else \
34
- uv sync --no-editable; \
35
- fi
36
 
37
- FROM ${BASE_IMAGE}
38
 
39
  WORKDIR /app
40
 
 
 
 
 
41
  COPY --from=builder /app/env/.venv /app/.venv
42
  COPY --from=builder /app/env /app/env
43
 
 
1
+ FROM python:3.12-slim AS builder
 
2
 
3
  WORKDIR /app
4
 
5
  RUN apt-get update && \
6
+ apt-get install -y --no-install-recommends git curl && \
7
  rm -rf /var/lib/apt/lists/*
8
 
 
 
 
9
  COPY . /app/env
10
 
11
  WORKDIR /app/env
12
 
13
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
14
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
15
+ mv /root/.local/bin/uvx /usr/local/bin/uvx
 
 
16
 
17
  RUN --mount=type=cache,target=/root/.cache/uv \
18
+ uv sync --frozen --no-install-project --no-editable
 
 
 
 
19
 
20
  RUN --mount=type=cache,target=/root/.cache/uv \
21
+ uv sync --frozen --no-editable
 
 
 
 
22
 
23
+ FROM python:3.12-slim
24
 
25
  WORKDIR /app
26
 
27
+ RUN apt-get update && \
28
+ apt-get install -y --no-install-recommends curl && \
29
+ rm -rf /var/lib/apt/lists/*
30
+
31
  COPY --from=builder /app/env/.venv /app/.venv
32
  COPY --from=builder /app/env /app/env
33
 
training/notebooks/fusion_design_lab_training.ipynb ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "7fb27b941602401d91542211134fc71a",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Fusion Design Lab — GRPO Training\n",
9
+ "\n",
10
+ "Train an LLM to optimize stellarator fusion reactor designs using **GRPO** (Group Relative Policy Optimization) with **Unsloth** and **TRL**.\n",
11
+ "\n",
12
+ "The agent interacts with a constrained optimization environment where it adjusts 4 geometric knobs of a stellarator boundary, aiming to **minimize max elongation** while satisfying 3 hard physics constraints:\n",
13
+ "- `aspect_ratio ≤ 4.0`\n",
14
+ "- `average_triangularity ≤ -0.5`\n",
15
+ "- `edge_iota_over_nfp ≥ 0.3`\n",
16
+ "\n",
17
+ "Each episode has **6 evaluations** budgeted. The agent produces a plan of actions and the environment scores it via the `constellaration` physics verifier.\n",
18
+ "\n",
19
+ "**Environment deployed at**: https://creativeengineer-fusion-design-lab.hf.space\n",
20
+ "\n",
21
+ "**Runtime**: Select GPU (T4 or better) via `Runtime > Change runtime type`."
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "id": "acae54e37e7d407bbb7b55eff062a284",
27
+ "metadata": {},
28
+ "source": [
29
+ "## 1. Install Dependencies"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "9a63283cbaf04dbcab1f6479b197f3a8",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "%%capture\n",
40
+ "!pip install unsloth vllm\n",
41
+ "!pip install --no-deps trl\n",
42
+ "!pip install constellaration openenv-core[core] pydantic fastapi uvicorn\n",
43
+ "!pip install matplotlib"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "8dd0d8092fe74a7c96281538738b07e2",
49
+ "metadata": {},
50
+ "source": [
51
+ "## 2. Load Model with Unsloth"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "72eea5119410473aa328ad9291626812",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "from unsloth import FastLanguageModel\n",
62
+ "\n",
63
+ "MODEL_NAME = \"unsloth/Qwen3-0.6B\"\n",
64
+ "MAX_SEQ_LENGTH = 2048\n",
65
+ "\n",
66
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
67
+ " model_name=MODEL_NAME,\n",
68
+ " max_seq_length=MAX_SEQ_LENGTH,\n",
69
+ " load_in_4bit=True,\n",
70
+ " fast_inference=True,\n",
71
+ ")\n",
72
+ "\n",
73
+ "model = FastLanguageModel.get_peft_model(\n",
74
+ " model,\n",
75
+ " r=32,\n",
76
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
77
+ " lora_alpha=32,\n",
78
+ " use_gradient_checkpointing=\"unsloth\",\n",
79
+ ")\n",
80
+ "\n",
81
+ "print(f\"Model loaded: {MODEL_NAME}\")"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "id": "8edb47106e1a46a883d545849b8ab81b",
87
+ "metadata": {},
88
+ "source": [
89
+ "## 3. Setup Stellarator Environment\n",
90
+ "\n",
91
+ "We install the environment package directly from the HF Space repository so training runs locally (no network latency). The same environment is deployed at the HF Space URL above."
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "10185d26023b46108eb7d9f57d49d2b3",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "%%capture\n",
102
+ "!pip install git+https://huggingface.co/spaces/CreativeEngineer/fusion-design-lab"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "id": "8763a12b2bbd4a93a75aff182afb95dc",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "import json\n",
113
+ "import re\n",
114
+ "from typing import Final\n",
115
+ "\n",
116
+ "from fusion_lab.models import StellaratorAction, StellaratorObservation\n",
117
+ "from server.contract import RESET_SEEDS\n",
118
+ "from server.environment import BUDGET, StellaratorEnvironment\n",
119
+ "\n",
120
+ "AVAILABLE_ACTIONS: Final[list[dict[str, str]]] = [\n",
121
+ " {\"intent\": \"run\", \"parameter\": p, \"direction\": d, \"magnitude\": m}\n",
122
+ " for p in [\"aspect_ratio\", \"elongation\", \"rotational_transform\", \"triangularity_scale\"]\n",
123
+ " for d in [\"increase\", \"decrease\"]\n",
124
+ " for m in [\"small\", \"medium\", \"large\"]\n",
125
+ "] + [\n",
126
+ " {\"intent\": \"restore_best\"},\n",
127
+ " {\"intent\": \"submit\"},\n",
128
+ "]\n",
129
+ "\n",
130
+ "ACTION_LABELS: Final[list[str]] = [\n",
131
+ " f\"{a['intent']} {a.get('parameter', '')} {a.get('direction', '')} {a.get('magnitude', '')}\".strip()\n",
132
+ " for a in AVAILABLE_ACTIONS\n",
133
+ "]\n",
134
+ "\n",
135
+ "# Quick smoke test\n",
136
+ "env = StellaratorEnvironment()\n",
137
+ "obs = env.reset(seed=0)\n",
138
+ "print(\n",
139
+ " f\"Environment ready. Initial score: {obs.p1_score:.4f}, feasibility: {obs.p1_feasibility:.4f}\"\n",
140
+ ")\n",
141
+ "print(f\"Budget: {obs.budget_remaining}, Constraints satisfied: {obs.constraints_satisfied}\")"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "markdown",
146
+ "id": "7623eae2785240b9bd12b16a66d81610",
147
+ "metadata": {},
148
+ "source": [
149
+ "## 4. Prompt Template & Action Parsing\n",
150
+ "\n",
151
+ "Each training sample is a prompt describing the stellarator task and initial state. The model generates a plan of actions to optimize the design."
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "7cdc8c89c7104fffa095e18ddfef8986",
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "SYSTEM_PROMPT: Final[\n",
162
+ " str\n",
163
+ "] = \"\"\"You are an expert stellarator fusion reactor designer. Your goal is to optimize a stellarator design by adjusting 4 geometric parameters to minimize max elongation while satisfying physics constraints.\n",
164
+ "\n",
165
+ "Constraints:\n",
166
+ "- aspect_ratio <= 4.0\n",
167
+ "- average_triangularity <= -0.5\n",
168
+ "- edge_iota_over_nfp >= 0.3\n",
169
+ "\n",
170
+ "Available parameters: aspect_ratio, elongation, rotational_transform, triangularity_scale\n",
171
+ "Available directions: increase, decrease\n",
172
+ "Available magnitudes: small, medium, large\n",
173
+ "\n",
174
+ "You have a budget of 6 evaluations. Output a plan of actions as a JSON array. Each action is an object with keys: intent, parameter, direction, magnitude. The last action should be {\"intent\": \"submit\"} to finalize your design.\n",
175
+ "\n",
176
+ "Example:\n",
177
+ "[{\"intent\":\"run\",\"parameter\":\"triangularity_scale\",\"direction\":\"increase\",\"magnitude\":\"small\"},{\"intent\":\"run\",\"parameter\":\"rotational_transform\",\"direction\":\"increase\",\"magnitude\":\"medium\"},{\"intent\":\"submit\"}]\"\"\"\n",
178
+ "\n",
179
+ "\n",
180
+ "def format_observation(obs: StellaratorObservation) -> str:\n",
181
+ " return (\n",
182
+ " f\"Current stellarator state:\\n\"\n",
183
+ " f\" max_elongation: {obs.max_elongation:.4f}\\n\"\n",
184
+ " f\" aspect_ratio: {obs.aspect_ratio:.4f} (constraint: <= 4.0)\\n\"\n",
185
+ " f\" average_triangularity: {obs.average_triangularity:.6f} (constraint: <= -0.5)\\n\"\n",
186
+ " f\" edge_iota_over_nfp: {obs.edge_iota_over_nfp:.4f} (constraint: >= 0.3)\\n\"\n",
187
+ " f\" p1_score: {obs.p1_score:.4f}\\n\"\n",
188
+ " f\" feasibility: {obs.p1_feasibility:.4f}\\n\"\n",
189
+ " f\" constraints_satisfied: {obs.constraints_satisfied}\\n\"\n",
190
+ " f\" budget_remaining: {obs.budget_remaining}\\n\"\n",
191
+ " f\"\\nGenerate an action plan as a JSON array to optimize this design.\"\n",
192
+ " )\n",
193
+ "\n",
194
+ "\n",
195
+ "def build_prompt(obs: StellaratorObservation) -> str:\n",
196
+ " return (\n",
197
+ " f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
198
+ " f\"<|im_start|>user\\n{format_observation(obs)}<|im_end|>\\n\"\n",
199
+ " f\"<|im_start|>assistant\\n\"\n",
200
+ " )\n",
201
+ "\n",
202
+ "\n",
203
+ "def parse_action_plan(text: str) -> list[StellaratorAction]:\n",
204
+ " \"\"\"Parse a JSON action plan from model output.\"\"\"\n",
205
+ " # Find JSON array in the text\n",
206
+ " match = re.search(r\"\\[.*?\\]\", text, re.DOTALL)\n",
207
+ " if not match:\n",
208
+ " return []\n",
209
+ " try:\n",
210
+ " raw = json.loads(match.group())\n",
211
+ " except json.JSONDecodeError:\n",
212
+ " return []\n",
213
+ " actions = []\n",
214
+ " for item in raw:\n",
215
+ " if not isinstance(item, dict) or \"intent\" not in item:\n",
216
+ " continue\n",
217
+ " intent = item[\"intent\"]\n",
218
+ " if intent == \"submit\":\n",
219
+ " actions.append(StellaratorAction(intent=\"submit\"))\n",
220
+ " break\n",
221
+ " if intent == \"restore_best\":\n",
222
+ " actions.append(StellaratorAction(intent=\"restore_best\"))\n",
223
+ " continue\n",
224
+ " if intent == \"run\":\n",
225
+ " p = item.get(\"parameter\", \"\")\n",
226
+ " d = item.get(\"direction\", \"\")\n",
227
+ " m = item.get(\"magnitude\", \"small\")\n",
228
+ " if p in (\n",
229
+ " \"aspect_ratio\",\n",
230
+ " \"elongation\",\n",
231
+ " \"rotational_transform\",\n",
232
+ " \"triangularity_scale\",\n",
233
+ " ) and d in (\"increase\", \"decrease\"):\n",
234
+ " if m not in (\"small\", \"medium\", \"large\"):\n",
235
+ " m = \"small\"\n",
236
+ " actions.append(\n",
237
+ " StellaratorAction(intent=\"run\", parameter=p, direction=d, magnitude=m)\n",
238
+ " )\n",
239
+ " return actions\n",
240
+ "\n",
241
+ "\n",
242
+ "# Test prompt\n",
243
+ "env = StellaratorEnvironment()\n",
244
+ "obs = env.reset(seed=0)\n",
245
+ "prompt = build_prompt(obs)\n",
246
+ "print(prompt[:500])\n",
247
+ "print(\"...\")"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "id": "b118ea5561624da68c537baed56e602f",
253
+ "metadata": {},
254
+ "source": [
255
+ "## 5. Training Dataset\n",
256
+ "\n",
257
+ "Create prompts from all 3 reset seeds. Each prompt is an initial observation that the model must optimize."
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "938c804e27f84196a10c8828c723f798",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "from datasets import Dataset\n",
268
+ "\n",
269
+ "prompts = []\n",
270
+ "for seed_idx in range(len(RESET_SEEDS)):\n",
271
+ " env = StellaratorEnvironment()\n",
272
+ " obs = env.reset(seed=seed_idx)\n",
273
+ " prompt = build_prompt(obs)\n",
274
+ " # Repeat each seed to create a larger training set\n",
275
+ " for _ in range(50):\n",
276
+ " prompts.append({\"prompt\": prompt, \"seed_idx\": seed_idx})\n",
277
+ "\n",
278
+ "dataset = Dataset.from_list(prompts)\n",
279
+ "dataset = dataset.shuffle(seed=42)\n",
280
+ "print(f\"Training dataset: {len(dataset)} samples from {len(RESET_SEEDS)} seeds\")"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "id": "504fb2a444614c0babb325280ed9130a",
286
+ "metadata": {},
287
+ "source": [
288
+ "## 6. Reward Functions\n",
289
+ "\n",
290
+ "Two reward signals:\n",
291
+ "1. **Format reward**: Does the completion contain a valid JSON action plan?\n",
292
+ "2. **Environment reward**: Execute the plan in the stellarator environment and return cumulative reward."
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": null,
298
+ "id": "59bbdb311c014d738909a11f9e486628",
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "import traceback\n",
303
+ "\n",
304
+ "\n",
305
+ "def format_reward_fn(completions: list[str], **kwargs) -> list[float]:\n",
306
+ " \"\"\"Reward for producing a valid, parseable action plan.\"\"\"\n",
307
+ " rewards = []\n",
308
+ " for completion in completions:\n",
309
+ " actions = parse_action_plan(completion)\n",
310
+ " if len(actions) == 0:\n",
311
+ " rewards.append(-1.0)\n",
312
+ " elif any(a.intent == \"submit\" for a in actions):\n",
313
+ " rewards.append(1.0) # valid plan ending with submit\n",
314
+ " else:\n",
315
+ " rewards.append(0.0) # valid actions but no submit\n",
316
+ " return rewards\n",
317
+ "\n",
318
+ "\n",
319
+ "def environment_reward_fn(\n",
320
+ " completions: list[str], seed_idx: list[int] | None = None, **kwargs\n",
321
+ ") -> list[float]:\n",
322
+ " \"\"\"Execute each action plan in the environment and return cumulative reward.\"\"\"\n",
323
+ " rewards = []\n",
324
+ " seeds = seed_idx if seed_idx is not None else [0] * len(completions)\n",
325
+ " for i, completion in enumerate(completions):\n",
326
+ " try:\n",
327
+ " actions = parse_action_plan(completion)\n",
328
+ " if len(actions) == 0:\n",
329
+ " rewards.append(-3.0)\n",
330
+ " continue\n",
331
+ " env = StellaratorEnvironment()\n",
332
+ " env.reset(seed=int(seeds[i]) % len(RESET_SEEDS))\n",
333
+ " total_reward = 0.0\n",
334
+ " for action in actions[:BUDGET]:\n",
335
+ " obs = env.step(action)\n",
336
+ " total_reward += float(obs.reward or 0.0)\n",
337
+ " if obs.done:\n",
338
+ " break\n",
339
+ " rewards.append(total_reward)\n",
340
+ " except Exception:\n",
341
+ " traceback.print_exc()\n",
342
+ " rewards.append(-3.0)\n",
343
+ " return rewards\n",
344
+ "\n",
345
+ "\n",
346
+ "# Test reward functions with a hand-crafted plan\n",
347
+ "test_plan = json.dumps(\n",
348
+ " [\n",
349
+ " {\n",
350
+ " \"intent\": \"run\",\n",
351
+ " \"parameter\": \"triangularity_scale\",\n",
352
+ " \"direction\": \"increase\",\n",
353
+ " \"magnitude\": \"small\",\n",
354
+ " },\n",
355
+ " {\n",
356
+ " \"intent\": \"run\",\n",
357
+ " \"parameter\": \"rotational_transform\",\n",
358
+ " \"direction\": \"increase\",\n",
359
+ " \"magnitude\": \"medium\",\n",
360
+ " },\n",
361
+ " {\"intent\": \"submit\"},\n",
362
+ " ]\n",
363
+ ")\n",
364
+ "print(f\"Format reward: {format_reward_fn([test_plan])}\")\n",
365
+ "print(f\"Environment reward: {environment_reward_fn([test_plan], seed_idx=[0])}\")"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "markdown",
370
+ "id": "b43b363d81ae4b689946ece5c682cd59",
371
+ "metadata": {},
372
+ "source": [
373
+ "## 7. GRPO Training\n",
374
+ "\n",
375
+ "Train the model using Group Relative Policy Optimization. GRPO generates multiple completions per prompt and updates the policy toward higher-reward completions."
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": null,
381
+ "id": "8a65eabff63a45729fe45fb5ade58bdc",
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "from trl import GRPOConfig, GRPOTrainer\n",
386
+ "\n",
387
+ "MAX_PROMPT_LENGTH = 768\n",
388
+ "MAX_COMPLETION_LENGTH = MAX_SEQ_LENGTH - MAX_PROMPT_LENGTH\n",
389
+ "\n",
390
+ "training_args = GRPOConfig(\n",
391
+ " output_dir=\"./grpo_fusion_output\",\n",
392
+ " learning_rate=2e-4,\n",
393
+ " num_generations=4,\n",
394
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
395
+ " max_prompt_length=MAX_PROMPT_LENGTH,\n",
396
+ " per_device_train_batch_size=4,\n",
397
+ " gradient_accumulation_steps=1,\n",
398
+ " max_steps=60,\n",
399
+ " temperature=1.0,\n",
400
+ " logging_steps=1,\n",
401
+ " save_steps=20,\n",
402
+ " bf16=True,\n",
403
+ " report_to=\"none\",\n",
404
+ " seed=42,\n",
405
+ ")\n",
406
+ "\n",
407
+ "trainer = GRPOTrainer(\n",
408
+ " model=model,\n",
409
+ " processing_class=tokenizer,\n",
410
+ " reward_funcs=[format_reward_fn, environment_reward_fn],\n",
411
+ " args=training_args,\n",
412
+ " train_dataset=dataset,\n",
413
+ ")\n",
414
+ "\n",
415
+ "print(\"Starting GRPO training...\")\n",
416
+ "train_result = trainer.train()\n",
417
+ "print(f\"Training complete. Total steps: {train_result.global_step}\")"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "markdown",
422
+ "id": "c3933fab20d04ec698c2621248eb3be0",
423
+ "metadata": {},
424
+ "source": [
425
+ "## 8. Training Results\n",
426
+ "\n",
427
+ "Visualize reward improvement over training steps."
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": null,
433
+ "id": "4dd4641cc4064e0191573fe9c69df29b",
434
+ "metadata": {},
435
+ "outputs": [],
436
+ "source": [
437
+ "import matplotlib.pyplot as plt\n",
438
+ "\n",
439
+ "log_history = trainer.state.log_history\n",
440
+ "steps = [entry[\"step\"] for entry in log_history if \"loss\" in entry]\n",
441
+ "losses = [entry[\"loss\"] for entry in log_history if \"loss\" in entry]\n",
442
+ "\n",
443
+ "# Extract reward metrics if available\n",
444
+ "reward_steps = [\n",
445
+ " entry[\"step\"]\n",
446
+ " for entry in log_history\n",
447
+ " if \"reward\" in entry or \"rewards/environment_reward_fn\" in entry\n",
448
+ "]\n",
449
+ "rewards = [\n",
450
+ " entry.get(\"reward\", entry.get(\"rewards/environment_reward_fn\", 0))\n",
451
+ " for entry in log_history\n",
452
+ " if \"reward\" in entry or \"rewards/environment_reward_fn\" in entry\n",
453
+ "]\n",
454
+ "\n",
455
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
456
+ "\n",
457
+ "axes[0].plot(steps, losses, \"b-\", alpha=0.7)\n",
458
+ "axes[0].set_xlabel(\"Step\")\n",
459
+ "axes[0].set_ylabel(\"Loss\")\n",
460
+ "axes[0].set_title(\"GRPO Training Loss\")\n",
461
+ "axes[0].grid(True, alpha=0.3)\n",
462
+ "\n",
463
+ "if rewards:\n",
464
+ " axes[1].plot(reward_steps, rewards, \"g-o\", alpha=0.7, markersize=3)\n",
465
+ " axes[1].set_xlabel(\"Step\")\n",
466
+ " axes[1].set_ylabel(\"Mean Reward\")\n",
467
+ " axes[1].set_title(\"Environment Reward Over Training\")\n",
468
+ " axes[1].grid(True, alpha=0.3)\n",
469
+ "else:\n",
470
+ " axes[1].text(0.5, 0.5, \"Reward metrics not logged\", ha=\"center\", va=\"center\")\n",
471
+ "\n",
472
+ "plt.suptitle(\"Fusion Design Lab — GRPO Training Curves\", fontsize=14, fontweight=\"bold\")\n",
473
+ "plt.tight_layout()\n",
474
+ "plt.savefig(\"training_curves.png\", dpi=150, bbox_inches=\"tight\")\n",
475
+ "plt.show()\n",
476
+ "print(\"Saved training_curves.png\")"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "markdown",
481
+ "id": "8309879909854d7188b41380fd92a7c3",
482
+ "metadata": {},
483
+ "source": [
484
+ "## 9. Evaluate Trained Policy\n",
485
+ "\n",
486
+ "Generate action plans from the trained model and compare against random baselines."
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "id": "3ed186c9a28b402fb0bc4494df01f08d",
493
+ "metadata": {},
494
+ "outputs": [],
495
+ "source": [
496
+ "import random\n",
497
+ "\n",
498
+ "FastLanguageModel.for_inference(model)\n",
499
+ "\n",
500
+ "\n",
501
+ "def run_episode_with_model(seed_idx: int) -> tuple[float, list[str]]:\n",
502
+ " \"\"\"Run one episode using the trained model.\"\"\"\n",
503
+ " env = StellaratorEnvironment()\n",
504
+ " obs = env.reset(seed=seed_idx)\n",
505
+ " prompt = build_prompt(obs)\n",
506
+ " inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
507
+ " outputs = model.generate(\n",
508
+ " **inputs,\n",
509
+ " max_new_tokens=MAX_COMPLETION_LENGTH,\n",
510
+ " temperature=0.7,\n",
511
+ " do_sample=True,\n",
512
+ " )\n",
513
+ " completion = tokenizer.decode(\n",
514
+ " outputs[0][inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True\n",
515
+ " )\n",
516
+ " actions = parse_action_plan(completion)\n",
517
+ " trace = []\n",
518
+ " total_reward = 0.0\n",
519
+ " for action in actions[:BUDGET]:\n",
520
+ " obs = env.step(action)\n",
521
+ " r = float(obs.reward or 0.0)\n",
522
+ " total_reward += r\n",
523
+ " trace.append(\n",
524
+ " f\" {action.intent} {action.parameter or ''} {action.direction or ''} {action.magnitude or ''} → reward={r:.3f} score={obs.p1_score:.4f} feasible={obs.constraints_satisfied}\".strip()\n",
525
+ " )\n",
526
+ " if obs.done:\n",
527
+ " break\n",
528
+ " return total_reward, trace\n",
529
+ "\n",
530
+ "\n",
531
+ "def run_random_episode(seed_idx: int) -> float:\n",
532
+ " \"\"\"Run one episode with random actions for comparison.\"\"\"\n",
533
+ " env = StellaratorEnvironment()\n",
534
+ " env.reset(seed=seed_idx)\n",
535
+ " total_reward = 0.0\n",
536
+ " for step in range(BUDGET - 1):\n",
537
+ " spec = random.choice(AVAILABLE_ACTIONS[:24]) # run actions only\n",
538
+ " action = StellaratorAction(**spec)\n",
539
+ " obs = env.step(action)\n",
540
+ " total_reward += float(obs.reward or 0.0)\n",
541
+ " if obs.done:\n",
542
+ " return total_reward\n",
543
+ " # submit on last step\n",
544
+ " obs = env.step(StellaratorAction(intent=\"submit\"))\n",
545
+ " total_reward += float(obs.reward or 0.0)\n",
546
+ " return total_reward\n",
547
+ "\n",
548
+ "\n",
549
+ "# Evaluate\n",
550
+ "print(\"=\" * 60)\n",
551
+ "print(\"TRAINED MODEL EPISODES\")\n",
552
+ "print(\"=\" * 60)\n",
553
+ "trained_rewards = []\n",
554
+ "for seed in range(len(RESET_SEEDS)):\n",
555
+ " reward, trace = run_episode_with_model(seed)\n",
556
+ " trained_rewards.append(reward)\n",
557
+ " print(f\"\\nSeed {seed} — Total reward: {reward:.3f}\")\n",
558
+ " for line in trace:\n",
559
+ " print(f\" {line}\")\n",
560
+ "\n",
561
+ "print(f\"\\nMean trained reward: {sum(trained_rewards) / len(trained_rewards):.3f}\")\n",
562
+ "\n",
563
+ "print(\"\\n\" + \"=\" * 60)\n",
564
+ "print(\"RANDOM BASELINE (10 episodes per seed)\")\n",
565
+ "print(\"=\" * 60)\n",
566
+ "random_rewards = []\n",
567
+ "for seed in range(len(RESET_SEEDS)):\n",
568
+ " seed_rewards = [run_random_episode(seed) for _ in range(10)]\n",
569
+ " random_rewards.extend(seed_rewards)\n",
570
+ " print(\n",
571
+ " f\"Seed {seed} — Mean: {sum(seed_rewards) / len(seed_rewards):.3f}, Best: {max(seed_rewards):.3f}\"\n",
572
+ " )\n",
573
+ "\n",
574
+ "print(f\"\\nMean random reward: {sum(random_rewards) / len(random_rewards):.3f}\")\n",
575
+ "print(f\"Mean trained reward: {sum(trained_rewards) / len(trained_rewards):.3f}\")"
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "markdown",
580
+ "id": "cb1e1581032b452c9409d6c6813c49d1",
581
+ "metadata": {},
582
+ "source": [
583
+ "## 10. Connect to Deployed HF Space\n",
584
+ "\n",
585
+ "Demonstrate connecting to the live environment on Hugging Face Spaces."
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "execution_count": null,
591
+ "id": "379cbbc1e968416e875cc15c1202d7eb",
592
+ "metadata": {},
593
+ "outputs": [],
594
+ "source": [
595
+ "from fusion_lab.client import FusionLabClient\n",
596
+ "from fusion_lab.models import StellaratorAction\n",
597
+ "\n",
598
+ "HF_SPACE_URL = \"https://creativeengineer-fusion-design-lab.hf.space\"\n",
599
+ "\n",
600
+ "with FusionLabClient(base_url=HF_SPACE_URL).sync() as client:\n",
601
+ " obs = client.reset()\n",
602
+ " print(f\"Connected to HF Space: {HF_SPACE_URL}\")\n",
603
+ " print(\"Initial observation:\")\n",
604
+ " print(f\" max_elongation: {obs.observation.max_elongation:.4f}\")\n",
605
+ " print(f\" aspect_ratio: {obs.observation.aspect_ratio:.4f}\")\n",
606
+ " print(f\" p1_score: {obs.observation.p1_score:.4f}\")\n",
607
+ " print(f\" constraints_satisfied: {obs.observation.constraints_satisfied}\")\n",
608
+ " print(f\" budget_remaining: {obs.observation.budget_remaining}\")\n",
609
+ "\n",
610
+ " # Run one action from the trained model\n",
611
+ " prompt = build_prompt(obs.observation)\n",
612
+ " inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
613
+ " outputs = model.generate(\n",
614
+ " **inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True\n",
615
+ " )\n",
616
+ " completion = tokenizer.decode(\n",
617
+ " outputs[0][inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True\n",
618
+ " )\n",
619
+ " actions = parse_action_plan(completion)\n",
620
+ "\n",
621
+ " print(f\"\\nModel generated {len(actions)} actions:\")\n",
622
+ " for i, action in enumerate(actions[:BUDGET]):\n",
623
+ " result = client.step(action)\n",
624
+ " print(\n",
625
+ " f\" Step {i + 1}: {action.intent} {action.parameter or ''} {action.direction or ''} {action.magnitude or ''} → reward={result.reward:.3f}\"\n",
626
+ " )\n",
627
+ " if result.done:\n",
628
+ " print(f\" Episode done. Final score: {result.observation.p1_score:.4f}\")\n",
629
+ " break\n",
630
+ "\n",
631
+ "print(\"\\nDone! Environment is live and accessible for training and evaluation.\")"
632
+ ]
633
+ }
634
+ ],
635
+ "metadata": {
636
+ "accelerator": "GPU",
637
+ "colab": {
638
+ "gpuType": "T4",
639
+ "provenance": []
640
+ },
641
+ "kernelspec": {
642
+ "display_name": "Python 3",
643
+ "language": "python",
644
+ "name": "python3"
645
+ },
646
+ "language_info": {
647
+ "name": "python",
648
+ "version": "3.12.0"
649
+ }
650
+ },
651
+ "nbformat": 4,
652
+ "nbformat_minor": 5
653
+ }