mapvggt / mapgs /train /render.py
ChenmingWu's picture
Upload folder using huggingface_hub
b2efbe4 verified
Raw
History Blame Contribute Delete
3.48 kB
"""Rendering helpers used by the trainer and evaluator.
Renders a scene's supervision views, grouping by clip frame so dynamic
Gaussians can be rigidly placed at each frame (§2.5) before rasterization. Also
provides batched Plucker computation for the encoder.
"""
from __future__ import annotations
from typing import Optional
import torch
from mapgs.geometry.cameras import plucker_embedding
from mapgs.model.dynamic import place_dynamic_gaussians
from mapgs.render.gaussians import Gaussians
from mapgs.render.rasterizer import GaussianRasterizer
def batched_plucker(K: torch.Tensor, c2w: torch.Tensor, H: int, W: int) -> torch.Tensor:
"""``K [B,V,3,3]``, ``c2w [B,V,4,4]`` -> ``[B,V,6,H,W]``."""
B = K.shape[0]
return torch.stack([plucker_embedding(K[b], c2w[b], H, W) for b in range(B)], 0)
def _as_float(g: Gaussians) -> Gaussians:
return Gaussians(
means=g.means.float(), scales=g.scales.float(), quats=g.quats.float(),
opacities=g.opacities.float(), colors=g.colors.float(),
group=g.group, instance_id=g.instance_id,
aux=None if g.aux is None else g.aux.float(),
)
def extract_scene_dynamic(batch_dynamic: Optional[dict], b: int) -> Optional[dict]:
"""Slice the padded batch dynamic dict to scene ``b`` (None if no instances)."""
if batch_dynamic is None:
return None
valid = batch_dynamic["valid"][b]
if not bool(valid.any()):
return None
return {
"box_centers": batch_dynamic["box_centers"][b][valid],
"box_rots": batch_dynamic["box_rots"][b][valid],
"box_size": batch_dynamic["box_size"][b][valid],
"canon_idx": batch_dynamic["canon_idx"][b][valid],
}
def render_scene_views(
model,
ras: GaussianRasterizer,
g_canon: Gaussians,
dyn_scene: Optional[dict],
K: torch.Tensor, # [N, 3, 3]
c2w: torch.Tensor, # [N, 4, 4]
frames: Optional[torch.Tensor], # [N] long clip-frame indices (for dynamic placement)
H: int,
W: int,
uses_features: bool,
) -> dict:
"""Render N views; returns rgb [N,3,H,W], depth [N,H,W], alpha [N,H,W], lane [N,H,W]."""
N = K.shape[0]
device = g_canon.means.device
g_canon = _as_float(g_canon)
rgb = g_canon.means.new_zeros(N, 3, H, W)
depth = g_canon.means.new_zeros(N, H, W)
alpha = g_canon.means.new_zeros(N, H, W)
lane = g_canon.means.new_zeros(N, H, W)
if dyn_scene is None or frames is None:
groups = [(None, torch.arange(N, device=device))]
else:
groups = [(int(f), (frames == f).nonzero(as_tuple=True)[0]) for f in frames.unique().tolist()]
for fr, sel in groups:
if sel.numel() == 0:
continue
g = g_canon if fr is None else place_dynamic_gaussians(
g_canon, dyn_scene["box_centers"], dyn_scene["box_rots"], dyn_scene["canon_idx"], fr
)
out = ras.render(g, K[sel], c2w[sel], H, W)
col = model.feature_to_rgb(out.color) if uses_features else out.color[:, :3].clamp(0, 1)
# under autocast the UNet may emit bf16; keep buffers in a single (fp32) dtype
rgb = rgb.index_copy(0, sel, col.to(rgb.dtype))
depth = depth.index_copy(0, sel, out.depth.to(depth.dtype))
alpha = alpha.index_copy(0, sel, out.alpha.to(alpha.dtype))
if out.aux is not None:
lane = lane.index_copy(0, sel, out.aux[:, 0].to(lane.dtype))
return {"rgb": rgb, "depth": depth, "alpha": alpha, "lane": lane}