Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, precision_recall_fscore_support | |
| import os | |
| import json | |
| class Evaluator: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| test_loader: DataLoader, | |
| classes: List[str], | |
| device: Optional[torch.device] = None | |
| ): | |
| self.model = model.to(device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
| self.device = self.model.device if hasattr(self.model, 'device') else device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.test_loader = test_loader | |
| self.classes = classes | |
| self.all_predictions = [] | |
| self.all_targets = [] | |
| self.all_probs = [] | |
| def predict(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| self.model.eval() | |
| all_preds = [] | |
| all_targets = [] | |
| all_probs = [] | |
| with torch.no_grad(): | |
| for inputs, targets in self.test_loader: | |
| inputs = inputs.to(self.device) | |
| outputs = self.model(inputs) | |
| probs = torch.softmax(outputs, dim=1) | |
| _, predicted = outputs.max(1) | |
| all_preds.extend(predicted.cpu().numpy()) | |
| all_targets.extend(targets.numpy()) | |
| all_probs.extend(probs.cpu().numpy()) | |
| self.all_predictions = np.array(all_preds) | |
| self.all_targets = np.array(all_targets) | |
| self.all_probs = np.array(all_probs) | |
| return self.all_predictions, self.all_targets, self.all_probs | |
| def calculate_metrics(self) -> Dict: | |
| if len(self.all_predictions) == 0: | |
| self.predict() | |
| accuracy = accuracy_score(self.all_targets, self.all_predictions) | |
| precision, recall, f1, support = precision_recall_fscore_support( | |
| self.all_targets, self.all_predictions, average='weighted' | |
| ) | |
| metrics = { | |
| 'accuracy': float(accuracy), | |
| 'precision': float(precision), | |
| 'recall': float(recall), | |
| 'f1_score': float(f1), | |
| 'num_samples': len(self.all_predictions) | |
| } | |
| return metrics | |
| def get_classification_report(self) -> str: | |
| if len(self.all_predictions) == 0: | |
| self.predict() | |
| return classification_report( | |
| self.all_targets, | |
| self.all_predictions, | |
| target_names=self.classes, | |
| digits=4 | |
| ) | |
| def get_confusion_matrix(self) -> np.ndarray: | |
| if len(self.all_predictions) == 0: | |
| self.predict() | |
| return confusion_matrix(self.all_targets, self.all_predictions) | |
| def plot_confusion_matrix(self, save_path: Optional[str] = None, figsize: Tuple[int, int] = (12, 10)): | |
| cm = self.get_confusion_matrix() | |
| plt.figure(figsize=figsize) | |
| sns.heatmap( | |
| cm, | |
| annot=False, | |
| fmt='d', | |
| cmap='Blues', | |
| xticklabels=self.classes, | |
| yticklabels=self.classes, | |
| cbar_kws={'label': 'Cantidad'} | |
| ) | |
| plt.xlabel('Predicci贸n', fontsize=12) | |
| plt.ylabel('Real', fontsize=12) | |
| plt.title('Matriz de Confusi贸n', fontsize=14) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f'Matriz de confusi贸n guardada en: {save_path}') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_per_class_accuracy(self, save_path: Optional[str] = None, figsize: Tuple[int, int] = (14, 6)): | |
| if len(self.all_predictions) == 0: | |
| self.predict() | |
| cm = self.get_confusion_matrix() | |
| class_accuracies = cm.diagonal() / cm.sum(axis=1) | |
| plt.figure(figsize=figsize) | |
| bars = plt.bar(range(len(self.classes)), class_accuracies, color='steelblue', edgecolor='black') | |
| for i, (bar, acc) in enumerate(zip(bars, class_accuracies)): | |
| if acc < 0.7: | |
| bar.set_color('coral') | |
| plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, | |
| f'{acc:.2%}', ha='center', va='bottom', fontsize=8) | |
| plt.xlabel('Clase', fontsize=12) | |
| plt.ylabel('Accuracy', fontsize=12) | |
| plt.title('Accuracy por Clase', fontsize=14) | |
| plt.xticks(range(len(self.classes)), self.classes, rotation=45, ha='right') | |
| plt.ylim(0, 1.1) | |
| plt.grid(axis='y', alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_top_errors(self, n: int = 20, save_path: Optional[str] = None): | |
| if len(self.all_predictions) == 0: | |
| self.predict() | |
| errors_idx = np.where(self.all_predictions != self.all_targets)[0] | |
| if len(errors_idx) == 0: | |
| print('No se encontraron errores!') | |
| return | |
| error_confidences = self.all_probs[errors_idx] | |
| error_confidence = np.array([error_confidences[i][self.all_predictions[errors_idx[i]]] for i in range(len(errors_idx))]) | |
| top_errors = np.argsort(error_confidence)[::-1][:n] | |
| print(f'\nTop {n} errores con mayor confianza:') | |
| print('-' * 60) | |
| for i, idx in enumerate(top_errors[:10]): | |
| error_idx = errors_idx[idx] | |
| print(f'{i+1}. Predicci贸n: {self.classes[self.all_predictions[error_idx]]} ' | |
| f'(Conf: {error_confidence[idx]:.2%}) | ' | |
| f'Real: {self.classes[self.all_targets[error_idx]]}') | |
| def save_results(self, output_dir: str): | |
| os.makedirs(output_dir, exist_ok=True) | |
| metrics = self.calculate_metrics() | |
| with open(os.path.join(output_dir, 'metrics.json'), 'w') as f: | |
| json.dump(metrics, f, indent=2) | |
| report = self.get_classification_report() | |
| with open(os.path.join(output_dir, 'classification_report.txt'), 'w') as f: | |
| f.write(report) | |
| np.save(os.path.join(output_dir, 'predictions.npy'), self.all_predictions) | |
| np.save(os.path.join(output_dir, 'targets.npy'), self.all_targets) | |
| np.save(os.path.join(output_dir, 'probabilities.npy'), self.all_probs) | |
| self.plot_confusion_matrix(os.path.join(output_dir, 'confusion_matrix.png')) | |
| self.plot_per_class_accuracy(os.path.join(output_dir, 'per_class_accuracy.png')) | |
| print(f'\nResultados guardados en: {output_dir}') | |
| def print_summary(self): | |
| metrics = self.calculate_metrics() | |
| print('\n' + '=' * 60) | |
| print('RESUMEN DE EVALUACION') | |
| print('=' * 60) | |
| acc = metrics["accuracy"] | |
| prec = metrics["precision"] | |
| rec = metrics["recall"] | |
| f1 = metrics["f1_score"] | |
| num = metrics["num_samples"] | |
| print(f'Accuracy: {acc:.4f} ({acc*100:.2f}%)') | |
| print(f'Precision (weighted): {prec:.4f}') | |
| print(f'Recall (weighted): {rec:.4f}') | |
| print(f'F1-Score (weighted): {f1:.4f}') | |
| print(f'Muestras evaluadas: {num}') | |
| print('=' * 60) | |