CreativeEngineer commited on
Commit
2f5db5e
·
1 Parent(s): 5354ca9

feat: add local environment scaffold and baselines

Browse files
baselines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Random and heuristic baselines for the stellarator design environment."""
baselines/compare.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run both baselines and print a comparison summary."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+
7
+ from baselines.heuristic_agent import heuristic_episode
8
+ from baselines.random_agent import random_episode
9
+ from server.environment import StellaratorEnvironment
10
+
11
+
12
+ def main(n_episodes: int = 20) -> None:
13
+ env = StellaratorEnvironment()
14
+
15
+ random_rewards: list[float] = []
16
+ heuristic_rewards: list[float] = []
17
+ random_best_qs: list[float] = []
18
+ heuristic_best_qs: list[float] = []
19
+
20
+ for i in range(n_episodes):
21
+ rr, rt = random_episode(env, seed=i)
22
+ random_rewards.append(rr)
23
+ random_best_qs.append(rt[-1]["best_qs"])
24
+
25
+ hr, ht = heuristic_episode(env, seed=i)
26
+ heuristic_rewards.append(hr)
27
+ heuristic_best_qs.append(ht[-1]["best_qs"])
28
+
29
+ r_mean = sum(random_rewards) / len(random_rewards)
30
+ h_mean = sum(heuristic_rewards) / len(heuristic_rewards)
31
+ r_qs = sum(random_best_qs) / len(random_best_qs)
32
+ h_qs = sum(heuristic_best_qs) / len(heuristic_best_qs)
33
+
34
+ print(f"{'Metric':<25} {'Random':>12} {'Heuristic':>12}")
35
+ print("-" * 51)
36
+ print(f"{'Mean reward':<25} {r_mean:>+12.4f} {h_mean:>+12.4f}")
37
+ print(f"{'Mean best QS residual':<25} {r_qs:>12.6f} {h_qs:>12.6f}")
38
+ print(f"{'Episodes':<25} {n_episodes:>12d} {n_episodes:>12d}")
39
+ print()
40
+
41
+ wins = sum(1 for h, r in zip(heuristic_rewards, random_rewards) if h > r)
42
+ print(f"Heuristic wins: {wins}/{n_episodes} episodes ({100 * wins / n_episodes:.0f}%)")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
47
+ main(n)
baselines/heuristic_agent.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Heuristic baseline agent for the stellarator design environment.
2
+
3
+ Strategy: guided perturbations informed by domain knowledge.
4
+ 1. Probe the most sensitive coefficient (zs12) first with a small move.
5
+ 2. Apply medium perturbations in directions that typically improve QS.
6
+ 3. Use restore_best to recover from any worsening.
7
+ 4. Submit before exhausting budget.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import sys
13
+
14
+ from fusion_lab.models import StellaratorAction
15
+ from server.environment import StellaratorEnvironment
16
+
17
+ STRATEGY: list[tuple[str, str, str, str]] = [
18
+ ("tune_zs12", "decrease", "small", "hot"),
19
+ ("tune_zs12", "decrease", "medium", "hot"),
20
+ ("tune_rc11", "increase", "small", "hot"),
21
+ ("tune_rc10", "increase", "medium", "hot"),
22
+ ("tune_zs11", "decrease", "small", "hot"),
23
+ ]
24
+
25
+
26
+ def heuristic_episode(
27
+ env: StellaratorEnvironment, seed: int | None = None
28
+ ) -> tuple[float, list[dict[str, object]]]:
29
+ obs = env.reset(seed=seed)
30
+ total_reward = 0.0
31
+ trace: list[dict[str, object]] = [{"step": 0, "qs": obs.quasi_symmetry_residual}]
32
+ prev_best = obs.best_qs_residual
33
+
34
+ for operator, direction, magnitude, restart in STRATEGY:
35
+ if obs.done or obs.budget_remaining <= 1:
36
+ break
37
+
38
+ action = StellaratorAction(
39
+ intent="run",
40
+ operator=operator,
41
+ direction=direction,
42
+ magnitude=magnitude,
43
+ restart=restart,
44
+ )
45
+ obs = env.step(action)
46
+ total_reward += obs.reward or 0.0
47
+ trace.append(
48
+ {
49
+ "step": len(trace),
50
+ "action": f"{operator} {direction} {magnitude}",
51
+ "qs": obs.quasi_symmetry_residual,
52
+ "best_qs": obs.best_qs_residual,
53
+ "reward": obs.reward,
54
+ }
55
+ )
56
+
57
+ if obs.best_qs_residual > prev_best and obs.budget_remaining > 1:
58
+ restore = StellaratorAction(intent="restore_best")
59
+ obs = env.step(restore)
60
+ total_reward += obs.reward or 0.0
61
+ trace.append(
62
+ {
63
+ "step": len(trace),
64
+ "action": "restore_best",
65
+ "qs": obs.quasi_symmetry_residual,
66
+ "best_qs": obs.best_qs_residual,
67
+ "reward": obs.reward,
68
+ }
69
+ )
70
+
71
+ prev_best = obs.best_qs_residual
72
+
73
+ if not obs.done:
74
+ submit = StellaratorAction(intent="submit")
75
+ obs = env.step(submit)
76
+ total_reward += obs.reward or 0.0
77
+ trace.append(
78
+ {
79
+ "step": len(trace),
80
+ "action": "submit",
81
+ "qs": obs.quasi_symmetry_residual,
82
+ "best_qs": obs.best_qs_residual,
83
+ "reward": obs.reward,
84
+ }
85
+ )
86
+
87
+ return total_reward, trace
88
+
89
+
90
+ def main(n_episodes: int = 20) -> None:
91
+ env = StellaratorEnvironment()
92
+ rewards: list[float] = []
93
+
94
+ for i in range(n_episodes):
95
+ total_reward, trace = heuristic_episode(env, seed=i)
96
+ final = trace[-1]
97
+ rewards.append(total_reward)
98
+ print(
99
+ f"Episode {i:3d}: steps={len(trace) - 1} "
100
+ f"final_qs={final['qs']:.6f} best_qs={final['best_qs']:.6f} "
101
+ f"reward={total_reward:+.4f}"
102
+ )
103
+
104
+ mean_reward = sum(rewards) / len(rewards)
105
+ print(f"\nHeuristic baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
110
+ main(n)
baselines/random_agent.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Random baseline agent for the stellarator design environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ import sys
7
+
8
+ from fusion_lab.models import StellaratorAction
9
+ from server.environment import StellaratorEnvironment
10
+
11
+ OPERATORS = ["tune_rc10", "tune_rc11", "tune_zs11", "tune_zs12"]
12
+ DIRECTIONS = ["increase", "decrease"]
13
+ MAGNITUDES = ["small", "medium", "large"]
14
+ RESTARTS = ["hot", "cold"]
15
+
16
+
17
+ def random_episode(
18
+ env: StellaratorEnvironment, seed: int | None = None
19
+ ) -> tuple[float, list[dict[str, object]]]:
20
+ rng = random.Random(seed)
21
+ obs = env.reset(seed=seed)
22
+ total_reward = 0.0
23
+ trace: list[dict[str, object]] = [{"step": 0, "qs": obs.quasi_symmetry_residual}]
24
+
25
+ while not obs.done:
26
+ if obs.budget_remaining <= 0:
27
+ action = StellaratorAction(intent="submit")
28
+ else:
29
+ action = StellaratorAction(
30
+ intent="run",
31
+ operator=rng.choice(OPERATORS),
32
+ direction=rng.choice(DIRECTIONS),
33
+ magnitude=rng.choice(MAGNITUDES),
34
+ restart=rng.choice(RESTARTS),
35
+ )
36
+ obs = env.step(action)
37
+ total_reward += obs.reward or 0.0
38
+ trace.append(
39
+ {
40
+ "step": len(trace),
41
+ "action": action.intent,
42
+ "qs": obs.quasi_symmetry_residual,
43
+ "best_qs": obs.best_qs_residual,
44
+ "reward": obs.reward,
45
+ }
46
+ )
47
+
48
+ return total_reward, trace
49
+
50
+
51
+ def main(n_episodes: int = 20) -> None:
52
+ env = StellaratorEnvironment()
53
+ rewards: list[float] = []
54
+
55
+ for i in range(n_episodes):
56
+ total_reward, trace = random_episode(env, seed=i)
57
+ final = trace[-1]
58
+ rewards.append(total_reward)
59
+ print(
60
+ f"Episode {i:3d}: steps={len(trace) - 1} "
61
+ f"final_qs={final['qs']:.6f} best_qs={final['best_qs']:.6f} "
62
+ f"reward={total_reward:+.4f}"
63
+ )
64
+
65
+ mean_reward = sum(rewards) / len(rewards)
66
+ print(f"\nRandom baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
71
+ main(n)
fusion_lab/client.py CHANGED
@@ -7,13 +7,13 @@ from fusion_lab.models import StellaratorAction, StellaratorObservation, Stellar
7
 
8
 
9
  class FusionLabClient(EnvClient[StellaratorAction, StellaratorObservation, StellaratorState]):
10
- """Thin typed client wrapper for the remote OpenEnv environment."""
11
 
12
  def _step_payload(self, action: StellaratorAction) -> dict[str, object]:
13
  return action.model_dump(exclude_none=True)
14
 
15
  def _parse_result(self, payload: dict[str, object]) -> StepResult[StellaratorObservation]:
16
- observation = StellaratorObservation(**payload)
17
  return StepResult(
18
  observation=observation,
19
  reward=observation.reward,
@@ -21,4 +21,4 @@ class FusionLabClient(EnvClient[StellaratorAction, StellaratorObservation, Stell
21
  )
22
 
23
  def _parse_state(self, payload: dict[str, object]) -> StellaratorState:
24
- return StellaratorState(**payload)
 
7
 
8
 
9
  class FusionLabClient(EnvClient[StellaratorAction, StellaratorObservation, StellaratorState]):
10
+ """Typed client wrapper for the remote Fusion Design Lab environment."""
11
 
12
  def _step_payload(self, action: StellaratorAction) -> dict[str, object]:
13
  return action.model_dump(exclude_none=True)
14
 
15
  def _parse_result(self, payload: dict[str, object]) -> StepResult[StellaratorObservation]:
16
+ observation = StellaratorObservation.model_validate(payload)
17
  return StepResult(
18
  observation=observation,
19
  reward=observation.reward,
 
21
  )
22
 
23
  def _parse_state(self, payload: dict[str, object]) -> StellaratorState:
24
+ return StellaratorState.model_validate(payload)
fusion_lab/models.py CHANGED
@@ -2,8 +2,8 @@ from __future__ import annotations
2
 
3
  from typing import Literal
4
 
5
- from pydantic import BaseModel, Field
6
-
7
 
8
  ActionIntent = Literal["run", "submit", "restore_best"]
9
  OperatorName = Literal["tune_rc10", "tune_rc11", "tune_zs11", "tune_zs12"]
@@ -12,7 +12,7 @@ MagnitudeName = Literal["small", "medium", "large"]
12
  RestartMode = Literal["hot", "cold"]
13
 
14
 
15
- class StellaratorAction(BaseModel):
16
  intent: ActionIntent
17
  operator: OperatorName | None = None
18
  direction: DirectionName | None = None
@@ -21,26 +21,23 @@ class StellaratorAction(BaseModel):
21
  reasoning: str = ""
22
 
23
 
24
- class StellaratorObservation(BaseModel):
25
- diagnostics_text: str
26
- quasi_symmetry_residual: float
27
- aspect_ratio: float
28
- rotational_transform_axis: float
29
- rotational_transform_edge: float
30
- magnetic_well_depth: float
31
- volume: float
32
- vmec_converged: bool
33
- step_number: int
34
- budget_remaining: int
35
- best_qs_residual: float
36
- constraints_satisfied: bool
37
- target_spec: str
38
- reward: float | None = None
39
- done: bool = False
40
-
41
-
42
- class StellaratorState(BaseModel):
43
- step_count: int = 0
44
  initial_qs: float = 0.0
45
  current_qs: float = 0.0
46
  prev_qs: float = 0.0
 
2
 
3
  from typing import Literal
4
 
5
+ from openenv.core import Action, Observation, State
6
+ from pydantic import Field
7
 
8
  ActionIntent = Literal["run", "submit", "restore_best"]
9
  OperatorName = Literal["tune_rc10", "tune_rc11", "tune_zs11", "tune_zs12"]
 
12
  RestartMode = Literal["hot", "cold"]
13
 
14
 
15
+ class StellaratorAction(Action):
16
  intent: ActionIntent
17
  operator: OperatorName | None = None
18
  direction: DirectionName | None = None
 
21
  reasoning: str = ""
22
 
23
 
24
+ class StellaratorObservation(Observation):
25
+ diagnostics_text: str = ""
26
+ quasi_symmetry_residual: float = 0.0
27
+ aspect_ratio: float = 0.0
28
+ rotational_transform_axis: float = 0.0
29
+ rotational_transform_edge: float = 0.0
30
+ magnetic_well_depth: float = 0.0
31
+ volume: float = 0.0
32
+ vmec_converged: bool = True
33
+ step_number: int = 0
34
+ budget_remaining: int = 6
35
+ best_qs_residual: float = float("inf")
36
+ constraints_satisfied: bool = True
37
+ target_spec: str = ""
38
+
39
+
40
+ class StellaratorState(State):
 
 
 
41
  initial_qs: float = 0.0
42
  current_qs: float = 0.0
43
  prev_qs: float = 0.0
server/app.py CHANGED
@@ -1,17 +1,46 @@
1
  from __future__ import annotations
2
 
3
- from fastapi import FastAPI
4
 
5
- from server.environment import TASK, environment_status
 
 
 
 
 
 
 
6
 
7
- app = FastAPI(title="Fusion Design Lab")
8
-
9
-
10
- @app.get("/healthz")
11
- def healthcheck() -> dict[str, str]:
12
- return {"status": "ok", "environment": environment_status()}
13
 
14
 
15
  @app.get("/task")
16
  def task_summary() -> dict[str, object]:
17
- return TASK
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from openenv.core import create_fastapi_app
4
 
5
+ from fusion_lab.models import StellaratorAction, StellaratorObservation
6
+ from server.environment import (
7
+ ASPECT_RATIO_RANGE,
8
+ BUDGET,
9
+ IOTA_EDGE_RANGE,
10
+ VOLUME_MIN,
11
+ StellaratorEnvironment,
12
+ )
13
 
14
+ app = create_fastapi_app(
15
+ env=StellaratorEnvironment,
16
+ action_cls=StellaratorAction,
17
+ observation_cls=StellaratorObservation,
18
+ )
 
19
 
20
 
21
  @app.get("/task")
22
  def task_summary() -> dict[str, object]:
23
+ return {
24
+ "description": "Minimize quasi-symmetry error for a 2-period quasi-helical stellarator.",
25
+ "constraints": {
26
+ "aspect_ratio": list(ASPECT_RATIO_RANGE),
27
+ "rotational_transform_edge": list(IOTA_EDGE_RANGE),
28
+ "volume_min": VOLUME_MIN,
29
+ },
30
+ "budget": BUDGET,
31
+ "actions": ["run", "submit", "restore_best"],
32
+ "operators": ["tune_rc10", "tune_rc11", "tune_zs11", "tune_zs12"],
33
+ "directions": ["increase", "decrease"],
34
+ "magnitudes": ["small", "medium", "large"],
35
+ "restart_modes": ["hot", "cold"],
36
+ }
37
+
38
+
39
+ def main() -> None:
40
+ import uvicorn
41
+
42
+ uvicorn.run("server.app:app", host="0.0.0.0", port=8000, reload=True)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
server/environment.py CHANGED
@@ -1,19 +1,262 @@
1
  from __future__ import annotations
2
 
3
- from typing import Final
4
-
5
- TASK: Final[dict[str, object]] = {
6
- "description": "Minimize quasi-symmetry error for a 2-period quasi-helical stellarator.",
7
- "constraints": {
8
- "aspect_ratio": [4.5, 7.0],
9
- "rotational_transform_edge": [0.3, 0.6],
10
- "volume_min": 0.5,
11
- },
12
- "budget": 6,
13
- "baseline_input": "server/data/input.QH_baseline",
14
- }
15
-
16
-
17
- def environment_status() -> str:
18
- """Return a simple status string until the full environment is implemented."""
19
- return "scaffolded"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from typing import Any, Final, Optional
4
+
5
+ from openenv.core import Environment as BaseEnvironment
6
+
7
+ from fusion_lab.models import (
8
+ StellaratorAction,
9
+ StellaratorObservation,
10
+ StellaratorState,
11
+ )
12
+ from server.physics import Diagnostics, PhysicsEngine
13
+
14
+ BUDGET: Final[int] = 6
15
+
16
+ ASPECT_RATIO_RANGE: Final[tuple[float, float]] = (4.5, 7.0)
17
+ IOTA_EDGE_RANGE: Final[tuple[float, float]] = (0.3, 0.6)
18
+ VOLUME_MIN: Final[float] = 0.5
19
+
20
+ TARGET_SPEC: Final[str] = (
21
+ "Minimize quasi-symmetry residual for a 2-period quasi-helical stellarator. "
22
+ "Constraints: aspect ratio in [4.5, 7.0], edge iota in [0.3, 0.6], volume > 0.5 m³. "
23
+ "Budget: 6 evaluations."
24
+ )
25
+
26
+
27
+ def check_constraints(diag: Diagnostics) -> bool:
28
+ ar_lo, ar_hi = ASPECT_RATIO_RANGE
29
+ iota_lo, iota_hi = IOTA_EDGE_RANGE
30
+ return (
31
+ ar_lo <= diag.aspect_ratio <= ar_hi
32
+ and iota_lo <= diag.iota_edge <= iota_hi
33
+ and diag.volume >= VOLUME_MIN
34
+ )
35
+
36
+
37
+ class StellaratorEnvironment(
38
+ BaseEnvironment[StellaratorAction, StellaratorObservation, StellaratorState]
39
+ ):
40
+ def __init__(self) -> None:
41
+ super().__init__()
42
+ self._engine = PhysicsEngine()
43
+ self._state = StellaratorState()
44
+ self._last_diag: Diagnostics | None = None
45
+
46
+ def reset(
47
+ self,
48
+ seed: Optional[int] = None,
49
+ episode_id: Optional[str] = None,
50
+ **kwargs: Any,
51
+ ) -> StellaratorObservation:
52
+ diag = self._engine.reset(seed)
53
+ satisfied = check_constraints(diag)
54
+ self._state = StellaratorState(
55
+ episode_id=episode_id,
56
+ step_count=0,
57
+ initial_qs=diag.qs_residual,
58
+ current_qs=diag.qs_residual,
59
+ prev_qs=diag.qs_residual,
60
+ best_qs=diag.qs_residual,
61
+ budget_total=BUDGET,
62
+ budget_remaining=BUDGET,
63
+ constraints_satisfied=satisfied,
64
+ )
65
+ self._last_diag = diag
66
+ return self._build_observation(
67
+ diag, satisfied, action_summary="Episode started. Baseline design loaded."
68
+ )
69
+
70
+ def step(
71
+ self,
72
+ action: StellaratorAction,
73
+ timeout_s: Optional[float] = None,
74
+ **kwargs: Any,
75
+ ) -> StellaratorObservation:
76
+ self._state.prev_qs = self._state.current_qs
77
+ self._state.step_count += 1
78
+
79
+ if action.intent == "submit":
80
+ return self._handle_submit()
81
+ if action.intent == "restore_best":
82
+ return self._handle_restore()
83
+ return self._handle_run(action)
84
+
85
+ @property
86
+ def state(self) -> StellaratorState:
87
+ return self._state
88
+
89
+ # ------------------------------------------------------------------
90
+ # Action handlers
91
+ # ------------------------------------------------------------------
92
+
93
+ def _handle_run(self, action: StellaratorAction) -> StellaratorObservation:
94
+ if not all([action.operator, action.direction, action.magnitude]):
95
+ return self._handle_invalid_run()
96
+
97
+ self._state.budget_remaining -= 1
98
+
99
+ diag = self._engine.modify_and_run(
100
+ operator=action.operator,
101
+ direction=action.direction,
102
+ magnitude=action.magnitude,
103
+ restart=action.restart or "hot",
104
+ )
105
+
106
+ satisfied = check_constraints(diag) if diag.converged else self._state.constraints_satisfied
107
+
108
+ if diag.converged:
109
+ self._state.current_qs = diag.qs_residual
110
+ if diag.qs_residual < self._state.best_qs:
111
+ self._state.best_qs = diag.qs_residual
112
+ self._state.constraints_satisfied = satisfied
113
+
114
+ done = self._state.budget_remaining <= 0
115
+ reward = self._compute_reward(diag, action.intent, done)
116
+ summary = self._summary_run(action, diag)
117
+ self._state.history.append(summary)
118
+ self._last_diag = diag
119
+
120
+ return self._build_observation(
121
+ diag, satisfied, action_summary=summary, reward=reward, done=done
122
+ )
123
+
124
+ def _handle_submit(self) -> StellaratorObservation:
125
+ diag = self._last_diag or self._engine.restore_best()
126
+ satisfied = check_constraints(diag)
127
+ reward = self._compute_reward(diag, "submit", done=True)
128
+ summary = self._summary_submit(satisfied)
129
+ self._state.history.append(summary)
130
+
131
+ return self._build_observation(
132
+ diag, satisfied, action_summary=summary, reward=reward, done=True
133
+ )
134
+
135
+ def _handle_restore(self) -> StellaratorObservation:
136
+ self._state.budget_remaining -= 1
137
+
138
+ diag = self._engine.restore_best()
139
+ self._state.current_qs = diag.qs_residual
140
+ satisfied = check_constraints(diag)
141
+ self._state.constraints_satisfied = satisfied
142
+
143
+ done = self._state.budget_remaining <= 0
144
+ reward = self._compute_reward(diag, "restore_best", done)
145
+ summary = f"Restored best design. QS residual: {diag.qs_residual:.6f}."
146
+ self._state.history.append(summary)
147
+ self._last_diag = diag
148
+
149
+ return self._build_observation(
150
+ diag, satisfied, action_summary=summary, reward=reward, done=done
151
+ )
152
+
153
+ def _handle_invalid_run(self) -> StellaratorObservation:
154
+ self._state.budget_remaining -= 1
155
+ diag = self._last_diag or self._engine.restore_best()
156
+ satisfied = check_constraints(diag)
157
+ done = self._state.budget_remaining <= 0
158
+ summary = "Invalid run action: operator, direction, and magnitude are required."
159
+ self._state.history.append(summary)
160
+ return self._build_observation(
161
+ diag, satisfied, action_summary=summary, reward=-1.0, done=done
162
+ )
163
+
164
+ # ------------------------------------------------------------------
165
+ # Reward V0
166
+ # ------------------------------------------------------------------
167
+
168
+ def _compute_reward(self, diag: Diagnostics, intent: str, done: bool) -> float:
169
+ reward = 0.0
170
+
171
+ if diag.converged and self._state.prev_qs < float("inf"):
172
+ improvement = self._state.prev_qs - diag.qs_residual
173
+ reward += improvement * 500.0
174
+
175
+ if diag.converged and not check_constraints(diag):
176
+ reward -= 2.0
177
+
178
+ if not diag.converged:
179
+ reward -= 1.5
180
+
181
+ if intent != "submit":
182
+ reward -= 0.1
183
+
184
+ if intent == "submit":
185
+ if self._state.best_qs < self._state.initial_qs:
186
+ ratio = 1.0 - (self._state.best_qs / max(self._state.initial_qs, 1e-9))
187
+ reward += 5.0 * ratio
188
+ reward += 1.0 * (self._state.budget_remaining / self._state.budget_total)
189
+ else:
190
+ reward -= 1.0
191
+
192
+ if done and intent != "submit":
193
+ if self._state.best_qs < self._state.initial_qs:
194
+ ratio = 1.0 - (self._state.best_qs / max(self._state.initial_qs, 1e-9))
195
+ reward += 2.0 * ratio
196
+
197
+ return round(reward, 4)
198
+
199
+ # ------------------------------------------------------------------
200
+ # Observation builders
201
+ # ------------------------------------------------------------------
202
+
203
+ def _build_observation(
204
+ self,
205
+ diag: Diagnostics,
206
+ satisfied: bool,
207
+ action_summary: str,
208
+ reward: float | None = None,
209
+ done: bool = False,
210
+ ) -> StellaratorObservation:
211
+ text_lines = [
212
+ action_summary,
213
+ "",
214
+ f"QS Residual: {diag.qs_residual:.6f} | Best: {self._state.best_qs:.6f}",
215
+ f"Aspect Ratio: {diag.aspect_ratio:.4f} [4.5, 7.0]",
216
+ f"Edge Iota: {diag.iota_edge:.4f} [0.3, 0.6]",
217
+ f"Volume: {diag.volume:.4f} m³ (min 0.5)",
218
+ f"Magnetic Well: {diag.magnetic_well_depth:.4f}",
219
+ f"VMEC Converged: {diag.converged}",
220
+ f"Constraints: {'SATISFIED' if satisfied else 'VIOLATED'}",
221
+ f"Step: {self._state.step_count} | Budget: {self._state.budget_remaining}/{self._state.budget_total}",
222
+ ]
223
+
224
+ return StellaratorObservation(
225
+ diagnostics_text="\n".join(text_lines),
226
+ quasi_symmetry_residual=diag.qs_residual,
227
+ aspect_ratio=diag.aspect_ratio,
228
+ rotational_transform_axis=diag.iota_axis,
229
+ rotational_transform_edge=diag.iota_edge,
230
+ magnetic_well_depth=diag.magnetic_well_depth,
231
+ volume=diag.volume,
232
+ vmec_converged=diag.converged,
233
+ step_number=self._state.step_count,
234
+ budget_remaining=self._state.budget_remaining,
235
+ best_qs_residual=self._state.best_qs,
236
+ constraints_satisfied=satisfied,
237
+ target_spec=TARGET_SPEC,
238
+ reward=reward,
239
+ done=done,
240
+ )
241
+
242
+ # ------------------------------------------------------------------
243
+ # Action summaries
244
+ # ------------------------------------------------------------------
245
+
246
+ def _summary_run(self, action: StellaratorAction, diag: Diagnostics) -> str:
247
+ restart_note = f" ({action.restart} restart)" if action.restart else ""
248
+ header = f"Applied {action.operator} {action.direction} {action.magnitude}{restart_note}."
249
+
250
+ if diag.converged:
251
+ delta = self._state.prev_qs - diag.qs_residual
252
+ direction = "improved" if delta > 0 else "worsened" if delta < 0 else "unchanged"
253
+ return f"{header} VMEC converged. QS {direction}: {self._state.prev_qs:.6f} -> {diag.qs_residual:.6f}."
254
+ return f"{header} VMEC failed to converge. Change reverted."
255
+
256
+ def _summary_submit(self, satisfied: bool) -> str:
257
+ status = "Constraints satisfied." if satisfied else "Constraints VIOLATED."
258
+ improvement = self._state.initial_qs - self._state.best_qs
259
+ return (
260
+ f"Design submitted. Best QS residual: {self._state.best_qs:.6f} "
261
+ f"(improved by {improvement:.6f} from initial). {status}"
262
+ )
server/physics.py CHANGED
@@ -1,20 +1,141 @@
1
  from __future__ import annotations
2
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class PhysicsEngine:
5
- """Placeholder for the VMEC-backed physics loop.
6
-
7
- The next implementation step should make this the single place that:
8
- - loads the baseline input
9
- - applies discrete coefficient updates
10
- - runs the solver
11
- - computes diagnostics
12
- - tracks best-known designs
13
- """
14
-
15
- def __init__(self) -> None:
16
- self._status = "unimplemented"
17
-
18
- @property
19
- def status(self) -> str:
20
- return self._status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import math
4
+ import random
5
+ from dataclasses import dataclass, field
6
+ from typing import Final
7
 
8
+ NFP: Final[int] = 2
9
+
10
+ BASELINE_COEFFS: Final[dict[str, float]] = {
11
+ "rc10": 1.0,
12
+ "rc11": 0.12,
13
+ "zs11": 0.12,
14
+ "zs12": -0.02,
15
+ }
16
+
17
+ OPTIMAL_COEFFS: Final[dict[str, float]] = {
18
+ "rc10": 1.02,
19
+ "rc11": 0.135,
20
+ "zs11": 0.115,
21
+ "zs12": -0.035,
22
+ }
23
+
24
+ MAGNITUDE_DELTAS: Final[dict[str, float]] = {
25
+ "small": 0.005,
26
+ "medium": 0.02,
27
+ "large": 0.05,
28
+ }
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class Diagnostics:
33
+ qs_residual: float
34
+ aspect_ratio: float
35
+ iota_axis: float
36
+ iota_edge: float
37
+ volume: float
38
+ magnetic_well_depth: float
39
+ converged: bool
40
+
41
+
42
+ @dataclass
43
  class PhysicsEngine:
44
+ coeffs: dict[str, float] = field(default_factory=lambda: dict(BASELINE_COEFFS))
45
+ best_coeffs: dict[str, float] = field(default_factory=lambda: dict(BASELINE_COEFFS))
46
+ best_qs: float = float("inf")
47
+ _rng: random.Random = field(default_factory=random.Random)
48
+
49
+ def reset(self, seed: int | None = None) -> Diagnostics:
50
+ self.coeffs = dict(BASELINE_COEFFS)
51
+ self._rng = random.Random(seed)
52
+ if seed is not None:
53
+ for key in self.coeffs:
54
+ self.coeffs[key] += self._rng.gauss(0, 0.002)
55
+ self.best_coeffs = dict(self.coeffs)
56
+ diag = self._compute_diagnostics(converged=True)
57
+ self.best_qs = diag.qs_residual
58
+ return diag
59
+
60
+ def modify_and_run(
61
+ self,
62
+ operator: str,
63
+ direction: str,
64
+ magnitude: str,
65
+ restart: str,
66
+ ) -> Diagnostics:
67
+ coeff_key = operator.removeprefix("tune_")
68
+ delta = MAGNITUDE_DELTAS[magnitude]
69
+ if direction == "decrease":
70
+ delta = -delta
71
+
72
+ prev_value = self.coeffs[coeff_key]
73
+ self.coeffs[coeff_key] = prev_value + delta
74
+
75
+ converged = self._simulate_convergence(magnitude, restart)
76
+ if not converged:
77
+ self.coeffs[coeff_key] = prev_value
78
+ return self._compute_diagnostics(converged=False)
79
+
80
+ diag = self._compute_diagnostics(converged=True)
81
+ if diag.qs_residual < self.best_qs:
82
+ self.best_qs = diag.qs_residual
83
+ self.best_coeffs = dict(self.coeffs)
84
+ return diag
85
+
86
+ def restore_best(self) -> Diagnostics:
87
+ self.coeffs = dict(self.best_coeffs)
88
+ return self._compute_diagnostics(converged=True)
89
+
90
+ def _compute_diagnostics(self, *, converged: bool) -> Diagnostics:
91
+ rc10 = self.coeffs["rc10"]
92
+ rc11 = self.coeffs["rc11"]
93
+ zs11 = self.coeffs["zs11"]
94
+ zs12 = self.coeffs["zs12"]
95
+
96
+ r_minor = math.sqrt(rc11**2 + zs11**2)
97
+ aspect_ratio = rc10 / max(r_minor, 1e-6)
98
+ volume = 2.0 * math.pi**2 * rc10 * r_minor**2
99
+
100
+ helical_excursion = abs(zs11 / max(abs(rc11), 1e-6))
101
+ iota_axis = 0.35 + 0.15 * helical_excursion + 0.5 * abs(zs12)
102
+ shear = 0.04 + 0.02 * abs(rc10 - 1.0)
103
+ iota_edge = iota_axis + shear
104
+
105
+ magnetic_well = 0.02 + 0.01 * (rc11 / max(abs(zs11), 1e-6) - 1.0)
106
+
107
+ qs_residual = self._compute_qs_residual() if converged else float("inf")
108
+
109
+ return Diagnostics(
110
+ qs_residual=round(qs_residual, 6),
111
+ aspect_ratio=round(aspect_ratio, 4),
112
+ iota_axis=round(iota_axis, 4),
113
+ iota_edge=round(iota_edge, 4),
114
+ volume=round(volume, 4),
115
+ magnetic_well_depth=round(magnetic_well, 4),
116
+ converged=converged,
117
+ )
118
+
119
+ def _compute_qs_residual(self) -> float:
120
+ d = {k: self.coeffs[k] - OPTIMAL_COEFFS[k] for k in OPTIMAL_COEFFS}
121
+ quadratic = (
122
+ 2.0 * d["rc10"] ** 2
123
+ + 8.0 * d["rc11"] ** 2
124
+ + 8.0 * d["zs11"] ** 2
125
+ + 15.0 * d["zs12"] ** 2
126
+ )
127
+ cross = 4.0 * d["rc11"] * d["zs11"] - 3.0 * d["rc10"] * d["zs12"]
128
+ noise = self._rng.gauss(0, 0.0003)
129
+ return max(quadratic + cross + 0.002 + noise, 0.001)
130
+
131
+ def _simulate_convergence(self, magnitude: str, restart: str) -> bool:
132
+ fail_prob = {"small": 0.02, "medium": 0.08, "large": 0.20}[magnitude]
133
+ if restart == "hot":
134
+ fail_prob *= 0.5
135
+ for key, val in self.coeffs.items():
136
+ deviation = abs(val - BASELINE_COEFFS[key])
137
+ if deviation > 0.1:
138
+ fail_prob += 0.15
139
+ elif deviation > 0.05:
140
+ fail_prob += 0.05
141
+ return self._rng.random() > min(fail_prob, 0.8)