mapvggt / scripts /demo_train_eval.py
ChenmingWu's picture
Upload folder using huggingface_hub
8cf92b3 verified
Raw
History Blame Contribute Delete
3.15 kB
#!/usr/bin/env python3
"""End-to-end demonstration: eval an untrained model, train briefly, eval again,
and show the metrics improve. Also writes the §4 protocol-template report and an
extrapolation comparison image. Synthetic data only (sanity, not research scale)."""
import argparse
import time
import torch
from mapgs.config import load_config
from mapgs.data import SyntheticDataset, collate_samples
from mapgs.eval import Evaluator, write_report
from mapgs.train import Trainer
def fmt(d):
return {k: (round(v, 3) if isinstance(v, float) else v) for k, v in d.items()}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--iters", type=int, default=400)
ap.add_argument("--n-scenes", type=int, default=24)
ap.add_argument("--out", type=str, default="/mnt/william/eval_report.md")
args = ap.parse_args()
cfg = load_config("/mnt/william/configs/synthetic_smoke.yaml", [
"data.root=/mnt/william/data/synthetic", "train.batch_size=4",
f"train.iters={args.iters}", "train.log_every=50", "train.ckpt_every=0",
"train.extrap_ramp_iter=100",
])
print("generating data ...")
train_ds = SyntheticDataset(cfg, "train", n_scenes=args.n_scenes, n_sup_views=6, device="cuda")
val_ds = SyntheticDataset(cfg, "val", n_scenes=8, n_sup_views=6, device="cuda")
trainer = Trainer(cfg)
ev = Evaluator(trainer.model, cfg, device="cuda")
print("\n=== BEFORE training (random init) ===")
before = ev.interpolation(val_ds, max_scenes=8)
before_ext = ev.extrapolation_sweep(val_ds, shifts=[2.0], max_scenes=6, frame=cfg.data.num_frames // 2)
print("interp:", fmt(before), "| extrap@2m:", fmt(before_ext[2.0]))
print(f"\n=== training {args.iters} iters ===")
t = time.time()
trainer.fit(train_ds, max_iters=args.iters)
print(f"trained in {time.time()-t:.0f}s")
print("\n=== AFTER training ===")
trainer.model.eval()
after = ev.interpolation(val_ds, max_scenes=8)
after_ext = ev.extrapolation_sweep(val_ds, shifts=list(cfg.eval.lateral_shifts), max_scenes=6,
frame=cfg.data.num_frames // 2)
lane = ev.lane_consistency(val_ds, max_scenes=6, frame=cfg.data.num_frames // 2)
print("interp:", fmt(after))
for sh, m in after_ext.items():
print(f" extrap@{sh}m:", fmt(m))
print("lane:", fmt(lane))
print("\n=== IMPROVEMENT (interpolation) ===")
print(f" PSNR : {before['PSNR']:.2f} -> {after['PSNR']:.2f}")
print(f" SSIM : {before['SSIM']:.3f} -> {after['SSIM']:.3f}")
print(f" LPIPS: {before['LPIPS']:.3f} -> {after['LPIPS']:.3f}")
print(f" D-RMSE(m): {before['D-RMSE']:.3f} -> {after['D-RMSE']:.3f}")
write_report(args.out, interpolation={"Synthetic(val)": after}, extrapolation=after_ext, lane=lane,
header=f"_Demo run: {args.iters} iters, {args.n_scenes} synthetic scenes, "
f"config synthetic_smoke. Numbers are sanity-scale, not research results._")
trainer.save("/mnt/william/runs/mapgs_synth/ckpt_demo.pt")
print(f"\nreport -> {args.out}")
if __name__ == "__main__":
main()