|
|
|
|
|
""" |
|
|
Run BA validation on full video to identify rejected frames. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
|
|
|
|
|
|
superglue_path = Path("/tmp/SuperGluePretrainedNetwork") |
|
|
if superglue_path.exists(): |
|
|
if str(superglue_path) not in sys.path: |
|
|
sys.path.insert(0, str(superglue_path)) |
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
|
force=True, |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("Starting BA Validation Script") |
|
|
logger.info("=" * 60) |
|
|
logger.info("Step 0: Importing dependencies...") |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
logger.info(f"Project root: {project_root}") |
|
|
|
|
|
try: |
|
|
logger.info(" - Importing numpy...") |
|
|
import numpy as np |
|
|
|
|
|
logger.info(" ✓ numpy imported") |
|
|
|
|
|
logger.info(" - Importing cv2...") |
|
|
import cv2 |
|
|
|
|
|
logger.info(" ✓ cv2 imported") |
|
|
|
|
|
logger.info(" - Importing torch...") |
|
|
import torch |
|
|
|
|
|
logger.info(" ✓ torch imported") |
|
|
|
|
|
logger.info(" - Importing tqdm...") |
|
|
from tqdm import tqdm |
|
|
|
|
|
logger.info(" ✓ tqdm imported") |
|
|
|
|
|
logger.info(" - Importing json...") |
|
|
import json |
|
|
|
|
|
logger.info(" ✓ json imported") |
|
|
|
|
|
logger.info(" - Importing typing...") |
|
|
from typing import Optional |
|
|
|
|
|
logger.info(" ✓ typing imported") |
|
|
|
|
|
logger.info(" - Importing ylff modules...") |
|
|
from ylff.utils.model_loader import load_da3_model |
|
|
|
|
|
logger.info(" ✓ ylff.models imported") |
|
|
|
|
|
from ylff.services.ba_validator import BAValidator |
|
|
|
|
|
logger.info(" ✓ ylff.ba_validator imported") |
|
|
|
|
|
logger.info("✓ All imports complete") |
|
|
except Exception as e: |
|
|
logger.error(f"✗ Import failed: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def extract_all_frames( |
|
|
video_path: Path, max_frames: Optional[int] = None, frame_interval: int = 1 |
|
|
) -> list: |
|
|
"""Extract all frames from video.""" |
|
|
logger.info(f"Extracting frames from {video_path}") |
|
|
|
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
if not cap.isOpened(): |
|
|
raise ValueError(f"Could not open video: {video_path}") |
|
|
|
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
logger.info(f"Video properties: {total_frames} frames, {fps:.2f} fps") |
|
|
|
|
|
frames = [] |
|
|
frame_idx = 0 |
|
|
|
|
|
|
|
|
frames_to_extract = max_frames if max_frames else total_frames |
|
|
|
|
|
|
|
|
frames_at_interval = (total_frames + frame_interval - 1) // frame_interval |
|
|
|
|
|
logger.info(f"Target: {frames_to_extract} frames") |
|
|
logger.info(f" - At interval {frame_interval}: can get up to {frames_at_interval} frames") |
|
|
|
|
|
interval_frames = [] |
|
|
frame_idx = 0 |
|
|
|
|
|
with tqdm(total=frames_to_extract, desc="Extracting frames", unit="frame") as pbar: |
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
|
|
while frame_idx < total_frames and len(interval_frames) < frames_to_extract: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
if frame_idx % frame_interval == 0: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
interval_frames.append((frame_idx, frame_rgb)) |
|
|
pbar.update(1) |
|
|
|
|
|
frame_idx += 1 |
|
|
|
|
|
|
|
|
if len(interval_frames) < frames_to_extract: |
|
|
logger.info(f" - Got {len(interval_frames)} frames at interval {frame_interval}") |
|
|
logger.info( |
|
|
f" - Filling remaining {frames_to_extract - len(interval_frames)} frames..." |
|
|
) |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
|
|
frame_idx = 0 |
|
|
extracted_indices = {idx for idx, _ in interval_frames} |
|
|
|
|
|
while len(interval_frames) < frames_to_extract: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
logger.warning(f" - Video ended. Got {len(interval_frames)} frames total.") |
|
|
break |
|
|
|
|
|
if frame_idx not in extracted_indices: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
interval_frames.append((frame_idx, frame_rgb)) |
|
|
pbar.update(1) |
|
|
|
|
|
frame_idx += 1 |
|
|
|
|
|
|
|
|
interval_frames.sort(key=lambda x: x[0]) |
|
|
frames = [img for _, img in interval_frames] |
|
|
|
|
|
logger.info(f" - Final: {len(frames)} frames extracted") |
|
|
if len(interval_frames) > 0: |
|
|
frame_indices = [idx for idx, _ in interval_frames] |
|
|
logger.info( |
|
|
f" - Frame indices: {frame_indices[0]}...{frame_indices[-1]} " |
|
|
f"(every {frame_interval} + fills)" |
|
|
) |
|
|
|
|
|
cap.release() |
|
|
logger.info(f"✓ Extracted {len(frames)} total frames") |
|
|
return frames |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("Parsing arguments...") |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run BA validation on video") |
|
|
parser.add_argument("--video", type=str, default=None, help="Path to video file") |
|
|
parser.add_argument( |
|
|
"--max-frames", type=int, default=None, help="Maximum number of frames to process" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frame-interval", |
|
|
type=int, |
|
|
default=1, |
|
|
help="Extract every Nth frame (e.g., 15 for every 15th frame)", |
|
|
) |
|
|
parser.add_argument("--output-dir", type=str, default=None, help="Output directory") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.video: |
|
|
video_path = Path(args.video) |
|
|
else: |
|
|
video_path = project_root / "assets" / "examples" / "robot_unitree.mp4" |
|
|
|
|
|
if args.output_dir: |
|
|
output_dir = Path(args.output_dir) |
|
|
else: |
|
|
output_dir = project_root / "data" / "ba_validation_results" |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("BA Validation on Full Video") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Video: {video_path}") |
|
|
logger.info(f"Output: {output_dir}") |
|
|
if args.max_frames: |
|
|
logger.info(f"Max frames: {args.max_frames}") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
|
|
|
logger.info("\n[Step 1] Extracting frames from video...") |
|
|
frames = extract_all_frames( |
|
|
video_path, max_frames=args.max_frames, frame_interval=args.frame_interval |
|
|
) |
|
|
logger.info(f"✓ Extracted {len(frames)} frames (every {args.frame_interval} frame(s))") |
|
|
|
|
|
|
|
|
logger.info("\n[Step 2] Running DA3 inference...") |
|
|
logger.info("Loading DA3 model (this may take a moment)...") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
model = load_da3_model("depth-anything/DA3-LARGE", device=device) |
|
|
logger.info("✓ Model loaded") |
|
|
|
|
|
logger.info(f"Running inference on {len(frames)} frames (this may take a while)...") |
|
|
logger.info(" - Processing in batches...") |
|
|
logger.info(" - This step can take several minutes on CPU...") |
|
|
|
|
|
import threading |
|
|
import time |
|
|
|
|
|
|
|
|
inference_done = threading.Event() |
|
|
|
|
|
def heartbeat(): |
|
|
count = 0 |
|
|
while not inference_done.wait(30): |
|
|
count += 1 |
|
|
logger.info(f" - Still processing... ({count * 30}s elapsed)") |
|
|
|
|
|
heartbeat_thread = threading.Thread(target=heartbeat, daemon=True) |
|
|
heartbeat_thread.start() |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
|
|
|
logger.info(" - Starting DA3 forward pass...") |
|
|
da3_output = model.inference(frames) |
|
|
elapsed = time.time() - start_time |
|
|
logger.info( |
|
|
f" - DA3 inference completed in {elapsed:.1f} seconds " |
|
|
f"({elapsed / len(frames):.2f}s per frame)" |
|
|
) |
|
|
finally: |
|
|
inference_done.set() |
|
|
|
|
|
poses_da3 = da3_output.extrinsics |
|
|
intrinsics = da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None |
|
|
|
|
|
logger.info("✓ DA3 inference complete") |
|
|
logger.info(f" - Poses shape: {poses_da3.shape}") |
|
|
logger.info(f" - Intrinsics shape: {intrinsics.shape if intrinsics is not None else 'None'}") |
|
|
|
|
|
|
|
|
logger.info("\n[Step 3] Running BA validation...") |
|
|
logger.info("Initializing BA validator...") |
|
|
validator = BAValidator( |
|
|
accept_threshold=2.0, |
|
|
reject_threshold=30.0, |
|
|
work_dir=output_dir / "ba_work", |
|
|
) |
|
|
logger.info("✓ BA validator initialized") |
|
|
|
|
|
logger.info("Validating poses with BA (this may take a while)...") |
|
|
logger.info(" - Step 3.1: Saving images...") |
|
|
logger.info(" - Step 3.2: Extracting features (SuperPoint)...") |
|
|
logger.info(" - Step 3.3: Matching features (LightGlue)...") |
|
|
logger.info(" - Step 3.4: Running Bundle Adjustment...") |
|
|
|
|
|
result = validator.validate( |
|
|
images=frames, |
|
|
poses_model=poses_da3, |
|
|
intrinsics=intrinsics, |
|
|
) |
|
|
|
|
|
logger.info("✓ BA validation complete") |
|
|
|
|
|
|
|
|
logger.info("\n[Step 4] Analyzing results...") |
|
|
|
|
|
status = result["status"] |
|
|
error = result.get("error") |
|
|
error_metrics = result.get("error_metrics", {}) |
|
|
|
|
|
logger.info(f"\n{'=' * 60}") |
|
|
logger.info("RESULTS") |
|
|
logger.info(f"{'=' * 60}") |
|
|
logger.info(f"Overall Status: {status}") |
|
|
|
|
|
if error is not None and isinstance(error, (int, float)): |
|
|
logger.info(f"Max Rotation Error: {error:.2f}°") |
|
|
elif error is not None: |
|
|
logger.info(f"Max Rotation Error: {error}") |
|
|
|
|
|
if error_metrics: |
|
|
rot_errors = error_metrics.get("rotation_errors_deg", []) |
|
|
if rot_errors: |
|
|
logger.info("\nRotation Error Statistics:") |
|
|
logger.info(f" - Mean: {np.mean(rot_errors):.2f}°") |
|
|
logger.info(f" - Median: {np.median(rot_errors):.2f}°") |
|
|
logger.info(f" - Max: {np.max(rot_errors):.2f}°") |
|
|
logger.info(f" - Min: {np.min(rot_errors):.2f}°") |
|
|
logger.info(f" - Std: {np.std(rot_errors):.2f}°") |
|
|
|
|
|
|
|
|
accepted = [] |
|
|
rejected_learnable = [] |
|
|
rejected_outlier = [] |
|
|
|
|
|
for i, err in enumerate(rot_errors): |
|
|
if err < 2.0: |
|
|
accepted.append(i) |
|
|
elif err < 30.0: |
|
|
rejected_learnable.append(i) |
|
|
else: |
|
|
rejected_outlier.append(i) |
|
|
|
|
|
logger.info("\nFrame Categorization:") |
|
|
accepted_pct = 100 * len(accepted) / len(rot_errors) |
|
|
learnable_pct = 100 * len(rejected_learnable) / len(rot_errors) |
|
|
outlier_pct = 100 * len(rejected_outlier) / len(rot_errors) |
|
|
logger.info(f" - Accepted (< 2°): {len(accepted)} frames ({accepted_pct:.1f}%)") |
|
|
logger.info( |
|
|
f" - Rejected-Learnable (2-30°): {len(rejected_learnable)} frames " |
|
|
f"({learnable_pct:.1f}%)" |
|
|
) |
|
|
logger.info( |
|
|
f" - Rejected-Outlier (> 30°): {len(rejected_outlier)} frames " |
|
|
f"({outlier_pct:.1f}%)" |
|
|
) |
|
|
|
|
|
|
|
|
results_dict = { |
|
|
"status": status, |
|
|
"error": float(error) if error is not None else None, |
|
|
"error_metrics": { |
|
|
"rotation_errors_deg": [float(e) for e in rot_errors], |
|
|
"mean_rotation_error_deg": float(np.mean(rot_errors)), |
|
|
"median_rotation_error_deg": float(np.median(rot_errors)), |
|
|
"max_rotation_error_deg": float(np.max(rot_errors)), |
|
|
"min_rotation_error_deg": float(np.min(rot_errors)), |
|
|
"std_rotation_error_deg": float(np.std(rot_errors)), |
|
|
}, |
|
|
"frame_categories": { |
|
|
"accepted": accepted, |
|
|
"rejected_learnable": rejected_learnable, |
|
|
"rejected_outlier": rejected_outlier, |
|
|
}, |
|
|
"num_frames": len(frames), |
|
|
} |
|
|
|
|
|
output_json = output_dir / "validation_results.json" |
|
|
with open(output_json, "w") as f: |
|
|
json.dump(results_dict, f, indent=2) |
|
|
|
|
|
logger.info(f"\n✓ Results saved to {output_json}") |
|
|
|
|
|
|
|
|
if rejected_learnable: |
|
|
logger.info("\nExample Rejected-Learnable frames (first 10):") |
|
|
for idx in rejected_learnable[:10]: |
|
|
logger.info(f" Frame {idx}: {rot_errors[idx]:.2f}°") |
|
|
|
|
|
if rejected_outlier: |
|
|
logger.info("\nExample Rejected-Outlier frames (first 10):") |
|
|
for idx in rejected_outlier[:10]: |
|
|
logger.info(f" Frame {idx}: {rot_errors[idx]:.2f}°") |
|
|
|
|
|
logger.info(f"\n{'=' * 60}") |
|
|
logger.info("✓ BA validation complete!") |
|
|
logger.info(f"{'=' * 60}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|