Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import subprocess | |
| from glob import glob | |
| from pathlib import Path | |
| from typing import TypedDict | |
| import numpy as np | |
| import torch | |
| from PIL.Image import Image | |
| from jaxtyping import UInt8, Int, Float | |
| from torch import Tensor | |
| # def get_example_keys(stage: Literal["test", "train"]) -> list[str]: | |
| # image_keys = set( | |
| # example.name | |
| # for example in tqdm(list((INPUT_DIR / stage).iterdir()), desc="Indexing scenes") | |
| # if example.is_dir() and not example.name.startswith(".") | |
| # ) | |
| # # keys = image_keys & metadata_keys | |
| # keys = image_keys | |
| # # print(keys) | |
| # print(f"Found {len(keys)} keys.") | |
| # return sorted(list(keys)) | |
| def get_size(path: Path) -> int: | |
| """Get file or folder size in bytes.""" | |
| return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8")) | |
| def load_raw(path: Path) -> UInt8[Tensor, " length"]: | |
| return torch.tensor(np.memmap(path, dtype="uint8", mode="r")) | |
| def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]: | |
| """Load JPG images as raw bytes (do not decode).""" | |
| return { | |
| int(path.stem.split("_")[-1]): load_raw(path) | |
| for path in example_path.iterdir() | |
| if path.suffix.lower() not in [".npz"] | |
| } | |
| class Metadata(TypedDict): | |
| url: str | |
| timestamps: Int[Tensor, " camera"] | |
| cameras: Float[Tensor, "camera entry"] | |
| class Example(Metadata): | |
| key: str | |
| images: list[UInt8[Tensor, "..."]] | |
| def load_metadata(example_path: Path) -> Metadata: | |
| blender2opencv = np.array( | |
| [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] | |
| ) | |
| url = str(example_path).split("/")[-3] | |
| with open(example_path, "r") as f: | |
| meta_data = json.load(f) | |
| store_h, store_w = meta_data["h"], meta_data["w"] | |
| fx, fy, cx, cy = ( | |
| meta_data["fl_x"], | |
| meta_data["fl_y"], | |
| meta_data["cx"], | |
| meta_data["cy"], | |
| ) | |
| saved_fx = float(fx) / float(store_w) | |
| saved_fy = float(fy) / float(store_h) | |
| saved_cx = float(cx) / float(store_w) | |
| saved_cy = float(cy) / float(store_h) | |
| timestamps = [] | |
| cameras = [] | |
| opencv_c2ws = [] # will be used to calculate camera distance | |
| for frame in meta_data["frames"]: | |
| timestamps.append( | |
| int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1]) | |
| ) | |
| camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0] | |
| # transform_matrix is in blender c2w, while we need to store opencv w2c matrix here | |
| opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv | |
| opencv_c2ws.append(opencv_c2w) | |
| camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist()) | |
| cameras.append(np.array(camera)) | |
| # timestamp should be the one that match the above images keys, use for indexing | |
| timestamps = torch.tensor(timestamps, dtype=torch.int64) | |
| cameras = torch.tensor(np.stack(cameras), dtype=torch.float32) | |
| return {"url": url, "timestamps": timestamps, "cameras": cameras} | |
| def partition_train_test_splits(root_dir, n_test=10): | |
| sub_folders = sorted(glob(os.path.join(root_dir, "*/"))) | |
| test_list = sub_folders[::n_test] | |
| train_list = [x for x in sub_folders if x not in test_list] | |
| out_dict = {"train": train_list, "test": test_list} | |
| return out_dict | |
| def is_image_shape_matched(image_dir, target_shape): | |
| image_path = sorted(glob(str(image_dir / "*"))) | |
| if len(image_path) == 0: | |
| return False | |
| image_path = image_path[0] | |
| try: | |
| im = Image.open(image_path) | |
| except: | |
| return False | |
| w, h = im.size | |
| if (h, w) == target_shape: | |
| return True | |
| else: | |
| return False | |