Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Heuristic baseline agent for the stellarator design environment.""" | |
| from __future__ import annotations | |
| import sys | |
| from fusion_lab.models import StellaratorAction, StellaratorObservation | |
| from server.environment import StellaratorEnvironment | |
| FEASIBLE_SUBMIT_ELONGATION_MAX = 7.45 | |
| TRIANGULARITY_TARGET_MAX = -0.5 | |
| LOW_IOTA_RESET_THRESHOLD = 0.305 | |
| IOTA_RECOVERY_THRESHOLD = 0.3 | |
| ASPECT_RATIO_TARGET_MAX = 4.0 | |
| def heuristic_episode( | |
| env: StellaratorEnvironment, seed: int | None = None | |
| ) -> tuple[float, list[dict[str, object]]]: | |
| obs = env.reset(seed=seed) | |
| total_reward = 0.0 | |
| trace: list[dict[str, object]] = [ | |
| { | |
| "step": 0, | |
| "score": obs.p1_score, | |
| "evaluation_fidelity": obs.evaluation_fidelity, | |
| "constraints_satisfied": obs.constraints_satisfied, | |
| "feasibility": obs.p1_feasibility, | |
| "max_elongation": obs.max_elongation, | |
| "average_triangularity": obs.average_triangularity, | |
| "edge_iota_over_nfp": obs.edge_iota_over_nfp, | |
| } | |
| ] | |
| while not obs.done: | |
| action = ( | |
| StellaratorAction(intent="submit") if obs.budget_remaining <= 1 else _choose_action(obs) | |
| ) | |
| obs = env.step(action) | |
| total_reward += obs.reward or 0.0 | |
| trace.append( | |
| { | |
| "step": len(trace), | |
| "action": _action_label(action), | |
| "score": obs.p1_score, | |
| "evaluation_fidelity": obs.evaluation_fidelity, | |
| "constraints_satisfied": obs.constraints_satisfied, | |
| "feasibility": obs.p1_feasibility, | |
| "max_elongation": obs.max_elongation, | |
| "average_triangularity": obs.average_triangularity, | |
| "edge_iota_over_nfp": obs.edge_iota_over_nfp, | |
| "reward": obs.reward, | |
| "evaluation_failed": obs.evaluation_failed, | |
| } | |
| ) | |
| return total_reward, trace | |
| def _choose_action(obs: StellaratorObservation) -> StellaratorAction: | |
| if obs.evaluation_failed: | |
| return StellaratorAction(intent="restore_best") | |
| if obs.constraints_satisfied: | |
| if ( | |
| obs.max_elongation <= FEASIBLE_SUBMIT_ELONGATION_MAX | |
| or obs.budget_remaining <= 2 | |
| or obs.step_number >= 3 | |
| ): | |
| return StellaratorAction(intent="submit") | |
| return StellaratorAction( | |
| intent="run", | |
| parameter="elongation", | |
| direction="decrease", | |
| magnitude="small", | |
| ) | |
| if obs.average_triangularity > TRIANGULARITY_TARGET_MAX: | |
| if obs.step_number == 0 and obs.edge_iota_over_nfp < LOW_IOTA_RESET_THRESHOLD: | |
| return StellaratorAction( | |
| intent="run", | |
| parameter="rotational_transform", | |
| direction="increase", | |
| magnitude="medium", | |
| ) | |
| return StellaratorAction( | |
| intent="run", | |
| parameter="triangularity_scale", | |
| direction="increase", | |
| magnitude="medium", | |
| ) | |
| if obs.edge_iota_over_nfp < IOTA_RECOVERY_THRESHOLD: | |
| return StellaratorAction( | |
| intent="run", | |
| parameter="rotational_transform", | |
| direction="increase", | |
| magnitude="small", | |
| ) | |
| if obs.aspect_ratio > ASPECT_RATIO_TARGET_MAX: | |
| return StellaratorAction( | |
| intent="run", | |
| parameter="aspect_ratio", | |
| direction="decrease", | |
| magnitude="small", | |
| ) | |
| return StellaratorAction( | |
| intent="run", | |
| parameter="elongation", | |
| direction="decrease", | |
| magnitude="small", | |
| ) | |
| def _action_label(action: StellaratorAction) -> str: | |
| if action.intent != "run": | |
| return action.intent | |
| return f"{action.parameter} {action.direction} {action.magnitude}" | |
| def main(n_episodes: int = 20) -> None: | |
| env = StellaratorEnvironment() | |
| rewards: list[float] = [] | |
| for i in range(n_episodes): | |
| total_reward, trace = heuristic_episode(env, seed=i) | |
| final = trace[-1] | |
| rewards.append(total_reward) | |
| print( | |
| f"Episode {i:3d}: steps={len(trace) - 1} " | |
| f"final_score={final['score']:.6f} fidelity={final['evaluation_fidelity']} " | |
| f"constraints={'yes' if final['constraints_satisfied'] else 'no'} " | |
| f"reward={total_reward:+.4f}" | |
| ) | |
| mean_reward = sum(rewards) / len(rewards) | |
| print(f"\nHeuristic baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}") | |
| if __name__ == "__main__": | |
| n = int(sys.argv[1]) if len(sys.argv) > 1 else 20 | |
| main(n) | |