| | """ |
| | Preprocessed Dataset: Load pre-computed oracle results for training. |
| | |
| | This dataset loads pre-processed results (BA, oracle uncertainty) from cache, |
| | enabling fast training iteration without expensive BA computation. |
| | """ |
| |
|
| | import logging |
| | from pathlib import Path |
| | from typing import Dict, List, Optional |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from torch.utils.data import Dataset |
| |
|
| | from .preprocessing import load_preprocessed_sample |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class PreprocessedARKitDataset(Dataset): |
| | """ |
| | Dataset that loads pre-computed oracle results for training. |
| | |
| | This dataset loads pre-processed results from the preprocessing phase, enabling |
| | fast training iteration without expensive BA computation. All oracle targets |
| | (BA poses, LiDAR depth) and uncertainty results are pre-computed and cached. |
| | |
| | Key Features: |
| | - Fast loading: No BA computation during training (100-1000x faster) |
| | - Uncertainty-aware: Includes confidence maps for loss weighting |
| | - Lazy image loading: Can load images on-demand or pre-load into memory |
| | |
| | Dataset Structure: |
| | Each sample contains: |
| | - 'images': List of image arrays (H, W, 3) uint8 or tensor [N, C, H, W] |
| | - 'oracle_targets': Dict with: |
| | - 'poses': [N, 3, 4] camera poses (w2c) from BA or ARKit |
| | - 'depth': [N, H, W] depth maps from LiDAR or BA (optional) |
| | - 'uncertainty_results': Dict with: |
| | - 'pose_confidence': [N] per-frame pose confidence |
| | - 'depth_confidence': [N, H, W] per-pixel depth confidence |
| | - 'collective_confidence': [N] overall sequence confidence |
| | - 'sequence_id': str identifier for the sequence |
| | - 'metadata': Dict with sequence metadata |
| | |
| | Example: |
| | >>> from ylff.services.preprocessed_dataset import PreprocessedARKitDataset |
| | >>> from torch.utils.data import DataLoader |
| | >>> |
| | >>> dataset = PreprocessedARKitDataset( |
| | ... cache_dir=Path("cache/preprocessed"), |
| | ... arkit_sequences_dir=Path("data/arkit_sequences"), |
| | ... load_images=True, |
| | ... ) |
| | >>> |
| | >>> dataloader = DataLoader(dataset, batch_size=1, shuffle=True) |
| | >>> for batch in dataloader: |
| | ... images = batch['images'] |
| | ... oracle_targets = batch['oracle_targets'] |
| | ... uncertainty_results = batch['uncertainty_results'] |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | cache_dir: Path, |
| | arkit_sequences_dir: Optional[Path] = None, |
| | load_images: bool = True, |
| | ): |
| | """ |
| | Args: |
| | cache_dir: Directory containing pre-processed results |
| | arkit_sequences_dir: Optional directory with original ARKit sequences |
| | (for loading images if not cached) |
| | load_images: If True, load images into memory; if False, return paths |
| | """ |
| | self.cache_dir = Path(cache_dir) |
| | self.arkit_sequences_dir = Path(arkit_sequences_dir) if arkit_sequences_dir else None |
| | self.load_images = load_images |
| |
|
| | |
| | self.sequences = self._find_preprocessed_sequences() |
| |
|
| | if len(self.sequences) == 0: |
| | logger.warning(f"No pre-processed sequences found in {cache_dir}") |
| | else: |
| | logger.info(f"Found {len(self.sequences)} pre-processed sequences") |
| |
|
| | def _find_preprocessed_sequences(self) -> List[str]: |
| | """Find all pre-processed sequences in cache directory.""" |
| | sequences = [] |
| | if not self.cache_dir.exists(): |
| | return sequences |
| |
|
| | for item in self.cache_dir.iterdir(): |
| | if item.is_dir(): |
| | |
| | oracle_targets_file = item / "oracle_targets.npz" |
| | uncertainty_file = item / "uncertainty_results.npz" |
| | metadata_file = item / "metadata.json" |
| |
|
| | if ( |
| | oracle_targets_file.exists() |
| | and uncertainty_file.exists() |
| | and metadata_file.exists() |
| | ): |
| | sequences.append(item.name) |
| |
|
| | return sorted(sequences) |
| |
|
| | def __len__(self) -> int: |
| | return len(self.sequences) |
| |
|
| | def __getitem__(self, idx: int) -> Dict: |
| | sequence_id = self.sequences[idx] |
| |
|
| | |
| | sample = load_preprocessed_sample(self.cache_dir, sequence_id) |
| |
|
| | if sample is None: |
| | raise ValueError(f"Failed to load pre-processed sample: {sequence_id}") |
| |
|
| | |
| | if self.load_images: |
| | images = self._load_images_for_sequence(sequence_id) |
| | images_tensor = torch.stack( |
| | [torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 for img in images] |
| | ) |
| | else: |
| | |
| | images_tensor = None |
| | image_paths = self._get_image_paths_for_sequence(sequence_id) |
| |
|
| | |
| | oracle_targets = { |
| | "poses": torch.from_numpy(sample["oracle_targets"]["poses"]).float(), |
| | } |
| |
|
| | |
| | if sample["oracle_targets"]["depth"] is not None: |
| | depth_array = sample["oracle_targets"]["depth"] |
| | |
| | if depth_array.shape != (1, 1, 1): |
| | oracle_targets["depth"] = torch.from_numpy(depth_array).float() |
| |
|
| | uncertainty_results = { |
| | "pose_confidence": torch.from_numpy( |
| | sample["uncertainty_results"]["pose_confidence"] |
| | ).float(), |
| | "depth_confidence": torch.from_numpy( |
| | sample["uncertainty_results"]["depth_confidence"] |
| | ).float(), |
| | "collective_confidence": torch.from_numpy( |
| | sample["uncertainty_results"]["collective_confidence"] |
| | ).float(), |
| | } |
| |
|
| | |
| | if "pose_uncertainty" in sample["uncertainty_results"]: |
| | uncertainty_results["pose_uncertainty"] = torch.from_numpy( |
| | sample["uncertainty_results"]["pose_uncertainty"] |
| | ).float() |
| |
|
| | if "depth_uncertainty" in sample["uncertainty_results"]: |
| | uncertainty_results["depth_uncertainty"] = torch.from_numpy( |
| | sample["uncertainty_results"]["depth_uncertainty"] |
| | ).float() |
| |
|
| | result = { |
| | "images": images_tensor, |
| | "oracle_targets": oracle_targets, |
| | "uncertainty_results": uncertainty_results, |
| | "sequence_id": sequence_id, |
| | "metadata": sample.get("metadata", {}), |
| | } |
| |
|
| | if not self.load_images: |
| | result["image_paths"] = image_paths |
| |
|
| | return result |
| |
|
| | def _load_images_for_sequence(self, sequence_id: str) -> List[np.ndarray]: |
| | """Load images for a sequence.""" |
| | if self.arkit_sequences_dir is None: |
| | raise ValueError("arkit_sequences_dir required when load_images=True") |
| |
|
| | |
| | |
| | found_dirs = list(self.arkit_sequences_dir.rglob(f"*/{sequence_id}/videos")) |
| | if not found_dirs: |
| | |
| | found_dirs = list(self.arkit_sequences_dir.rglob(f"{sequence_id}/videos")) |
| | |
| | if not found_dirs: |
| | |
| | raise FileNotFoundError( |
| | f"Sequence directory with 'videos' subfolder for '{sequence_id}' not found in {self.arkit_sequences_dir}" |
| | ) |
| |
|
| | videos_dir = found_dirs[0] |
| | video_files = list(videos_dir.glob("*.MOV")) + list(videos_dir.glob("*.mov")) |
| | if not video_files: |
| | raise FileNotFoundError(f"No video file found in {videos_dir}") |
| |
|
| | video_path = video_files[0] |
| | logger.info(f"Loading images from {video_path}") |
| |
|
| | |
| | images = [] |
| | cap = cv2.VideoCapture(str(video_path)) |
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | images.append(frame_rgb) |
| | cap.release() |
| |
|
| | return images |
| |
|
| | def _get_image_paths_for_sequence(self, sequence_id: str) -> List[Path]: |
| | """Get image paths for a sequence (lazy loading).""" |
| | if self.arkit_sequences_dir is None: |
| | raise ValueError("arkit_sequences_dir required when load_images=False") |
| |
|
| | sequence_dir = self.arkit_sequences_dir / sequence_id |
| | |
| | return [sequence_dir] |
| |
|