Learn2Splat / optgs /scripts /convert_dl3dv_utils.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import json
import os
import subprocess
from glob import glob
from pathlib import Path
from typing import TypedDict
import numpy as np
import torch
from PIL.Image import Image
from jaxtyping import UInt8, Int, Float
from torch import Tensor
# def get_example_keys(stage: Literal["test", "train"]) -> list[str]:
# image_keys = set(
# example.name
# for example in tqdm(list((INPUT_DIR / stage).iterdir()), desc="Indexing scenes")
# if example.is_dir() and not example.name.startswith(".")
# )
# # keys = image_keys & metadata_keys
# keys = image_keys
# # print(keys)
# print(f"Found {len(keys)} keys.")
# return sorted(list(keys))
def get_size(path: Path) -> int:
"""Get file or folder size in bytes."""
return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8"))
def load_raw(path: Path) -> UInt8[Tensor, " length"]:
return torch.tensor(np.memmap(path, dtype="uint8", mode="r"))
def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]:
"""Load JPG images as raw bytes (do not decode)."""
return {
int(path.stem.split("_")[-1]): load_raw(path)
for path in example_path.iterdir()
if path.suffix.lower() not in [".npz"]
}
class Metadata(TypedDict):
url: str
timestamps: Int[Tensor, " camera"]
cameras: Float[Tensor, "camera entry"]
class Example(Metadata):
key: str
images: list[UInt8[Tensor, "..."]]
def load_metadata(example_path: Path) -> Metadata:
blender2opencv = np.array(
[[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
)
url = str(example_path).split("/")[-3]
with open(example_path, "r") as f:
meta_data = json.load(f)
store_h, store_w = meta_data["h"], meta_data["w"]
fx, fy, cx, cy = (
meta_data["fl_x"],
meta_data["fl_y"],
meta_data["cx"],
meta_data["cy"],
)
saved_fx = float(fx) / float(store_w)
saved_fy = float(fy) / float(store_h)
saved_cx = float(cx) / float(store_w)
saved_cy = float(cy) / float(store_h)
timestamps = []
cameras = []
opencv_c2ws = [] # will be used to calculate camera distance
for frame in meta_data["frames"]:
timestamps.append(
int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1])
)
camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0]
# transform_matrix is in blender c2w, while we need to store opencv w2c matrix here
opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv
opencv_c2ws.append(opencv_c2w)
camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist())
cameras.append(np.array(camera))
# timestamp should be the one that match the above images keys, use for indexing
timestamps = torch.tensor(timestamps, dtype=torch.int64)
cameras = torch.tensor(np.stack(cameras), dtype=torch.float32)
return {"url": url, "timestamps": timestamps, "cameras": cameras}
def partition_train_test_splits(root_dir, n_test=10):
sub_folders = sorted(glob(os.path.join(root_dir, "*/")))
test_list = sub_folders[::n_test]
train_list = [x for x in sub_folders if x not in test_list]
out_dict = {"train": train_list, "test": test_list}
return out_dict
def is_image_shape_matched(image_dir, target_shape):
image_path = sorted(glob(str(image_dir / "*")))
if len(image_path) == 0:
return False
image_path = image_path[0]
try:
im = Image.open(image_path)
except:
return False
w, h = im.size
if (h, w) == target_shape:
return True
else:
return False