File size: 10,788 Bytes
ebd0ff3
 
 
 
cdc237b
ebd0ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc237b
 
 
 
 
 
ebd0ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc237b
ebd0ff3
cdc237b
ebd0ff3
 
 
 
2fccde8
ebd0ff3
e826e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebd0ff3
 
 
 
 
 
 
 
 
 
 
 
 
5e0e606
 
 
 
ebd0ff3
 
 
 
 
 
 
 
 
 
 
5e0e606
 
 
 
ebd0ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fccde8
 
 
 
 
ebd0ff3
 
 
 
 
 
cdc237b
ebd0ff3
 
 
 
 
 
cdc237b
 
 
 
 
 
 
ebd0ff3
cdc237b
ebd0ff3
cdc237b
ebd0ff3
 
 
 
e826e11
ebd0ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc237b
ebd0ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c3599b
 
ebd0ff3
9c3599b
ebd0ff3
 
 
 
 
 
 
 
9c3599b
cdc237b
ebd0ff3
 
 
 
 
 
5f2da5f
 
ebd0ff3
5f2da5f
ebd0ff3
 
 
 
 
 
 
 
 
 
 
 
5e0e606
 
 
 
ebd0ff3
 
 
5f2da5f
 
 
 
9c3599b
cdc237b
 
 
 
 
 
 
 
 
 
9c3599b
5f2da5f
 
ebd0ff3
 
9c3599b
5f2da5f
 
ebd0ff3
 
 
 
 
 
 
5e0e606
 
 
 
ebd0ff3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
from __future__ import annotations

import json
from dataclasses import asdict, dataclass
from typing import Final, Literal, Sequence, TypedDict

from fusion_lab.models import (
    DirectionName,
    MagnitudeName,
    ParameterName,
    StellaratorAction,
    StellaratorObservation,
)
from server.environment import BUDGET, StellaratorEnvironment

RUN_PARAMETERS: Final[tuple[ParameterName, ...]] = (
    "aspect_ratio",
    "elongation",
    "rotational_transform",
    "triangularity_scale",
)
RUN_DIRECTIONS: Final[tuple[DirectionName, ...]] = ("increase", "decrease")
RUN_MAGNITUDES: Final[tuple[MagnitudeName, ...]] = ("small", "medium", "large")


class PromptMessage(TypedDict):
    role: Literal["system", "user"]
    content: str


SYSTEM_PROMPT: Final[str] = """You are an expert stellarator designer.

Goal:
- satisfy the P1 physics constraints
- then improve the design score by lowering max elongation

You control a 4-knob low-dimensional design:
- aspect_ratio
- elongation
- rotational_transform
- triangularity_scale

Action rules:
- output a JSON array
- each item must be either:
  - {"intent":"run","parameter":"<parameter>","direction":"increase|decrease","magnitude":"small|medium|large"}
  - {"intent":"restore_best"}
  - {"intent":"submit"}
- keep the plan short and within the remaining budget
- use "submit" once when you want to stop and lock in the current design

Constraint directions:
- aspect_ratio <= 4.0
- average_triangularity <= -0.5
- abs(edge_iota_over_nfp) >= 0.3"""


def _extract_json_array(text: str) -> str | None:
    """Return the first balanced ``[...]`` substring that parses as a JSON array.

    Iterates through every ``[`` in *text*, finds its balanced closing ``]``
    (respecting nested brackets and JSON string literals), and attempts
    ``json.loads``.  Returns the first candidate that successfully decodes as a
    JSON list, skipping prose fragments like ``[draft]``.
    """
    start = text.find("[")
    while start != -1:
        depth = 0
        in_string = False
        escape = False
        matched_end: int | None = None
        for index in range(start, len(text)):
            char = text[index]
            if in_string:
                if escape:
                    escape = False
                elif char == "\\":
                    escape = True
                elif char == '"':
                    in_string = False
                continue
            if char == '"':
                in_string = True
            elif char == "[":
                depth += 1
            elif char == "]":
                depth -= 1
                if depth == 0:
                    matched_end = index
                    break
        if matched_end is not None:
            candidate = text[start : matched_end + 1]
            try:
                decoded = json.loads(candidate)
                if isinstance(decoded, list):
                    return candidate
            except (json.JSONDecodeError, ValueError):
                pass
        start = text.find("[", start + 1)
    return None


@dataclass(frozen=True)
class LLMStepTrace:
    step: int
    action_label: str
    reward: float
    p1_score: float
    p1_feasibility: float
    constraints_satisfied: bool
    evaluation_fidelity: str
    evaluation_failed: bool
    budget_remaining: int
    reward_breakdown: dict[str, object]
    action_monitor: dict[str, object]
    episode_total_reward: float
    trajectory_summary: str
    diagnostics_text: str


@dataclass(frozen=True)
class LLMEpisodeTrace:
    seed: int
    total_reward: float
    final_score: float
    final_feasibility: float
    constraints_satisfied: bool
    evaluation_failed: bool
    final_evaluation_fidelity: str
    failure_reason: str
    final_reward_breakdown: dict[str, object]
    trajectory_summary: str
    steps: list[LLMStepTrace]

    def asdict(self) -> dict[str, object]:
        return asdict(self)


def action_label(action: StellaratorAction) -> str:
    if action.intent != "run":
        return action.intent
    return f"{action.intent} {action.parameter} {action.direction} {action.magnitude}"


def format_observation(observation: StellaratorObservation) -> str:
    return (
        "Current stellarator state:\n"
        f"- max_elongation: {observation.max_elongation:.4f}\n"
        f"- aspect_ratio: {observation.aspect_ratio:.4f} (must stay <= 4.0)\n"
        f"- average_triangularity: {observation.average_triangularity:.6f} "
        "(must stay <= -0.5)\n"
        f"- edge_iota_over_nfp: {observation.edge_iota_over_nfp:.4f} "
        "(must satisfy abs(.) >= 0.3)\n"
        f"- aspect_ratio_violation: {observation.aspect_ratio_violation:.6f}\n"
        f"- triangularity_violation: {observation.triangularity_violation:.6f}\n"
        f"- iota_violation: {observation.iota_violation:.6f}\n"
        f"- dominant_constraint: {observation.dominant_constraint}\n"
        f"- p1_score: {observation.p1_score:.4f}\n"
        f"- p1_feasibility: {observation.p1_feasibility:.6f}\n"
        f"- constraints_satisfied: {observation.constraints_satisfied}\n"
        f"- evaluation_fidelity: {observation.evaluation_fidelity}\n"
        f"- evaluation_failed: {observation.evaluation_failed}\n"
        f"- budget_remaining: {observation.budget_remaining}\n"
        f"- no_progress_steps: {observation.no_progress_steps}\n"
        f"- best_low_fidelity_score: {observation.best_low_fidelity_score:.4f}\n"
        f"- best_low_fidelity_feasibility: {observation.best_low_fidelity_feasibility:.6f}\n"
        f"- diagnostics: {observation.diagnostics_text}\n"
    )


def build_messages(observation: StellaratorObservation) -> tuple[PromptMessage, PromptMessage]:
    return (
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": format_observation(observation)},
    )


def build_prompt(observation: StellaratorObservation) -> str:
    system_message, user_message = build_messages(observation)
    return (
        f"System:\n{system_message['content']}\n\nUser:\n{user_message['content']}\n\nAssistant:\n"
    )


def extract_json_plan(text: str) -> str | None:
    return _extract_json_array(text)


def _parse_action_item(item: object) -> StellaratorAction | None:
    if not isinstance(item, dict):
        return None

    intent = item.get("intent")
    if intent == "submit":
        return StellaratorAction(intent="submit")
    if intent == "restore_best":
        return StellaratorAction(intent="restore_best")
    if intent != "run":
        return None

    parameter = item.get("parameter")
    direction = item.get("direction")
    magnitude = item.get("magnitude", "small")
    if parameter not in RUN_PARAMETERS:
        return None
    if direction not in RUN_DIRECTIONS:
        return None
    if magnitude not in RUN_MAGNITUDES:
        return None

    return StellaratorAction(
        intent="run",
        parameter=parameter,
        direction=direction,
        magnitude=magnitude,
    )


def parse_action_plan(text: str, *, allow_submit: bool = True) -> list[StellaratorAction]:
    raw_plan = extract_json_plan(text)
    if raw_plan is None:
        return []
    try:
        decoded = json.loads(raw_plan)
    except json.JSONDecodeError:
        return []
    if not isinstance(decoded, list):
        return []

    parsed: list[StellaratorAction] = []
    for item in decoded:
        action = _parse_action_item(item)
        if action is None:
            continue
        if action.intent == "submit" and not allow_submit:
            continue
        parsed.append(action)
        if action.intent == "submit" and allow_submit:
            break
    return parsed


def run_episode_with_actions(
    actions: Sequence[StellaratorAction],
    *,
    seed_idx: int,
    auto_submit: bool = False,
    allow_submit: bool = True,
) -> LLMEpisodeTrace:
    environment = StellaratorEnvironment()
    observation = environment.reset(seed=seed_idx)
    step_traces: list[LLMStepTrace] = []
    total_reward = 0.0

    def _step_and_record(action: StellaratorAction, step_index: int) -> bool:
        nonlocal observation, total_reward
        observation = environment.step(action)
        reward = float(observation.reward) if observation.reward is not None else 0.0
        total_reward += reward
        step_traces.append(
            LLMStepTrace(
                step=step_index,
                action_label=action_label(action),
                reward=reward,
                p1_score=observation.p1_score,
                p1_feasibility=observation.p1_feasibility,
                constraints_satisfied=observation.constraints_satisfied,
                evaluation_fidelity=observation.evaluation_fidelity,
                evaluation_failed=observation.evaluation_failed,
                budget_remaining=observation.budget_remaining,
                reward_breakdown=observation.reward_breakdown.model_dump(),
                action_monitor=observation.action_monitor.model_dump(),
                episode_total_reward=observation.episode_total_reward,
                trajectory_summary=observation.trajectory_summary,
                diagnostics_text=observation.diagnostics_text,
            )
        )
        return bool(observation.done)

    done = False
    step_index = 0
    rollout_actions = [action for action in actions if allow_submit or action.intent != "submit"]
    if len(rollout_actions) > BUDGET:
        submit_index = next(
            (idx for idx, action in enumerate(rollout_actions) if action.intent == "submit"),
            None,
        )
        if submit_index is not None and submit_index >= BUDGET:
            # Keep terminal submit within the budget if the model over-runs plan length.
            rollout_actions = rollout_actions[: BUDGET - 1] + [rollout_actions[submit_index]]
        else:
            rollout_actions = rollout_actions[:BUDGET]
    for step_index, action in enumerate(rollout_actions[:BUDGET], start=1):
        if _step_and_record(action, step_index):
            done = True
            break

    if auto_submit and not done:
        _step_and_record(StellaratorAction(intent="submit"), step_index + 1)

    return LLMEpisodeTrace(
        seed=seed_idx,
        total_reward=round(total_reward, 4),
        final_score=observation.p1_score,
        final_feasibility=observation.p1_feasibility,
        constraints_satisfied=observation.constraints_satisfied,
        evaluation_failed=observation.evaluation_failed,
        final_evaluation_fidelity=observation.evaluation_fidelity,
        failure_reason=observation.failure_reason,
        final_reward_breakdown=observation.reward_breakdown.model_dump(),
        trajectory_summary=observation.trajectory_summary,
        steps=step_traces,
    )