| """Rendering helpers used by the trainer and evaluator. |
| |
| Renders a scene's supervision views, grouping by clip frame so dynamic |
| Gaussians can be rigidly placed at each frame (§2.5) before rasterization. Also |
| provides batched Plucker computation for the encoder. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
|
|
| from mapgs.geometry.cameras import plucker_embedding |
| from mapgs.model.dynamic import place_dynamic_gaussians |
| from mapgs.render.gaussians import Gaussians |
| from mapgs.render.rasterizer import GaussianRasterizer |
|
|
|
|
| def batched_plucker(K: torch.Tensor, c2w: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| """``K [B,V,3,3]``, ``c2w [B,V,4,4]`` -> ``[B,V,6,H,W]``.""" |
| B = K.shape[0] |
| return torch.stack([plucker_embedding(K[b], c2w[b], H, W) for b in range(B)], 0) |
|
|
|
|
| def _as_float(g: Gaussians) -> Gaussians: |
| return Gaussians( |
| means=g.means.float(), scales=g.scales.float(), quats=g.quats.float(), |
| opacities=g.opacities.float(), colors=g.colors.float(), |
| group=g.group, instance_id=g.instance_id, |
| aux=None if g.aux is None else g.aux.float(), |
| ) |
|
|
|
|
| def extract_scene_dynamic(batch_dynamic: Optional[dict], b: int) -> Optional[dict]: |
| """Slice the padded batch dynamic dict to scene ``b`` (None if no instances).""" |
| if batch_dynamic is None: |
| return None |
| valid = batch_dynamic["valid"][b] |
| if not bool(valid.any()): |
| return None |
| return { |
| "box_centers": batch_dynamic["box_centers"][b][valid], |
| "box_rots": batch_dynamic["box_rots"][b][valid], |
| "box_size": batch_dynamic["box_size"][b][valid], |
| "canon_idx": batch_dynamic["canon_idx"][b][valid], |
| } |
|
|
|
|
| def render_scene_views( |
| model, |
| ras: GaussianRasterizer, |
| g_canon: Gaussians, |
| dyn_scene: Optional[dict], |
| K: torch.Tensor, |
| c2w: torch.Tensor, |
| frames: Optional[torch.Tensor], |
| H: int, |
| W: int, |
| uses_features: bool, |
| ) -> dict: |
| """Render N views; returns rgb [N,3,H,W], depth [N,H,W], alpha [N,H,W], lane [N,H,W].""" |
| N = K.shape[0] |
| device = g_canon.means.device |
| g_canon = _as_float(g_canon) |
| rgb = g_canon.means.new_zeros(N, 3, H, W) |
| depth = g_canon.means.new_zeros(N, H, W) |
| alpha = g_canon.means.new_zeros(N, H, W) |
| lane = g_canon.means.new_zeros(N, H, W) |
|
|
| if dyn_scene is None or frames is None: |
| groups = [(None, torch.arange(N, device=device))] |
| else: |
| groups = [(int(f), (frames == f).nonzero(as_tuple=True)[0]) for f in frames.unique().tolist()] |
|
|
| for fr, sel in groups: |
| if sel.numel() == 0: |
| continue |
| g = g_canon if fr is None else place_dynamic_gaussians( |
| g_canon, dyn_scene["box_centers"], dyn_scene["box_rots"], dyn_scene["canon_idx"], fr |
| ) |
| out = ras.render(g, K[sel], c2w[sel], H, W) |
| col = model.feature_to_rgb(out.color) if uses_features else out.color[:, :3].clamp(0, 1) |
| |
| rgb = rgb.index_copy(0, sel, col.to(rgb.dtype)) |
| depth = depth.index_copy(0, sel, out.depth.to(depth.dtype)) |
| alpha = alpha.index_copy(0, sel, out.alpha.to(alpha.dtype)) |
| if out.aux is not None: |
| lane = lane.index_copy(0, sel, out.aux[:, 0].to(lane.dtype)) |
| return {"rgb": rgb, "depth": depth, "alpha": alpha, "lane": lane} |
|
|