Boyun7's picture
upload all files
03d5bce
"""
Evaluation script for Pest and Disease Classification
Generate confusion matrix, classification report, and per-class metrics
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, f1_score
import argparse
import json
from pathlib import Path
from dataset import get_dataloaders
from model import create_model
def evaluate_model(model, dataloader, device, dataset):
"""
Evaluate model on a dataset
Returns:
predictions: List of predicted labels
true_labels: List of true labels
accuracy: Overall accuracy
"""
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
accuracy = np.mean(all_preds == all_labels)
return all_preds, all_labels, accuracy
def plot_confusion_matrix(y_true, y_pred, class_names, save_path='confusion_matrix.png'):
"""
Plot and save confusion matrix
Args:
y_true: True labels
y_pred: Predicted labels
class_names: List of class names
save_path: Path to save figure
"""
cm = confusion_matrix(y_true, y_pred)
# Calculate percentages
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
# Create figure
plt.figure(figsize=(12, 10))
# Plot with annotations
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names,
yticklabels=class_names,
cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix', fontsize=16, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to {save_path}")
# Also save percentage version
plt.figure(figsize=(12, 10))
sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues',
xticklabels=class_names,
yticklabels=class_names,
cbar_kws={'label': 'Percentage (%)'})
plt.title('Confusion Matrix (Percentage)', fontsize=16, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
save_path_percent = str(save_path).replace('.png', '_percent.png')
plt.savefig(save_path_percent, dpi=300, bbox_inches='tight')
print(f"Confusion matrix (percentage) saved to {save_path_percent}")
plt.close('all')
return cm
def generate_classification_report(y_true, y_pred, class_names, save_path='classification_report.txt'):
"""
Generate and save detailed classification report
Args:
y_true: True labels
y_pred: Predicted labels
class_names: List of class names
save_path: Path to save report
"""
# Generate report
report = classification_report(
y_true, y_pred,
target_names=class_names,
digits=4
)
# Print to console
print("\n" + "=" * 80)
print("Classification Report")
print("=" * 80)
print(report)
# Save to file
with open(save_path, 'w', encoding='utf-8') as f:
f.write("Classification Report\n")
f.write("=" * 80 + "\n")
f.write(report)
print(f"\nClassification report saved to {save_path}")
# Calculate per-class metrics
from sklearn.metrics import precision_recall_fscore_support
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average=None
)
# Create detailed metrics dictionary
metrics = {}
for i, class_name in enumerate(class_names):
metrics[class_name] = {
'precision': float(precision[i]),
'recall': float(recall[i]),
'f1-score': float(f1[i]),
'support': int(support[i])
}
# Add overall metrics
metrics['overall'] = {
'accuracy': float(np.mean(y_true == y_pred)),
'macro_avg_f1': float(np.mean(f1)),
'weighted_avg_f1': float(f1_score(y_true, y_pred, average='weighted'))
}
# Save metrics as JSON
metrics_path = str(save_path).replace('.txt', '.json')
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump(metrics, f, indent=2, ensure_ascii=False)
print(f"Metrics JSON saved to {metrics_path}")
return metrics
def plot_per_class_metrics(metrics, class_names, save_path='per_class_metrics.png'):
"""
Plot per-class precision, recall, and F1-score
Args:
metrics: Dictionary of metrics
class_names: List of class names
save_path: Path to save figure
"""
precision = [metrics[name]['precision'] for name in class_names]
recall = [metrics[name]['recall'] for name in class_names]
f1 = [metrics[name]['f1-score'] for name in class_names]
x = np.arange(len(class_names))
width = 0.25
fig, ax = plt.subplots(figsize=(14, 6))
ax.bar(x - width, precision, width, label='Precision', alpha=0.8)
ax.bar(x, recall, width, label='Recall', alpha=0.8)
ax.bar(x + width, f1, width, label='F1-Score', alpha=0.8)
ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Metrics', fontsize=14, pad=20)
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)
ax.set_ylim([0, 1.1])
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Per-class metrics plot saved to {save_path}")
plt.close()
def main(args):
"""Main evaluation function"""
print("Pest and Disease Classification Evaluation")
print("=" * 80)
print(f"Configuration:")
print(f" Checkpoint: {args.checkpoint}")
print(f" Split: {args.split}")
print(f" Batch size: {args.batch_size}")
print(f" Device: {args.device}")
print("=" * 80)
# Set device
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")
# Load data
print("\nLoading datasets...")
loaders = get_dataloaders(
csv_file=args.csv_file,
label_mapping_file=args.label_mapping,
batch_size=args.batch_size,
img_size=args.img_size,
num_workers=args.num_workers
)
# Get class names
dataset = loaders['datasets'][args.split]
class_names = [dataset.get_label_name(i) for i in range(dataset.num_classes)]
print(f"Classes: {class_names}")
# Create model
print(f"\nCreating model: {args.backbone}")
model = create_model(
num_classes=loaders['num_classes'],
backbone=args.backbone,
pretrained=False
)
# Load checkpoint
print(f"\nLoading checkpoint: {args.checkpoint}")
checkpoint = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
if 'val_acc' in checkpoint:
print(f"Checkpoint validation accuracy: {checkpoint['val_acc']:.4f}")
# Evaluate
print(f"\nEvaluating on {args.split} set...")
dataloader = loaders[args.split]
predictions, true_labels, accuracy = evaluate_model(model, dataloader, device, dataset)
print(f"\n{args.split.capitalize()} Set Accuracy: {accuracy:.4f}")
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)
# Generate confusion matrix
print("\nGenerating confusion matrix...")
cm = plot_confusion_matrix(
true_labels, predictions, class_names,
save_path=output_dir / f'confusion_matrix_{args.split}.png'
)
# Generate classification report
print("\nGenerating classification report...")
metrics = generate_classification_report(
true_labels, predictions, class_names,
save_path=output_dir / f'classification_report_{args.split}.txt'
)
# Plot per-class metrics
print("\nGenerating per-class metrics plot...")
plot_per_class_metrics(
metrics, class_names,
save_path=output_dir / f'per_class_metrics_{args.split}.png'
)
print("\n" + "=" * 80)
print("Evaluation complete!")
print(f"Results saved to {output_dir}/")
print("=" * 80)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Evaluate Pest and Disease Classifier')
# Data parameters
parser.add_argument('--csv_file', type=str, default='dataset.csv',
help='Path to dataset CSV')
parser.add_argument('--label_mapping', type=str, default='label_mapping.json',
help='Path to label mapping JSON')
# Model parameters
parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth',
help='Path to model checkpoint')
parser.add_argument('--backbone', type=str, default='resnet50',
choices=['resnet50', 'resnet101', 'efficientnet_b0',
'efficientnet_b3', 'mobilenet_v2'],
help='Model backbone')
# Evaluation parameters
parser.add_argument('--split', type=str, default='test',
choices=['train', 'val', 'test'],
help='Dataset split to evaluate')
parser.add_argument('--batch_size', type=int, default=16,
help='Batch size')
parser.add_argument('--img_size', type=int, default=224,
help='Image size')
# System parameters
parser.add_argument('--device', type=str, default='cuda',
choices=['cuda', 'cpu'],
help='Device to use')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of data loading workers')
parser.add_argument('--output_dir', type=str, default='evaluation_results',
help='Directory to save results')
args = parser.parse_args()
main(args)