from __future__ import annotations import json from dataclasses import asdict, dataclass from typing import Final, Literal, Sequence, TypedDict from fusion_lab.models import ( DirectionName, MagnitudeName, ParameterName, StellaratorAction, StellaratorObservation, ) from server.environment import BUDGET, StellaratorEnvironment RUN_PARAMETERS: Final[tuple[ParameterName, ...]] = ( "aspect_ratio", "elongation", "rotational_transform", "triangularity_scale", ) RUN_DIRECTIONS: Final[tuple[DirectionName, ...]] = ("increase", "decrease") RUN_MAGNITUDES: Final[tuple[MagnitudeName, ...]] = ("small", "medium", "large") class PromptMessage(TypedDict): role: Literal["system", "user"] content: str SYSTEM_PROMPT: Final[str] = """You are an expert stellarator designer. Goal: - satisfy the P1 physics constraints - then improve the design score by lowering max elongation You control a 4-knob low-dimensional design: - aspect_ratio - elongation - rotational_transform - triangularity_scale Action rules: - output a JSON array - each item must be either: - {"intent":"run","parameter":"","direction":"increase|decrease","magnitude":"small|medium|large"} - {"intent":"restore_best"} - {"intent":"submit"} - keep the plan short and within the remaining budget - use "submit" once when you want to stop and lock in the current design Constraint directions: - aspect_ratio <= 4.0 - average_triangularity <= -0.5 - abs(edge_iota_over_nfp) >= 0.3""" def _extract_json_array(text: str) -> str | None: """Return the first balanced ``[...]`` substring that parses as a JSON array. Iterates through every ``[`` in *text*, finds its balanced closing ``]`` (respecting nested brackets and JSON string literals), and attempts ``json.loads``. Returns the first candidate that successfully decodes as a JSON list, skipping prose fragments like ``[draft]``. """ start = text.find("[") while start != -1: depth = 0 in_string = False escape = False matched_end: int | None = None for index in range(start, len(text)): char = text[index] if in_string: if escape: escape = False elif char == "\\": escape = True elif char == '"': in_string = False continue if char == '"': in_string = True elif char == "[": depth += 1 elif char == "]": depth -= 1 if depth == 0: matched_end = index break if matched_end is not None: candidate = text[start : matched_end + 1] try: decoded = json.loads(candidate) if isinstance(decoded, list): return candidate except (json.JSONDecodeError, ValueError): pass start = text.find("[", start + 1) return None @dataclass(frozen=True) class LLMStepTrace: step: int action_label: str reward: float p1_score: float p1_feasibility: float constraints_satisfied: bool evaluation_fidelity: str evaluation_failed: bool budget_remaining: int reward_breakdown: dict[str, object] action_monitor: dict[str, object] episode_total_reward: float trajectory_summary: str diagnostics_text: str @dataclass(frozen=True) class LLMEpisodeTrace: seed: int total_reward: float final_score: float final_feasibility: float constraints_satisfied: bool evaluation_failed: bool final_evaluation_fidelity: str failure_reason: str final_reward_breakdown: dict[str, object] trajectory_summary: str steps: list[LLMStepTrace] def asdict(self) -> dict[str, object]: return asdict(self) def action_label(action: StellaratorAction) -> str: if action.intent != "run": return action.intent return f"{action.intent} {action.parameter} {action.direction} {action.magnitude}" def format_observation(observation: StellaratorObservation) -> str: return ( "Current stellarator state:\n" f"- max_elongation: {observation.max_elongation:.4f}\n" f"- aspect_ratio: {observation.aspect_ratio:.4f} (must stay <= 4.0)\n" f"- average_triangularity: {observation.average_triangularity:.6f} " "(must stay <= -0.5)\n" f"- edge_iota_over_nfp: {observation.edge_iota_over_nfp:.4f} " "(must satisfy abs(.) >= 0.3)\n" f"- aspect_ratio_violation: {observation.aspect_ratio_violation:.6f}\n" f"- triangularity_violation: {observation.triangularity_violation:.6f}\n" f"- iota_violation: {observation.iota_violation:.6f}\n" f"- dominant_constraint: {observation.dominant_constraint}\n" f"- p1_score: {observation.p1_score:.4f}\n" f"- p1_feasibility: {observation.p1_feasibility:.6f}\n" f"- constraints_satisfied: {observation.constraints_satisfied}\n" f"- evaluation_fidelity: {observation.evaluation_fidelity}\n" f"- evaluation_failed: {observation.evaluation_failed}\n" f"- budget_remaining: {observation.budget_remaining}\n" f"- no_progress_steps: {observation.no_progress_steps}\n" f"- best_low_fidelity_score: {observation.best_low_fidelity_score:.4f}\n" f"- best_low_fidelity_feasibility: {observation.best_low_fidelity_feasibility:.6f}\n" f"- diagnostics: {observation.diagnostics_text}\n" ) def build_messages(observation: StellaratorObservation) -> tuple[PromptMessage, PromptMessage]: return ( {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": format_observation(observation)}, ) def build_prompt(observation: StellaratorObservation) -> str: system_message, user_message = build_messages(observation) return ( f"System:\n{system_message['content']}\n\nUser:\n{user_message['content']}\n\nAssistant:\n" ) def extract_json_plan(text: str) -> str | None: return _extract_json_array(text) def _parse_action_item(item: object) -> StellaratorAction | None: if not isinstance(item, dict): return None intent = item.get("intent") if intent == "submit": return StellaratorAction(intent="submit") if intent == "restore_best": return StellaratorAction(intent="restore_best") if intent != "run": return None parameter = item.get("parameter") direction = item.get("direction") magnitude = item.get("magnitude", "small") if parameter not in RUN_PARAMETERS: return None if direction not in RUN_DIRECTIONS: return None if magnitude not in RUN_MAGNITUDES: return None return StellaratorAction( intent="run", parameter=parameter, direction=direction, magnitude=magnitude, ) def parse_action_plan(text: str, *, allow_submit: bool = True) -> list[StellaratorAction]: raw_plan = extract_json_plan(text) if raw_plan is None: return [] try: decoded = json.loads(raw_plan) except json.JSONDecodeError: return [] if not isinstance(decoded, list): return [] parsed: list[StellaratorAction] = [] for item in decoded: action = _parse_action_item(item) if action is None: continue if action.intent == "submit" and not allow_submit: continue parsed.append(action) if action.intent == "submit" and allow_submit: break return parsed def run_episode_with_actions( actions: Sequence[StellaratorAction], *, seed_idx: int, auto_submit: bool = False, allow_submit: bool = True, ) -> LLMEpisodeTrace: environment = StellaratorEnvironment() observation = environment.reset(seed=seed_idx) step_traces: list[LLMStepTrace] = [] total_reward = 0.0 def _step_and_record(action: StellaratorAction, step_index: int) -> bool: nonlocal observation, total_reward observation = environment.step(action) reward = float(observation.reward) if observation.reward is not None else 0.0 total_reward += reward step_traces.append( LLMStepTrace( step=step_index, action_label=action_label(action), reward=reward, p1_score=observation.p1_score, p1_feasibility=observation.p1_feasibility, constraints_satisfied=observation.constraints_satisfied, evaluation_fidelity=observation.evaluation_fidelity, evaluation_failed=observation.evaluation_failed, budget_remaining=observation.budget_remaining, reward_breakdown=observation.reward_breakdown.model_dump(), action_monitor=observation.action_monitor.model_dump(), episode_total_reward=observation.episode_total_reward, trajectory_summary=observation.trajectory_summary, diagnostics_text=observation.diagnostics_text, ) ) return bool(observation.done) done = False step_index = 0 rollout_actions = [action for action in actions if allow_submit or action.intent != "submit"] if len(rollout_actions) > BUDGET: submit_index = next( (idx for idx, action in enumerate(rollout_actions) if action.intent == "submit"), None, ) if submit_index is not None and submit_index >= BUDGET: # Keep terminal submit within the budget if the model over-runs plan length. rollout_actions = rollout_actions[: BUDGET - 1] + [rollout_actions[submit_index]] else: rollout_actions = rollout_actions[:BUDGET] for step_index, action in enumerate(rollout_actions[:BUDGET], start=1): if _step_and_record(action, step_index): done = True break if auto_submit and not done: _step_and_record(StellaratorAction(intent="submit"), step_index + 1) return LLMEpisodeTrace( seed=seed_idx, total_reward=round(total_reward, 4), final_score=observation.p1_score, final_feasibility=observation.p1_feasibility, constraints_satisfied=observation.constraints_satisfied, evaluation_failed=observation.evaluation_failed, final_evaluation_fidelity=observation.evaluation_fidelity, failure_reason=observation.failure_reason, final_reward_breakdown=observation.reward_breakdown.model_dump(), trajectory_summary=observation.trajectory_summary, steps=step_traces, )