de99 commited on
Commit
90e657e
·
verified ·
1 Parent(s): 2126e06

Upload datasets_v0.py

Browse files
Files changed (1) hide show
  1. datasets_v0.py +368 -0
datasets_v0.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # --------------------------------------------------------
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import os
7
+ from PIL import Image
8
+ from typing import Tuple
9
+ import yaml
10
+ import pickle
11
+ import tqdm
12
+ from torch.utils.data import Dataset
13
+ from misc import angle_difference, get_data_path, get_delta_np, normalize_data, to_local_coords
14
+ from project_functions import reproject_depth_to_other_pose_2seq, project_to_2d_image_2seq
15
+
16
+ class BaseDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ data_folder: str,
20
+ data_split_folder: str,
21
+ dataset_name: str,
22
+ image_size: Tuple[int, int],
23
+ min_dist_cat: int,
24
+ max_dist_cat: int,
25
+ len_traj_pred: int,
26
+ traj_stride: int,
27
+ context_size: int,
28
+ transform: object,
29
+ traj_names: str,
30
+ normalize: bool = True,
31
+ predefined_index: list = None,
32
+ goals_per_obs: int = 1,
33
+ ):
34
+ self.data_folder = data_folder
35
+ self.data_split_folder = data_split_folder
36
+ self.dataset_name = dataset_name
37
+ self.goals_per_obs = goals_per_obs
38
+
39
+
40
+ traj_names_file = os.path.join(data_split_folder, traj_names)
41
+ with open(traj_names_file, "r") as f:
42
+ file_lines = f.read()
43
+ self.traj_names = file_lines.split("\n")
44
+ if "" in self.traj_names:
45
+ self.traj_names.remove("")
46
+
47
+ self.image_size = image_size
48
+ self.distance_categories = list(range(min_dist_cat, max_dist_cat + 1))
49
+ self.min_dist_cat = self.distance_categories[0]
50
+ self.max_dist_cat = self.distance_categories[-1]
51
+ self.len_traj_pred = len_traj_pred
52
+ self.traj_stride = traj_stride
53
+
54
+ self.context_size = context_size
55
+ self.normalize = normalize
56
+
57
+ # load data/data_config.yaml
58
+ with open("config/data_config.yaml", "r") as f:
59
+ all_data_config = yaml.safe_load(f)
60
+
61
+ dataset_names = list(all_data_config.keys())
62
+ dataset_names.sort()
63
+ # use this index to retrieve the dataset name from the data_config.yaml
64
+ self.data_config = all_data_config[self.dataset_name]
65
+ self.transform = transform
66
+ self._load_index(predefined_index)
67
+ self.ACTION_STATS = {}
68
+ for key in all_data_config['action_stats']:
69
+ self.ACTION_STATS[key] = np.expand_dims(all_data_config['action_stats'][key], axis=0)
70
+
71
+ def _load_index(self, predefined_index) -> None:
72
+ """
73
+ Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset
74
+ """
75
+ if predefined_index:
76
+ print(f"****** Using a predefined evaluation index... {predefined_index}******")
77
+ with open(predefined_index, "rb") as f:
78
+ self.index_to_data = pickle.load(f)
79
+ return
80
+ else:
81
+ print("****** Evaluating from NON PREDEFINED index... ******")
82
+ index_to_data_path = os.path.join(
83
+ self.data_split_folder,
84
+ f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_n{self.context_size}_len_traj_pred_{self.len_traj_pred}.pkl",
85
+ )
86
+
87
+ self.index_to_data, self.goals_index = self._build_index()
88
+ with open(index_to_data_path, "wb") as f:
89
+ pickle.dump((self.index_to_data, self.goals_index), f)
90
+
91
+ def _build_index(self, use_tqdm: bool = False):
92
+ """
93
+ Build an index consisting of tuples (trajectory name, time, max goal distance)
94
+ """
95
+ samples_index = []
96
+ goals_index = []
97
+
98
+ for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True):
99
+ traj_data = self._get_trajectory(traj_name)
100
+ traj_len = len(traj_data["position"])
101
+ for goal_time in range(0, traj_len):
102
+ goals_index.append((traj_name, goal_time))
103
+
104
+ begin_time = self.context_size - 1
105
+ end_time = traj_len - self.len_traj_pred
106
+ for curr_time in range(begin_time, end_time, self.traj_stride):
107
+ max_goal_distance = min(self.max_dist_cat, traj_len - curr_time - 1)
108
+ min_goal_distance = max(self.min_dist_cat, -curr_time)
109
+ samples_index.append((traj_name, curr_time, min_goal_distance, max_goal_distance))
110
+
111
+ return samples_index, goals_index
112
+
113
+ def _get_trajectory(self, trajectory_name):
114
+ with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f:
115
+ traj_data = pickle.load(f)
116
+ for k,v in traj_data.items():
117
+ traj_data[k] = v.astype('float')
118
+ return traj_data
119
+
120
+ def __len__(self) -> int:
121
+ return len(self.index_to_data)
122
+
123
+ def _compute_actions(self, traj_data, curr_time, goal_time):
124
+ start_index = curr_time
125
+ end_index = curr_time + self.len_traj_pred + 1
126
+ yaw = traj_data["yaw"][start_index:end_index]
127
+ positions = traj_data["position"][start_index:end_index]
128
+ goal_pos = traj_data["position"][goal_time]
129
+ goal_yaw = traj_data["yaw"][goal_time]
130
+
131
+ if len(yaw.shape) == 2:
132
+ yaw = yaw.squeeze(1)
133
+
134
+ if yaw.shape != (self.len_traj_pred + 1,):
135
+ raise ValueError("is used?")
136
+ # const_len = self.len_traj_pred + 1 - yaw.shape[0]
137
+ # yaw = np.concatenate([yaw, np.repeat(yaw[-1], const_len)])
138
+ # positions = np.concatenate([positions, np.repeat(positions[-1][None], const_len, axis=0)], axis=0)
139
+
140
+ waypoints_pos = to_local_coords(positions, positions[0], yaw[0])
141
+ waypoints_yaw = angle_difference(yaw[0], yaw)
142
+ actions = np.concatenate([waypoints_pos, waypoints_yaw.reshape(-1, 1)], axis=-1)
143
+ actions = actions[1:]
144
+
145
+ goal_pos = to_local_coords(goal_pos, positions[0], yaw[0])
146
+ goal_yaw = angle_difference(yaw[0], goal_yaw)
147
+
148
+ if self.normalize:
149
+ actions[:, :2] /= self.data_config["metric_waypoint_spacing"]
150
+ goal_pos[:, :2] /= self.data_config["metric_waypoint_spacing"]
151
+
152
+ goal_pos = np.concatenate([goal_pos, goal_yaw.reshape(-1, 1)], axis=-1)
153
+ return actions, goal_pos
154
+
155
+ def _compute_projected_image(self, traj_data, curr_time, goal_time, rgb_img):
156
+ start_index = curr_time
157
+ end_index = curr_time + self.len_traj_pred + 1
158
+ pose_src = traj_data["pose"][start_index:end_index][-1]
159
+ pose_dst = traj_data["pose"][goal_time]
160
+ depth_map = traj_data["depth"][start_index:end_index][-1]
161
+ K = traj_data["K"]
162
+
163
+ projected_images = self.generate_augmented_image(K=K, depth_map=depth_map, rgb_img=rgb_img, pose_src=pose_src, pose_dst=pose_dst)
164
+ return projected_images
165
+
166
+ def generate_augmented_image(self, K, depth_map, rgb_img, pose_src, pose_dst) -> np.ndarray:
167
+ """
168
+ 基于深度图 + pose 生成从另一个相机视角观察到的图像。
169
+ """
170
+ image_size = depth_map.shape # (H, W)
171
+ # pose_dst = create_relative_pose(pose_src, delta_translation, delta_angle_deg)
172
+ points_3d, colors = reproject_depth_to_other_pose_2seq(K, depth_map, rgb_img, pose_src, pose_dst)
173
+ images = project_to_2d_image_2seq(K, points_3d, colors, image_size) # (H, W, 3, goal_time)
174
+ return images
175
+
176
+
177
+ class TrainingDataset(BaseDataset):
178
+ def __init__(
179
+ self,
180
+ data_folder: str,
181
+ data_split_folder: str,
182
+ dataset_name: str,
183
+ image_size: Tuple[int, int],
184
+ min_dist_cat: int,
185
+ max_dist_cat: int,
186
+ len_traj_pred: int,
187
+ traj_stride: int,
188
+ context_size: int,
189
+ transform: object,
190
+ traj_names: str = 'traj_names.txt',
191
+ normalize: bool = True,
192
+ predefined_index: list = None,
193
+ goals_per_obs: int = 1,
194
+ ):
195
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
196
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
197
+
198
+
199
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
200
+ try:
201
+ f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i]
202
+ goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1, size=(self.goals_per_obs))
203
+ goal_time = (curr_time + goal_offset).astype('int')
204
+ rel_time = (goal_offset).astype('float')/(128.) # TODO: refactor, currently a fixed const
205
+
206
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
207
+ context = [(f_curr, t) for t in context_times] + [(f_curr, t) for t in goal_time]
208
+
209
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
210
+
211
+ # Load other trajectory data
212
+ curr_traj_data = self._get_trajectory(f_curr)
213
+
214
+ # Compute actions
215
+ _, goal_pos = self._compute_actions(curr_traj_data, curr_time, goal_time)
216
+ goal_pos[:, :2] = normalize_data(goal_pos[:, :2], self.ACTION_STATS)
217
+ # Compute projected images
218
+ f_img, t_img = context[-1] # curr_time img
219
+ rgb_img = cv2.imread(get_data_path(self.data_folder, f_img, t_img))
220
+ rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
221
+ original_height, original_width = rgb_img.shape[:2]
222
+ resized_width = original_width // 2
223
+ resized_height = original_height // 2
224
+ rgb_img = cv2.resize(rgb_img, (resized_width, resized_height))
225
+ projected_images = self._compute_projected_image(curr_traj_data, curr_time, goal_time, rgb_img)
226
+ projected_tensor_list = [self.transform(Image.fromarray(img)) for img in projected_images]
227
+ projected_tensor = torch.stack(projected_tensor_list, dim=0)
228
+
229
+ return (
230
+ torch.as_tensor(obs_image, dtype=torch.float32),
231
+ torch.as_tensor(goal_pos, dtype=torch.float32),
232
+ torch.as_tensor(rel_time, dtype=torch.float32),
233
+ torch.as_tensor(projected_tensor, dtype=torch.float32),
234
+ )
235
+ except Exception as e:
236
+ print(f"Exception in {self.dataset_name}", e)
237
+ raise Exception(e)
238
+
239
+ class EvalDataset(BaseDataset):
240
+ def __init__(
241
+ self,
242
+ data_folder: str,
243
+ data_split_folder: str,
244
+ dataset_name: str,
245
+ image_size: Tuple[int, int],
246
+ min_dist_cat: int,
247
+ max_dist_cat: int,
248
+ len_traj_pred: int,
249
+ traj_stride: int,
250
+ context_size: int,
251
+ transform: object,
252
+ traj_names: str,
253
+ normalize: bool = True,
254
+ predefined_index: list = None,
255
+ goals_per_obs: int = 1,
256
+ ):
257
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
258
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
259
+
260
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
261
+ try:
262
+ f_curr, curr_time, _, _ = self.index_to_data[i]
263
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
264
+ pred_times = list(range(curr_time + 1, curr_time + self.len_traj_pred + 1))
265
+
266
+ context = [(f_curr, t) for t in context_times]
267
+ pred = [(f_curr, t) for t in pred_times]
268
+
269
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
270
+ pred_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in pred])
271
+
272
+ curr_traj_data = self._get_trajectory(f_curr)
273
+
274
+ # Compute actions
275
+ actions, _ = self._compute_actions(curr_traj_data, curr_time, np.array([curr_time+1])) # last argument is dummy goal
276
+ actions[:, :2] = normalize_data(actions[:, :2], self.ACTION_STATS)
277
+ delta = get_delta_np(actions)
278
+ # Compute projected images
279
+ f_img, t_img = context[-1] # curr_time img
280
+ rgb_img = cv2.imread(get_data_path(self.data_folder, f_img, t_img))
281
+ rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
282
+ original_height, original_width = rgb_img.shape[:2]
283
+ resized_width = original_width // 2
284
+ resized_height = original_height // 2
285
+ rgb_img = cv2.resize(rgb_img, (resized_width, resized_height))
286
+ projected_images = self._compute_projected_image(curr_traj_data, curr_time, goal_time, rgb_img)
287
+ projected_tensor_list = [self.transform(Image.fromarray(img)) for img in projected_images]
288
+ projected_tensor = torch.stack(projected_tensor_list, dim=0)
289
+
290
+ return (
291
+ torch.tensor([i], dtype=torch.float32), # for logging purposes
292
+ torch.as_tensor(obs_image, dtype=torch.float32),
293
+ torch.as_tensor(pred_image, dtype=torch.float32),
294
+ torch.as_tensor(delta, dtype=torch.float32),
295
+ torch.as_tensor(projected_tensor, dtype=torch.float32),
296
+ )
297
+ except Exception as e:
298
+ print(f"Exception in {self.dataset_name}", e)
299
+ raise Exception(e)
300
+
301
+ class TrajectoryEvalDataset(BaseDataset):
302
+ def __init__(
303
+ self,
304
+ data_folder: str,
305
+ data_split_folder: str,
306
+ dataset_name: str,
307
+ image_size: Tuple[int, int],
308
+ min_dist_cat: int,
309
+ max_dist_cat: int,
310
+ len_traj_pred: int,
311
+ traj_stride: int,
312
+ context_size: int,
313
+ transform: object,
314
+ traj_names: str,
315
+ normalize: bool = True,
316
+ predefined_index: list = None,
317
+ goals_per_obs: int = 1,
318
+ ):
319
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
320
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
321
+
322
+
323
+ def _sample_goal(self, trajectory_name, curr_time, min_goal_dist, max_goal_dist):
324
+ """
325
+ Sample a goal from the future in the same trajectory.
326
+ Returns: (trajectory_name, goal_time, goal_is_negative)
327
+ """
328
+ goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1)
329
+ goal_time = curr_time + int(goal_offset)
330
+ return trajectory_name, goal_time, False
331
+
332
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
333
+ try:
334
+ f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i]
335
+ f_goal, goal_time, _ = self._sample_goal(f_curr, curr_time, min_goal_dist, max_goal_dist)
336
+
337
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
338
+ context = [(f_curr, t) for t in context_times]
339
+
340
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
341
+ goal_image = self.transform(Image.open(get_data_path(self.data_folder, f_goal, goal_time))).unsqueeze(0)
342
+ curr_traj_data = self._get_trajectory(f_curr)
343
+
344
+ actions, goal_pos = self._compute_actions(curr_traj_data, curr_time, np.array([goal_time]))
345
+
346
+ # Compute projected images
347
+ f_img, t_img = context[-1] # curr_time img
348
+ rgb_img = cv2.imread(get_data_path(self.data_folder, f_img, t_img))
349
+ rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
350
+ original_height, original_width = rgb_img.shape[:2]
351
+ resized_width = original_width // 2
352
+ resized_height = original_height // 2
353
+ rgb_img = cv2.resize(rgb_img, (resized_width, resized_height))
354
+ projected_images = self._compute_projected_image(curr_traj_data, curr_time, goal_time, rgb_img)
355
+ projected_tensor_list = [self.transform(Image.fromarray(img)) for img in projected_images]
356
+ projected_tensor = torch.stack(projected_tensor_list, dim=0)
357
+
358
+ return (
359
+ torch.tensor([i], dtype=torch.float32), # for logging purposes
360
+ torch.as_tensor(obs_image, dtype=torch.float32),
361
+ torch.as_tensor(goal_image, dtype=torch.float32),
362
+ torch.as_tensor(actions, dtype=torch.float32),
363
+ torch.as_tensor(goal_pos, dtype=torch.float32),
364
+ torch.as_tensor(projected_tensor, dtype=torch.float32),
365
+ )
366
+ except Exception as e:
367
+ print(f"Exception in {self.dataset_name}", e)
368
+ raise Exception(e)