de99 commited on
Commit
f8f9be6
·
verified ·
1 Parent(s): 53f2c76

Upload datasets_v2.py

Browse files
Files changed (1) hide show
  1. datasets_v2.py +425 -0
datasets_v2.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer
9
+ # --------------------------------------------------------
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ import os
14
+ from PIL import Image
15
+ from typing import Tuple
16
+ import yaml
17
+ import pickle
18
+ import tqdm
19
+ from torch.utils.data import Dataset
20
+ from misc import angle_difference, get_data_path, get_delta_np, normalize_data, to_local_coords
21
+ from project_functions import reproject_depth_to_other_pose_2seq, project_to_2d_image_2seq, resize_image_half
22
+
23
+ class BaseDataset(Dataset):
24
+ def __init__(
25
+ self,
26
+ data_folder: str,
27
+ data_split_folder: str,
28
+ dataset_name: str,
29
+ image_size: Tuple[int, int],
30
+ min_dist_cat: int,
31
+ max_dist_cat: int,
32
+ len_traj_pred: int,
33
+ traj_stride: int,
34
+ context_size: int,
35
+ transform: object,
36
+ traj_names: str,
37
+ normalize: bool = True,
38
+ predefined_index: list = None,
39
+ goals_per_obs: int = 1,
40
+ ):
41
+ self.data_folder = data_folder
42
+ self.data_split_folder = data_split_folder
43
+ self.dataset_name = dataset_name
44
+ self.goals_per_obs = goals_per_obs
45
+
46
+
47
+ traj_names_file = os.path.join(data_split_folder, traj_names)
48
+ with open(traj_names_file, "r") as f:
49
+ file_lines = f.read()
50
+ self.traj_names = file_lines.split("\n")
51
+ if "" in self.traj_names:
52
+ self.traj_names.remove("")
53
+
54
+ self.image_size = image_size
55
+ self.distance_categories = list(range(min_dist_cat, max_dist_cat + 1))
56
+ self.min_dist_cat = self.distance_categories[0]
57
+ self.max_dist_cat = self.distance_categories[-1]
58
+ self.len_traj_pred = len_traj_pred
59
+ self.traj_stride = traj_stride
60
+
61
+ self.context_size = context_size
62
+ self.normalize = normalize
63
+
64
+ # load data/data_config.yaml
65
+ with open("config/data_config.yaml", "r") as f:
66
+ all_data_config = yaml.safe_load(f)
67
+
68
+ dataset_names = list(all_data_config.keys())
69
+ dataset_names.sort()
70
+ # use this index to retrieve the dataset name from the data_config.yaml
71
+ self.data_config = all_data_config[self.dataset_name]
72
+ self.transform = transform
73
+ self._load_index(predefined_index)
74
+ self.ACTION_STATS = {}
75
+ for key in all_data_config['action_stats']:
76
+ self.ACTION_STATS[key] = np.expand_dims(all_data_config['action_stats'][key], axis=0)
77
+
78
+ def _load_index(self, predefined_index) -> None:
79
+ """
80
+ Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset
81
+ """
82
+ if predefined_index:
83
+ print(f"****** Using a predefined evaluation index... {predefined_index}******")
84
+ with open(predefined_index, "rb") as f:
85
+ self.index_to_data = pickle.load(f)
86
+ return
87
+ else:
88
+ print("****** Evaluating from NON PREDEFINED index... ******")
89
+ index_to_data_path = os.path.join(
90
+ self.data_split_folder,
91
+ f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_n{self.context_size}_len_traj_pred_{self.len_traj_pred}.pkl",
92
+ )
93
+
94
+ self.index_to_data, self.goals_index = self._build_index()
95
+ with open(index_to_data_path, "wb") as f:
96
+ pickle.dump((self.index_to_data, self.goals_index), f)
97
+ print(f"Saved index to {index_to_data_path}, total samples: {len(self.index_to_data)}")
98
+
99
+ def _build_index(self, use_tqdm: bool = False):
100
+ """
101
+ Build an index consisting of tuples (trajectory name, time, max goal distance)
102
+ """
103
+ samples_index = []
104
+ goals_index = []
105
+
106
+ for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True):
107
+ traj_data = self._get_trajectory(traj_name)
108
+ traj_len = len(traj_data["position"])
109
+ # if traj_len < 12:
110
+ # continue
111
+ for goal_time in range(0, traj_len):
112
+ goals_index.append((traj_name, goal_time))
113
+
114
+ begin_time = self.context_size - 1
115
+ end_time = traj_len - self.len_traj_pred
116
+ for curr_time in range(begin_time, end_time, self.traj_stride):
117
+ max_goal_distance = min(self.max_dist_cat, traj_len - curr_time - 1)
118
+ min_goal_distance = max(self.min_dist_cat, -curr_time)
119
+ samples_index.append((traj_name, curr_time, min_goal_distance, max_goal_distance))
120
+
121
+ return samples_index, goals_index
122
+
123
+ def _get_trajectory(self, trajectory_name):
124
+ with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f:
125
+ traj_data = pickle.load(f)
126
+ for k,v in traj_data.items():
127
+ traj_data[k] = v.astype('float')
128
+ # off = 88
129
+ # TIME_KEYS = ("point", "position", "pose", "depth", "yaw") # 这些第一维是时间
130
+ # # 先确定“时间长度”
131
+ # time_lens = []
132
+ # for k in TIME_KEYS:
133
+ # if k in traj_data and isinstance(traj_data[k], np.ndarray) and traj_data[k].ndim >= 1:
134
+ # time_lens.append(traj_data[k].shape[0])
135
+ # time_len = min(time_lens) if len(time_lens) > 0 else 0
136
+
137
+ # if time_len > 0 and off > 0:
138
+ # for k in TIME_KEYS:
139
+ # if k in traj_data and isinstance(traj_data[k], np.ndarray):
140
+ # arr = traj_data[k]
141
+ # # 只切第一维等于 time_len 的数组(按时间展开的)
142
+ # if arr.ndim >= 1 and arr.shape[0] == time_len:
143
+ # traj_data[k] = arr[off:]
144
+
145
+ return traj_data
146
+
147
+ def __len__(self) -> int:
148
+ return len(self.index_to_data)
149
+
150
+ def _compute_projected_image(self, traj_data, curr_time, goal_time, rgb_img):
151
+ pose_src = traj_data["pose"][curr_time]
152
+ pose_dst = traj_data["pose"][goal_time]
153
+ depth_map = traj_data["depth"][curr_time]
154
+ K = traj_data["K"]
155
+
156
+ projected_images = self.generate_augmented_image(K=K, depth_map=depth_map, rgb_img=rgb_img, pose_src=pose_src, pose_dst=pose_dst)
157
+ return projected_images
158
+
159
+ def generate_augmented_image(self, K, depth_map, rgb_img, pose_src, pose_dst) -> np.ndarray:
160
+ """
161
+ 基于深度图 + pose 生成从另一个相机视角观察到的图像。
162
+ """
163
+ image_size = depth_map.shape # (H, W)
164
+ if rgb_img.shape[:2] != image_size:
165
+ rgb_img = resize_image_half(rgb_img)
166
+ points_3d, colors = reproject_depth_to_other_pose_2seq(K, depth_map, rgb_img, pose_src, pose_dst)
167
+ images = project_to_2d_image_2seq(K, points_3d, colors, image_size) # (H, W, 3, goal_time)
168
+ return images
169
+
170
+ def _compute_actions(self, traj_data, curr_time, goal_time, rgb_img):
171
+ start_index = curr_time
172
+ end_index = curr_time + self.len_traj_pred + 1
173
+ yaw = traj_data["yaw"][start_index:end_index]
174
+ positions = traj_data["point"][start_index:end_index]
175
+ goal_pos = traj_data["point"][goal_time]
176
+ goal_yaw = traj_data["yaw"][goal_time]
177
+
178
+ if len(yaw.shape) == 2:
179
+ yaw = yaw.squeeze(1)
180
+
181
+ if yaw.shape != (self.len_traj_pred + 1,):
182
+ raise ValueError("is used?")
183
+ # const_len = self.len_traj_pred + 1 - yaw.shape[0]
184
+ # yaw = np.concatenate([yaw, np.repeat(yaw[-1], const_len)])
185
+ # positions = np.concatenate([positions, np.repeat(positions[-1][None], const_len, axis=0)], axis=0)
186
+
187
+ waypoints_pos = to_local_coords(positions, positions[0], yaw[0])
188
+ waypoints_yaw = angle_difference(yaw[0], yaw)
189
+ actions = np.concatenate([waypoints_pos, waypoints_yaw.reshape(-1, 1)], axis=-1)
190
+ actions = actions[1:]
191
+
192
+ goal_pos = to_local_coords(goal_pos, positions[0], yaw[0])
193
+ goal_yaw = angle_difference(yaw[0], goal_yaw)
194
+
195
+ if self.normalize:
196
+ actions[:, :3] /= self.data_config["metric_waypoint_spacing"]
197
+ goal_pos[:, :3] /= self.data_config["metric_waypoint_spacing"]
198
+
199
+ goal_pos = np.concatenate([goal_pos, goal_yaw.reshape(-1, 1)], axis=-1)
200
+
201
+ projected_images = self._compute_projected_image(traj_data, curr_time, goal_time, rgb_img)
202
+ return actions, goal_pos, projected_images
203
+
204
+ class TrainingDataset(BaseDataset):
205
+ def __init__(
206
+ self,
207
+ data_folder: str,
208
+ data_split_folder: str,
209
+ dataset_name: str,
210
+ image_size: Tuple[int, int],
211
+ min_dist_cat: int,
212
+ max_dist_cat: int,
213
+ len_traj_pred: int,
214
+ traj_stride: int,
215
+ context_size: int,
216
+ transform: object,
217
+ traj_names: str = 'traj_names.txt',
218
+ normalize: bool = True,
219
+ predefined_index: list = None,
220
+ goals_per_obs: int = 1,
221
+ ):
222
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
223
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
224
+
225
+
226
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
227
+ try:
228
+ f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i]
229
+ goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1, size=(self.goals_per_obs))
230
+ goal_time = (curr_time + goal_offset).astype('int')
231
+ rel_time = (goal_offset).astype('float')/(128.) # TODO: refactor, currently a fixed const
232
+
233
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
234
+ true_context = [(f_curr, t) for t in context_times]
235
+ goal_context = [(f_curr, t) for t in goal_time]
236
+ context = [(f_curr, t) for t in context_times] + [(f_curr, t) for t in goal_time]
237
+
238
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
239
+
240
+ # Load other trajectory data
241
+ curr_traj_data = self._get_trajectory(f_curr)
242
+
243
+ # aug
244
+ f_img, t_img = true_context[-1] # curr_time img
245
+ rgb_img = cv2.imread(get_data_path(self.data_folder, f_img, t_img))
246
+ rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
247
+
248
+ # Compute actions
249
+ _, goal_pos, projected_images = self._compute_actions(curr_traj_data, curr_time, goal_time, rgb_img)
250
+ goal_pos[:, :3] = normalize_data(goal_pos[:, :3], self.ACTION_STATS)
251
+
252
+ projected_tensor_list = [self.transform(Image.fromarray(img)) for img in projected_images]
253
+ projected_tensor = torch.stack(projected_tensor_list, dim=0)
254
+
255
+ # # ===================== 保存图像 =====================
256
+ # vis_root = './visualizations'
257
+ # sample_dir = os.path.join(vis_root, f'{self.dataset_name}', f'sample_{i}')
258
+ # os.makedirs(sample_dir, exist_ok=True)
259
+
260
+ # # 1. 保存 curr_frame
261
+ # curr_img_save_path = os.path.join(sample_dir, 'curr_frame.png')
262
+ # Image.fromarray(rgb_img).save(curr_img_save_path)
263
+
264
+ # # 2. 保存 goal_frame
265
+ # for idx, (f_curr, t_goal) in enumerate(goal_context):
266
+ # goal_img_path = get_data_path(self.data_folder, f_curr, t_goal)
267
+ # goal_img = Image.open(goal_img_path)
268
+ # goal_img.save(os.path.join(sample_dir, f'goal_{idx}.png'))
269
+
270
+ # # 3. 保存 projected goal frame
271
+ # for idx, proj_img in enumerate(projected_images):
272
+ # proj_img_save_path = os.path.join(sample_dir, f'projected_goal_{idx}.png')
273
+ # Image.fromarray(proj_img).save(proj_img_save_path)
274
+ # # ====================================================
275
+
276
+ return (
277
+ torch.as_tensor(obs_image, dtype=torch.float32),
278
+ torch.as_tensor(goal_pos, dtype=torch.float32),
279
+ torch.as_tensor(rel_time, dtype=torch.float32),
280
+ torch.as_tensor(projected_tensor, dtype=torch.float32),
281
+ )
282
+ except Exception as e:
283
+ print(f"Exception in {self.dataset_name}", e)
284
+ raise Exception(e)
285
+
286
+
287
+ class EvalDataset(BaseDataset):
288
+ def __init__(
289
+ self,
290
+ data_folder: str,
291
+ data_split_folder: str,
292
+ dataset_name: str,
293
+ image_size: Tuple[int, int],
294
+ min_dist_cat: int,
295
+ max_dist_cat: int,
296
+ len_traj_pred: int,
297
+ traj_stride: int,
298
+ context_size: int,
299
+ transform: object,
300
+ traj_names: str,
301
+ normalize: bool = True,
302
+ predefined_index: list = None,
303
+ goals_per_obs: int = 1,
304
+ ):
305
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
306
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
307
+
308
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
309
+ try:
310
+ f_curr, curr_time, _, _ = self.index_to_data[i]
311
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
312
+ pred_times = list(range(curr_time + 1, curr_time + self.len_traj_pred + 1))
313
+
314
+ context = [(f_curr, t) for t in context_times]
315
+ pred = [(f_curr, t) for t in pred_times]
316
+
317
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
318
+ pred_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in pred])
319
+
320
+ curr_traj_data = self._get_trajectory(f_curr)
321
+
322
+ # Compute last rgb image
323
+ f_img, t_img = context[-1] # curr_time img
324
+ rgb_img = cv2.imread(get_data_path(self.data_folder, f_img, t_img))
325
+ rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
326
+ # Compute actions
327
+ actions, _, projected_images = self._compute_actions(curr_traj_data, curr_time, np.array(pred_times), rgb_img) # last argument is dummy goal
328
+ actions[:, :3] = normalize_data(actions[:, :3], self.ACTION_STATS)
329
+ delta = get_delta_np(actions)
330
+ # Compute projected tensor
331
+ projected_tensor_list = [self.transform(Image.fromarray(img)) for img in projected_images]
332
+ projected_tensor = torch.stack(projected_tensor_list, dim=0)
333
+ print(f"Index {i}, projected_images shape: {projected_images.shape}, projected_tensor shape: {projected_tensor.size()}")
334
+ # # ===================== 保存图像 =====================
335
+ # vis_root = './visualizations-eval'
336
+ # sample_dir = os.path.join(vis_root, f'{self.dataset_name}', f'sample_{i}')
337
+ # os.makedirs(sample_dir, exist_ok=True)
338
+
339
+ # # 1) 保存当前帧
340
+ # Image.fromarray(rgb_img).save(os.path.join(sample_dir, 'curr_frame.png'))
341
+
342
+ # # 2) 保存各个未来 GT 帧(与 pred_times 对齐)
343
+ # for idx, (f_pred, t_pred) in enumerate(pred):
344
+ # gt_img = Image.open(get_data_path(self.data_folder, f_pred, t_pred))
345
+ # gt_img.save(os.path.join(sample_dir, f'gt_future_{idx:02d}.png'))
346
+
347
+ # # 3) 保存各个投影图(与 pred_times 一一对应)
348
+ # for idx, proj_img in enumerate(projected_images):
349
+ # proj_img_save_path = os.path.join(sample_dir, f'projected_goal_{idx}.png')
350
+ # Image.fromarray(proj_img).save(proj_img_save_path)
351
+ # # ====================================================
352
+ return (
353
+ torch.tensor([i], dtype=torch.float32), # for logging purposes
354
+ torch.as_tensor(obs_image, dtype=torch.float32),
355
+ torch.as_tensor(pred_image, dtype=torch.float32),
356
+ torch.as_tensor(delta, dtype=torch.float32),
357
+ torch.as_tensor(projected_tensor, dtype=torch.float32),
358
+ )
359
+ except Exception as e:
360
+ print(f"Exception in {self.dataset_name}", e)
361
+ raise Exception(e)
362
+
363
+ class TrajectoryEvalDataset(BaseDataset):
364
+ def __init__(
365
+ self,
366
+ data_folder: str,
367
+ data_split_folder: str,
368
+ dataset_name: str,
369
+ image_size: Tuple[int, int],
370
+ min_dist_cat: int,
371
+ max_dist_cat: int,
372
+ len_traj_pred: int,
373
+ traj_stride: int,
374
+ context_size: int,
375
+ transform: object,
376
+ traj_names: str,
377
+ normalize: bool = True,
378
+ predefined_index: list = None,
379
+ goals_per_obs: int = 1,
380
+ ):
381
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
382
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
383
+
384
+
385
+ def _sample_goal(self, trajectory_name, curr_time, min_goal_dist, max_goal_dist):
386
+ """
387
+ Sample a goal from the future in the same trajectory.
388
+ Returns: (trajectory_name, goal_time, goal_is_negative)
389
+ """
390
+ goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1)
391
+ goal_time = curr_time + int(goal_offset)
392
+ return trajectory_name, goal_time, False
393
+
394
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
395
+ try:
396
+ f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i]
397
+ f_goal, goal_time, _ = self._sample_goal(f_curr, curr_time, min_goal_dist, max_goal_dist)
398
+
399
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
400
+ context = [(f_curr, t) for t in context_times]
401
+
402
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
403
+ goal_image = self.transform(Image.open(get_data_path(self.data_folder, f_goal, goal_time))).unsqueeze(0)
404
+ curr_traj_data = self._get_trajectory(f_curr)
405
+ # Compute actions, goal_pos, projected images
406
+ f_img, t_img = context[-1] # curr_time img
407
+ rgb_img = cv2.imread(get_data_path(self.data_folder, f_img, t_img))
408
+ rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
409
+
410
+ actions, goal_pos, projected_images = self._compute_actions(curr_traj_data, curr_time, np.array([goal_time]), rgb_img)
411
+
412
+ projected_tensor_list = [self.transform(Image.fromarray(img)) for img in projected_images]
413
+ projected_tensor = torch.stack(projected_tensor_list, dim=0)
414
+
415
+ return (
416
+ torch.tensor([i], dtype=torch.float32), # for logging purposes
417
+ torch.as_tensor(obs_image, dtype=torch.float32),
418
+ torch.as_tensor(goal_image, dtype=torch.float32),
419
+ torch.as_tensor(actions, dtype=torch.float32),
420
+ torch.as_tensor(goal_pos, dtype=torch.float32),
421
+ torch.as_tensor(projected_tensor, dtype=torch.float32),
422
+ )
423
+ except Exception as e:
424
+ print(f"Exception in {self.dataset_name}", e)
425
+ raise Exception(e)