import pytorch_lightning as pl from torch import Tensor from typing import Dict, Tuple from navsim.agents.abstract_agent import AbstractAgent class AgentLightningModule(pl.LightningModule): """Pytorch lightning wrapper for learnable agent.""" def __init__(self, agent: AbstractAgent): """ Initialise the lightning module wrapper. :param agent: agent interface in NAVSIM """ super().__init__() self.agent = agent def _step(self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], logging_prefix: str) -> Tensor: """ Propagates the model forward and backwards and computes/logs losses and metrics. :param batch: tuple of dictionaries for feature and target tensors (batched) :param logging_prefix: prefix where to log step :return: scalar loss """ features, targets = batch prediction = self.agent.forward(features) loss = self.agent.compute_loss(features, targets, prediction) self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) return loss def training_step(self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int) -> Tensor: """ Step called on training samples :param batch: tuple of dictionaries for feature and target tensors (batched) :param batch_idx: index of batch (ignored) :return: scalar loss """ return self._step(batch, "train") def validation_step(self, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], batch_idx: int): """ Step called on validation samples :param batch: tuple of dictionaries for feature and target tensors (batched) :param batch_idx: index of batch (ignored) :return: scalar loss """ return self._step(batch, "val") def configure_optimizers(self): """Inherited, see superclass.""" return self.agent.get_optimizers()