| """Sample / batch containers and collation. |
| |
| A :class:`MapGSSample` is one posed driving clip + its map (§2.1). The model's |
| encoder only consumes the *context* views; supervision views are render targets. |
| Per-scene objects (ground field, polylines, variable-instance dynamics) are kept |
| as Python lists in the batch, while fixed-shape tensors (context views, anchors, |
| sampled supervision views) are stacked for the batched encoder/decoder. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import List, Optional |
|
|
| import torch |
|
|
| from mapgs.hdmap.ground_field import GridGroundField |
|
|
|
|
| @dataclass |
| class MapGSSample: |
| |
| ctx_images: torch.Tensor |
| ctx_K: torch.Tensor |
| ctx_c2w: torch.Tensor |
| ctx_tids: torch.Tensor |
|
|
| |
| sup_images: torch.Tensor |
| sup_K: torch.Tensor |
| sup_c2w: torch.Tensor |
| sup_depth: torch.Tensor |
| sup_mono: torch.Tensor |
| sup_frame: torch.Tensor |
|
|
| |
| anchor_pos: torch.Tensor |
| anchor_type: torch.Tensor |
| anchor_normal: torch.Tensor |
|
|
| |
| ground: GridGroundField |
| lanes: List[torch.Tensor] |
| boundaries: List[torch.Tensor] |
|
|
| |
| box_centers: torch.Tensor |
| box_rots: torch.Tensor |
| box_size: torch.Tensor |
| canon_idx: torch.Tensor |
|
|
| scene_id: int = 0 |
| scene_scale: float = 1.0 |
| box_valid: Optional[torch.Tensor] = None |
|
|
|
|
| def _pad_dynamics(batch: List[MapGSSample], device): |
| I_max = max(s.box_centers.shape[0] for s in batch) |
| if I_max == 0: |
| return None |
| F = batch[0].box_centers.shape[1] |
| B = len(batch) |
| centers = torch.zeros(B, I_max, F, 3) |
| rots = torch.eye(3).view(1, 1, 1, 3, 3).repeat(B, I_max, F, 1, 1) |
| size = torch.ones(B, I_max, 3) |
| canon = torch.zeros(B, I_max, dtype=torch.long) |
| valid = torch.zeros(B, I_max, dtype=torch.bool) |
| for b, s in enumerate(batch): |
| I = s.box_centers.shape[0] |
| if I == 0: |
| continue |
| centers[b, :I] = s.box_centers |
| rots[b, :I] = s.box_rots |
| size[b, :I] = s.box_size |
| canon[b, :I] = s.canon_idx |
| valid[b, :I] = True |
| return dict(box_centers=centers, box_rots=rots, box_size=size, canon_idx=canon, valid=valid) |
|
|
|
|
| def collate_samples(batch: List[MapGSSample]) -> dict: |
| """Stack fixed-shape tensors; keep per-scene objects as lists.""" |
| def stack(attr): |
| return torch.stack([getattr(s, attr) for s in batch], 0) |
|
|
| out = { |
| "ctx_images": stack("ctx_images"), |
| "ctx_K": stack("ctx_K"), |
| "ctx_c2w": stack("ctx_c2w"), |
| "ctx_tids": stack("ctx_tids"), |
| "sup_images": stack("sup_images"), |
| "sup_K": stack("sup_K"), |
| "sup_c2w": stack("sup_c2w"), |
| "sup_depth": stack("sup_depth"), |
| "sup_mono": stack("sup_mono"), |
| "sup_frame": stack("sup_frame"), |
| "anchor_pos": stack("anchor_pos"), |
| "anchor_type": stack("anchor_type"), |
| "anchor_normal": stack("anchor_normal"), |
| "grounds": [s.ground for s in batch], |
| "lanes": [s.lanes for s in batch], |
| "boundaries": [s.boundaries for s in batch], |
| "dynamic": _pad_dynamics(batch, None), |
| "scene_ids": [s.scene_id for s in batch], |
| "scene_scale": torch.tensor([s.scene_scale for s in batch]), |
| } |
| return out |
|
|