File size: 9,281 Bytes
af758d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import os
import zipfile
from typing import Tuple, Optional

import numpy as np
import torch
import torch.nn.functional as F


def _center_crop(tensor_bchw: torch.Tensor, crop_h: int, crop_w: int) -> torch.Tensor:
    _, _, h, w = tensor_bchw.shape
    top = max((h - crop_h) // 2, 0)
    left = max((w - crop_w) // 2, 0)
    return tensor_bchw[:, :, top : top + crop_h, left : left + crop_w]


def _adjust_intrinsics_for_resize_and_crop(
    intrinsics_3x3: np.ndarray,
    src_hw: Tuple[int, int],
    resize_hw: Tuple[int, int],
    crop_hw: Tuple[int, int],
) -> np.ndarray:
    src_h, src_w = src_hw
    resize_h, resize_w = resize_hw
    crop_h, crop_w = crop_hw

    K = intrinsics_3x3.copy()

    sx = resize_w / float(src_w)
    sy = resize_h / float(src_h)
    K[0, 0] *= sx
    K[1, 1] *= sy
    K[0, 2] *= sx
    K[1, 2] *= sy

    off_x = max((resize_w - crop_w) // 2, 0)
    off_y = max((resize_h - crop_h) // 2, 0)
    K[0, 2] -= off_x
    K[1, 2] -= off_y

    return K


def _intrinsics_from_fxfycxcy(fxfycxcy: np.ndarray) -> np.ndarray:
    fx, fy, cx, cy = [float(x) for x in fxfycxcy]
    K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
    return K


def _load_pose_matrix_for_frame(pose_npz_path: str, frame_idx: int) -> np.ndarray:
    data = np.load(pose_npz_path)
    inds = data["inds"]
    arr = data["data"]
    pos = int(np.searchsorted(inds, frame_idx))
    if not (0 <= pos < len(inds)) or int(inds[pos]) != int(frame_idx):
        raise FileNotFoundError(f"Pose for frame {frame_idx} not found in {pose_npz_path}")
    mat = arr[pos]
    if mat.shape == (16,):
        mat = mat.reshape(4, 4)
    assert mat.shape == (4, 4)
    return mat.astype(np.float32)


def _load_intrinsics_for_frame(intrinsics_npz_path: str, frame_idx: int) -> np.ndarray:
    data = np.load(intrinsics_npz_path)
    inds = data["inds"]
    arr = data["data"]
    pos = int(np.searchsorted(inds, frame_idx))
    if not (0 <= pos < len(inds)) or int(inds[pos]) != int(frame_idx):
        raise FileNotFoundError(
            f"Intrinsics for frame {frame_idx} not found in {intrinsics_npz_path}"
        )
    item = arr[pos]
    if item.shape == (3, 3):
        K = item.astype(np.float32)
    elif item.shape[-1] == 4:
        K = _intrinsics_from_fxfycxcy(item)
    else:
        raise ValueError(
            f"Unsupported intrinsics format {item.shape} in {intrinsics_npz_path}"
        )
    return K


def _read_depth_from_zip(zip_path: str, frame_idx: int) -> np.ndarray:
    try:
        import OpenEXR  # type: ignore
    except ImportError as e:
        raise ImportError("OpenEXR package is required to read VIPE depth EXR files") from e

    fname = f"{frame_idx:05d}.exr"
    with zipfile.ZipFile(zip_path, "r") as zf:
        with zf.open(fname, "r") as f:
            exr = OpenEXR.InputFile(f)
            dw = exr.header()["dataWindow"]
            height = dw.max.y - dw.min.y + 1
            width = dw.max.x - dw.min.x + 1
            depth = np.frombuffer(exr.channel("Z"), np.float16).astype(np.float32)
            depth = depth.reshape(height, width)
    return depth


def _read_mask_from_zip(zip_path: str, frame_idx: int) -> Optional[np.ndarray]:
    try:
        import cv2  # type: ignore
    except ImportError:
        return None
    fname = f"{frame_idx:05d}.png"
    if not os.path.exists(zip_path):
        return None
    with zipfile.ZipFile(zip_path, "r") as zf:
        try:
            with zf.open(fname, "r") as f:
                buf = np.frombuffer(f.read(), np.uint8)
                img = cv2.imdecode(buf, cv2.IMREAD_UNCHANGED)
        except KeyError:
            return None
    if img is None:
        return None
    if img.ndim == 3:
        img = img[..., 0]
    mask = (img > 0).astype(np.float32)
    return mask


def _read_mp4_frame(mp4_path: str, frame_idx: int) -> np.ndarray:
    try:
        from decord import VideoReader  # type: ignore
    except ImportError as e:
        raise ImportError("decord is required to read VIPE rgb mp4") from e
    vr = VideoReader(mp4_path, num_threads=4)
    if frame_idx < 0 or frame_idx >= len(vr):
        raise IndexError(
            f"Requested frame_idx {frame_idx} is out of bounds for video length {len(vr)}"
        )
    frame = vr.get_batch([frame_idx])
    try:
        frame_np = frame.asnumpy()
    except AttributeError:
        frame_np = frame.numpy()
    frame_np = frame_np[0]  # (H, W, C)
    return frame_np


def _find_clip_paths(vipe_root_or_mp4: str, video_idx: int = 0) -> Tuple[str, str, str, str, Optional[str]]:
    if vipe_root_or_mp4.endswith(".mp4"):
        mp4_path = vipe_root_or_mp4
        base = os.path.splitext(os.path.basename(mp4_path))[0]
        root = os.path.dirname(os.path.dirname(mp4_path))
    else:
        rgb_dir = os.path.join(vipe_root_or_mp4, "rgb")
        mp4_files = [
            os.path.join(rgb_dir, f)
            for f in sorted(os.listdir(rgb_dir))
            if f.endswith(".mp4")
        ]
        if len(mp4_files) == 0:
            raise FileNotFoundError(f"No mp4 found under {rgb_dir}")
        mp4_path = mp4_files[video_idx]
        base = os.path.splitext(os.path.basename(mp4_path))[0]
        root = vipe_root_or_mp4

    depth_zip = os.path.join(root, "depth", f"{base}.zip")
    pose_npz = os.path.join(root, "pose", f"{base}.npz")
    intr_npz = os.path.join(root, "intrinsics", f"{base}.npz")
    mask_zip = os.path.join(root, "mask", f"{base}.zip")
    if not os.path.exists(mask_zip):
        mask_zip = None
    return mp4_path, depth_zip, pose_npz, intr_npz, mask_zip


def load_vipe_data(
    vipe_root_or_mp4: str,
    starting_frame_idx: int,
    resize_hw: Tuple[int, int] = (720, 1280),
    crop_hw: Tuple[int, int] = (704, 1280),
    num_frames: int = 121,
    read_mask: bool = False,
    video_idx: int = 0,
):
    mp4_path, depth_zip, pose_npz, intr_npz, mask_zip = _find_clip_paths(vipe_root_or_mp4, video_idx=video_idx)

    # Read the sequence of RGB frames
    try:
        from decord import VideoReader  # type: ignore
    except ImportError as e:
        raise ImportError("decord is required to read VIPE rgb mp4") from e
    vr = VideoReader(mp4_path, num_threads=4)
    total_len = len(vr)
    # If starting index is beyond the video, clamp to last frame
    if starting_frame_idx >= total_len:
        starting_frame_idx = max(0, total_len - 1)
    last_available_idx = total_len - 1
    # Build index list and repeat the last available frame if not enough
    frame_indices = list(range(starting_frame_idx, min(starting_frame_idx + num_frames, total_len)))
    while len(frame_indices) < num_frames:
        frame_indices.append(last_available_idx)
    batch = vr.get_batch(frame_indices)
    try:
        frames_np = batch.asnumpy()
    except AttributeError:
        frames_np = batch.numpy()
    # frames_np: (T, H, W, C) in [0,255]
    frames_np = frames_np.astype(np.float32) / 255.0
    src_h, src_w = frames_np.shape[1], frames_np.shape[2]

    # Load per-frame pose (c2w) and intrinsics, convert and adjust K
    w2cs_list = []
    Ks_list = []
    for fidx in frame_indices:
        c2w_44 = _load_pose_matrix_for_frame(pose_npz, fidx)
        w2c_44 = np.linalg.inv(c2w_44).astype(np.float32)
        w2cs_list.append(w2c_44)

        K_src = _load_intrinsics_for_frame(intr_npz, fidx)
        K_adj = _adjust_intrinsics_for_resize_and_crop(K_src, (src_h, src_w), resize_hw, crop_hw)
        Ks_list.append(K_adj)

    w2cs_np = np.stack(w2cs_list, axis=0)  # (T, 4, 4)
    Ks_np = np.stack(Ks_list, axis=0)      # (T, 3, 3)

    # Depth/mask for the whole sequence
    depth_list = []
    mask_list = []
    for fidx in frame_indices:
        d_hw = _read_depth_from_zip(depth_zip, fidx)
        depth_list.append(d_hw)
        if read_mask and mask_zip:
            m_hw = _read_mask_from_zip(mask_zip, fidx)
        else:
            m_hw = None
        mask_list.append(m_hw)

    # Convert to torch and apply resize/crop
    frames_t = torch.from_numpy(frames_np).permute(0, 3, 1, 2).contiguous()  # (T, C, H, W)
    depth_seq = torch.from_numpy(np.stack(depth_list, axis=0)).unsqueeze(1).contiguous()  # (T,1,H,W)
    mask_seq_np = []
    for m in mask_list:
        if m is None:
            mask_seq_np.append(np.ones((src_h, src_w), dtype=np.float32))
        else:
            mask_seq_np.append(m.astype(np.float32))
    mask_seq = torch.from_numpy(np.stack(mask_seq_np, axis=0)).unsqueeze(1).contiguous()  # (T,1,H,W)

    rh, rw = resize_hw
    ch, cw = crop_hw

    frames_t = F.interpolate(frames_t, size=(rh, rw), mode="bilinear", align_corners=False)
    depth_seq = F.interpolate(depth_seq, size=(rh, rw), mode="bilinear", align_corners=False)
    mask_seq = F.interpolate(mask_seq, size=(rh, rw), mode="nearest")

    frames_t = _center_crop(frames_t, ch, cw)  # (T, C, ch, cw)
    depth_seq = _center_crop(depth_seq, ch, cw)    # (T, 1, ch, cw)
    mask_seq = _center_crop(mask_seq, ch, cw)      # (T, 1, ch, cw)

    frames_t = frames_t * 2.0 - 1.0  # to [-1, 1]

    # Full sequences (T, ...)
    w2cs_T44 = torch.from_numpy(w2cs_np).contiguous()
    Ks_T33 = torch.from_numpy(Ks_np).contiguous()

    return (
        frames_t,
        depth_seq,
        mask_seq,
        w2cs_T44,
        Ks_T33,  
    )