3d_model / ylff /services /preprocessed_dataset.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Preprocessed Dataset: Load pre-computed oracle results for training.
This dataset loads pre-processed results (BA, oracle uncertainty) from cache,
enabling fast training iteration without expensive BA computation.
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from .preprocessing import load_preprocessed_sample
logger = logging.getLogger(__name__)
class PreprocessedARKitDataset(Dataset):
"""
Dataset that loads pre-computed oracle results for training.
This dataset loads pre-processed results from the preprocessing phase, enabling
fast training iteration without expensive BA computation. All oracle targets
(BA poses, LiDAR depth) and uncertainty results are pre-computed and cached.
Key Features:
- Fast loading: No BA computation during training (100-1000x faster)
- Uncertainty-aware: Includes confidence maps for loss weighting
- Lazy image loading: Can load images on-demand or pre-load into memory
Dataset Structure:
Each sample contains:
- 'images': List of image arrays (H, W, 3) uint8 or tensor [N, C, H, W]
- 'oracle_targets': Dict with:
- 'poses': [N, 3, 4] camera poses (w2c) from BA or ARKit
- 'depth': [N, H, W] depth maps from LiDAR or BA (optional)
- 'uncertainty_results': Dict with:
- 'pose_confidence': [N] per-frame pose confidence
- 'depth_confidence': [N, H, W] per-pixel depth confidence
- 'collective_confidence': [N] overall sequence confidence
- 'sequence_id': str identifier for the sequence
- 'metadata': Dict with sequence metadata
Example:
>>> from ylff.services.preprocessed_dataset import PreprocessedARKitDataset
>>> from torch.utils.data import DataLoader
>>>
>>> dataset = PreprocessedARKitDataset(
... cache_dir=Path("cache/preprocessed"),
... arkit_sequences_dir=Path("data/arkit_sequences"),
... load_images=True,
... )
>>>
>>> dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
>>> for batch in dataloader:
... images = batch['images']
... oracle_targets = batch['oracle_targets']
... uncertainty_results = batch['uncertainty_results']
"""
def __init__(
self,
cache_dir: Path,
arkit_sequences_dir: Optional[Path] = None,
load_images: bool = True,
):
"""
Args:
cache_dir: Directory containing pre-processed results
arkit_sequences_dir: Optional directory with original ARKit sequences
(for loading images if not cached)
load_images: If True, load images into memory; if False, return paths
"""
self.cache_dir = Path(cache_dir)
self.arkit_sequences_dir = Path(arkit_sequences_dir) if arkit_sequences_dir else None
self.load_images = load_images
# Find all pre-processed sequences
self.sequences = self._find_preprocessed_sequences()
if len(self.sequences) == 0:
logger.warning(f"No pre-processed sequences found in {cache_dir}")
else:
logger.info(f"Found {len(self.sequences)} pre-processed sequences")
def _find_preprocessed_sequences(self) -> List[str]:
"""Find all pre-processed sequences in cache directory."""
sequences = []
if not self.cache_dir.exists():
return sequences
for item in self.cache_dir.iterdir():
if item.is_dir():
# Check if it has required files
oracle_targets_file = item / "oracle_targets.npz"
uncertainty_file = item / "uncertainty_results.npz"
metadata_file = item / "metadata.json"
if (
oracle_targets_file.exists()
and uncertainty_file.exists()
and metadata_file.exists()
):
sequences.append(item.name)
return sorted(sequences)
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> Dict:
sequence_id = self.sequences[idx]
# Load pre-processed results
sample = load_preprocessed_sample(self.cache_dir, sequence_id)
if sample is None:
raise ValueError(f"Failed to load pre-processed sample: {sequence_id}")
# Load images
if self.load_images:
images = self._load_images_for_sequence(sequence_id)
images_tensor = torch.stack(
[torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 for img in images]
)
else:
# Return image paths instead
images_tensor = None
image_paths = self._get_image_paths_for_sequence(sequence_id)
# Convert to tensors
oracle_targets = {
"poses": torch.from_numpy(sample["oracle_targets"]["poses"]).float(),
}
# Add depth if available
if sample["oracle_targets"]["depth"] is not None:
depth_array = sample["oracle_targets"]["depth"]
# Check if it's the placeholder (1, 1, 1) array
if depth_array.shape != (1, 1, 1):
oracle_targets["depth"] = torch.from_numpy(depth_array).float()
uncertainty_results = {
"pose_confidence": torch.from_numpy(
sample["uncertainty_results"]["pose_confidence"]
).float(),
"depth_confidence": torch.from_numpy(
sample["uncertainty_results"]["depth_confidence"]
).float(),
"collective_confidence": torch.from_numpy(
sample["uncertainty_results"]["collective_confidence"]
).float(),
}
# Optional: Add uncertainty tensors if available
if "pose_uncertainty" in sample["uncertainty_results"]:
uncertainty_results["pose_uncertainty"] = torch.from_numpy(
sample["uncertainty_results"]["pose_uncertainty"]
).float()
if "depth_uncertainty" in sample["uncertainty_results"]:
uncertainty_results["depth_uncertainty"] = torch.from_numpy(
sample["uncertainty_results"]["depth_uncertainty"]
).float()
result = {
"images": images_tensor,
"oracle_targets": oracle_targets,
"uncertainty_results": uncertainty_results,
"sequence_id": sequence_id,
"metadata": sample.get("metadata", {}),
}
if not self.load_images:
result["image_paths"] = image_paths
return result
def _load_images_for_sequence(self, sequence_id: str) -> List[np.ndarray]:
"""Load images for a sequence."""
if self.arkit_sequences_dir is None:
raise ValueError("arkit_sequences_dir required when load_images=True")
# Find sequence directory (recursive search to handle new folder structures)
# We look for a directory named sequence_id that contains a 'videos' subfolder
found_dirs = list(self.arkit_sequences_dir.rglob(f"*/{sequence_id}/videos"))
if not found_dirs:
# Fallback: maybe it's directly there
found_dirs = list(self.arkit_sequences_dir.rglob(f"{sequence_id}/videos"))
if not found_dirs:
# Last fallback: search for anything containing the sequence_id
raise FileNotFoundError(
f"Sequence directory with 'videos' subfolder for '{sequence_id}' not found in {self.arkit_sequences_dir}"
)
videos_dir = found_dirs[0]
video_files = list(videos_dir.glob("*.MOV")) + list(videos_dir.glob("*.mov"))
if not video_files:
raise FileNotFoundError(f"No video file found in {videos_dir}")
video_path = video_files[0]
logger.info(f"Loading images from {video_path}")
# Extract frames
images = []
cap = cv2.VideoCapture(str(video_path))
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
images.append(frame_rgb)
cap.release()
return images
def _get_image_paths_for_sequence(self, sequence_id: str) -> List[Path]:
"""Get image paths for a sequence (lazy loading)."""
if self.arkit_sequences_dir is None:
raise ValueError("arkit_sequences_dir required when load_images=False")
sequence_dir = self.arkit_sequences_dir / sequence_id
# For now, return sequence directory (images loaded from video)
return [sequence_dir]