3AM / data.py
nycu-cplab's picture
app overall
0bb5fcf
import json
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch
from tqdm import tqdm
import os
from glob import glob
from torch.utils.data import Dataset
from must3r.tools.image import get_resize_function
from PIL import Image
import numpy as np
from einops import rearrange
from typing import List, Dict, Optional, Tuple
from pycocotools import mask as mask_utils
import random, cv2
from scipy.spatial.transform import Rotation
SAV_ANNOT_RATE = 4 # SA-V: annotations at 6 fps, video at 24 fps
def load_images(folder_content, size, patch_size = 16, verbose = True):
imgs = []
transform = ImgNorm = T.Compose([T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
resize_funcs = []
for content in folder_content:
if isinstance(content, str):
if verbose:
print(f'Loading image from {content} ', end = '')
rgb_image = Image.open(content).convert('RGB')
elif isinstance(content, Image.Image):
rgb_image = content
else:
raise ValueError(f'Unknown content type: {type(content)}')
rgb_image.load()
W, H = rgb_image.size
resize_func, _, to_orig = get_resize_function(size, patch_size, H, W)
resize_funcs.append(resize_func)
rgb_tensor = resize_func(transform(rgb_image))
imgs.append(dict(img=rgb_tensor, true_shape=np.int32([rgb_tensor.shape[-2], rgb_tensor.shape[-1]])))
if verbose:
print(f'with resolution {W}x{H} --> {rgb_tensor.shape[-1]}x{rgb_tensor.shape[-2]}')
return imgs, resize_funcs
def _decode_rle(rle: Dict, h: int, w: int) -> np.ndarray:
if not rle or "counts" not in rle:
return np.zeros((h, w), dtype=np.uint8)
counts = rle["counts"]
if isinstance(counts, str):
counts = counts.encode("utf-8")
m = mask_utils.decode({"size": [h, w], "counts": counts})
return (np.asarray(m).squeeze() > 0)
def _read_frame_rgb(cap: cv2.VideoCapture, idx: int, fallback_hw: Optional[Tuple[int,int]]=None) -> np.ndarray:
ok = cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
if not ok:
raise RuntimeError(f"cv2.VideoCapture.set({idx}) failed")
else:
ok, bgr = cap.read()
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
class SAVTrainDataset(Dataset):
"""
SA-V train Dataset (mp4 + {video_id}_{manual|auto}.json).
Scans JSON with pattern: root/*/*.json (non-recursive).
__getitem__ follows the requested 1–5 procedure.
"""
def __init__(
self,
data_root: str,
mask_type: Optional[str] = None, # None | "manual" | "auto"
img_mean = (0.485, 0.456, 0.406),
img_std = (0.229, 0.224, 0.225),
N: int = 8,
image_size: int = 1024,
verbose: bool = False,
max_stride: int = 1, # kept for parity, not used in this flow
dataset_scale: int = 32,
area_thresh: float = 0.01, # area ratio threshold at original HxW
valid_must3r_sizes = [224, 512]
):
assert mask_type in (None, "manual", "auto")
assert N >= 1
self.verbose = verbose
self.data_root = data_root
self.dataset_scale = int(dataset_scale)
self.N = int(N)
self.mask_type = mask_type
self.area_thresh = float(area_thresh)
self.max_stride = int(max_stride)
self.valid_must3r_sizes = valid_must3r_sizes
self.image_transform = T.Compose([
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.NEAREST_EXACT),
T.Normalize(mean=img_mean, std=img_std),
])
self.instance_transform = T.Compose([
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.NEAREST_EXACT),
])
# --- collect through JSONs (non-recursive) ---
json_paths = glob(os.path.join(data_root, "*", "*.json"))
self.items: List[Tuple[str, str]] = [] # (vpath, jpath)
for jpath in tqdm(json_paths, desc="scanning jsons"):
base = os.path.splitext(os.path.basename(jpath))[0]
# filter by mask_type if specified
if self.mask_type is not None and not base.endswith(f"_{self.mask_type}"):
continue
if base.endswith("_manual"):
vid = base[:-7]
elif base.endswith("_auto"):
vid = base[:-5]
else:
# strictly require suffix
continue
vpath = os.path.join(os.path.dirname(jpath), f"{vid}.mp4")
if os.path.isfile(vpath):
self.items.append((vpath, jpath))
print(f"Collected {len(self.items)} video-json pairs")
self._log_path = "./sav_dataset_resample.log"
def __len__(self):
return self.dataset_scale * len(self.items)
def _resample(self):
return self[random.randrange(len(self))]
def _log(self, msg: str):
try:
with open(self._log_path, "a") as f:
f.write(msg.rstrip() + "\n")
except Exception:
pass
def __getitem__(self, idx: int):
vpath, jpath = self.items[idx % len(self.items)]
# 1) load json
with open(jpath, "r") as f:
meta = json.load(f)
masklet: List[List[Dict]] = meta.get("masklet", [])
if not isinstance(masklet, list) or len(masklet) < self.N:
self._log(f"[short_json] {jpath}: len(masklet)={len(masklet)} < N={self.N}")
return self._resample()
H, W = int(meta["video_height"]), int(meta["video_width"])
# 2) randomly sample a center frame idx in masklet, build sample_indices = [idx-N, idx+N]
center = random.randrange(len(masklet))
left = max(0, center - self.N * self.max_stride)
right = min(len(masklet), center + self.N * self.max_stride)
sample_indices = list(range(left, right))
if len(sample_indices) < self.N:
self._log(f"[short_span] {jpath}: span={len(sample_indices)} < N={self.N}")
return self._resample()
obj_order = None
while True:
if len(sample_indices) < self.N:
self._log(f"[exhausted_span] {jpath}: remaining span < N; resample")
return self._resample()
f0 = sample_indices[0]
rles = masklet[f0] if isinstance(masklet[f0], list) else []
if len(rles) == 0:
# no objects at this frame, pop and continue
sample_indices.pop(0)
continue
obj_order = list(range(len(rles)))
random.shuffle(obj_order)
has_valid_id = False
for oid in obj_order:
m = _decode_rle(rles[oid], H, W)
area = int(m.sum())
if area <= 0:
continue
ratio = area / float(H * W + 1e-6)
if ratio >= self.area_thresh:
has_valid_id = True
break
if has_valid_id:
break
else:
# tried all object indices, none passed; pop first frame and continue
sample_indices.pop(0)
# downsample sample_indices to exactly N
sample_indices = sample_indices[::min(len(sample_indices) // self.N, self.max_stride)][:self.N]
assert len(sample_indices) == self.N
# 5) similar to MOSE dataset: read frames, build masks only at anchor frame
cap = cv2.VideoCapture(vpath)
frames_rgb = []
frame_indices_24 = []
for f_annot in sample_indices:
f24 = int(f_annot * SAV_ANNOT_RATE)
frames_rgb.append(_read_frame_rgb(cap, f24, fallback_hw=(H, W)))
frame_indices_24.append(f24)
cap.release()
# build original_images tensor [N, 3, H, W]
original_imgs_pil = [Image.fromarray(fr) for fr in frames_rgb]
# must3r parity fields
must3r_size = np.random.choice(self.valid_must3r_sizes).item()
views, resize_funcs = load_images(original_imgs_pil, size = must3r_size, patch_size = 16, verbose = self.verbose)
original_instances = []
original_imgs = []
for frame_idx, (resize_func, sample_idx) in enumerate(zip(resize_funcs, sample_indices)):
assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}'
# assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}'
# assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}'
if frame_idx == 0:
for instance_id in obj_order + [None]:
if instance_id is None:
return self._resample()
if (resize_func.transforms[0](torch.from_numpy(_decode_rle(masklet[sample_idx][instance_id], H, W))).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * self.area_thresh)):
break
original_instances.append(resize_func.transforms[0](torch.from_numpy(_decode_rle(masklet[sample_idx][instance_id], H, W))))
original_imgs.append(resize_func.transforms[0](TF.to_tensor(original_imgs_pil[frame_idx])))
original_instances = torch.stack(original_instances).squeeze()[:, None]
instances = self.instance_transform(original_instances)
assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}'
original_imgs = torch.stack(original_imgs)
imgs = self.image_transform(original_imgs)
return {
"original_images": original_imgs, # [N,3,H,W]
"images": imgs, # [N,3,S,S]
"original_masks": original_instances, # [N,1,H,W]
"masks": instances, # [N,1,S,S]
"filelist": sample_indices,
"must3r_views": views,
"video": os.path.splitext(os.path.basename(vpath))[0],
"instance_id": int(instance_id),
"dataset": "sav",
"valid_masks": torch.ones_like(instances), # [N,1,S,S]
"must3r_size": must3r_size
}
class MOSEDataset(Dataset):
def __init__(
self,
data_root: str,
img_mean = (0.485, 0.456, 0.406),
img_std = (0.229, 0.224, 0.225),
N: int = 8,
image_size: int = 1024,
verbose = False,
max_stride = 2,
dataset_scale = 1,
valid_must3r_sizes = [224, 512]
):
self.verbose = verbose
self.data_root = data_root
self.dataset_scale = dataset_scale
self.N = N
self.max_stride = max_stride
self.image_transform = T.Compose([
T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT),
T.Normalize(mean = img_mean, std = img_std)
])
self.instance_transform = T.Compose([
T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT),
])
self.valid_must3r_sizes = valid_must3r_sizes
self.videos = os.listdir(os.path.join(data_root, 'JPEGImages'))
self.frames = {}
self.masks = {}
self.indices = []
for video in tqdm(self.videos):
if not os.path.isdir(os.path.join(data_root, 'JPEGImages', video)):
continue
frames = sorted(glob(os.path.join(data_root, 'JPEGImages', video, '*.jpg')), key = lambda x: int(os.path.basename(x).split('.')[0]))
masks = sorted(glob(os.path.join(data_root, 'Annotations', video, '*.png')), key = lambda x: int(os.path.basename(x).split('.')[0]))
if len(frames) < self.N:
if self.verbose:
print(f"skip video {video} as not enough frames")
continue
assert len(frames) == len(masks) and len(frames) >= self.N, f'{len(frames)=}, {len(masks)=} in {video}'
self.frames[video] = frames
self.masks[video] = masks
self.indices += [(video, idx) for idx in range(len(frames))]
print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])}')
def __len__(self):
return len(self.indices) * self.dataset_scale
def __getitem__(self, idx):
idx = idx % len(self.indices)
video, idx = self.indices[idx]
sampled_indices = np.arange(max(0, idx - self.N), idx).tolist() + np.arange(idx, min(len(self.frames[video]), idx + self.N * self.max_stride)).tolist()
unique_ids = None
while unique_ids is None or len(unique_ids) == 0:
if unique_ids is not None:
sampled_indices.pop(0)
if len(sampled_indices) < self.N:
return self[np.random.randint(len(self))]
unique_ids, counts = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = True)
unique_ids = unique_ids[(unique_ids != 0) & (counts > counts.sum() * 0.01)]
sampled_indices = sampled_indices[::len(sampled_indices) // self.N][:self.N]
assert len(unique_ids) > 0 and len(sampled_indices) == self.N
filelist = [self.frames[video][idx] for idx in sampled_indices]
must3r_size = np.random.choice(self.valid_must3r_sizes).item()
views, resize_funcs = load_images(filelist, size = must3r_size, patch_size = 16, verbose = self.verbose)
original_instances = []
original_imgs = []
for frame_idx, (resize_func, sample_idx) in enumerate(zip(resize_funcs, sampled_indices)):
assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}'
# assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}'
# assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}'
if frame_idx == 0:
for instance_id in np.random.permutation(unique_ids).tolist() + [None]:
if instance_id is None:
return self[np.random.randint(len(self))]
if (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01):
break
original_instances.append(resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id))
original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx]))))
original_instances = torch.stack(original_instances).squeeze()[:, None]
instances = self.instance_transform(original_instances)
assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}'
original_imgs = torch.stack(original_imgs)
imgs = self.image_transform(original_imgs)
return {
'original_images': original_imgs,
'images': imgs,
'original_masks': original_instances,
'masks': instances,
'filelist': filelist,
'must3r_views': views,
'video': video,
'instance_id': instance_id,
'dataset': 'mose',
'valid_masks': torch.ones_like(instances),
'must3r_size': must3r_size,
}
# Reads a Ground truth trajectory file
def read_trajectory_file(filepath):
def _transform_from_Rt(R, t):
M = np.identity(4)
M[:3, :3] = R
M[:3, 3] = t
return M
# Reads a Ground truth trajectory line
def _read_trajectory_line(line):
line = line.rstrip().split(",")
pose = {}
pose["timestamp"] = int(line[1])
translation = np.array([float(p) for p in line[3:6]])
quat_xyzw = np.array([float(o) for o in line[6:10]])
rot_matrix = Rotation.from_quat(quat_xyzw).as_matrix()
rot_matrix = np.array(rot_matrix)
pose["position"] = translation
pose["rotation"] = rot_matrix
pose["transform"] = _transform_from_Rt(rot_matrix, translation)
return pose
assert os.path.exists(filepath), f"Could not find trajectory file: {filepath}"
with open(filepath, "r") as f:
_ = f.readline() # header
positions = []
rotations = []
transforms = []
timestamps = []
for line in f.readlines():
pose = _read_trajectory_line(line)
positions.append(pose["position"])
rotations.append(pose["rotation"])
transforms.append(pose["transform"])
timestamps.append(pose["timestamp"])
positions = np.stack(positions)
rotations = np.stack(rotations)
transforms = np.stack(transforms)
timestamps = np.array(timestamps)
return {
"ts": positions,
"Rs": rotations,
"Ts_world_from_device": transforms,
"timestamps": timestamps,
}
from projectaria_tools.core import calibration
from projectaria_tools.core.image import InterpolationMethod
class ASEDataset(Dataset):
def __init__(
self,
data_root: str,
img_mean = (0.485, 0.456, 0.406),
img_std = (0.229, 0.224, 0.225),
N: int = 8,
image_size: int = 1024,
verbose = False,
dataset_scale = 1,
continuous_prob = 0,
invalid_classes = ['ceiling', 'wall', 'empty_space', 'background', 'floor', 'window'],
valid_must3r_sizes = [224, 512]
):
self.verbose = verbose
self.data_root = data_root
self.dataset_scale = dataset_scale
self.continuous_prob = continuous_prob
self.N = N
self.image_transform = T.Compose([
T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT),
T.Normalize(mean = img_mean, std = img_std)
])
self.instance_transform = T.Compose([
T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT),
])
self.valid_must3r_sizes = valid_must3r_sizes
from projectaria_tools.projects import ase
from projectaria_tools.core import calibration
self.ase_device = ase.get_ase_rgb_calibration()
self.ase_width, self.ase_height = self.ase_device.get_image_size()
assert self.ase_width == self.ase_height, f"Expected square images, got {self.ase_width}x{self.ase_height}"
self.ase_pinhole = calibration.get_linear_camera_calibration(
self.ase_width, self.ase_height, 320, "camera-rgb", self.ase_device.get_transform_device_camera()
)
self.fx, self.fy = self.ase_pinhole.get_focal_lengths()
self.cx, self.cy = self.ase_pinhole.get_principal_point()
self.K = np.array([[self.fx, 0, self.cx],
[0, self.fy, self.cy],
[0, 0, 1 ]], dtype = np.float32)
self.videos = os.listdir(os.path.join(data_root))
self.frames = {}
self.masks = {}
self.must3r_feats = {}
self.appearances = {}
self.mask2indices = {}
self.validindices = {}
self.indices = []
for video in tqdm(self.videos, desc='Loading ASE videos'):
if not os.path.isdir(os.path.join(data_root, video)):
print(f"skip {video} as not a directory")
continue
frames = sorted(glob(os.path.join(data_root, video, 'undistorted', '*.jpg')))
masks = sorted(glob(os.path.join(data_root, video, 'undistorted-instances', '*.png')))
must3r_feats = sorted(glob(os.path.join(data_root, video, 'must3r-features', '*.pt')))
if not (len(must3r_feats) == len(frames) == len(masks)):
if self.verbose:
print(f"skip {video} as {len(must3r_feats)=}, {len(frames)=}, {len(masks)=} in {video}")
continue
assert all([os.path.splitext(os.path.basename(must3r_feat))[0] == os.path.splitext(os.path.basename(frame))[0] for must3r_feat, frame in zip(must3r_feats, frames)]), f'Must3r features and frames do not match in {video}'
if len(frames) < self.N:
if self.verbose:
print(f"skip video {video} as not enough frames")
continue
self.frames[video] = frames
self.masks[video] = masks
self.must3r_feats[video] = must3r_feats
self.appearances[video] = json.load(open(os.path.join(data_root, video, 'instances-appearances.json')))
self.mask2indices[video] = {
os.path.basename(m): i for i, m in enumerate(masks)
}
self.indices += [(video, idx) for idx in range(len(frames) - self.N + 1)]
self.validindices[video] = [int(instance_id) for instance_id, class_name in json.load(open(os.path.join(data_root, video, 'object_instances_to_classes.json'))).items() if class_name not in invalid_classes] # if os.path.exists(os.path.join(data_root, video, 'object_instances_to_classes.json')) else None
print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])} and {sum([(len(ids) if ids is not None else 0) for ids in self.validindices.values()])} valid instances')
self._log_path = "./ase_dataset_resample.log"
def __len__(self):
return len(self.indices) * self.dataset_scale
def __getitem__(self, idx):
idx = idx % len(self.indices)
video, idx = self.indices[idx]
## 1. Randomly shuffle frames
choices = np.delete(np.arange(len(self.frames[video]) - self.N + 1), idx)
sampled_indices = [idx] + np.random.choice(choices, size = len(choices), replace = False).tolist()
## 2. Find unique instance IDs in the first frame
unique_ids = None
while unique_ids is None or len(unique_ids) == 0:
if unique_ids is not None:
sampled_indices.pop(0)
if len(sampled_indices) < self.N:
return self[np.random.randint(len(self))]
unique_ids = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = False)
unique_ids = unique_ids[(unique_ids != 0) & np.array([class_id in self.validindices[video] for class_id in unique_ids])] # if self.validindices[video] is not None else True
first_frame_idx = sampled_indices[0]
assert len(unique_ids) > 0
## 3. Load the resize funcs of the first frame
feat_len = torch.load(self.must3r_feats[video][first_frame_idx], map_location = 'cpu')[-1].shape[-2]
must3r_size = original_must3r_size = (224 if feat_len == 196 else 512)
is_continuous = (np.random.rand() < self.continuous_prob) or original_must3r_size not in self.valid_must3r_sizes
if is_continuous:
must3r_size = np.random.choice(self.valid_must3r_sizes).item()
_, [resize_func] = load_images([self.frames[video][first_frame_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose)
assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}'
assert must3r_size != original_must3r_size or resize_func.transforms[1].size[0] * resize_func.transforms[1].size[1] == feat_len * 256, f'Expected {resize_func.transforms[1].size[0]}x{resize_func.transforms[1].size[1]} to be {feat_len * 256}, got {feat_len}'
for instance_id in np.random.permutation(unique_ids).tolist() + [None]:
if instance_id is None:
return self[np.random.randint(len(self))]
if (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.2) > (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][first_frame_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01):
break
if is_continuous:
sampled_indices = np.arange(first_frame_idx, min(len(self.frames[video]), first_frame_idx + self.N)).tolist()
# sampled_indices += np.random.choice(first_frame_idx, size = first_frame_idx, replace = False).tolist()
sampled_indices = sampled_indices[:self.N]
assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}'
else:
sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist()[:2]
sampled_indices = sorted(sampled_indices, key = lambda sample_idx: resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id).sum(), reverse = True) ## prioritize frames with larger masks
first_frame_idx = sampled_indices[0]
views, original_instances, original_imgs, filelist, extrinsics, depths, point_maps, fov_ratios = [], [], [], [], [], [], [], {}
pre_sampled_len = len(sampled_indices)
if len(sampled_indices) < self.N:
instance_appearance_candidates = set([self.mask2indices[video][p] for p in self.appearances[video][str(instance_id)]]) - set(sampled_indices)
sampled_indices += np.random.permutation(list(instance_appearance_candidates)).tolist()
sampled_indices += np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(instance_appearance_candidates) - set(sampled_indices))).tolist()
trajectory = read_trajectory_file(os.path.join(self.data_root, video, 'trajectory.csv'))
while len(views) < self.N and len(sampled_indices) >= self.N:
sample_idx = sampled_indices[len(views)]
[view], [resize_func] = load_images([self.frames[video][sample_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose)
instance_map = resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx])) == instance_id))
if len(views) >= pre_sampled_len and not (instance_map.shape[-1] * instance_map.shape[-2] * 0.005 < instance_map.sum() < instance_map.shape[-1] * instance_map.shape[-2] * 0.25):
sampled_indices.pop(len(views))
continue
extrinsic = trajectory['Ts_world_from_device'][sample_idx] @ self.ase_pinhole.get_transform_device_camera().to_matrix()
depth = calibration.distort_by_calibration(
np.array(Image.open(self.frames[video][sample_idx].replace('undistorted', 'depth').replace('vignette', 'depth').replace('.jpg', '.png'))), self.ase_pinhole, self.ase_device, InterpolationMethod.NEAREST_NEIGHBOR
).astype(np.float32) / 1000.0
point_map = resize_func.transforms[0](torch.rot90(torch.from_numpy(depth_to_world_pointmap(depth, extrinsic, self.K).astype(np.float32)).permute(2, 0, 1), k = -1, dims = (1, 2)))
assert point_map.shape[-2] == instance_map.shape[-2], f"Expected height {instance_map.shape[-2]}, got {point_map.shape[-2]}"
fov_ratio = None
if len(views) < pre_sampled_len or instance_map.sum().item() == 0 or \
(fov_ratio := (in_fov_ratio(point_map[:, instance_map].permute(1, 0), extrinsics[0], K = self.K, W = self.ase_height, H = self.ase_width, ## for rot -90
W_crop = abs(int(self.ase_height) - original_instances[0].shape[-2]) // 2,
H_crop = abs(int(self.ase_width) - original_instances[0].shape[-1]) // 2)[0])) > 0.25:
views.append(view)
original_instances.append(instance_map)
original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx]))))
filelist.append(self.frames[video][sample_idx])
extrinsics.append(extrinsic)
depths.append(resize_func.transforms[0](torch.rot90(torch.from_numpy(depth), k = -1, dims = (0, 1))))
point_maps.append(point_map)
fov_ratios[self.frames[video][sample_idx]] = fov_ratio if fov_ratio is not None else -1
else:
sampled_indices.pop(len(views))
continue
sampled_indices = sampled_indices[:len(views)]
if len(sampled_indices) < self.N:
open(self._log_path, "a").write(f"[short_span] {video}: span={len(sampled_indices)} < N={self.N}\n")
return self[np.random.randint(len(self))]
assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}'
if not is_continuous or (np.random.rand() < 0.8 and must3r_size == original_must3r_size):
assert original_must3r_size == must3r_size, f'If not continuous, must3r size should not change, got {must3r_size} and {original_must3r_size}'
must3r_feats_filelist = [self.must3r_feats[video][idx] for idx in sampled_indices]
must3r_feats = [torch.load(must3r_filepath, map_location = 'cpu') for must3r_filepath in must3r_feats_filelist]
must3r_feats_head = torch.cat([f[-1] for f in must3r_feats], dim = 0)
must3r_feats = [f[:-1] for f in must3r_feats]
must3r_feats = [torch.cat(f, dim = 0) for f in zip(*must3r_feats)]
must3r_feats = [
rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16)
for f in must3r_feats
]
else:
assert is_continuous, f'If must3r size changed, should be continuous sampling, got {must3r_size} and {original_must3r_size}'
must3r_feats = None
must3r_feats_head = None
original_instances = torch.stack(original_instances).squeeze()[:, None]
instances = self.instance_transform(original_instances)
assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}'
original_imgs = torch.stack(original_imgs)
imgs = self.image_transform(original_imgs)
# if is_continuous:
# permutation = torch.arange(len(instances))
# else:
# permutation = torch.argsort(instances.squeeze().sum(dim = (1, 2)), descending = True)
permutation = torch.arange(len(instances))
permutation[pre_sampled_len:] = torch.randperm(len(instances) - pre_sampled_len) + pre_sampled_len
return {
'original_images': original_imgs[permutation],
'images': imgs[permutation],
'original_masks': original_instances[permutation],
'masks': instances[permutation],
'filelist': [filelist[idx] for idx in permutation],
'must3r_views': [views[idx] for idx in permutation],
'must3r_size': must3r_size,
'video': video,
'instance_id': instance_id,
'dataset': 'scannetpp',
'valid_masks': torch.ones_like(instances),
'intrinsics': torch.from_numpy(self.K).unsqueeze(0).repeat(self.N, 1, 1)[permutation],
'extrinsics': torch.from_numpy(np.stack(extrinsics, axis = 0))[permutation],
'depths': torch.from_numpy(np.stack(depths, axis = 0))[permutation],
'point_maps': torch.from_numpy(np.stack(point_maps, axis = 0))[permutation],
'fov_ratios': fov_ratios,
'is_continuous': is_continuous
} | (
{
'must3r_feats': [f[permutation] for f in must3r_feats],
'must3r_feats_head': must3r_feats_head[permutation],
'must3r_feats_filelist': [must3r_feats_filelist[idx] for idx in permutation],
} if must3r_feats is not None else {}
)
def pose_from_qwxyz_txyz(elems):
qw, qx, qy, qz, tx, ty, tz = map(float, elems)
pose = np.eye(4)
pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix()
pose[:3, 3] = (tx, ty, tz)
return np.linalg.inv(pose) # returns cam2world
def depth_to_world_pointmap(depth, c2w, K, depth_type = 'range'):
"""
depth: (H,W) depth in meters, camera-Z
c2w: (4,4) camera-to-world transform
K: (3,3) camera intrinsics
Returns: (H,W,3) world xyz (NaN for invalid depth)
"""
Kinv = np.linalg.inv(K)
H_, W_ = depth.shape
ys, xs = np.meshgrid(np.arange(H_), np.arange(W_), indexing='ij')
ones = np.ones_like(xs, dtype=np.float64)
pix = np.stack([xs, ys, ones], axis=-1).reshape(-1, 3).T # (3,N)
rays_cam = Kinv @ pix # (3,N)
z = depth.reshape(-1) # (N,)
if depth_type == 'range':
rays_cam = rays_cam / np.linalg.norm(rays_cam, axis = 0, keepdims = True) # (3,N)
elif depth_type == 'z-buf':
pass
else:
raise ValueError(f'Unknown depth_type {depth_type}')
xyz_cam = rays_cam * z # scale each ray by depth
xyz_cam_h = np.vstack([xyz_cam, np.ones_like(z)]) # (4,N)
xyz_w_h = c2w @ xyz_cam_h # (4,N)
xyz_w = xyz_w_h[:3].T.reshape(H_, W_, 3)
mask = (depth <= 0) | ~np.isfinite(depth)
xyz_w[mask] = np.nan
return xyz_w
def in_fov_ratio(points, c2w, K, H, W, H_crop, W_crop):
"""
points: (N,3) world coords, torch tensor
c2w: (4,4) camera-to-world, torch tensor
K: (3,3) intrinsics, torch tensor
H,W: image size
"""
# device = points.device
K = K # .to(device)
# world -> camera
w2c = np.linalg.inv(c2w) # .to(device)
Pc = (points @ w2c[:3, :3].T) + w2c[:3, 3]
X, Y, Z = Pc[:,0], Pc[:,1], Pc[:,2]
# projection
u = K[0, 0] * (X / Z) + K[0, 2]
v = K[1, 1] * (Y / Z) + K[1, 2]
mask = (Z > 0) & (u >= W_crop) & (u < W - W_crop) & (v >= H_crop) & (v < H - H_crop)
return mask.float().mean(), mask
class ScanNetPPV2Dataset(Dataset):
def __init__(
self,
data_root: str,
must3r_data_root: str = None,
img_mean = (0.485, 0.456, 0.406),
img_std = (0.229, 0.224, 0.225),
N: int = 8,
image_size: int = 1024,
verbose = False,
dataset_scale = 1,
continuous_prob = 0,
instance_classes_file = '<your path to scannetppv2>/metadata/semantic_benchmark/top100_instance.txt',
split_file: str = '<your path to scannetppv2>/splits/nvs_sem_train.txt',
excluding_scenes = ["09d6e808b4", "0f69aefe3d", "1b379f1114", "1cbb105c6a", "2c7c10379b", "46638cfd0f", "4f341f3af0", "6ef2ac745a", "898a7dfd0c", "aa852f7871", "eea4ad9c04", 'd27235711b'], ## horizontal / vertical flip issues
valid_must3r_sizes = [224, 512]
):
self.verbose = verbose
self.data_root = data_root
self.must3r_data_root = must3r_data_root if must3r_data_root is not None else data_root
self.dataset_scale = dataset_scale
self.excluding_scenes = excluding_scenes
self.instance_classes = open(instance_classes_file).read().splitlines()
self.valid_scene_names = open(split_file).read().splitlines()
self.continuous_prob = continuous_prob
self.N = N
self.image_transform = T.Compose([
T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT),
T.Normalize(mean = img_mean, std = img_std)
])
self.instance_transform = T.Compose([
T.Resize((image_size, image_size), interpolation = T.InterpolationMode.NEAREST_EXACT),
])
self.valid_must3r_sizes = valid_must3r_sizes
self.videos = os.listdir(os.path.join(data_root))
self.frames = {}
self.masks = {}
self.must3r_feats = {}
self.appearances = {}
self.id2label_name = {}
self.intrinsics = {}
self.extrinsics = {}
self.indices = []
self._log_path = "./scannetppv2_dataset_resample.log"
for video in tqdm(self.videos, desc = 'Loading ScanNet++V2 videos'):
if video not in self.valid_scene_names or video in self.excluding_scenes:
if self.verbose:
print(f"skip {video} as not in split or excluded")
continue
if not os.path.isdir(os.path.join(data_root, video)):
print(f"skip {video} as not a directory")
continue
if video in ['46638cfd0f']:
if self.verbose:
print(f"skip {video} as broken")
continue
masks = sorted(glob(os.path.join(self.data_root, video, 'iphone', 'render_instance', '*.png')))
if len(masks) == 0:
if self.verbose:
print(f"skip {video} as no masks found")
continue
frames = [m.replace('render_instance', 'rgb').replace('.png', '.jpg') for m in masks]
must3r_feats = [m.replace(self.data_root, self.must3r_data_root).replace('iphone/render_instance', 'must3r-features').replace('.png', '.pt') for m in masks]
if not all([os.path.exists(p) for p in must3r_feats[:1]]):
if self.verbose:
print(f"skip {video} as not all must3r features or frames exist")
continue
# assert all([os.path.exists(p) for p in frames]), f'Not all frames exist in {video}'
self.frames[video] = frames
self.masks[video] = masks
self.must3r_feats[video] = must3r_feats
self.appearances[video] = json.loads(open(os.path.join(data_root, video, 'scans/instance-appearances.json')).read())
self.intrinsics[video] = self.load_intrinsics(os.path.join(data_root, video, 'iphone', 'colmap', 'cameras.txt'))
assert len(self.intrinsics[video]) == 1, f'Expected 1 camera, got {len(self.intrinsics[video])} in {video}'
self.extrinsics[video] = os.path.join(data_root, video, 'iphone', 'colmap', 'images.txt')
assert all([f_name == os.path.basename(m) for f_name, m in zip(self.appearances[video]['framenames'], self.masks[video])]), f'Frame names in appearances do not match masks in {video}'
self.id2label_name[video] = json.loads(open(os.path.join(data_root, video, 'scans/instance_id2label_name.json')).read())
self.indices += [(video, idx) for idx in range(len(frames) - self.N + 1)]
print(f'Found {len(self.indices)} frames, and {len(self.frames)} videos, with min length {min([len(self.frames[video]) for video in self.frames])} and max length {max([len(self.frames[video]) for video in self.frames])}')
def load_intrinsics(self, path):
with open(path, 'r') as f:
raw = f.read().splitlines()[3:] # skip header
intrinsics = {}
for camera in tqdm(raw, position = 1, leave = False):
camera = camera.split(' ')
intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]]
return intrinsics
def __len__(self):
return len(self.indices) * self.dataset_scale
def __getitem__(self, idx):
idx = idx % len(self.indices)
video, idx = self.indices[idx]
if len(glob(os.path.join(self.data_root, video, 'iphone/depth/*.png'))) == 0:
return self[np.random.randint(len(self))]
## 1. Randomly shuffle frames
choices = np.delete(np.arange(len(self.frames[video]) - self.N + 1), idx)
sampled_indices = [idx] + np.random.choice(choices, size = len(choices), replace = False).tolist()
## 2. Find unique instance IDs in the first frame
unique_ids = None
while unique_ids is None or len(unique_ids) == 0:
if unique_ids is not None:
sampled_indices.pop(0)
if len(sampled_indices) == 0:
return self[np.random.randint(len(self))]
unique_ids, _ = np.unique(np.array(Image.open(self.masks[video][sampled_indices[0]])), return_counts = True)
unique_ids = unique_ids[np.array([class_id not in [0, 65535] and self.id2label_name[video][str(class_id)] in self.instance_classes and all([s not in self.id2label_name[video][str(class_id)].lower() for s in ['wall', 'floor', 'ceiling', 'window', 'curtain', 'blind', 'table']]) for class_id in unique_ids])]
first_frame_idx = sampled_indices[0]
assert len(unique_ids) > 0
## 3. Load the resize funcs of the first frame
feat_len = torch.load(self.must3r_feats[video][first_frame_idx], map_location = 'cpu')[-1].shape[-2]
must3r_size = original_must3r_size = (224 if feat_len == 196 else 512)
is_continuous = (np.random.rand() < self.continuous_prob) or original_must3r_size not in self.valid_must3r_sizes
if is_continuous:
must3r_size = np.random.choice(self.valid_must3r_sizes).item()
_, [resize_func] = load_images([self.frames[video][first_frame_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose)
assert len(resize_func.transforms) == 2, f'Expected 2 transforms, got {len(resize_func.transforms)}'
# assert resize_func.transforms[0].size[0] > resize_func.transforms[1].size[0], f'Expected first transform to be larger than second, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}'
# assert resize_func.transforms[0].size[1] / resize_func.transforms[1].size[1] == resize_func.transforms[0].size[0] / resize_func.transforms[1].size[0], f'Expected aspect ratio to be preserved, got {resize_func.transforms[0].size} and {resize_func.transforms[1].size}'
assert must3r_size != original_must3r_size or resize_func.transforms[1].size[0] * resize_func.transforms[1].size[1] == feat_len * 256, f'Expected {resize_func.transforms[1].size[0]}x{resize_func.transforms[1].size[1]} to be {feat_len * 256}, got {feat_len}'
for instance_id in np.random.permutation(unique_ids).tolist() + [None]:
if instance_id is None:
return self[np.random.randint(len(self))]
if (resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][first_frame_idx]))) == instance_id)).sum() > (resize_func.transforms[0].size[0] * resize_func.transforms[0].size[1] * 0.01):
break
if is_continuous:
sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist()
# sampled_indices += np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(self.appearances[video][str(instance_id)]) - set(sampled_indices))).tolist()
sampled_indices = sampled_indices[:self.N]
assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}'
else:
sampled_indices = np.arange(first_frame_idx, len(self.frames[video])).tolist()[:2]
sampled_indices = sorted(sampled_indices, key = lambda sample_idx: resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx]))) == instance_id).sum(), reverse = True) ## prioritize frames with larger masks
first_frame_idx = sampled_indices[0]
raw_poses = {
raw.split()[-1].split('iphone/')[-1].split('video/')[-1]: raw.split()[1:-1]
for raw in open(self.extrinsics[video], 'r').read().splitlines() if (not raw.startswith('#')) and len(raw.split()) > 0
}
views, original_instances, original_imgs, filelist, extrinsics, raw_intrinsics, intrinsics, depths, point_maps, fov_ratios = [], [], [], [], [], [], [], [], [], {}
pre_sampled_len = len(sampled_indices)
if len(sampled_indices) < self.N:
sampled_indices = sampled_indices + np.random.permutation(list(set(self.appearances[video][self.id2label_name[video][str(instance_id)]]) - set(sampled_indices))).tolist() + \
np.random.permutation(list(set(np.arange(len(self.frames[video])).tolist()) - set(self.appearances[video][self.id2label_name[video][str(instance_id)]]) - set(sampled_indices))).tolist()
while len(views) < self.N and len(sampled_indices) >= self.N:
sample_idx = sampled_indices[len(views)]
[view], [resize_func] = load_images([self.frames[video][sample_idx]], size = must3r_size, patch_size = 16, verbose = self.verbose)
instance_map = resize_func.transforms[0](torch.from_numpy(np.array(Image.open(self.masks[video][sample_idx])) == instance_id))
if len(views) >= pre_sampled_len and (0 < instance_map.sum() < instance_map.shape[-1] * instance_map.shape[-2] * 0.01):
sampled_indices.pop(len(views))
continue
f_name = os.path.basename(self.frames[video][sample_idx])
extrinsic = pose_from_qwxyz_txyz(raw_poses[f_name][:-1])
raw_intrinsic = self.intrinsics[video][int(raw_poses[f_name][-1])]
intrinsic = np.array([[raw_intrinsic[3], 0, raw_intrinsic[5]],
[0, raw_intrinsic[4], raw_intrinsic[6]],
[0, 0, 1 ]], dtype = np.float32)
depth = np.array(Image.open(self.frames[video][sample_idx].replace('rgb', 'depth').replace('.jpg', '.png')).resize((int(raw_intrinsic[1]), int(raw_intrinsic[2]))), dtype = np.float32) / 1000.0
point_map = resize_func.transforms[0](torch.from_numpy(depth_to_world_pointmap(depth, extrinsic, intrinsic).astype(np.float32)).permute(2, 0, 1))
assert point_map.shape[-2] == instance_map.shape[-2] == int(raw_intrinsic[2]), f'Expected height {int(raw_intrinsic[2])}, got {point_map.shape[-2]} and {instance_map.shape[-2]}'
fov_ratio = None
if len(views) < pre_sampled_len or instance_map.sum().item() == 0 or \
(fov_ratio := (in_fov_ratio(point_map[:, instance_map].permute(1, 0), extrinsics[0], K = intrinsics[0], H = int(raw_intrinsics[0][2]), W = int(raw_intrinsics[0][1]),
H_crop = abs(int(raw_intrinsics[0][2]) - original_instances[0].shape[-2]) // 2,
W_crop = abs(int(raw_intrinsics[0][1]) - original_instances[0].shape[-1]) // 2)[0])) > 0.25:
views.append(view)
original_instances.append(instance_map)
original_imgs.append(resize_func.transforms[0](TF.to_tensor(Image.open(self.frames[video][sample_idx]))))
filelist.append(self.frames[video][sample_idx])
extrinsics.append(extrinsic)
raw_intrinsics.append(raw_intrinsic)
intrinsics.append(intrinsic)
depths.append(resize_func.transforms[0](torch.from_numpy(depth)))
point_maps.append(point_map)
fov_ratios[self.frames[video][sample_idx]] = fov_ratio if fov_ratio is not None else -1
else:
sampled_indices.pop(len(views))
continue
sampled_indices = sampled_indices[:len(views)]
if len(sampled_indices) < self.N:
open(self._log_path, "a").write(f"[short_span] {video}: span={len(sampled_indices)} < N={self.N}\n")
return self[np.random.randint(len(self))]
assert len(sampled_indices) == self.N and sampled_indices[0] == first_frame_idx, f'Expected {self.N} sampled indices and first index {first_frame_idx}, got {len(sampled_indices)} with first index {sampled_indices[0]}'
if not is_continuous or (np.random.rand() < 0.8 and must3r_size == original_must3r_size):
assert original_must3r_size == must3r_size, f'If not continuous, must3r size should not change, got {must3r_size} and {original_must3r_size}'
must3r_feats_filelist = [self.must3r_feats[video][idx] for idx in sampled_indices]
must3r_feats = [torch.load(must3r_filepath, map_location = 'cpu') for must3r_filepath in must3r_feats_filelist]
must3r_feats_head = torch.cat([f[-1] for f in must3r_feats], dim = 0)
must3r_feats = [f[:-1] for f in must3r_feats]
must3r_feats = [torch.cat(f, dim = 0) for f in zip(*must3r_feats)]
must3r_feats = [
rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16)
for f in must3r_feats
]
else:
assert is_continuous, f'If must3r size changed, should be continuous sampling, got {must3r_size} and {original_must3r_size}'
must3r_feats = None
must3r_feats_head = None
original_instances = torch.stack(original_instances).squeeze()[:, None]
instances = self.instance_transform(original_instances)
assert instances[0].sum() > 0 and instances.ndim == 4, f'{instances.shape=}, {instances[0].sum()=}'
# assert instances[1:].sum() == 0, f"Only first frame should have the instance, got {instances.sum()=}"
original_imgs = torch.stack(original_imgs)
imgs = self.image_transform(original_imgs)
# if is_continuous:
# permutation = torch.arange(len(instances))
# else:
# permutation = torch.argsort(instances.squeeze().sum(dim = (1, 2)), descending = True)
permutation = torch.arange(len(instances))
permutation[pre_sampled_len:] = torch.randperm(len(instances) - pre_sampled_len) + pre_sampled_len
return {
'original_images': original_imgs[permutation],
'images': imgs[permutation],
'original_masks': original_instances[permutation],
'masks': instances[permutation],
'filelist': [filelist[idx] for idx in permutation],
'must3r_views': [views[idx] for idx in permutation],
'must3r_size': must3r_size,
'video': video,
'instance_id': instance_id,
'dataset': 'scannetpp',
'valid_masks': torch.ones_like(instances),
'intrinsics': torch.from_numpy(np.stack(intrinsics, axis = 0))[permutation],
'extrinsics': torch.from_numpy(np.stack(extrinsics, axis = 0))[permutation],
'depths': torch.from_numpy(np.stack(depths, axis = 0))[permutation],
'point_maps': torch.from_numpy(np.stack(point_maps, axis = 0))[permutation],
'fov_ratios': fov_ratios,
'is_continuous': is_continuous,
} | (
{
'must3r_feats': [f[permutation] for f in must3r_feats],
'must3r_feats_head': must3r_feats_head[permutation],
'must3r_feats_filelist': [must3r_feats_filelist[idx] for idx in permutation],
} if must3r_feats is not None else {}
)