ICCV2025-RealADSim-ClosedLoop-SubmissionDemo
/
navsim
/planning
/training
/agent_lightning_module.py
| import pytorch_lightning as pl | |
| from torch import Tensor | |
| from typing import Dict, Tuple | |
| from navsim.agents.abstract_agent import AbstractAgent | |
| class AgentLightningModule(pl.LightningModule): | |
| """Pytorch lightning wrapper for learnable agent.""" | |
| def __init__(self, agent: AbstractAgent): | |
| """ | |
| Initialise the lightning module wrapper. | |
| :param agent: agent interface in NAVSIM | |
| """ | |
| super().__init__() | |
| self.agent = agent | |
| def _step(self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], logging_prefix: str) -> Tensor: | |
| """ | |
| Propagates the model forward and backwards and computes/logs losses and metrics. | |
| :param batch: tuple of dictionaries for feature and target tensors (batched) | |
| :param logging_prefix: prefix where to log step | |
| :return: scalar loss | |
| """ | |
| features, targets = batch | |
| prediction = self.agent.forward(features) | |
| loss = self.agent.compute_loss(features, targets, prediction) | |
| self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) | |
| return loss | |
| def training_step(self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int) -> Tensor: | |
| """ | |
| Step called on training samples | |
| :param batch: tuple of dictionaries for feature and target tensors (batched) | |
| :param batch_idx: index of batch (ignored) | |
| :return: scalar loss | |
| """ | |
| return self._step(batch, "train") | |
| def validation_step(self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int): | |
| """ | |
| Step called on validation samples | |
| :param batch: tuple of dictionaries for feature and target tensors (batched) | |
| :param batch_idx: index of batch (ignored) | |
| :return: scalar loss | |
| """ | |
| return self._step(batch, "val") | |
| def configure_optimizers(self): | |
| """Inherited, see superclass.""" | |
| return self.agent.get_optimizers() | |