File size: 10,788 Bytes
ebd0ff3 cdc237b ebd0ff3 cdc237b ebd0ff3 cdc237b ebd0ff3 cdc237b ebd0ff3 2fccde8 ebd0ff3 e826e11 ebd0ff3 5e0e606 ebd0ff3 5e0e606 ebd0ff3 2fccde8 ebd0ff3 cdc237b ebd0ff3 cdc237b ebd0ff3 cdc237b ebd0ff3 cdc237b ebd0ff3 e826e11 ebd0ff3 cdc237b ebd0ff3 9c3599b ebd0ff3 9c3599b ebd0ff3 9c3599b cdc237b ebd0ff3 5f2da5f ebd0ff3 5f2da5f ebd0ff3 5e0e606 ebd0ff3 5f2da5f 9c3599b cdc237b 9c3599b 5f2da5f ebd0ff3 9c3599b 5f2da5f ebd0ff3 5e0e606 ebd0ff3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 | from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from typing import Final, Literal, Sequence, TypedDict
from fusion_lab.models import (
DirectionName,
MagnitudeName,
ParameterName,
StellaratorAction,
StellaratorObservation,
)
from server.environment import BUDGET, StellaratorEnvironment
RUN_PARAMETERS: Final[tuple[ParameterName, ...]] = (
"aspect_ratio",
"elongation",
"rotational_transform",
"triangularity_scale",
)
RUN_DIRECTIONS: Final[tuple[DirectionName, ...]] = ("increase", "decrease")
RUN_MAGNITUDES: Final[tuple[MagnitudeName, ...]] = ("small", "medium", "large")
class PromptMessage(TypedDict):
role: Literal["system", "user"]
content: str
SYSTEM_PROMPT: Final[str] = """You are an expert stellarator designer.
Goal:
- satisfy the P1 physics constraints
- then improve the design score by lowering max elongation
You control a 4-knob low-dimensional design:
- aspect_ratio
- elongation
- rotational_transform
- triangularity_scale
Action rules:
- output a JSON array
- each item must be either:
- {"intent":"run","parameter":"<parameter>","direction":"increase|decrease","magnitude":"small|medium|large"}
- {"intent":"restore_best"}
- {"intent":"submit"}
- keep the plan short and within the remaining budget
- use "submit" once when you want to stop and lock in the current design
Constraint directions:
- aspect_ratio <= 4.0
- average_triangularity <= -0.5
- abs(edge_iota_over_nfp) >= 0.3"""
def _extract_json_array(text: str) -> str | None:
"""Return the first balanced ``[...]`` substring that parses as a JSON array.
Iterates through every ``[`` in *text*, finds its balanced closing ``]``
(respecting nested brackets and JSON string literals), and attempts
``json.loads``. Returns the first candidate that successfully decodes as a
JSON list, skipping prose fragments like ``[draft]``.
"""
start = text.find("[")
while start != -1:
depth = 0
in_string = False
escape = False
matched_end: int | None = None
for index in range(start, len(text)):
char = text[index]
if in_string:
if escape:
escape = False
elif char == "\\":
escape = True
elif char == '"':
in_string = False
continue
if char == '"':
in_string = True
elif char == "[":
depth += 1
elif char == "]":
depth -= 1
if depth == 0:
matched_end = index
break
if matched_end is not None:
candidate = text[start : matched_end + 1]
try:
decoded = json.loads(candidate)
if isinstance(decoded, list):
return candidate
except (json.JSONDecodeError, ValueError):
pass
start = text.find("[", start + 1)
return None
@dataclass(frozen=True)
class LLMStepTrace:
step: int
action_label: str
reward: float
p1_score: float
p1_feasibility: float
constraints_satisfied: bool
evaluation_fidelity: str
evaluation_failed: bool
budget_remaining: int
reward_breakdown: dict[str, object]
action_monitor: dict[str, object]
episode_total_reward: float
trajectory_summary: str
diagnostics_text: str
@dataclass(frozen=True)
class LLMEpisodeTrace:
seed: int
total_reward: float
final_score: float
final_feasibility: float
constraints_satisfied: bool
evaluation_failed: bool
final_evaluation_fidelity: str
failure_reason: str
final_reward_breakdown: dict[str, object]
trajectory_summary: str
steps: list[LLMStepTrace]
def asdict(self) -> dict[str, object]:
return asdict(self)
def action_label(action: StellaratorAction) -> str:
if action.intent != "run":
return action.intent
return f"{action.intent} {action.parameter} {action.direction} {action.magnitude}"
def format_observation(observation: StellaratorObservation) -> str:
return (
"Current stellarator state:\n"
f"- max_elongation: {observation.max_elongation:.4f}\n"
f"- aspect_ratio: {observation.aspect_ratio:.4f} (must stay <= 4.0)\n"
f"- average_triangularity: {observation.average_triangularity:.6f} "
"(must stay <= -0.5)\n"
f"- edge_iota_over_nfp: {observation.edge_iota_over_nfp:.4f} "
"(must satisfy abs(.) >= 0.3)\n"
f"- aspect_ratio_violation: {observation.aspect_ratio_violation:.6f}\n"
f"- triangularity_violation: {observation.triangularity_violation:.6f}\n"
f"- iota_violation: {observation.iota_violation:.6f}\n"
f"- dominant_constraint: {observation.dominant_constraint}\n"
f"- p1_score: {observation.p1_score:.4f}\n"
f"- p1_feasibility: {observation.p1_feasibility:.6f}\n"
f"- constraints_satisfied: {observation.constraints_satisfied}\n"
f"- evaluation_fidelity: {observation.evaluation_fidelity}\n"
f"- evaluation_failed: {observation.evaluation_failed}\n"
f"- budget_remaining: {observation.budget_remaining}\n"
f"- no_progress_steps: {observation.no_progress_steps}\n"
f"- best_low_fidelity_score: {observation.best_low_fidelity_score:.4f}\n"
f"- best_low_fidelity_feasibility: {observation.best_low_fidelity_feasibility:.6f}\n"
f"- diagnostics: {observation.diagnostics_text}\n"
)
def build_messages(observation: StellaratorObservation) -> tuple[PromptMessage, PromptMessage]:
return (
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": format_observation(observation)},
)
def build_prompt(observation: StellaratorObservation) -> str:
system_message, user_message = build_messages(observation)
return (
f"System:\n{system_message['content']}\n\nUser:\n{user_message['content']}\n\nAssistant:\n"
)
def extract_json_plan(text: str) -> str | None:
return _extract_json_array(text)
def _parse_action_item(item: object) -> StellaratorAction | None:
if not isinstance(item, dict):
return None
intent = item.get("intent")
if intent == "submit":
return StellaratorAction(intent="submit")
if intent == "restore_best":
return StellaratorAction(intent="restore_best")
if intent != "run":
return None
parameter = item.get("parameter")
direction = item.get("direction")
magnitude = item.get("magnitude", "small")
if parameter not in RUN_PARAMETERS:
return None
if direction not in RUN_DIRECTIONS:
return None
if magnitude not in RUN_MAGNITUDES:
return None
return StellaratorAction(
intent="run",
parameter=parameter,
direction=direction,
magnitude=magnitude,
)
def parse_action_plan(text: str, *, allow_submit: bool = True) -> list[StellaratorAction]:
raw_plan = extract_json_plan(text)
if raw_plan is None:
return []
try:
decoded = json.loads(raw_plan)
except json.JSONDecodeError:
return []
if not isinstance(decoded, list):
return []
parsed: list[StellaratorAction] = []
for item in decoded:
action = _parse_action_item(item)
if action is None:
continue
if action.intent == "submit" and not allow_submit:
continue
parsed.append(action)
if action.intent == "submit" and allow_submit:
break
return parsed
def run_episode_with_actions(
actions: Sequence[StellaratorAction],
*,
seed_idx: int,
auto_submit: bool = False,
allow_submit: bool = True,
) -> LLMEpisodeTrace:
environment = StellaratorEnvironment()
observation = environment.reset(seed=seed_idx)
step_traces: list[LLMStepTrace] = []
total_reward = 0.0
def _step_and_record(action: StellaratorAction, step_index: int) -> bool:
nonlocal observation, total_reward
observation = environment.step(action)
reward = float(observation.reward) if observation.reward is not None else 0.0
total_reward += reward
step_traces.append(
LLMStepTrace(
step=step_index,
action_label=action_label(action),
reward=reward,
p1_score=observation.p1_score,
p1_feasibility=observation.p1_feasibility,
constraints_satisfied=observation.constraints_satisfied,
evaluation_fidelity=observation.evaluation_fidelity,
evaluation_failed=observation.evaluation_failed,
budget_remaining=observation.budget_remaining,
reward_breakdown=observation.reward_breakdown.model_dump(),
action_monitor=observation.action_monitor.model_dump(),
episode_total_reward=observation.episode_total_reward,
trajectory_summary=observation.trajectory_summary,
diagnostics_text=observation.diagnostics_text,
)
)
return bool(observation.done)
done = False
step_index = 0
rollout_actions = [action for action in actions if allow_submit or action.intent != "submit"]
if len(rollout_actions) > BUDGET:
submit_index = next(
(idx for idx, action in enumerate(rollout_actions) if action.intent == "submit"),
None,
)
if submit_index is not None and submit_index >= BUDGET:
# Keep terminal submit within the budget if the model over-runs plan length.
rollout_actions = rollout_actions[: BUDGET - 1] + [rollout_actions[submit_index]]
else:
rollout_actions = rollout_actions[:BUDGET]
for step_index, action in enumerate(rollout_actions[:BUDGET], start=1):
if _step_and_record(action, step_index):
done = True
break
if auto_submit and not done:
_step_and_record(StellaratorAction(intent="submit"), step_index + 1)
return LLMEpisodeTrace(
seed=seed_idx,
total_reward=round(total_reward, 4),
final_score=observation.p1_score,
final_feasibility=observation.p1_feasibility,
constraints_satisfied=observation.constraints_satisfied,
evaluation_failed=observation.evaluation_failed,
final_evaluation_fidelity=observation.evaluation_fidelity,
failure_reason=observation.failure_reason,
final_reward_breakdown=observation.reward_breakdown.model_dump(),
trajectory_summary=observation.trajectory_summary,
steps=step_traces,
)
|