3d_model / ylff /services /evaluate.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Evaluation scripts for measuring improvement after fine-tuning.
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from ..utils.wandb_utils import finish_wandb, init_wandb, log_metrics
from .ba_validator import BAValidator
logger = logging.getLogger(__name__)
def evaluate_ba_agreement(
model: torch.nn.Module,
sequences: List[Path],
ba_validator: BAValidator,
threshold: float = 2.0, # degrees
use_wandb: bool = True,
wandb_project: str = "ylff",
wandb_name: Optional[str] = None,
) -> Dict:
"""
Evaluate model agreement with BA.
Args:
model: Model to evaluate
sequences: List of sequence paths
ba_validator: BA validator
threshold: Agreement threshold (degrees)
Returns:
Dictionary with evaluation metrics
"""
model.eval()
agreement_count = 0
total_count = 0
rotation_errors = []
translation_errors = []
for seq_path in sequences:
# Load images
images = load_images(seq_path)
if len(images) == 0:
continue
# Run model
with torch.no_grad():
output = model.inference(images)
poses_model = output.extrinsics
# Validate with BA
result = ba_validator.validate(
images=images,
poses_model=poses_model,
)
if result["status"] == "ba_failed":
continue
total_count += 1
error = result["error"]
if error < threshold:
agreement_count += 1
rotation_errors.append(error)
if "error_metrics" in result:
translation_errors.extend(result["error_metrics"].get("translation_errors", []))
agreement_rate = agreement_count / total_count if total_count > 0 else 0.0
mean_rot_error = np.mean(rotation_errors) if rotation_errors else 0.0
mean_trans_error = np.mean(translation_errors) if translation_errors else 0.0
metrics = {
"agreement_rate": agreement_rate,
"agreement_count": agreement_count,
"total_count": total_count,
"mean_rotation_error_deg": mean_rot_error,
"mean_translation_error": mean_trans_error,
"rotation_errors": rotation_errors,
"translation_errors": translation_errors,
}
# Log to wandb
if use_wandb:
wandb_run = init_wandb(
project=wandb_project,
name=wandb_name or f"eval-ba-agreement-{len(sequences)}-seqs",
config={
"task": "evaluation",
"threshold": threshold,
"num_sequences": len(sequences),
},
tags=["evaluation", "ba-agreement"],
)
if wandb_run:
log_metrics(
{
"eval/agreement_rate": agreement_rate,
"eval/agreement_count": agreement_count,
"eval/total_count": total_count,
"eval/mean_rotation_error_deg": mean_rot_error,
"eval/mean_translation_error": mean_trans_error,
}
)
finish_wandb()
return metrics
def load_images(sequence_path: Path) -> List[np.ndarray]:
"""Load images from sequence directory."""
import cv2
image_extensions = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
image_paths = sorted([p for p in sequence_path.iterdir() if p.suffix in image_extensions])
images = []
for img_path in image_paths:
img = cv2.imread(str(img_path))
if img is None:
continue
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
images.append(img_rgb)
return images