| from typing import Dict, Tuple, List
|
|
|
| import pytorch_lightning as pl
|
| import torch
|
| from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
|
| from torch import Tensor
|
|
|
| from navsim.agents.abstract_agent import AbstractAgent
|
| from navsim.agents.vadv2.vadv2_agent import Vadv2Agent
|
| from navsim.common.dataclasses import Trajectory
|
|
|
|
|
| class AgentLightningModuleMap(pl.LightningModule):
|
| def __init__(
|
| self,
|
| agent: AbstractAgent,
|
| ):
|
| super().__init__()
|
| self.agent = agent
|
|
|
| def _step(
|
| self,
|
| batch: Tuple[Dict[str, Tensor], Dict[str, Tensor], List[str]],
|
| logging_prefix: str,
|
| ):
|
| features, targets = batch
|
| if logging_prefix in ['train', 'val'] and isinstance(self.agent, Vadv2Agent):
|
| prediction = self.agent.forward_train(features, targets['interpolated_traj'])
|
| else:
|
| prediction = self.agent.forward(features)
|
|
|
| loss, loss_dict = self.agent.compute_loss(features, targets, prediction)
|
|
|
| for k, v in loss_dict.items():
|
| self.log(f"{logging_prefix}/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
| return loss
|
|
|
| def training_step(
|
| self,
|
| batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
|
| batch_idx: int
|
| ):
|
| return self._step(batch, "train")
|
|
|
| def validation_step(
|
| self,
|
| batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
|
| batch_idx: int
|
| ):
|
| return self._step(batch, "val")
|
|
|
| def configure_optimizers(self):
|
| return self.agent.get_optimizers()
|
|
|
| def predict_step(
|
| self,
|
| batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
|
| batch_idx: int
|
| ):
|
| features, targets, tokens = batch
|
| self.agent.eval()
|
| with torch.no_grad():
|
| predictions = self.agent.forward(features)
|
| poses = predictions["trajectory"].cpu().numpy()
|
|
|
| imis = predictions["imi"].softmax(-1).log().cpu().numpy()
|
| nocs = predictions["noc"].log().cpu().numpy()
|
| das = predictions["da"].log().cpu().numpy()
|
| ttcs = predictions["ttc"].log().cpu().numpy()
|
| comforts = predictions["comfort"].log().cpu().numpy()
|
| progresses = predictions["progress"].log().cpu().numpy()
|
| if poses.shape[1] == 40:
|
| interval_length = 0.1
|
| else:
|
| interval_length = 0.5
|
|
|
| return {token: {
|
| 'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)),
|
| 'imi': imi,
|
| 'noc': noc,
|
| 'da': da,
|
| 'ttc': ttc,
|
| 'comfort': comfort,
|
| 'progress': progress
|
| } for pose, imi, noc, da, ttc, comfort, progress, token in zip(poses, imis, nocs, das, ttcs, comforts, progresses,
|
| tokens)}
|
|
|
|
|
|
|
|
|
|
|
| |