""" Preprocessed Dataset: Load pre-computed oracle results for training. This dataset loads pre-processed results (BA, oracle uncertainty) from cache, enabling fast training iteration without expensive BA computation. """ import logging from pathlib import Path from typing import Dict, List, Optional import cv2 import numpy as np import torch from torch.utils.data import Dataset from .preprocessing import load_preprocessed_sample logger = logging.getLogger(__name__) class PreprocessedARKitDataset(Dataset): """ Dataset that loads pre-computed oracle results for training. This dataset loads pre-processed results from the preprocessing phase, enabling fast training iteration without expensive BA computation. All oracle targets (BA poses, LiDAR depth) and uncertainty results are pre-computed and cached. Key Features: - Fast loading: No BA computation during training (100-1000x faster) - Uncertainty-aware: Includes confidence maps for loss weighting - Lazy image loading: Can load images on-demand or pre-load into memory Dataset Structure: Each sample contains: - 'images': List of image arrays (H, W, 3) uint8 or tensor [N, C, H, W] - 'oracle_targets': Dict with: - 'poses': [N, 3, 4] camera poses (w2c) from BA or ARKit - 'depth': [N, H, W] depth maps from LiDAR or BA (optional) - 'uncertainty_results': Dict with: - 'pose_confidence': [N] per-frame pose confidence - 'depth_confidence': [N, H, W] per-pixel depth confidence - 'collective_confidence': [N] overall sequence confidence - 'sequence_id': str identifier for the sequence - 'metadata': Dict with sequence metadata Example: >>> from ylff.services.preprocessed_dataset import PreprocessedARKitDataset >>> from torch.utils.data import DataLoader >>> >>> dataset = PreprocessedARKitDataset( ... cache_dir=Path("cache/preprocessed"), ... arkit_sequences_dir=Path("data/arkit_sequences"), ... load_images=True, ... ) >>> >>> dataloader = DataLoader(dataset, batch_size=1, shuffle=True) >>> for batch in dataloader: ... images = batch['images'] ... oracle_targets = batch['oracle_targets'] ... uncertainty_results = batch['uncertainty_results'] """ def __init__( self, cache_dir: Path, arkit_sequences_dir: Optional[Path] = None, load_images: bool = True, ): """ Args: cache_dir: Directory containing pre-processed results arkit_sequences_dir: Optional directory with original ARKit sequences (for loading images if not cached) load_images: If True, load images into memory; if False, return paths """ self.cache_dir = Path(cache_dir) self.arkit_sequences_dir = Path(arkit_sequences_dir) if arkit_sequences_dir else None self.load_images = load_images # Find all pre-processed sequences self.sequences = self._find_preprocessed_sequences() if len(self.sequences) == 0: logger.warning(f"No pre-processed sequences found in {cache_dir}") else: logger.info(f"Found {len(self.sequences)} pre-processed sequences") def _find_preprocessed_sequences(self) -> List[str]: """Find all pre-processed sequences in cache directory.""" sequences = [] if not self.cache_dir.exists(): return sequences for item in self.cache_dir.iterdir(): if item.is_dir(): # Check if it has required files oracle_targets_file = item / "oracle_targets.npz" uncertainty_file = item / "uncertainty_results.npz" metadata_file = item / "metadata.json" if ( oracle_targets_file.exists() and uncertainty_file.exists() and metadata_file.exists() ): sequences.append(item.name) return sorted(sequences) def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> Dict: sequence_id = self.sequences[idx] # Load pre-processed results sample = load_preprocessed_sample(self.cache_dir, sequence_id) if sample is None: raise ValueError(f"Failed to load pre-processed sample: {sequence_id}") # Load images if self.load_images: images = self._load_images_for_sequence(sequence_id) images_tensor = torch.stack( [torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 for img in images] ) else: # Return image paths instead images_tensor = None image_paths = self._get_image_paths_for_sequence(sequence_id) # Convert to tensors oracle_targets = { "poses": torch.from_numpy(sample["oracle_targets"]["poses"]).float(), } # Add depth if available if sample["oracle_targets"]["depth"] is not None: depth_array = sample["oracle_targets"]["depth"] # Check if it's the placeholder (1, 1, 1) array if depth_array.shape != (1, 1, 1): oracle_targets["depth"] = torch.from_numpy(depth_array).float() uncertainty_results = { "pose_confidence": torch.from_numpy( sample["uncertainty_results"]["pose_confidence"] ).float(), "depth_confidence": torch.from_numpy( sample["uncertainty_results"]["depth_confidence"] ).float(), "collective_confidence": torch.from_numpy( sample["uncertainty_results"]["collective_confidence"] ).float(), } # Optional: Add uncertainty tensors if available if "pose_uncertainty" in sample["uncertainty_results"]: uncertainty_results["pose_uncertainty"] = torch.from_numpy( sample["uncertainty_results"]["pose_uncertainty"] ).float() if "depth_uncertainty" in sample["uncertainty_results"]: uncertainty_results["depth_uncertainty"] = torch.from_numpy( sample["uncertainty_results"]["depth_uncertainty"] ).float() result = { "images": images_tensor, "oracle_targets": oracle_targets, "uncertainty_results": uncertainty_results, "sequence_id": sequence_id, "metadata": sample.get("metadata", {}), } if not self.load_images: result["image_paths"] = image_paths return result def _load_images_for_sequence(self, sequence_id: str) -> List[np.ndarray]: """Load images for a sequence.""" if self.arkit_sequences_dir is None: raise ValueError("arkit_sequences_dir required when load_images=True") # Find sequence directory (recursive search to handle new folder structures) # We look for a directory named sequence_id that contains a 'videos' subfolder found_dirs = list(self.arkit_sequences_dir.rglob(f"*/{sequence_id}/videos")) if not found_dirs: # Fallback: maybe it's directly there found_dirs = list(self.arkit_sequences_dir.rglob(f"{sequence_id}/videos")) if not found_dirs: # Last fallback: search for anything containing the sequence_id raise FileNotFoundError( f"Sequence directory with 'videos' subfolder for '{sequence_id}' not found in {self.arkit_sequences_dir}" ) videos_dir = found_dirs[0] video_files = list(videos_dir.glob("*.MOV")) + list(videos_dir.glob("*.mov")) if not video_files: raise FileNotFoundError(f"No video file found in {videos_dir}") video_path = video_files[0] logger.info(f"Loading images from {video_path}") # Extract frames images = [] cap = cv2.VideoCapture(str(video_path)) while True: ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) images.append(frame_rgb) cap.release() return images def _get_image_paths_for_sequence(self, sequence_id: str) -> List[Path]: """Get image paths for a sequence (lazy loading).""" if self.arkit_sequences_dir is None: raise ValueError("arkit_sequences_dir required when load_images=False") sequence_dir = self.arkit_sequences_dir / sequence_id # For now, return sequence directory (images loaded from video) return [sequence_dir]