|
|
import os |
|
|
import zipfile |
|
|
from typing import Tuple, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def _center_crop(tensor_bchw: torch.Tensor, crop_h: int, crop_w: int) -> torch.Tensor: |
|
|
_, _, h, w = tensor_bchw.shape |
|
|
top = max((h - crop_h) // 2, 0) |
|
|
left = max((w - crop_w) // 2, 0) |
|
|
return tensor_bchw[:, :, top : top + crop_h, left : left + crop_w] |
|
|
|
|
|
|
|
|
def _adjust_intrinsics_for_resize_and_crop( |
|
|
intrinsics_3x3: np.ndarray, |
|
|
src_hw: Tuple[int, int], |
|
|
resize_hw: Tuple[int, int], |
|
|
crop_hw: Tuple[int, int], |
|
|
) -> np.ndarray: |
|
|
src_h, src_w = src_hw |
|
|
resize_h, resize_w = resize_hw |
|
|
crop_h, crop_w = crop_hw |
|
|
|
|
|
K = intrinsics_3x3.copy() |
|
|
|
|
|
sx = resize_w / float(src_w) |
|
|
sy = resize_h / float(src_h) |
|
|
K[0, 0] *= sx |
|
|
K[1, 1] *= sy |
|
|
K[0, 2] *= sx |
|
|
K[1, 2] *= sy |
|
|
|
|
|
off_x = max((resize_w - crop_w) // 2, 0) |
|
|
off_y = max((resize_h - crop_h) // 2, 0) |
|
|
K[0, 2] -= off_x |
|
|
K[1, 2] -= off_y |
|
|
|
|
|
return K |
|
|
|
|
|
|
|
|
def _intrinsics_from_fxfycxcy(fxfycxcy: np.ndarray) -> np.ndarray: |
|
|
fx, fy, cx, cy = [float(x) for x in fxfycxcy] |
|
|
K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32) |
|
|
return K |
|
|
|
|
|
|
|
|
def _load_pose_matrix_for_frame(pose_npz_path: str, frame_idx: int) -> np.ndarray: |
|
|
data = np.load(pose_npz_path) |
|
|
inds = data["inds"] |
|
|
arr = data["data"] |
|
|
pos = int(np.searchsorted(inds, frame_idx)) |
|
|
if not (0 <= pos < len(inds)) or int(inds[pos]) != int(frame_idx): |
|
|
raise FileNotFoundError(f"Pose for frame {frame_idx} not found in {pose_npz_path}") |
|
|
mat = arr[pos] |
|
|
if mat.shape == (16,): |
|
|
mat = mat.reshape(4, 4) |
|
|
assert mat.shape == (4, 4) |
|
|
return mat.astype(np.float32) |
|
|
|
|
|
|
|
|
def _load_intrinsics_for_frame(intrinsics_npz_path: str, frame_idx: int) -> np.ndarray: |
|
|
data = np.load(intrinsics_npz_path) |
|
|
inds = data["inds"] |
|
|
arr = data["data"] |
|
|
pos = int(np.searchsorted(inds, frame_idx)) |
|
|
if not (0 <= pos < len(inds)) or int(inds[pos]) != int(frame_idx): |
|
|
raise FileNotFoundError( |
|
|
f"Intrinsics for frame {frame_idx} not found in {intrinsics_npz_path}" |
|
|
) |
|
|
item = arr[pos] |
|
|
if item.shape == (3, 3): |
|
|
K = item.astype(np.float32) |
|
|
elif item.shape[-1] == 4: |
|
|
K = _intrinsics_from_fxfycxcy(item) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unsupported intrinsics format {item.shape} in {intrinsics_npz_path}" |
|
|
) |
|
|
return K |
|
|
|
|
|
|
|
|
def _read_depth_from_zip(zip_path: str, frame_idx: int) -> np.ndarray: |
|
|
try: |
|
|
import OpenEXR |
|
|
except ImportError as e: |
|
|
raise ImportError("OpenEXR package is required to read VIPE depth EXR files") from e |
|
|
|
|
|
fname = f"{frame_idx:05d}.exr" |
|
|
with zipfile.ZipFile(zip_path, "r") as zf: |
|
|
with zf.open(fname, "r") as f: |
|
|
exr = OpenEXR.InputFile(f) |
|
|
dw = exr.header()["dataWindow"] |
|
|
height = dw.max.y - dw.min.y + 1 |
|
|
width = dw.max.x - dw.min.x + 1 |
|
|
depth = np.frombuffer(exr.channel("Z"), np.float16).astype(np.float32) |
|
|
depth = depth.reshape(height, width) |
|
|
return depth |
|
|
|
|
|
|
|
|
def _read_mask_from_zip(zip_path: str, frame_idx: int) -> Optional[np.ndarray]: |
|
|
try: |
|
|
import cv2 |
|
|
except ImportError: |
|
|
return None |
|
|
fname = f"{frame_idx:05d}.png" |
|
|
if not os.path.exists(zip_path): |
|
|
return None |
|
|
with zipfile.ZipFile(zip_path, "r") as zf: |
|
|
try: |
|
|
with zf.open(fname, "r") as f: |
|
|
buf = np.frombuffer(f.read(), np.uint8) |
|
|
img = cv2.imdecode(buf, cv2.IMREAD_UNCHANGED) |
|
|
except KeyError: |
|
|
return None |
|
|
if img is None: |
|
|
return None |
|
|
if img.ndim == 3: |
|
|
img = img[..., 0] |
|
|
mask = (img > 0).astype(np.float32) |
|
|
return mask |
|
|
|
|
|
|
|
|
def _read_mp4_frame(mp4_path: str, frame_idx: int) -> np.ndarray: |
|
|
try: |
|
|
from decord import VideoReader |
|
|
except ImportError as e: |
|
|
raise ImportError("decord is required to read VIPE rgb mp4") from e |
|
|
vr = VideoReader(mp4_path, num_threads=4) |
|
|
if frame_idx < 0 or frame_idx >= len(vr): |
|
|
raise IndexError( |
|
|
f"Requested frame_idx {frame_idx} is out of bounds for video length {len(vr)}" |
|
|
) |
|
|
frame = vr.get_batch([frame_idx]) |
|
|
try: |
|
|
frame_np = frame.asnumpy() |
|
|
except AttributeError: |
|
|
frame_np = frame.numpy() |
|
|
frame_np = frame_np[0] |
|
|
return frame_np |
|
|
|
|
|
|
|
|
def _find_clip_paths(vipe_root_or_mp4: str, video_idx: int = 0) -> Tuple[str, str, str, str, Optional[str]]: |
|
|
if vipe_root_or_mp4.endswith(".mp4"): |
|
|
mp4_path = vipe_root_or_mp4 |
|
|
base = os.path.splitext(os.path.basename(mp4_path))[0] |
|
|
root = os.path.dirname(os.path.dirname(mp4_path)) |
|
|
else: |
|
|
rgb_dir = os.path.join(vipe_root_or_mp4, "rgb") |
|
|
mp4_files = [ |
|
|
os.path.join(rgb_dir, f) |
|
|
for f in sorted(os.listdir(rgb_dir)) |
|
|
if f.endswith(".mp4") |
|
|
] |
|
|
if len(mp4_files) == 0: |
|
|
raise FileNotFoundError(f"No mp4 found under {rgb_dir}") |
|
|
mp4_path = mp4_files[video_idx] |
|
|
base = os.path.splitext(os.path.basename(mp4_path))[0] |
|
|
root = vipe_root_or_mp4 |
|
|
|
|
|
depth_zip = os.path.join(root, "depth", f"{base}.zip") |
|
|
pose_npz = os.path.join(root, "pose", f"{base}.npz") |
|
|
intr_npz = os.path.join(root, "intrinsics", f"{base}.npz") |
|
|
mask_zip = os.path.join(root, "mask", f"{base}.zip") |
|
|
if not os.path.exists(mask_zip): |
|
|
mask_zip = None |
|
|
return mp4_path, depth_zip, pose_npz, intr_npz, mask_zip |
|
|
|
|
|
|
|
|
def load_vipe_data( |
|
|
vipe_root_or_mp4: str, |
|
|
starting_frame_idx: int, |
|
|
resize_hw: Tuple[int, int] = (720, 1280), |
|
|
crop_hw: Tuple[int, int] = (704, 1280), |
|
|
num_frames: int = 121, |
|
|
read_mask: bool = False, |
|
|
video_idx: int = 0, |
|
|
): |
|
|
mp4_path, depth_zip, pose_npz, intr_npz, mask_zip = _find_clip_paths(vipe_root_or_mp4, video_idx=video_idx) |
|
|
|
|
|
|
|
|
try: |
|
|
from decord import VideoReader |
|
|
except ImportError as e: |
|
|
raise ImportError("decord is required to read VIPE rgb mp4") from e |
|
|
vr = VideoReader(mp4_path, num_threads=4) |
|
|
total_len = len(vr) |
|
|
|
|
|
if starting_frame_idx >= total_len: |
|
|
starting_frame_idx = max(0, total_len - 1) |
|
|
last_available_idx = total_len - 1 |
|
|
|
|
|
frame_indices = list(range(starting_frame_idx, min(starting_frame_idx + num_frames, total_len))) |
|
|
while len(frame_indices) < num_frames: |
|
|
frame_indices.append(last_available_idx) |
|
|
batch = vr.get_batch(frame_indices) |
|
|
try: |
|
|
frames_np = batch.asnumpy() |
|
|
except AttributeError: |
|
|
frames_np = batch.numpy() |
|
|
|
|
|
frames_np = frames_np.astype(np.float32) / 255.0 |
|
|
src_h, src_w = frames_np.shape[1], frames_np.shape[2] |
|
|
|
|
|
|
|
|
w2cs_list = [] |
|
|
Ks_list = [] |
|
|
for fidx in frame_indices: |
|
|
c2w_44 = _load_pose_matrix_for_frame(pose_npz, fidx) |
|
|
w2c_44 = np.linalg.inv(c2w_44).astype(np.float32) |
|
|
w2cs_list.append(w2c_44) |
|
|
|
|
|
K_src = _load_intrinsics_for_frame(intr_npz, fidx) |
|
|
K_adj = _adjust_intrinsics_for_resize_and_crop(K_src, (src_h, src_w), resize_hw, crop_hw) |
|
|
Ks_list.append(K_adj) |
|
|
|
|
|
w2cs_np = np.stack(w2cs_list, axis=0) |
|
|
Ks_np = np.stack(Ks_list, axis=0) |
|
|
|
|
|
|
|
|
depth_list = [] |
|
|
mask_list = [] |
|
|
for fidx in frame_indices: |
|
|
d_hw = _read_depth_from_zip(depth_zip, fidx) |
|
|
depth_list.append(d_hw) |
|
|
if read_mask and mask_zip: |
|
|
m_hw = _read_mask_from_zip(mask_zip, fidx) |
|
|
else: |
|
|
m_hw = None |
|
|
mask_list.append(m_hw) |
|
|
|
|
|
|
|
|
frames_t = torch.from_numpy(frames_np).permute(0, 3, 1, 2).contiguous() |
|
|
depth_seq = torch.from_numpy(np.stack(depth_list, axis=0)).unsqueeze(1).contiguous() |
|
|
mask_seq_np = [] |
|
|
for m in mask_list: |
|
|
if m is None: |
|
|
mask_seq_np.append(np.ones((src_h, src_w), dtype=np.float32)) |
|
|
else: |
|
|
mask_seq_np.append(m.astype(np.float32)) |
|
|
mask_seq = torch.from_numpy(np.stack(mask_seq_np, axis=0)).unsqueeze(1).contiguous() |
|
|
|
|
|
rh, rw = resize_hw |
|
|
ch, cw = crop_hw |
|
|
|
|
|
frames_t = F.interpolate(frames_t, size=(rh, rw), mode="bilinear", align_corners=False) |
|
|
depth_seq = F.interpolate(depth_seq, size=(rh, rw), mode="bilinear", align_corners=False) |
|
|
mask_seq = F.interpolate(mask_seq, size=(rh, rw), mode="nearest") |
|
|
|
|
|
frames_t = _center_crop(frames_t, ch, cw) |
|
|
depth_seq = _center_crop(depth_seq, ch, cw) |
|
|
mask_seq = _center_crop(mask_seq, ch, cw) |
|
|
|
|
|
frames_t = frames_t * 2.0 - 1.0 |
|
|
|
|
|
|
|
|
w2cs_T44 = torch.from_numpy(w2cs_np).contiguous() |
|
|
Ks_T33 = torch.from_numpy(Ks_np).contiguous() |
|
|
|
|
|
return ( |
|
|
frames_t, |
|
|
depth_seq, |
|
|
mask_seq, |
|
|
w2cs_T44, |
|
|
Ks_T33, |
|
|
) |
|
|
|