Muhammad Taqi Raza
adding lyra files
af758d1
import os
import zipfile
from typing import Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
def _center_crop(tensor_bchw: torch.Tensor, crop_h: int, crop_w: int) -> torch.Tensor:
_, _, h, w = tensor_bchw.shape
top = max((h - crop_h) // 2, 0)
left = max((w - crop_w) // 2, 0)
return tensor_bchw[:, :, top : top + crop_h, left : left + crop_w]
def _adjust_intrinsics_for_resize_and_crop(
intrinsics_3x3: np.ndarray,
src_hw: Tuple[int, int],
resize_hw: Tuple[int, int],
crop_hw: Tuple[int, int],
) -> np.ndarray:
src_h, src_w = src_hw
resize_h, resize_w = resize_hw
crop_h, crop_w = crop_hw
K = intrinsics_3x3.copy()
sx = resize_w / float(src_w)
sy = resize_h / float(src_h)
K[0, 0] *= sx
K[1, 1] *= sy
K[0, 2] *= sx
K[1, 2] *= sy
off_x = max((resize_w - crop_w) // 2, 0)
off_y = max((resize_h - crop_h) // 2, 0)
K[0, 2] -= off_x
K[1, 2] -= off_y
return K
def _intrinsics_from_fxfycxcy(fxfycxcy: np.ndarray) -> np.ndarray:
fx, fy, cx, cy = [float(x) for x in fxfycxcy]
K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
return K
def _load_pose_matrix_for_frame(pose_npz_path: str, frame_idx: int) -> np.ndarray:
data = np.load(pose_npz_path)
inds = data["inds"]
arr = data["data"]
pos = int(np.searchsorted(inds, frame_idx))
if not (0 <= pos < len(inds)) or int(inds[pos]) != int(frame_idx):
raise FileNotFoundError(f"Pose for frame {frame_idx} not found in {pose_npz_path}")
mat = arr[pos]
if mat.shape == (16,):
mat = mat.reshape(4, 4)
assert mat.shape == (4, 4)
return mat.astype(np.float32)
def _load_intrinsics_for_frame(intrinsics_npz_path: str, frame_idx: int) -> np.ndarray:
data = np.load(intrinsics_npz_path)
inds = data["inds"]
arr = data["data"]
pos = int(np.searchsorted(inds, frame_idx))
if not (0 <= pos < len(inds)) or int(inds[pos]) != int(frame_idx):
raise FileNotFoundError(
f"Intrinsics for frame {frame_idx} not found in {intrinsics_npz_path}"
)
item = arr[pos]
if item.shape == (3, 3):
K = item.astype(np.float32)
elif item.shape[-1] == 4:
K = _intrinsics_from_fxfycxcy(item)
else:
raise ValueError(
f"Unsupported intrinsics format {item.shape} in {intrinsics_npz_path}"
)
return K
def _read_depth_from_zip(zip_path: str, frame_idx: int) -> np.ndarray:
try:
import OpenEXR # type: ignore
except ImportError as e:
raise ImportError("OpenEXR package is required to read VIPE depth EXR files") from e
fname = f"{frame_idx:05d}.exr"
with zipfile.ZipFile(zip_path, "r") as zf:
with zf.open(fname, "r") as f:
exr = OpenEXR.InputFile(f)
dw = exr.header()["dataWindow"]
height = dw.max.y - dw.min.y + 1
width = dw.max.x - dw.min.x + 1
depth = np.frombuffer(exr.channel("Z"), np.float16).astype(np.float32)
depth = depth.reshape(height, width)
return depth
def _read_mask_from_zip(zip_path: str, frame_idx: int) -> Optional[np.ndarray]:
try:
import cv2 # type: ignore
except ImportError:
return None
fname = f"{frame_idx:05d}.png"
if not os.path.exists(zip_path):
return None
with zipfile.ZipFile(zip_path, "r") as zf:
try:
with zf.open(fname, "r") as f:
buf = np.frombuffer(f.read(), np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_UNCHANGED)
except KeyError:
return None
if img is None:
return None
if img.ndim == 3:
img = img[..., 0]
mask = (img > 0).astype(np.float32)
return mask
def _read_mp4_frame(mp4_path: str, frame_idx: int) -> np.ndarray:
try:
from decord import VideoReader # type: ignore
except ImportError as e:
raise ImportError("decord is required to read VIPE rgb mp4") from e
vr = VideoReader(mp4_path, num_threads=4)
if frame_idx < 0 or frame_idx >= len(vr):
raise IndexError(
f"Requested frame_idx {frame_idx} is out of bounds for video length {len(vr)}"
)
frame = vr.get_batch([frame_idx])
try:
frame_np = frame.asnumpy()
except AttributeError:
frame_np = frame.numpy()
frame_np = frame_np[0] # (H, W, C)
return frame_np
def _find_clip_paths(vipe_root_or_mp4: str, video_idx: int = 0) -> Tuple[str, str, str, str, Optional[str]]:
if vipe_root_or_mp4.endswith(".mp4"):
mp4_path = vipe_root_or_mp4
base = os.path.splitext(os.path.basename(mp4_path))[0]
root = os.path.dirname(os.path.dirname(mp4_path))
else:
rgb_dir = os.path.join(vipe_root_or_mp4, "rgb")
mp4_files = [
os.path.join(rgb_dir, f)
for f in sorted(os.listdir(rgb_dir))
if f.endswith(".mp4")
]
if len(mp4_files) == 0:
raise FileNotFoundError(f"No mp4 found under {rgb_dir}")
mp4_path = mp4_files[video_idx]
base = os.path.splitext(os.path.basename(mp4_path))[0]
root = vipe_root_or_mp4
depth_zip = os.path.join(root, "depth", f"{base}.zip")
pose_npz = os.path.join(root, "pose", f"{base}.npz")
intr_npz = os.path.join(root, "intrinsics", f"{base}.npz")
mask_zip = os.path.join(root, "mask", f"{base}.zip")
if not os.path.exists(mask_zip):
mask_zip = None
return mp4_path, depth_zip, pose_npz, intr_npz, mask_zip
def load_vipe_data(
vipe_root_or_mp4: str,
starting_frame_idx: int,
resize_hw: Tuple[int, int] = (720, 1280),
crop_hw: Tuple[int, int] = (704, 1280),
num_frames: int = 121,
read_mask: bool = False,
video_idx: int = 0,
):
mp4_path, depth_zip, pose_npz, intr_npz, mask_zip = _find_clip_paths(vipe_root_or_mp4, video_idx=video_idx)
# Read the sequence of RGB frames
try:
from decord import VideoReader # type: ignore
except ImportError as e:
raise ImportError("decord is required to read VIPE rgb mp4") from e
vr = VideoReader(mp4_path, num_threads=4)
total_len = len(vr)
# If starting index is beyond the video, clamp to last frame
if starting_frame_idx >= total_len:
starting_frame_idx = max(0, total_len - 1)
last_available_idx = total_len - 1
# Build index list and repeat the last available frame if not enough
frame_indices = list(range(starting_frame_idx, min(starting_frame_idx + num_frames, total_len)))
while len(frame_indices) < num_frames:
frame_indices.append(last_available_idx)
batch = vr.get_batch(frame_indices)
try:
frames_np = batch.asnumpy()
except AttributeError:
frames_np = batch.numpy()
# frames_np: (T, H, W, C) in [0,255]
frames_np = frames_np.astype(np.float32) / 255.0
src_h, src_w = frames_np.shape[1], frames_np.shape[2]
# Load per-frame pose (c2w) and intrinsics, convert and adjust K
w2cs_list = []
Ks_list = []
for fidx in frame_indices:
c2w_44 = _load_pose_matrix_for_frame(pose_npz, fidx)
w2c_44 = np.linalg.inv(c2w_44).astype(np.float32)
w2cs_list.append(w2c_44)
K_src = _load_intrinsics_for_frame(intr_npz, fidx)
K_adj = _adjust_intrinsics_for_resize_and_crop(K_src, (src_h, src_w), resize_hw, crop_hw)
Ks_list.append(K_adj)
w2cs_np = np.stack(w2cs_list, axis=0) # (T, 4, 4)
Ks_np = np.stack(Ks_list, axis=0) # (T, 3, 3)
# Depth/mask for the whole sequence
depth_list = []
mask_list = []
for fidx in frame_indices:
d_hw = _read_depth_from_zip(depth_zip, fidx)
depth_list.append(d_hw)
if read_mask and mask_zip:
m_hw = _read_mask_from_zip(mask_zip, fidx)
else:
m_hw = None
mask_list.append(m_hw)
# Convert to torch and apply resize/crop
frames_t = torch.from_numpy(frames_np).permute(0, 3, 1, 2).contiguous() # (T, C, H, W)
depth_seq = torch.from_numpy(np.stack(depth_list, axis=0)).unsqueeze(1).contiguous() # (T,1,H,W)
mask_seq_np = []
for m in mask_list:
if m is None:
mask_seq_np.append(np.ones((src_h, src_w), dtype=np.float32))
else:
mask_seq_np.append(m.astype(np.float32))
mask_seq = torch.from_numpy(np.stack(mask_seq_np, axis=0)).unsqueeze(1).contiguous() # (T,1,H,W)
rh, rw = resize_hw
ch, cw = crop_hw
frames_t = F.interpolate(frames_t, size=(rh, rw), mode="bilinear", align_corners=False)
depth_seq = F.interpolate(depth_seq, size=(rh, rw), mode="bilinear", align_corners=False)
mask_seq = F.interpolate(mask_seq, size=(rh, rw), mode="nearest")
frames_t = _center_crop(frames_t, ch, cw) # (T, C, ch, cw)
depth_seq = _center_crop(depth_seq, ch, cw) # (T, 1, ch, cw)
mask_seq = _center_crop(mask_seq, ch, cw) # (T, 1, ch, cw)
frames_t = frames_t * 2.0 - 1.0 # to [-1, 1]
# Full sequences (T, ...)
w2cs_T44 = torch.from_numpy(w2cs_np).contiguous()
Ks_T33 = torch.from_numpy(Ks_np).contiguous()
return (
frames_t,
depth_seq,
mask_seq,
w2cs_T44,
Ks_T33,
)