#!/usr/bin/env python3 """ Run BA validation on full video to identify rejected frames. """ import os import sys from pathlib import Path # Set environment variable FIRST before any imports os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Add SuperGluePretrainedNetwork to Python path if it exists superglue_path = Path("/tmp/SuperGluePretrainedNetwork") if superglue_path.exists(): if str(superglue_path) not in sys.path: sys.path.insert(0, str(superglue_path)) # Set up logging IMMEDIATELY import logging # noqa: E402 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True, # Force reconfiguration ) logger = logging.getLogger(__name__) logger.info("=" * 60) logger.info("Starting BA Validation Script") logger.info("=" * 60) logger.info("Step 0: Importing dependencies...") # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) logger.info(f"Project root: {project_root}") try: logger.info(" - Importing numpy...") import numpy as np logger.info(" ✓ numpy imported") logger.info(" - Importing cv2...") import cv2 logger.info(" ✓ cv2 imported") logger.info(" - Importing torch...") import torch logger.info(" ✓ torch imported") logger.info(" - Importing tqdm...") from tqdm import tqdm logger.info(" ✓ tqdm imported") logger.info(" - Importing json...") import json logger.info(" ✓ json imported") logger.info(" - Importing typing...") from typing import Optional logger.info(" ✓ typing imported") logger.info(" - Importing ylff modules...") from ylff.utils.model_loader import load_da3_model logger.info(" ✓ ylff.models imported") from ylff.services.ba_validator import BAValidator logger.info(" ✓ ylff.ba_validator imported") logger.info("✓ All imports complete") except Exception as e: logger.error(f"✗ Import failed: {e}") import traceback traceback.print_exc() sys.exit(1) def extract_all_frames( video_path: Path, max_frames: Optional[int] = None, frame_interval: int = 1 ) -> list: """Extract all frames from video.""" logger.info(f"Extracting frames from {video_path}") cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise ValueError(f"Could not open video: {video_path}") # Get video properties total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) logger.info(f"Video properties: {total_frames} frames, {fps:.2f} fps") frames = [] frame_idx = 0 # Strategy: Extract every Nth frame first, then fill remaining slots if needed frames_to_extract = max_frames if max_frames else total_frames # Calculate how many frames we can get at the specified interval frames_at_interval = (total_frames + frame_interval - 1) // frame_interval logger.info(f"Target: {frames_to_extract} frames") logger.info(f" - At interval {frame_interval}: can get up to {frames_at_interval} frames") interval_frames = [] frame_idx = 0 with tqdm(total=frames_to_extract, desc="Extracting frames", unit="frame") as pbar: # First pass: extract every Nth frame cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Reset to start while frame_idx < total_frames and len(interval_frames) < frames_to_extract: ret, frame = cap.read() if not ret: break if frame_idx % frame_interval == 0: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) interval_frames.append((frame_idx, frame_rgb)) pbar.update(1) frame_idx += 1 # Second pass: if we need more frames to reach target, fill in gaps if len(interval_frames) < frames_to_extract: logger.info(f" - Got {len(interval_frames)} frames at interval {frame_interval}") logger.info( f" - Filling remaining {frames_to_extract - len(interval_frames)} frames..." ) cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Reset to start frame_idx = 0 extracted_indices = {idx for idx, _ in interval_frames} while len(interval_frames) < frames_to_extract: ret, frame = cap.read() if not ret: logger.warning(f" - Video ended. Got {len(interval_frames)} frames total.") break if frame_idx not in extracted_indices: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) interval_frames.append((frame_idx, frame_rgb)) pbar.update(1) frame_idx += 1 # Sort by frame index and extract just the images interval_frames.sort(key=lambda x: x[0]) frames = [img for _, img in interval_frames] logger.info(f" - Final: {len(frames)} frames extracted") if len(interval_frames) > 0: frame_indices = [idx for idx, _ in interval_frames] logger.info( f" - Frame indices: {frame_indices[0]}...{frame_indices[-1]} " f"(every {frame_interval} + fills)" ) cap.release() logger.info(f"✓ Extracted {len(frames)} total frames") return frames def main(): import argparse logger.info("\n" + "=" * 60) logger.info("Parsing arguments...") parser = argparse.ArgumentParser(description="Run BA validation on video") parser.add_argument("--video", type=str, default=None, help="Path to video file") parser.add_argument( "--max-frames", type=int, default=None, help="Maximum number of frames to process" ) parser.add_argument( "--frame-interval", type=int, default=1, help="Extract every Nth frame (e.g., 15 for every 15th frame)", ) parser.add_argument("--output-dir", type=str, default=None, help="Output directory") args = parser.parse_args() # Paths if args.video: video_path = Path(args.video) else: video_path = project_root / "assets" / "examples" / "robot_unitree.mp4" if args.output_dir: output_dir = Path(args.output_dir) else: output_dir = project_root / "data" / "ba_validation_results" output_dir.mkdir(parents=True, exist_ok=True) logger.info("=" * 60) logger.info("BA Validation on Full Video") logger.info("=" * 60) logger.info(f"Video: {video_path}") logger.info(f"Output: {output_dir}") if args.max_frames: logger.info(f"Max frames: {args.max_frames}") logger.info("=" * 60) # 1. Extract frames logger.info("\n[Step 1] Extracting frames from video...") frames = extract_all_frames( video_path, max_frames=args.max_frames, frame_interval=args.frame_interval ) logger.info(f"✓ Extracted {len(frames)} frames (every {args.frame_interval} frame(s))") # 2. Run DA3 inference logger.info("\n[Step 2] Running DA3 inference...") logger.info("Loading DA3 model (this may take a moment)...") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") model = load_da3_model("depth-anything/DA3-LARGE", device=device) logger.info("✓ Model loaded") logger.info(f"Running inference on {len(frames)} frames (this may take a while)...") logger.info(" - Processing in batches...") logger.info(" - This step can take several minutes on CPU...") import threading import time # Start a heartbeat thread to show we're still alive inference_done = threading.Event() def heartbeat(): count = 0 while not inference_done.wait(30): # Log every 30 seconds count += 1 logger.info(f" - Still processing... ({count * 30}s elapsed)") heartbeat_thread = threading.Thread(target=heartbeat, daemon=True) heartbeat_thread.start() start_time = time.time() try: with torch.no_grad(): # DA3 inference processes all frames at once logger.info(" - Starting DA3 forward pass...") da3_output = model.inference(frames) elapsed = time.time() - start_time logger.info( f" - DA3 inference completed in {elapsed:.1f} seconds " f"({elapsed / len(frames):.2f}s per frame)" ) finally: inference_done.set() poses_da3 = da3_output.extrinsics # (N, 3, 4) intrinsics = da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None logger.info("✓ DA3 inference complete") logger.info(f" - Poses shape: {poses_da3.shape}") logger.info(f" - Intrinsics shape: {intrinsics.shape if intrinsics is not None else 'None'}") # 3. Run BA validation logger.info("\n[Step 3] Running BA validation...") logger.info("Initializing BA validator...") validator = BAValidator( accept_threshold=2.0, reject_threshold=30.0, work_dir=output_dir / "ba_work", ) logger.info("✓ BA validator initialized") logger.info("Validating poses with BA (this may take a while)...") logger.info(" - Step 3.1: Saving images...") logger.info(" - Step 3.2: Extracting features (SuperPoint)...") logger.info(" - Step 3.3: Matching features (LightGlue)...") logger.info(" - Step 3.4: Running Bundle Adjustment...") result = validator.validate( images=frames, poses_model=poses_da3, intrinsics=intrinsics, ) logger.info("✓ BA validation complete") # 4. Analyze results logger.info("\n[Step 4] Analyzing results...") status = result["status"] error = result.get("error") error_metrics = result.get("error_metrics", {}) logger.info(f"\n{'=' * 60}") logger.info("RESULTS") logger.info(f"{'=' * 60}") logger.info(f"Overall Status: {status}") if error is not None and isinstance(error, (int, float)): logger.info(f"Max Rotation Error: {error:.2f}°") elif error is not None: logger.info(f"Max Rotation Error: {error}") if error_metrics: rot_errors = error_metrics.get("rotation_errors_deg", []) if rot_errors: logger.info("\nRotation Error Statistics:") logger.info(f" - Mean: {np.mean(rot_errors):.2f}°") logger.info(f" - Median: {np.median(rot_errors):.2f}°") logger.info(f" - Max: {np.max(rot_errors):.2f}°") logger.info(f" - Min: {np.min(rot_errors):.2f}°") logger.info(f" - Std: {np.std(rot_errors):.2f}°") # Categorize individual frames accepted = [] rejected_learnable = [] rejected_outlier = [] for i, err in enumerate(rot_errors): if err < 2.0: accepted.append(i) elif err < 30.0: rejected_learnable.append(i) else: rejected_outlier.append(i) logger.info("\nFrame Categorization:") accepted_pct = 100 * len(accepted) / len(rot_errors) learnable_pct = 100 * len(rejected_learnable) / len(rot_errors) outlier_pct = 100 * len(rejected_outlier) / len(rot_errors) logger.info(f" - Accepted (< 2°): {len(accepted)} frames ({accepted_pct:.1f}%)") logger.info( f" - Rejected-Learnable (2-30°): {len(rejected_learnable)} frames " f"({learnable_pct:.1f}%)" ) logger.info( f" - Rejected-Outlier (> 30°): {len(rejected_outlier)} frames " f"({outlier_pct:.1f}%)" ) # Save detailed results results_dict = { "status": status, "error": float(error) if error is not None else None, "error_metrics": { "rotation_errors_deg": [float(e) for e in rot_errors], "mean_rotation_error_deg": float(np.mean(rot_errors)), "median_rotation_error_deg": float(np.median(rot_errors)), "max_rotation_error_deg": float(np.max(rot_errors)), "min_rotation_error_deg": float(np.min(rot_errors)), "std_rotation_error_deg": float(np.std(rot_errors)), }, "frame_categories": { "accepted": accepted, "rejected_learnable": rejected_learnable, "rejected_outlier": rejected_outlier, }, "num_frames": len(frames), } output_json = output_dir / "validation_results.json" with open(output_json, "w") as f: json.dump(results_dict, f, indent=2) logger.info(f"\n✓ Results saved to {output_json}") # Show some examples if rejected_learnable: logger.info("\nExample Rejected-Learnable frames (first 10):") for idx in rejected_learnable[:10]: logger.info(f" Frame {idx}: {rot_errors[idx]:.2f}°") if rejected_outlier: logger.info("\nExample Rejected-Outlier frames (first 10):") for idx in rejected_outlier[:10]: logger.info(f" Frame {idx}: {rot_errors[idx]:.2f}°") logger.info(f"\n{'=' * 60}") logger.info("✓ BA validation complete!") logger.info(f"{'=' * 60}") if __name__ == "__main__": main()