""" Model evaluation script for crop disease detection """ import torch import torch.nn as nn import numpy as np import json from sklearn.metrics import classification_report, confusion_matrix, accuracy_score from sklearn.metrics import precision_recall_fscore_support import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path from .dataset import create_data_loaders from .model import create_model, load_checkpoint class ModelEvaluator: """Evaluate trained model performance""" def __init__(self, model, test_loader, class_names, device='cpu'): self.model = model self.test_loader = test_loader self.class_names = class_names self.device = device def evaluate(self): """Evaluate model on test dataset""" self.model.eval() all_preds = [] all_labels = [] all_probs = [] with torch.no_grad(): for inputs, labels in self.test_loader: inputs = inputs.to(self.device) labels = labels.to(self.device) outputs = self.model(inputs) probs = torch.softmax(outputs, dim=1) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) all_probs.extend(probs.cpu().numpy()) return np.array(all_preds), np.array(all_labels), np.array(all_probs) def calculate_metrics(self, y_true, y_pred, y_probs): """Calculate comprehensive evaluation metrics""" # Basic metrics accuracy = accuracy_score(y_true, y_pred) # Per-class metrics precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, average=None, labels=range(len(self.class_names)) ) # Macro and weighted averages precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( y_true, y_pred, average='macro' ) precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support( y_true, y_pred, average='weighted' ) # Confusion matrix cm = confusion_matrix(y_true, y_pred) # Classification report class_report = classification_report( y_true, y_pred, target_names=self.class_names, output_dict=True ) metrics = { 'accuracy': float(accuracy), 'precision_macro': float(precision_macro), 'recall_macro': float(recall_macro), 'f1_macro': float(f1_macro), 'precision_weighted': float(precision_weighted), 'recall_weighted': float(recall_weighted), 'f1_weighted': float(f1_weighted), 'per_class_metrics': { 'precision': precision.tolist(), 'recall': recall.tolist(), 'f1_score': f1.tolist(), 'support': support.tolist() }, 'confusion_matrix': cm.tolist(), 'classification_report': class_report } return metrics def plot_confusion_matrix(self, cm, save_path='outputs/confusion_matrix.png'): """Plot and save confusion matrix""" plt.figure(figsize=(12, 10)) # Normalize confusion matrix cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # Create heatmap sns.heatmap( cm_normalized, annot=True, fmt='.2f', cmap='Blues', xticklabels=[name.replace('___', '\n') for name in self.class_names], yticklabels=[name.replace('___', '\n') for name in self.class_names], cbar_kws={'label': 'Normalized Frequency'} ) plt.title('Confusion Matrix (Normalized)', fontsize=16, pad=20) plt.xlabel('Predicted Label', fontsize=12) plt.ylabel('True Label', fontsize=12) plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() # Save plot Path(save_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() print(f"Confusion matrix saved to: {save_path}") def plot_per_class_metrics(self, metrics, save_path='outputs/per_class_metrics.png'): """Plot per-class performance metrics""" fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) class_names_short = [name.replace('___', '\n') for name in self.class_names] x_pos = np.arange(len(self.class_names)) # Precision ax1.bar(x_pos, metrics['per_class_metrics']['precision'], color='skyblue', alpha=0.7) ax1.set_title('Precision per Class') ax1.set_ylabel('Precision') ax1.set_xticks(x_pos) ax1.set_xticklabels(class_names_short, rotation=45, ha='right') ax1.set_ylim(0, 1) ax1.grid(True, alpha=0.3) # Recall ax2.bar(x_pos, metrics['per_class_metrics']['recall'], color='lightcoral', alpha=0.7) ax2.set_title('Recall per Class') ax2.set_ylabel('Recall') ax2.set_xticks(x_pos) ax2.set_xticklabels(class_names_short, rotation=45, ha='right') ax2.set_ylim(0, 1) ax2.grid(True, alpha=0.3) # F1-Score ax3.bar(x_pos, metrics['per_class_metrics']['f1_score'], color='lightgreen', alpha=0.7) ax3.set_title('F1-Score per Class') ax3.set_ylabel('F1-Score') ax3.set_xticks(x_pos) ax3.set_xticklabels(class_names_short, rotation=45, ha='right') ax3.set_ylim(0, 1) ax3.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() print(f"Per-class metrics plot saved to: {save_path}") def save_results(self, metrics, save_path='outputs/results.json'): """Save evaluation results to JSON file""" # Add class names to results results = { 'class_names': self.class_names, 'num_classes': len(self.class_names), 'test_samples': len(self.test_loader.dataset), 'metrics': metrics, 'model_info': { 'architecture': 'ResNet50', 'pretrained': True, 'transfer_learning': True } } # Save to file Path(save_path).parent.mkdir(parents=True, exist_ok=True) with open(save_path, 'w') as f: json.dump(results, f, indent=2) print(f"Results saved to: {save_path}") return results def print_summary(self, metrics): """Print evaluation summary""" print("\n" + "="*60) print("MODEL EVALUATION SUMMARY") print("="*60) print(f"Test Accuracy: {metrics['accuracy']:.4f}") print(f"Precision (Macro): {metrics['precision_macro']:.4f}") print(f"Recall (Macro): {metrics['recall_macro']:.4f}") print(f"F1-Score (Macro): {metrics['f1_macro']:.4f}") print(f"F1-Score (Weighted): {metrics['f1_weighted']:.4f}") print("\nPer-Class Performance:") print("-" * 60) for i, class_name in enumerate(self.class_names): precision = metrics['per_class_metrics']['precision'][i] recall = metrics['per_class_metrics']['recall'][i] f1 = metrics['per_class_metrics']['f1_score'][i] support = metrics['per_class_metrics']['support'][i] print(f"{class_name:40} | P: {precision:.3f} | R: {recall:.3f} | F1: {f1:.3f} | N: {support:2d}") print("="*60) def evaluate_model(checkpoint_path, data_dir='data', batch_size=32): """Main evaluation function""" # Device setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load data print("Loading test dataset...") _, _, test_loader, class_names = create_data_loaders( data_dir=data_dir, batch_size=batch_size, num_workers=0 ) print(f"Test dataset loaded: {len(test_loader.dataset)} samples") # Create and load model print("Loading trained model...") model = create_model(num_classes=len(class_names), device=device) try: model, _, epoch, _ = load_checkpoint(checkpoint_path, model, device=device) print(f"Model loaded successfully from epoch {epoch}") except Exception as e: print(f"Error loading checkpoint: {e}") print("Using untrained model for testing...") # Create evaluator evaluator = ModelEvaluator(model, test_loader, class_names, device) # Run evaluation print("Evaluating model...") y_pred, y_true, y_probs = evaluator.evaluate() # Calculate metrics metrics = evaluator.calculate_metrics(y_true, y_pred, y_probs) # Print summary evaluator.print_summary(metrics) # Generate plots evaluator.plot_confusion_matrix(metrics['confusion_matrix']) evaluator.plot_per_class_metrics(metrics) # Save results results = evaluator.save_results(metrics) return results if __name__ == "__main__": # Evaluate the trained model results = evaluate_model( checkpoint_path='models/crop_disease_resnet50.pth', data_dir='data', batch_size=16 )