mapvggt / tests /test_model_render.py
ChenmingWu's picture
Upload folder using huggingface_hub
195056b verified
Raw
History Blame Contribute Delete
3.92 kB
import torch
from mapgs.model import MapGS
from mapgs.render import Gaussians, GaussianRasterizer, GROUP_MAP, GROUP_DYNAMIC
from mapgs.geometry import plucker_embedding, look_at_pose
from .conftest import requires_cuda
def _model_inputs(cfg, device, B=1, dynamic=False):
V, H, W = 3, cfg.data.height, cfg.data.width
imgs = torch.rand(B, V, 3, H, W, device=device)
K = torch.tensor([[60., 0, W / 2], [0, 60, H / 2], [0, 0, 1]], device=device)[None, None].repeat(B, V, 1, 1)
c2w = torch.stack([look_at_pose(torch.tensor([dx, 0., 1.5]), torch.tensor([dx, 20., 0.5]))
for dx in [-1., 0, 1]]).to(device)[None].repeat(B, 1, 1, 1)
pl = torch.stack([plucker_embedding(K[b], c2w[b], H, W) for b in range(B)])
tids = torch.zeros(B, V, dtype=torch.long, device=device)
nmap = cfg.model.tokens.n_map
apos = torch.randn(B, nmap, 3, device=device); apos[..., 2] *= 0.1
atype = torch.randint(0, 3, (B, nmap), device=device)
anorm = torch.zeros(B, nmap, 3, device=device); anorm[..., 2] = 1
dyn = None
if dynamic:
I, F = 2, cfg.data.num_frames
centers = torch.zeros(B, I, F, 3, device=device)
for f in range(F):
centers[:, 0, f] = torch.tensor([3., 5 + 0.5 * f, 0.5], device=device)
centers[:, 1, f] = torch.tensor([-3., 8., 0.5], device=device)
dyn = dict(box_centers=centers, box_rots=torch.eye(3, device=device).view(1, 1, 1, 3, 3).repeat(B, I, F, 1, 1),
box_size=torch.ones(B, I, 3, device=device) * 2, valid=torch.ones(B, I, dtype=torch.bool, device=device),
canon_idx=torch.zeros(B, I, dtype=torch.long, device=device))
return imgs, pl, tids, apos, atype, anorm, K, c2w, dyn
@requires_cuda
def test_forward_shapes_and_budget(cfg):
model = MapGS(cfg).cuda()
imgs, pl, tids, apos, atype, anorm, K, c2w, _ = _model_inputs(cfg, "cuda")
g = model(imgs, pl, tids, apos, atype, anorm, s_t=0.5)
M = (cfg.model.tokens.n_map + cfg.model.tokens.n_free) * cfg.model.tokens.gaussians_per_token
assert g.means.shape == (1, M, 3)
assert g.colors.shape[-1] == cfg.model.feature_dim
@requires_cuda
def test_bounded_residual_for_map_tokens(cfg):
model = MapGS(cfg).cuda()
imgs, pl, tids, apos, atype, anorm, K, c2w, _ = _model_inputs(cfg, "cuda")
s_t = 0.5
g = model(imgs, pl, tids, apos, atype, anorm, s_t=s_t)
is_map = g.group[0] == GROUP_MAP
# map gaussian means must lie within s_t of their anchor (sqrt(3)*s_t bound on the cube)
means = g.means[0][is_map]
ng = cfg.model.tokens.gaussians_per_token
anchors = apos[0].repeat_interleave(ng, 0)[: means.shape[0]]
dev = (means - anchors).abs().max()
assert dev <= s_t + 1e-4
@requires_cuda
def test_dynamic_placement_moves_only_dynamic(cfg):
model = MapGS(cfg).cuda()
imgs, pl, tids, apos, atype, anorm, K, c2w, dyn = _model_inputs(cfg, "cuda", dynamic=True)
g = model(imgs, pl, tids, apos, atype, anorm, s_t=0.5, dynamic=dyn)
g0 = model.place_dynamics(g, dyn, 0)
g1 = model.place_dynamics(g, dyn, cfg.data.num_frames - 1)
dynm = g.group[0] == GROUP_DYNAMIC
assert (g1.means[0][dynm] - g0.means[0][dynm]).norm(dim=-1).mean() > 0.1
assert torch.allclose(g1.means[0][~dynm], g0.means[0][~dynm])
@requires_cuda
def test_render_and_backward(cfg):
model = MapGS(cfg).cuda()
imgs, pl, tids, apos, atype, anorm, K, c2w, _ = _model_inputs(cfg, "cuda")
g = model(imgs, pl, tids, apos, atype, anorm, s_t=0.5)
ras = GaussianRasterizer()
out = ras.render(g.scene(0), K[0], c2w[0], cfg.data.height, cfg.data.width)
rgb = model.feature_to_rgb(out.color)
assert rgb.shape == (3, 3, cfg.data.height, cfg.data.width)
assert out.aux is not None # lane channel rendered
(rgb.mean() + out.depth.mean()).backward()
assert sum(p.grad.abs().sum() for p in model.parameters() if p.grad is not None) > 0