|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
from mmdet.models import DETECTORS |
|
|
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector,Base3DDetector |
|
|
from navsim.agents.diffusiondrive.transfuser_loss import transfuser_loss |
|
|
|
|
|
from typing import Dict |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import copy |
|
|
from navsim.agents.diffusiondrive.transfuser_config import TransfuserConfig |
|
|
from navsim.agents.diffusiondrive.transfuser_backbone import TransfuserBackbone |
|
|
from navsim.agents.diffusiondrive.transfuser_features import BoundingBox2DIndex |
|
|
from navsim.agents.diffusiondrive.modules.blocks import linear_relu_ln,bias_init_with_prob, gen_sineembed_for_position, GridSampleCrossBEVAttention |
|
|
from navsim.agents.diffusiondrive.modules.multimodal_loss import LossComputer |
|
|
from navsim.common.enums import StateSE2Index |
|
|
from diffusers.schedulers import DDIMScheduler |
|
|
from navsim.agents.diffusiondrive.modules.conditional_unet1d import ConditionalUnet1D,SinusoidalPosEmb |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from typing import Any, List, Dict, Optional, Union |
|
|
|
|
|
@DETECTORS.register_module() |
|
|
class V2TransfuserModel(Base3DDetector): |
|
|
def __init__( |
|
|
self, |
|
|
rejection_sampling=1, |
|
|
train_sampling=1, |
|
|
online_cost_learning=False, |
|
|
pdm_config_path='', |
|
|
image_architecture=None, |
|
|
bkb_path=None, |
|
|
online_pdm_infer=False, |
|
|
rlft=False, |
|
|
*args, |
|
|
**kwargs): |
|
|
""" |
|
|
Initializes TransFuser torch module. |
|
|
:param config: global config dataclass of TransFuser. |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
config = TransfuserConfig() |
|
|
if image_architecture is not None: |
|
|
config.image_architecture = image_architecture |
|
|
config.bkb_path = bkb_path |
|
|
self.rejection_sampling = rejection_sampling |
|
|
self.online_cost_learning = online_cost_learning |
|
|
self.online_pdm_infer = online_pdm_infer |
|
|
self.rlft = rlft |
|
|
|
|
|
self._query_splits = [ |
|
|
1, |
|
|
config.num_bounding_boxes, |
|
|
] |
|
|
|
|
|
self._config = config |
|
|
self._backbone = TransfuserBackbone(config) |
|
|
|
|
|
self._keyval_embedding = nn.Embedding(8**2 + 1, config.tf_d_model) |
|
|
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model) |
|
|
|
|
|
|
|
|
self._bev_downscale = nn.Conv2d(512, config.tf_d_model, kernel_size=1) |
|
|
self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model) |
|
|
|
|
|
|
|
|
self._bev_semantic_head = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
config.bev_features_channels, |
|
|
config.bev_features_channels, |
|
|
kernel_size=(3, 3), |
|
|
stride=1, |
|
|
padding=(1, 1), |
|
|
bias=True, |
|
|
), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d( |
|
|
config.bev_features_channels, |
|
|
config.num_bev_classes, |
|
|
kernel_size=(1, 1), |
|
|
stride=1, |
|
|
padding=0, |
|
|
bias=True, |
|
|
), |
|
|
nn.Upsample( |
|
|
size=(config.lidar_resolution_height // 2, config.lidar_resolution_width), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
), |
|
|
) |
|
|
|
|
|
tf_decoder_layer = nn.TransformerDecoderLayer( |
|
|
d_model=config.tf_d_model, |
|
|
nhead=config.tf_num_head, |
|
|
dim_feedforward=config.tf_d_ffn, |
|
|
dropout=config.tf_dropout, |
|
|
batch_first=True, |
|
|
) |
|
|
|
|
|
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers) |
|
|
self._agent_head = AgentHead( |
|
|
num_agents=config.num_bounding_boxes, |
|
|
d_ffn=config.tf_d_ffn, |
|
|
d_model=config.tf_d_model, |
|
|
) |
|
|
|
|
|
self._trajectory_head = TrajectoryHead( |
|
|
num_poses=config.trajectory_sampling.num_poses, |
|
|
d_ffn=config.tf_d_ffn, |
|
|
d_model=config.tf_d_model, |
|
|
plan_anchor_path=config.plan_anchor_path, |
|
|
config=config, |
|
|
rejection_sampling=rejection_sampling, |
|
|
cost_learning=online_cost_learning, |
|
|
pdm_config_path=pdm_config_path, |
|
|
online_pdm_infer=online_pdm_infer, |
|
|
train_sampling=train_sampling, |
|
|
rlft=rlft |
|
|
) |
|
|
self.bev_proj = nn.Sequential( |
|
|
*linear_relu_ln(config.tf_d_model, 1, 1,config.tf_d_model+64), |
|
|
) |
|
|
|
|
|
def aug_test(self, *args, **kwargs): |
|
|
raise NotImplementedError("aug_test not implemented") |
|
|
|
|
|
def extract_feat(self, *args, **kwargs): |
|
|
raise NotImplementedError("extract_feat not implemented") |
|
|
|
|
|
def simple_test(self, *args, **kwargs): |
|
|
raise NotImplementedError("simple_test not implemented") |
|
|
|
|
|
def forward(self, |
|
|
camera_feature, |
|
|
lidar_feature, |
|
|
status_feature, |
|
|
**kwargs): |
|
|
"""Torch module forward pass.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric_cache = None |
|
|
targets = None |
|
|
tokens = kwargs['token'] |
|
|
|
|
|
if self.training: |
|
|
targets = dict() |
|
|
targets['trajectory'] = kwargs['trajectory'] |
|
|
targets['agent_states'] = kwargs['agent_states'] |
|
|
targets['agent_labels'] = kwargs['agent_labels'] |
|
|
targets['bev_semantic_map'] = kwargs['bev_semantic_map'] |
|
|
if 'metric_cache' in kwargs.keys(): |
|
|
metric_cache = kwargs['metric_cache'] |
|
|
if self.rlft: |
|
|
self._backbone.eval() |
|
|
|
|
|
batch_size = status_feature.shape[0] |
|
|
|
|
|
bev_feature_upscale, bev_feature, _ = self._backbone(camera_feature, lidar_feature) |
|
|
cross_bev_feature = bev_feature_upscale |
|
|
bev_spatial_shape = bev_feature_upscale.shape[2:] |
|
|
concat_cross_bev_shape = bev_feature.shape[2:] |
|
|
bev_feature = self._bev_downscale(bev_feature).flatten(-2, -1) |
|
|
bev_feature = bev_feature.permute(0, 2, 1) |
|
|
status_encoding = self._status_encoding(status_feature) |
|
|
|
|
|
keyval = torch.cat([bev_feature, status_encoding[:, None]], dim=1) |
|
|
keyval += self._keyval_embedding.weight[None, ...] |
|
|
|
|
|
concat_cross_bev = keyval[:,:-1].permute(0,2,1).contiguous().view(batch_size, -1, concat_cross_bev_shape[0], concat_cross_bev_shape[1]) |
|
|
|
|
|
|
|
|
|
|
|
concat_cross_bev = F.interpolate(concat_cross_bev, size=bev_spatial_shape, mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
|
cross_bev_feature = torch.cat([concat_cross_bev, cross_bev_feature], dim=1) |
|
|
|
|
|
cross_bev_feature = self.bev_proj(cross_bev_feature.flatten(-2,-1).permute(0,2,1)) |
|
|
cross_bev_feature = cross_bev_feature.permute(0,2,1).contiguous().view(batch_size, -1, bev_spatial_shape[0], bev_spatial_shape[1]) |
|
|
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1) |
|
|
query_out = self._tf_decoder(query, keyval) |
|
|
|
|
|
bev_semantic_map = self._bev_semantic_head(bev_feature_upscale) |
|
|
trajectory_query, agents_query = query_out.split(self._query_splits, dim=1) |
|
|
|
|
|
output = dict() |
|
|
trajectory = self._trajectory_head(trajectory_query,agents_query, cross_bev_feature, |
|
|
bev_spatial_shape,status_encoding[:, None], |
|
|
targets=targets,global_img=None,metric_cache=metric_cache, |
|
|
tokens=tokens) |
|
|
if self.rlft and self.training: |
|
|
loss_dict = trajectory['cost_loss'] |
|
|
return loss_dict |
|
|
output.update(trajectory) |
|
|
|
|
|
agents = self._agent_head(agents_query) |
|
|
if self.training: |
|
|
output.update(agents) |
|
|
output["bev_semantic_map"] = bev_semantic_map |
|
|
|
|
|
loss_dict = transfuser_loss(targets, output, self._config) |
|
|
return loss_dict |
|
|
|
|
|
zeros = torch.zeros(trajectory['trajectory'].shape[0],).to(trajectory['trajectory'].device) |
|
|
|
|
|
pdm_dict = { |
|
|
"no_at_fault_collisions":zeros, |
|
|
"drivable_area_compliance":zeros, |
|
|
"ego_progress":zeros, |
|
|
"time_to_collision_within_bound":zeros, |
|
|
"comfort":zeros, |
|
|
"score":zeros, |
|
|
|
|
|
|
|
|
} |
|
|
output.update(pdm_dict) |
|
|
|
|
|
return output |
|
|
|
|
|
class AgentHead(nn.Module): |
|
|
"""Bounding box prediction head.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_agents: int, |
|
|
d_ffn: int, |
|
|
d_model: int, |
|
|
): |
|
|
""" |
|
|
Initializes prediction head. |
|
|
:param num_agents: maximum number of agents to predict |
|
|
:param d_ffn: dimensionality of feed-forward network |
|
|
:param d_model: input dimensionality |
|
|
""" |
|
|
super(AgentHead, self).__init__() |
|
|
|
|
|
self._num_objects = num_agents |
|
|
self._d_model = d_model |
|
|
self._d_ffn = d_ffn |
|
|
|
|
|
self._mlp_states = nn.Sequential( |
|
|
nn.Linear(self._d_model, self._d_ffn), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self._d_ffn, BoundingBox2DIndex.size()), |
|
|
) |
|
|
|
|
|
self._mlp_label = nn.Sequential( |
|
|
nn.Linear(self._d_model, 1), |
|
|
) |
|
|
|
|
|
def forward(self, agent_queries) -> Dict[str, torch.Tensor]: |
|
|
"""Torch module forward pass.""" |
|
|
|
|
|
agent_states = self._mlp_states(agent_queries) |
|
|
agent_states[..., BoundingBox2DIndex.POINT] = agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 |
|
|
agent_states[..., BoundingBox2DIndex.HEADING] = agent_states[..., BoundingBox2DIndex.HEADING].tanh() * np.pi |
|
|
|
|
|
agent_labels = self._mlp_label(agent_queries).squeeze(dim=-1) |
|
|
|
|
|
return {"agent_states": agent_states, "agent_labels": agent_labels} |
|
|
|
|
|
class DiffMotionPlanningRefinementModule(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dims=256, |
|
|
ego_fut_ts=8, |
|
|
ego_fut_mode=20, |
|
|
if_zeroinit_reg=True, |
|
|
cost_learning=False, |
|
|
): |
|
|
super(DiffMotionPlanningRefinementModule, self).__init__() |
|
|
self.embed_dims = embed_dims |
|
|
self.ego_fut_ts = ego_fut_ts |
|
|
self.ego_fut_mode = ego_fut_mode |
|
|
self.plan_cls_branch = nn.Sequential( |
|
|
*linear_relu_ln(embed_dims, 1, 2), |
|
|
nn.Linear(embed_dims, 1), |
|
|
) |
|
|
self.plan_reg_branch = nn.Sequential( |
|
|
nn.Linear(embed_dims, embed_dims), |
|
|
nn.ReLU(), |
|
|
nn.Linear(embed_dims, embed_dims), |
|
|
nn.ReLU(), |
|
|
nn.Linear(embed_dims, ego_fut_ts * 3), |
|
|
) |
|
|
self.if_zeroinit_reg = False |
|
|
self.cost_learning = cost_learning |
|
|
|
|
|
self.init_weight() |
|
|
|
|
|
def init_weight(self): |
|
|
if self.if_zeroinit_reg: |
|
|
nn.init.constant_(self.plan_reg_branch[-1].weight, 0) |
|
|
nn.init.constant_(self.plan_reg_branch[-1].bias, 0) |
|
|
|
|
|
bias_init = bias_init_with_prob(0.01) |
|
|
nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init) |
|
|
def forward( |
|
|
self, |
|
|
traj_feature, |
|
|
): |
|
|
bs, ego_fut_mode, _ = traj_feature.shape |
|
|
|
|
|
|
|
|
|
|
|
plan_cls = self.plan_cls_branch(traj_feature).squeeze(-1) |
|
|
traj_delta = self.plan_reg_branch(traj_feature) |
|
|
plan_reg = traj_delta.reshape(bs,ego_fut_mode, self.ego_fut_ts, 3) |
|
|
|
|
|
cost_dict=dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return plan_reg, plan_cls |
|
|
|
|
|
class ModulationLayer(nn.Module): |
|
|
|
|
|
def __init__(self, embed_dims: int, condition_dims: int): |
|
|
super(ModulationLayer, self).__init__() |
|
|
self.if_zeroinit_scale=False |
|
|
self.embed_dims = embed_dims |
|
|
self.scale_shift_mlp = nn.Sequential( |
|
|
nn.Mish(), |
|
|
nn.Linear(condition_dims, embed_dims*2), |
|
|
) |
|
|
self.init_weight() |
|
|
|
|
|
def init_weight(self): |
|
|
if self.if_zeroinit_scale: |
|
|
nn.init.constant_(self.scale_shift_mlp[-1].weight, 0) |
|
|
nn.init.constant_(self.scale_shift_mlp[-1].bias, 0) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
traj_feature, |
|
|
time_embed, |
|
|
global_cond=None, |
|
|
global_img=None, |
|
|
): |
|
|
if global_cond is not None: |
|
|
global_feature = torch.cat([ |
|
|
global_cond, time_embed |
|
|
], axis=-1) |
|
|
else: |
|
|
global_feature = time_embed |
|
|
if global_img is not None: |
|
|
global_img = global_img.flatten(2,3).permute(0,2,1).contiguous() |
|
|
global_feature = torch.cat([ |
|
|
global_img, global_feature |
|
|
], axis=-1) |
|
|
|
|
|
scale_shift = self.scale_shift_mlp(global_feature) |
|
|
scale,shift = scale_shift.chunk(2,dim=-1) |
|
|
traj_feature = traj_feature * (1 + scale) + shift |
|
|
return traj_feature |
|
|
|
|
|
class CustomTransformerDecoderLayer(nn.Module): |
|
|
def __init__(self, |
|
|
num_poses, |
|
|
d_model, |
|
|
d_ffn, |
|
|
config, |
|
|
cost_learning, |
|
|
feat_out=False |
|
|
): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.dropout1 = nn.Dropout(0.1) |
|
|
self.feat_out = feat_out |
|
|
self.cross_bev_attention = GridSampleCrossBEVAttention( |
|
|
config.tf_d_model, |
|
|
config.tf_num_head, |
|
|
num_points=num_poses, |
|
|
config=config, |
|
|
in_bev_dims=config.tf_d_model, |
|
|
) |
|
|
self.cross_agent_attention = nn.MultiheadAttention( |
|
|
config.tf_d_model, |
|
|
config.tf_num_head, |
|
|
dropout=config.tf_dropout, |
|
|
batch_first=True, |
|
|
) |
|
|
self.cross_ego_attention = nn.MultiheadAttention( |
|
|
config.tf_d_model, |
|
|
config.tf_num_head, |
|
|
dropout=config.tf_dropout, |
|
|
batch_first=True, |
|
|
) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(config.tf_d_model, config.tf_d_ffn), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.tf_d_ffn, config.tf_d_model), |
|
|
) |
|
|
self.norm1 = nn.LayerNorm(config.tf_d_model) |
|
|
self.norm2 = nn.LayerNorm(config.tf_d_model) |
|
|
self.norm3 = nn.LayerNorm(config.tf_d_model) |
|
|
self.time_modulation = ModulationLayer(config.tf_d_model, config.tf_d_model) |
|
|
if feat_out==False: |
|
|
self.task_decoder = DiffMotionPlanningRefinementModule( |
|
|
embed_dims=config.tf_d_model, |
|
|
ego_fut_ts=num_poses, |
|
|
ego_fut_mode=20, |
|
|
cost_learning=cost_learning |
|
|
) |
|
|
|
|
|
def forward(self, |
|
|
traj_feature, |
|
|
noisy_traj_points, |
|
|
bev_feature, |
|
|
bev_spatial_shape, |
|
|
agents_query, |
|
|
ego_query, |
|
|
time_embed, |
|
|
status_encoding, |
|
|
global_img=None): |
|
|
|
|
|
traj_feature = self.cross_bev_attention(traj_feature,noisy_traj_points,bev_feature,bev_spatial_shape) |
|
|
traj_feature = traj_feature + self.dropout(self.cross_agent_attention(traj_feature, agents_query,agents_query)[0]) |
|
|
traj_feature = self.norm1(traj_feature) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
traj_feature = traj_feature + self.dropout1(self.cross_ego_attention(traj_feature, ego_query,ego_query)[0]) |
|
|
traj_feature = self.norm2(traj_feature) |
|
|
|
|
|
|
|
|
traj_feature = self.norm3(self.ffn(traj_feature)) |
|
|
|
|
|
traj_feature = self.time_modulation(traj_feature, time_embed,global_cond=None,global_img=global_img) |
|
|
if self.feat_out: |
|
|
return traj_feature |
|
|
|
|
|
|
|
|
poses_reg, poses_cls = self.task_decoder(traj_feature) |
|
|
poses_reg[...,:2] = poses_reg[...,:2] + noisy_traj_points |
|
|
poses_reg[..., StateSE2Index.HEADING] = poses_reg[..., StateSE2Index.HEADING].tanh() * np.pi |
|
|
|
|
|
return poses_reg, poses_cls, traj_feature |
|
|
|
|
|
def _get_clones(module, N): |
|
|
|
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
|
|
|
class CustomTransformerDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
decoder_layer, |
|
|
num_layers, |
|
|
norm=None, |
|
|
): |
|
|
super().__init__() |
|
|
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") |
|
|
self.layers = _get_clones(decoder_layer, num_layers) |
|
|
self.num_layers = num_layers |
|
|
|
|
|
def forward(self, |
|
|
traj_feature, |
|
|
noisy_traj_points, |
|
|
bev_feature, |
|
|
bev_spatial_shape, |
|
|
agents_query, |
|
|
ego_query, |
|
|
time_embed, |
|
|
status_encoding, |
|
|
global_img=None): |
|
|
poses_reg_list = [] |
|
|
poses_cls_list = [] |
|
|
cost_dict_list = [] |
|
|
traj_points = noisy_traj_points |
|
|
for mod in self.layers: |
|
|
poses_reg, poses_cls, cost_dict = mod(traj_feature, traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding,global_img) |
|
|
poses_reg_list.append(poses_reg) |
|
|
poses_cls_list.append(poses_cls) |
|
|
cost_dict_list.append(cost_dict) |
|
|
traj_points = poses_reg[...,:2].clone().detach() |
|
|
return poses_reg_list, poses_cls_list, cost_dict_list |
|
|
|
|
|
from navsim.evaluate.pdm_score import batched_pdm_score, upsample_traj_nd_torch, wrap_to_pi |
|
|
from hydra.utils import instantiate |
|
|
from navsim.planning.simulation.planner.pdm_planner.simulation.pdm_simulator import PDMSimulator |
|
|
from navsim.planning.simulation.planner.pdm_planner.scoring.pdm_scorer import PDMScorer |
|
|
from navsim.common.dataclasses import Trajectory |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
|
from nuplan.planning.utils.multithreading.worker_pool import Task, WorkerPool |
|
|
from nuplan.planning.utils.multithreading.worker_ray import RayDistributed |
|
|
|
|
|
import os |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore", category=RuntimeWarning) |
|
|
import ray |
|
|
from time import time |
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
|
|
|
@ray.remote(num_cpus=1) |
|
|
class PDMWorker: |
|
|
def __init__(self, cfg_path): |
|
|
import os |
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
os.environ["MKL_NUM_THREADS"] = "1" |
|
|
os.environ["OPENBLAS_NUM_THREADS"] = "1" |
|
|
import sys |
|
|
sys.path.append('/cpfs04/user/liuhaochen/AlgEngine_nuplan/navsim') |
|
|
from navsim.evaluate.pdm_score import batched_pdm_score as pdm_score_func |
|
|
|
|
|
self.batched_pdm_score = pdm_score_func |
|
|
self.cfg = OmegaConf.load(cfg_path) |
|
|
self.simulator = instantiate(self.cfg.simulator) |
|
|
self.scorer = instantiate(self.cfg.scorer) |
|
|
|
|
|
def compute(self, metric_cache, traj, token): |
|
|
results = [] |
|
|
for i in range(len(metric_cache)): |
|
|
pdm_result = self.batched_pdm_score( |
|
|
metric_cache[i], traj[i], self.simulator.proposal_sampling, self.simulator, self.scorer |
|
|
) |
|
|
pdm_result['token'] = token[i] |
|
|
self.simulator.clear() |
|
|
self.scorer.clear() |
|
|
results.append(pdm_result) |
|
|
return results |
|
|
|
|
|
from ray.util.actor_pool import ActorPool |
|
|
from navsim.planning.utils.multithreading.worker_ray_no_torch import RayDistributedNoTorch |
|
|
from nuplan.planning.utils.multithreading.worker_utils import worker_map |
|
|
|
|
|
def init_ray_actor_pool(cfg_path, num_actors=7): |
|
|
if not ray.is_initialized(): |
|
|
ray.init( |
|
|
num_cpus=num_actors, |
|
|
_temp_dir=f"/tmp/ray_rank_{os.environ.get('RANK', '0')}", |
|
|
ignore_reinit_error=True, |
|
|
log_to_driver=False, |
|
|
) |
|
|
|
|
|
print(ray.available_resources()) |
|
|
actors = [PDMWorker.remote(cfg_path) for _ in range(num_actors)] |
|
|
return ActorPool(actors) |
|
|
|
|
|
import sys |
|
|
import lzma |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
|
|
|
def ray_worker_pdm_func(args): |
|
|
|
|
|
metric_cache = [x for a in args for x in a["metric_cache"]] |
|
|
traj_tensor = [x for a in args for x in a["traj_tensor"]] |
|
|
token = [x for a in args for x in a["tokens"]] |
|
|
cfg = args[0]['cfg'][0] |
|
|
|
|
|
cfg = OmegaConf.load(cfg) |
|
|
|
|
|
results = [] |
|
|
simulator = instantiate(cfg.simulator) |
|
|
scorer = instantiate(cfg.scorer) |
|
|
|
|
|
for i in range(len(metric_cache)): |
|
|
with lzma.open(metric_cache[i], "rb") as f: |
|
|
metric_cache_data = pickle.load(f) |
|
|
pdm_result = batched_pdm_score( |
|
|
metric_cache=metric_cache_data, |
|
|
model_trajectory=traj_tensor[i], |
|
|
future_sampling=simulator.proposal_sampling, |
|
|
simulator=simulator, |
|
|
scorer=scorer, |
|
|
) |
|
|
pdm_result['token'] = token[i] |
|
|
simulator.clear() |
|
|
scorer.clear() |
|
|
results.append(pdm_result) |
|
|
return results |
|
|
|
|
|
from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, MOELoRALinear, |
|
|
LoRAMoECLAdapter, BayesianLinear, |
|
|
finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper, retreive_bayesian_lora_param) |
|
|
|
|
|
class TrajectoryHead(nn.Module): |
|
|
"""Trajectory prediction head.""" |
|
|
|
|
|
def __init__(self, num_poses: int, d_ffn: int, d_model: int, plan_anchor_path: str,config: TransfuserConfig, |
|
|
rejection_sampling=1,cost_learning=False, pdm_config_path='',online_pdm_infer=False, |
|
|
train_sampling=1,rlft=False): |
|
|
""" |
|
|
Initializes trajectory head. |
|
|
:param num_poses: number of (x,y,θ) poses to predict |
|
|
:param d_ffn: dimensionality of feed-forward network |
|
|
:param d_model: input dimensionality |
|
|
""" |
|
|
super(TrajectoryHead, self).__init__() |
|
|
|
|
|
self._num_poses = num_poses |
|
|
self._d_model = d_model |
|
|
self._d_ffn = d_ffn |
|
|
self.diff_loss_weight = 2.0 |
|
|
self.ego_fut_mode = 20 |
|
|
self.rejection_sampling = rejection_sampling |
|
|
self.train_sampling = train_sampling |
|
|
self.rlft = rlft |
|
|
|
|
|
self.diffusion_scheduler = DDIMScheduler( |
|
|
num_train_timesteps=1000, |
|
|
beta_schedule="scaled_linear", |
|
|
prediction_type="sample", |
|
|
) |
|
|
|
|
|
|
|
|
plan_anchor = np.load(plan_anchor_path) |
|
|
|
|
|
self.plan_anchor = nn.Parameter( |
|
|
torch.tensor(plan_anchor[:, 4::5, :2], dtype=torch.float32), |
|
|
requires_grad=False, |
|
|
) |
|
|
self.plan_anchor_encoder = nn.Sequential( |
|
|
*linear_relu_ln(d_model, 1, 1,512), |
|
|
nn.Linear(d_model, d_model), |
|
|
) |
|
|
self.time_mlp = nn.Sequential( |
|
|
SinusoidalPosEmb(d_model), |
|
|
nn.Linear(d_model, d_model * 4), |
|
|
nn.Mish(), |
|
|
nn.Linear(d_model * 4, d_model), |
|
|
) |
|
|
|
|
|
diff_decoder_layer = CustomTransformerDecoderLayer( |
|
|
num_poses=num_poses, |
|
|
d_model=d_model, |
|
|
d_ffn=d_ffn, |
|
|
config=config, |
|
|
cost_learning=cost_learning, |
|
|
) |
|
|
self.diff_decoder = CustomTransformerDecoder(diff_decoder_layer, 2) |
|
|
if self.rlft: |
|
|
self.cost_diff_decoder = nn.ModuleList([CustomTransformerDecoderLayer( |
|
|
num_poses=num_poses, |
|
|
d_model=d_model, |
|
|
d_ffn=d_ffn, |
|
|
config=config, |
|
|
cost_learning=cost_learning, |
|
|
feat_out=True |
|
|
) for _ in range(2)]) |
|
|
|
|
|
self.loss_computer = LossComputer(config, self.train_sampling) |
|
|
self.cost_learning = cost_learning |
|
|
self.online_pdm_infer = online_pdm_infer |
|
|
|
|
|
if self.cost_learning: |
|
|
self.heads = nn.ModuleDict({ |
|
|
'noc': nn.Sequential( |
|
|
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), |
|
|
), |
|
|
'da': nn.Sequential( |
|
|
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), |
|
|
), |
|
|
'ttc': nn.Sequential( |
|
|
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), |
|
|
), |
|
|
'comfort': nn.Sequential( |
|
|
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), |
|
|
), |
|
|
'progress': nn.Sequential( |
|
|
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), |
|
|
), |
|
|
'pdms': nn.Sequential( |
|
|
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), |
|
|
), |
|
|
}) |
|
|
if self.rlft: |
|
|
self.heads_lora = lora_wrapper(self.heads, LoRALinear, |
|
|
rank=16, alpha=1.0, dropout=0.1,num_task=4) |
|
|
finetuning_detach(self.heads) |
|
|
|
|
|
cfg_path = pdm_config_path |
|
|
|
|
|
print('online safe RL scoring, load from:',cfg_path) |
|
|
cfg = cfg_path |
|
|
self.cfg = cfg |
|
|
self.chunk_size = 4 |
|
|
self.worker = RayDistributedNoTorch(threads_per_node=16) |
|
|
self.init_metric_cache_paths() |
|
|
|
|
|
|
|
|
def init_metric_cache_paths(self): |
|
|
""" |
|
|
Helper function to load all cache file paths from folder. |
|
|
:param cache_path: directory of cache folder |
|
|
:return: dictionary of token and file path |
|
|
""" |
|
|
|
|
|
metadata_file = '/xxx/navsim_metric_cache_navtrain2_metadata_node_0.csv' |
|
|
with open(str(metadata_file), "r") as f: |
|
|
cache_paths = f.read().splitlines()[1:] |
|
|
self.metric_cache_dict = {cache_path.split("/")[-2]: cache_path for cache_path in cache_paths} |
|
|
|
|
|
vidmetadata_file = '/xxx/navsim_metric_cache_navtrain2/metadata/navsim_metric_cache_navtrain2_metadata_node_0.csv' |
|
|
with open(str(vidmetadata_file), "r") as f: |
|
|
vid_cache_paths = f.read().splitlines()[1:] |
|
|
self.metric_cache_dict.update({cache_path.split("/")[-2]: cache_path for cache_path in vid_cache_paths}) |
|
|
|
|
|
print(len(list(self.metric_cache_dict.keys()))) |
|
|
|
|
|
def norm_odo(self, odo_info_fut): |
|
|
odo_info_fut_x = odo_info_fut[..., 0:1] |
|
|
odo_info_fut_y = odo_info_fut[..., 1:2] |
|
|
odo_info_fut_head = odo_info_fut[..., 2:3] |
|
|
|
|
|
odo_info_fut_x = 2*(odo_info_fut_x + 1.2)/56.9 -1 |
|
|
odo_info_fut_y = 2*(odo_info_fut_y + 20)/46 -1 |
|
|
odo_info_fut_head = 2*(odo_info_fut_head + 2)/3.9 -1 |
|
|
return torch.cat([odo_info_fut_x, odo_info_fut_y, odo_info_fut_head], dim=-1) |
|
|
def denorm_odo(self, odo_info_fut): |
|
|
odo_info_fut_x = odo_info_fut[..., 0:1] |
|
|
odo_info_fut_y = odo_info_fut[..., 1:2] |
|
|
odo_info_fut_head = odo_info_fut[..., 2:3] |
|
|
|
|
|
odo_info_fut_x = (odo_info_fut_x + 1)/2 * 56.9 - 1.2 |
|
|
odo_info_fut_y = (odo_info_fut_y + 1)/2 * 46 - 20 |
|
|
odo_info_fut_head = (odo_info_fut_head + 1)/2 * 3.9 - 2 |
|
|
return torch.cat([odo_info_fut_x, odo_info_fut_y, odo_info_fut_head], dim=-1) |
|
|
|
|
|
def forward(self, ego_query, agents_query, bev_feature,bev_spatial_shape, |
|
|
status_encoding, targets=None,global_img=None, |
|
|
metric_cache=None,tokens=None) -> Dict[str, torch.Tensor]: |
|
|
"""Torch module forward pass.""" |
|
|
if self.training: |
|
|
return self.forward_train(ego_query, agents_query, |
|
|
bev_feature,bev_spatial_shape,status_encoding, |
|
|
targets,global_img, metric_cache,tokens) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
return self.forward_test(ego_query, agents_query, |
|
|
bev_feature,bev_spatial_shape,status_encoding,global_img,tokens) |
|
|
|
|
|
def online_cost_loss(self, cost_dict, gt_pdm_score, idx=0): |
|
|
loss = {} |
|
|
loss[f'loss.noc_{idx}'] = 3*torch.mean( |
|
|
F.binary_cross_entropy_with_logits(cost_dict['noc'], gt_pdm_score['no_at_fault_collisions'])) |
|
|
|
|
|
loss[f'loss.da_{idx}'] = 3*torch.mean( |
|
|
F.binary_cross_entropy_with_logits(cost_dict['da'], gt_pdm_score['drivable_area_compliance'])) |
|
|
|
|
|
loss[f'loss.ttc_{idx}'] = 2*torch.mean( |
|
|
F.binary_cross_entropy_with_logits(cost_dict['ttc'], gt_pdm_score['time_to_collision_within_bound'])) |
|
|
|
|
|
loss[f'loss.ep_{idx}'] = torch.mean( |
|
|
F.binary_cross_entropy_with_logits(cost_dict['progress'], gt_pdm_score['ego_progress'])) |
|
|
|
|
|
loss[f'loss.comfort_{idx}'] = torch.mean( |
|
|
F.binary_cross_entropy_with_logits( |
|
|
cost_dict['comfort'], gt_pdm_score['comfort'])) |
|
|
|
|
|
score = (5 * cost_dict['ttc'].sigmoid() + 5 * cost_dict['progress'].sigmoid() + 2*cost_dict['comfort'].sigmoid() ) / 12 |
|
|
score = cost_dict['noc'].sigmoid() * cost_dict['da'].sigmoid() * score |
|
|
|
|
|
loss[f'loss.RL_score_{idx}'] = 5 * torch.mean( |
|
|
F.binary_cross_entropy( |
|
|
score, gt_pdm_score['score']) |
|
|
) |
|
|
|
|
|
loss[f'loss.score_{idx}'] = 5 * torch.mean( |
|
|
F.binary_cross_entropy_with_logits( |
|
|
cost_dict['pdms'], gt_pdm_score['score']) |
|
|
) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def online_batch_pdm_calculation(self, traj_tensor, metric_cache,traj_len,tokens): |
|
|
b, n, t, d = traj_tensor.shape |
|
|
device = traj_tensor.device |
|
|
|
|
|
traj_tensor = traj_tensor.view(b*n, t, d) |
|
|
zeros = torch.zeros((b*n, 1, d)).to(device) |
|
|
full_traj_tensor = torch.cat([zeros, traj_tensor], dim=1) |
|
|
interped_traj_tensor = upsample_traj_nd_torch(full_traj_tensor, dt_in=0.5, dt_out=0.1) |
|
|
interped_traj_tensor = interped_traj_tensor.reshape(b, n, -1, d) |
|
|
interped_traj_tensor[..., -1] = wrap_to_pi(interped_traj_tensor[..., -1]) |
|
|
|
|
|
traj_tensor = interped_traj_tensor.detach().cpu().numpy() |
|
|
|
|
|
batched_pdm_results = dict( |
|
|
no_at_fault_collisions=[], |
|
|
drivable_area_compliance=[], |
|
|
ego_progress=[], |
|
|
time_to_collision_within_bound=[], |
|
|
comfort=[], |
|
|
driving_direction_compliance=[], |
|
|
score=[], |
|
|
) |
|
|
mc_chunk, traj_chunk, tok_chunk = [],[],[] |
|
|
worker_args = [] |
|
|
|
|
|
for i in range(0, len(traj_tensor), self.chunk_size): |
|
|
traj_chunk.append(traj_tensor[i:i+self.chunk_size]) |
|
|
tok_chunk.append(tokens[i:i+self.chunk_size]) |
|
|
if self.training: |
|
|
worker_args.append( |
|
|
dict( |
|
|
metric_cache = [self.metric_cache_dict[t] for t in tokens[i:i+self.chunk_size]], |
|
|
traj_tensor=traj_tensor[i:i+self.chunk_size], |
|
|
tokens=tokens[i:i+self.chunk_size], |
|
|
cfg=[self.cfg]*self.chunk_size |
|
|
)) |
|
|
else: |
|
|
buf_cache_token = [] |
|
|
for t in tokens[i:i+self.chunk_size]: |
|
|
if t not in self.metric_cache_dict.keys(): |
|
|
buf_cache_token.append(self.metric_cache_dict[list(self.metric_cache_dict.keys())[0]]) |
|
|
else: |
|
|
buf_cache_token.append(self.metric_cache_dict[t]) |
|
|
worker_args.append( |
|
|
dict( |
|
|
metric_cache =buf_cache_token, |
|
|
traj_tensor=traj_tensor[i:i+self.chunk_size], |
|
|
tokens=tokens[i:i+self.chunk_size], |
|
|
cfg=[self.cfg]*self.chunk_size |
|
|
)) |
|
|
|
|
|
args = zip(mc_chunk, traj_chunk, tok_chunk) |
|
|
|
|
|
|
|
|
pdm_results = worker_map(self.worker, ray_worker_pdm_func, worker_args) |
|
|
|
|
|
pdm_results_dict = { |
|
|
pdm['token']:pdm for pdm in pdm_results |
|
|
} |
|
|
for i in range(b): |
|
|
pdm_res = pdm_results_dict[tokens[i]] |
|
|
for k in batched_pdm_results.keys(): |
|
|
batched_pdm_results[k].append(pdm_res[k]) |
|
|
|
|
|
|
|
|
ret_batch_pdm_results_list = [dict() for _ in range(traj_len)] |
|
|
|
|
|
for k,v in batched_pdm_results.items(): |
|
|
buf_value = torch.from_numpy( |
|
|
np.stack(v, axis=0) |
|
|
).to(device) |
|
|
if self.train_sampling > 1: |
|
|
buf_value = buf_value.view(b, self.train_sampling, n//self.train_sampling) |
|
|
buf_value = buf_value.view(b*self.train_sampling, -1) |
|
|
|
|
|
buf_value = buf_value.reshape(b*self.train_sampling, traj_len, -1) |
|
|
for i in range(traj_len): |
|
|
ret_batch_pdm_results_list[i][k] = buf_value[:, i] |
|
|
|
|
|
return ret_batch_pdm_results_list |
|
|
|
|
|
def expand_and_reshape(self, x, k=4): |
|
|
B, rest = x.shape[0], x.shape[1:] |
|
|
x = x.unsqueeze(1).repeat(1, k, *([1] * len(rest))) |
|
|
|
|
|
return x.view(B*k, *rest) |
|
|
|
|
|
def forward_train(self, ego_query,agents_query,bev_feature,bev_spatial_shape,status_encoding, targets=None,global_img=None, |
|
|
metric_cache=None,tokens=None) -> Dict[str, torch.Tensor]: |
|
|
bs = ego_query.shape[0] |
|
|
device = ego_query.device |
|
|
|
|
|
plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs,1,1,1) |
|
|
odo_info_fut = self.norm_odo(plan_anchor) |
|
|
|
|
|
if self.train_sampling > 1: |
|
|
odo_info_fut = self.expand_and_reshape(odo_info_fut, k=self.train_sampling) |
|
|
agents_query = self.expand_and_reshape(agents_query, k=self.train_sampling) |
|
|
ego_query = self.expand_and_reshape(ego_query, k=self.train_sampling) |
|
|
bev_feature = self.expand_and_reshape(bev_feature, k=self.train_sampling) |
|
|
status_encoding = self.expand_and_reshape(status_encoding, k=self.train_sampling) |
|
|
bs = bs*self.train_sampling |
|
|
|
|
|
timesteps = torch.randint( |
|
|
0, 50, |
|
|
(bs,), device=device |
|
|
) |
|
|
noise = torch.randn(odo_info_fut.shape, device=device) |
|
|
noisy_traj_points = self.diffusion_scheduler.add_noise( |
|
|
original_samples=odo_info_fut, |
|
|
noise=noise, |
|
|
timesteps=timesteps, |
|
|
).float() |
|
|
noisy_traj_points = torch.clamp(noisy_traj_points, min=-1, max=1) |
|
|
noisy_traj_points = self.denorm_odo(noisy_traj_points) |
|
|
|
|
|
ego_fut_mode = noisy_traj_points.shape[1] |
|
|
|
|
|
traj_pos_embed = gen_sineembed_for_position(noisy_traj_points,hidden_dim=64) |
|
|
traj_pos_embed = traj_pos_embed.flatten(-2) |
|
|
traj_feature = self.plan_anchor_encoder(traj_pos_embed) |
|
|
traj_feature = traj_feature.view(bs,ego_fut_mode,-1) |
|
|
|
|
|
time_embed = self.time_mlp(timesteps) |
|
|
time_embed = time_embed.view(bs,1,-1) |
|
|
|
|
|
|
|
|
poses_reg_list, poses_cls_list, traj_feat_list = self.diff_decoder( |
|
|
traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape, |
|
|
agents_query, ego_query, time_embed, status_encoding,global_img) |
|
|
|
|
|
|
|
|
trajectory_loss_dict = {} |
|
|
ret_traj_loss = 0 |
|
|
|
|
|
if self.cost_learning: |
|
|
|
|
|
cost_dict_list = [] |
|
|
for idx, feat in enumerate(traj_feat_list): |
|
|
cost_dict =dict() |
|
|
if self.rlft: |
|
|
feat = self.cost_diff_decoder[idx]( |
|
|
feat.detach(), poses_reg_list[idx][..., :2].detach(), |
|
|
bev_feature.detach(), bev_spatial_shape, |
|
|
agents_query.detach(), ego_query.detach(), |
|
|
time_embed.detach(), status_encoding.detach() |
|
|
) |
|
|
for k in self.heads.keys(): |
|
|
cost_dict[k] = peft_wrapper_forward(feat, self.heads[k], self.heads_lora['lora_'+k])[..., 0] |
|
|
else: |
|
|
for k in self.heads.keys(): |
|
|
cost_dict[k] = self.heads[k](feat)[..., 0] |
|
|
cost_dict_list.append(cost_dict) |
|
|
|
|
|
b, num_mode, ts, d = poses_reg_list[0].shape |
|
|
stack_trajs = torch.cat(poses_reg_list, dim=1) |
|
|
if self.train_sampling > 1: |
|
|
stack_trajs = stack_trajs.view(bs//self.train_sampling, self.train_sampling, -1, ts, d) |
|
|
stack_trajs = stack_trajs.view(bs//self.train_sampling, -1, ts, d) |
|
|
|
|
|
batched_pdm_score_list = self.online_batch_pdm_calculation( |
|
|
stack_trajs, None, len(poses_reg_list),tokens |
|
|
) |
|
|
|
|
|
cost_dict_loss = dict() |
|
|
|
|
|
for idx, (poses_reg, poses_cls) in enumerate(zip(poses_reg_list, poses_cls_list)): |
|
|
trajectory_loss = self.loss_computer(poses_reg, poses_cls, targets, plan_anchor) |
|
|
trajectory_loss_dict[f"loss.planning.{idx}"] = trajectory_loss |
|
|
ret_traj_loss += trajectory_loss |
|
|
if self.cost_learning: |
|
|
cost_dict_loss.update( |
|
|
self.online_cost_loss( |
|
|
cost_dict_list[idx], batched_pdm_score_list[idx], idx |
|
|
) |
|
|
) |
|
|
|
|
|
mode_idx = poses_cls_list[-1].argmax(dim=-1) |
|
|
mode_idx = mode_idx[...,None,None,None].repeat(1,1,self._num_poses,3) |
|
|
best_reg = torch.gather(poses_reg_list[-1], 1, mode_idx).squeeze(1) |
|
|
|
|
|
ret_dict = dict(trajectory=best_reg) |
|
|
ret_dict.update(trajectory_loss_dict) |
|
|
if self.cost_learning: |
|
|
return {'cost_loss':cost_dict_loss,"trajectory": best_reg,"trajectory_loss":ret_traj_loss} |
|
|
return {"trajectory": best_reg,"trajectory_loss":ret_traj_loss} |
|
|
|
|
|
def cal_test_rl_score(self, cost_dict): |
|
|
score = (5 * cost_dict['ttc'].sigmoid() + 5 * cost_dict['progress'].sigmoid() + 2*cost_dict['comfort'].sigmoid() ) / 12 |
|
|
score = (cost_dict['noc'].sigmoid()).float() * (cost_dict['da'].sigmoid()).float() * score |
|
|
return score |
|
|
|
|
|
def forward_test(self, ego_query,agents_query,bev_feature,bev_spatial_shape,status_encoding,global_img,tokens) -> Dict[str, torch.Tensor]: |
|
|
step_num = 2 |
|
|
bs = ego_query.shape[0] |
|
|
device = ego_query.device |
|
|
|
|
|
step_ratio = 20 / step_num |
|
|
roll_timesteps = (np.arange(0, step_num) * step_ratio).round()[::-1].copy().astype(np.int64) |
|
|
roll_timesteps = torch.from_numpy(roll_timesteps).to(device) |
|
|
|
|
|
|
|
|
plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs,1,1,1) |
|
|
org_img = self.norm_odo(plan_anchor) |
|
|
trajs, scores , costs= [], [], [] |
|
|
for _ in range(100): |
|
|
self.diffusion_scheduler.set_timesteps(1000, device) |
|
|
timesteps = torch.randint( |
|
|
0, 5, |
|
|
(bs,), device=device |
|
|
) |
|
|
noise = torch.randn(org_img.shape, device=device) |
|
|
trunc_timesteps = torch.ones((bs,), device=device, dtype=torch.long) * 8 |
|
|
img = self.diffusion_scheduler.add_noise(original_samples=org_img, noise=noise, timesteps=timesteps) |
|
|
|
|
|
ego_fut_mode = img.shape[1] |
|
|
|
|
|
if True: |
|
|
x_boxes = torch.clamp(img, min=-1, max=1) |
|
|
noisy_traj_points = self.denorm_odo(x_boxes) |
|
|
|
|
|
|
|
|
traj_pos_embed = gen_sineembed_for_position(noisy_traj_points,hidden_dim=64) |
|
|
traj_pos_embed = traj_pos_embed.flatten(-2) |
|
|
traj_feature = self.plan_anchor_encoder(traj_pos_embed) |
|
|
traj_feature = traj_feature.view(bs,ego_fut_mode,-1) |
|
|
|
|
|
|
|
|
if not torch.is_tensor(timesteps): |
|
|
|
|
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=img.device) |
|
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
|
|
timesteps = timesteps[None].to(img.device) |
|
|
|
|
|
|
|
|
|
|
|
time_embed = self.time_mlp(timesteps) |
|
|
time_embed = time_embed.view(bs,1,-1) |
|
|
|
|
|
|
|
|
poses_reg_list, poses_cls_list, traj_feat_list = self.diff_decoder(traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding,global_img) |
|
|
if self.cost_learning: |
|
|
cost_list = [] |
|
|
for idx, feat in enumerate(traj_feat_list): |
|
|
cost_dict =dict() |
|
|
if self.rlft: |
|
|
feat = self.cost_diff_decoder[idx]( |
|
|
feat.detach(), poses_reg_list[idx][..., :2].detach(), |
|
|
bev_feature.detach(), bev_spatial_shape, |
|
|
agents_query.detach(), ego_query.detach(), |
|
|
time_embed.detach(), status_encoding.detach() |
|
|
) |
|
|
for h in self.heads.keys(): |
|
|
cost_dict[h] = peft_wrapper_forward(feat, self.heads[h], self.heads_lora['lora_'+h])[..., 0] |
|
|
else: |
|
|
for h in self.heads.keys(): |
|
|
cost_dict[h] = self.heads[h](feat)[..., 0] |
|
|
cost_list.append(cost_dict) |
|
|
else: |
|
|
cost_list = [0]*len(poses_cls_list) |
|
|
poses_reg = poses_reg_list[-1] |
|
|
poses_cls = poses_cls_list[-1] |
|
|
poses_cost = cost_list[-1] |
|
|
|
|
|
x_start = poses_reg[...,:2] |
|
|
x_start = self.norm_odo(x_start) |
|
|
|
|
|
human_score = poses_cls.argmax(1) |
|
|
trajs.append(poses_reg[torch.arange(bs)[:, None], human_score]) |
|
|
scores.append(poses_cls.softmax(-1)[torch.arange(bs)[:, None], human_score]) |
|
|
costs.append({k:v[torch.arange(bs)[:, None], human_score] for k,v in poses_cost.items()}) |
|
|
|
|
|
trajs = torch.cat(trajs, dim=1) |
|
|
scores = torch.cat(scores, dim=1) |
|
|
mode_idx = torch.argmax(scores, dim=1) |
|
|
|
|
|
|
|
|
if self.cost_learning: |
|
|
costs = [(self.cal_test_rl_score(cost)+cost['pdms'].sigmoid())/2 for cost in costs] |
|
|
costs = torch.cat(costs,dim=1) |
|
|
|
|
|
mode_idx = torch.argmax(costs, dim=1) |
|
|
|
|
|
if self.online_pdm_infer: |
|
|
b, m, t, d = trajs.shape |
|
|
batched_pdm_score_list = self.online_batch_pdm_calculation( |
|
|
trajs, None, 1, tokens |
|
|
) |
|
|
mode_idx = torch.argmax(batched_pdm_score_list[0]['score'], dim=1) |
|
|
|
|
|
b = trajs.shape[0] |
|
|
best_reg = trajs[torch.arange(b), mode_idx] |
|
|
|
|
|
return {"trajectory": best_reg} |