#!/usr/bin/env python3 """Main execution script for SAM3 metrics evaluation.""" import argparse import json import logging import sys from pathlib import Path from metrics_evaluation.config.config_loader import load_config from metrics_evaluation.extraction.cvat_extractor import CVATExtractor from metrics_evaluation.inference.sam3_inference import SAM3Inferencer from metrics_evaluation.metrics.metrics_calculator import MetricsCalculator from metrics_evaluation.utils.logging_config import setup_logging from metrics_evaluation.visualization.visual_comparison import VisualComparator logger = logging.getLogger(__name__) def write_metrics_summary(metrics: dict, output_path: Path) -> None: """Write human-readable metrics summary. Args: metrics: Metrics dictionary output_path: Path to output file """ with open(output_path, "w") as f: f.write("=" * 80 + "\n") f.write("SAM3 EVALUATION METRICS SUMMARY\n") f.write("=" * 80 + "\n\n") aggregate = metrics["aggregate"] f.write(f"Total Images Evaluated: {aggregate['total_images']}\n\n") for threshold_str, threshold_data in aggregate["by_threshold"].items(): iou = threshold_data["iou_threshold"] f.write(f"\n{'='*80}\n") f.write(f"IoU Threshold: {iou:.0%}\n") f.write(f"{'='*80}\n\n") overall = threshold_data["overall"] f.write("Overall Metrics:\n") f.write(f" True Positives: {overall['true_positives']}\n") f.write(f" False Positives: {overall['false_positives']}\n") f.write(f" False Negatives: {overall['false_negatives']}\n") f.write(f" Precision: {overall['precision']:.2%}\n") f.write(f" Recall: {overall['recall']:.2%}\n") f.write(f" F1-Score: {overall['f1']:.2%}\n") f.write(f" mAP: {overall['map']:.2%}\n") f.write(f" mAR: {overall['mar']:.2%}\n\n") f.write("Per-Class Metrics:\n") f.write("-" * 80 + "\n") f.write(f"{'Class':<20} {'GT':>6} {'Pred':>6} {'TP':>6} {'FP':>6} {'FN':>6} {'Prec':>8} {'Rec':>8} {'F1':>8}\n") f.write("-" * 80 + "\n") for label, stats in sorted(threshold_data["by_label"].items()): f.write( f"{label:<20} " f"{stats['gt_total']:>6} " f"{stats['pred_total']:>6} " f"{stats['tp']:>6} " f"{stats['fp']:>6} " f"{stats['fn']:>6} " f"{stats['precision']:>8.2%} " f"{stats['recall']:>8.2%} " f"{stats['f1']:>8.2%}\n" ) f.write("\n") # Confusion Matrix cm = threshold_data["confusion_matrix"] labels = cm["labels"] matrix = cm["matrix"] if labels: f.write("Confusion Matrix:\n") f.write("-" * 80 + "\n") # Header header = "Actual \\ Pred |" for label in labels: header += f" {label[:10]:>10} |" f.write(header + "\n") f.write("-" * len(header) + "\n") # Rows for i, actual_label in enumerate(labels): row = f"{actual_label[:13]:>13} |" for j in range(len(labels)): row += f" {matrix[i][j]:>10} |" f.write(row + "\n") f.write("\n") f.write("=" * 80 + "\n") f.write("END OF REPORT\n") f.write("=" * 80 + "\n") logger.info(f"Wrote metrics summary to {output_path}") def main() -> int: """Main execution function. Returns: Exit code (0 for success, non-zero for failure) """ parser = argparse.ArgumentParser( description="Run SAM3 metrics evaluation against CVAT ground truth" ) parser.add_argument( "--config", type=str, default="config/config.json", help="Path to configuration file" ) parser.add_argument( "--force-download", action="store_true", help="Force re-download images from CVAT" ) parser.add_argument( "--force-inference", action="store_true", help="Force re-run SAM3 inference" ) parser.add_argument( "--skip-inference", action="store_true", help="Skip inference, use cached results" ) parser.add_argument( "--visualize", action="store_true", help="Generate visual comparisons" ) parser.add_argument( "--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging level" ) args = parser.parse_args() # Load configuration try: config = load_config(args.config) except Exception as e: print(f"ERROR: Failed to load configuration: {e}", file=sys.stderr) return 1 # Setup logging cache_dir = config.get_cache_path() log_file = cache_dir / "evaluation_log.txt" setup_logging(log_file, getattr(logging, args.log_level)) logger.info("=" * 80) logger.info("SAM3 METRICS EVALUATION") logger.info("=" * 80) try: # Phase 1: Extract from CVAT logger.info("\n" + "=" * 80) logger.info("PHASE 1: CVAT Data Extraction") logger.info("=" * 80) extractor = CVATExtractor(config) if args.force_download: logger.info("Force download enabled - will re-download all images") image_paths = extractor.run_extraction() total_extracted = sum(len(paths) for paths in image_paths.values()) logger.info(f"Extraction complete: {total_extracted} images extracted") if total_extracted == 0: logger.error("No images extracted. Aborting.") return 1 # Phase 2: Run SAM3 Inference if not args.skip_inference: logger.info("\n" + "=" * 80) logger.info("PHASE 2: SAM3 Inference") logger.info("=" * 80) inferencer = SAM3Inferencer(config) stats = inferencer.run_inference_batch(image_paths, args.force_inference) logger.info( f"Inference complete: {stats['successful']} successful, " f"{stats['failed']} failed, {stats['skipped']} skipped" ) if stats['successful'] == 0 and stats['skipped'] == 0: logger.error("No successful inferences. Aborting.") return 1 else: logger.info("Skipping inference (--skip-inference)") # Phase 3: Calculate Metrics logger.info("\n" + "=" * 80) logger.info("PHASE 3: Metrics Calculation") logger.info("=" * 80) calculator = MetricsCalculator(config) metrics = calculator.run_evaluation(cache_dir) # Save detailed metrics metrics_json_path = cache_dir / "metrics_detailed.json" with open(metrics_json_path, "w") as f: json.dump(metrics, f, indent=2) logger.info(f"Saved detailed metrics to {metrics_json_path}") # Write summary metrics_summary_path = cache_dir / "metrics_summary.txt" write_metrics_summary(metrics, metrics_summary_path) # Phase 4: Visualization (optional) if args.visualize or config.output.generate_visualizations: logger.info("\n" + "=" * 80) logger.info("PHASE 4: Visual Comparisons") logger.info("=" * 80) comparator = VisualComparator() comparison_paths = comparator.generate_all_comparisons(cache_dir) logger.info(f"Generated {len(comparison_paths)} visual comparisons") # Summary logger.info("\n" + "=" * 80) logger.info("EVALUATION COMPLETE") logger.info("=" * 80) aggregate = metrics["aggregate"] logger.info(f"Images evaluated: {aggregate['total_images']}") # Show metrics at 50% IoU threshold_50 = aggregate["by_threshold"]["0.5"] overall = threshold_50["overall"] logger.info(f"\nMetrics at 50% IoU:") logger.info(f" Precision: {overall['precision']:.2%}") logger.info(f" Recall: {overall['recall']:.2%}") logger.info(f" F1-Score: {overall['f1']:.2%}") logger.info(f" mAP: {overall['map']:.2%}") logger.info(f" mAR: {overall['mar']:.2%}") logger.info(f"\nResults saved to:") logger.info(f" Metrics Summary: {metrics_summary_path}") logger.info(f" Detailed JSON: {metrics_json_path}") logger.info(f" Execution Log: {log_file}") return 0 except KeyboardInterrupt: logger.warning("\nEvaluation interrupted by user") return 130 except Exception as e: logger.error(f"\nEvaluation failed with error: {e}", exc_info=True) return 1 if __name__ == "__main__": sys.exit(main())