LongStream / longstream /data /dataloader.py
Cc
init
e340a84
import os
import glob
from dataclasses import dataclass
from typing import List, Dict, Any, Iterator, Optional, Tuple
import torch
from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval
dataset_metadata: Dict[str, Dict[str, Any]] = {
"davis": {
"img_path": "data/davis/DAVIS/JPEGImages/480p",
"mask_path": "data/davis/DAVIS/masked_images/480p",
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: None,
"traj_format": None,
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
"skip_condition": None,
"process_func": None,
},
"kitti": {
"img_path": "data/kitti/sequences",
"anno_path": "data/kitti/poses",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "image_2"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
anno_path, f"{seq}.txt"
)
if os.path.exists(os.path.join(anno_path, f"{seq}.txt"))
else None,
"traj_format": "kitti",
"seq_list": ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10"],
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
"bonn": {
"img_path": "data/bonn/rgbd_bonn_dataset",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(
img_path, f"rgbd_bonn_{seq}", "rgb_110"
),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
),
"traj_format": "tum",
"seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
"full_seq": False,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
"nyu": {
"img_path": "data/nyu-v2/val/nyu_images",
"mask_path": None,
"process_func": None,
},
"scannet": {
"img_path": "data/scannetv2",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "pose_90.txt"
),
"traj_format": "replica",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
"tum": {
"img_path": "data/tum",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path, seq, "groundtruth_90.txt"
),
"traj_format": "tum",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
"sintel": {
"img_path": "data/sintel/training/final",
"anno_path": "data/sintel/training/camdata_left",
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
"traj_format": None,
"seq_list": [
"alley_2",
"ambush_4",
"ambush_5",
"ambush_6",
"cave_2",
"cave_4",
"market_2",
"market_5",
"market_6",
"shaman_3",
"sleeping_1",
"sleeping_2",
"temple_2",
"temple_3",
],
"full_seq": False,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
"waymo": {
"img_path": "/horizon-bucket/saturn_v_4dlabel/004_vision/01_users/tao02.xie/datasets/scatt3r_evaluation/waymo_open_dataset_v1_4_3",
"anno_path": None,
"mask_path": None,
"dir_path_func": lambda img_path, seq: os.path.join(
img_path,
seq.split("_cam")[0] if "_cam" in seq else seq,
"images",
seq.split("_cam")[1] if "_cam" in seq else "00",
),
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
img_path,
seq.split("_cam")[0] if "_cam" in seq else seq,
"cameras",
seq.split("_cam")[1] if "_cam" in seq else "00",
"extri.yml",
),
"traj_format": "waymo",
"seq_list": None,
"full_seq": True,
"mask_path_seq_func": lambda mask_path, seq: None,
"skip_condition": None,
"process_func": None,
},
}
@dataclass
class LongStreamSequenceInfo:
name: str
scene_root: str
image_dir: str
image_paths: List[str]
camera: Optional[str]
class LongStreamSequence:
def __init__(
self,
name: str,
images: torch.Tensor,
image_paths: List[str],
scene_root: Optional[str] = None,
image_dir: Optional[str] = None,
camera: Optional[str] = None,
):
self.name = name
self.images = images
self.image_paths = image_paths
self.scene_root = scene_root
self.image_dir = image_dir
self.camera = camera
def _read_list_file(path: str) -> List[str]:
with open(path, "r") as f:
lines = []
for line in f.readlines():
line = line.strip()
if not line:
continue
if line.startswith("#"):
continue
lines.append(line)
return lines
def _is_generalizable_scene_root(path: str) -> bool:
return os.path.isdir(os.path.join(path, "images"))
def _direct_image_files(dir_path: str) -> List[str]:
filelist = sorted(glob.glob(os.path.join(dir_path, "*.png")))
if not filelist:
filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpg")))
if not filelist:
filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpeg")))
return filelist
class LongStreamDataLoader:
def __init__(self, cfg: Dict[str, Any]):
self.cfg = cfg
self.dataset = cfg.get("dataset", None)
meta = dataset_metadata.get(self.dataset, {})
self.img_path = cfg.get("img_path", meta.get("img_path"))
self.mask_path = cfg.get("mask_path", meta.get("mask_path"))
self.dir_path_func = meta.get("dir_path_func", lambda p, s: os.path.join(p, s))
self.mask_path_seq_func = meta.get("mask_path_seq_func", lambda p, s: None)
self.full_seq = bool(cfg.get("full_seq", meta.get("full_seq", True)))
self.seq_list = cfg.get("seq_list", None)
self.stride = int(cfg.get("stride", 1))
self.max_frames = cfg.get("max_frames", None)
self.size = int(cfg.get("size", 518))
self.crop = bool(cfg.get("crop", False))
self.patch_size = int(cfg.get("patch_size", 14))
self.format = cfg.get("format", "auto")
self.data_roots_file = cfg.get("data_roots_file", None)
self.split = cfg.get("split", None)
self.camera = cfg.get("camera", None)
def _infer_format(self) -> str:
if self.format in ["relpose", "generalizable"]:
return self.format
if self.img_path is None:
return "relpose"
if _is_generalizable_scene_root(self.img_path):
return "generalizable"
default_list = self.data_roots_file or "data_roots.txt"
if os.path.exists(os.path.join(self.img_path, default_list)):
return "generalizable"
return "relpose"
def _resolve_seq_list_generalizable(self) -> List[str]:
if self.seq_list is not None:
return list(self.seq_list)
if self.img_path is None or not os.path.isdir(self.img_path):
return []
if _is_generalizable_scene_root(self.img_path):
return [self.img_path]
candidates = []
if isinstance(self.data_roots_file, str) and self.data_roots_file:
candidates.append(self.data_roots_file)
if isinstance(self.split, str) and self.split:
split_name = self.split.lower()
if split_name in ["val", "valid", "validate"]:
split_name = "validate"
candidates.append(f"{split_name}_data_roots.txt")
candidates.append("data_roots.txt")
candidates.append("train_data_roots.txt")
candidates.append("validate_data_roots.txt")
for fname in candidates:
path = os.path.join(self.img_path, fname)
if os.path.exists(path):
return _read_list_file(path)
img_dirs = sorted(
glob.glob(os.path.join(self.img_path, "**", "images"), recursive=True)
)
scene_roots = [os.path.dirname(p) for p in img_dirs]
rels = []
for p in scene_roots:
try:
rels.append(os.path.relpath(p, self.img_path))
except ValueError:
rels.append(p)
return sorted(set(rels))
def _resolve_seq_list_relpose(self) -> List[str]:
if self.seq_list is not None:
return list(self.seq_list)
meta = dataset_metadata.get(self.dataset, {})
if self.full_seq:
if self.img_path is None or not os.path.isdir(self.img_path):
return []
seqs = [
s
for s in os.listdir(self.img_path)
if os.path.isdir(os.path.join(self.img_path, s))
]
return sorted(seqs)
seqs = meta.get("seq_list", []) or []
return list(seqs)
def _resolve_seq_list(self) -> List[str]:
fmt = self._infer_format()
if fmt == "generalizable":
return self._resolve_seq_list_generalizable()
return self._resolve_seq_list_relpose()
def _resolve_scene_root(self, seq_entry: str) -> Tuple[str, str]:
if os.path.isabs(seq_entry) or os.path.sep in seq_entry:
scene_root = seq_entry
name = os.path.basename(os.path.normpath(seq_entry))
else:
scene_root = os.path.join(self.img_path, seq_entry)
name = seq_entry
return name, scene_root
def _resolve_image_dir_generalizable(self, scene_root: str) -> Optional[str]:
images_root = os.path.join(scene_root, "images")
if not os.path.isdir(images_root):
return None
if isinstance(self.camera, str) and self.camera:
cam_dir = os.path.join(images_root, self.camera)
if os.path.isdir(cam_dir):
return cam_dir
if _direct_image_files(images_root):
return images_root
cams = [
d
for d in os.listdir(images_root)
if os.path.isdir(os.path.join(images_root, d))
]
if not cams:
return None
cams = sorted(cams)
frame_dirs = []
for name in cams:
child_dir = os.path.join(images_root, name)
child_images = _direct_image_files(child_dir)
if child_images:
frame_dirs.append((name, len(child_images)))
if (
len(cams) > 10
and len(frame_dirs) == len(cams)
and max(count for _, count in frame_dirs) == 1
):
return images_root
return os.path.join(images_root, cams[0])
def _camera_from_image_dir(self, image_dir: str) -> Optional[str]:
parent = os.path.basename(os.path.dirname(image_dir))
if parent != "images":
return None
return os.path.basename(image_dir)
def _collect_filelist(self, dir_path: str) -> List[str]:
filelist = _direct_image_files(dir_path)
if not filelist:
nested = []
child_dirs = sorted(
d for d in glob.glob(os.path.join(dir_path, "*")) if os.path.isdir(d)
)
for child_dir in child_dirs:
child_images = _direct_image_files(child_dir)
if child_images:
nested.append(child_images[0])
filelist = nested
if self.stride > 1:
filelist = filelist[:: self.stride]
if self.max_frames is not None:
filelist = filelist[: self.max_frames]
return filelist
def _load_images(self, filelist: List[str]) -> torch.Tensor:
views = load_images_for_eval(
filelist,
size=self.size,
verbose=False,
crop=self.crop,
patch_size=self.patch_size,
)
imgs = torch.cat([view["img"] for view in views], dim=0)
images = imgs.unsqueeze(0)
images = (images + 1.0) / 2.0
return images
def iter_sequence_infos(self) -> Iterator[LongStreamSequenceInfo]:
fmt = self._infer_format()
seqs = self._resolve_seq_list()
for seq_entry in seqs:
if fmt == "generalizable":
seq, scene_root = self._resolve_scene_root(seq_entry)
dir_path = self._resolve_image_dir_generalizable(scene_root)
if dir_path is None or not os.path.isdir(dir_path):
continue
camera = self._camera_from_image_dir(dir_path)
else:
seq = seq_entry
scene_root = os.path.join(self.img_path, seq)
dir_path = self.dir_path_func(self.img_path, seq)
if not os.path.isdir(dir_path):
continue
camera = None
filelist = self._collect_filelist(dir_path)
if not filelist:
continue
yield LongStreamSequenceInfo(
name=seq,
scene_root=scene_root,
image_dir=dir_path,
image_paths=filelist,
camera=camera,
)
def __iter__(self) -> Iterator[LongStreamSequence]:
for info in self.iter_sequence_infos():
print(
f"[longstream] loading sequence {info.name}: {len(info.image_paths)} frames",
flush=True,
)
images = self._load_images(info.image_paths)
print(
f"[longstream] loaded sequence {info.name}: {tuple(images.shape)}",
flush=True,
)
yield LongStreamSequence(
info.name,
images,
info.image_paths,
scene_root=info.scene_root,
image_dir=info.image_dir,
camera=info.camera,
)