fusion-design-lab / baselines /high_fidelity_validation.py
CreativeEngineer's picture
feat: reward verifier alignment, notebook hardening, model name fix
cdc237b
"""Validation utilities for high-fidelity fixture pairing and submit-side traces."""
from __future__ import annotations
import argparse
import json
from dataclasses import asdict, dataclass
from datetime import UTC, datetime
from pathlib import Path
from pprint import pformat
from time import perf_counter
from typing import Any
from fusion_lab.models import LowDimBoundaryParams, StellaratorAction
from server.contract import N_FIELD_PERIODS
from server.environment import StellaratorEnvironment
from server.physics import EvaluationMetrics, build_boundary_from_params, evaluate_boundary
LOW_FIDELITY_TOLERANCE = 1.0e-6
def _float(value: Any) -> float | None:
if isinstance(value, bool):
return None
try:
return float(value)
except (TypeError, ValueError):
return None
@dataclass(frozen=True)
class FixturePairResult:
name: str
file: str
status: str
low_fidelity: dict[str, Any]
high_fidelity: dict[str, Any]
comparison: dict[str, Any]
@dataclass(frozen=True)
class TraceStep:
step: int
intent: str
action: str
reward: float
score: float
feasibility: float
constraints_satisfied: bool
feasibility_delta: float | None
score_delta: float | None
max_elongation: float
p1_feasibility: float
budget_remaining: int
evaluation_fidelity: str
done: bool
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Run paired high-fidelity fixture checks and a submit-side manual trace "
"for the repaired P1 contract."
)
)
parser.add_argument(
"--fixture-dir",
type=Path,
default=Path("server/data/p1"),
help="Directory containing tracked P1 fixture JSON files.",
)
parser.add_argument(
"--fixture-output",
type=Path,
default=Path("baselines/fixture_high_fidelity_pairs.json"),
help="Output path for paired fixture summary JSON.",
)
parser.add_argument(
"--trace-output",
type=Path,
default=Path("baselines/submit_side_trace.json"),
help="Output path for one submit-side manual trace JSON.",
)
parser.add_argument(
"--no-write-fixture-updates",
action="store_true",
help="Do not write paired high-fidelity results back into fixture files.",
)
parser.add_argument(
"--skip-submit-trace",
action="store_true",
help="Only run paired fixture checks.",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Seed for the submit-side manual trace reset state.",
)
parser.add_argument(
"--submit-action-sequence",
type=str,
default=(
"run:rotational_transform:increase:medium,"
"run:triangularity_scale:increase:medium,"
"run:elongation:decrease:small,"
"submit"
),
help=(
"Comma-separated submit trace sequence. "
"Run actions are `run:parameter:direction:magnitude`; include `submit` as the last token."
),
)
return parser.parse_args()
def _fixture_files(fixture_dir: Path) -> list[Path]:
return sorted(path for path in fixture_dir.glob("*.json") if path.is_file())
def _load_fixture(path: Path) -> dict[str, Any]:
with path.open("r") as file:
return json.load(file)
def _metrics_payload(metrics: EvaluationMetrics) -> dict[str, Any]:
return {
"evaluation_failed": metrics.evaluation_failed,
"constraints_satisfied": metrics.constraints_satisfied,
"p1_score": metrics.p1_score,
"p1_feasibility": metrics.p1_feasibility,
"max_elongation": metrics.max_elongation,
"aspect_ratio": metrics.aspect_ratio,
"average_triangularity": metrics.average_triangularity,
"edge_iota_over_nfp": metrics.edge_iota_over_nfp,
"vacuum_well": metrics.vacuum_well,
"evaluation_fidelity": metrics.evaluation_fidelity,
"failure_reason": metrics.failure_reason,
}
def _parse_submit_sequence(raw: str) -> list[StellaratorAction]:
actions: list[StellaratorAction] = []
for token in raw.split(","):
token = token.strip()
if not token:
continue
if token == "submit":
actions.append(StellaratorAction(intent="submit"))
continue
parts = token.split(":")
if len(parts) != 4 or parts[0] != "run":
raise ValueError(
"Expected token format `run:parameter:direction:magnitude` or `submit`."
)
_, parameter, direction, magnitude = parts
actions.append(
StellaratorAction(
intent="run",
parameter=parameter,
direction=direction,
magnitude=magnitude,
)
)
if not actions:
raise ValueError("submit-action-sequence must include at least one action.")
if actions[-1].intent != "submit":
raise ValueError("submit-action-sequence must end with submit.")
return actions
def _compare_low_snapshot(
stored: dict[str, Any],
current: dict[str, Any],
) -> tuple[bool, dict[str, Any]]:
numeric_keys = [
"p1_feasibility",
"p1_score",
"max_elongation",
"aspect_ratio",
"average_triangularity",
"edge_iota_over_nfp",
"vacuum_well",
]
exact_keys = [
"constraints_satisfied",
"evaluation_fidelity",
"evaluation_failed",
"failure_reason",
]
missing_fields: list[str] = []
drift_fields: dict[str, dict[str, float]] = {}
mismatches: list[dict[str, Any]] = []
max_abs_drift = 0.0
for key in numeric_keys:
if key not in stored:
missing_fields.append(key)
continue
expected = _float(stored.get(key))
actual = _float(current.get(key))
if expected is None or actual is None:
mismatches.append(
{
"field": key,
"expected": stored.get(key),
"actual": current.get(key),
"reason": "non-numeric",
}
)
continue
drift = abs(expected - actual)
max_abs_drift = max(max_abs_drift, drift)
if drift > LOW_FIDELITY_TOLERANCE:
drift_fields[key] = {
"expected": expected,
"actual": actual,
"abs_drift": drift,
}
mismatches.append(
{
"field": key,
"expected": expected,
"actual": actual,
"abs_drift": drift,
}
)
for key in exact_keys:
if key not in stored:
missing_fields.append(key)
continue
expected = stored.get(key)
actual = current.get(key)
if expected != actual:
mismatches.append(
{
"field": key,
"expected": expected,
"actual": actual,
"reason": "exact-mismatch",
}
)
return (
not missing_fields and not mismatches,
{
"missing_fields": missing_fields,
"drift_fields": drift_fields,
"mismatches": mismatches,
"max_abs_drift": max_abs_drift,
},
)
def _pair_fixture(path: Path) -> FixturePairResult:
data = _load_fixture(path)
params = LowDimBoundaryParams.model_validate(data["params"])
boundary = build_boundary_from_params(params, n_field_periods=N_FIELD_PERIODS)
low = evaluate_boundary(boundary, fidelity="low")
high = evaluate_boundary(boundary, fidelity="high")
low_payload = _metrics_payload(low)
high_payload = _metrics_payload(high)
low_snapshot_ok, low_snapshot = _compare_low_snapshot(
data.get("low_fidelity", {}),
low_payload,
)
feasible_match = low.constraints_satisfied == high.constraints_satisfied
ranking_compat = (
"ambiguous"
if low.evaluation_failed or high.evaluation_failed
else "match"
if feasible_match
else "mismatch"
)
comparison: dict[str, Any] = {
"low_high_feasibility_match": feasible_match,
"feasibility_delta": high.p1_feasibility - low.p1_feasibility,
"score_delta": high.p1_score - low.p1_score,
"ranking_compatibility": ranking_compat,
"low_fidelity_stored_p1_score": data.get("low_fidelity", {}).get("p1_score"),
"low_fidelity_stored_p1_feasibility": data.get("low_fidelity", {}).get("p1_feasibility"),
"low_fidelity_snapshot": low_snapshot,
}
status = "pass"
if low.evaluation_failed or high.evaluation_failed or not feasible_match or not low_snapshot_ok:
status = "fail"
if not low_snapshot_ok:
print(f" low-fidelity snapshot mismatch:\n{pformat(low_snapshot)}")
return FixturePairResult(
name=str(data.get("name", path.stem)),
file=str(path),
status=status,
low_fidelity=low_payload,
high_fidelity=high_payload,
comparison=comparison,
)
def _write_json(payload: dict[str, Any], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w") as file:
json.dump(payload, file, indent=2)
def _run_fixture_checks(
*,
fixture_dir: Path,
fixture_output: Path,
write_fixture_updates: bool,
) -> tuple[list[FixturePairResult], int]:
results: list[FixturePairResult] = []
fail_count = 0
for path in _fixture_files(fixture_dir):
print(f"Evaluating fixture: {path.name}")
fixture_start = perf_counter()
result = _pair_fixture(path)
if result.status != "pass":
fail_count += 1
results.append(result)
if write_fixture_updates:
fixture = _load_fixture(path)
fixture["high_fidelity"] = result.high_fidelity
fixture["paired_high_fidelity_timestamp_utc"] = datetime.now(tz=UTC).isoformat()
with path.open("w") as file:
json.dump(fixture, file, indent=2)
elapsed = perf_counter() - fixture_start
print(
" done in "
f"{elapsed:0.1f}s | low_feasible={result.low_fidelity['constraints_satisfied']} "
f"| high_feasible={result.high_fidelity['constraints_satisfied']} "
f"| status={result.status}"
)
pass_count = len(results) - fail_count
payload = {
"timestamp_utc": datetime.now(tz=UTC).isoformat(),
"n_field_periods": N_FIELD_PERIODS,
"fixture_count": len(results),
"pass_count": pass_count,
"fail_count": fail_count,
"results": [asdict(result) for result in results],
}
_write_json(payload, fixture_output)
return results, fail_count
def _run_submit_trace(
trace_output: Path,
*,
seed: int,
action_sequence: str,
) -> dict[str, Any]:
env = StellaratorEnvironment()
obs = env.reset(seed=seed)
reset_params = env.state.current_params.model_dump()
actions = _parse_submit_sequence(action_sequence)
trace: list[dict[str, Any]] = [
{
"step": 0,
"intent": "reset",
"action": f"reset(seed={seed})",
"reward": 0.0,
"score": obs.p1_score,
"feasibility": obs.p1_feasibility,
"feasibility_delta": None,
"score_delta": None,
"constraints_satisfied": obs.constraints_satisfied,
"max_elongation": obs.max_elongation,
"p1_feasibility": obs.p1_feasibility,
"budget_remaining": obs.budget_remaining,
"evaluation_fidelity": obs.evaluation_fidelity,
"done": obs.done,
"params": reset_params,
}
]
previous_feasibility = obs.p1_feasibility
previous_score = obs.p1_score
for idx, action in enumerate(actions, start=1):
obs = env.step(action)
trace.append(
asdict(
TraceStep(
step=idx,
intent=action.intent,
action=(
f"{action.parameter} {action.direction} {action.magnitude}"
if action.intent == "run"
else action.intent
),
reward=float(obs.reward or 0.0),
score=obs.p1_score,
feasibility=obs.p1_feasibility,
constraints_satisfied=obs.constraints_satisfied,
feasibility_delta=obs.p1_feasibility - previous_feasibility,
score_delta=obs.p1_score - previous_score,
max_elongation=obs.max_elongation,
p1_feasibility=obs.p1_feasibility,
budget_remaining=obs.budget_remaining,
evaluation_fidelity=obs.evaluation_fidelity,
done=obs.done,
)
)
)
previous_feasibility = obs.p1_feasibility
previous_score = obs.p1_score
if obs.done:
break
total_reward = sum(step["reward"] for step in trace)
payload = {
"trace_label": "submit_side_manual",
"trace_profile": action_sequence,
"timestamp_utc": datetime.now(tz=UTC).isoformat(),
"n_field_periods": N_FIELD_PERIODS,
"seed": seed,
"total_reward": total_reward,
"final_score": obs.p1_score,
"final_feasibility": obs.p1_feasibility,
"final_constraints_satisfied": obs.constraints_satisfied,
"final_evaluation_fidelity": obs.evaluation_fidelity,
"final_evaluation_failed": obs.evaluation_failed,
"steps": trace,
"final_best_low_fidelity_score": obs.best_low_fidelity_score,
"final_best_low_fidelity_feasibility": obs.best_low_fidelity_feasibility,
"final_diagnostics_text": obs.diagnostics_text,
}
_write_json(payload, trace_output)
return payload
def main() -> int:
args = parse_args()
results, fail_count = _run_fixture_checks(
fixture_dir=args.fixture_dir,
fixture_output=args.fixture_output,
write_fixture_updates=not args.no_write_fixture_updates,
)
print(
f"Paired fixtures: {len(results)} total, {len(results) - fail_count} pass, {fail_count} fail"
)
for result in results:
print(
f" - {result.name}: {result.status} "
f"(low={result.low_fidelity['constraints_satisfied']} "
f"high={result.high_fidelity['constraints_satisfied']})"
)
if not args.skip_submit_trace:
trace = _run_submit_trace(
args.trace_output,
seed=args.seed,
action_sequence=args.submit_action_sequence,
)
print(
f"Manual submit trace written to {args.trace_output} | "
f"sequence='{trace['trace_profile']}' | "
f"final_feasibility={trace['final_feasibility']:.6f} | "
f"fidelity={trace['final_evaluation_fidelity']}"
)
return 1 if fail_count else 0
if __name__ == "__main__":
raise SystemExit(main())