| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| import time |
|
|
| from eval.ablations import MANDATORY_ABLATIONS |
| from eval.report import write_comparison_report |
| from eval.run_reveal_benchmark import _paired_seed_summary, evaluate_model, load_model |
| from sim_reveal import available_proxy_names |
|
|
| import torch |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--episodes", type=int, default=24) |
| parser.add_argument("--resolution", type=int, default=None) |
| parser.add_argument("--output-root", default="/workspace/reports/reveal_ablation") |
| parser.add_argument("--proxies", nargs="*", default=None) |
| parser.add_argument("--chunk-commit-steps", type=int, default=0) |
| parser.add_argument("--resume", action="store_true") |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model, checkpoint = load_model(args.checkpoint, device=device) |
| resolution = int(args.resolution or checkpoint.get("data_resolution", 96)) |
| proxies = list(args.proxies or available_proxy_names()) |
| output_root = Path(args.output_root) |
| output_root.mkdir(parents=True, exist_ok=True) |
| chunk_commit_steps = None if args.chunk_commit_steps <= 0 else args.chunk_commit_steps |
|
|
| json_path = output_root / "ablations.json" |
| partial_path = output_root / "ablations.partial.json" |
| sections = {} |
| raw = {} |
| full_episode_records: list[dict[str, float | int | str]] | None = None |
| completed_labels: set[str] = set() |
| if args.resume and partial_path.exists(): |
| partial = json.loads(partial_path.read_text(encoding="utf-8")) |
| raw = partial.get("raw", {}) |
| sections = partial.get("sections", {}) |
| completed_labels = set(raw) |
| full_episode_records = raw.get("full_model", {}).get("episode_records") |
| print(json.dumps({"resume_from": str(partial_path), "completed": sorted(completed_labels)}, indent=2)) |
|
|
| ablations = (None, *MANDATORY_ABLATIONS) |
| start_time = time.monotonic() |
| for index, ablation in enumerate(ablations, start=1): |
| label = "full_model" if ablation is None else ablation |
| if label in completed_labels: |
| continue |
| print(json.dumps({"running": label, "index": index, "total": len(ablations)}, indent=2)) |
| metrics = evaluate_model( |
| model=model, |
| device=device, |
| proxies=proxies, |
| episodes=args.episodes, |
| resolution=resolution, |
| ablation=ablation, |
| chunk_commit_steps=chunk_commit_steps, |
| ) |
| metrics, episode_records = metrics |
| raw[label] = { |
| "per_task_success": metrics.per_task_success, |
| "mean_success": metrics.mean_success, |
| "visibility_integral": metrics.visibility_integral, |
| "corridor_availability": metrics.corridor_availability, |
| "reocclusion_rate": metrics.reocclusion_rate, |
| "persistence_horizon_mae": metrics.persistence_horizon_mae, |
| "disturbance_cost": metrics.disturbance_cost, |
| "episode_records": episode_records, |
| } |
| sections[label] = { |
| "mean_success": metrics.mean_success, |
| "visibility_integral": metrics.visibility_integral or 0.0, |
| "corridor_availability": metrics.corridor_availability or 0.0, |
| "reocclusion_rate": metrics.reocclusion_rate or 0.0, |
| "persistence_horizon_mae": metrics.persistence_horizon_mae or 0.0, |
| "disturbance_cost": metrics.disturbance_cost or 0.0, |
| "chunk_commit_steps": float(0 if chunk_commit_steps is None else chunk_commit_steps), |
| } |
| if label == "full_model": |
| full_episode_records = episode_records |
| elif full_episode_records is not None: |
| paired = _paired_seed_summary(full_episode_records, episode_records) |
| raw[label]["paired_seed_summary_vs_full_model"] = paired |
| for key, value in paired.items(): |
| sections[label][f"paired_{key}_vs_full_model"] = value |
| partial_path.write_text( |
| json.dumps( |
| { |
| "checkpoint": args.checkpoint, |
| "episodes": args.episodes, |
| "chunk_commit_steps": 0 if chunk_commit_steps is None else chunk_commit_steps, |
| "sections": sections, |
| "raw": raw, |
| "elapsed_seconds": time.monotonic() - start_time, |
| }, |
| indent=2, |
| ), |
| encoding="utf-8", |
| ) |
|
|
| json_path.write_text(json.dumps(raw, indent=2), encoding="utf-8") |
| write_comparison_report(output_root / "ablations.md", "Reveal Ablations", sections) |
| print(json.dumps({"output_json": str(json_path), "sections": sections}, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|