| """ |
| StreamPETR -> Atlas detection token adapter WITHOUT modifying StreamPETR source code. |
| |
| Rationale: |
| - StreamPETRHead updates internal memory with Top-K proposals (topk_proposals, typically 256). |
| - After a forward pass, the head stores: |
| - self.memory_embedding: [B, memory_len + topk_proposals, D] (prepend new Top-K) |
| - self.memory_reference_point: [B, memory_len + topk_proposals, 3] |
| (exact ordering: new Top-K are concatenated in front) |
| |
| IMPORTANT: after post_update_memory, memory_reference_point is in the **GLOBAL** |
| coordinate frame (ego_pose applied). We must invert the ego_pose to bring |
| ref_points back to ego frame before normalizing with pc_range. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any, Dict, Optional, Tuple |
|
|
| import torch |
|
|
| PC_RANGE = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0) |
|
|
|
|
| def _normalize_ref_points( |
| ref: torch.Tensor, |
| pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE, |
| ) -> torch.Tensor: |
| pc_min = ref.new_tensor(pc_range[:3]) |
| pc_max = ref.new_tensor(pc_range[3:]) |
| denom = (pc_max - pc_min).clamp(min=1e-6) |
| return ((ref - pc_min) / denom).clamp(0.0, 1.0) |
|
|
|
|
| def _global_to_ego(ref: torch.Tensor, ego_pose: torch.Tensor) -> torch.Tensor: |
| """Transform reference points from global frame back to ego frame. |
| |
| Args: |
| ref: [B, N, 3] in global coordinates |
| ego_pose: [B, 4, 4] ego-to-global transform |
| Returns: |
| [B, N, 3] in ego coordinates |
| """ |
| B, N, _ = ref.shape |
| ones = torch.ones(B, N, 1, device=ref.device, dtype=ref.dtype) |
| ref_homo = torch.cat([ref, ones], dim=-1) |
| ego_pose_inv = torch.inverse(ego_pose) |
| ref_ego = (ego_pose_inv.unsqueeze(1) @ ref_homo.unsqueeze(-1)).squeeze(-1)[..., :3] |
| return ref_ego |
|
|
|
|
| def _nuscenes_ego_to_paper(ref: torch.Tensor) -> torch.Tensor: |
| """Convert nuScenes ego coords to Atlas paper frame. |
| |
| nuScenes ego uses x=forward, y=left. Atlas detection QA uses |
| x=right, y=forward, so (x_p, y_p) = (-y_n, x_n). |
| """ |
| ref_paper = ref.clone() |
| ref_paper[..., 0] = -ref[..., 1] |
| ref_paper[..., 1] = ref[..., 0] |
| return ref_paper |
|
|
|
|
| @torch.no_grad() |
| def extract_streampetr_topk_tokens( |
| pts_bbox_head: Any, |
| topk: int = 256, |
| pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE, |
| ego_pose: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| pts_bbox_head: the StreamPETRHead instance (model.pts_bbox_head) |
| topk: number of tokens to export; should match pts_bbox_head.topk_proposals |
| pc_range: point cloud range used by StreamPETR, for normalizing ref_points |
| ego_pose: [B, 4, 4] ego-to-global transform. If provided, ref_points are |
| transformed back from global to ego frame before normalization. |
| Returns: |
| dict: |
| - detection: [B, topk, D] |
| - detection_ref_points: [B, topk, 3] (normalized to [0, 1]; |
| if ego_pose is provided, aligned to Atlas paper frame) |
| """ |
| if not hasattr(pts_bbox_head, "memory_embedding") or not hasattr(pts_bbox_head, "memory_reference_point"): |
| raise RuntimeError("pts_bbox_head missing memory buffers; ensure you have run a forward pass first.") |
|
|
| mem = pts_bbox_head.memory_embedding |
| ref = pts_bbox_head.memory_reference_point |
| if mem is None or ref is None: |
| raise RuntimeError("pts_bbox_head memory is None; ensure you have run a forward pass and prev_exists is set.") |
|
|
| if mem.ndim != 3 or ref.ndim != 3 or ref.shape[-1] != 3: |
| raise RuntimeError(f"unexpected shapes: memory_embedding={getattr(mem,'shape',None)} memory_reference_point={getattr(ref,'shape',None)}") |
|
|
| B = mem.shape[0] |
| if mem.shape[1] < topk or ref.shape[1] < topk: |
| raise RuntimeError(f"memory length too small: mem_len={mem.shape[1]} ref_len={ref.shape[1]} topk={topk}") |
|
|
| det = mem[:, :topk, :].contiguous() |
| det_ref = ref[:, :topk, :].contiguous() |
|
|
| |
| |
| |
| if ego_pose is not None: |
| det_ref = _global_to_ego(det_ref, ego_pose) |
| det_ref = _nuscenes_ego_to_paper(det_ref) |
|
|
| det_ref = _normalize_ref_points(det_ref, pc_range) |
| return {"detection": det, "detection_ref_points": det_ref} |
|
|