| |
| """End-to-end demonstration: eval an untrained model, train briefly, eval again, |
| and show the metrics improve. Also writes the §4 protocol-template report and an |
| extrapolation comparison image. Synthetic data only (sanity, not research scale).""" |
|
|
| import argparse |
| import time |
|
|
| import torch |
|
|
| from mapgs.config import load_config |
| from mapgs.data import SyntheticDataset, collate_samples |
| from mapgs.eval import Evaluator, write_report |
| from mapgs.train import Trainer |
|
|
|
|
| def fmt(d): |
| return {k: (round(v, 3) if isinstance(v, float) else v) for k, v in d.items()} |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--iters", type=int, default=400) |
| ap.add_argument("--n-scenes", type=int, default=24) |
| ap.add_argument("--out", type=str, default="/mnt/william/eval_report.md") |
| args = ap.parse_args() |
|
|
| cfg = load_config("/mnt/william/configs/synthetic_smoke.yaml", [ |
| "data.root=/mnt/william/data/synthetic", "train.batch_size=4", |
| f"train.iters={args.iters}", "train.log_every=50", "train.ckpt_every=0", |
| "train.extrap_ramp_iter=100", |
| ]) |
| print("generating data ...") |
| train_ds = SyntheticDataset(cfg, "train", n_scenes=args.n_scenes, n_sup_views=6, device="cuda") |
| val_ds = SyntheticDataset(cfg, "val", n_scenes=8, n_sup_views=6, device="cuda") |
|
|
| trainer = Trainer(cfg) |
| ev = Evaluator(trainer.model, cfg, device="cuda") |
|
|
| print("\n=== BEFORE training (random init) ===") |
| before = ev.interpolation(val_ds, max_scenes=8) |
| before_ext = ev.extrapolation_sweep(val_ds, shifts=[2.0], max_scenes=6, frame=cfg.data.num_frames // 2) |
| print("interp:", fmt(before), "| extrap@2m:", fmt(before_ext[2.0])) |
|
|
| print(f"\n=== training {args.iters} iters ===") |
| t = time.time() |
| trainer.fit(train_ds, max_iters=args.iters) |
| print(f"trained in {time.time()-t:.0f}s") |
|
|
| print("\n=== AFTER training ===") |
| trainer.model.eval() |
| after = ev.interpolation(val_ds, max_scenes=8) |
| after_ext = ev.extrapolation_sweep(val_ds, shifts=list(cfg.eval.lateral_shifts), max_scenes=6, |
| frame=cfg.data.num_frames // 2) |
| lane = ev.lane_consistency(val_ds, max_scenes=6, frame=cfg.data.num_frames // 2) |
| print("interp:", fmt(after)) |
| for sh, m in after_ext.items(): |
| print(f" extrap@{sh}m:", fmt(m)) |
| print("lane:", fmt(lane)) |
|
|
| print("\n=== IMPROVEMENT (interpolation) ===") |
| print(f" PSNR : {before['PSNR']:.2f} -> {after['PSNR']:.2f}") |
| print(f" SSIM : {before['SSIM']:.3f} -> {after['SSIM']:.3f}") |
| print(f" LPIPS: {before['LPIPS']:.3f} -> {after['LPIPS']:.3f}") |
| print(f" D-RMSE(m): {before['D-RMSE']:.3f} -> {after['D-RMSE']:.3f}") |
|
|
| write_report(args.out, interpolation={"Synthetic(val)": after}, extrapolation=after_ext, lane=lane, |
| header=f"_Demo run: {args.iters} iters, {args.n_scenes} synthetic scenes, " |
| f"config synthetic_smoke. Numbers are sanity-scale, not research results._") |
| trainer.save("/mnt/william/runs/mapgs_synth/ckpt_demo.pt") |
| print(f"\nreport -> {args.out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|