VLAarchtests / code /reveal_vla_bimanual /eval /run_ablations.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
import time
from eval.ablations import MANDATORY_ABLATIONS
from eval.report import write_comparison_report
from eval.run_reveal_benchmark import _paired_seed_summary, evaluate_model, load_model
from sim_reveal import available_proxy_names
import torch
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--episodes", type=int, default=24)
parser.add_argument("--resolution", type=int, default=None)
parser.add_argument("--output-root", default="/workspace/reports/reveal_ablation")
parser.add_argument("--proxies", nargs="*", default=None)
parser.add_argument("--chunk-commit-steps", type=int, default=0)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, checkpoint = load_model(args.checkpoint, device=device)
resolution = int(args.resolution or checkpoint.get("data_resolution", 96))
proxies = list(args.proxies or available_proxy_names())
output_root = Path(args.output_root)
output_root.mkdir(parents=True, exist_ok=True)
chunk_commit_steps = None if args.chunk_commit_steps <= 0 else args.chunk_commit_steps
json_path = output_root / "ablations.json"
partial_path = output_root / "ablations.partial.json"
sections = {}
raw = {}
full_episode_records: list[dict[str, float | int | str]] | None = None
completed_labels: set[str] = set()
if args.resume and partial_path.exists():
partial = json.loads(partial_path.read_text(encoding="utf-8"))
raw = partial.get("raw", {})
sections = partial.get("sections", {})
completed_labels = set(raw)
full_episode_records = raw.get("full_model", {}).get("episode_records")
print(json.dumps({"resume_from": str(partial_path), "completed": sorted(completed_labels)}, indent=2))
ablations = (None, *MANDATORY_ABLATIONS)
start_time = time.monotonic()
for index, ablation in enumerate(ablations, start=1):
label = "full_model" if ablation is None else ablation
if label in completed_labels:
continue
print(json.dumps({"running": label, "index": index, "total": len(ablations)}, indent=2))
metrics = evaluate_model(
model=model,
device=device,
proxies=proxies,
episodes=args.episodes,
resolution=resolution,
ablation=ablation,
chunk_commit_steps=chunk_commit_steps,
)
metrics, episode_records = metrics
raw[label] = {
"per_task_success": metrics.per_task_success,
"mean_success": metrics.mean_success,
"visibility_integral": metrics.visibility_integral,
"corridor_availability": metrics.corridor_availability,
"reocclusion_rate": metrics.reocclusion_rate,
"persistence_horizon_mae": metrics.persistence_horizon_mae,
"disturbance_cost": metrics.disturbance_cost,
"episode_records": episode_records,
}
sections[label] = {
"mean_success": metrics.mean_success,
"visibility_integral": metrics.visibility_integral or 0.0,
"corridor_availability": metrics.corridor_availability or 0.0,
"reocclusion_rate": metrics.reocclusion_rate or 0.0,
"persistence_horizon_mae": metrics.persistence_horizon_mae or 0.0,
"disturbance_cost": metrics.disturbance_cost or 0.0,
"chunk_commit_steps": float(0 if chunk_commit_steps is None else chunk_commit_steps),
}
if label == "full_model":
full_episode_records = episode_records
elif full_episode_records is not None:
paired = _paired_seed_summary(full_episode_records, episode_records)
raw[label]["paired_seed_summary_vs_full_model"] = paired
for key, value in paired.items():
sections[label][f"paired_{key}_vs_full_model"] = value
partial_path.write_text(
json.dumps(
{
"checkpoint": args.checkpoint,
"episodes": args.episodes,
"chunk_commit_steps": 0 if chunk_commit_steps is None else chunk_commit_steps,
"sections": sections,
"raw": raw,
"elapsed_seconds": time.monotonic() - start_time,
},
indent=2,
),
encoding="utf-8",
)
json_path.write_text(json.dumps(raw, indent=2), encoding="utf-8")
write_comparison_report(output_root / "ablations.md", "Reveal Ablations", sections)
print(json.dumps({"output_json": str(json_path), "sections": sections}, indent=2))
if __name__ == "__main__":
main()