# ---------------------------------------------------------------------------------# # UniAD: Planning-oriented Autonomous Driving (https://arxiv.org/abs/2212.10156) # # Source code: https://github.com/OpenDriveLab/UniAD # # Copyright (c) OpenDriveLab. All rights reserved. # # ---------------------------------------------------------------------------------# import sys from mmdet.models import DETECTORS from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector,Base3DDetector from navsim.agents.diffusiondrive.transfuser_loss import transfuser_loss from typing import Dict import numpy as np import torch import torch.nn as nn import copy from navsim.agents.diffusiondrive.transfuser_config import TransfuserConfig from navsim.agents.diffusiondrive.transfuser_backbone import TransfuserBackbone from navsim.agents.diffusiondrive.transfuser_features import BoundingBox2DIndex from navsim.agents.diffusiondrive.modules.blocks import linear_relu_ln,bias_init_with_prob, gen_sineembed_for_position, GridSampleCrossBEVAttention from navsim.agents.diffusiondrive.modules.multimodal_loss import LossComputer from navsim.common.enums import StateSE2Index from diffusers.schedulers import DDIMScheduler from navsim.agents.diffusiondrive.modules.conditional_unet1d import ConditionalUnet1D,SinusoidalPosEmb import torch.nn.functional as F from typing import Any, List, Dict, Optional, Union @DETECTORS.register_module() class V2TransfuserModel(Base3DDetector): def __init__( self, rejection_sampling=1, train_sampling=1, online_cost_learning=False, pdm_config_path='', image_architecture=None, bkb_path=None, online_pdm_infer=False, rlft=False, *args, **kwargs): """ Initializes TransFuser torch module. :param config: global config dataclass of TransFuser. """ super().__init__() config = TransfuserConfig() if image_architecture is not None: config.image_architecture = image_architecture config.bkb_path = bkb_path self.rejection_sampling = rejection_sampling self.online_cost_learning = online_cost_learning self.online_pdm_infer = online_pdm_infer self.rlft = rlft self._query_splits = [ 1, config.num_bounding_boxes, ] self._config = config self._backbone = TransfuserBackbone(config) self._keyval_embedding = nn.Embedding(8**2 + 1, config.tf_d_model) # 8x8 feature grid + trajectory self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model) # usually, the BEV features are variable in size. self._bev_downscale = nn.Conv2d(512, config.tf_d_model, kernel_size=1) self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model) # self._status_encoding = nn.Linear(4 + 2, config.tf_d_model) self._bev_semantic_head = nn.Sequential( nn.Conv2d( config.bev_features_channels, config.bev_features_channels, kernel_size=(3, 3), stride=1, padding=(1, 1), bias=True, ), nn.ReLU(inplace=True), nn.Conv2d( config.bev_features_channels, config.num_bev_classes, kernel_size=(1, 1), stride=1, padding=0, bias=True, ), nn.Upsample( size=(config.lidar_resolution_height // 2, config.lidar_resolution_width), mode="bilinear", align_corners=False, ), ) tf_decoder_layer = nn.TransformerDecoderLayer( d_model=config.tf_d_model, nhead=config.tf_num_head, dim_feedforward=config.tf_d_ffn, dropout=config.tf_dropout, batch_first=True, ) self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers) self._agent_head = AgentHead( num_agents=config.num_bounding_boxes, d_ffn=config.tf_d_ffn, d_model=config.tf_d_model, ) self._trajectory_head = TrajectoryHead( num_poses=config.trajectory_sampling.num_poses, d_ffn=config.tf_d_ffn, d_model=config.tf_d_model, plan_anchor_path=config.plan_anchor_path, config=config, rejection_sampling=rejection_sampling, cost_learning=online_cost_learning, pdm_config_path=pdm_config_path, online_pdm_infer=online_pdm_infer, train_sampling=train_sampling, rlft=rlft ) self.bev_proj = nn.Sequential( *linear_relu_ln(config.tf_d_model, 1, 1,config.tf_d_model+64), ) def aug_test(self, *args, **kwargs): raise NotImplementedError("aug_test not implemented") def extract_feat(self, *args, **kwargs): raise NotImplementedError("extract_feat not implemented") def simple_test(self, *args, **kwargs): raise NotImplementedError("simple_test not implemented") def forward(self, camera_feature, lidar_feature, status_feature, **kwargs): """Torch module forward pass.""" # camera_feature: torch.Tensor = features["camera_feature"] # lidar_feature: torch.Tensor = features["lidar_feature"] # status_feature: torch.Tensor = features["status_feature"] metric_cache = None targets = None tokens = kwargs['token'] if self.training: targets = dict() targets['trajectory'] = kwargs['trajectory'] targets['agent_states'] = kwargs['agent_states'] targets['agent_labels'] = kwargs['agent_labels'] targets['bev_semantic_map'] = kwargs['bev_semantic_map'] if 'metric_cache' in kwargs.keys(): metric_cache = kwargs['metric_cache'] if self.rlft: self._backbone.eval() batch_size = status_feature.shape[0] bev_feature_upscale, bev_feature, _ = self._backbone(camera_feature, lidar_feature) cross_bev_feature = bev_feature_upscale bev_spatial_shape = bev_feature_upscale.shape[2:] concat_cross_bev_shape = bev_feature.shape[2:] bev_feature = self._bev_downscale(bev_feature).flatten(-2, -1) bev_feature = bev_feature.permute(0, 2, 1) status_encoding = self._status_encoding(status_feature)#[..., :6]) keyval = torch.cat([bev_feature, status_encoding[:, None]], dim=1) keyval += self._keyval_embedding.weight[None, ...] concat_cross_bev = keyval[:,:-1].permute(0,2,1).contiguous().view(batch_size, -1, concat_cross_bev_shape[0], concat_cross_bev_shape[1]) # upsample to the same shape as bev_feature_upscale concat_cross_bev = F.interpolate(concat_cross_bev, size=bev_spatial_shape, mode='bilinear', align_corners=False) # concat concat_cross_bev and cross_bev_feature cross_bev_feature = torch.cat([concat_cross_bev, cross_bev_feature], dim=1) cross_bev_feature = self.bev_proj(cross_bev_feature.flatten(-2,-1).permute(0,2,1)) cross_bev_feature = cross_bev_feature.permute(0,2,1).contiguous().view(batch_size, -1, bev_spatial_shape[0], bev_spatial_shape[1]) query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1) query_out = self._tf_decoder(query, keyval) bev_semantic_map = self._bev_semantic_head(bev_feature_upscale) trajectory_query, agents_query = query_out.split(self._query_splits, dim=1) output = dict() trajectory = self._trajectory_head(trajectory_query,agents_query, cross_bev_feature, bev_spatial_shape,status_encoding[:, None], targets=targets,global_img=None,metric_cache=metric_cache, tokens=tokens) if self.rlft and self.training: loss_dict = trajectory['cost_loss'] return loss_dict output.update(trajectory) agents = self._agent_head(agents_query) if self.training: output.update(agents) output["bev_semantic_map"] = bev_semantic_map # compute loss: loss_dict = transfuser_loss(targets, output, self._config) return loss_dict zeros = torch.zeros(trajectory['trajectory'].shape[0],).to(trajectory['trajectory'].device) # idle pdm_dict = { "no_at_fault_collisions":zeros, "drivable_area_compliance":zeros, "ego_progress":zeros, "time_to_collision_within_bound":zeros, "comfort":zeros, "score":zeros, # 'chosen_ind':[] # 'token':[] } output.update(pdm_dict) return output class AgentHead(nn.Module): """Bounding box prediction head.""" def __init__( self, num_agents: int, d_ffn: int, d_model: int, ): """ Initializes prediction head. :param num_agents: maximum number of agents to predict :param d_ffn: dimensionality of feed-forward network :param d_model: input dimensionality """ super(AgentHead, self).__init__() self._num_objects = num_agents self._d_model = d_model self._d_ffn = d_ffn self._mlp_states = nn.Sequential( nn.Linear(self._d_model, self._d_ffn), nn.ReLU(), nn.Linear(self._d_ffn, BoundingBox2DIndex.size()), ) self._mlp_label = nn.Sequential( nn.Linear(self._d_model, 1), ) def forward(self, agent_queries) -> Dict[str, torch.Tensor]: """Torch module forward pass.""" agent_states = self._mlp_states(agent_queries) agent_states[..., BoundingBox2DIndex.POINT] = agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 agent_states[..., BoundingBox2DIndex.HEADING] = agent_states[..., BoundingBox2DIndex.HEADING].tanh() * np.pi agent_labels = self._mlp_label(agent_queries).squeeze(dim=-1) return {"agent_states": agent_states, "agent_labels": agent_labels} class DiffMotionPlanningRefinementModule(nn.Module): def __init__( self, embed_dims=256, ego_fut_ts=8, ego_fut_mode=20, if_zeroinit_reg=True, cost_learning=False, ): super(DiffMotionPlanningRefinementModule, self).__init__() self.embed_dims = embed_dims self.ego_fut_ts = ego_fut_ts self.ego_fut_mode = ego_fut_mode self.plan_cls_branch = nn.Sequential( *linear_relu_ln(embed_dims, 1, 2), nn.Linear(embed_dims, 1), ) self.plan_reg_branch = nn.Sequential( nn.Linear(embed_dims, embed_dims), nn.ReLU(), nn.Linear(embed_dims, embed_dims), nn.ReLU(), nn.Linear(embed_dims, ego_fut_ts * 3), ) self.if_zeroinit_reg = False self.cost_learning = cost_learning self.init_weight() def init_weight(self): if self.if_zeroinit_reg: nn.init.constant_(self.plan_reg_branch[-1].weight, 0) nn.init.constant_(self.plan_reg_branch[-1].bias, 0) bias_init = bias_init_with_prob(0.01) nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init) def forward( self, traj_feature, ): bs, ego_fut_mode, _ = traj_feature.shape # 6. get final prediction # traj_feature = traj_feature.view(bs, ego_fut_mode,-1) plan_cls = self.plan_cls_branch(traj_feature).squeeze(-1) traj_delta = self.plan_reg_branch(traj_feature) plan_reg = traj_delta.reshape(bs,ego_fut_mode, self.ego_fut_ts, 3) cost_dict=dict() # if self.cost_learning: # for k in self.heads.keys(): # cost_dict[k] = self.heads[k](traj_feature)[..., 0] return plan_reg, plan_cls class ModulationLayer(nn.Module): def __init__(self, embed_dims: int, condition_dims: int): super(ModulationLayer, self).__init__() self.if_zeroinit_scale=False self.embed_dims = embed_dims self.scale_shift_mlp = nn.Sequential( nn.Mish(), nn.Linear(condition_dims, embed_dims*2), ) self.init_weight() def init_weight(self): if self.if_zeroinit_scale: nn.init.constant_(self.scale_shift_mlp[-1].weight, 0) nn.init.constant_(self.scale_shift_mlp[-1].bias, 0) def forward( self, traj_feature, time_embed, global_cond=None, global_img=None, ): if global_cond is not None: global_feature = torch.cat([ global_cond, time_embed ], axis=-1) else: global_feature = time_embed if global_img is not None: global_img = global_img.flatten(2,3).permute(0,2,1).contiguous() global_feature = torch.cat([ global_img, global_feature ], axis=-1) scale_shift = self.scale_shift_mlp(global_feature) scale,shift = scale_shift.chunk(2,dim=-1) traj_feature = traj_feature * (1 + scale) + shift return traj_feature class CustomTransformerDecoderLayer(nn.Module): def __init__(self, num_poses, d_model, d_ffn, config, cost_learning, feat_out=False ): super().__init__() self.dropout = nn.Dropout(0.1) self.dropout1 = nn.Dropout(0.1) self.feat_out = feat_out self.cross_bev_attention = GridSampleCrossBEVAttention( config.tf_d_model, config.tf_num_head, num_points=num_poses, config=config, in_bev_dims=config.tf_d_model, ) self.cross_agent_attention = nn.MultiheadAttention( config.tf_d_model, config.tf_num_head, dropout=config.tf_dropout, batch_first=True, ) self.cross_ego_attention = nn.MultiheadAttention( config.tf_d_model, config.tf_num_head, dropout=config.tf_dropout, batch_first=True, ) self.ffn = nn.Sequential( nn.Linear(config.tf_d_model, config.tf_d_ffn), nn.ReLU(), nn.Linear(config.tf_d_ffn, config.tf_d_model), ) self.norm1 = nn.LayerNorm(config.tf_d_model) self.norm2 = nn.LayerNorm(config.tf_d_model) self.norm3 = nn.LayerNorm(config.tf_d_model) self.time_modulation = ModulationLayer(config.tf_d_model, config.tf_d_model) if feat_out==False: self.task_decoder = DiffMotionPlanningRefinementModule( embed_dims=config.tf_d_model, ego_fut_ts=num_poses, ego_fut_mode=20, cost_learning=cost_learning ) def forward(self, traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding, global_img=None): traj_feature = self.cross_bev_attention(traj_feature,noisy_traj_points,bev_feature,bev_spatial_shape) traj_feature = traj_feature + self.dropout(self.cross_agent_attention(traj_feature, agents_query,agents_query)[0]) traj_feature = self.norm1(traj_feature) # traj_feature = traj_feature + self.dropout(self.self_attn(traj_feature, traj_feature, traj_feature)[0]) # 4.5 cross attention with ego query traj_feature = traj_feature + self.dropout1(self.cross_ego_attention(traj_feature, ego_query,ego_query)[0]) traj_feature = self.norm2(traj_feature) # 4.6 feedforward network traj_feature = self.norm3(self.ffn(traj_feature)) # 4.8 modulate with time steps traj_feature = self.time_modulation(traj_feature, time_embed,global_cond=None,global_img=global_img) if self.feat_out: return traj_feature # 4.9 predict the offset & heading poses_reg, poses_cls = self.task_decoder(traj_feature) #bs,20,8,3; bs,20 poses_reg[...,:2] = poses_reg[...,:2] + noisy_traj_points poses_reg[..., StateSE2Index.HEADING] = poses_reg[..., StateSE2Index.HEADING].tanh() * np.pi return poses_reg, poses_cls, traj_feature def _get_clones(module, N): # FIXME: copy.deepcopy() is not defined on nn.module return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class CustomTransformerDecoder(nn.Module): def __init__( self, decoder_layer, num_layers, norm=None, ): super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers def forward(self, traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding, global_img=None): poses_reg_list = [] poses_cls_list = [] cost_dict_list = [] traj_points = noisy_traj_points for mod in self.layers: poses_reg, poses_cls, cost_dict = mod(traj_feature, traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding,global_img) poses_reg_list.append(poses_reg) poses_cls_list.append(poses_cls) cost_dict_list.append(cost_dict) traj_points = poses_reg[...,:2].clone().detach() return poses_reg_list, poses_cls_list, cost_dict_list from navsim.evaluate.pdm_score import batched_pdm_score, upsample_traj_nd_torch, wrap_to_pi from hydra.utils import instantiate from navsim.planning.simulation.planner.pdm_planner.simulation.pdm_simulator import PDMSimulator from navsim.planning.simulation.planner.pdm_planner.scoring.pdm_scorer import PDMScorer from navsim.common.dataclasses import Trajectory from omegaconf import OmegaConf from nuplan.planning.utils.multithreading.worker_pool import Task, WorkerPool from nuplan.planning.utils.multithreading.worker_ray import RayDistributed import os import warnings warnings.filterwarnings("ignore", category=RuntimeWarning) import ray from time import time from copy import deepcopy @ray.remote(num_cpus=1) class PDMWorker: def __init__(self, cfg_path): import os os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" import sys sys.path.append('/cpfs04/user/liuhaochen/AlgEngine_nuplan/navsim') # Manually append in the remote process from navsim.evaluate.pdm_score import batched_pdm_score as pdm_score_func self.batched_pdm_score = pdm_score_func self.cfg = OmegaConf.load(cfg_path) self.simulator = instantiate(self.cfg.simulator) self.scorer = instantiate(self.cfg.scorer) def compute(self, metric_cache, traj, token): results = [] for i in range(len(metric_cache)): pdm_result = self.batched_pdm_score( metric_cache[i], traj[i], self.simulator.proposal_sampling, self.simulator, self.scorer ) pdm_result['token'] = token[i] self.simulator.clear() self.scorer.clear() results.append(pdm_result) return results from ray.util.actor_pool import ActorPool from navsim.planning.utils.multithreading.worker_ray_no_torch import RayDistributedNoTorch from nuplan.planning.utils.multithreading.worker_utils import worker_map def init_ray_actor_pool(cfg_path, num_actors=7): if not ray.is_initialized(): ray.init( num_cpus=num_actors, _temp_dir=f"/tmp/ray_rank_{os.environ.get('RANK', '0')}", ignore_reinit_error=True, log_to_driver=False, ) print(ray.available_resources()) actors = [PDMWorker.remote(cfg_path) for _ in range(num_actors)] return ActorPool(actors) import sys import lzma import pickle from pathlib import Path def ray_worker_pdm_func(args): metric_cache = [x for a in args for x in a["metric_cache"]] traj_tensor = [x for a in args for x in a["traj_tensor"]] token = [x for a in args for x in a["tokens"]] cfg = args[0]['cfg'][0] cfg = OmegaConf.load(cfg) results = [] simulator = instantiate(cfg.simulator) scorer = instantiate(cfg.scorer) for i in range(len(metric_cache)): with lzma.open(metric_cache[i], "rb") as f: metric_cache_data = pickle.load(f) pdm_result = batched_pdm_score( metric_cache=metric_cache_data, model_trajectory=traj_tensor[i], future_sampling=simulator.proposal_sampling, simulator=simulator, scorer=scorer, ) pdm_result['token'] = token[i] simulator.clear() scorer.clear() results.append(pdm_result) return results from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, MOELoRALinear, LoRAMoECLAdapter, BayesianLinear, finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper, retreive_bayesian_lora_param) class TrajectoryHead(nn.Module): """Trajectory prediction head.""" def __init__(self, num_poses: int, d_ffn: int, d_model: int, plan_anchor_path: str,config: TransfuserConfig, rejection_sampling=1,cost_learning=False, pdm_config_path='',online_pdm_infer=False, train_sampling=1,rlft=False): """ Initializes trajectory head. :param num_poses: number of (x,y,θ) poses to predict :param d_ffn: dimensionality of feed-forward network :param d_model: input dimensionality """ super(TrajectoryHead, self).__init__() self._num_poses = num_poses self._d_model = d_model self._d_ffn = d_ffn self.diff_loss_weight = 2.0 self.ego_fut_mode = 20 self.rejection_sampling = rejection_sampling self.train_sampling = train_sampling self.rlft = rlft self.diffusion_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_schedule="scaled_linear", prediction_type="sample", ) plan_anchor = np.load(plan_anchor_path) self.plan_anchor = nn.Parameter( torch.tensor(plan_anchor[:, 4::5, :2], dtype=torch.float32), requires_grad=False, ) # 20,8,2 self.plan_anchor_encoder = nn.Sequential( *linear_relu_ln(d_model, 1, 1,512), nn.Linear(d_model, d_model), ) self.time_mlp = nn.Sequential( SinusoidalPosEmb(d_model), nn.Linear(d_model, d_model * 4), nn.Mish(), nn.Linear(d_model * 4, d_model), ) diff_decoder_layer = CustomTransformerDecoderLayer( num_poses=num_poses, d_model=d_model, d_ffn=d_ffn, config=config, cost_learning=cost_learning, ) self.diff_decoder = CustomTransformerDecoder(diff_decoder_layer, 2) if self.rlft: self.cost_diff_decoder = nn.ModuleList([CustomTransformerDecoderLayer( num_poses=num_poses, d_model=d_model, d_ffn=d_ffn, config=config, cost_learning=cost_learning, feat_out=True ) for _ in range(2)]) self.loss_computer = LossComputer(config, self.train_sampling) self.cost_learning = cost_learning self.online_pdm_infer = online_pdm_infer if self.cost_learning: self.heads = nn.ModuleDict({ 'noc': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), ), 'da': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), ), 'ttc': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), ), 'comfort': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), ), 'progress': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), ), 'pdms': nn.Sequential( nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1), ), }) if self.rlft: self.heads_lora = lora_wrapper(self.heads, LoRALinear, rank=16, alpha=1.0, dropout=0.1,num_task=4) finetuning_detach(self.heads) cfg_path = pdm_config_path print('online safe RL scoring, load from:',cfg_path) cfg = cfg_path self.cfg = cfg self.chunk_size = 4 self.worker = RayDistributedNoTorch(threads_per_node=16) self.init_metric_cache_paths() def init_metric_cache_paths(self): """ Helper function to load all cache file paths from folder. :param cache_path: directory of cache folder :return: dictionary of token and file path """ metadata_file = '/xxx/navsim_metric_cache_navtrain2_metadata_node_0.csv' with open(str(metadata_file), "r") as f: cache_paths = f.read().splitlines()[1:] self.metric_cache_dict = {cache_path.split("/")[-2]: cache_path for cache_path in cache_paths} #video_file: vidmetadata_file = '/xxx/navsim_metric_cache_navtrain2/metadata/navsim_metric_cache_navtrain2_metadata_node_0.csv' with open(str(vidmetadata_file), "r") as f: vid_cache_paths = f.read().splitlines()[1:] self.metric_cache_dict.update({cache_path.split("/")[-2]: cache_path for cache_path in vid_cache_paths}) print(len(list(self.metric_cache_dict.keys()))) def norm_odo(self, odo_info_fut): odo_info_fut_x = odo_info_fut[..., 0:1] odo_info_fut_y = odo_info_fut[..., 1:2] odo_info_fut_head = odo_info_fut[..., 2:3] odo_info_fut_x = 2*(odo_info_fut_x + 1.2)/56.9 -1 odo_info_fut_y = 2*(odo_info_fut_y + 20)/46 -1 odo_info_fut_head = 2*(odo_info_fut_head + 2)/3.9 -1 return torch.cat([odo_info_fut_x, odo_info_fut_y, odo_info_fut_head], dim=-1) def denorm_odo(self, odo_info_fut): odo_info_fut_x = odo_info_fut[..., 0:1] odo_info_fut_y = odo_info_fut[..., 1:2] odo_info_fut_head = odo_info_fut[..., 2:3] odo_info_fut_x = (odo_info_fut_x + 1)/2 * 56.9 - 1.2 odo_info_fut_y = (odo_info_fut_y + 1)/2 * 46 - 20 odo_info_fut_head = (odo_info_fut_head + 1)/2 * 3.9 - 2 return torch.cat([odo_info_fut_x, odo_info_fut_y, odo_info_fut_head], dim=-1) def forward(self, ego_query, agents_query, bev_feature,bev_spatial_shape, status_encoding, targets=None,global_img=None, metric_cache=None,tokens=None) -> Dict[str, torch.Tensor]: """Torch module forward pass.""" if self.training: return self.forward_train(ego_query, agents_query, bev_feature,bev_spatial_shape,status_encoding, targets,global_img, metric_cache,tokens) else: with torch.no_grad(): return self.forward_test(ego_query, agents_query, bev_feature,bev_spatial_shape,status_encoding,global_img,tokens) def online_cost_loss(self, cost_dict, gt_pdm_score, idx=0): loss = {} loss[f'loss.noc_{idx}'] = 3*torch.mean( F.binary_cross_entropy_with_logits(cost_dict['noc'], gt_pdm_score['no_at_fault_collisions'])) loss[f'loss.da_{idx}'] = 3*torch.mean( F.binary_cross_entropy_with_logits(cost_dict['da'], gt_pdm_score['drivable_area_compliance'])) loss[f'loss.ttc_{idx}'] = 2*torch.mean( F.binary_cross_entropy_with_logits(cost_dict['ttc'], gt_pdm_score['time_to_collision_within_bound'])) loss[f'loss.ep_{idx}'] = torch.mean( F.binary_cross_entropy_with_logits(cost_dict['progress'], gt_pdm_score['ego_progress'])) loss[f'loss.comfort_{idx}'] = torch.mean( F.binary_cross_entropy_with_logits( cost_dict['comfort'], gt_pdm_score['comfort'])) score = (5 * cost_dict['ttc'].sigmoid() + 5 * cost_dict['progress'].sigmoid() + 2*cost_dict['comfort'].sigmoid() ) / 12 score = cost_dict['noc'].sigmoid() * cost_dict['da'].sigmoid() * score loss[f'loss.RL_score_{idx}'] = 5 * torch.mean( F.binary_cross_entropy( score, gt_pdm_score['score']) ) loss[f'loss.score_{idx}'] = 5 * torch.mean( F.binary_cross_entropy_with_logits( cost_dict['pdms'], gt_pdm_score['score']) ) return loss def online_batch_pdm_calculation(self, traj_tensor, metric_cache,traj_len,tokens): b, n, t, d = traj_tensor.shape device = traj_tensor.device # make batched interpolation: traj_tensor = traj_tensor.view(b*n, t, d) zeros = torch.zeros((b*n, 1, d)).to(device) full_traj_tensor = torch.cat([zeros, traj_tensor], dim=1) interped_traj_tensor = upsample_traj_nd_torch(full_traj_tensor, dt_in=0.5, dt_out=0.1) interped_traj_tensor = interped_traj_tensor.reshape(b, n, -1, d) interped_traj_tensor[..., -1] = wrap_to_pi(interped_traj_tensor[..., -1]) traj_tensor = interped_traj_tensor.detach().cpu().numpy() batched_pdm_results = dict( no_at_fault_collisions=[], drivable_area_compliance=[], ego_progress=[], time_to_collision_within_bound=[], comfort=[], driving_direction_compliance=[], score=[], ) mc_chunk, traj_chunk, tok_chunk = [],[],[] worker_args = [] for i in range(0, len(traj_tensor), self.chunk_size): traj_chunk.append(traj_tensor[i:i+self.chunk_size]) tok_chunk.append(tokens[i:i+self.chunk_size]) if self.training: worker_args.append( dict( metric_cache = [self.metric_cache_dict[t] for t in tokens[i:i+self.chunk_size]], traj_tensor=traj_tensor[i:i+self.chunk_size], tokens=tokens[i:i+self.chunk_size], cfg=[self.cfg]*self.chunk_size )) else: buf_cache_token = [] for t in tokens[i:i+self.chunk_size]: if t not in self.metric_cache_dict.keys(): buf_cache_token.append(self.metric_cache_dict[list(self.metric_cache_dict.keys())[0]]) else: buf_cache_token.append(self.metric_cache_dict[t]) worker_args.append( dict( metric_cache =buf_cache_token, traj_tensor=traj_tensor[i:i+self.chunk_size], tokens=tokens[i:i+self.chunk_size], cfg=[self.cfg]*self.chunk_size )) args = zip(mc_chunk, traj_chunk, tok_chunk) # You can use map with a lambda to call .compute on each actor pdm_results = worker_map(self.worker, ray_worker_pdm_func, worker_args) pdm_results_dict = { pdm['token']:pdm for pdm in pdm_results } for i in range(b): pdm_res = pdm_results_dict[tokens[i]] for k in batched_pdm_results.keys(): batched_pdm_results[k].append(pdm_res[k]) ret_batch_pdm_results_list = [dict() for _ in range(traj_len)] for k,v in batched_pdm_results.items(): buf_value = torch.from_numpy( np.stack(v, axis=0) ).to(device) if self.train_sampling > 1: buf_value = buf_value.view(b, self.train_sampling, n//self.train_sampling) buf_value = buf_value.view(b*self.train_sampling, -1) buf_value = buf_value.reshape(b*self.train_sampling, traj_len, -1) for i in range(traj_len): ret_batch_pdm_results_list[i][k] = buf_value[:, i] return ret_batch_pdm_results_list def expand_and_reshape(self, x, k=4): B, rest = x.shape[0], x.shape[1:] # arbitrary dims x = x.unsqueeze(1).repeat(1, k, *([1] * len(rest))) # [B, A, k, ...] # Step 2: reshape so second dimension becomes A*k return x.view(B*k, *rest) # [B, A*k, ...] def forward_train(self, ego_query,agents_query,bev_feature,bev_spatial_shape,status_encoding, targets=None,global_img=None, metric_cache=None,tokens=None) -> Dict[str, torch.Tensor]: bs = ego_query.shape[0] device = ego_query.device # 1. add truncated noise to the plan anchor plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs,1,1,1) odo_info_fut = self.norm_odo(plan_anchor) if self.train_sampling > 1: odo_info_fut = self.expand_and_reshape(odo_info_fut, k=self.train_sampling) agents_query = self.expand_and_reshape(agents_query, k=self.train_sampling) ego_query = self.expand_and_reshape(ego_query, k=self.train_sampling) bev_feature = self.expand_and_reshape(bev_feature, k=self.train_sampling) status_encoding = self.expand_and_reshape(status_encoding, k=self.train_sampling) bs = bs*self.train_sampling timesteps = torch.randint( 0, 50, (bs,), device=device ) noise = torch.randn(odo_info_fut.shape, device=device) noisy_traj_points = self.diffusion_scheduler.add_noise( original_samples=odo_info_fut, noise=noise, timesteps=timesteps, ).float() noisy_traj_points = torch.clamp(noisy_traj_points, min=-1, max=1) noisy_traj_points = self.denorm_odo(noisy_traj_points) ego_fut_mode = noisy_traj_points.shape[1] # 2. proj noisy_traj_points to the query traj_pos_embed = gen_sineembed_for_position(noisy_traj_points,hidden_dim=64) traj_pos_embed = traj_pos_embed.flatten(-2) traj_feature = self.plan_anchor_encoder(traj_pos_embed) traj_feature = traj_feature.view(bs,ego_fut_mode,-1) # 3. embed the timesteps time_embed = self.time_mlp(timesteps) time_embed = time_embed.view(bs,1,-1) # 4. begin the stacked decoder poses_reg_list, poses_cls_list, traj_feat_list = self.diff_decoder( traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding,global_img) trajectory_loss_dict = {} ret_traj_loss = 0 if self.cost_learning: cost_dict_list = [] for idx, feat in enumerate(traj_feat_list): cost_dict =dict() if self.rlft: feat = self.cost_diff_decoder[idx]( feat.detach(), poses_reg_list[idx][..., :2].detach(), bev_feature.detach(), bev_spatial_shape, agents_query.detach(), ego_query.detach(), time_embed.detach(), status_encoding.detach() ) for k in self.heads.keys(): cost_dict[k] = peft_wrapper_forward(feat, self.heads[k], self.heads_lora['lora_'+k])[..., 0] else: for k in self.heads.keys(): cost_dict[k] = self.heads[k](feat)[..., 0] cost_dict_list.append(cost_dict) # assert metric_cache is not None b, num_mode, ts, d = poses_reg_list[0].shape stack_trajs = torch.cat(poses_reg_list, dim=1) if self.train_sampling > 1: stack_trajs = stack_trajs.view(bs//self.train_sampling, self.train_sampling, -1, ts, d) stack_trajs = stack_trajs.view(bs//self.train_sampling, -1, ts, d) batched_pdm_score_list = self.online_batch_pdm_calculation( stack_trajs, None, len(poses_reg_list),tokens ) cost_dict_loss = dict() for idx, (poses_reg, poses_cls) in enumerate(zip(poses_reg_list, poses_cls_list)): trajectory_loss = self.loss_computer(poses_reg, poses_cls, targets, plan_anchor) trajectory_loss_dict[f"loss.planning.{idx}"] = trajectory_loss ret_traj_loss += trajectory_loss if self.cost_learning: cost_dict_loss.update( self.online_cost_loss( cost_dict_list[idx], batched_pdm_score_list[idx], idx ) ) mode_idx = poses_cls_list[-1].argmax(dim=-1) mode_idx = mode_idx[...,None,None,None].repeat(1,1,self._num_poses,3) best_reg = torch.gather(poses_reg_list[-1], 1, mode_idx).squeeze(1) ret_dict = dict(trajectory=best_reg) ret_dict.update(trajectory_loss_dict) if self.cost_learning: return {'cost_loss':cost_dict_loss,"trajectory": best_reg,"trajectory_loss":ret_traj_loss} return {"trajectory": best_reg,"trajectory_loss":ret_traj_loss} def cal_test_rl_score(self, cost_dict): score = (5 * cost_dict['ttc'].sigmoid() + 5 * cost_dict['progress'].sigmoid() + 2*cost_dict['comfort'].sigmoid() ) / 12 score = (cost_dict['noc'].sigmoid()).float() * (cost_dict['da'].sigmoid()).float() * score return score def forward_test(self, ego_query,agents_query,bev_feature,bev_spatial_shape,status_encoding,global_img,tokens) -> Dict[str, torch.Tensor]: step_num = 2 bs = ego_query.shape[0] device = ego_query.device step_ratio = 20 / step_num roll_timesteps = (np.arange(0, step_num) * step_ratio).round()[::-1].copy().astype(np.int64) roll_timesteps = torch.from_numpy(roll_timesteps).to(device) # 1. add truncated noise to the plan anchor plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs,1,1,1) org_img = self.norm_odo(plan_anchor) trajs, scores , costs= [], [], [] for _ in range(100): self.diffusion_scheduler.set_timesteps(1000, device) timesteps = torch.randint( 0, 5, (bs,), device=device ) noise = torch.randn(org_img.shape, device=device) trunc_timesteps = torch.ones((bs,), device=device, dtype=torch.long) * 8 img = self.diffusion_scheduler.add_noise(original_samples=org_img, noise=noise, timesteps=timesteps) # noisy_trajs = self.denorm_odo(img) ego_fut_mode = img.shape[1] # for k in roll_timesteps[:]: if True: x_boxes = torch.clamp(img, min=-1, max=1) noisy_traj_points = self.denorm_odo(x_boxes) # 2. proj noisy_traj_points to the query traj_pos_embed = gen_sineembed_for_position(noisy_traj_points,hidden_dim=64) traj_pos_embed = traj_pos_embed.flatten(-2) traj_feature = self.plan_anchor_encoder(traj_pos_embed) traj_feature = traj_feature.view(bs,ego_fut_mode,-1) # timesteps = k if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can timesteps = torch.tensor([timesteps], dtype=torch.long, device=img.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(img.device) # 3. embed the timesteps # timesteps = timesteps.expand(img.shape[0]) time_embed = self.time_mlp(timesteps) time_embed = time_embed.view(bs,1,-1) # 4. begin the stacked decoder poses_reg_list, poses_cls_list, traj_feat_list = self.diff_decoder(traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding,global_img) if self.cost_learning: cost_list = [] for idx, feat in enumerate(traj_feat_list): cost_dict =dict() if self.rlft: feat = self.cost_diff_decoder[idx]( feat.detach(), poses_reg_list[idx][..., :2].detach(), bev_feature.detach(), bev_spatial_shape, agents_query.detach(), ego_query.detach(), time_embed.detach(), status_encoding.detach() ) for h in self.heads.keys(): cost_dict[h] = peft_wrapper_forward(feat, self.heads[h], self.heads_lora['lora_'+h])[..., 0] else: for h in self.heads.keys(): cost_dict[h] = self.heads[h](feat)[..., 0] cost_list.append(cost_dict) else: cost_list = [0]*len(poses_cls_list) poses_reg = poses_reg_list[-1] poses_cls = poses_cls_list[-1] poses_cost = cost_list[-1] x_start = poses_reg[...,:2] x_start = self.norm_odo(x_start) human_score = poses_cls.argmax(1) trajs.append(poses_reg[torch.arange(bs)[:, None], human_score]) scores.append(poses_cls.softmax(-1)[torch.arange(bs)[:, None], human_score]) costs.append({k:v[torch.arange(bs)[:, None], human_score] for k,v in poses_cost.items()}) trajs = torch.cat(trajs, dim=1) scores = torch.cat(scores, dim=1) mode_idx = torch.argmax(scores, dim=1) if self.cost_learning: costs = [(self.cal_test_rl_score(cost)+cost['pdms'].sigmoid())/2 for cost in costs] costs = torch.cat(costs,dim=1) # costs[idx] = 0 mode_idx = torch.argmax(costs, dim=1) if self.online_pdm_infer: b, m, t, d = trajs.shape batched_pdm_score_list = self.online_batch_pdm_calculation( trajs, None, 1, tokens ) mode_idx = torch.argmax(batched_pdm_score_list[0]['score'], dim=1) b = trajs.shape[0] best_reg = trajs[torch.arange(b), mode_idx] return {"trajectory": best_reg}