File size: 5,940 Bytes
8cf92b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""Visualize the trained MapTokenGS: held-out novel-view reconstruction (GT vs
predicted RGB vs predicted depth) and an out-of-trajectory simulation strip
(render the scene from laterally-shifted cameras the ego never drove through).
Saves PNGs under runs/."""
import argparse, numpy as np, torch
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt

from tokengs import options
from safetensors.torch import load_file
from mapgs.config import load_config
from mapgs.data import UnifiedClipDataset
from mapgs.eval.metrics import psnr, ssim
import scripts.finetune_maptokengs as FT

dev = "cuda"


def to_img(t):  # [3,H,W] -> HWC uint8
    return (t.clamp(0, 1).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)


def depth_img(d):  # [H,W] -> colored HWC
    d = d.cpu().numpy()
    lo, hi = np.percentile(d[d > 0], 2) if (d > 0).any() else 0, np.percentile(d[d > 0], 98) if (d > 0).any() else 1
    dn = np.clip((d - lo) / max(hi - lo, 1e-6), 0, 1)
    return (plt.cm.turbo(dn)[..., :3] * 255).astype(np.uint8)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="/mnt/william/runs/maptokengs_highres_best.safetensors")
    ap.add_argument("--scan", type=int, default=20)
    ap.add_argument("--s", type=float, default=1.5)
    ap.add_argument("--H", type=int, default=448)
    ap.add_argument("--W", type=int, default=768)
    args = ap.parse_args()
    H, W = args.H, args.W

    opt = options.config_defaults["train_dl3dv_base"]; opt = opt() if callable(opt) else opt
    opt.img_size = (H, W)
    m = FT.MapTokenGS(opt, s0=0.5, s_max=2.0, max_instances=8, dyn_per_instance=32).to(dev).eval()
    m.load_state_dict(load_file(args.ckpt), strict=False)
    m.cur_s = args.s
    cfg = load_config(overrides=["data.name=unified",
                                 "data.root=/mnt/william/data/unified/av2,/mnt/william/data/unified/waymo",
                                 f"data.height={H}", f"data.width={W}", "model.tokens.n_map=2048",
                                 "data.max_instances=8"])
    ds = UnifiedClipDataset(cfg, roots=["/mnt/william/data/unified/av2", "/mnt/william/data/unified/waymo"],
                            split="train", n_sup_views=4)

    # pick the best-PSNR clip among the first --scan (a clean, representative example)
    best = (-1, None, None)
    for i in range(args.scan):
        d = FT.prep(ds[i], 8, 1.0, dev)
        with torch.no_grad():
            g = m.forward_reconstruction_mapped(d["mi"], d["anchors"], d["atype"], d["anormal"])
            grp = m.gaussian_group()[0]; inst = m.gaussian_instance()[0]
            rgb, dep, alp = FT.render_dyn(m, g, grp, inst, d["mi"].decoder.cam_view,
                                          d["mi"].decoder.intrinsics, d["sup_frame"], d["dyn"], gain=1.0)
        p = float(psnr(rgb.clamp(0, 1), d["gt"]))
        if p > best[0]:
            best = (p, i, (d, g, grp, inst, rgb, dep))
    p, i, (d, g, grp, inst, rgb, dep) = best
    src = "waymo" if "waymo" in ds.clips[i] else "av2"
    print(f"chosen clip idx {i} ({src}) | held-out PSNR {p:.2f} SSIM {float(ssim(rgb.clamp(0,1), d['gt'])):.3f}", flush=True)

    # ---- panel 1: held-out NVS grid (rows = views, cols = GT | Pred | Depth) ----
    gt = d["gt"]; V = gt.shape[0]
    fig, ax = plt.subplots(V, 3, figsize=(3 * 3.4, V * 1.9))
    for v in range(V):
        pv = float(psnr(rgb[v:v+1].clamp(0, 1), gt[v:v+1]))
        ax[v, 0].imshow(to_img(gt[v])); ax[v, 0].set_ylabel(f"view {v}", fontsize=9)
        ax[v, 1].imshow(to_img(rgb[v])); ax[v, 1].set_title(f"PSNR {pv:.1f}", fontsize=8)
        ax[v, 2].imshow(depth_img(dep[v]))
        for c in range(3):
            ax[v, c].set_xticks([]); ax[v, c].set_yticks([])
    for c, t in enumerate(["GT (held-out)", "MapTokenGS pred", "pred depth"]):
        ax[0, c].set_title(t + ("" if c else ""), fontsize=10) if c == 0 else None
        ax[0, c].text(0.5, 1.25, t, transform=ax[0, c].transAxes, ha="center", fontsize=11, weight="bold")
    fig.suptitle(f"Held-out novel-view reconstruction — {src} clip {i} — mean PSNR {p:.2f}", y=0.995, fontsize=12)
    fig.tight_layout(rect=[0, 0, 1, 0.97])
    out1 = "/mnt/william/runs/viz_nvs.png"; fig.savefig(out1, dpi=110, bbox_inches="tight"); plt.close(fig)
    print("saved", out1, flush=True)

    # ---- panel 2: out-of-trajectory simulation (lateral shifts from input view 0) ----
    from mapgs.losses.extrap import perturb_pose
    refK = d["ref_K"]; intr = torch.stack([refK[0, 0], refK[1, 1], refK[0, 2], refK[1, 2]])[None, None].to(dev)
    f0 = int(d["sup_frame"][0]) if d["dyn"] is not None else 0
    shifts = [-0.6, -0.3, 0.0, 0.3, 0.6]
    fig2, ax2 = plt.subplots(1, len(shifts), figsize=(len(shifts) * 3.0, 2.6))
    for k, s in enumerate(shifts):
        P = perturb_pose(d["ref_c2w_norm"], lateral=s, yaw_deg=0.0)
        cv = torch.inverse(P).transpose(0, 1)[None, None].to(dev)
        gg = g[0]
        if d["dyn"] is not None:
            from maptokengs import place_dynamics_14d
            gg = place_dynamics_14d(g[0], grp, inst, f0, d["dyn"]["c"], d["dyn"]["R"],
                                    d["dyn"]["canon"], d["dyn"]["valid"], d["dyn"]["radius"], gain=1.0)
        with torch.no_grad():
            o = m.render_gaussians(gg[None], cv, intr)
        ax2[k].imshow(to_img(o["images_pred"][0, 0].clamp(0, 1))); ax2[k].set_xticks([]); ax2[k].set_yticks([])
        ax2[k].set_title(("ego path" if s == 0 else f"{'+' if s>0 else ''}{s} lat"), fontsize=10,
                         weight=("bold" if s == 0 else "normal"))
    fig2.suptitle(f"Out-of-trajectory simulation — render from cameras off the driven path ({src} clip {i})", fontsize=11)
    fig2.tight_layout(rect=[0, 0, 1, 0.93])
    out2 = "/mnt/william/runs/viz_oot.png"; fig2.savefig(out2, dpi=110, bbox_inches="tight"); plt.close(fig2)
    print("saved", out2, flush=True)


if __name__ == "__main__":
    main()