""" Evaluation scripts for measuring improvement after fine-tuning. """ import logging from pathlib import Path from typing import Dict, List, Optional import numpy as np import torch from ..utils.wandb_utils import finish_wandb, init_wandb, log_metrics from .ba_validator import BAValidator logger = logging.getLogger(__name__) def evaluate_ba_agreement( model: torch.nn.Module, sequences: List[Path], ba_validator: BAValidator, threshold: float = 2.0, # degrees use_wandb: bool = True, wandb_project: str = "ylff", wandb_name: Optional[str] = None, ) -> Dict: """ Evaluate model agreement with BA. Args: model: Model to evaluate sequences: List of sequence paths ba_validator: BA validator threshold: Agreement threshold (degrees) Returns: Dictionary with evaluation metrics """ model.eval() agreement_count = 0 total_count = 0 rotation_errors = [] translation_errors = [] for seq_path in sequences: # Load images images = load_images(seq_path) if len(images) == 0: continue # Run model with torch.no_grad(): output = model.inference(images) poses_model = output.extrinsics # Validate with BA result = ba_validator.validate( images=images, poses_model=poses_model, ) if result["status"] == "ba_failed": continue total_count += 1 error = result["error"] if error < threshold: agreement_count += 1 rotation_errors.append(error) if "error_metrics" in result: translation_errors.extend(result["error_metrics"].get("translation_errors", [])) agreement_rate = agreement_count / total_count if total_count > 0 else 0.0 mean_rot_error = np.mean(rotation_errors) if rotation_errors else 0.0 mean_trans_error = np.mean(translation_errors) if translation_errors else 0.0 metrics = { "agreement_rate": agreement_rate, "agreement_count": agreement_count, "total_count": total_count, "mean_rotation_error_deg": mean_rot_error, "mean_translation_error": mean_trans_error, "rotation_errors": rotation_errors, "translation_errors": translation_errors, } # Log to wandb if use_wandb: wandb_run = init_wandb( project=wandb_project, name=wandb_name or f"eval-ba-agreement-{len(sequences)}-seqs", config={ "task": "evaluation", "threshold": threshold, "num_sequences": len(sequences), }, tags=["evaluation", "ba-agreement"], ) if wandb_run: log_metrics( { "eval/agreement_rate": agreement_rate, "eval/agreement_count": agreement_count, "eval/total_count": total_count, "eval/mean_rotation_error_deg": mean_rot_error, "eval/mean_translation_error": mean_trans_error, } ) finish_wandb() return metrics def load_images(sequence_path: Path) -> List[np.ndarray]: """Load images from sequence directory.""" import cv2 image_extensions = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"} image_paths = sorted([p for p in sequence_path.iterdir() if p.suffix in image_extensions]) images = [] for img_path in image_paths: img = cv2.imread(str(img_path)) if img is None: continue img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) images.append(img_rgb) return images