Spaces:
Build error
Build error
| """ | |
| Evaluation metrics and visualization | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import ( | |
| accuracy_score, precision_score, recall_score, f1_score, | |
| confusion_matrix, classification_report, roc_curve, auc, | |
| roc_auc_score | |
| ) | |
| from sklearn.preprocessing import label_binarize | |
| from torch.cuda.amp import autocast | |
| from tqdm import tqdm | |
| from typing import Dict, List, Tuple | |
| import json | |
| import pandas as pd | |
| from pathlib import Path | |
| import config | |
| from models import get_model | |
| # Set style | |
| plt.style.use('seaborn-v0_8-whitegrid') | |
| sns.set_palette("husl") | |
| class Evaluator: | |
| """Model evaluator with comprehensive metrics""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| model_name: str, | |
| test_loader, | |
| class_names: List[str], | |
| device: str = config.DEVICE | |
| ): | |
| self.model = model.to(device) | |
| self.model_name = model_name | |
| self.test_loader = test_loader | |
| self.class_names = class_names | |
| self.num_classes = len(class_names) | |
| self.device = device | |
| self.model.eval() | |
| def get_predictions(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Get predictions, true labels, and probabilities""" | |
| all_preds = [] | |
| all_labels = [] | |
| all_probs = [] | |
| for images, labels in tqdm(self.test_loader, desc=f"Evaluating {self.model_name}"): | |
| images = images.to(self.device) | |
| with autocast(): | |
| outputs = self.model(images) | |
| probs = torch.softmax(outputs, dim=1) | |
| _, preds = outputs.max(1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.numpy()) | |
| all_probs.extend(probs.cpu().numpy()) | |
| return np.array(all_preds), np.array(all_labels), np.array(all_probs) | |
| def calculate_metrics(self) -> Dict: | |
| """Calculate all evaluation metrics""" | |
| preds, labels, probs = self.get_predictions() | |
| # Basic metrics | |
| accuracy = accuracy_score(labels, preds) * 100 | |
| precision_macro = precision_score(labels, preds, average='macro', zero_division=0) * 100 | |
| recall_macro = recall_score(labels, preds, average='macro', zero_division=0) * 100 | |
| f1_macro = f1_score(labels, preds, average='macro', zero_division=0) * 100 | |
| precision_weighted = precision_score(labels, preds, average='weighted', zero_division=0) * 100 | |
| recall_weighted = recall_score(labels, preds, average='weighted', zero_division=0) * 100 | |
| f1_weighted = f1_score(labels, preds, average='weighted', zero_division=0) * 100 | |
| # Per-class metrics | |
| precision_per_class = precision_score(labels, preds, average=None, zero_division=0) * 100 | |
| recall_per_class = recall_score(labels, preds, average=None, zero_division=0) * 100 | |
| f1_per_class = f1_score(labels, preds, average=None, zero_division=0) * 100 | |
| # ROC AUC (multi-class) | |
| labels_bin = label_binarize(labels, classes=range(self.num_classes)) | |
| try: | |
| auc_macro = roc_auc_score(labels_bin, probs, average='macro', multi_class='ovr') * 100 | |
| auc_weighted = roc_auc_score(labels_bin, probs, average='weighted', multi_class='ovr') * 100 | |
| except: | |
| auc_macro = 0.0 | |
| auc_weighted = 0.0 | |
| # Confusion matrix | |
| cm = confusion_matrix(labels, preds) | |
| metrics = { | |
| 'model_name': self.model_name, | |
| 'accuracy': accuracy, | |
| 'precision_macro': precision_macro, | |
| 'recall_macro': recall_macro, | |
| 'f1_macro': f1_macro, | |
| 'precision_weighted': precision_weighted, | |
| 'recall_weighted': recall_weighted, | |
| 'f1_weighted': f1_weighted, | |
| 'auc_roc_macro': auc_macro, | |
| 'auc_roc_weighted': auc_weighted, | |
| 'confusion_matrix': cm, | |
| 'predictions': preds, | |
| 'labels': labels, | |
| 'probabilities': probs, | |
| 'precision_per_class': precision_per_class, | |
| 'recall_per_class': recall_per_class, | |
| 'f1_per_class': f1_per_class | |
| } | |
| return metrics | |
| def print_metrics(self, metrics: Dict): | |
| """Print metrics summary""" | |
| print(f"\n{'='*60}") | |
| print(f"EVALUATION RESULTS: {metrics['model_name']}") | |
| print(f"{'='*60}") | |
| print(f"Accuracy: {metrics['accuracy']:.2f}%") | |
| print(f"Precision (macro): {metrics['precision_macro']:.2f}%") | |
| print(f"Recall (macro): {metrics['recall_macro']:.2f}%") | |
| print(f"F1-Score (macro): {metrics['f1_macro']:.2f}%") | |
| print(f"AUC-ROC (macro): {metrics['auc_roc_macro']:.2f}%") | |
| print(f"-" * 40) | |
| print(f"Precision (weighted): {metrics['precision_weighted']:.2f}%") | |
| print(f"Recall (weighted): {metrics['recall_weighted']:.2f}%") | |
| print(f"F1-Score (weighted): {metrics['f1_weighted']:.2f}%") | |
| print(f"AUC-ROC (weighted): {metrics['auc_roc_weighted']:.2f}%") | |
| def plot_confusion_matrix(metrics: Dict, class_names: List[str], save_path: Path): | |
| """Plot and save confusion matrix""" | |
| cm = metrics['confusion_matrix'] | |
| # Normalize confusion matrix | |
| cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
| plt.figure(figsize=(20, 16)) | |
| # Plot normalized confusion matrix | |
| sns.heatmap( | |
| cm_normalized, | |
| annot=True, | |
| fmt='.1%', | |
| cmap='Blues', | |
| xticklabels=class_names, | |
| yticklabels=class_names, | |
| cbar_kws={'label': 'Percentage'} | |
| ) | |
| plt.title(f'Confusion Matrix - {metrics["model_name"]}\nAccuracy: {metrics["accuracy"]:.2f}%', | |
| fontsize=14, fontweight='bold') | |
| plt.xlabel('Predicted Label', fontsize=12) | |
| plt.ylabel('True Label', fontsize=12) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| print(f"Confusion matrix saved to {save_path}") | |
| def plot_roc_curves(metrics: Dict, class_names: List[str], save_path: Path): | |
| """Plot ROC curves for all classes""" | |
| labels = metrics['labels'] | |
| probs = metrics['probabilities'] | |
| num_classes = len(class_names) | |
| # Binarize labels | |
| labels_bin = label_binarize(labels, classes=range(num_classes)) | |
| plt.figure(figsize=(14, 10)) | |
| # Plot ROC curve for each class | |
| colors = plt.cm.tab20(np.linspace(0, 1, num_classes)) | |
| for i in range(num_classes): | |
| fpr, tpr, _ = roc_curve(labels_bin[:, i], probs[:, i]) | |
| roc_auc = auc(fpr, tpr) | |
| plt.plot(fpr, tpr, color=colors[i], lw=1.5, alpha=0.7, | |
| label=f'{class_names[i]} (AUC={roc_auc:.3f})') | |
| # Plot diagonal | |
| plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random (AUC=0.500)') | |
| plt.xlim([0.0, 1.0]) | |
| plt.ylim([0.0, 1.05]) | |
| plt.xlabel('False Positive Rate', fontsize=12) | |
| plt.ylabel('True Positive Rate', fontsize=12) | |
| plt.title(f'ROC Curves - {metrics["model_name"]}\nMacro AUC: {metrics["auc_roc_macro"]:.2f}%', | |
| fontsize=14, fontweight='bold') | |
| plt.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), fontsize=8) | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| print(f"ROC curves saved to {save_path}") | |
| def plot_training_history(history: Dict, model_name: str, save_path: Path): | |
| """Plot training history""" | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) | |
| epochs = range(1, len(history['train_loss']) + 1) | |
| # Loss plot | |
| axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2) | |
| axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2) | |
| axes[0, 0].set_xlabel('Epoch') | |
| axes[0, 0].set_ylabel('Loss') | |
| axes[0, 0].set_title('Training & Validation Loss') | |
| axes[0, 0].legend() | |
| axes[0, 0].grid(True, alpha=0.3) | |
| # Accuracy plot | |
| axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2) | |
| axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2) | |
| axes[0, 1].set_xlabel('Epoch') | |
| axes[0, 1].set_ylabel('Accuracy (%)') | |
| axes[0, 1].set_title('Training & Validation Accuracy') | |
| axes[0, 1].legend() | |
| axes[0, 1].grid(True, alpha=0.3) | |
| # Learning rate plot | |
| axes[1, 0].plot(epochs, history['lr'], 'g-', linewidth=2) | |
| axes[1, 0].set_xlabel('Epoch') | |
| axes[1, 0].set_ylabel('Learning Rate') | |
| axes[1, 0].set_title('Learning Rate Schedule') | |
| axes[1, 0].grid(True, alpha=0.3) | |
| axes[1, 0].set_yscale('log') | |
| # Text summary | |
| axes[1, 1].axis('off') | |
| summary_text = f""" | |
| Model: {model_name} | |
| Training Summary: | |
| ───────────────────────── | |
| Best Val Accuracy: {history['best_val_acc']:.2f}% | |
| Training Time: {history['training_time']/60:.2f} min | |
| Total Epochs: {len(epochs)} | |
| Final Train Loss: {history['train_loss'][-1]:.4f} | |
| Final Val Loss: {history['val_loss'][-1]:.4f} | |
| """ | |
| axes[1, 1].text(0.1, 0.5, summary_text, fontsize=12, fontfamily='monospace', | |
| verticalalignment='center', bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.5)) | |
| plt.suptitle(f'Training History - {model_name}', fontsize=14, fontweight='bold') | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| print(f"Training history saved to {save_path}") | |
| def plot_model_comparison(all_metrics: List[Dict], save_path: Path): | |
| """Plot comparison of all models""" | |
| fig, axes = plt.subplots(2, 2, figsize=(16, 12)) | |
| model_names = [m['model_name'] for m in all_metrics] | |
| # Metrics for comparison | |
| metrics_to_compare = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro', 'auc_roc_macro'] | |
| metric_labels = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC'] | |
| # Bar chart comparison | |
| x = np.arange(len(model_names)) | |
| width = 0.15 | |
| for i, (metric, label) in enumerate(zip(metrics_to_compare, metric_labels)): | |
| values = [m[metric] for m in all_metrics] | |
| axes[0, 0].bar(x + i * width, values, width, label=label) | |
| axes[0, 0].set_xlabel('Model') | |
| axes[0, 0].set_ylabel('Score (%)') | |
| axes[0, 0].set_title('Model Comparison - All Metrics') | |
| axes[0, 0].set_xticks(x + width * 2) | |
| axes[0, 0].set_xticklabels(model_names, rotation=45, ha='right') | |
| axes[0, 0].legend() | |
| axes[0, 0].grid(True, alpha=0.3, axis='y') | |
| axes[0, 0].set_ylim([0, 105]) | |
| # Accuracy comparison (horizontal bar) | |
| accuracies = [m['accuracy'] for m in all_metrics] | |
| colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(model_names))) | |
| bars = axes[0, 1].barh(model_names, accuracies, color=colors) | |
| axes[0, 1].set_xlabel('Accuracy (%)') | |
| axes[0, 1].set_title('Model Accuracy Comparison') | |
| axes[0, 1].set_xlim([0, 105]) | |
| for bar, acc in zip(bars, accuracies): | |
| axes[0, 1].text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, | |
| f'{acc:.2f}%', va='center', fontweight='bold') | |
| axes[0, 1].grid(True, alpha=0.3, axis='x') | |
| # F1-Score comparison | |
| f1_scores = [m['f1_macro'] for m in all_metrics] | |
| bars = axes[1, 0].barh(model_names, f1_scores, color=colors) | |
| axes[1, 0].set_xlabel('F1-Score (%)') | |
| axes[1, 0].set_title('Model F1-Score Comparison (Macro)') | |
| axes[1, 0].set_xlim([0, 105]) | |
| for bar, f1 in zip(bars, f1_scores): | |
| axes[1, 0].text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, | |
| f'{f1:.2f}%', va='center', fontweight='bold') | |
| axes[1, 0].grid(True, alpha=0.3, axis='x') | |
| # AUC-ROC comparison | |
| auc_scores = [m['auc_roc_macro'] for m in all_metrics] | |
| bars = axes[1, 1].barh(model_names, auc_scores, color=colors) | |
| axes[1, 1].set_xlabel('AUC-ROC (%)') | |
| axes[1, 1].set_title('Model AUC-ROC Comparison (Macro)') | |
| axes[1, 1].set_xlim([0, 105]) | |
| for bar, auc_val in zip(bars, auc_scores): | |
| axes[1, 1].text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, | |
| f'{auc_val:.2f}%', va='center', fontweight='bold') | |
| axes[1, 1].grid(True, alpha=0.3, axis='x') | |
| plt.suptitle('Model Performance Comparison\nIndonesian Herbal Plants Classification', | |
| fontsize=14, fontweight='bold') | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| print(f"Model comparison saved to {save_path}") | |
| def plot_per_class_metrics(all_metrics: List[Dict], class_names: List[str], save_path: Path): | |
| """Plot per-class F1 scores for all models""" | |
| fig, axes = plt.subplots(1, 1, figsize=(20, 10)) | |
| model_names = [m['model_name'] for m in all_metrics] | |
| x = np.arange(len(class_names)) | |
| width = 0.15 | |
| for i, metrics in enumerate(all_metrics): | |
| f1_per_class = metrics['f1_per_class'] | |
| axes.bar(x + i * width, f1_per_class, width, label=metrics['model_name'], alpha=0.8) | |
| axes.set_xlabel('Class', fontsize=12) | |
| axes.set_ylabel('F1-Score (%)', fontsize=12) | |
| axes.set_title('Per-Class F1-Score Comparison', fontsize=14, fontweight='bold') | |
| axes.set_xticks(x + width * 2) | |
| axes.set_xticklabels(class_names, rotation=45, ha='right') | |
| axes.legend() | |
| axes.grid(True, alpha=0.3, axis='y') | |
| axes.set_ylim([0, 105]) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| print(f"Per-class metrics saved to {save_path}") | |
| def create_results_table(all_metrics: List[Dict], save_path: Path): | |
| """Create and save results table""" | |
| data = [] | |
| for m in all_metrics: | |
| data.append({ | |
| 'Model': m['model_name'], | |
| 'Accuracy (%)': f"{m['accuracy']:.2f}", | |
| 'Precision (%)': f"{m['precision_macro']:.2f}", | |
| 'Recall (%)': f"{m['recall_macro']:.2f}", | |
| 'F1-Score (%)': f"{m['f1_macro']:.2f}", | |
| 'AUC-ROC (%)': f"{m['auc_roc_macro']:.2f}" | |
| }) | |
| df = pd.DataFrame(data) | |
| # Save as CSV | |
| df.to_csv(save_path.with_suffix('.csv'), index=False) | |
| # Create table image | |
| fig, ax = plt.subplots(figsize=(14, 4)) | |
| ax.axis('off') | |
| ax.axis('tight') | |
| table = ax.table( | |
| cellText=df.values, | |
| colLabels=df.columns, | |
| cellLoc='center', | |
| loc='center', | |
| colColours=['#4CAF50'] * len(df.columns) | |
| ) | |
| table.auto_set_font_size(False) | |
| table.set_fontsize(11) | |
| table.scale(1.2, 1.8) | |
| # Style header | |
| for i in range(len(df.columns)): | |
| table[(0, i)].set_text_props(weight='bold', color='white') | |
| plt.title('Model Evaluation Results Summary\nIndonesian Herbal Plants Classification', | |
| fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| plt.savefig(save_path.with_suffix('.png'), dpi=150, bbox_inches='tight', | |
| facecolor='white', edgecolor='none') | |
| plt.close() | |
| print(f"Results table saved to {save_path}") | |
| return df | |
| def evaluate_all_models(test_loader, class_names: List[str], training_results: Dict = None): | |
| """Evaluate all trained models""" | |
| print("\n" + "="*70) | |
| print("EVALUATING ALL MODELS") | |
| print("="*70) | |
| all_metrics = [] | |
| for model_name in config.MODEL_NAMES: | |
| print(f"\nLoading {model_name}...") | |
| # Load model | |
| model_path = config.MODELS_DIR / f"{model_name.lower()}.pth" | |
| if not model_path.exists(): | |
| print(f" Model not found: {model_path}") | |
| continue | |
| checkpoint = torch.load(model_path, map_location=config.DEVICE) | |
| model = get_model(model_name, len(class_names), pretrained=False) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| # Evaluate | |
| evaluator = Evaluator(model, model_name, test_loader, class_names) | |
| metrics = evaluator.calculate_metrics() | |
| evaluator.print_metrics(metrics) | |
| all_metrics.append(metrics) | |
| # Plot confusion matrix | |
| cm_path = config.PLOTS_DIR / f"confusion_matrix_{model_name.lower()}.png" | |
| plot_confusion_matrix(metrics, class_names, cm_path) | |
| # Plot ROC curves | |
| roc_path = config.PLOTS_DIR / f"roc_curves_{model_name.lower()}.png" | |
| plot_roc_curves(metrics, class_names, roc_path) | |
| # Plot training history if available | |
| if training_results and model_name in training_results: | |
| history = training_results[model_name]['history'] | |
| history_path = config.PLOTS_DIR / f"training_history_{model_name.lower()}.png" | |
| plot_training_history(history, model_name, history_path) | |
| if len(all_metrics) > 0: | |
| # Plot model comparison | |
| comparison_path = config.PLOTS_DIR / "model_comparison.png" | |
| plot_model_comparison(all_metrics, comparison_path) | |
| # Plot per-class metrics | |
| per_class_path = config.PLOTS_DIR / "per_class_f1_comparison.png" | |
| plot_per_class_metrics(all_metrics, class_names, per_class_path) | |
| # Create results table | |
| table_path = config.PLOTS_DIR / "results_table" | |
| results_df = create_results_table(all_metrics, table_path) | |
| print("\n" + "="*70) | |
| print("FINAL RESULTS SUMMARY") | |
| print("="*70) | |
| print(results_df.to_string(index=False)) | |
| # Find best model | |
| best_idx = np.argmax([m['accuracy'] for m in all_metrics]) | |
| best_model = all_metrics[best_idx] | |
| print(f"\n🏆 BEST MODEL: {best_model['model_name']}") | |
| print(f" Accuracy: {best_model['accuracy']:.2f}%") | |
| print(f" F1-Score: {best_model['f1_macro']:.2f}%") | |
| print(f" AUC-ROC: {best_model['auc_roc_macro']:.2f}%") | |
| return all_metrics | |
| if __name__ == "__main__": | |
| from dataset import create_data_loaders | |
| _, _, test_loader, class_names = create_data_loaders() | |
| all_metrics = evaluate_all_models(test_loader, class_names) | |