fusion-design-lab / baselines /random_agent.py
CreativeEngineer's picture
refactor: align p1 runtime contract and baseline reporting
6deaccc
"""Random baseline agent for the stellarator design environment."""
from __future__ import annotations
import random
import sys
from fusion_lab.models import StellaratorAction
from server.environment import StellaratorEnvironment
PARAMETERS = [
"aspect_ratio",
"elongation",
"rotational_transform",
"triangularity_scale",
]
DIRECTIONS = ["increase", "decrease"]
MAGNITUDES = ["small", "medium", "large"]
def random_episode(
env: StellaratorEnvironment, seed: int | None = None
) -> tuple[float, list[dict[str, object]]]:
rng = random.Random(seed)
obs = env.reset(seed=seed)
total_reward = 0.0
trace: list[dict[str, object]] = [
{
"step": 0,
"score": obs.p1_score,
"evaluation_fidelity": obs.evaluation_fidelity,
"constraints_satisfied": obs.constraints_satisfied,
}
]
while not obs.done:
if obs.budget_remaining <= 1:
action = StellaratorAction(intent="submit")
else:
action = StellaratorAction(
intent="run",
parameter=rng.choice(PARAMETERS),
direction=rng.choice(DIRECTIONS),
magnitude=rng.choice(MAGNITUDES),
)
obs = env.step(action)
total_reward += obs.reward or 0.0
trace.append(
{
"step": len(trace),
"action": action.intent,
"score": obs.p1_score,
"evaluation_fidelity": obs.evaluation_fidelity,
"constraints_satisfied": obs.constraints_satisfied,
"evaluation_failed": obs.evaluation_failed,
"reward": obs.reward,
}
)
return total_reward, trace
def main(n_episodes: int = 20) -> None:
env = StellaratorEnvironment()
rewards: list[float] = []
for i in range(n_episodes):
total_reward, trace = random_episode(env, seed=i)
final = trace[-1]
rewards.append(total_reward)
print(
f"Episode {i:3d}: steps={len(trace) - 1} "
f"final_score={final['score']:.6f} fidelity={final['evaluation_fidelity']} "
f"constraints={'yes' if final['constraints_satisfied'] else 'no'} "
f"reward={total_reward:+.4f}"
)
mean_reward = sum(rewards) / len(rewards)
print(f"\nRandom baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}")
if __name__ == "__main__":
n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
main(n)