Spaces:
Sleeping
Sleeping
| """ | |
| Evaluation script for Pest and Disease Classification | |
| Generate confusion matrix, classification report, and per-class metrics | |
| """ | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix, classification_report, f1_score | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from dataset import get_dataloaders | |
| from model import create_model | |
| def evaluate_model(model, dataloader, device, dataset): | |
| """ | |
| Evaluate model on a dataset | |
| Returns: | |
| predictions: List of predicted labels | |
| true_labels: List of true labels | |
| accuracy: Overall accuracy | |
| """ | |
| model.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for inputs, labels in dataloader: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| outputs = model(inputs) | |
| _, preds = torch.max(outputs, 1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| all_preds = np.array(all_preds) | |
| all_labels = np.array(all_labels) | |
| accuracy = np.mean(all_preds == all_labels) | |
| return all_preds, all_labels, accuracy | |
| def plot_confusion_matrix(y_true, y_pred, class_names, save_path='confusion_matrix.png'): | |
| """ | |
| Plot and save confusion matrix | |
| Args: | |
| y_true: True labels | |
| y_pred: Predicted labels | |
| class_names: List of class names | |
| save_path: Path to save figure | |
| """ | |
| cm = confusion_matrix(y_true, y_pred) | |
| # Calculate percentages | |
| cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 | |
| # Create figure | |
| plt.figure(figsize=(12, 10)) | |
| # Plot with annotations | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=class_names, | |
| yticklabels=class_names, | |
| cbar_kws={'label': 'Count'}) | |
| plt.title('Confusion Matrix', fontsize=16, pad=20) | |
| plt.ylabel('True Label', fontsize=12) | |
| plt.xlabel('Predicted Label', fontsize=12) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"Confusion matrix saved to {save_path}") | |
| # Also save percentage version | |
| plt.figure(figsize=(12, 10)) | |
| sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues', | |
| xticklabels=class_names, | |
| yticklabels=class_names, | |
| cbar_kws={'label': 'Percentage (%)'}) | |
| plt.title('Confusion Matrix (Percentage)', fontsize=16, pad=20) | |
| plt.ylabel('True Label', fontsize=12) | |
| plt.xlabel('Predicted Label', fontsize=12) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| save_path_percent = str(save_path).replace('.png', '_percent.png') | |
| plt.savefig(save_path_percent, dpi=300, bbox_inches='tight') | |
| print(f"Confusion matrix (percentage) saved to {save_path_percent}") | |
| plt.close('all') | |
| return cm | |
| def generate_classification_report(y_true, y_pred, class_names, save_path='classification_report.txt'): | |
| """ | |
| Generate and save detailed classification report | |
| Args: | |
| y_true: True labels | |
| y_pred: Predicted labels | |
| class_names: List of class names | |
| save_path: Path to save report | |
| """ | |
| # Generate report | |
| report = classification_report( | |
| y_true, y_pred, | |
| target_names=class_names, | |
| digits=4 | |
| ) | |
| # Print to console | |
| print("\n" + "=" * 80) | |
| print("Classification Report") | |
| print("=" * 80) | |
| print(report) | |
| # Save to file | |
| with open(save_path, 'w', encoding='utf-8') as f: | |
| f.write("Classification Report\n") | |
| f.write("=" * 80 + "\n") | |
| f.write(report) | |
| print(f"\nClassification report saved to {save_path}") | |
| # Calculate per-class metrics | |
| from sklearn.metrics import precision_recall_fscore_support | |
| precision, recall, f1, support = precision_recall_fscore_support( | |
| y_true, y_pred, average=None | |
| ) | |
| # Create detailed metrics dictionary | |
| metrics = {} | |
| for i, class_name in enumerate(class_names): | |
| metrics[class_name] = { | |
| 'precision': float(precision[i]), | |
| 'recall': float(recall[i]), | |
| 'f1-score': float(f1[i]), | |
| 'support': int(support[i]) | |
| } | |
| # Add overall metrics | |
| metrics['overall'] = { | |
| 'accuracy': float(np.mean(y_true == y_pred)), | |
| 'macro_avg_f1': float(np.mean(f1)), | |
| 'weighted_avg_f1': float(f1_score(y_true, y_pred, average='weighted')) | |
| } | |
| # Save metrics as JSON | |
| metrics_path = str(save_path).replace('.txt', '.json') | |
| with open(metrics_path, 'w', encoding='utf-8') as f: | |
| json.dump(metrics, f, indent=2, ensure_ascii=False) | |
| print(f"Metrics JSON saved to {metrics_path}") | |
| return metrics | |
| def plot_per_class_metrics(metrics, class_names, save_path='per_class_metrics.png'): | |
| """ | |
| Plot per-class precision, recall, and F1-score | |
| Args: | |
| metrics: Dictionary of metrics | |
| class_names: List of class names | |
| save_path: Path to save figure | |
| """ | |
| precision = [metrics[name]['precision'] for name in class_names] | |
| recall = [metrics[name]['recall'] for name in class_names] | |
| f1 = [metrics[name]['f1-score'] for name in class_names] | |
| x = np.arange(len(class_names)) | |
| width = 0.25 | |
| fig, ax = plt.subplots(figsize=(14, 6)) | |
| ax.bar(x - width, precision, width, label='Precision', alpha=0.8) | |
| ax.bar(x, recall, width, label='Recall', alpha=0.8) | |
| ax.bar(x + width, f1, width, label='F1-Score', alpha=0.8) | |
| ax.set_xlabel('Class', fontsize=12) | |
| ax.set_ylabel('Score', fontsize=12) | |
| ax.set_title('Per-Class Metrics', fontsize=14, pad=20) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(class_names, rotation=45, ha='right') | |
| ax.legend() | |
| ax.grid(axis='y', alpha=0.3) | |
| ax.set_ylim([0, 1.1]) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f"Per-class metrics plot saved to {save_path}") | |
| plt.close() | |
| def main(args): | |
| """Main evaluation function""" | |
| print("Pest and Disease Classification Evaluation") | |
| print("=" * 80) | |
| print(f"Configuration:") | |
| print(f" Checkpoint: {args.checkpoint}") | |
| print(f" Split: {args.split}") | |
| print(f" Batch size: {args.batch_size}") | |
| print(f" Device: {args.device}") | |
| print("=" * 80) | |
| # Set device | |
| device = torch.device(args.device if torch.cuda.is_available() else 'cpu') | |
| print(f"\nUsing device: {device}") | |
| # Load data | |
| print("\nLoading datasets...") | |
| loaders = get_dataloaders( | |
| csv_file=args.csv_file, | |
| label_mapping_file=args.label_mapping, | |
| batch_size=args.batch_size, | |
| img_size=args.img_size, | |
| num_workers=args.num_workers | |
| ) | |
| # Get class names | |
| dataset = loaders['datasets'][args.split] | |
| class_names = [dataset.get_label_name(i) for i in range(dataset.num_classes)] | |
| print(f"Classes: {class_names}") | |
| # Create model | |
| print(f"\nCreating model: {args.backbone}") | |
| model = create_model( | |
| num_classes=loaders['num_classes'], | |
| backbone=args.backbone, | |
| pretrained=False | |
| ) | |
| # Load checkpoint | |
| print(f"\nLoading checkpoint: {args.checkpoint}") | |
| checkpoint = torch.load(args.checkpoint, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(device) | |
| if 'val_acc' in checkpoint: | |
| print(f"Checkpoint validation accuracy: {checkpoint['val_acc']:.4f}") | |
| # Evaluate | |
| print(f"\nEvaluating on {args.split} set...") | |
| dataloader = loaders[args.split] | |
| predictions, true_labels, accuracy = evaluate_model(model, dataloader, device, dataset) | |
| print(f"\n{args.split.capitalize()} Set Accuracy: {accuracy:.4f}") | |
| # Create output directory | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(exist_ok=True) | |
| # Generate confusion matrix | |
| print("\nGenerating confusion matrix...") | |
| cm = plot_confusion_matrix( | |
| true_labels, predictions, class_names, | |
| save_path=output_dir / f'confusion_matrix_{args.split}.png' | |
| ) | |
| # Generate classification report | |
| print("\nGenerating classification report...") | |
| metrics = generate_classification_report( | |
| true_labels, predictions, class_names, | |
| save_path=output_dir / f'classification_report_{args.split}.txt' | |
| ) | |
| # Plot per-class metrics | |
| print("\nGenerating per-class metrics plot...") | |
| plot_per_class_metrics( | |
| metrics, class_names, | |
| save_path=output_dir / f'per_class_metrics_{args.split}.png' | |
| ) | |
| print("\n" + "=" * 80) | |
| print("Evaluation complete!") | |
| print(f"Results saved to {output_dir}/") | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Evaluate Pest and Disease Classifier') | |
| # Data parameters | |
| parser.add_argument('--csv_file', type=str, default='dataset.csv', | |
| help='Path to dataset CSV') | |
| parser.add_argument('--label_mapping', type=str, default='label_mapping.json', | |
| help='Path to label mapping JSON') | |
| # Model parameters | |
| parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth', | |
| help='Path to model checkpoint') | |
| parser.add_argument('--backbone', type=str, default='resnet50', | |
| choices=['resnet50', 'resnet101', 'efficientnet_b0', | |
| 'efficientnet_b3', 'mobilenet_v2'], | |
| help='Model backbone') | |
| # Evaluation parameters | |
| parser.add_argument('--split', type=str, default='test', | |
| choices=['train', 'val', 'test'], | |
| help='Dataset split to evaluate') | |
| parser.add_argument('--batch_size', type=int, default=16, | |
| help='Batch size') | |
| parser.add_argument('--img_size', type=int, default=224, | |
| help='Image size') | |
| # System parameters | |
| parser.add_argument('--device', type=str, default='cuda', | |
| choices=['cuda', 'cpu'], | |
| help='Device to use') | |
| parser.add_argument('--num_workers', type=int, default=4, | |
| help='Number of data loading workers') | |
| parser.add_argument('--output_dir', type=str, default='evaluation_results', | |
| help='Directory to save results') | |
| args = parser.parse_args() | |
| main(args) | |