unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
# ---------------------------------------------------------------------------------#
# UniAD: Planning-oriented Autonomous Driving (https://arxiv.org/abs/2212.10156) #
# Source code: https://github.com/OpenDriveLab/UniAD #
# Copyright (c) OpenDriveLab. All rights reserved. #
# ---------------------------------------------------------------------------------#
import sys
from mmdet.models import DETECTORS
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector,Base3DDetector
from navsim.agents.diffusiondrive.transfuser_loss import transfuser_loss
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import copy
from navsim.agents.diffusiondrive.transfuser_config import TransfuserConfig
from navsim.agents.diffusiondrive.transfuser_backbone import TransfuserBackbone
from navsim.agents.diffusiondrive.transfuser_features import BoundingBox2DIndex
from navsim.agents.diffusiondrive.modules.blocks import linear_relu_ln,bias_init_with_prob, gen_sineembed_for_position, GridSampleCrossBEVAttention
from navsim.agents.diffusiondrive.modules.multimodal_loss import LossComputer
from navsim.common.enums import StateSE2Index
from diffusers.schedulers import DDIMScheduler
from navsim.agents.diffusiondrive.modules.conditional_unet1d import ConditionalUnet1D,SinusoidalPosEmb
import torch.nn.functional as F
from typing import Any, List, Dict, Optional, Union
@DETECTORS.register_module()
class V2TransfuserModel(Base3DDetector):
def __init__(
self,
rejection_sampling=1,
train_sampling=1,
online_cost_learning=False,
pdm_config_path='',
image_architecture=None,
bkb_path=None,
online_pdm_infer=False,
rlft=False,
*args,
**kwargs):
"""
Initializes TransFuser torch module.
:param config: global config dataclass of TransFuser.
"""
super().__init__()
config = TransfuserConfig()
if image_architecture is not None:
config.image_architecture = image_architecture
config.bkb_path = bkb_path
self.rejection_sampling = rejection_sampling
self.online_cost_learning = online_cost_learning
self.online_pdm_infer = online_pdm_infer
self.rlft = rlft
self._query_splits = [
1,
config.num_bounding_boxes,
]
self._config = config
self._backbone = TransfuserBackbone(config)
self._keyval_embedding = nn.Embedding(8**2 + 1, config.tf_d_model) # 8x8 feature grid + trajectory
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model)
# usually, the BEV features are variable in size.
self._bev_downscale = nn.Conv2d(512, config.tf_d_model, kernel_size=1)
self._status_encoding = nn.Linear(4 + 2 + 2, config.tf_d_model)
# self._status_encoding = nn.Linear(4 + 2, config.tf_d_model)
self._bev_semantic_head = nn.Sequential(
nn.Conv2d(
config.bev_features_channels,
config.bev_features_channels,
kernel_size=(3, 3),
stride=1,
padding=(1, 1),
bias=True,
),
nn.ReLU(inplace=True),
nn.Conv2d(
config.bev_features_channels,
config.num_bev_classes,
kernel_size=(1, 1),
stride=1,
padding=0,
bias=True,
),
nn.Upsample(
size=(config.lidar_resolution_height // 2, config.lidar_resolution_width),
mode="bilinear",
align_corners=False,
),
)
tf_decoder_layer = nn.TransformerDecoderLayer(
d_model=config.tf_d_model,
nhead=config.tf_num_head,
dim_feedforward=config.tf_d_ffn,
dropout=config.tf_dropout,
batch_first=True,
)
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers)
self._agent_head = AgentHead(
num_agents=config.num_bounding_boxes,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
)
self._trajectory_head = TrajectoryHead(
num_poses=config.trajectory_sampling.num_poses,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
plan_anchor_path=config.plan_anchor_path,
config=config,
rejection_sampling=rejection_sampling,
cost_learning=online_cost_learning,
pdm_config_path=pdm_config_path,
online_pdm_infer=online_pdm_infer,
train_sampling=train_sampling,
rlft=rlft
)
self.bev_proj = nn.Sequential(
*linear_relu_ln(config.tf_d_model, 1, 1,config.tf_d_model+64),
)
def aug_test(self, *args, **kwargs):
raise NotImplementedError("aug_test not implemented")
def extract_feat(self, *args, **kwargs):
raise NotImplementedError("extract_feat not implemented")
def simple_test(self, *args, **kwargs):
raise NotImplementedError("simple_test not implemented")
def forward(self,
camera_feature,
lidar_feature,
status_feature,
**kwargs):
"""Torch module forward pass."""
# camera_feature: torch.Tensor = features["camera_feature"]
# lidar_feature: torch.Tensor = features["lidar_feature"]
# status_feature: torch.Tensor = features["status_feature"]
metric_cache = None
targets = None
tokens = kwargs['token']
if self.training:
targets = dict()
targets['trajectory'] = kwargs['trajectory']
targets['agent_states'] = kwargs['agent_states']
targets['agent_labels'] = kwargs['agent_labels']
targets['bev_semantic_map'] = kwargs['bev_semantic_map']
if 'metric_cache' in kwargs.keys():
metric_cache = kwargs['metric_cache']
if self.rlft:
self._backbone.eval()
batch_size = status_feature.shape[0]
bev_feature_upscale, bev_feature, _ = self._backbone(camera_feature, lidar_feature)
cross_bev_feature = bev_feature_upscale
bev_spatial_shape = bev_feature_upscale.shape[2:]
concat_cross_bev_shape = bev_feature.shape[2:]
bev_feature = self._bev_downscale(bev_feature).flatten(-2, -1)
bev_feature = bev_feature.permute(0, 2, 1)
status_encoding = self._status_encoding(status_feature)#[..., :6])
keyval = torch.cat([bev_feature, status_encoding[:, None]], dim=1)
keyval += self._keyval_embedding.weight[None, ...]
concat_cross_bev = keyval[:,:-1].permute(0,2,1).contiguous().view(batch_size, -1, concat_cross_bev_shape[0], concat_cross_bev_shape[1])
# upsample to the same shape as bev_feature_upscale
concat_cross_bev = F.interpolate(concat_cross_bev, size=bev_spatial_shape, mode='bilinear', align_corners=False)
# concat concat_cross_bev and cross_bev_feature
cross_bev_feature = torch.cat([concat_cross_bev, cross_bev_feature], dim=1)
cross_bev_feature = self.bev_proj(cross_bev_feature.flatten(-2,-1).permute(0,2,1))
cross_bev_feature = cross_bev_feature.permute(0,2,1).contiguous().view(batch_size, -1, bev_spatial_shape[0], bev_spatial_shape[1])
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1)
query_out = self._tf_decoder(query, keyval)
bev_semantic_map = self._bev_semantic_head(bev_feature_upscale)
trajectory_query, agents_query = query_out.split(self._query_splits, dim=1)
output = dict()
trajectory = self._trajectory_head(trajectory_query,agents_query, cross_bev_feature,
bev_spatial_shape,status_encoding[:, None],
targets=targets,global_img=None,metric_cache=metric_cache,
tokens=tokens)
if self.rlft and self.training:
loss_dict = trajectory['cost_loss']
return loss_dict
output.update(trajectory)
agents = self._agent_head(agents_query)
if self.training:
output.update(agents)
output["bev_semantic_map"] = bev_semantic_map
# compute loss:
loss_dict = transfuser_loss(targets, output, self._config)
return loss_dict
zeros = torch.zeros(trajectory['trajectory'].shape[0],).to(trajectory['trajectory'].device)
# idle
pdm_dict = {
"no_at_fault_collisions":zeros,
"drivable_area_compliance":zeros,
"ego_progress":zeros,
"time_to_collision_within_bound":zeros,
"comfort":zeros,
"score":zeros,
# 'chosen_ind':[]
# 'token':[]
}
output.update(pdm_dict)
return output
class AgentHead(nn.Module):
"""Bounding box prediction head."""
def __init__(
self,
num_agents: int,
d_ffn: int,
d_model: int,
):
"""
Initializes prediction head.
:param num_agents: maximum number of agents to predict
:param d_ffn: dimensionality of feed-forward network
:param d_model: input dimensionality
"""
super(AgentHead, self).__init__()
self._num_objects = num_agents
self._d_model = d_model
self._d_ffn = d_ffn
self._mlp_states = nn.Sequential(
nn.Linear(self._d_model, self._d_ffn),
nn.ReLU(),
nn.Linear(self._d_ffn, BoundingBox2DIndex.size()),
)
self._mlp_label = nn.Sequential(
nn.Linear(self._d_model, 1),
)
def forward(self, agent_queries) -> Dict[str, torch.Tensor]:
"""Torch module forward pass."""
agent_states = self._mlp_states(agent_queries)
agent_states[..., BoundingBox2DIndex.POINT] = agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
agent_states[..., BoundingBox2DIndex.HEADING] = agent_states[..., BoundingBox2DIndex.HEADING].tanh() * np.pi
agent_labels = self._mlp_label(agent_queries).squeeze(dim=-1)
return {"agent_states": agent_states, "agent_labels": agent_labels}
class DiffMotionPlanningRefinementModule(nn.Module):
def __init__(
self,
embed_dims=256,
ego_fut_ts=8,
ego_fut_mode=20,
if_zeroinit_reg=True,
cost_learning=False,
):
super(DiffMotionPlanningRefinementModule, self).__init__()
self.embed_dims = embed_dims
self.ego_fut_ts = ego_fut_ts
self.ego_fut_mode = ego_fut_mode
self.plan_cls_branch = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 2),
nn.Linear(embed_dims, 1),
)
self.plan_reg_branch = nn.Sequential(
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, embed_dims),
nn.ReLU(),
nn.Linear(embed_dims, ego_fut_ts * 3),
)
self.if_zeroinit_reg = False
self.cost_learning = cost_learning
self.init_weight()
def init_weight(self):
if self.if_zeroinit_reg:
nn.init.constant_(self.plan_reg_branch[-1].weight, 0)
nn.init.constant_(self.plan_reg_branch[-1].bias, 0)
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init)
def forward(
self,
traj_feature,
):
bs, ego_fut_mode, _ = traj_feature.shape
# 6. get final prediction
# traj_feature = traj_feature.view(bs, ego_fut_mode,-1)
plan_cls = self.plan_cls_branch(traj_feature).squeeze(-1)
traj_delta = self.plan_reg_branch(traj_feature)
plan_reg = traj_delta.reshape(bs,ego_fut_mode, self.ego_fut_ts, 3)
cost_dict=dict()
# if self.cost_learning:
# for k in self.heads.keys():
# cost_dict[k] = self.heads[k](traj_feature)[..., 0]
return plan_reg, plan_cls
class ModulationLayer(nn.Module):
def __init__(self, embed_dims: int, condition_dims: int):
super(ModulationLayer, self).__init__()
self.if_zeroinit_scale=False
self.embed_dims = embed_dims
self.scale_shift_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(condition_dims, embed_dims*2),
)
self.init_weight()
def init_weight(self):
if self.if_zeroinit_scale:
nn.init.constant_(self.scale_shift_mlp[-1].weight, 0)
nn.init.constant_(self.scale_shift_mlp[-1].bias, 0)
def forward(
self,
traj_feature,
time_embed,
global_cond=None,
global_img=None,
):
if global_cond is not None:
global_feature = torch.cat([
global_cond, time_embed
], axis=-1)
else:
global_feature = time_embed
if global_img is not None:
global_img = global_img.flatten(2,3).permute(0,2,1).contiguous()
global_feature = torch.cat([
global_img, global_feature
], axis=-1)
scale_shift = self.scale_shift_mlp(global_feature)
scale,shift = scale_shift.chunk(2,dim=-1)
traj_feature = traj_feature * (1 + scale) + shift
return traj_feature
class CustomTransformerDecoderLayer(nn.Module):
def __init__(self,
num_poses,
d_model,
d_ffn,
config,
cost_learning,
feat_out=False
):
super().__init__()
self.dropout = nn.Dropout(0.1)
self.dropout1 = nn.Dropout(0.1)
self.feat_out = feat_out
self.cross_bev_attention = GridSampleCrossBEVAttention(
config.tf_d_model,
config.tf_num_head,
num_points=num_poses,
config=config,
in_bev_dims=config.tf_d_model,
)
self.cross_agent_attention = nn.MultiheadAttention(
config.tf_d_model,
config.tf_num_head,
dropout=config.tf_dropout,
batch_first=True,
)
self.cross_ego_attention = nn.MultiheadAttention(
config.tf_d_model,
config.tf_num_head,
dropout=config.tf_dropout,
batch_first=True,
)
self.ffn = nn.Sequential(
nn.Linear(config.tf_d_model, config.tf_d_ffn),
nn.ReLU(),
nn.Linear(config.tf_d_ffn, config.tf_d_model),
)
self.norm1 = nn.LayerNorm(config.tf_d_model)
self.norm2 = nn.LayerNorm(config.tf_d_model)
self.norm3 = nn.LayerNorm(config.tf_d_model)
self.time_modulation = ModulationLayer(config.tf_d_model, config.tf_d_model)
if feat_out==False:
self.task_decoder = DiffMotionPlanningRefinementModule(
embed_dims=config.tf_d_model,
ego_fut_ts=num_poses,
ego_fut_mode=20,
cost_learning=cost_learning
)
def forward(self,
traj_feature,
noisy_traj_points,
bev_feature,
bev_spatial_shape,
agents_query,
ego_query,
time_embed,
status_encoding,
global_img=None):
traj_feature = self.cross_bev_attention(traj_feature,noisy_traj_points,bev_feature,bev_spatial_shape)
traj_feature = traj_feature + self.dropout(self.cross_agent_attention(traj_feature, agents_query,agents_query)[0])
traj_feature = self.norm1(traj_feature)
# traj_feature = traj_feature + self.dropout(self.self_attn(traj_feature, traj_feature, traj_feature)[0])
# 4.5 cross attention with ego query
traj_feature = traj_feature + self.dropout1(self.cross_ego_attention(traj_feature, ego_query,ego_query)[0])
traj_feature = self.norm2(traj_feature)
# 4.6 feedforward network
traj_feature = self.norm3(self.ffn(traj_feature))
# 4.8 modulate with time steps
traj_feature = self.time_modulation(traj_feature, time_embed,global_cond=None,global_img=global_img)
if self.feat_out:
return traj_feature
# 4.9 predict the offset & heading
poses_reg, poses_cls = self.task_decoder(traj_feature) #bs,20,8,3; bs,20
poses_reg[...,:2] = poses_reg[...,:2] + noisy_traj_points
poses_reg[..., StateSE2Index.HEADING] = poses_reg[..., StateSE2Index.HEADING].tanh() * np.pi
return poses_reg, poses_cls, traj_feature
def _get_clones(module, N):
# FIXME: copy.deepcopy() is not defined on nn.module
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class CustomTransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
):
super().__init__()
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
def forward(self,
traj_feature,
noisy_traj_points,
bev_feature,
bev_spatial_shape,
agents_query,
ego_query,
time_embed,
status_encoding,
global_img=None):
poses_reg_list = []
poses_cls_list = []
cost_dict_list = []
traj_points = noisy_traj_points
for mod in self.layers:
poses_reg, poses_cls, cost_dict = mod(traj_feature, traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding,global_img)
poses_reg_list.append(poses_reg)
poses_cls_list.append(poses_cls)
cost_dict_list.append(cost_dict)
traj_points = poses_reg[...,:2].clone().detach()
return poses_reg_list, poses_cls_list, cost_dict_list
from navsim.evaluate.pdm_score import batched_pdm_score, upsample_traj_nd_torch, wrap_to_pi
from hydra.utils import instantiate
from navsim.planning.simulation.planner.pdm_planner.simulation.pdm_simulator import PDMSimulator
from navsim.planning.simulation.planner.pdm_planner.scoring.pdm_scorer import PDMScorer
from navsim.common.dataclasses import Trajectory
from omegaconf import OmegaConf
from nuplan.planning.utils.multithreading.worker_pool import Task, WorkerPool
from nuplan.planning.utils.multithreading.worker_ray import RayDistributed
import os
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
import ray
from time import time
from copy import deepcopy
@ray.remote(num_cpus=1)
class PDMWorker:
def __init__(self, cfg_path):
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
import sys
sys.path.append('/cpfs04/user/liuhaochen/AlgEngine_nuplan/navsim') # Manually append in the remote process
from navsim.evaluate.pdm_score import batched_pdm_score as pdm_score_func
self.batched_pdm_score = pdm_score_func
self.cfg = OmegaConf.load(cfg_path)
self.simulator = instantiate(self.cfg.simulator)
self.scorer = instantiate(self.cfg.scorer)
def compute(self, metric_cache, traj, token):
results = []
for i in range(len(metric_cache)):
pdm_result = self.batched_pdm_score(
metric_cache[i], traj[i], self.simulator.proposal_sampling, self.simulator, self.scorer
)
pdm_result['token'] = token[i]
self.simulator.clear()
self.scorer.clear()
results.append(pdm_result)
return results
from ray.util.actor_pool import ActorPool
from navsim.planning.utils.multithreading.worker_ray_no_torch import RayDistributedNoTorch
from nuplan.planning.utils.multithreading.worker_utils import worker_map
def init_ray_actor_pool(cfg_path, num_actors=7):
if not ray.is_initialized():
ray.init(
num_cpus=num_actors,
_temp_dir=f"/tmp/ray_rank_{os.environ.get('RANK', '0')}",
ignore_reinit_error=True,
log_to_driver=False,
)
print(ray.available_resources())
actors = [PDMWorker.remote(cfg_path) for _ in range(num_actors)]
return ActorPool(actors)
import sys
import lzma
import pickle
from pathlib import Path
def ray_worker_pdm_func(args):
metric_cache = [x for a in args for x in a["metric_cache"]]
traj_tensor = [x for a in args for x in a["traj_tensor"]]
token = [x for a in args for x in a["tokens"]]
cfg = args[0]['cfg'][0]
cfg = OmegaConf.load(cfg)
results = []
simulator = instantiate(cfg.simulator)
scorer = instantiate(cfg.scorer)
for i in range(len(metric_cache)):
with lzma.open(metric_cache[i], "rb") as f:
metric_cache_data = pickle.load(f)
pdm_result = batched_pdm_score(
metric_cache=metric_cache_data,
model_trajectory=traj_tensor[i],
future_sampling=simulator.proposal_sampling,
simulator=simulator,
scorer=scorer,
)
pdm_result['token'] = token[i]
simulator.clear()
scorer.clear()
results.append(pdm_result)
return results
from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, MOELoRALinear,
LoRAMoECLAdapter, BayesianLinear,
finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper, retreive_bayesian_lora_param)
class TrajectoryHead(nn.Module):
"""Trajectory prediction head."""
def __init__(self, num_poses: int, d_ffn: int, d_model: int, plan_anchor_path: str,config: TransfuserConfig,
rejection_sampling=1,cost_learning=False, pdm_config_path='',online_pdm_infer=False,
train_sampling=1,rlft=False):
"""
Initializes trajectory head.
:param num_poses: number of (x,y,θ) poses to predict
:param d_ffn: dimensionality of feed-forward network
:param d_model: input dimensionality
"""
super(TrajectoryHead, self).__init__()
self._num_poses = num_poses
self._d_model = d_model
self._d_ffn = d_ffn
self.diff_loss_weight = 2.0
self.ego_fut_mode = 20
self.rejection_sampling = rejection_sampling
self.train_sampling = train_sampling
self.rlft = rlft
self.diffusion_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule="scaled_linear",
prediction_type="sample",
)
plan_anchor = np.load(plan_anchor_path)
self.plan_anchor = nn.Parameter(
torch.tensor(plan_anchor[:, 4::5, :2], dtype=torch.float32),
requires_grad=False,
) # 20,8,2
self.plan_anchor_encoder = nn.Sequential(
*linear_relu_ln(d_model, 1, 1,512),
nn.Linear(d_model, d_model),
)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(d_model),
nn.Linear(d_model, d_model * 4),
nn.Mish(),
nn.Linear(d_model * 4, d_model),
)
diff_decoder_layer = CustomTransformerDecoderLayer(
num_poses=num_poses,
d_model=d_model,
d_ffn=d_ffn,
config=config,
cost_learning=cost_learning,
)
self.diff_decoder = CustomTransformerDecoder(diff_decoder_layer, 2)
if self.rlft:
self.cost_diff_decoder = nn.ModuleList([CustomTransformerDecoderLayer(
num_poses=num_poses,
d_model=d_model,
d_ffn=d_ffn,
config=config,
cost_learning=cost_learning,
feat_out=True
) for _ in range(2)])
self.loss_computer = LossComputer(config, self.train_sampling)
self.cost_learning = cost_learning
self.online_pdm_infer = online_pdm_infer
if self.cost_learning:
self.heads = nn.ModuleDict({
'noc': nn.Sequential(
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1),
),
'da': nn.Sequential(
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1),
),
'ttc': nn.Sequential(
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1),
),
'comfort': nn.Sequential(
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1),
),
'progress': nn.Sequential(
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1),
),
'pdms': nn.Sequential(
nn.Linear(d_model, d_ffn), nn.ReLU(),nn.Linear(d_ffn, 1),
),
})
if self.rlft:
self.heads_lora = lora_wrapper(self.heads, LoRALinear,
rank=16, alpha=1.0, dropout=0.1,num_task=4)
finetuning_detach(self.heads)
cfg_path = pdm_config_path
print('online safe RL scoring, load from:',cfg_path)
cfg = cfg_path
self.cfg = cfg
self.chunk_size = 4
self.worker = RayDistributedNoTorch(threads_per_node=16)
self.init_metric_cache_paths()
def init_metric_cache_paths(self):
"""
Helper function to load all cache file paths from folder.
:param cache_path: directory of cache folder
:return: dictionary of token and file path
"""
metadata_file = '/xxx/navsim_metric_cache_navtrain2_metadata_node_0.csv'
with open(str(metadata_file), "r") as f:
cache_paths = f.read().splitlines()[1:]
self.metric_cache_dict = {cache_path.split("/")[-2]: cache_path for cache_path in cache_paths}
#video_file:
vidmetadata_file = '/xxx/navsim_metric_cache_navtrain2/metadata/navsim_metric_cache_navtrain2_metadata_node_0.csv'
with open(str(vidmetadata_file), "r") as f:
vid_cache_paths = f.read().splitlines()[1:]
self.metric_cache_dict.update({cache_path.split("/")[-2]: cache_path for cache_path in vid_cache_paths})
print(len(list(self.metric_cache_dict.keys())))
def norm_odo(self, odo_info_fut):
odo_info_fut_x = odo_info_fut[..., 0:1]
odo_info_fut_y = odo_info_fut[..., 1:2]
odo_info_fut_head = odo_info_fut[..., 2:3]
odo_info_fut_x = 2*(odo_info_fut_x + 1.2)/56.9 -1
odo_info_fut_y = 2*(odo_info_fut_y + 20)/46 -1
odo_info_fut_head = 2*(odo_info_fut_head + 2)/3.9 -1
return torch.cat([odo_info_fut_x, odo_info_fut_y, odo_info_fut_head], dim=-1)
def denorm_odo(self, odo_info_fut):
odo_info_fut_x = odo_info_fut[..., 0:1]
odo_info_fut_y = odo_info_fut[..., 1:2]
odo_info_fut_head = odo_info_fut[..., 2:3]
odo_info_fut_x = (odo_info_fut_x + 1)/2 * 56.9 - 1.2
odo_info_fut_y = (odo_info_fut_y + 1)/2 * 46 - 20
odo_info_fut_head = (odo_info_fut_head + 1)/2 * 3.9 - 2
return torch.cat([odo_info_fut_x, odo_info_fut_y, odo_info_fut_head], dim=-1)
def forward(self, ego_query, agents_query, bev_feature,bev_spatial_shape,
status_encoding, targets=None,global_img=None,
metric_cache=None,tokens=None) -> Dict[str, torch.Tensor]:
"""Torch module forward pass."""
if self.training:
return self.forward_train(ego_query, agents_query,
bev_feature,bev_spatial_shape,status_encoding,
targets,global_img, metric_cache,tokens)
else:
with torch.no_grad():
return self.forward_test(ego_query, agents_query,
bev_feature,bev_spatial_shape,status_encoding,global_img,tokens)
def online_cost_loss(self, cost_dict, gt_pdm_score, idx=0):
loss = {}
loss[f'loss.noc_{idx}'] = 3*torch.mean(
F.binary_cross_entropy_with_logits(cost_dict['noc'], gt_pdm_score['no_at_fault_collisions']))
loss[f'loss.da_{idx}'] = 3*torch.mean(
F.binary_cross_entropy_with_logits(cost_dict['da'], gt_pdm_score['drivable_area_compliance']))
loss[f'loss.ttc_{idx}'] = 2*torch.mean(
F.binary_cross_entropy_with_logits(cost_dict['ttc'], gt_pdm_score['time_to_collision_within_bound']))
loss[f'loss.ep_{idx}'] = torch.mean(
F.binary_cross_entropy_with_logits(cost_dict['progress'], gt_pdm_score['ego_progress']))
loss[f'loss.comfort_{idx}'] = torch.mean(
F.binary_cross_entropy_with_logits(
cost_dict['comfort'], gt_pdm_score['comfort']))
score = (5 * cost_dict['ttc'].sigmoid() + 5 * cost_dict['progress'].sigmoid() + 2*cost_dict['comfort'].sigmoid() ) / 12
score = cost_dict['noc'].sigmoid() * cost_dict['da'].sigmoid() * score
loss[f'loss.RL_score_{idx}'] = 5 * torch.mean(
F.binary_cross_entropy(
score, gt_pdm_score['score'])
)
loss[f'loss.score_{idx}'] = 5 * torch.mean(
F.binary_cross_entropy_with_logits(
cost_dict['pdms'], gt_pdm_score['score'])
)
return loss
def online_batch_pdm_calculation(self, traj_tensor, metric_cache,traj_len,tokens):
b, n, t, d = traj_tensor.shape
device = traj_tensor.device
# make batched interpolation:
traj_tensor = traj_tensor.view(b*n, t, d)
zeros = torch.zeros((b*n, 1, d)).to(device)
full_traj_tensor = torch.cat([zeros, traj_tensor], dim=1)
interped_traj_tensor = upsample_traj_nd_torch(full_traj_tensor, dt_in=0.5, dt_out=0.1)
interped_traj_tensor = interped_traj_tensor.reshape(b, n, -1, d)
interped_traj_tensor[..., -1] = wrap_to_pi(interped_traj_tensor[..., -1])
traj_tensor = interped_traj_tensor.detach().cpu().numpy()
batched_pdm_results = dict(
no_at_fault_collisions=[],
drivable_area_compliance=[],
ego_progress=[],
time_to_collision_within_bound=[],
comfort=[],
driving_direction_compliance=[],
score=[],
)
mc_chunk, traj_chunk, tok_chunk = [],[],[]
worker_args = []
for i in range(0, len(traj_tensor), self.chunk_size):
traj_chunk.append(traj_tensor[i:i+self.chunk_size])
tok_chunk.append(tokens[i:i+self.chunk_size])
if self.training:
worker_args.append(
dict(
metric_cache = [self.metric_cache_dict[t] for t in tokens[i:i+self.chunk_size]],
traj_tensor=traj_tensor[i:i+self.chunk_size],
tokens=tokens[i:i+self.chunk_size],
cfg=[self.cfg]*self.chunk_size
))
else:
buf_cache_token = []
for t in tokens[i:i+self.chunk_size]:
if t not in self.metric_cache_dict.keys():
buf_cache_token.append(self.metric_cache_dict[list(self.metric_cache_dict.keys())[0]])
else:
buf_cache_token.append(self.metric_cache_dict[t])
worker_args.append(
dict(
metric_cache =buf_cache_token,
traj_tensor=traj_tensor[i:i+self.chunk_size],
tokens=tokens[i:i+self.chunk_size],
cfg=[self.cfg]*self.chunk_size
))
args = zip(mc_chunk, traj_chunk, tok_chunk)
# You can use map with a lambda to call .compute on each actor
pdm_results = worker_map(self.worker, ray_worker_pdm_func, worker_args)
pdm_results_dict = {
pdm['token']:pdm for pdm in pdm_results
}
for i in range(b):
pdm_res = pdm_results_dict[tokens[i]]
for k in batched_pdm_results.keys():
batched_pdm_results[k].append(pdm_res[k])
ret_batch_pdm_results_list = [dict() for _ in range(traj_len)]
for k,v in batched_pdm_results.items():
buf_value = torch.from_numpy(
np.stack(v, axis=0)
).to(device)
if self.train_sampling > 1:
buf_value = buf_value.view(b, self.train_sampling, n//self.train_sampling)
buf_value = buf_value.view(b*self.train_sampling, -1)
buf_value = buf_value.reshape(b*self.train_sampling, traj_len, -1)
for i in range(traj_len):
ret_batch_pdm_results_list[i][k] = buf_value[:, i]
return ret_batch_pdm_results_list
def expand_and_reshape(self, x, k=4):
B, rest = x.shape[0], x.shape[1:] # arbitrary dims
x = x.unsqueeze(1).repeat(1, k, *([1] * len(rest))) # [B, A, k, ...]
# Step 2: reshape so second dimension becomes A*k
return x.view(B*k, *rest) # [B, A*k, ...]
def forward_train(self, ego_query,agents_query,bev_feature,bev_spatial_shape,status_encoding, targets=None,global_img=None,
metric_cache=None,tokens=None) -> Dict[str, torch.Tensor]:
bs = ego_query.shape[0]
device = ego_query.device
# 1. add truncated noise to the plan anchor
plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs,1,1,1)
odo_info_fut = self.norm_odo(plan_anchor)
if self.train_sampling > 1:
odo_info_fut = self.expand_and_reshape(odo_info_fut, k=self.train_sampling)
agents_query = self.expand_and_reshape(agents_query, k=self.train_sampling)
ego_query = self.expand_and_reshape(ego_query, k=self.train_sampling)
bev_feature = self.expand_and_reshape(bev_feature, k=self.train_sampling)
status_encoding = self.expand_and_reshape(status_encoding, k=self.train_sampling)
bs = bs*self.train_sampling
timesteps = torch.randint(
0, 50,
(bs,), device=device
)
noise = torch.randn(odo_info_fut.shape, device=device)
noisy_traj_points = self.diffusion_scheduler.add_noise(
original_samples=odo_info_fut,
noise=noise,
timesteps=timesteps,
).float()
noisy_traj_points = torch.clamp(noisy_traj_points, min=-1, max=1)
noisy_traj_points = self.denorm_odo(noisy_traj_points)
ego_fut_mode = noisy_traj_points.shape[1]
# 2. proj noisy_traj_points to the query
traj_pos_embed = gen_sineembed_for_position(noisy_traj_points,hidden_dim=64)
traj_pos_embed = traj_pos_embed.flatten(-2)
traj_feature = self.plan_anchor_encoder(traj_pos_embed)
traj_feature = traj_feature.view(bs,ego_fut_mode,-1)
# 3. embed the timesteps
time_embed = self.time_mlp(timesteps)
time_embed = time_embed.view(bs,1,-1)
# 4. begin the stacked decoder
poses_reg_list, poses_cls_list, traj_feat_list = self.diff_decoder(
traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape,
agents_query, ego_query, time_embed, status_encoding,global_img)
trajectory_loss_dict = {}
ret_traj_loss = 0
if self.cost_learning:
cost_dict_list = []
for idx, feat in enumerate(traj_feat_list):
cost_dict =dict()
if self.rlft:
feat = self.cost_diff_decoder[idx](
feat.detach(), poses_reg_list[idx][..., :2].detach(),
bev_feature.detach(), bev_spatial_shape,
agents_query.detach(), ego_query.detach(),
time_embed.detach(), status_encoding.detach()
)
for k in self.heads.keys():
cost_dict[k] = peft_wrapper_forward(feat, self.heads[k], self.heads_lora['lora_'+k])[..., 0]
else:
for k in self.heads.keys():
cost_dict[k] = self.heads[k](feat)[..., 0]
cost_dict_list.append(cost_dict)
# assert metric_cache is not None
b, num_mode, ts, d = poses_reg_list[0].shape
stack_trajs = torch.cat(poses_reg_list, dim=1)
if self.train_sampling > 1:
stack_trajs = stack_trajs.view(bs//self.train_sampling, self.train_sampling, -1, ts, d)
stack_trajs = stack_trajs.view(bs//self.train_sampling, -1, ts, d)
batched_pdm_score_list = self.online_batch_pdm_calculation(
stack_trajs, None, len(poses_reg_list),tokens
)
cost_dict_loss = dict()
for idx, (poses_reg, poses_cls) in enumerate(zip(poses_reg_list, poses_cls_list)):
trajectory_loss = self.loss_computer(poses_reg, poses_cls, targets, plan_anchor)
trajectory_loss_dict[f"loss.planning.{idx}"] = trajectory_loss
ret_traj_loss += trajectory_loss
if self.cost_learning:
cost_dict_loss.update(
self.online_cost_loss(
cost_dict_list[idx], batched_pdm_score_list[idx], idx
)
)
mode_idx = poses_cls_list[-1].argmax(dim=-1)
mode_idx = mode_idx[...,None,None,None].repeat(1,1,self._num_poses,3)
best_reg = torch.gather(poses_reg_list[-1], 1, mode_idx).squeeze(1)
ret_dict = dict(trajectory=best_reg)
ret_dict.update(trajectory_loss_dict)
if self.cost_learning:
return {'cost_loss':cost_dict_loss,"trajectory": best_reg,"trajectory_loss":ret_traj_loss}
return {"trajectory": best_reg,"trajectory_loss":ret_traj_loss}
def cal_test_rl_score(self, cost_dict):
score = (5 * cost_dict['ttc'].sigmoid() + 5 * cost_dict['progress'].sigmoid() + 2*cost_dict['comfort'].sigmoid() ) / 12
score = (cost_dict['noc'].sigmoid()).float() * (cost_dict['da'].sigmoid()).float() * score
return score
def forward_test(self, ego_query,agents_query,bev_feature,bev_spatial_shape,status_encoding,global_img,tokens) -> Dict[str, torch.Tensor]:
step_num = 2
bs = ego_query.shape[0]
device = ego_query.device
step_ratio = 20 / step_num
roll_timesteps = (np.arange(0, step_num) * step_ratio).round()[::-1].copy().astype(np.int64)
roll_timesteps = torch.from_numpy(roll_timesteps).to(device)
# 1. add truncated noise to the plan anchor
plan_anchor = self.plan_anchor.unsqueeze(0).repeat(bs,1,1,1)
org_img = self.norm_odo(plan_anchor)
trajs, scores , costs= [], [], []
for _ in range(100):
self.diffusion_scheduler.set_timesteps(1000, device)
timesteps = torch.randint(
0, 5,
(bs,), device=device
)
noise = torch.randn(org_img.shape, device=device)
trunc_timesteps = torch.ones((bs,), device=device, dtype=torch.long) * 8
img = self.diffusion_scheduler.add_noise(original_samples=org_img, noise=noise, timesteps=timesteps)
# noisy_trajs = self.denorm_odo(img)
ego_fut_mode = img.shape[1]
# for k in roll_timesteps[:]:
if True:
x_boxes = torch.clamp(img, min=-1, max=1)
noisy_traj_points = self.denorm_odo(x_boxes)
# 2. proj noisy_traj_points to the query
traj_pos_embed = gen_sineembed_for_position(noisy_traj_points,hidden_dim=64)
traj_pos_embed = traj_pos_embed.flatten(-2)
traj_feature = self.plan_anchor_encoder(traj_pos_embed)
traj_feature = traj_feature.view(bs,ego_fut_mode,-1)
# timesteps = k
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=img.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(img.device)
# 3. embed the timesteps
# timesteps = timesteps.expand(img.shape[0])
time_embed = self.time_mlp(timesteps)
time_embed = time_embed.view(bs,1,-1)
# 4. begin the stacked decoder
poses_reg_list, poses_cls_list, traj_feat_list = self.diff_decoder(traj_feature, noisy_traj_points, bev_feature, bev_spatial_shape, agents_query, ego_query, time_embed, status_encoding,global_img)
if self.cost_learning:
cost_list = []
for idx, feat in enumerate(traj_feat_list):
cost_dict =dict()
if self.rlft:
feat = self.cost_diff_decoder[idx](
feat.detach(), poses_reg_list[idx][..., :2].detach(),
bev_feature.detach(), bev_spatial_shape,
agents_query.detach(), ego_query.detach(),
time_embed.detach(), status_encoding.detach()
)
for h in self.heads.keys():
cost_dict[h] = peft_wrapper_forward(feat, self.heads[h], self.heads_lora['lora_'+h])[..., 0]
else:
for h in self.heads.keys():
cost_dict[h] = self.heads[h](feat)[..., 0]
cost_list.append(cost_dict)
else:
cost_list = [0]*len(poses_cls_list)
poses_reg = poses_reg_list[-1]
poses_cls = poses_cls_list[-1]
poses_cost = cost_list[-1]
x_start = poses_reg[...,:2]
x_start = self.norm_odo(x_start)
human_score = poses_cls.argmax(1)
trajs.append(poses_reg[torch.arange(bs)[:, None], human_score])
scores.append(poses_cls.softmax(-1)[torch.arange(bs)[:, None], human_score])
costs.append({k:v[torch.arange(bs)[:, None], human_score] for k,v in poses_cost.items()})
trajs = torch.cat(trajs, dim=1)
scores = torch.cat(scores, dim=1)
mode_idx = torch.argmax(scores, dim=1)
if self.cost_learning:
costs = [(self.cal_test_rl_score(cost)+cost['pdms'].sigmoid())/2 for cost in costs]
costs = torch.cat(costs,dim=1)
# costs[idx] = 0
mode_idx = torch.argmax(costs, dim=1)
if self.online_pdm_infer:
b, m, t, d = trajs.shape
batched_pdm_score_list = self.online_batch_pdm_calculation(
trajs, None, 1, tokens
)
mode_idx = torch.argmax(batched_pdm_score_list[0]['score'], dim=1)
b = trajs.shape[0]
best_reg = trajs[torch.arange(b), mode_idx]
return {"trajectory": best_reg}