Atlas-online / src /model /streampetr_adapter.py
guoyb0's picture
Add files using upload-large-folder tool
9fe982a verified
"""
StreamPETR -> Atlas detection token adapter WITHOUT modifying StreamPETR source code.
Rationale:
- StreamPETRHead updates internal memory with Top-K proposals (topk_proposals, typically 256).
- After a forward pass, the head stores:
- self.memory_embedding: [B, memory_len + topk_proposals, D] (prepend new Top-K)
- self.memory_reference_point: [B, memory_len + topk_proposals, 3]
(exact ordering: new Top-K are concatenated in front)
IMPORTANT: after post_update_memory, memory_reference_point is in the **GLOBAL**
coordinate frame (ego_pose applied). We must invert the ego_pose to bring
ref_points back to ego frame before normalizing with pc_range.
"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
import torch
PC_RANGE = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0)
def _normalize_ref_points(
ref: torch.Tensor,
pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE,
) -> torch.Tensor:
pc_min = ref.new_tensor(pc_range[:3])
pc_max = ref.new_tensor(pc_range[3:])
denom = (pc_max - pc_min).clamp(min=1e-6)
return ((ref - pc_min) / denom).clamp(0.0, 1.0)
def _global_to_ego(ref: torch.Tensor, ego_pose: torch.Tensor) -> torch.Tensor:
"""Transform reference points from global frame back to ego frame.
Args:
ref: [B, N, 3] in global coordinates
ego_pose: [B, 4, 4] ego-to-global transform
Returns:
[B, N, 3] in ego coordinates
"""
B, N, _ = ref.shape
ones = torch.ones(B, N, 1, device=ref.device, dtype=ref.dtype)
ref_homo = torch.cat([ref, ones], dim=-1) # [B, N, 4]
ego_pose_inv = torch.inverse(ego_pose) # [B, 4, 4]
ref_ego = (ego_pose_inv.unsqueeze(1) @ ref_homo.unsqueeze(-1)).squeeze(-1)[..., :3]
return ref_ego
def _nuscenes_ego_to_paper(ref: torch.Tensor) -> torch.Tensor:
"""Convert nuScenes ego coords to Atlas paper frame.
nuScenes ego uses x=forward, y=left. Atlas detection QA uses
x=right, y=forward, so (x_p, y_p) = (-y_n, x_n).
"""
ref_paper = ref.clone()
ref_paper[..., 0] = -ref[..., 1]
ref_paper[..., 1] = ref[..., 0]
return ref_paper
@torch.no_grad()
def extract_streampetr_topk_tokens(
pts_bbox_head: Any,
topk: int = 256,
pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE,
ego_pose: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Args:
pts_bbox_head: the StreamPETRHead instance (model.pts_bbox_head)
topk: number of tokens to export; should match pts_bbox_head.topk_proposals
pc_range: point cloud range used by StreamPETR, for normalizing ref_points
ego_pose: [B, 4, 4] ego-to-global transform. If provided, ref_points are
transformed back from global to ego frame before normalization.
Returns:
dict:
- detection: [B, topk, D]
- detection_ref_points: [B, topk, 3] (normalized to [0, 1];
if ego_pose is provided, aligned to Atlas paper frame)
"""
if not hasattr(pts_bbox_head, "memory_embedding") or not hasattr(pts_bbox_head, "memory_reference_point"):
raise RuntimeError("pts_bbox_head missing memory buffers; ensure you have run a forward pass first.")
mem = pts_bbox_head.memory_embedding
ref = pts_bbox_head.memory_reference_point
if mem is None or ref is None:
raise RuntimeError("pts_bbox_head memory is None; ensure you have run a forward pass and prev_exists is set.")
if mem.ndim != 3 or ref.ndim != 3 or ref.shape[-1] != 3:
raise RuntimeError(f"unexpected shapes: memory_embedding={getattr(mem,'shape',None)} memory_reference_point={getattr(ref,'shape',None)}")
B = mem.shape[0]
if mem.shape[1] < topk or ref.shape[1] < topk:
raise RuntimeError(f"memory length too small: mem_len={mem.shape[1]} ref_len={ref.shape[1]} topk={topk}")
det = mem[:, :topk, :].contiguous()
det_ref = ref[:, :topk, :].contiguous()
# post_update_memory transforms ref_points to global frame via ego_pose.
# We invert this to get ego-frame coordinates, then rotate to Atlas paper
# frame so projector_rp sees the same XY semantics as detection QA/GT.
if ego_pose is not None:
det_ref = _global_to_ego(det_ref, ego_pose)
det_ref = _nuscenes_ego_to_paper(det_ref)
det_ref = _normalize_ref_points(det_ref, pc_range)
return {"detection": det, "detection_ref_points": det_ref}