Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
File size: 2,588 Bytes
2f5db5e 6deaccc 2f5db5e cdc237b 2f5db5e 6deaccc 2f5db5e cdc237b 2f5db5e 6deaccc 2f5db5e 6deaccc 2f5db5e 6deaccc 2f5db5e cdc237b 6deaccc cdc237b 6deaccc cdc237b 6deaccc 2f5db5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | """Run both baselines and print a comparison summary."""
from __future__ import annotations
import sys
from baselines.heuristic_agent import heuristic_episode
from baselines.random_agent import random_episode
from server.environment import StellaratorEnvironment
def main(n_episodes: int = 20) -> None:
env = StellaratorEnvironment()
random_rewards: list[float] = []
heuristic_rewards: list[float] = []
random_final_scores: list[float] = []
heuristic_final_scores: list[float] = []
random_feasible: list[int] = []
heuristic_feasible: list[int] = []
for i in range(n_episodes):
rr, rt = random_episode(env, seed=i)
_require_successful_submit(rt[-1], baseline_name="random")
random_rewards.append(rr)
random_final_scores.append(rt[-1]["score"])
random_feasible.append(1 if rt[-1]["constraints_satisfied"] else 0)
hr, ht = heuristic_episode(env, seed=i)
_require_successful_submit(ht[-1], baseline_name="heuristic")
heuristic_rewards.append(hr)
heuristic_final_scores.append(ht[-1]["score"])
heuristic_feasible.append(1 if ht[-1]["constraints_satisfied"] else 0)
r_mean = sum(random_rewards) / len(random_rewards)
h_mean = sum(heuristic_rewards) / len(heuristic_rewards)
r_score = sum(random_final_scores) / len(random_final_scores)
h_score = sum(heuristic_final_scores) / len(heuristic_final_scores)
r_feasible = sum(random_feasible)
h_feasible = sum(heuristic_feasible)
print(f"{'Metric':<25} {'Random':>12} {'Heuristic':>12}")
print("-" * 51)
print(f"{'Mean reward':<25} {r_mean:>+12.4f} {h_mean:>+12.4f}")
print(f"{'Mean final P1 score':<25} {r_score:>12.6f} {h_score:>12.6f}")
print(f"{'Feasible finals':<25} {r_feasible:>12d} {h_feasible:>12d}")
print(f"{'Episodes':<25} {n_episodes:>12d} {n_episodes:>12d}")
print()
wins = sum(1 for h, r in zip(heuristic_rewards, random_rewards) if h > r)
print(f"Heuristic wins: {wins}/{n_episodes} episodes ({100 * wins / n_episodes:.0f}%)")
def _require_successful_submit(final_step: dict[str, object], *, baseline_name: str) -> None:
action = final_step.get("action")
if action != "submit":
raise ValueError(
f"{baseline_name} baseline ended on {action!r} instead of an explicit submit."
)
if bool(final_step.get("evaluation_failed")):
raise ValueError(f"{baseline_name} baseline submit ended in evaluation failure.")
if __name__ == "__main__":
n = int(sys.argv[1]) if len(sys.argv) > 1 else 20
main(n)
|