File size: 4,675 Bytes
fe3a41d
2f5db5e
 
 
 
 
fe3a41d
2f5db5e
 
f238af4
 
 
 
 
 
2f5db5e
 
 
 
 
 
6deaccc
 
 
 
 
 
f238af4
 
 
 
6deaccc
 
2f5db5e
fe3a41d
6deaccc
 
 
2f5db5e
 
 
 
 
fe3a41d
daba1b9
6deaccc
 
f238af4
 
 
 
2f5db5e
cdc237b
2f5db5e
 
 
fe3a41d
 
 
 
f238af4
 
 
fe3a41d
f238af4
 
 
 
 
fe3a41d
 
 
 
 
 
2f5db5e
 
f238af4
 
 
 
 
 
 
 
fe3a41d
 
 
 
f238af4
fe3a41d
 
f238af4
fe3a41d
 
 
 
 
 
 
f238af4
fe3a41d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f5db5e
 
 
 
 
 
 
 
 
 
 
 
6deaccc
 
2f5db5e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Heuristic baseline agent for the stellarator design environment."""

from __future__ import annotations

import sys

from fusion_lab.models import StellaratorAction, StellaratorObservation
from server.environment import StellaratorEnvironment

FEASIBLE_SUBMIT_ELONGATION_MAX = 7.45
TRIANGULARITY_TARGET_MAX = -0.5
LOW_IOTA_RESET_THRESHOLD = 0.305
IOTA_RECOVERY_THRESHOLD = 0.3
ASPECT_RATIO_TARGET_MAX = 4.0


def heuristic_episode(
    env: StellaratorEnvironment, seed: int | None = None
) -> tuple[float, list[dict[str, object]]]:
    obs = env.reset(seed=seed)
    total_reward = 0.0
    trace: list[dict[str, object]] = [
        {
            "step": 0,
            "score": obs.p1_score,
            "evaluation_fidelity": obs.evaluation_fidelity,
            "constraints_satisfied": obs.constraints_satisfied,
            "feasibility": obs.p1_feasibility,
            "max_elongation": obs.max_elongation,
            "average_triangularity": obs.average_triangularity,
            "edge_iota_over_nfp": obs.edge_iota_over_nfp,
        }
    ]

    while not obs.done:
        action = (
            StellaratorAction(intent="submit") if obs.budget_remaining <= 1 else _choose_action(obs)
        )
        obs = env.step(action)
        total_reward += obs.reward or 0.0
        trace.append(
            {
                "step": len(trace),
                "action": _action_label(action),
                "score": obs.p1_score,
                "evaluation_fidelity": obs.evaluation_fidelity,
                "constraints_satisfied": obs.constraints_satisfied,
                "feasibility": obs.p1_feasibility,
                "max_elongation": obs.max_elongation,
                "average_triangularity": obs.average_triangularity,
                "edge_iota_over_nfp": obs.edge_iota_over_nfp,
                "reward": obs.reward,
                "evaluation_failed": obs.evaluation_failed,
            }
        )

    return total_reward, trace


def _choose_action(obs: StellaratorObservation) -> StellaratorAction:
    if obs.evaluation_failed:
        return StellaratorAction(intent="restore_best")

    if obs.constraints_satisfied:
        if (
            obs.max_elongation <= FEASIBLE_SUBMIT_ELONGATION_MAX
            or obs.budget_remaining <= 2
            or obs.step_number >= 3
        ):
            return StellaratorAction(intent="submit")
        return StellaratorAction(
            intent="run",
            parameter="elongation",
            direction="decrease",
            magnitude="small",
        )

    if obs.average_triangularity > TRIANGULARITY_TARGET_MAX:
        if obs.step_number == 0 and obs.edge_iota_over_nfp < LOW_IOTA_RESET_THRESHOLD:
            return StellaratorAction(
                intent="run",
                parameter="rotational_transform",
                direction="increase",
                magnitude="medium",
            )
        return StellaratorAction(
            intent="run",
            parameter="triangularity_scale",
            direction="increase",
            magnitude="medium",
        )

    if obs.edge_iota_over_nfp < IOTA_RECOVERY_THRESHOLD:
        return StellaratorAction(
            intent="run",
            parameter="rotational_transform",
            direction="increase",
            magnitude="small",
        )

    if obs.aspect_ratio > ASPECT_RATIO_TARGET_MAX:
        return StellaratorAction(
            intent="run",
            parameter="aspect_ratio",
            direction="decrease",
            magnitude="small",
        )

    return StellaratorAction(
        intent="run",
        parameter="elongation",
        direction="decrease",
        magnitude="small",
    )


def _action_label(action: StellaratorAction) -> str:
    if action.intent != "run":
        return action.intent
    return f"{action.parameter} {action.direction} {action.magnitude}"


def main(n_episodes: int = 20) -> None:
    env = StellaratorEnvironment()
    rewards: list[float] = []

    for i in range(n_episodes):
        total_reward, trace = heuristic_episode(env, seed=i)
        final = trace[-1]
        rewards.append(total_reward)
        print(
            f"Episode {i:3d}: steps={len(trace) - 1}  "
            f"final_score={final['score']:.6f}  fidelity={final['evaluation_fidelity']}  "
            f"constraints={'yes' if final['constraints_satisfied'] else 'no'}  "
            f"reward={total_reward:+.4f}"
        )

    mean_reward = sum(rewards) / len(rewards)
    print(f"\nHeuristic baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}")


if __name__ == "__main__":
    n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
    main(n)