mapvggt / mapgs /data /types.py
ChenmingWu's picture
Upload folder using huggingface_hub
b2efbe4 verified
Raw
History Blame Contribute Delete
3.88 kB
"""Sample / batch containers and collation.
A :class:`MapGSSample` is one posed driving clip + its map (§2.1). The model's
encoder only consumes the *context* views; supervision views are render targets.
Per-scene objects (ground field, polylines, variable-instance dynamics) are kept
as Python lists in the batch, while fixed-shape tensors (context views, anchors,
sampled supervision views) are stacked for the batched encoder/decoder.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional
import torch
from mapgs.hdmap.ground_field import GridGroundField
@dataclass
class MapGSSample:
# context (encoder inputs)
ctx_images: torch.Tensor # [Vc, 3, H, W]
ctx_K: torch.Tensor # [Vc, 3, 3]
ctx_c2w: torch.Tensor # [Vc, 4, 4]
ctx_tids: torch.Tensor # [Vc] long (context-timestep index)
# supervision targets (render + losses)
sup_images: torch.Tensor # [Ns, 3, H, W]
sup_K: torch.Tensor # [Ns, 3, 3]
sup_c2w: torch.Tensor # [Ns, 4, 4]
sup_depth: torch.Tensor # [Ns, H, W] GT depth (eval D-RMSE; -1 where unknown)
sup_mono: torch.Tensor # [Ns, H, W] frozen mono-depth prior (L_vert)
sup_frame: torch.Tensor # [Ns] long (clip frame index, for dynamic placement)
# map anchors (MAGT)
anchor_pos: torch.Tensor # [Na, 3]
anchor_type: torch.Tensor # [Na] long
anchor_normal: torch.Tensor # [Na, 3]
# per-scene map objects
ground: GridGroundField
lanes: List[torch.Tensor]
boundaries: List[torch.Tensor]
# dynamics (I instances over F frames; I may be 0)
box_centers: torch.Tensor # [I, F, 3]
box_rots: torch.Tensor # [I, F, 3, 3]
box_size: torch.Tensor # [I, 3]
canon_idx: torch.Tensor # [I] long
scene_id: int = 0
scene_scale: float = 1.0
box_valid: Optional[torch.Tensor] = None # [I, F] bool: actor tracked at that frame
def _pad_dynamics(batch: List[MapGSSample], device):
I_max = max(s.box_centers.shape[0] for s in batch)
if I_max == 0:
return None
F = batch[0].box_centers.shape[1]
B = len(batch)
centers = torch.zeros(B, I_max, F, 3)
rots = torch.eye(3).view(1, 1, 1, 3, 3).repeat(B, I_max, F, 1, 1)
size = torch.ones(B, I_max, 3)
canon = torch.zeros(B, I_max, dtype=torch.long)
valid = torch.zeros(B, I_max, dtype=torch.bool)
for b, s in enumerate(batch):
I = s.box_centers.shape[0]
if I == 0:
continue
centers[b, :I] = s.box_centers
rots[b, :I] = s.box_rots
size[b, :I] = s.box_size
canon[b, :I] = s.canon_idx
valid[b, :I] = True
return dict(box_centers=centers, box_rots=rots, box_size=size, canon_idx=canon, valid=valid)
def collate_samples(batch: List[MapGSSample]) -> dict:
"""Stack fixed-shape tensors; keep per-scene objects as lists."""
def stack(attr):
return torch.stack([getattr(s, attr) for s in batch], 0)
out = {
"ctx_images": stack("ctx_images"),
"ctx_K": stack("ctx_K"),
"ctx_c2w": stack("ctx_c2w"),
"ctx_tids": stack("ctx_tids"),
"sup_images": stack("sup_images"),
"sup_K": stack("sup_K"),
"sup_c2w": stack("sup_c2w"),
"sup_depth": stack("sup_depth"),
"sup_mono": stack("sup_mono"),
"sup_frame": stack("sup_frame"),
"anchor_pos": stack("anchor_pos"),
"anchor_type": stack("anchor_type"),
"anchor_normal": stack("anchor_normal"),
"grounds": [s.ground for s in batch],
"lanes": [s.lanes for s in batch],
"boundaries": [s.boundaries for s in batch],
"dynamic": _pad_dynamics(batch, None),
"scene_ids": [s.scene_id for s in batch],
"scene_scale": torch.tensor([s.scene_scale for s in batch]),
}
return out