File size: 8,295 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
Generic video clip dataset wrapper.

Pulls fixed-length clips of length L from any underlying frame-indexed
trajectory dataset. Used for video training (temporal Kalman filter loss
ablations) and video eval.

Each item is a dict with stacked-frame tensors:
    image       : (L, 3, H, W)
    depth       : (L, 1, H, W)
    mask        : (L, 1, H, W)
    sequence_id : str
    frame_ids   : list[int]
"""
from __future__ import annotations

import json
import os
import re
from typing import Sequence

import cv2
import numpy as np
import torch
from omegaconf import DictConfig
from torchvision.transforms import Compose

from ppd.utils.logger import Log


def _read_rgb(path: str) -> np.ndarray:
    rgb = cv2.imread(path)
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    return (rgb / 255.0).astype(np.float32)


def _read_depth_npy(path: str) -> np.ndarray:
    d = np.load(path).astype(np.float32)
    if d.ndim == 3 and d.shape[-1] == 1:
        d = d[..., 0]
    return d


class TartanAirVideoClip:
    """Sample fixed-length clips from extracted TartanAir trajectories.

    Expected layout (from the V1 zips we downloaded under
    /mnt/sig/datasets/train/tartanair/extracted/):

        <data_root>/<scene>/<difficulty>/<P###>/image_left/NNNNNN_left.png
        <data_root>/<scene>/<difficulty>/<P###>/depth_left/NNNNNN_left_depth.npy
    """

    def __init__(
        self,
        data_root: str,
        clip_length: int = 8,
        stride: int = 1,
        scenes: Sequence[str] | None = None,
        transforms: list | None = None,
        max_depth: float = 80.0,
    ):
        self.cfg = DictConfig(
            dict(data_root=data_root, clip_length=clip_length, stride=stride, max_depth=max_depth)
        )
        self.dataset_name = "tartanair_video"
        self.transform = Compose(transforms or [])
        self._build(scenes)

    def _build(self, scenes: Sequence[str] | None) -> None:
        root = self.cfg.data_root
        L = self.cfg.clip_length
        S = self.cfg.stride
        self.clips: list[tuple[str, list[str], list[str]]] = []
        if not os.path.isdir(root):
            Log.warn(f"TartanAir video root missing: {root}")
            return
        for scene in sorted(os.listdir(root)):
            if scenes is not None and scene not in scenes:
                continue
            scene_dir = os.path.join(root, scene)
            if not os.path.isdir(scene_dir):
                continue
            for difficulty in ("Easy", "Hard"):
                diff_dir = os.path.join(scene_dir, difficulty)
                if not os.path.isdir(diff_dir):
                    continue
                for traj in sorted(os.listdir(diff_dir)):
                    img_dir = os.path.join(diff_dir, traj, "image_left")
                    dpt_dir = os.path.join(diff_dir, traj, "depth_left")
                    if not (os.path.isdir(img_dir) and os.path.isdir(dpt_dir)):
                        continue
                    frames = sorted(
                        f for f in os.listdir(img_dir) if f.endswith("_left.png")
                    )
                    if len(frames) < L * S:
                        continue
                    for start in range(0, len(frames) - L * S + 1, max(L // 2, 1)):
                        idx = [start + i * S for i in range(L)]
                        rgb_paths = [os.path.join(img_dir, frames[i]) for i in idx]
                        dpt_paths = [
                            os.path.join(
                                dpt_dir, frames[i].replace("_left.png", "_left_depth.npy")
                            )
                            for i in idx
                        ]
                        seq_id = f"{scene}/{difficulty}/{traj}"
                        self.clips.append((seq_id, rgb_paths, dpt_paths))
        Log.info(f"TartanAir video: {len(self.clips)} clips")

    def __len__(self) -> int:
        return len(self.clips)

    def __getitem__(self, idx: int) -> dict:
        seq_id, rgb_paths, dpt_paths = self.clips[idx]
        images = []
        depths = []
        masks = []
        for r, d in zip(rgb_paths, dpt_paths):
            rgb = _read_rgb(r)
            depth = _read_depth_npy(d)
            mask = np.logical_and(depth > 0.1, ~np.isnan(depth)) & ~np.isinf(depth)
            mask = np.logical_and(mask, depth < self.cfg.max_depth)
            sample = {"image": rgb, "depth": depth, "mask": mask.astype(np.uint8)}
            sample = self.transform(sample)
            images.append(sample["image"])
            depths.append(sample["depth"])
            masks.append(sample["mask"])
        return {
            "image": np.stack(images, axis=0),
            "depth": np.stack(depths, axis=0),
            "mask": np.stack(masks, axis=0),
            "dataset_name": self.dataset_name,
            "sequence_id": seq_id,
            "frame_ids": list(range(len(images))),
        }


class BonnRGBDVideoClip:
    """ Bonn dynamic RGB-D dataset clip loader.

    Each unzipped sequence has:
        rgb/<timestamp>.png
        depth/<timestamp>.png  (16-bit, mm)
        rgb.txt, depth.txt     (timestamps)
        associated.txt         (rgb-depth pairing, optional)

    For simplicity, we pair frames by index after sorting.
    """

    def __init__(
        self,
        data_root: str,
        clip_length: int = 8,
        stride: int = 1,
        sequences: Sequence[str] | None = None,
        transforms: list | None = None,
    ):
        self.cfg = DictConfig(
            dict(data_root=data_root, clip_length=clip_length, stride=stride)
        )
        self.dataset_name = "bonn_rgbd_video"
        self.transform = Compose(transforms or [])
        self._build(sequences)

    def _build(self, sequences) -> None:
        root = self.cfg.data_root
        L = self.cfg.clip_length
        S = self.cfg.stride
        self.clips: list[tuple[str, list[str], list[str]]] = []
        if not os.path.isdir(root):
            Log.warn(f"Bonn root missing: {root}")
            return
        # bonn sequences live in subdirectories
        for d in sorted(os.listdir(root)):
            if sequences is not None and d not in sequences:
                continue
            seq_dir = os.path.join(root, d)
            if not os.path.isdir(seq_dir):
                continue
            rgb_dir = os.path.join(seq_dir, "rgb")
            dpt_dir = os.path.join(seq_dir, "depth")
            if not (os.path.isdir(rgb_dir) and os.path.isdir(dpt_dir)):
                continue
            rgb_files = sorted(f for f in os.listdir(rgb_dir) if f.endswith(".png"))
            dpt_files = sorted(f for f in os.listdir(dpt_dir) if f.endswith(".png"))
            n = min(len(rgb_files), len(dpt_files))
            if n < L * S:
                continue
            for start in range(0, n - L * S + 1, max(L // 2, 1)):
                idx = [start + i * S for i in range(L)]
                rgb_paths = [os.path.join(rgb_dir, rgb_files[i]) for i in idx]
                dpt_paths = [os.path.join(dpt_dir, dpt_files[i]) for i in idx]
                self.clips.append((d, rgb_paths, dpt_paths))
        Log.info(f"Bonn video: {len(self.clips)} clips from {len(set(c[0] for c in self.clips))} sequences")

    def __len__(self) -> int:
        return len(self.clips)

    def __getitem__(self, idx: int) -> dict:
        seq_id, rgb_paths, dpt_paths = self.clips[idx]
        images, depths, masks = [], [], []
        for r, d in zip(rgb_paths, dpt_paths):
            rgb = _read_rgb(r)
            # 16-bit PNG, mm scale → meters /5000 by Bonn convention
            depth = cv2.imread(d, cv2.IMREAD_ANYDEPTH).astype(np.float32) / 5000.0
            mask = np.logical_and(depth > 0.01, depth < 10.0)
            sample = {"image": rgb, "depth": depth, "mask": mask.astype(np.uint8)}
            sample = self.transform(sample)
            images.append(sample["image"])
            depths.append(sample["depth"])
            masks.append(sample["mask"])
        return {
            "image": np.stack(images, axis=0),
            "depth": np.stack(depths, axis=0),
            "mask": np.stack(masks, axis=0),
            "dataset_name": self.dataset_name,
            "sequence_id": seq_id,
            "frame_ids": list(range(len(images))),
        }