lhmd commited on
Commit
3472e2c
·
verified ·
1 Parent(s): 0299d25

Upload convert.py

Browse files
Files changed (1) hide show
  1. convert.py +289 -0
convert.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Literal, TypedDict
5
+ from PIL import Image
6
+
7
+ import numpy as np
8
+ import torch
9
+ from jaxtyping import Float, Int, UInt8
10
+ from torch import Tensor
11
+ from tqdm import tqdm
12
+ import argparse
13
+ import json
14
+ import os
15
+
16
+ from glob import glob
17
+
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--input_base_dir", type=str, help="base directory containing 1K, 2K, ..., 11K subdirectories")
21
+ parser.add_argument("--output_base_dir", type=str, help="base output directory for processed datasets")
22
+ parser.add_argument(
23
+ "--img_subdir",
24
+ type=str,
25
+ default="images_8",
26
+ help="image directory name",
27
+ choices=[
28
+ "images_4",
29
+ "images_8",
30
+ ],
31
+ )
32
+ parser.add_argument("--n_test", type=int, default=10, help="test skip")
33
+ parser.add_argument("--which_stage", type=str, default=None, help="dataset directory")
34
+ parser.add_argument("--detect_overlap", action="store_true")
35
+ parser.add_argument("--start_k", type=int, default=1, help="starting K value (default: 1)")
36
+ parser.add_argument("--end_k", type=int, default=11, help="ending K value (default: 11)")
37
+
38
+ args = parser.parse_args()
39
+
40
+
41
+ # Target 200 MB per chunk.
42
+ TARGET_BYTES_PER_CHUNK = int(2e8)
43
+
44
+
45
+ def get_size(path: Path) -> int:
46
+ """Get file or folder size in bytes."""
47
+ return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8"))
48
+
49
+
50
+ def load_raw(path: Path) -> UInt8[Tensor, " length"]:
51
+ return torch.tensor(np.memmap(path, dtype="uint8", mode="r"))
52
+
53
+
54
+ def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]:
55
+ """Load JPG images as raw bytes (do not decode)."""
56
+
57
+ return {
58
+ int(path.stem.split("_")[-1]): load_raw(path)
59
+ for path in example_path.iterdir()
60
+ if path.suffix.lower() not in [".npz"]
61
+ }
62
+
63
+
64
+ class Metadata(TypedDict):
65
+ url: str
66
+ timestamps: Int[Tensor, " camera"]
67
+ cameras: Float[Tensor, "camera entry"]
68
+
69
+
70
+ class Example(Metadata):
71
+ key: str
72
+ images: list[UInt8[Tensor, "..."]]
73
+
74
+
75
+ def load_metadata(example_path: Path) -> Metadata:
76
+ blender2opencv = np.array(
77
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
78
+ )
79
+ url = str(example_path).split("/")[-3]
80
+ with open(example_path, "r") as f:
81
+ meta_data = json.load(f)
82
+
83
+ store_h, store_w = meta_data["h"], meta_data["w"]
84
+ fx, fy, cx, cy = (
85
+ meta_data["fl_x"],
86
+ meta_data["fl_y"],
87
+ meta_data["cx"],
88
+ meta_data["cy"],
89
+ )
90
+ saved_fx = float(fx) / float(store_w)
91
+ saved_fy = float(fy) / float(store_h)
92
+ saved_cx = float(cx) / float(store_w)
93
+ saved_cy = float(cy) / float(store_h)
94
+
95
+ timestamps = []
96
+ cameras = []
97
+ opencv_c2ws = [] # will be used to calculate camera distance
98
+
99
+ for frame in meta_data["frames"]:
100
+ timestamps.append(
101
+ int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1])
102
+ )
103
+ camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0]
104
+ # transform_matrix is in blender c2w, while we need to store opencv w2c matrix here
105
+ opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv
106
+ opencv_c2ws.append(opencv_c2w)
107
+ camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist())
108
+ cameras.append(np.array(camera))
109
+
110
+ # timestamp should be the one that match the above images keys, use for indexing
111
+ timestamps = torch.tensor(timestamps, dtype=torch.int64)
112
+ cameras = torch.tensor(np.stack(cameras), dtype=torch.float32)
113
+
114
+ return {"url": url, "timestamps": timestamps, "cameras": cameras}
115
+
116
+
117
+ def partition_train_test_splits(root_dir, n_test=10):
118
+ sub_folders = sorted(glob(os.path.join(root_dir, "*/")))
119
+ test_list = sub_folders[::n_test]
120
+ train_list = [x for x in sub_folders if x not in test_list]
121
+ out_dict = {"train": train_list, "test": test_list}
122
+ return out_dict
123
+
124
+
125
+ def is_image_shape_matched(image_dir, target_shape):
126
+ image_path = sorted(glob(str(image_dir / "*")))
127
+ if len(image_path) == 0:
128
+ return False
129
+
130
+ image_path = image_path[0]
131
+ try:
132
+ im = Image.open(image_path)
133
+ except:
134
+ return False
135
+ w, h = im.size
136
+ if (h, w) == target_shape:
137
+ return True
138
+ else:
139
+ return False
140
+
141
+
142
+ def legal_check_for_all_scenes(root_dir, target_shape):
143
+ valid_folders = []
144
+ sub_folders = sorted(glob(os.path.join(root_dir, "*/*")))
145
+ for sub_folder in tqdm(sub_folders, desc="checking scenes..."):
146
+ img_dir = os.path.join(sub_folder, 'images_8')
147
+ # img_dir = os.path.join(sub_folder, "images_4")
148
+ if not is_image_shape_matched(Path(img_dir), target_shape):
149
+ print(f"image shape does not match for {sub_folder}")
150
+ continue
151
+ pose_file = os.path.join(sub_folder, "transforms.json")
152
+ if not os.path.isfile(pose_file):
153
+ print(f"cannot find pose file for {sub_folder}")
154
+ continue
155
+
156
+ valid_folders.append(sub_folder)
157
+
158
+ return valid_folders
159
+
160
+
161
+ def process_single_directory(input_dir: Path, output_dir: Path):
162
+ """Process a single K directory"""
163
+ print(f"\n=== Processing {input_dir.name} ===")
164
+
165
+ INPUT_DIR = input_dir
166
+ OUTPUT_DIR = output_dir
167
+
168
+ if "images_8" in args.img_subdir:
169
+ target_shape = (270, 480) # (h, w)
170
+ elif "images_4" in args.img_subdir:
171
+ target_shape = (540, 960)
172
+ else:
173
+ raise ValueError
174
+
175
+ print("checking all scenes...")
176
+ valid_scenes = legal_check_for_all_scenes(INPUT_DIR, target_shape)
177
+ print("valid scenes:", len(valid_scenes))
178
+
179
+ # test scenes
180
+ test_scenes = "/scratch/azureml/cr/j/e8e7ca980a5641daa86426c3fa644c10/exe/wd/dl3dv_benchmark/index.json"
181
+ with open(test_scenes, "r") as f:
182
+ overlap_scenes = json.load(f)
183
+
184
+ assert len(overlap_scenes) == 140, "test scenes should contain 140 scenes"
185
+
186
+ for stage in ["train"]:
187
+
188
+ error_logs = []
189
+ image_dirs = valid_scenes
190
+
191
+ chunk_size = 0
192
+ chunk_index = 0
193
+ chunk: list[Example] = []
194
+
195
+ def save_chunk():
196
+ nonlocal chunk_size, chunk_index, chunk
197
+
198
+ chunk_key = f"{chunk_index:0>6}"
199
+ dir = OUTPUT_DIR / stage
200
+ dir.mkdir(exist_ok=True, parents=True)
201
+ torch.save(chunk, dir / f"{chunk_key}.torch")
202
+
203
+ # Reset the chunk.
204
+ chunk_size = 0
205
+ chunk_index += 1
206
+ chunk = []
207
+
208
+ for image_dir in tqdm(image_dirs, desc=f"Processing {stage}"):
209
+ key = os.path.basename(image_dir.strip("/"))
210
+ # skip test scenes
211
+ if key in overlap_scenes:
212
+ print(f"scene {key} in benchmark, skip.")
213
+ continue
214
+
215
+ image_dir = Path(image_dir) / "images_8" # 270x480
216
+ # image_dir = Path(image_dir) / 'images_4' # 540x960
217
+
218
+ num_bytes = get_size(image_dir)
219
+
220
+ # Read images and metadata.
221
+ try:
222
+ images = load_images(image_dir)
223
+ except:
224
+ print("image loading error")
225
+ continue
226
+ meta_path = image_dir.parent / "transforms.json"
227
+ if not meta_path.is_file():
228
+ error_msg = f"---------> [ERROR] no meta file in {key}, skip."
229
+ print(error_msg)
230
+ error_logs.append(error_msg)
231
+ continue
232
+ example = load_metadata(meta_path)
233
+
234
+ # Merge the images into the example.
235
+ try:
236
+ example["images"] = [
237
+ images[timestamp.item()] for timestamp in example["timestamps"]
238
+ ]
239
+ except:
240
+ error_msg = f"---------> [ERROR] Some images missing in {key}, skip."
241
+ print(error_msg)
242
+ error_logs.append(error_msg)
243
+ continue
244
+
245
+ # Add the key to the example.
246
+ example["key"] = "dl3dv_" + key
247
+
248
+ chunk.append(example)
249
+ chunk_size += num_bytes
250
+
251
+ if chunk_size >= TARGET_BYTES_PER_CHUNK:
252
+ save_chunk()
253
+
254
+ if chunk_size > 0:
255
+ save_chunk()
256
+
257
+
258
+ if __name__ == "__main__":
259
+ base_input_dir = Path(args.input_base_dir)
260
+ base_output_dir = Path(args.output_base_dir)
261
+
262
+ # Process all directories from start_k to end_k
263
+ total_dirs = args.end_k - args.start_k + 1
264
+ processed_dirs = 0
265
+
266
+ for k in range(args.start_k, args.end_k + 1):
267
+ k_dir = f"{k}K"
268
+ input_dir = base_input_dir / k_dir
269
+ output_dir = base_output_dir / k_dir
270
+
271
+ if not input_dir.exists():
272
+ print(f"Warning: Input directory {input_dir} does not exist, skipping...")
273
+ continue
274
+
275
+ print(f"\n{'='*50}")
276
+ print(f"Processing directory {k_dir} ({processed_dirs + 1}/{total_dirs})")
277
+ print(f"Input: {input_dir}")
278
+ print(f"Output: {output_dir}")
279
+ print(f"{'='*50}")
280
+
281
+ # Process this directory
282
+ process_single_directory(input_dir, output_dir)
283
+
284
+ processed_dirs += 1
285
+ print(f"\nCompleted {k_dir} ({processed_dirs}/{total_dirs})")
286
+
287
+ print(f"\n{'='*50}")
288
+ print(f"All processing complete! Processed {processed_dirs}/{total_dirs} directories.")
289
+ print(f"{'='*50}")