|
|
|
|
|
"""Main execution script for SAM3 metrics evaluation.""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
from metrics_evaluation.config.config_loader import load_config |
|
|
from metrics_evaluation.extraction.cvat_extractor import CVATExtractor |
|
|
from metrics_evaluation.inference.sam3_inference import SAM3Inferencer |
|
|
from metrics_evaluation.metrics.metrics_calculator import MetricsCalculator |
|
|
from metrics_evaluation.utils.logging_config import setup_logging |
|
|
from metrics_evaluation.visualization.visual_comparison import VisualComparator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def write_metrics_summary(metrics: dict, output_path: Path) -> None: |
|
|
"""Write human-readable metrics summary. |
|
|
|
|
|
Args: |
|
|
metrics: Metrics dictionary |
|
|
output_path: Path to output file |
|
|
""" |
|
|
with open(output_path, "w") as f: |
|
|
f.write("=" * 80 + "\n") |
|
|
f.write("SAM3 EVALUATION METRICS SUMMARY\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
aggregate = metrics["aggregate"] |
|
|
|
|
|
f.write(f"Total Images Evaluated: {aggregate['total_images']}\n\n") |
|
|
|
|
|
for threshold_str, threshold_data in aggregate["by_threshold"].items(): |
|
|
iou = threshold_data["iou_threshold"] |
|
|
f.write(f"\n{'='*80}\n") |
|
|
f.write(f"IoU Threshold: {iou:.0%}\n") |
|
|
f.write(f"{'='*80}\n\n") |
|
|
|
|
|
overall = threshold_data["overall"] |
|
|
|
|
|
f.write("Overall Metrics:\n") |
|
|
f.write(f" True Positives: {overall['true_positives']}\n") |
|
|
f.write(f" False Positives: {overall['false_positives']}\n") |
|
|
f.write(f" False Negatives: {overall['false_negatives']}\n") |
|
|
f.write(f" Precision: {overall['precision']:.2%}\n") |
|
|
f.write(f" Recall: {overall['recall']:.2%}\n") |
|
|
f.write(f" F1-Score: {overall['f1']:.2%}\n") |
|
|
f.write(f" mAP: {overall['map']:.2%}\n") |
|
|
f.write(f" mAR: {overall['mar']:.2%}\n\n") |
|
|
|
|
|
f.write("Per-Class Metrics:\n") |
|
|
f.write("-" * 80 + "\n") |
|
|
f.write(f"{'Class':<20} {'GT':>6} {'Pred':>6} {'TP':>6} {'FP':>6} {'FN':>6} {'Prec':>8} {'Rec':>8} {'F1':>8}\n") |
|
|
f.write("-" * 80 + "\n") |
|
|
|
|
|
for label, stats in sorted(threshold_data["by_label"].items()): |
|
|
f.write( |
|
|
f"{label:<20} " |
|
|
f"{stats['gt_total']:>6} " |
|
|
f"{stats['pred_total']:>6} " |
|
|
f"{stats['tp']:>6} " |
|
|
f"{stats['fp']:>6} " |
|
|
f"{stats['fn']:>6} " |
|
|
f"{stats['precision']:>8.2%} " |
|
|
f"{stats['recall']:>8.2%} " |
|
|
f"{stats['f1']:>8.2%}\n" |
|
|
) |
|
|
|
|
|
f.write("\n") |
|
|
|
|
|
|
|
|
cm = threshold_data["confusion_matrix"] |
|
|
labels = cm["labels"] |
|
|
matrix = cm["matrix"] |
|
|
|
|
|
if labels: |
|
|
f.write("Confusion Matrix:\n") |
|
|
f.write("-" * 80 + "\n") |
|
|
|
|
|
|
|
|
header = "Actual \\ Pred |" |
|
|
for label in labels: |
|
|
header += f" {label[:10]:>10} |" |
|
|
f.write(header + "\n") |
|
|
f.write("-" * len(header) + "\n") |
|
|
|
|
|
|
|
|
for i, actual_label in enumerate(labels): |
|
|
row = f"{actual_label[:13]:>13} |" |
|
|
for j in range(len(labels)): |
|
|
row += f" {matrix[i][j]:>10} |" |
|
|
f.write(row + "\n") |
|
|
|
|
|
f.write("\n") |
|
|
|
|
|
f.write("=" * 80 + "\n") |
|
|
f.write("END OF REPORT\n") |
|
|
f.write("=" * 80 + "\n") |
|
|
|
|
|
logger.info(f"Wrote metrics summary to {output_path}") |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
"""Main execution function. |
|
|
|
|
|
Returns: |
|
|
Exit code (0 for success, non-zero for failure) |
|
|
""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run SAM3 metrics evaluation against CVAT ground truth" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
default="config/config.json", |
|
|
help="Path to configuration file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--force-download", |
|
|
action="store_true", |
|
|
help="Force re-download images from CVAT" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--force-inference", |
|
|
action="store_true", |
|
|
help="Force re-run SAM3 inference" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--skip-inference", |
|
|
action="store_true", |
|
|
help="Skip inference, use cached results" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--visualize", |
|
|
action="store_true", |
|
|
help="Generate visual comparisons" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--log-level", |
|
|
type=str, |
|
|
default="INFO", |
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"], |
|
|
help="Logging level" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
try: |
|
|
config = load_config(args.config) |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to load configuration: {e}", file=sys.stderr) |
|
|
return 1 |
|
|
|
|
|
|
|
|
cache_dir = config.get_cache_path() |
|
|
log_file = cache_dir / "evaluation_log.txt" |
|
|
setup_logging(log_file, getattr(logging, args.log_level)) |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("SAM3 METRICS EVALUATION") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("PHASE 1: CVAT Data Extraction") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
extractor = CVATExtractor(config) |
|
|
|
|
|
if args.force_download: |
|
|
logger.info("Force download enabled - will re-download all images") |
|
|
|
|
|
image_paths = extractor.run_extraction() |
|
|
|
|
|
total_extracted = sum(len(paths) for paths in image_paths.values()) |
|
|
logger.info(f"Extraction complete: {total_extracted} images extracted") |
|
|
|
|
|
if total_extracted == 0: |
|
|
logger.error("No images extracted. Aborting.") |
|
|
return 1 |
|
|
|
|
|
|
|
|
if not args.skip_inference: |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("PHASE 2: SAM3 Inference") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
inferencer = SAM3Inferencer(config) |
|
|
stats = inferencer.run_inference_batch(image_paths, args.force_inference) |
|
|
|
|
|
logger.info( |
|
|
f"Inference complete: {stats['successful']} successful, " |
|
|
f"{stats['failed']} failed, {stats['skipped']} skipped" |
|
|
) |
|
|
|
|
|
if stats['successful'] == 0 and stats['skipped'] == 0: |
|
|
logger.error("No successful inferences. Aborting.") |
|
|
return 1 |
|
|
else: |
|
|
logger.info("Skipping inference (--skip-inference)") |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("PHASE 3: Metrics Calculation") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
calculator = MetricsCalculator(config) |
|
|
metrics = calculator.run_evaluation(cache_dir) |
|
|
|
|
|
|
|
|
metrics_json_path = cache_dir / "metrics_detailed.json" |
|
|
with open(metrics_json_path, "w") as f: |
|
|
json.dump(metrics, f, indent=2) |
|
|
logger.info(f"Saved detailed metrics to {metrics_json_path}") |
|
|
|
|
|
|
|
|
metrics_summary_path = cache_dir / "metrics_summary.txt" |
|
|
write_metrics_summary(metrics, metrics_summary_path) |
|
|
|
|
|
|
|
|
if args.visualize or config.output.generate_visualizations: |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("PHASE 4: Visual Comparisons") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
comparator = VisualComparator() |
|
|
comparison_paths = comparator.generate_all_comparisons(cache_dir) |
|
|
logger.info(f"Generated {len(comparison_paths)} visual comparisons") |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("EVALUATION COMPLETE") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
aggregate = metrics["aggregate"] |
|
|
logger.info(f"Images evaluated: {aggregate['total_images']}") |
|
|
|
|
|
|
|
|
threshold_50 = aggregate["by_threshold"]["0.5"] |
|
|
overall = threshold_50["overall"] |
|
|
|
|
|
logger.info(f"\nMetrics at 50% IoU:") |
|
|
logger.info(f" Precision: {overall['precision']:.2%}") |
|
|
logger.info(f" Recall: {overall['recall']:.2%}") |
|
|
logger.info(f" F1-Score: {overall['f1']:.2%}") |
|
|
logger.info(f" mAP: {overall['map']:.2%}") |
|
|
logger.info(f" mAR: {overall['mar']:.2%}") |
|
|
|
|
|
logger.info(f"\nResults saved to:") |
|
|
logger.info(f" Metrics Summary: {metrics_summary_path}") |
|
|
logger.info(f" Detailed JSON: {metrics_json_path}") |
|
|
logger.info(f" Execution Log: {log_file}") |
|
|
|
|
|
return 0 |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
logger.warning("\nEvaluation interrupted by user") |
|
|
return 130 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"\nEvaluation failed with error: {e}", exc_info=True) |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|