| from typing import Dict |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from navsim.agents.transfuser.transfuser_config import TransfuserConfig |
| from navsim.agents.transfuser.transfuser_backbone import TransfuserBackbone |
| from navsim.agents.transfuser.transfuser_features import BoundingBox2DIndex |
| from navsim.common.enums import StateSE2Index |
|
|
|
|
| class TransfuserModel(nn.Module): |
| """Torch module for Transfuser.""" |
|
|
| def __init__(self, config: TransfuserConfig): |
| """ |
| Initializes TransFuser torch module. |
| :param config: global config dataclass of TransFuser. |
| """ |
|
|
| super().__init__() |
|
|
| self._query_splits = [ |
| 1, |
| config.num_bounding_boxes, |
| ] |
|
|
| self._config = config |
| self._backbone = TransfuserBackbone(config) |
|
|
| self._keyval_embedding = nn.Embedding(8**2 + 1, config.tf_d_model) |
| self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model) |
|
|
| |
| self._bev_downscale = nn.Conv2d(512, config.tf_d_model, kernel_size=1) |
| self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model) |
|
|
| self._bev_semantic_head = nn.Sequential( |
| nn.Conv2d( |
| config.bev_features_channels, |
| config.bev_features_channels, |
| kernel_size=(3, 3), |
| stride=1, |
| padding=(1, 1), |
| bias=True, |
| ), |
| nn.ReLU(inplace=True), |
| nn.Conv2d( |
| config.bev_features_channels, |
| config.num_bev_classes, |
| kernel_size=(1, 1), |
| stride=1, |
| padding=0, |
| bias=True, |
| ), |
| nn.Upsample( |
| size=(config.lidar_resolution_height // 2, config.lidar_resolution_width), |
| mode="bilinear", |
| align_corners=False, |
| ), |
| ) |
|
|
| tf_decoder_layer = nn.TransformerDecoderLayer( |
| d_model=config.tf_d_model, |
| nhead=config.tf_num_head, |
| dim_feedforward=config.tf_d_ffn, |
| dropout=config.tf_dropout, |
| batch_first=True, |
| ) |
|
|
| self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers) |
| self._agent_head = AgentHead( |
| num_agents=config.num_bounding_boxes, |
| d_ffn=config.tf_d_ffn, |
| d_model=config.tf_d_model, |
| ) |
|
|
| self._trajectory_head = TrajectoryHead( |
| num_poses=config.trajectory_sampling.num_poses, |
| d_ffn=config.tf_d_ffn, |
| d_model=config.tf_d_model, |
| ) |
|
|
| def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """Torch module forward pass.""" |
|
|
| camera_feature: torch.Tensor = features["camera_feature"].cuda() |
| lidar_feature: torch.Tensor = features["lidar_feature"].cuda() |
| status_feature: torch.Tensor = features["status_feature"].cuda() |
|
|
| batch_size = status_feature.shape[0] |
|
|
| bev_feature_upscale, bev_feature, _ = self._backbone(camera_feature, lidar_feature) |
|
|
| bev_feature = self._bev_downscale(bev_feature).flatten(-2, -1) |
| bev_feature = bev_feature.permute(0, 2, 1) |
| status_encoding = self._status_encoding(status_feature) |
|
|
| keyval = torch.concatenate([bev_feature, status_encoding[:, None]], dim=1) |
| keyval += self._keyval_embedding.weight[None, ...] |
|
|
| query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1) |
| query_out = self._tf_decoder(query, keyval) |
|
|
| bev_semantic_map = self._bev_semantic_head(bev_feature_upscale) |
| trajectory_query, agents_query = query_out.split(self._query_splits, dim=1) |
|
|
| output: Dict[str, torch.Tensor] = {"bev_semantic_map": bev_semantic_map} |
| trajectory = self._trajectory_head(trajectory_query) |
| output.update(trajectory) |
|
|
| agents = self._agent_head(agents_query) |
| output.update(agents) |
|
|
| return output |
|
|
|
|
| class AgentHead(nn.Module): |
| """Bounding box prediction head.""" |
|
|
| def __init__( |
| self, |
| num_agents: int, |
| d_ffn: int, |
| d_model: int, |
| ): |
| """ |
| Initializes prediction head. |
| :param num_agents: maximum number of agents to predict |
| :param d_ffn: dimensionality of feed-forward network |
| :param d_model: input dimensionality |
| """ |
| super(AgentHead, self).__init__() |
|
|
| self._num_objects = num_agents |
| self._d_model = d_model |
| self._d_ffn = d_ffn |
|
|
| self._mlp_states = nn.Sequential( |
| nn.Linear(self._d_model, self._d_ffn), |
| nn.ReLU(), |
| nn.Linear(self._d_ffn, BoundingBox2DIndex.size()), |
| ) |
|
|
| self._mlp_label = nn.Sequential( |
| nn.Linear(self._d_model, 1), |
| ) |
|
|
| def forward(self, agent_queries) -> Dict[str, torch.Tensor]: |
| """Torch module forward pass.""" |
|
|
| agent_states = self._mlp_states(agent_queries) |
| agent_states[..., BoundingBox2DIndex.POINT] = agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 |
| agent_states[..., BoundingBox2DIndex.HEADING] = agent_states[..., BoundingBox2DIndex.HEADING].tanh() * np.pi |
|
|
| agent_labels = self._mlp_label(agent_queries).squeeze(dim=-1) |
|
|
| return {"agent_states": agent_states, "agent_labels": agent_labels} |
|
|
|
|
| class TrajectoryHead(nn.Module): |
| """Trajectory prediction head.""" |
|
|
| def __init__(self, num_poses: int, d_ffn: int, d_model: int): |
| """ |
| Initializes trajectory head. |
| :param num_poses: number of (x,y,θ) poses to predict |
| :param d_ffn: dimensionality of feed-forward network |
| :param d_model: input dimensionality |
| """ |
| super(TrajectoryHead, self).__init__() |
|
|
| self._num_poses = num_poses |
| self._d_model = d_model |
| self._d_ffn = d_ffn |
|
|
| self._mlp = nn.Sequential( |
| nn.Linear(self._d_model, self._d_ffn), |
| nn.ReLU(), |
| nn.Linear(self._d_ffn, num_poses * StateSE2Index.size()), |
| ) |
|
|
| def forward(self, object_queries) -> Dict[str, torch.Tensor]: |
| """Torch module forward pass.""" |
| poses = self._mlp(object_queries).reshape(-1, self._num_poses, StateSE2Index.size()) |
| poses[..., StateSE2Index.HEADING] = poses[..., StateSE2Index.HEADING].tanh() * np.pi |
| return {"trajectory": poses} |
|
|