import json import os import subprocess from glob import glob from pathlib import Path from typing import TypedDict import numpy as np import torch from PIL.Image import Image from jaxtyping import UInt8, Int, Float from torch import Tensor # def get_example_keys(stage: Literal["test", "train"]) -> list[str]: # image_keys = set( # example.name # for example in tqdm(list((INPUT_DIR / stage).iterdir()), desc="Indexing scenes") # if example.is_dir() and not example.name.startswith(".") # ) # # keys = image_keys & metadata_keys # keys = image_keys # # print(keys) # print(f"Found {len(keys)} keys.") # return sorted(list(keys)) def get_size(path: Path) -> int: """Get file or folder size in bytes.""" return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8")) def load_raw(path: Path) -> UInt8[Tensor, " length"]: return torch.tensor(np.memmap(path, dtype="uint8", mode="r")) def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]: """Load JPG images as raw bytes (do not decode).""" return { int(path.stem.split("_")[-1]): load_raw(path) for path in example_path.iterdir() if path.suffix.lower() not in [".npz"] } class Metadata(TypedDict): url: str timestamps: Int[Tensor, " camera"] cameras: Float[Tensor, "camera entry"] class Example(Metadata): key: str images: list[UInt8[Tensor, "..."]] def load_metadata(example_path: Path) -> Metadata: blender2opencv = np.array( [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] ) url = str(example_path).split("/")[-3] with open(example_path, "r") as f: meta_data = json.load(f) store_h, store_w = meta_data["h"], meta_data["w"] fx, fy, cx, cy = ( meta_data["fl_x"], meta_data["fl_y"], meta_data["cx"], meta_data["cy"], ) saved_fx = float(fx) / float(store_w) saved_fy = float(fy) / float(store_h) saved_cx = float(cx) / float(store_w) saved_cy = float(cy) / float(store_h) timestamps = [] cameras = [] opencv_c2ws = [] # will be used to calculate camera distance for frame in meta_data["frames"]: timestamps.append( int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1]) ) camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0] # transform_matrix is in blender c2w, while we need to store opencv w2c matrix here opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv opencv_c2ws.append(opencv_c2w) camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist()) cameras.append(np.array(camera)) # timestamp should be the one that match the above images keys, use for indexing timestamps = torch.tensor(timestamps, dtype=torch.int64) cameras = torch.tensor(np.stack(cameras), dtype=torch.float32) return {"url": url, "timestamps": timestamps, "cameras": cameras} def partition_train_test_splits(root_dir, n_test=10): sub_folders = sorted(glob(os.path.join(root_dir, "*/"))) test_list = sub_folders[::n_test] train_list = [x for x in sub_folders if x not in test_list] out_dict = {"train": train_list, "test": test_list} return out_dict def is_image_shape_matched(image_dir, target_shape): image_path = sorted(glob(str(image_dir / "*"))) if len(image_path) == 0: return False image_path = image_path[0] try: im = Image.open(image_path) except: return False w, h = im.size if (h, w) == target_shape: return True else: return False