3d_model / scripts /experiments /run_ba_validation_video.py
Azan
Clean deployment build (Squashed)
7a87926
#!/usr/bin/env python3
"""
Run BA validation on full video to identify rejected frames.
"""
import os
import sys
from pathlib import Path
# Set environment variable FIRST before any imports
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# Add SuperGluePretrainedNetwork to Python path if it exists
superglue_path = Path("/tmp/SuperGluePretrainedNetwork")
if superglue_path.exists():
if str(superglue_path) not in sys.path:
sys.path.insert(0, str(superglue_path))
# Set up logging IMMEDIATELY
import logging # noqa: E402
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
force=True, # Force reconfiguration
)
logger = logging.getLogger(__name__)
logger.info("=" * 60)
logger.info("Starting BA Validation Script")
logger.info("=" * 60)
logger.info("Step 0: Importing dependencies...")
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
logger.info(f"Project root: {project_root}")
try:
logger.info(" - Importing numpy...")
import numpy as np
logger.info(" ✓ numpy imported")
logger.info(" - Importing cv2...")
import cv2
logger.info(" ✓ cv2 imported")
logger.info(" - Importing torch...")
import torch
logger.info(" ✓ torch imported")
logger.info(" - Importing tqdm...")
from tqdm import tqdm
logger.info(" ✓ tqdm imported")
logger.info(" - Importing json...")
import json
logger.info(" ✓ json imported")
logger.info(" - Importing typing...")
from typing import Optional
logger.info(" ✓ typing imported")
logger.info(" - Importing ylff modules...")
from ylff.utils.model_loader import load_da3_model
logger.info(" ✓ ylff.models imported")
from ylff.services.ba_validator import BAValidator
logger.info(" ✓ ylff.ba_validator imported")
logger.info("✓ All imports complete")
except Exception as e:
logger.error(f"✗ Import failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def extract_all_frames(
video_path: Path, max_frames: Optional[int] = None, frame_interval: int = 1
) -> list:
"""Extract all frames from video."""
logger.info(f"Extracting frames from {video_path}")
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Could not open video: {video_path}")
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
logger.info(f"Video properties: {total_frames} frames, {fps:.2f} fps")
frames = []
frame_idx = 0
# Strategy: Extract every Nth frame first, then fill remaining slots if needed
frames_to_extract = max_frames if max_frames else total_frames
# Calculate how many frames we can get at the specified interval
frames_at_interval = (total_frames + frame_interval - 1) // frame_interval
logger.info(f"Target: {frames_to_extract} frames")
logger.info(f" - At interval {frame_interval}: can get up to {frames_at_interval} frames")
interval_frames = []
frame_idx = 0
with tqdm(total=frames_to_extract, desc="Extracting frames", unit="frame") as pbar:
# First pass: extract every Nth frame
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Reset to start
while frame_idx < total_frames and len(interval_frames) < frames_to_extract:
ret, frame = cap.read()
if not ret:
break
if frame_idx % frame_interval == 0:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
interval_frames.append((frame_idx, frame_rgb))
pbar.update(1)
frame_idx += 1
# Second pass: if we need more frames to reach target, fill in gaps
if len(interval_frames) < frames_to_extract:
logger.info(f" - Got {len(interval_frames)} frames at interval {frame_interval}")
logger.info(
f" - Filling remaining {frames_to_extract - len(interval_frames)} frames..."
)
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Reset to start
frame_idx = 0
extracted_indices = {idx for idx, _ in interval_frames}
while len(interval_frames) < frames_to_extract:
ret, frame = cap.read()
if not ret:
logger.warning(f" - Video ended. Got {len(interval_frames)} frames total.")
break
if frame_idx not in extracted_indices:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
interval_frames.append((frame_idx, frame_rgb))
pbar.update(1)
frame_idx += 1
# Sort by frame index and extract just the images
interval_frames.sort(key=lambda x: x[0])
frames = [img for _, img in interval_frames]
logger.info(f" - Final: {len(frames)} frames extracted")
if len(interval_frames) > 0:
frame_indices = [idx for idx, _ in interval_frames]
logger.info(
f" - Frame indices: {frame_indices[0]}...{frame_indices[-1]} "
f"(every {frame_interval} + fills)"
)
cap.release()
logger.info(f"✓ Extracted {len(frames)} total frames")
return frames
def main():
import argparse
logger.info("\n" + "=" * 60)
logger.info("Parsing arguments...")
parser = argparse.ArgumentParser(description="Run BA validation on video")
parser.add_argument("--video", type=str, default=None, help="Path to video file")
parser.add_argument(
"--max-frames", type=int, default=None, help="Maximum number of frames to process"
)
parser.add_argument(
"--frame-interval",
type=int,
default=1,
help="Extract every Nth frame (e.g., 15 for every 15th frame)",
)
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
args = parser.parse_args()
# Paths
if args.video:
video_path = Path(args.video)
else:
video_path = project_root / "assets" / "examples" / "robot_unitree.mp4"
if args.output_dir:
output_dir = Path(args.output_dir)
else:
output_dir = project_root / "data" / "ba_validation_results"
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("=" * 60)
logger.info("BA Validation on Full Video")
logger.info("=" * 60)
logger.info(f"Video: {video_path}")
logger.info(f"Output: {output_dir}")
if args.max_frames:
logger.info(f"Max frames: {args.max_frames}")
logger.info("=" * 60)
# 1. Extract frames
logger.info("\n[Step 1] Extracting frames from video...")
frames = extract_all_frames(
video_path, max_frames=args.max_frames, frame_interval=args.frame_interval
)
logger.info(f"✓ Extracted {len(frames)} frames (every {args.frame_interval} frame(s))")
# 2. Run DA3 inference
logger.info("\n[Step 2] Running DA3 inference...")
logger.info("Loading DA3 model (this may take a moment)...")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
model = load_da3_model("depth-anything/DA3-LARGE", device=device)
logger.info("✓ Model loaded")
logger.info(f"Running inference on {len(frames)} frames (this may take a while)...")
logger.info(" - Processing in batches...")
logger.info(" - This step can take several minutes on CPU...")
import threading
import time
# Start a heartbeat thread to show we're still alive
inference_done = threading.Event()
def heartbeat():
count = 0
while not inference_done.wait(30): # Log every 30 seconds
count += 1
logger.info(f" - Still processing... ({count * 30}s elapsed)")
heartbeat_thread = threading.Thread(target=heartbeat, daemon=True)
heartbeat_thread.start()
start_time = time.time()
try:
with torch.no_grad():
# DA3 inference processes all frames at once
logger.info(" - Starting DA3 forward pass...")
da3_output = model.inference(frames)
elapsed = time.time() - start_time
logger.info(
f" - DA3 inference completed in {elapsed:.1f} seconds "
f"({elapsed / len(frames):.2f}s per frame)"
)
finally:
inference_done.set()
poses_da3 = da3_output.extrinsics # (N, 3, 4)
intrinsics = da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None
logger.info("✓ DA3 inference complete")
logger.info(f" - Poses shape: {poses_da3.shape}")
logger.info(f" - Intrinsics shape: {intrinsics.shape if intrinsics is not None else 'None'}")
# 3. Run BA validation
logger.info("\n[Step 3] Running BA validation...")
logger.info("Initializing BA validator...")
validator = BAValidator(
accept_threshold=2.0,
reject_threshold=30.0,
work_dir=output_dir / "ba_work",
)
logger.info("✓ BA validator initialized")
logger.info("Validating poses with BA (this may take a while)...")
logger.info(" - Step 3.1: Saving images...")
logger.info(" - Step 3.2: Extracting features (SuperPoint)...")
logger.info(" - Step 3.3: Matching features (LightGlue)...")
logger.info(" - Step 3.4: Running Bundle Adjustment...")
result = validator.validate(
images=frames,
poses_model=poses_da3,
intrinsics=intrinsics,
)
logger.info("✓ BA validation complete")
# 4. Analyze results
logger.info("\n[Step 4] Analyzing results...")
status = result["status"]
error = result.get("error")
error_metrics = result.get("error_metrics", {})
logger.info(f"\n{'=' * 60}")
logger.info("RESULTS")
logger.info(f"{'=' * 60}")
logger.info(f"Overall Status: {status}")
if error is not None and isinstance(error, (int, float)):
logger.info(f"Max Rotation Error: {error:.2f}°")
elif error is not None:
logger.info(f"Max Rotation Error: {error}")
if error_metrics:
rot_errors = error_metrics.get("rotation_errors_deg", [])
if rot_errors:
logger.info("\nRotation Error Statistics:")
logger.info(f" - Mean: {np.mean(rot_errors):.2f}°")
logger.info(f" - Median: {np.median(rot_errors):.2f}°")
logger.info(f" - Max: {np.max(rot_errors):.2f}°")
logger.info(f" - Min: {np.min(rot_errors):.2f}°")
logger.info(f" - Std: {np.std(rot_errors):.2f}°")
# Categorize individual frames
accepted = []
rejected_learnable = []
rejected_outlier = []
for i, err in enumerate(rot_errors):
if err < 2.0:
accepted.append(i)
elif err < 30.0:
rejected_learnable.append(i)
else:
rejected_outlier.append(i)
logger.info("\nFrame Categorization:")
accepted_pct = 100 * len(accepted) / len(rot_errors)
learnable_pct = 100 * len(rejected_learnable) / len(rot_errors)
outlier_pct = 100 * len(rejected_outlier) / len(rot_errors)
logger.info(f" - Accepted (< 2°): {len(accepted)} frames ({accepted_pct:.1f}%)")
logger.info(
f" - Rejected-Learnable (2-30°): {len(rejected_learnable)} frames "
f"({learnable_pct:.1f}%)"
)
logger.info(
f" - Rejected-Outlier (> 30°): {len(rejected_outlier)} frames "
f"({outlier_pct:.1f}%)"
)
# Save detailed results
results_dict = {
"status": status,
"error": float(error) if error is not None else None,
"error_metrics": {
"rotation_errors_deg": [float(e) for e in rot_errors],
"mean_rotation_error_deg": float(np.mean(rot_errors)),
"median_rotation_error_deg": float(np.median(rot_errors)),
"max_rotation_error_deg": float(np.max(rot_errors)),
"min_rotation_error_deg": float(np.min(rot_errors)),
"std_rotation_error_deg": float(np.std(rot_errors)),
},
"frame_categories": {
"accepted": accepted,
"rejected_learnable": rejected_learnable,
"rejected_outlier": rejected_outlier,
},
"num_frames": len(frames),
}
output_json = output_dir / "validation_results.json"
with open(output_json, "w") as f:
json.dump(results_dict, f, indent=2)
logger.info(f"\n✓ Results saved to {output_json}")
# Show some examples
if rejected_learnable:
logger.info("\nExample Rejected-Learnable frames (first 10):")
for idx in rejected_learnable[:10]:
logger.info(f" Frame {idx}: {rot_errors[idx]:.2f}°")
if rejected_outlier:
logger.info("\nExample Rejected-Outlier frames (first 10):")
for idx in rejected_outlier[:10]:
logger.info(f" Frame {idx}: {rot_errors[idx]:.2f}°")
logger.info(f"\n{'=' * 60}")
logger.info("✓ BA validation complete!")
logger.info(f"{'=' * 60}")
if __name__ == "__main__":
main()