| from typing import Dict, List, Tuple
|
| import torch
|
|
|
| from det_map.data.datasets.dataloader import SceneLoader
|
| from det_map.data.datasets.dataset import Dataset
|
| from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
|
|
|
| class DetDataset(Dataset):
|
| def __init__(
|
| self, **kwargs
|
| ):
|
| super().__init__(**kwargs)
|
|
|
| def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
| scene = self._scene_loader.get_scene_from_token(self._scene_loader.tokens[idx])
|
| features: Dict[str, torch.Tensor] = {}
|
| for builder in self._feature_builders:
|
| features.update(builder.compute_features(scene.get_agent_input()))
|
| targets: Dict[str, torch.Tensor] = {}
|
| for builder in self._target_builders:
|
| targets.update(builder.compute_targets(scene))
|
|
|
| features, targets = self.pipelines['lidar_aug'](features, targets)
|
| features, targets = self.pipelines['depth'](features, targets)
|
| features, targets = self.pipelines['lidar_filter'](features, targets)
|
| features, targets = self.pipelines['point_shuffle'](features, targets)
|
|
|
| return (features, targets) |