"""Sample / batch containers and collation. A :class:`MapGSSample` is one posed driving clip + its map (ยง2.1). The model's encoder only consumes the *context* views; supervision views are render targets. Per-scene objects (ground field, polylines, variable-instance dynamics) are kept as Python lists in the batch, while fixed-shape tensors (context views, anchors, sampled supervision views) are stacked for the batched encoder/decoder. """ from __future__ import annotations from dataclasses import dataclass, field from typing import List, Optional import torch from mapgs.hdmap.ground_field import GridGroundField @dataclass class MapGSSample: # context (encoder inputs) ctx_images: torch.Tensor # [Vc, 3, H, W] ctx_K: torch.Tensor # [Vc, 3, 3] ctx_c2w: torch.Tensor # [Vc, 4, 4] ctx_tids: torch.Tensor # [Vc] long (context-timestep index) # supervision targets (render + losses) sup_images: torch.Tensor # [Ns, 3, H, W] sup_K: torch.Tensor # [Ns, 3, 3] sup_c2w: torch.Tensor # [Ns, 4, 4] sup_depth: torch.Tensor # [Ns, H, W] GT depth (eval D-RMSE; -1 where unknown) sup_mono: torch.Tensor # [Ns, H, W] frozen mono-depth prior (L_vert) sup_frame: torch.Tensor # [Ns] long (clip frame index, for dynamic placement) # map anchors (MAGT) anchor_pos: torch.Tensor # [Na, 3] anchor_type: torch.Tensor # [Na] long anchor_normal: torch.Tensor # [Na, 3] # per-scene map objects ground: GridGroundField lanes: List[torch.Tensor] boundaries: List[torch.Tensor] # dynamics (I instances over F frames; I may be 0) box_centers: torch.Tensor # [I, F, 3] box_rots: torch.Tensor # [I, F, 3, 3] box_size: torch.Tensor # [I, 3] canon_idx: torch.Tensor # [I] long scene_id: int = 0 scene_scale: float = 1.0 box_valid: Optional[torch.Tensor] = None # [I, F] bool: actor tracked at that frame def _pad_dynamics(batch: List[MapGSSample], device): I_max = max(s.box_centers.shape[0] for s in batch) if I_max == 0: return None F = batch[0].box_centers.shape[1] B = len(batch) centers = torch.zeros(B, I_max, F, 3) rots = torch.eye(3).view(1, 1, 1, 3, 3).repeat(B, I_max, F, 1, 1) size = torch.ones(B, I_max, 3) canon = torch.zeros(B, I_max, dtype=torch.long) valid = torch.zeros(B, I_max, dtype=torch.bool) for b, s in enumerate(batch): I = s.box_centers.shape[0] if I == 0: continue centers[b, :I] = s.box_centers rots[b, :I] = s.box_rots size[b, :I] = s.box_size canon[b, :I] = s.canon_idx valid[b, :I] = True return dict(box_centers=centers, box_rots=rots, box_size=size, canon_idx=canon, valid=valid) def collate_samples(batch: List[MapGSSample]) -> dict: """Stack fixed-shape tensors; keep per-scene objects as lists.""" def stack(attr): return torch.stack([getattr(s, attr) for s in batch], 0) out = { "ctx_images": stack("ctx_images"), "ctx_K": stack("ctx_K"), "ctx_c2w": stack("ctx_c2w"), "ctx_tids": stack("ctx_tids"), "sup_images": stack("sup_images"), "sup_K": stack("sup_K"), "sup_c2w": stack("sup_c2w"), "sup_depth": stack("sup_depth"), "sup_mono": stack("sup_mono"), "sup_frame": stack("sup_frame"), "anchor_pos": stack("anchor_pos"), "anchor_type": stack("anchor_type"), "anchor_normal": stack("anchor_normal"), "grounds": [s.ground for s in batch], "lanes": [s.lanes for s in batch], "boundaries": [s.boundaries for s in batch], "dynamic": _pad_dynamics(batch, None), "scene_ids": [s.scene_id for s in batch], "scene_scale": torch.tensor([s.scene_scale for s in batch]), } return out