guoyb0's picture
Upload code snapshot (2task with caption)
95f6448 verified
"""TopoMLP -> Atlas map token adapter.
Paper-aligned: Top-K selection from TopoMLP decoder outputs,
followed by a single linear projector (handled by AtlasUnifiedProjector).
No Perceiver resampler -- queries and ref_points are passed through directly.
Reference points are rotated from nuScenes ego frame (x=forward, y=left)
to the Atlas paper frame (x=right, y=forward) and normalized with the same
PC_RANGE as the detection branch so that the shared projector_rp receives
a unified coordinate space.
"""
from __future__ import annotations
from typing import Dict, Optional, Tuple
import torch
_DET_PC_RANGE = (-50.0, -50.0, -5.0, 50.0, 50.0, 3.0)
def _lane_control_points_to_center_xyz(lane_preds: torch.Tensor) -> torch.Tensor:
if lane_preds.ndim != 3 or lane_preds.shape[-1] % 3 != 0:
raise ValueError(f"lane_preds expected [B,Q,3*K], got {tuple(lane_preds.shape)}")
B, Q, D = lane_preds.shape
K = D // 3
pts = lane_preds.view(B, Q, K, 3)
return pts.mean(dim=2)
def _nuscenes_ego_to_paper(ref: torch.Tensor) -> torch.Tensor:
"""Rotate nuScenes ego coords (x=fwd, y=left) to paper frame (x=right, y=fwd)."""
out = ref.clone()
out[..., 0] = -ref[..., 1]
out[..., 1] = ref[..., 0]
return out
def _normalize_xyz(xyz: torch.Tensor, xyz_min: torch.Tensor, xyz_max: torch.Tensor) -> torch.Tensor:
denom = (xyz_max - xyz_min).clamp(min=1e-6)
out = (xyz - xyz_min) / denom
return out.clamp(0.0, 1.0)
class TopoMLPToAtlasMapTokens(torch.nn.Module):
"""Select top-K lane queries from TopoMLP and return them with reference points.
Aligned with Atlas paper Section 3.1:
"these queries are streamlined through a single linear layer"
The linear projection itself is in AtlasUnifiedProjector.projector_map.
This module only does Top-K selection + ref_point computation.
"""
def __init__(
self,
num_map_tokens: int = 256,
hidden_size: int = 256,
bev_range: Tuple[float, float, float, float, float, float] = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0),
**kwargs,
):
super().__init__()
self.num_map_tokens = int(num_map_tokens)
self.hidden_size = int(hidden_size)
self.bev_range = tuple(float(x) for x in bev_range)
rp_min = torch.tensor(_DET_PC_RANGE[:3], dtype=torch.float32)
rp_max = torch.tensor(_DET_PC_RANGE[3:], dtype=torch.float32)
self.register_buffer("_rp_min", rp_min, persistent=False)
self.register_buffer("_rp_max", rp_max, persistent=False)
@torch.no_grad()
def infer_lane_centers_from_outs(self, outs: Dict) -> torch.Tensor:
one2one_preds = outs["all_lc_preds_list"][-1]
return _lane_control_points_to_center_xyz(one2one_preds)
def forward(self, outs: Dict) -> Dict[str, torch.Tensor]:
lane_tokens = outs["lc_outs_dec_list"][-1]
lane_scores = outs["all_lc_cls_scores_list"][-1].squeeze(-1)
lane_centers = self.infer_lane_centers_from_outs(outs)
lane_centers_paper = _nuscenes_ego_to_paper(lane_centers)
lane_ref_norm = _normalize_xyz(lane_centers_paper, self._rp_min, self._rp_max)
B, N, D = lane_tokens.shape
if N == 0:
return {
"map": torch.zeros(B, self.num_map_tokens, D, dtype=lane_tokens.dtype, device=lane_tokens.device),
"map_ref_points": torch.zeros(B, self.num_map_tokens, 3, dtype=lane_ref_norm.dtype, device=lane_ref_norm.device),
}
k = min(self.num_map_tokens, N)
topk_idx = torch.topk(lane_scores, k=k, dim=1, largest=True, sorted=True).indices
tok_idx = topk_idx.unsqueeze(-1).expand(-1, -1, D)
ref_idx = topk_idx.unsqueeze(-1).expand(-1, -1, 3)
map_tokens = lane_tokens.gather(dim=1, index=tok_idx)
map_ref = lane_ref_norm.gather(dim=1, index=ref_idx)
if k < self.num_map_tokens:
pad_t = torch.zeros(B, self.num_map_tokens - k, D, dtype=map_tokens.dtype, device=map_tokens.device)
pad_r = torch.zeros(B, self.num_map_tokens - k, 3, dtype=map_ref.dtype, device=map_ref.device)
map_tokens = torch.cat([map_tokens, pad_t], dim=1)
map_ref = torch.cat([map_ref, pad_r], dim=1)
return {"map": map_tokens, "map_ref_points": map_ref}