sam3 / metrics_evaluation /run_evaluation.py
Thibaut's picture
Fix import paths for metrics evaluation - corrected relative imports and client class names
03a45bc
#!/usr/bin/env python3
"""Main execution script for SAM3 metrics evaluation."""
import argparse
import json
import logging
import sys
from pathlib import Path
from metrics_evaluation.config.config_loader import load_config
from metrics_evaluation.extraction.cvat_extractor import CVATExtractor
from metrics_evaluation.inference.sam3_inference import SAM3Inferencer
from metrics_evaluation.metrics.metrics_calculator import MetricsCalculator
from metrics_evaluation.utils.logging_config import setup_logging
from metrics_evaluation.visualization.visual_comparison import VisualComparator
logger = logging.getLogger(__name__)
def write_metrics_summary(metrics: dict, output_path: Path) -> None:
"""Write human-readable metrics summary.
Args:
metrics: Metrics dictionary
output_path: Path to output file
"""
with open(output_path, "w") as f:
f.write("=" * 80 + "\n")
f.write("SAM3 EVALUATION METRICS SUMMARY\n")
f.write("=" * 80 + "\n\n")
aggregate = metrics["aggregate"]
f.write(f"Total Images Evaluated: {aggregate['total_images']}\n\n")
for threshold_str, threshold_data in aggregate["by_threshold"].items():
iou = threshold_data["iou_threshold"]
f.write(f"\n{'='*80}\n")
f.write(f"IoU Threshold: {iou:.0%}\n")
f.write(f"{'='*80}\n\n")
overall = threshold_data["overall"]
f.write("Overall Metrics:\n")
f.write(f" True Positives: {overall['true_positives']}\n")
f.write(f" False Positives: {overall['false_positives']}\n")
f.write(f" False Negatives: {overall['false_negatives']}\n")
f.write(f" Precision: {overall['precision']:.2%}\n")
f.write(f" Recall: {overall['recall']:.2%}\n")
f.write(f" F1-Score: {overall['f1']:.2%}\n")
f.write(f" mAP: {overall['map']:.2%}\n")
f.write(f" mAR: {overall['mar']:.2%}\n\n")
f.write("Per-Class Metrics:\n")
f.write("-" * 80 + "\n")
f.write(f"{'Class':<20} {'GT':>6} {'Pred':>6} {'TP':>6} {'FP':>6} {'FN':>6} {'Prec':>8} {'Rec':>8} {'F1':>8}\n")
f.write("-" * 80 + "\n")
for label, stats in sorted(threshold_data["by_label"].items()):
f.write(
f"{label:<20} "
f"{stats['gt_total']:>6} "
f"{stats['pred_total']:>6} "
f"{stats['tp']:>6} "
f"{stats['fp']:>6} "
f"{stats['fn']:>6} "
f"{stats['precision']:>8.2%} "
f"{stats['recall']:>8.2%} "
f"{stats['f1']:>8.2%}\n"
)
f.write("\n")
# Confusion Matrix
cm = threshold_data["confusion_matrix"]
labels = cm["labels"]
matrix = cm["matrix"]
if labels:
f.write("Confusion Matrix:\n")
f.write("-" * 80 + "\n")
# Header
header = "Actual \\ Pred |"
for label in labels:
header += f" {label[:10]:>10} |"
f.write(header + "\n")
f.write("-" * len(header) + "\n")
# Rows
for i, actual_label in enumerate(labels):
row = f"{actual_label[:13]:>13} |"
for j in range(len(labels)):
row += f" {matrix[i][j]:>10} |"
f.write(row + "\n")
f.write("\n")
f.write("=" * 80 + "\n")
f.write("END OF REPORT\n")
f.write("=" * 80 + "\n")
logger.info(f"Wrote metrics summary to {output_path}")
def main() -> int:
"""Main execution function.
Returns:
Exit code (0 for success, non-zero for failure)
"""
parser = argparse.ArgumentParser(
description="Run SAM3 metrics evaluation against CVAT ground truth"
)
parser.add_argument(
"--config",
type=str,
default="config/config.json",
help="Path to configuration file"
)
parser.add_argument(
"--force-download",
action="store_true",
help="Force re-download images from CVAT"
)
parser.add_argument(
"--force-inference",
action="store_true",
help="Force re-run SAM3 inference"
)
parser.add_argument(
"--skip-inference",
action="store_true",
help="Skip inference, use cached results"
)
parser.add_argument(
"--visualize",
action="store_true",
help="Generate visual comparisons"
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level"
)
args = parser.parse_args()
# Load configuration
try:
config = load_config(args.config)
except Exception as e:
print(f"ERROR: Failed to load configuration: {e}", file=sys.stderr)
return 1
# Setup logging
cache_dir = config.get_cache_path()
log_file = cache_dir / "evaluation_log.txt"
setup_logging(log_file, getattr(logging, args.log_level))
logger.info("=" * 80)
logger.info("SAM3 METRICS EVALUATION")
logger.info("=" * 80)
try:
# Phase 1: Extract from CVAT
logger.info("\n" + "=" * 80)
logger.info("PHASE 1: CVAT Data Extraction")
logger.info("=" * 80)
extractor = CVATExtractor(config)
if args.force_download:
logger.info("Force download enabled - will re-download all images")
image_paths = extractor.run_extraction()
total_extracted = sum(len(paths) for paths in image_paths.values())
logger.info(f"Extraction complete: {total_extracted} images extracted")
if total_extracted == 0:
logger.error("No images extracted. Aborting.")
return 1
# Phase 2: Run SAM3 Inference
if not args.skip_inference:
logger.info("\n" + "=" * 80)
logger.info("PHASE 2: SAM3 Inference")
logger.info("=" * 80)
inferencer = SAM3Inferencer(config)
stats = inferencer.run_inference_batch(image_paths, args.force_inference)
logger.info(
f"Inference complete: {stats['successful']} successful, "
f"{stats['failed']} failed, {stats['skipped']} skipped"
)
if stats['successful'] == 0 and stats['skipped'] == 0:
logger.error("No successful inferences. Aborting.")
return 1
else:
logger.info("Skipping inference (--skip-inference)")
# Phase 3: Calculate Metrics
logger.info("\n" + "=" * 80)
logger.info("PHASE 3: Metrics Calculation")
logger.info("=" * 80)
calculator = MetricsCalculator(config)
metrics = calculator.run_evaluation(cache_dir)
# Save detailed metrics
metrics_json_path = cache_dir / "metrics_detailed.json"
with open(metrics_json_path, "w") as f:
json.dump(metrics, f, indent=2)
logger.info(f"Saved detailed metrics to {metrics_json_path}")
# Write summary
metrics_summary_path = cache_dir / "metrics_summary.txt"
write_metrics_summary(metrics, metrics_summary_path)
# Phase 4: Visualization (optional)
if args.visualize or config.output.generate_visualizations:
logger.info("\n" + "=" * 80)
logger.info("PHASE 4: Visual Comparisons")
logger.info("=" * 80)
comparator = VisualComparator()
comparison_paths = comparator.generate_all_comparisons(cache_dir)
logger.info(f"Generated {len(comparison_paths)} visual comparisons")
# Summary
logger.info("\n" + "=" * 80)
logger.info("EVALUATION COMPLETE")
logger.info("=" * 80)
aggregate = metrics["aggregate"]
logger.info(f"Images evaluated: {aggregate['total_images']}")
# Show metrics at 50% IoU
threshold_50 = aggregate["by_threshold"]["0.5"]
overall = threshold_50["overall"]
logger.info(f"\nMetrics at 50% IoU:")
logger.info(f" Precision: {overall['precision']:.2%}")
logger.info(f" Recall: {overall['recall']:.2%}")
logger.info(f" F1-Score: {overall['f1']:.2%}")
logger.info(f" mAP: {overall['map']:.2%}")
logger.info(f" mAR: {overall['mar']:.2%}")
logger.info(f"\nResults saved to:")
logger.info(f" Metrics Summary: {metrics_summary_path}")
logger.info(f" Detailed JSON: {metrics_json_path}")
logger.info(f" Execution Log: {log_file}")
return 0
except KeyboardInterrupt:
logger.warning("\nEvaluation interrupted by user")
return 130
except Exception as e:
logger.error(f"\nEvaluation failed with error: {e}", exc_info=True)
return 1
if __name__ == "__main__":
sys.exit(main())