fusion-design-lab / baselines /heuristic_agent.py
CreativeEngineer's picture
feat: reward verifier alignment, notebook hardening, model name fix
cdc237b
"""Heuristic baseline agent for the stellarator design environment."""
from __future__ import annotations
import sys
from fusion_lab.models import StellaratorAction, StellaratorObservation
from server.environment import StellaratorEnvironment
FEASIBLE_SUBMIT_ELONGATION_MAX = 7.45
TRIANGULARITY_TARGET_MAX = -0.5
LOW_IOTA_RESET_THRESHOLD = 0.305
IOTA_RECOVERY_THRESHOLD = 0.3
ASPECT_RATIO_TARGET_MAX = 4.0
def heuristic_episode(
env: StellaratorEnvironment, seed: int | None = None
) -> tuple[float, list[dict[str, object]]]:
obs = env.reset(seed=seed)
total_reward = 0.0
trace: list[dict[str, object]] = [
{
"step": 0,
"score": obs.p1_score,
"evaluation_fidelity": obs.evaluation_fidelity,
"constraints_satisfied": obs.constraints_satisfied,
"feasibility": obs.p1_feasibility,
"max_elongation": obs.max_elongation,
"average_triangularity": obs.average_triangularity,
"edge_iota_over_nfp": obs.edge_iota_over_nfp,
}
]
while not obs.done:
action = (
StellaratorAction(intent="submit") if obs.budget_remaining <= 1 else _choose_action(obs)
)
obs = env.step(action)
total_reward += obs.reward or 0.0
trace.append(
{
"step": len(trace),
"action": _action_label(action),
"score": obs.p1_score,
"evaluation_fidelity": obs.evaluation_fidelity,
"constraints_satisfied": obs.constraints_satisfied,
"feasibility": obs.p1_feasibility,
"max_elongation": obs.max_elongation,
"average_triangularity": obs.average_triangularity,
"edge_iota_over_nfp": obs.edge_iota_over_nfp,
"reward": obs.reward,
"evaluation_failed": obs.evaluation_failed,
}
)
return total_reward, trace
def _choose_action(obs: StellaratorObservation) -> StellaratorAction:
if obs.evaluation_failed:
return StellaratorAction(intent="restore_best")
if obs.constraints_satisfied:
if (
obs.max_elongation <= FEASIBLE_SUBMIT_ELONGATION_MAX
or obs.budget_remaining <= 2
or obs.step_number >= 3
):
return StellaratorAction(intent="submit")
return StellaratorAction(
intent="run",
parameter="elongation",
direction="decrease",
magnitude="small",
)
if obs.average_triangularity > TRIANGULARITY_TARGET_MAX:
if obs.step_number == 0 and obs.edge_iota_over_nfp < LOW_IOTA_RESET_THRESHOLD:
return StellaratorAction(
intent="run",
parameter="rotational_transform",
direction="increase",
magnitude="medium",
)
return StellaratorAction(
intent="run",
parameter="triangularity_scale",
direction="increase",
magnitude="medium",
)
if obs.edge_iota_over_nfp < IOTA_RECOVERY_THRESHOLD:
return StellaratorAction(
intent="run",
parameter="rotational_transform",
direction="increase",
magnitude="small",
)
if obs.aspect_ratio > ASPECT_RATIO_TARGET_MAX:
return StellaratorAction(
intent="run",
parameter="aspect_ratio",
direction="decrease",
magnitude="small",
)
return StellaratorAction(
intent="run",
parameter="elongation",
direction="decrease",
magnitude="small",
)
def _action_label(action: StellaratorAction) -> str:
if action.intent != "run":
return action.intent
return f"{action.parameter} {action.direction} {action.magnitude}"
def main(n_episodes: int = 20) -> None:
env = StellaratorEnvironment()
rewards: list[float] = []
for i in range(n_episodes):
total_reward, trace = heuristic_episode(env, seed=i)
final = trace[-1]
rewards.append(total_reward)
print(
f"Episode {i:3d}: steps={len(trace) - 1} "
f"final_score={final['score']:.6f} fidelity={final['evaluation_fidelity']} "
f"constraints={'yes' if final['constraints_satisfied'] else 'no'} "
f"reward={total_reward:+.4f}"
)
mean_reward = sum(rewards) / len(rewards)
print(f"\nHeuristic baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}")
if __name__ == "__main__":
n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
main(n)