| """BEVFromer model implementation. |
| |
| This file composes the operations associated with BEVFormer |
| `https://arxiv.org/abs/2203.17270` into the full model implementation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import copy |
| from typing import TypedDict |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from vis4d.common.ckpt import load_model_checkpoint |
| from vis4d.op.base import BaseModel |
| from vis4d.op.detect3d.bevformer import BEVFormerHead, GridMask |
| from vis4d.op.detect3d.common import Detect3DOut |
| from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock |
|
|
| REV_KEYS = [ |
| (r"^img_backbone\.", "basemodel."), |
| (r"^img_neck.lateral_convs\.", "fpn.inner_blocks."), |
| (r"^img_neck.fpn_convs\.", "fpn.layer_blocks."), |
| (r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."), |
| (r"\.conv.weight", ".weight"), |
| (r"\.conv.bias", ".bias"), |
| ] |
|
|
|
|
| class PrevFrameInfo(TypedDict): |
| """Previous frame information.""" |
|
|
| scene_name: str |
| prev_bev: Tensor | None |
| prev_pos: Tensor |
| prev_angle: Tensor |
|
|
|
|
| class BEVFormer(nn.Module): |
| """BEVFormer 3D Detector.""" |
|
|
| def __init__( |
| self, |
| basemodel: BaseModel, |
| fpn: FPN | None = None, |
| pts_bbox_head: BEVFormerHead | None = None, |
| weights: str | None = None, |
| ) -> None: |
| """Creates an instance of the class. |
| |
| Args: |
| basemodel (BaseModel): Base model network. |
| fpn (FPN, optional): Feature Pyramid Network. Defaults to None. If |
| None, a default FPN will be used. |
| pts_bbox_head (BEVFormerHead, optional): BEVFormer head. Defaults |
| to None. If None, a default BEVFormer head will be used. |
| weights (str, optional): Path to the checkpoint to load. Defaults |
| to None. |
| """ |
| super().__init__() |
| self.basemodel = basemodel |
| self.fpn = fpn or FPN( |
| self.basemodel.out_channels[3:], |
| 256, |
| extra_blocks=ExtraFPNBlock( |
| extra_levels=1, in_channels=256, out_channels=256 |
| ), |
| start_index=3, |
| ) |
|
|
| self.grid_mask = GridMask( |
| True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7 |
| ) |
|
|
| self.pts_bbox_head = pts_bbox_head or BEVFormerHead() |
|
|
| |
| self.prev_frame_info = PrevFrameInfo( |
| scene_name="", |
| prev_bev=None, |
| prev_pos=torch.zeros(3), |
| prev_angle=torch.zeros(1), |
| ) |
|
|
| if weights is not None: |
| load_model_checkpoint(self, weights, rev_keys=REV_KEYS) |
|
|
| def extract_feat(self, images_list: list[Tensor]) -> list[Tensor]: |
| """Extract features of images.""" |
| n = len(images_list) |
| b = images_list[0].shape[0] |
| images = torch.stack(images_list, dim=1) |
| images = images.view(-1, *images.shape[2:]) |
|
|
| |
| if self.training: |
| images = self.grid_mask(images) |
|
|
| features = self.basemodel(images) |
| features = self.fpn(features)[self.fpn.start_index :] |
|
|
| img_feats = [] |
| for img_feat in features: |
| _, c, h, w = img_feat.size() |
| img_feats.append(img_feat.view(b, n, c, h, w)) |
|
|
| return img_feats |
|
|
| def forward( |
| self, |
| images: list[Tensor], |
| can_bus: list[list[float]], |
| scene_names: list[str], |
| cam_intrinsics: list[Tensor], |
| cam_extrinsics: list[Tensor], |
| lidar_extrinsics: list[Tensor], |
| ) -> Detect3DOut: |
| """Forward.""" |
| |
| lidar_extrinsics_tensor = lidar_extrinsics[0] |
| can_bus_tensor = torch.tensor( |
| can_bus, dtype=torch.float32, device=images[0].device |
| ) |
|
|
| if scene_names[0] != self.prev_frame_info["scene_name"]: |
| |
| self.prev_frame_info["prev_bev"] = None |
|
|
| |
| self.prev_frame_info["scene_name"] = scene_names[0] |
|
|
| |
| tmp_pos = copy.deepcopy(can_bus_tensor[0][:3]) |
| tmp_angle = copy.deepcopy(can_bus_tensor[0][-1]) |
| if self.prev_frame_info["prev_bev"] is not None: |
| can_bus_tensor[0][:3] -= self.prev_frame_info["prev_pos"] |
| can_bus_tensor[0][-1] -= self.prev_frame_info["prev_angle"] |
| else: |
| can_bus_tensor[0][:3] = 0 |
| can_bus_tensor[0][-1] = 0 |
|
|
| images_hw = (int(images[0].shape[-2]), int(images[0].shape[-1])) |
| img_feats = self.extract_feat(images) |
|
|
| out, bev_embed = self.pts_bbox_head( |
| img_feats, |
| can_bus_tensor, |
| images_hw, |
| cam_intrinsics, |
| cam_extrinsics, |
| lidar_extrinsics_tensor, |
| prev_bev=self.prev_frame_info["prev_bev"], |
| ) |
|
|
| |
| |
| self.prev_frame_info["prev_pos"] = tmp_pos |
| self.prev_frame_info["prev_angle"] = tmp_angle |
| self.prev_frame_info["prev_bev"] = bev_embed |
|
|
| return out |
|
|