| from typing import Any, List, Dict, Optional, Union | |
| import torch | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import LRScheduler | |
| import pytorch_lightning as pl | |
| from navsim.agents.abstract_agent import AbstractAgent | |
| from navsim.agents.transfuser.transfuser_config import TransfuserConfig | |
| from navsim.agents.transfuser.transfuser_model import TransfuserModel | |
| from navsim.agents.transfuser.transfuser_callback import TransfuserCallback | |
| from navsim.agents.transfuser.transfuser_loss import transfuser_loss | |
| from navsim.agents.transfuser.transfuser_features import TransfuserFeatureBuilder, TransfuserTargetBuilder | |
| from navsim.common.dataclasses import SensorConfig | |
| from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder | |
| class TransfuserAgent(AbstractAgent): | |
| """Agent interface for TransFuser baseline.""" | |
| def __init__( | |
| self, | |
| config: TransfuserConfig, | |
| lr: float, | |
| checkpoint_path: Optional[str] = None, | |
| ): | |
| """ | |
| Initializes TransFuser agent. | |
| :param config: global config of TransFuser agent | |
| :param lr: learning rate during training | |
| :param checkpoint_path: optional path string to checkpoint, defaults to None | |
| """ | |
| super().__init__() | |
| self._config = config | |
| self._lr = lr | |
| self._checkpoint_path = checkpoint_path | |
| self._transfuser_model = TransfuserModel(config) | |
| def name(self) -> str: | |
| """Inherited, see superclass.""" | |
| return self.__class__.__name__ | |
| def initialize(self) -> None: | |
| """Inherited, see superclass.""" | |
| if torch.cuda.is_available(): | |
| state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"] | |
| else: | |
| state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))[ | |
| "state_dict" | |
| ] | |
| self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) | |
| def get_sensor_config(self) -> SensorConfig: | |
| """Inherited, see superclass.""" | |
| return SensorConfig.build_all_sensors(include=[3]) | |
| def get_target_builders(self) -> List[AbstractTargetBuilder]: | |
| """Inherited, see superclass.""" | |
| return [TransfuserTargetBuilder(config=self._config)] | |
| def get_feature_builders(self) -> List[AbstractFeatureBuilder]: | |
| """Inherited, see superclass.""" | |
| return [TransfuserFeatureBuilder(config=self._config)] | |
| def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """Inherited, see superclass.""" | |
| return self._transfuser_model(features) | |
| def compute_loss( | |
| self, | |
| features: Dict[str, torch.Tensor], | |
| targets: Dict[str, torch.Tensor], | |
| predictions: Dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Inherited, see superclass.""" | |
| return transfuser_loss(targets, predictions, self._config) | |
| def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]: | |
| """Inherited, see superclass.""" | |
| return torch.optim.Adam(self._transfuser_model.parameters(), lr=self._lr) | |
| def get_training_callbacks(self) -> List[pl.Callback]: | |
| """Inherited, see superclass.""" | |
| return [TransfuserCallback(self._config)] | |