""" Evaluation script for Pest and Disease Classification Generate confusion matrix, classification report, and per-class metrics """ import torch import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix, classification_report, f1_score import argparse import json from pathlib import Path from dataset import get_dataloaders from model import create_model def evaluate_model(model, dataloader, device, dataset): """ Evaluate model on a dataset Returns: predictions: List of predicted labels true_labels: List of true labels accuracy: Overall accuracy """ model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) all_preds = np.array(all_preds) all_labels = np.array(all_labels) accuracy = np.mean(all_preds == all_labels) return all_preds, all_labels, accuracy def plot_confusion_matrix(y_true, y_pred, class_names, save_path='confusion_matrix.png'): """ Plot and save confusion matrix Args: y_true: True labels y_pred: Predicted labels class_names: List of class names save_path: Path to save figure """ cm = confusion_matrix(y_true, y_pred) # Calculate percentages cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 # Create figure plt.figure(figsize=(12, 10)) # Plot with annotations sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, cbar_kws={'label': 'Count'}) plt.title('Confusion Matrix', fontsize=16, pad=20) plt.ylabel('True Label', fontsize=12) plt.xlabel('Predicted Label', fontsize=12) plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Confusion matrix saved to {save_path}") # Also save percentage version plt.figure(figsize=(12, 10)) sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues', xticklabels=class_names, yticklabels=class_names, cbar_kws={'label': 'Percentage (%)'}) plt.title('Confusion Matrix (Percentage)', fontsize=16, pad=20) plt.ylabel('True Label', fontsize=12) plt.xlabel('Predicted Label', fontsize=12) plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() save_path_percent = str(save_path).replace('.png', '_percent.png') plt.savefig(save_path_percent, dpi=300, bbox_inches='tight') print(f"Confusion matrix (percentage) saved to {save_path_percent}") plt.close('all') return cm def generate_classification_report(y_true, y_pred, class_names, save_path='classification_report.txt'): """ Generate and save detailed classification report Args: y_true: True labels y_pred: Predicted labels class_names: List of class names save_path: Path to save report """ # Generate report report = classification_report( y_true, y_pred, target_names=class_names, digits=4 ) # Print to console print("\n" + "=" * 80) print("Classification Report") print("=" * 80) print(report) # Save to file with open(save_path, 'w', encoding='utf-8') as f: f.write("Classification Report\n") f.write("=" * 80 + "\n") f.write(report) print(f"\nClassification report saved to {save_path}") # Calculate per-class metrics from sklearn.metrics import precision_recall_fscore_support precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, average=None ) # Create detailed metrics dictionary metrics = {} for i, class_name in enumerate(class_names): metrics[class_name] = { 'precision': float(precision[i]), 'recall': float(recall[i]), 'f1-score': float(f1[i]), 'support': int(support[i]) } # Add overall metrics metrics['overall'] = { 'accuracy': float(np.mean(y_true == y_pred)), 'macro_avg_f1': float(np.mean(f1)), 'weighted_avg_f1': float(f1_score(y_true, y_pred, average='weighted')) } # Save metrics as JSON metrics_path = str(save_path).replace('.txt', '.json') with open(metrics_path, 'w', encoding='utf-8') as f: json.dump(metrics, f, indent=2, ensure_ascii=False) print(f"Metrics JSON saved to {metrics_path}") return metrics def plot_per_class_metrics(metrics, class_names, save_path='per_class_metrics.png'): """ Plot per-class precision, recall, and F1-score Args: metrics: Dictionary of metrics class_names: List of class names save_path: Path to save figure """ precision = [metrics[name]['precision'] for name in class_names] recall = [metrics[name]['recall'] for name in class_names] f1 = [metrics[name]['f1-score'] for name in class_names] x = np.arange(len(class_names)) width = 0.25 fig, ax = plt.subplots(figsize=(14, 6)) ax.bar(x - width, precision, width, label='Precision', alpha=0.8) ax.bar(x, recall, width, label='Recall', alpha=0.8) ax.bar(x + width, f1, width, label='F1-Score', alpha=0.8) ax.set_xlabel('Class', fontsize=12) ax.set_ylabel('Score', fontsize=12) ax.set_title('Per-Class Metrics', fontsize=14, pad=20) ax.set_xticks(x) ax.set_xticklabels(class_names, rotation=45, ha='right') ax.legend() ax.grid(axis='y', alpha=0.3) ax.set_ylim([0, 1.1]) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Per-class metrics plot saved to {save_path}") plt.close() def main(args): """Main evaluation function""" print("Pest and Disease Classification Evaluation") print("=" * 80) print(f"Configuration:") print(f" Checkpoint: {args.checkpoint}") print(f" Split: {args.split}") print(f" Batch size: {args.batch_size}") print(f" Device: {args.device}") print("=" * 80) # Set device device = torch.device(args.device if torch.cuda.is_available() else 'cpu') print(f"\nUsing device: {device}") # Load data print("\nLoading datasets...") loaders = get_dataloaders( csv_file=args.csv_file, label_mapping_file=args.label_mapping, batch_size=args.batch_size, img_size=args.img_size, num_workers=args.num_workers ) # Get class names dataset = loaders['datasets'][args.split] class_names = [dataset.get_label_name(i) for i in range(dataset.num_classes)] print(f"Classes: {class_names}") # Create model print(f"\nCreating model: {args.backbone}") model = create_model( num_classes=loaders['num_classes'], backbone=args.backbone, pretrained=False ) # Load checkpoint print(f"\nLoading checkpoint: {args.checkpoint}") checkpoint = torch.load(args.checkpoint, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) if 'val_acc' in checkpoint: print(f"Checkpoint validation accuracy: {checkpoint['val_acc']:.4f}") # Evaluate print(f"\nEvaluating on {args.split} set...") dataloader = loaders[args.split] predictions, true_labels, accuracy = evaluate_model(model, dataloader, device, dataset) print(f"\n{args.split.capitalize()} Set Accuracy: {accuracy:.4f}") # Create output directory output_dir = Path(args.output_dir) output_dir.mkdir(exist_ok=True) # Generate confusion matrix print("\nGenerating confusion matrix...") cm = plot_confusion_matrix( true_labels, predictions, class_names, save_path=output_dir / f'confusion_matrix_{args.split}.png' ) # Generate classification report print("\nGenerating classification report...") metrics = generate_classification_report( true_labels, predictions, class_names, save_path=output_dir / f'classification_report_{args.split}.txt' ) # Plot per-class metrics print("\nGenerating per-class metrics plot...") plot_per_class_metrics( metrics, class_names, save_path=output_dir / f'per_class_metrics_{args.split}.png' ) print("\n" + "=" * 80) print("Evaluation complete!") print(f"Results saved to {output_dir}/") print("=" * 80) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Evaluate Pest and Disease Classifier') # Data parameters parser.add_argument('--csv_file', type=str, default='dataset.csv', help='Path to dataset CSV') parser.add_argument('--label_mapping', type=str, default='label_mapping.json', help='Path to label mapping JSON') # Model parameters parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth', help='Path to model checkpoint') parser.add_argument('--backbone', type=str, default='resnet50', choices=['resnet50', 'resnet101', 'efficientnet_b0', 'efficientnet_b3', 'mobilenet_v2'], help='Model backbone') # Evaluation parameters parser.add_argument('--split', type=str, default='test', choices=['train', 'val', 'test'], help='Dataset split to evaluate') parser.add_argument('--batch_size', type=int, default=16, help='Batch size') parser.add_argument('--img_size', type=int, default=224, help='Image size') # System parameters parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='Device to use') parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers') parser.add_argument('--output_dir', type=str, default='evaluation_results', help='Directory to save results') args = parser.parse_args() main(args)