| """TopoMLP -> Atlas map token adapter. |
| |
| Paper-aligned: Top-K selection from TopoMLP decoder outputs, |
| followed by a single linear projector (handled by AtlasUnifiedProjector). |
| No Perceiver resampler -- queries and ref_points are passed through directly. |
| |
| Reference points are rotated from nuScenes ego frame (x=forward, y=left) |
| to the Atlas paper frame (x=right, y=forward) and normalized with the same |
| PC_RANGE as the detection branch so that the shared projector_rp receives |
| a unified coordinate space. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
|
|
| _DET_PC_RANGE = (-50.0, -50.0, -5.0, 50.0, 50.0, 3.0) |
|
|
|
|
| def _lane_control_points_to_center_xyz(lane_preds: torch.Tensor) -> torch.Tensor: |
| if lane_preds.ndim != 3 or lane_preds.shape[-1] % 3 != 0: |
| raise ValueError(f"lane_preds expected [B,Q,3*K], got {tuple(lane_preds.shape)}") |
| B, Q, D = lane_preds.shape |
| K = D // 3 |
| pts = lane_preds.view(B, Q, K, 3) |
| return pts.mean(dim=2) |
|
|
|
|
| def _nuscenes_ego_to_paper(ref: torch.Tensor) -> torch.Tensor: |
| """Rotate nuScenes ego coords (x=fwd, y=left) to paper frame (x=right, y=fwd).""" |
| out = ref.clone() |
| out[..., 0] = -ref[..., 1] |
| out[..., 1] = ref[..., 0] |
| return out |
|
|
|
|
| def _normalize_xyz(xyz: torch.Tensor, xyz_min: torch.Tensor, xyz_max: torch.Tensor) -> torch.Tensor: |
| denom = (xyz_max - xyz_min).clamp(min=1e-6) |
| out = (xyz - xyz_min) / denom |
| return out.clamp(0.0, 1.0) |
|
|
|
|
| class TopoMLPToAtlasMapTokens(torch.nn.Module): |
| """Select top-K lane queries from TopoMLP and return them with reference points. |
| |
| Aligned with Atlas paper Section 3.1: |
| "these queries are streamlined through a single linear layer" |
| The linear projection itself is in AtlasUnifiedProjector.projector_map. |
| This module only does Top-K selection + ref_point computation. |
| """ |
|
|
| def __init__( |
| self, |
| num_map_tokens: int = 256, |
| hidden_size: int = 256, |
| bev_range: Tuple[float, float, float, float, float, float] = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0), |
| **kwargs, |
| ): |
| super().__init__() |
| self.num_map_tokens = int(num_map_tokens) |
| self.hidden_size = int(hidden_size) |
| self.bev_range = tuple(float(x) for x in bev_range) |
|
|
| rp_min = torch.tensor(_DET_PC_RANGE[:3], dtype=torch.float32) |
| rp_max = torch.tensor(_DET_PC_RANGE[3:], dtype=torch.float32) |
| self.register_buffer("_rp_min", rp_min, persistent=False) |
| self.register_buffer("_rp_max", rp_max, persistent=False) |
|
|
| @torch.no_grad() |
| def infer_lane_centers_from_outs(self, outs: Dict) -> torch.Tensor: |
| one2one_preds = outs["all_lc_preds_list"][-1] |
| return _lane_control_points_to_center_xyz(one2one_preds) |
|
|
| def forward(self, outs: Dict) -> Dict[str, torch.Tensor]: |
| lane_tokens = outs["lc_outs_dec_list"][-1] |
| lane_scores = outs["all_lc_cls_scores_list"][-1].squeeze(-1) |
|
|
| lane_centers = self.infer_lane_centers_from_outs(outs) |
| lane_centers_paper = _nuscenes_ego_to_paper(lane_centers) |
| lane_ref_norm = _normalize_xyz(lane_centers_paper, self._rp_min, self._rp_max) |
|
|
| B, N, D = lane_tokens.shape |
| if N == 0: |
| return { |
| "map": torch.zeros(B, self.num_map_tokens, D, dtype=lane_tokens.dtype, device=lane_tokens.device), |
| "map_ref_points": torch.zeros(B, self.num_map_tokens, 3, dtype=lane_ref_norm.dtype, device=lane_ref_norm.device), |
| } |
|
|
| k = min(self.num_map_tokens, N) |
| topk_idx = torch.topk(lane_scores, k=k, dim=1, largest=True, sorted=True).indices |
| tok_idx = topk_idx.unsqueeze(-1).expand(-1, -1, D) |
| ref_idx = topk_idx.unsqueeze(-1).expand(-1, -1, 3) |
| map_tokens = lane_tokens.gather(dim=1, index=tok_idx) |
| map_ref = lane_ref_norm.gather(dim=1, index=ref_idx) |
|
|
| if k < self.num_map_tokens: |
| pad_t = torch.zeros(B, self.num_map_tokens - k, D, dtype=map_tokens.dtype, device=map_tokens.device) |
| pad_r = torch.zeros(B, self.num_map_tokens - k, 3, dtype=map_ref.dtype, device=map_ref.device) |
| map_tokens = torch.cat([map_tokens, pad_t], dim=1) |
| map_ref = torch.cat([map_ref, pad_r], dim=1) |
|
|
| return {"map": map_tokens, "map_ref_points": map_ref} |
|
|