hyzhou404's picture
init
7accb91
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from navsim.agents.transfuser.transfuser_config import TransfuserConfig
from navsim.agents.transfuser.transfuser_backbone import TransfuserBackbone
from navsim.agents.transfuser.transfuser_features import BoundingBox2DIndex
from navsim.common.enums import StateSE2Index
class TransfuserModel(nn.Module):
"""Torch module for Transfuser."""
def __init__(self, config: TransfuserConfig):
"""
Initializes TransFuser torch module.
:param config: global config dataclass of TransFuser.
"""
super().__init__()
self._query_splits = [
1,
config.num_bounding_boxes,
]
self._config = config
self._backbone = TransfuserBackbone(config)
self._keyval_embedding = nn.Embedding(8**2 + 1, config.tf_d_model) # 8x8 feature grid + trajectory
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model)
# usually, the BEV features are variable in size.
self._bev_downscale = nn.Conv2d(512, config.tf_d_model, kernel_size=1)
self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model)
self._bev_semantic_head = nn.Sequential(
nn.Conv2d(
config.bev_features_channels,
config.bev_features_channels,
kernel_size=(3, 3),
stride=1,
padding=(1, 1),
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv2d(
config.bev_features_channels,
config.num_bev_classes,
kernel_size=(1, 1),
stride=1,
padding=0,
bias=True,
),
nn.Upsample(
size=(config.lidar_resolution_height // 2, config.lidar_resolution_width),
mode="bilinear",
align_corners=False,
),
)
tf_decoder_layer = nn.TransformerDecoderLayer(
d_model=config.tf_d_model,
nhead=config.tf_num_head,
dim_feedforward=config.tf_d_ffn,
dropout=config.tf_dropout,
batch_first=True,
)
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers)
self._agent_head = AgentHead(
num_agents=config.num_bounding_boxes,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
)
self._trajectory_head = TrajectoryHead(
num_poses=config.trajectory_sampling.num_poses,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Torch module forward pass."""
camera_feature: torch.Tensor = features["camera_feature"].cuda()
lidar_feature: torch.Tensor = features["lidar_feature"].cuda()
status_feature: torch.Tensor = features["status_feature"].cuda()
batch_size = status_feature.shape[0]
bev_feature_upscale, bev_feature, _ = self._backbone(camera_feature, lidar_feature)
bev_feature = self._bev_downscale(bev_feature).flatten(-2, -1)
bev_feature = bev_feature.permute(0, 2, 1)
status_encoding = self._status_encoding(status_feature)
keyval = torch.concatenate([bev_feature, status_encoding[:, None]], dim=1)
keyval += self._keyval_embedding.weight[None, ...]
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1)
query_out = self._tf_decoder(query, keyval)
bev_semantic_map = self._bev_semantic_head(bev_feature_upscale)
trajectory_query, agents_query = query_out.split(self._query_splits, dim=1)
output: Dict[str, torch.Tensor] = {"bev_semantic_map": bev_semantic_map}
trajectory = self._trajectory_head(trajectory_query)
output.update(trajectory)
agents = self._agent_head(agents_query)
output.update(agents)
return output
class AgentHead(nn.Module):
"""Bounding box prediction head."""
def __init__(
self,
num_agents: int,
d_ffn: int,
d_model: int,
):
"""
Initializes prediction head.
:param num_agents: maximum number of agents to predict
:param d_ffn: dimensionality of feed-forward network
:param d_model: input dimensionality
"""
super(AgentHead, self).__init__()
self._num_objects = num_agents
self._d_model = d_model
self._d_ffn = d_ffn
self._mlp_states = nn.Sequential(
nn.Linear(self._d_model, self._d_ffn),
nn.ReLU(),
nn.Linear(self._d_ffn, BoundingBox2DIndex.size()),
)
self._mlp_label = nn.Sequential(
nn.Linear(self._d_model, 1),
)
def forward(self, agent_queries) -> Dict[str, torch.Tensor]:
"""Torch module forward pass."""
agent_states = self._mlp_states(agent_queries)
agent_states[..., BoundingBox2DIndex.POINT] = agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
agent_states[..., BoundingBox2DIndex.HEADING] = agent_states[..., BoundingBox2DIndex.HEADING].tanh() * np.pi
agent_labels = self._mlp_label(agent_queries).squeeze(dim=-1)
return {"agent_states": agent_states, "agent_labels": agent_labels}
class TrajectoryHead(nn.Module):
"""Trajectory prediction head."""
def __init__(self, num_poses: int, d_ffn: int, d_model: int):
"""
Initializes trajectory head.
:param num_poses: number of (x,y,θ) poses to predict
:param d_ffn: dimensionality of feed-forward network
:param d_model: input dimensionality
"""
super(TrajectoryHead, self).__init__()
self._num_poses = num_poses
self._d_model = d_model
self._d_ffn = d_ffn
self._mlp = nn.Sequential(
nn.Linear(self._d_model, self._d_ffn),
nn.ReLU(),
nn.Linear(self._d_ffn, num_poses * StateSE2Index.size()),
)
def forward(self, object_queries) -> Dict[str, torch.Tensor]:
"""Torch module forward pass."""
poses = self._mlp(object_queries).reshape(-1, self._num_poses, StateSE2Index.size())
poses[..., StateSE2Index.HEADING] = poses[..., StateSE2Index.HEADING].tanh() * np.pi
return {"trajectory": poses}