File size: 5,850 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from __future__ import annotations
import lzma
import pickle
from pathlib import Path
from typing import Any, Dict, List
from tqdm import tqdm
from navsim.common.dataclasses import AgentInput, Scene, SceneFilter, SensorConfig
from navsim.planning.metric_caching.metric_cache import MetricCache
def filter_scenes(data_path: Path, scene_filter: SceneFilter) -> Dict[str, List[Dict[str, Any]]]:
def split_list(input_list: List[Any], num_frames: int, frame_interval: int) -> List[List[Any]]:
return [input_list[i : i + num_frames] for i in range(0, len(input_list), frame_interval)]
filtered_scenes: Dict[str, Scene] = {}
stop_loading: bool = False
# filter logs
log_files = list(data_path.iterdir())
if scene_filter.log_names is not None:
log_files = [
log_file
for log_file in log_files
if log_file.name.replace(".pkl", "") in scene_filter.log_names
]
if scene_filter.tokens is not None:
filter_tokens = True
tokens = set(scene_filter.tokens)
else:
filter_tokens = False
for log_pickle_path in tqdm(log_files, desc="Loading logs"):
scene_dict_list = pickle.load(open(log_pickle_path, "rb"))
for frame_list in split_list(
scene_dict_list, scene_filter.num_frames, scene_filter.frame_interval
):
# Filter scenes which are too short
if len(frame_list) < scene_filter.num_frames:
continue
# Filter scenes with no route
if (
scene_filter.has_route
and len(frame_list[scene_filter.num_history_frames - 1]["roadblock_ids"]) == 0
):
continue
# Filter by token
token = frame_list[scene_filter.num_history_frames - 1]["token"]
if filter_tokens and token not in tokens:
continue
filtered_scenes[token] = frame_list
if (scene_filter.max_scenes is not None) and (
len(filtered_scenes) >= scene_filter.max_scenes
):
stop_loading = True
break
if stop_loading:
break
return filtered_scenes
class SceneLoader:
def __init__(
self,
data_path: Path,
sensor_blobs_path: Path,
scene_filter: SceneFilter,
sensor_config: SensorConfig = SensorConfig.build_no_sensors(),
):
self.scene_frames_dicts = filter_scenes(data_path, scene_filter)
self._sensor_blobs_path = sensor_blobs_path
self._scene_filter = scene_filter
self._sensor_config = sensor_config
@property
def tokens(self) -> List[str]:
return list(self.scene_frames_dicts.keys())
def __len__(self):
return len(self.tokens)
def __getitem__(self, idx) -> str:
return self.tokens[idx]
def get_scene_from_token(self, token: str) -> Scene:
assert token in self.tokens
return Scene.from_scene_dict_list(
self.scene_frames_dicts[token],
self._sensor_blobs_path,
num_history_frames=self._scene_filter.num_history_frames,
num_future_frames=self._scene_filter.num_future_frames,
sensor_config=self._sensor_config,
)
def get_agent_input_from_token(self, token: str) -> AgentInput:
assert token in self.tokens
return AgentInput.from_scene_dict_list(
self.scene_frames_dicts[token],
self._sensor_blobs_path,
num_history_frames=self._scene_filter.num_history_frames,
sensor_config=self._sensor_config,
)
def get_tokens_list_per_log(self) -> Dict[str, List[str]]:
# generate a dict that contains a list of tokens for each log-name
tokens_per_logs: Dict[str, List[str]] = {}
for token, scene_dict_list in self.scene_frames_dicts.items():
log_name = scene_dict_list[0]["log_name"]
if tokens_per_logs.get(log_name):
tokens_per_logs[log_name].append(token)
else:
tokens_per_logs.update({log_name: [token]})
return tokens_per_logs
class MetricCacheLoader:
def __init__(
self,
cache_path: Path,
file_name: str = "metric_cache.pkl",
):
self._file_name = file_name
self.metric_cache_paths = self._load_metric_cache_paths(cache_path)
def _load_metric_cache_paths(self, cache_path: Path) -> Dict[str, Path]:
metadata_dir = cache_path / "metadata"
metadata_file = [file for file in metadata_dir.iterdir() if ".csv" in str(file)][0]
with open(str(metadata_file), "r") as f:
cache_paths=f.read().splitlines()[1:]
metric_cache_dict = {
cache_path.split("/")[-2]: cache_path
for cache_path in cache_paths
}
return metric_cache_dict
@property
def tokens(self) -> List[str]:
return list(self.metric_cache_paths.keys())
def __len__(self):
return len(self.metric_cache_paths)
def __getitem__(self, idx: int) -> MetricCache:
return self.get_from_token(self.tokens[idx])
def get_from_token(self, token: str) -> MetricCache:
with lzma.open(self.metric_cache_paths[token], "rb") as f:
metric_cache: MetricCache = pickle.load(f)
return metric_cache
def to_pickle(self, path: Path) -> None:
full_metric_cache = {}
for token in tqdm(self.tokens):
full_metric_cache[token] = self.get_from_token(token)
with open(path, "wb") as f:
pickle.dump(full_metric_cache, f)
|