Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
"""Main evaluation pipeline for bean detection models."""
import json
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from torch.utils.data import DataLoader
from .metrics import BeanMetrics, DetectionMetrics, SegmentationMetrics, SizeBasedMetrics
class BeanEvaluator:
"""Main evaluator for bean detection and segmentation models."""
def __init__(self, model: torch.nn.Module, device: torch.device,
confidence_threshold: float = 0.5,
iou_thresholds: List[float] = None):
"""Initialize evaluator.
Args:
model: PyTorch model to evaluate
device: Device to run evaluation on
confidence_threshold: Minimum confidence for valid detection
iou_thresholds: IoU thresholds for mAP calculation
"""
self.model = model
self.device = device
self.confidence_threshold = confidence_threshold
self.metrics_calculator = BeanMetrics(
confidence_threshold=confidence_threshold,
iou_thresholds=iou_thresholds
)
def evaluate_detection(self, dataloader: DataLoader,
save_predictions: bool = False,
output_dir: Optional[Path] = None) -> DetectionMetrics:
"""Evaluate detection performance on dataset.
Args:
dataloader: DataLoader for evaluation dataset
save_predictions: Whether to save predictions to file
output_dir: Directory to save results
Returns:
DetectionMetrics with computed metrics
"""
self.model.eval()
all_predictions = []
all_ground_truths = []
print(f"Evaluating detection on {len(dataloader)} batches...")
with torch.no_grad():
for batch_idx, (images, targets) in enumerate(dataloader):
print(f"Processing batch {batch_idx+1}/{len(dataloader)}")
# Move to device
images = [img.to(self.device) for img in images]
# Get predictions
outputs = self.model(images)
# Process each image in batch
for i, (output, target) in enumerate(zip(outputs, targets)):
pred_dict = {
'boxes': output['boxes'].cpu(),
'scores': output['scores'].cpu(),
'labels': output['labels'].cpu()
}
gt_dict = {
'boxes': target['boxes'],
'labels': target['labels']
}
all_predictions.append(pred_dict)
all_ground_truths.append(gt_dict)
# Compute detection metrics
print("Computing detection metrics...")
detection_metrics = self.metrics_calculator.compute_detection_metrics(
all_predictions, all_ground_truths
)
# Save predictions if requested
if save_predictions and output_dir:
self._save_predictions(all_predictions, all_ground_truths, output_dir)
return detection_metrics
def evaluate_segmentation(self, dataloader: DataLoader,
output_dir: Optional[Path] = None) -> SegmentationMetrics:
"""Evaluate segmentation performance on dataset.
Args:
dataloader: DataLoader for evaluation dataset
output_dir: Directory to save results
Returns:
SegmentationMetrics with computed metrics
"""
self.model.eval()
all_predictions = []
all_ground_truths = []
print(f"Evaluating segmentation on {len(dataloader)} batches...")
with torch.no_grad():
for batch_idx, (images, targets) in enumerate(dataloader):
print(f"Processing batch {batch_idx+1}/{len(dataloader)}")
images = [img.to(self.device) for img in images]
outputs = self.model(images)
for i, (output, target) in enumerate(zip(outputs, targets)):
pred_masks = output.get('masks', torch.tensor([]))
gt_masks = target.get('masks', torch.tensor([]))
# Convert masks to binary
if len(pred_masks) > 0:
pred_masks = (pred_masks > 0.5).float()
pred_dict = {'masks': pred_masks.cpu()}
gt_dict = {'masks': gt_masks}
all_predictions.append(pred_dict)
all_ground_truths.append(gt_dict)
# Compute segmentation metrics
print("Computing segmentation metrics...")
seg_metrics = self.metrics_calculator.compute_segmentation_metrics(
all_predictions, all_ground_truths
)
return seg_metrics
def evaluate_by_size(self, dataloader: DataLoader) -> SizeBasedMetrics:
"""Evaluate performance by object size categories.
Args:
dataloader: DataLoader for evaluation dataset
Returns:
SizeBasedMetrics with size-specific performance
"""
self.model.eval()
all_predictions = []
all_ground_truths = []
print("Evaluating performance by object size...")
with torch.no_grad():
for batch_idx, (images, targets) in enumerate(dataloader):
images = [img.to(self.device) for img in images]
outputs = self.model(images)
for output, target in zip(outputs, targets):
pred_dict = {
'boxes': output['boxes'].cpu(),
'scores': output['scores'].cpu(),
'labels': output['labels'].cpu()
}
gt_dict = {
'boxes': target['boxes'],
'labels': target['labels']
}
all_predictions.append(pred_dict)
all_ground_truths.append(gt_dict)
size_metrics = self.metrics_calculator.compute_size_based_metrics(
all_predictions, all_ground_truths
)
return size_metrics
def full_evaluation(self, dataloader: DataLoader,
output_dir: Optional[Path] = None,
save_predictions: bool = True) -> Dict[str, Union[DetectionMetrics, SegmentationMetrics, SizeBasedMetrics]]:
"""Run complete evaluation suite.
Args:
dataloader: DataLoader for evaluation dataset
output_dir: Directory to save results
save_predictions: Whether to save predictions
Returns:
Dictionary with all computed metrics
"""
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
print("Starting full evaluation...")
start_time = time.time()
# Detection evaluation
detection_metrics = self.evaluate_detection(
dataloader, save_predictions=save_predictions, output_dir=output_dir
)
# Segmentation evaluation
segmentation_metrics = self.evaluate_segmentation(dataloader, output_dir)
# Size-based evaluation
size_metrics = self.evaluate_by_size(dataloader)
evaluation_time = time.time() - start_time
results = {
'detection': detection_metrics,
'segmentation': segmentation_metrics,
'size_based': size_metrics,
'evaluation_time': evaluation_time
}
# Save results to file
if output_dir:
self._save_results(results, output_dir)
print(f"Evaluation completed in {evaluation_time:.2f} seconds")
return results
def compare_models(self, model_paths: List[Path], dataloader: DataLoader,
output_dir: Path) -> Dict[str, Dict]:
"""Compare multiple models on same dataset.
Args:
model_paths: List of paths to model checkpoints
dataloader: DataLoader for evaluation dataset
output_dir: Directory to save comparison results
Returns:
Dictionary with results for each model
"""
output_dir.mkdir(parents=True, exist_ok=True)
comparison_results = {}
original_model = self.model
for model_path in model_paths:
print(f"Evaluating model: {model_path.name}")
# Load model checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
# Run evaluation
model_results = self.full_evaluation(
dataloader,
output_dir=output_dir / model_path.stem,
save_predictions=False
)
comparison_results[model_path.stem] = model_results
# Restore original model
self.model = original_model
# Save comparison results
self._save_model_comparison(comparison_results, output_dir)
return comparison_results
def confidence_analysis(self, dataloader: DataLoader,
thresholds: List[float] = None) -> Dict[str, List[float]]:
"""Analyze performance across different confidence thresholds.
Args:
dataloader: DataLoader for evaluation dataset
thresholds: List of confidence thresholds to test
Returns:
Dictionary with metrics at each threshold
"""
if thresholds is None:
thresholds = np.arange(0.1, 1.0, 0.1).tolist()
print("Analyzing performance across confidence thresholds...")
# Get all predictions first
all_predictions = []
all_ground_truths = []
self.model.eval()
with torch.no_grad():
for images, targets in dataloader:
images = [img.to(self.device) for img in images]
outputs = self.model(images)
for output, target in zip(outputs, targets):
pred_dict = {
'boxes': output['boxes'].cpu(),
'scores': output['scores'].cpu(),
'labels': output['labels'].cpu()
}
gt_dict = {
'boxes': target['boxes'],
'labels': target['labels']
}
all_predictions.append(pred_dict)
all_ground_truths.append(gt_dict)
# Evaluate at different thresholds
results = {
'thresholds': thresholds,
'precision': [],
'recall': [],
'f1_score': [],
'ap_50': []
}
original_threshold = self.metrics_calculator.confidence_threshold
for threshold in thresholds:
print(f"Evaluating at threshold {threshold:.2f}")
self.metrics_calculator.confidence_threshold = threshold
metrics = self.metrics_calculator.compute_detection_metrics(
all_predictions, all_ground_truths
)
results['precision'].append(metrics.precision)
results['recall'].append(metrics.recall)
results['f1_score'].append(metrics.f1_score)
results['ap_50'].append(metrics.ap_50)
# Restore original threshold
self.metrics_calculator.confidence_threshold = original_threshold
return results
def _save_predictions(self, predictions: List[Dict], ground_truths: List[Dict],
output_dir: Path) -> None:
"""Save predictions and ground truth to JSON file."""
output_file = output_dir / 'predictions.json'
# Convert tensors to lists for JSON serialization
serializable_preds = []
for pred in predictions:
pred_dict = {}
for k, v in pred.items():
if torch.is_tensor(v):
pred_dict[k] = v.tolist()
else:
pred_dict[k] = v
serializable_preds.append(pred_dict)
serializable_gts = []
for gt in ground_truths:
gt_dict = {}
for k, v in gt.items():
if torch.is_tensor(v):
gt_dict[k] = v.tolist()
else:
gt_dict[k] = v
serializable_gts.append(gt_dict)
with open(output_file, 'w') as f:
json.dump({
'predictions': serializable_preds,
'ground_truths': serializable_gts
}, f, indent=2)
print(f"Predictions saved to {output_file}")
def _save_results(self, results: Dict, output_dir: Path) -> None:
"""Save evaluation results to JSON file."""
output_file = output_dir / 'evaluation_results.json'
# Convert dataclasses to dictionaries
serializable_results = {}
for key, value in results.items():
if hasattr(value, '__dict__'):
serializable_results[key] = value.__dict__
else:
serializable_results[key] = value
with open(output_file, 'w') as f:
json.dump(serializable_results, f, indent=2, default=str)
print(f"Results saved to {output_file}")
def _save_model_comparison(self, comparison_results: Dict, output_dir: Path) -> None:
"""Save model comparison results."""
output_file = output_dir / 'model_comparison.json'
# Convert dataclasses to dictionaries
serializable_results = {}
for model_name, results in comparison_results.items():
serializable_results[model_name] = {}
for key, value in results.items():
if hasattr(value, '__dict__'):
serializable_results[model_name][key] = value.__dict__
else:
serializable_results[model_name][key] = value
with open(output_file, 'w') as f:
json.dump(serializable_results, f, indent=2, default=str)
print(f"Model comparison saved to {output_file}")
def print_metrics_summary(self, metrics: Union[DetectionMetrics, Dict]) -> None:
"""Print a formatted summary of metrics."""
if isinstance(metrics, dict):
detection_metrics = metrics.get('detection')
if detection_metrics:
self._print_detection_metrics(detection_metrics)
segmentation_metrics = metrics.get('segmentation')
if segmentation_metrics:
self._print_segmentation_metrics(segmentation_metrics)
size_metrics = metrics.get('size_based')
if size_metrics:
self._print_size_metrics(size_metrics)
elif isinstance(metrics, DetectionMetrics):
self._print_detection_metrics(metrics)
def _print_detection_metrics(self, metrics: DetectionMetrics) -> None:
"""Print detection metrics in formatted table."""
print("\n" + "="*60)
print("DETECTION METRICS SUMMARY")
print("="*60)
print(f"{'Metric':<25} {'Value':<15}")
print("-"*40)
print(f"{'mAP@0.5':<25} {metrics.ap_50:<15.4f}")
print(f"{'mAP@0.75':<25} {metrics.ap_75:<15.4f}")
print(f"{'mAP@0.5:0.95':<25} {metrics.ap_50_95:<15.4f}")
print(f"{'Precision':<25} {metrics.precision:<15.4f}")
print(f"{'Recall':<25} {metrics.recall:<15.4f}")
print(f"{'F1 Score':<25} {metrics.f1_score:<15.4f}")
print(f"{'Mean Confidence':<25} {metrics.confidence_mean:<15.4f}")
print(f"{'Mean IoU':<25} {metrics.iou_mean:<15.4f}")
print(f"{'Total Detections':<25} {metrics.total_detections:<15}")
print(f"{'Valid Detections':<25} {metrics.valid_detections:<15}")
print(f"{'False Positives':<25} {metrics.false_positives:<15}")
print(f"{'False Negatives':<25} {metrics.false_negatives:<15}")
def _print_segmentation_metrics(self, metrics: SegmentationMetrics) -> None:
"""Print segmentation metrics."""
print(f"\nSegmentation AP@0.5: {metrics.mask_ap_50:.4f}")
print(f"Dice Coefficient: {metrics.dice_coefficient:.4f}")
print(f"Jaccard Index: {metrics.jaccard_index:.4f}")
print(f"Pixel Accuracy: {metrics.pixel_accuracy:.4f}")
def _print_size_metrics(self, metrics: SizeBasedMetrics) -> None:
"""Print size-based metrics."""
print(f"\nSize-based Performance:")
print(f"Small objects AP: {metrics.small_ap:.4f}")
print(f"Medium objects AP: {metrics.medium_ap:.4f}")
print(f"Large objects AP: {metrics.large_ap:.4f}")
print(f"Size distribution: {metrics.size_distribution}")