File size: 1,803 Bytes
8cf92b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
"""Evaluate MapGS and write the §4 protocol-template report.

    python scripts/eval.py --config configs/synthetic_smoke.yaml --ckpt runs/mapgs_synth/ckpt_final.pt
"""

import argparse

import torch

from mapgs.config import load_config
from mapgs.data import build_dataset
from mapgs.eval import Evaluator, write_report
from mapgs.model import MapGS


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", type=str, default=None)
    ap.add_argument("--override", nargs="*", default=[])
    ap.add_argument("--ckpt", type=str, default=None)
    ap.add_argument("--split", type=str, default="val")
    ap.add_argument("--out", type=str, default="eval_report.md")
    ap.add_argument("--max-scenes", type=int, default=50)
    ap.add_argument("--tt", action="store_true", help="apply test-time token tuning")
    args = ap.parse_args()

    cfg = load_config(args.config, args.override)
    model = MapGS(cfg).to(cfg.device)
    if args.ckpt:
        ckpt = torch.load(args.ckpt, map_location=cfg.device, weights_only=False)
        model.load_state_dict(ckpt["model"])
        print(f"loaded {args.ckpt}")
    model.eval()

    dataset = build_dataset(cfg, args.split)
    ev = Evaluator(model, cfg, device=cfg.device)
    interp = ev.interpolation(dataset, max_scenes=args.max_scenes)
    print("interpolation:", interp)
    sweep = ev.extrapolation_sweep(dataset, max_scenes=min(args.max_scenes, 30))
    for sh, m in sweep.items():
        print(f"  shift {sh}m:", m)
    lane = ev.lane_consistency(dataset, max_scenes=min(args.max_scenes, 30))
    print("lane:", lane)
    write_report(args.out, interpolation={cfg.data.name: interp}, extrapolation=sweep, lane=lane)
    print(f"report written to {args.out}")


if __name__ == "__main__":
    main()