|
|
""" |
|
|
Pre-Processing Pipeline: Compute BA and oracle uncertainty offline. |
|
|
|
|
|
This module handles the offline preprocessing phase that runs OUTSIDE the training |
|
|
loop to pre-compute expensive operations: |
|
|
- BA validation (CPU, expensive, slow) |
|
|
- Oracle uncertainty propagation (CPU, moderate) |
|
|
- Oracle target selection (BA vs ARKit) |
|
|
|
|
|
Results are cached to disk and loaded during training for fast iteration. |
|
|
|
|
|
Key Design: |
|
|
The training pipeline is split into two phases: |
|
|
1. **Pre-Processing Phase** (offline, expensive): Compute BA and oracle uncertainty |
|
|
2. **Training Phase** (online, fast): Load pre-computed results and train |
|
|
|
|
|
This separation allows: |
|
|
- BA computation outside training loop (can be parallelized) |
|
|
- Reuse of expensive computations across training runs |
|
|
- Continuous confidence weighting (not binary rejection) |
|
|
- Efficient training iteration (100-1000x faster) |
|
|
|
|
|
See `docs/TRAINING_PIPELINE_ARCHITECTURE.md` for detailed architecture. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional |
|
|
import numpy as np |
|
|
|
|
|
from ..utils.oracle_uncertainty import OracleUncertaintyPropagator |
|
|
from .arkit_processor import ARKitProcessor |
|
|
from .ba_validator import BAValidator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def preprocess_arkit_sequence( |
|
|
arkit_dir: Path, |
|
|
output_cache_dir: Path, |
|
|
model, |
|
|
ba_validator: BAValidator, |
|
|
oracle_propagator: OracleUncertaintyPropagator, |
|
|
device: str = "cuda", |
|
|
prefer_arkit_poses: bool = True, |
|
|
min_arkit_quality: float = 0.8, |
|
|
use_lidar: bool = True, |
|
|
use_ba_depth: bool = False, |
|
|
) -> Dict: |
|
|
""" |
|
|
Pre-process a single ARKit sequence: compute BA and oracle uncertainty. |
|
|
|
|
|
This runs OUTSIDE the training loop and can be parallelized across sequences. |
|
|
The preprocessing phase computes expensive operations once and caches results |
|
|
for fast training iteration. |
|
|
|
|
|
Processing Steps: |
|
|
1. Extract ARKit data (poses, LiDAR depth) - FREE, fast |
|
|
2. Run DA3 inference (GPU, batchable) - Moderate cost |
|
|
3. Run BA validation (CPU, expensive) - Only if ARKit quality is poor |
|
|
4. Compute oracle uncertainty propagation - Moderate cost |
|
|
5. Save to cache - Fast disk I/O |
|
|
|
|
|
Oracle Target Selection: |
|
|
- If ARKit tracking quality >= min_arkit_quality: Use ARKit poses directly |
|
|
(fast, no BA needed) |
|
|
- Otherwise: Run BA validation to refine poses (expensive but necessary) |
|
|
|
|
|
Args: |
|
|
arkit_dir: Directory containing ARKit sequence with: |
|
|
- videos/*.MOV: Video file |
|
|
- metadata.json: ARKit metadata (poses, LiDAR, intrinsics) |
|
|
output_cache_dir: Directory to save pre-processed results. Each sequence |
|
|
will be saved as a subdirectory with: |
|
|
- oracle_targets.npz: BA/ARKit poses and depth |
|
|
- uncertainty_results.npz: Confidence and uncertainty maps |
|
|
- metadata.json: Sequence metadata |
|
|
model: DA3 model for initial inference. Used to generate initial predictions |
|
|
that are then validated/refined by BA. |
|
|
ba_validator: BAValidator instance for pose refinement via Bundle Adjustment. |
|
|
Only used if ARKit tracking quality is below threshold. |
|
|
oracle_propagator: OracleUncertaintyPropagator for computing uncertainty |
|
|
and confidence maps from multiple oracle sources (ARKit, BA, LiDAR). |
|
|
device: Device for DA3 inference ('cuda' or 'cpu'). Default 'cuda'. |
|
|
prefer_arkit_poses: If True, use ARKit poses when tracking quality is good. |
|
|
This avoids expensive BA computation. Default True. |
|
|
min_arkit_quality: Minimum ARKit tracking quality (0-1) to use ARKit poses |
|
|
directly. Below this threshold, BA validation is run. Default 0.8. |
|
|
use_lidar: Include ARKit LiDAR depth in oracle uncertainty computation. |
|
|
Default True. |
|
|
use_ba_depth: Include BA depth maps in oracle uncertainty computation. |
|
|
BA depth is optional and may not always be available. Default False. |
|
|
|
|
|
Returns: |
|
|
Dictionary with preprocessing results: |
|
|
{ |
|
|
'status': str, # 'success', 'skipped', 'error' |
|
|
'reason': str, # Reason if skipped/error |
|
|
'sequence_id': str, # Sequence identifier |
|
|
'cache_path': Path, # Path to cached results |
|
|
'num_frames': int, # Number of frames processed |
|
|
'pose_source': str, # 'arkit' or 'ba' |
|
|
'tracking_quality': float, # ARKit tracking quality (0-1) |
|
|
} |
|
|
|
|
|
Example: |
|
|
>>> from ylff.services.preprocessing import preprocess_arkit_sequence |
|
|
>>> from ylff.services.ba_validator import BAValidator |
|
|
>>> from ylff.utils.oracle_uncertainty import OracleUncertaintyPropagator |
|
|
>>> |
|
|
>>> result = preprocess_arkit_sequence( |
|
|
... arkit_dir=Path("data/arkit_sequences/seq001"), |
|
|
... output_cache_dir=Path("cache/preprocessed"), |
|
|
... model=da3_model, |
|
|
... ba_validator=ba_validator, |
|
|
... oracle_propagator=oracle_propagator, |
|
|
... prefer_arkit_poses=True, |
|
|
... min_arkit_quality=0.8, |
|
|
... ) |
|
|
|
|
|
Note: |
|
|
This function is designed to be called in parallel across multiple sequences. |
|
|
Each sequence is processed independently and results are cached separately. |
|
|
See `ylff preprocess arkit` CLI command for batch processing. |
|
|
""" |
|
|
sequence_id = arkit_dir.name |
|
|
sequence_cache_dir = output_cache_dir / sequence_id |
|
|
sequence_cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Extracting ARKit data for {sequence_id}...") |
|
|
processor = ARKitProcessor(arkit_dir=arkit_dir) |
|
|
images = processor.extract_frames( |
|
|
output_dir=None, max_frames=None, frame_interval=1, return_images=True |
|
|
) |
|
|
|
|
|
if len(images) < 2: |
|
|
return {"status": "skipped", "reason": "insufficient_frames"} |
|
|
|
|
|
|
|
|
good_indices = processor.filter_good_frames() |
|
|
good_tracking_ratio = len(good_indices) / len(images) if images else 0.0 |
|
|
|
|
|
|
|
|
is_video_only = good_tracking_ratio < 0.5 |
|
|
if is_video_only: |
|
|
logger.info( |
|
|
f"ARKit tracking missing or poor for {sequence_id} ({good_tracking_ratio:.1%}). " |
|
|
"Falling back to Video-only (BA-driven) mode." |
|
|
) |
|
|
|
|
|
|
|
|
arkit_poses_c2w, intrinsics = processor.get_arkit_poses() |
|
|
arkit_poses_w2c = processor.convert_arkit_to_w2c(arkit_poses_c2w) |
|
|
|
|
|
|
|
|
|
|
|
if arkit_poses_c2w is not None and len(arkit_poses_c2w) > 0: |
|
|
min_len = min(len(images), len(arkit_poses_c2w)) |
|
|
if len(images) != len(arkit_poses_c2w): |
|
|
logger.warning( |
|
|
f"Syncing {sequence_id}: video has {len(images)} frames, " |
|
|
f"metadata has {len(arkit_poses_c2w)}. Slicing to {min_len}." |
|
|
) |
|
|
images = images[:min_len] |
|
|
arkit_poses_c2w = arkit_poses_c2w[:min_len] |
|
|
arkit_poses_w2c = arkit_poses_w2c[:min_len] |
|
|
if intrinsics is not None and len(intrinsics) > 0: |
|
|
intrinsics = intrinsics[:min_len] |
|
|
|
|
|
|
|
|
if arkit_poses_c2w is not None and arkit_poses_c2w.size == 0: |
|
|
arkit_poses_c2w = None |
|
|
if arkit_poses_w2c is not None and arkit_poses_w2c.size == 0: |
|
|
arkit_poses_w2c = None |
|
|
if intrinsics is not None and intrinsics.size == 0: |
|
|
intrinsics = None |
|
|
|
|
|
|
|
|
lidar_depth = None |
|
|
if use_lidar: |
|
|
lidar_depth = processor.get_lidar_depths() |
|
|
|
|
|
|
|
|
logger.info(f"Running DA3 inference for {sequence_id} (length: {len(images)})...") |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
batch_size = 8 |
|
|
overlap = 1 |
|
|
|
|
|
all_depths = [] |
|
|
all_poses = [] |
|
|
all_intrinsics = [] |
|
|
|
|
|
last_pose = None |
|
|
|
|
|
for i in range(0, len(images), batch_size - overlap): |
|
|
end_idx = min(i + batch_size, len(images)) |
|
|
chunk_images = images[i:end_idx] |
|
|
|
|
|
|
|
|
if len(chunk_images) < 2 and i > 0: |
|
|
break |
|
|
|
|
|
chunk_arkit = arkit_poses_c2w[i:end_idx] if arkit_poses_c2w is not None else None |
|
|
chunk_ix = intrinsics[i:end_idx] if intrinsics is not None else None |
|
|
|
|
|
with torch.no_grad(): |
|
|
chunk_output = model.inference( |
|
|
chunk_images, |
|
|
extrinsics=chunk_arkit, |
|
|
intrinsics=chunk_ix |
|
|
) |
|
|
|
|
|
|
|
|
c_depth = chunk_output.depth |
|
|
c_poses = chunk_output.extrinsics |
|
|
c_ix = getattr(chunk_output, "intrinsics", None) |
|
|
|
|
|
|
|
|
if is_video_only and last_pose is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p_prev = np.eye(4) |
|
|
p_prev[:3, :] = last_pose |
|
|
p_curr_start = np.eye(4) |
|
|
p_curr_start[:3, :] = c_poses[0] |
|
|
|
|
|
|
|
|
|
|
|
stitch_trans = p_prev @ np.linalg.inv(p_curr_start) |
|
|
|
|
|
|
|
|
for j in range(len(c_poses)): |
|
|
p_j = np.eye(4) |
|
|
p_j[:3, :] = c_poses[j] |
|
|
c_poses[j] = (stitch_trans @ p_j)[:3, :] |
|
|
|
|
|
|
|
|
skip = overlap if i > 0 else 0 |
|
|
all_depths.append(c_depth[skip:]) |
|
|
all_poses.append(c_poses[skip:]) |
|
|
if c_ix is not None: |
|
|
all_intrinsics.append(c_ix[skip:]) |
|
|
|
|
|
|
|
|
last_pose = c_poses[-1] |
|
|
|
|
|
if end_idx == len(images): |
|
|
break |
|
|
|
|
|
|
|
|
da3_depth = np.concatenate(all_depths, axis=0) |
|
|
da3_poses = np.concatenate(all_poses, axis=0) |
|
|
da3_intrinsics = ( |
|
|
np.concatenate(all_intrinsics, axis=0) |
|
|
if all_intrinsics else (intrinsics if intrinsics is not None else None) |
|
|
) |
|
|
|
|
|
da3_output_summary = { |
|
|
"extrinsics": da3_poses, |
|
|
"depth": da3_depth, |
|
|
"intrinsics": da3_intrinsics |
|
|
} |
|
|
|
|
|
|
|
|
use_arkit_poses = ( |
|
|
prefer_arkit_poses and |
|
|
good_tracking_ratio >= min_arkit_quality and |
|
|
not is_video_only |
|
|
) |
|
|
|
|
|
if use_arkit_poses: |
|
|
|
|
|
logger.info( |
|
|
f"Using ARKit poses for {sequence_id} " |
|
|
f"(tracking quality: {good_tracking_ratio:.1%})" |
|
|
) |
|
|
oracle_poses = arkit_poses_w2c |
|
|
pose_source = "arkit" |
|
|
ba_poses = None |
|
|
ba_depths = None |
|
|
else: |
|
|
|
|
|
if is_video_only: |
|
|
logger.info(f"Running video-only BA reconstruction for {sequence_id}...") |
|
|
else: |
|
|
logger.info( |
|
|
f"Running BA validation for {sequence_id} " |
|
|
f"(ARKit tracking quality: {good_tracking_ratio:.1%} < {min_arkit_quality:.1%})" |
|
|
) |
|
|
ba_result = ba_validator.validate( |
|
|
images=images, |
|
|
poses_model=da3_poses, |
|
|
intrinsics=da3_intrinsics, |
|
|
) |
|
|
|
|
|
|
|
|
ba_poses_extracted = ba_result.get("poses_ba") |
|
|
|
|
|
if ba_poses_extracted is None: |
|
|
if is_video_only: |
|
|
logger.warning(f"BA reconstruction failed for video-only sequence {sequence_id}") |
|
|
return {"status": "skipped", "reason": "ba_failed"} |
|
|
|
|
|
|
|
|
logger.warning(f"BA failed for {sequence_id}, falling back to ARKit poses") |
|
|
oracle_poses = arkit_poses_w2c |
|
|
pose_source = "arkit_fallback" |
|
|
ba_poses = None |
|
|
ba_depths = None |
|
|
else: |
|
|
oracle_poses = ba_poses_extracted |
|
|
pose_source = "ba" |
|
|
ba_poses = ba_poses_extracted |
|
|
ba_depths = ba_result.get("ba_depths") if use_ba_depth else None |
|
|
|
|
|
|
|
|
logger.info(f"Computing oracle uncertainty for {sequence_id}...") |
|
|
uncertainty_results = oracle_propagator.propagate_uncertainty( |
|
|
da3_poses=da3_poses, |
|
|
da3_depth=da3_depth, |
|
|
intrinsics=intrinsics, |
|
|
arkit_poses=arkit_poses_c2w, |
|
|
ba_poses=ba_poses, |
|
|
lidar_depth=lidar_depth if use_lidar else None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
oracle_depth = None |
|
|
if use_lidar and lidar_depth is not None: |
|
|
oracle_depth = lidar_depth |
|
|
depth_source = "lidar" |
|
|
elif use_ba_depth and ba_depths is not None: |
|
|
oracle_depth = ba_depths |
|
|
depth_source = "ba" |
|
|
else: |
|
|
depth_source = "none" |
|
|
|
|
|
|
|
|
logger.info(f"Saving pre-processed results for {sequence_id}...") |
|
|
|
|
|
|
|
|
np.savez_compressed( |
|
|
sequence_cache_dir / "oracle_targets.npz", |
|
|
poses=oracle_poses, |
|
|
depth=oracle_depth if oracle_depth is not None else np.zeros((1, 1, 1)), |
|
|
) |
|
|
|
|
|
|
|
|
np.savez_compressed( |
|
|
sequence_cache_dir / "uncertainty_results.npz", |
|
|
pose_confidence=uncertainty_results["pose_confidence"], |
|
|
depth_confidence=uncertainty_results["depth_confidence"], |
|
|
collective_confidence=uncertainty_results["collective_confidence"], |
|
|
pose_uncertainty=uncertainty_results.get( |
|
|
"pose_uncertainty", |
|
|
np.zeros((len(images), 6)), |
|
|
), |
|
|
depth_uncertainty=uncertainty_results.get( |
|
|
"depth_uncertainty", np.zeros_like(da3_depth) |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
np.savez_compressed( |
|
|
sequence_cache_dir / "arkit_data.npz", |
|
|
poses=arkit_poses_c2w, |
|
|
lidar_depth=lidar_depth if lidar_depth is not None else np.zeros((1, 1, 1)), |
|
|
) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"sequence_id": sequence_id, |
|
|
"num_frames": len(images), |
|
|
"tracking_quality": float(good_tracking_ratio), |
|
|
"pose_source": pose_source, |
|
|
"depth_source": depth_source, |
|
|
"has_lidar": lidar_depth is not None, |
|
|
"has_ba_depth": ba_depths is not None, |
|
|
"mean_pose_confidence": float(uncertainty_results["pose_confidence"].mean()), |
|
|
"mean_depth_confidence": float(uncertainty_results["depth_confidence"].mean()), |
|
|
} |
|
|
|
|
|
with open(sequence_cache_dir / "metadata.json", "w") as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
|
|
|
image_paths_file = sequence_cache_dir / "image_paths.txt" |
|
|
|
|
|
with open(image_paths_file, "w") as f: |
|
|
f.write(f"{arkit_dir}\n") |
|
|
|
|
|
logger.info(f"Pre-processing complete for {sequence_id}") |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"sequence_id": sequence_id, |
|
|
"num_frames": len(images), |
|
|
"pose_source": pose_source, |
|
|
"depth_source": depth_source, |
|
|
"mean_confidence": float(uncertainty_results["collective_confidence"].mean()), |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Pre-processing failed for {sequence_id}: {e}", exc_info=True) |
|
|
return {"status": "failed", "sequence_id": sequence_id, "error": str(e)} |
|
|
|
|
|
|
|
|
def load_preprocessed_sample(cache_dir: Path, sequence_id: str) -> Optional[Dict]: |
|
|
""" |
|
|
Load pre-processed sample from cache. |
|
|
|
|
|
Args: |
|
|
cache_dir: Cache directory |
|
|
sequence_id: Sequence identifier |
|
|
|
|
|
Returns: |
|
|
Dict with pre-processed data or None if not found |
|
|
""" |
|
|
sequence_cache_dir = cache_dir / sequence_id |
|
|
|
|
|
if not sequence_cache_dir.exists(): |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
oracle_targets_data = np.load(sequence_cache_dir / "oracle_targets.npz") |
|
|
oracle_targets = { |
|
|
"poses": oracle_targets_data["poses"], |
|
|
"depth": ( |
|
|
oracle_targets_data["depth"] |
|
|
if oracle_targets_data["depth"].shape != (1, 1, 1) |
|
|
else None |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
uncertainty_data = np.load(sequence_cache_dir / "uncertainty_results.npz") |
|
|
uncertainty_results = { |
|
|
"pose_confidence": uncertainty_data["pose_confidence"], |
|
|
"depth_confidence": uncertainty_data["depth_confidence"], |
|
|
"collective_confidence": uncertainty_data["collective_confidence"], |
|
|
"pose_uncertainty": uncertainty_data.get("pose_uncertainty"), |
|
|
"depth_uncertainty": uncertainty_data.get("depth_uncertainty"), |
|
|
} |
|
|
|
|
|
|
|
|
arkit_data_file = sequence_cache_dir / "arkit_data.npz" |
|
|
arkit_data = None |
|
|
if arkit_data_file.exists(): |
|
|
arkit_data_npz = np.load(arkit_data_file) |
|
|
arkit_data = { |
|
|
"poses": arkit_data_npz["poses"], |
|
|
"lidar_depth": ( |
|
|
arkit_data_npz["lidar_depth"] |
|
|
if arkit_data_npz["lidar_depth"].shape != (1, 1, 1) |
|
|
else None |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
metadata_file = sequence_cache_dir / "metadata.json" |
|
|
metadata = {} |
|
|
if metadata_file.exists(): |
|
|
with open(metadata_file) as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
return { |
|
|
"oracle_targets": oracle_targets, |
|
|
"uncertainty_results": uncertainty_results, |
|
|
"arkit_data": arkit_data, |
|
|
"metadata": metadata, |
|
|
"sequence_id": sequence_id, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load pre-processed sample {sequence_id}: {e}") |
|
|
return None |
|
|
|