Spaces:
Sleeping
Sleeping
File size: 4,675 Bytes
fe3a41d 2f5db5e fe3a41d 2f5db5e f238af4 2f5db5e 6deaccc f238af4 6deaccc 2f5db5e fe3a41d 6deaccc 2f5db5e fe3a41d daba1b9 6deaccc f238af4 2f5db5e cdc237b 2f5db5e fe3a41d f238af4 fe3a41d f238af4 fe3a41d 2f5db5e f238af4 fe3a41d f238af4 fe3a41d f238af4 fe3a41d f238af4 fe3a41d 2f5db5e 6deaccc 2f5db5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """Heuristic baseline agent for the stellarator design environment."""
from __future__ import annotations
import sys
from fusion_lab.models import StellaratorAction, StellaratorObservation
from server.environment import StellaratorEnvironment
FEASIBLE_SUBMIT_ELONGATION_MAX = 7.45
TRIANGULARITY_TARGET_MAX = -0.5
LOW_IOTA_RESET_THRESHOLD = 0.305
IOTA_RECOVERY_THRESHOLD = 0.3
ASPECT_RATIO_TARGET_MAX = 4.0
def heuristic_episode(
env: StellaratorEnvironment, seed: int | None = None
) -> tuple[float, list[dict[str, object]]]:
obs = env.reset(seed=seed)
total_reward = 0.0
trace: list[dict[str, object]] = [
{
"step": 0,
"score": obs.p1_score,
"evaluation_fidelity": obs.evaluation_fidelity,
"constraints_satisfied": obs.constraints_satisfied,
"feasibility": obs.p1_feasibility,
"max_elongation": obs.max_elongation,
"average_triangularity": obs.average_triangularity,
"edge_iota_over_nfp": obs.edge_iota_over_nfp,
}
]
while not obs.done:
action = (
StellaratorAction(intent="submit") if obs.budget_remaining <= 1 else _choose_action(obs)
)
obs = env.step(action)
total_reward += obs.reward or 0.0
trace.append(
{
"step": len(trace),
"action": _action_label(action),
"score": obs.p1_score,
"evaluation_fidelity": obs.evaluation_fidelity,
"constraints_satisfied": obs.constraints_satisfied,
"feasibility": obs.p1_feasibility,
"max_elongation": obs.max_elongation,
"average_triangularity": obs.average_triangularity,
"edge_iota_over_nfp": obs.edge_iota_over_nfp,
"reward": obs.reward,
"evaluation_failed": obs.evaluation_failed,
}
)
return total_reward, trace
def _choose_action(obs: StellaratorObservation) -> StellaratorAction:
if obs.evaluation_failed:
return StellaratorAction(intent="restore_best")
if obs.constraints_satisfied:
if (
obs.max_elongation <= FEASIBLE_SUBMIT_ELONGATION_MAX
or obs.budget_remaining <= 2
or obs.step_number >= 3
):
return StellaratorAction(intent="submit")
return StellaratorAction(
intent="run",
parameter="elongation",
direction="decrease",
magnitude="small",
)
if obs.average_triangularity > TRIANGULARITY_TARGET_MAX:
if obs.step_number == 0 and obs.edge_iota_over_nfp < LOW_IOTA_RESET_THRESHOLD:
return StellaratorAction(
intent="run",
parameter="rotational_transform",
direction="increase",
magnitude="medium",
)
return StellaratorAction(
intent="run",
parameter="triangularity_scale",
direction="increase",
magnitude="medium",
)
if obs.edge_iota_over_nfp < IOTA_RECOVERY_THRESHOLD:
return StellaratorAction(
intent="run",
parameter="rotational_transform",
direction="increase",
magnitude="small",
)
if obs.aspect_ratio > ASPECT_RATIO_TARGET_MAX:
return StellaratorAction(
intent="run",
parameter="aspect_ratio",
direction="decrease",
magnitude="small",
)
return StellaratorAction(
intent="run",
parameter="elongation",
direction="decrease",
magnitude="small",
)
def _action_label(action: StellaratorAction) -> str:
if action.intent != "run":
return action.intent
return f"{action.parameter} {action.direction} {action.magnitude}"
def main(n_episodes: int = 20) -> None:
env = StellaratorEnvironment()
rewards: list[float] = []
for i in range(n_episodes):
total_reward, trace = heuristic_episode(env, seed=i)
final = trace[-1]
rewards.append(total_reward)
print(
f"Episode {i:3d}: steps={len(trace) - 1} "
f"final_score={final['score']:.6f} fidelity={final['evaluation_fidelity']} "
f"constraints={'yes' if final['constraints_satisfied'] else 'no'} "
f"reward={total_reward:+.4f}"
)
mean_reward = sum(rewards) / len(rewards)
print(f"\nHeuristic baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}")
if __name__ == "__main__":
n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
main(n)
|