| | """ |
| | Evaluation scripts for measuring improvement after fine-tuning. |
| | """ |
| |
|
| | import logging |
| | from pathlib import Path |
| | from typing import Dict, List, Optional |
| | import numpy as np |
| | import torch |
| |
|
| | from ..utils.wandb_utils import finish_wandb, init_wandb, log_metrics |
| | from .ba_validator import BAValidator |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def evaluate_ba_agreement( |
| | model: torch.nn.Module, |
| | sequences: List[Path], |
| | ba_validator: BAValidator, |
| | threshold: float = 2.0, |
| | use_wandb: bool = True, |
| | wandb_project: str = "ylff", |
| | wandb_name: Optional[str] = None, |
| | ) -> Dict: |
| | """ |
| | Evaluate model agreement with BA. |
| | |
| | Args: |
| | model: Model to evaluate |
| | sequences: List of sequence paths |
| | ba_validator: BA validator |
| | threshold: Agreement threshold (degrees) |
| | |
| | Returns: |
| | Dictionary with evaluation metrics |
| | """ |
| | model.eval() |
| |
|
| | agreement_count = 0 |
| | total_count = 0 |
| | rotation_errors = [] |
| | translation_errors = [] |
| |
|
| | for seq_path in sequences: |
| | |
| | images = load_images(seq_path) |
| | if len(images) == 0: |
| | continue |
| |
|
| | |
| | with torch.no_grad(): |
| | output = model.inference(images) |
| |
|
| | poses_model = output.extrinsics |
| |
|
| | |
| | result = ba_validator.validate( |
| | images=images, |
| | poses_model=poses_model, |
| | ) |
| |
|
| | if result["status"] == "ba_failed": |
| | continue |
| |
|
| | total_count += 1 |
| | error = result["error"] |
| |
|
| | if error < threshold: |
| | agreement_count += 1 |
| |
|
| | rotation_errors.append(error) |
| | if "error_metrics" in result: |
| | translation_errors.extend(result["error_metrics"].get("translation_errors", [])) |
| |
|
| | agreement_rate = agreement_count / total_count if total_count > 0 else 0.0 |
| | mean_rot_error = np.mean(rotation_errors) if rotation_errors else 0.0 |
| | mean_trans_error = np.mean(translation_errors) if translation_errors else 0.0 |
| |
|
| | metrics = { |
| | "agreement_rate": agreement_rate, |
| | "agreement_count": agreement_count, |
| | "total_count": total_count, |
| | "mean_rotation_error_deg": mean_rot_error, |
| | "mean_translation_error": mean_trans_error, |
| | "rotation_errors": rotation_errors, |
| | "translation_errors": translation_errors, |
| | } |
| |
|
| | |
| | if use_wandb: |
| | wandb_run = init_wandb( |
| | project=wandb_project, |
| | name=wandb_name or f"eval-ba-agreement-{len(sequences)}-seqs", |
| | config={ |
| | "task": "evaluation", |
| | "threshold": threshold, |
| | "num_sequences": len(sequences), |
| | }, |
| | tags=["evaluation", "ba-agreement"], |
| | ) |
| |
|
| | if wandb_run: |
| | log_metrics( |
| | { |
| | "eval/agreement_rate": agreement_rate, |
| | "eval/agreement_count": agreement_count, |
| | "eval/total_count": total_count, |
| | "eval/mean_rotation_error_deg": mean_rot_error, |
| | "eval/mean_translation_error": mean_trans_error, |
| | } |
| | ) |
| | finish_wandb() |
| |
|
| | return metrics |
| |
|
| |
|
| | def load_images(sequence_path: Path) -> List[np.ndarray]: |
| | """Load images from sequence directory.""" |
| | import cv2 |
| |
|
| | image_extensions = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"} |
| | image_paths = sorted([p for p in sequence_path.iterdir() if p.suffix in image_extensions]) |
| |
|
| | images = [] |
| | for img_path in image_paths: |
| | img = cv2.imread(str(img_path)) |
| | if img is None: |
| | continue |
| | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| | images.append(img_rgb) |
| |
|
| | return images |
| |
|