"""Rendering helpers used by the trainer and evaluator. Renders a scene's supervision views, grouping by clip frame so dynamic Gaussians can be rigidly placed at each frame (ยง2.5) before rasterization. Also provides batched Plucker computation for the encoder. """ from __future__ import annotations from typing import Optional import torch from mapgs.geometry.cameras import plucker_embedding from mapgs.model.dynamic import place_dynamic_gaussians from mapgs.render.gaussians import Gaussians from mapgs.render.rasterizer import GaussianRasterizer def batched_plucker(K: torch.Tensor, c2w: torch.Tensor, H: int, W: int) -> torch.Tensor: """``K [B,V,3,3]``, ``c2w [B,V,4,4]`` -> ``[B,V,6,H,W]``.""" B = K.shape[0] return torch.stack([plucker_embedding(K[b], c2w[b], H, W) for b in range(B)], 0) def _as_float(g: Gaussians) -> Gaussians: return Gaussians( means=g.means.float(), scales=g.scales.float(), quats=g.quats.float(), opacities=g.opacities.float(), colors=g.colors.float(), group=g.group, instance_id=g.instance_id, aux=None if g.aux is None else g.aux.float(), ) def extract_scene_dynamic(batch_dynamic: Optional[dict], b: int) -> Optional[dict]: """Slice the padded batch dynamic dict to scene ``b`` (None if no instances).""" if batch_dynamic is None: return None valid = batch_dynamic["valid"][b] if not bool(valid.any()): return None return { "box_centers": batch_dynamic["box_centers"][b][valid], "box_rots": batch_dynamic["box_rots"][b][valid], "box_size": batch_dynamic["box_size"][b][valid], "canon_idx": batch_dynamic["canon_idx"][b][valid], } def render_scene_views( model, ras: GaussianRasterizer, g_canon: Gaussians, dyn_scene: Optional[dict], K: torch.Tensor, # [N, 3, 3] c2w: torch.Tensor, # [N, 4, 4] frames: Optional[torch.Tensor], # [N] long clip-frame indices (for dynamic placement) H: int, W: int, uses_features: bool, ) -> dict: """Render N views; returns rgb [N,3,H,W], depth [N,H,W], alpha [N,H,W], lane [N,H,W].""" N = K.shape[0] device = g_canon.means.device g_canon = _as_float(g_canon) rgb = g_canon.means.new_zeros(N, 3, H, W) depth = g_canon.means.new_zeros(N, H, W) alpha = g_canon.means.new_zeros(N, H, W) lane = g_canon.means.new_zeros(N, H, W) if dyn_scene is None or frames is None: groups = [(None, torch.arange(N, device=device))] else: groups = [(int(f), (frames == f).nonzero(as_tuple=True)[0]) for f in frames.unique().tolist()] for fr, sel in groups: if sel.numel() == 0: continue g = g_canon if fr is None else place_dynamic_gaussians( g_canon, dyn_scene["box_centers"], dyn_scene["box_rots"], dyn_scene["canon_idx"], fr ) out = ras.render(g, K[sel], c2w[sel], H, W) col = model.feature_to_rgb(out.color) if uses_features else out.color[:, :3].clamp(0, 1) # under autocast the UNet may emit bf16; keep buffers in a single (fp32) dtype rgb = rgb.index_copy(0, sel, col.to(rgb.dtype)) depth = depth.index_copy(0, sel, out.depth.to(depth.dtype)) alpha = alpha.index_copy(0, sel, out.alpha.to(alpha.dtype)) if out.aux is not None: lane = lane.index_copy(0, sel, out.aux[:, 0].to(lane.dtype)) return {"rgb": rgb, "depth": depth, "alpha": alpha, "lane": lane}