Spaces:
Sleeping
Sleeping
| """ | |
| Model Evaluator Module | |
| ====================== | |
| Provides comprehensive model evaluation with visualization | |
| support for confusion matrices, learning curves, and metrics. | |
| """ | |
| import os | |
| # Set environment variables before transformers import | |
| os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '3') | |
| os.environ.setdefault('TRANSFORMERS_NO_TF', '1') | |
| import json | |
| import logging | |
| from typing import Dict, List, Optional, Tuple, Any | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import torch | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_recall_fscore_support, | |
| confusion_matrix, | |
| classification_report, | |
| roc_curve, | |
| auc, | |
| precision_recall_curve | |
| ) | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # Non-interactive backend for server use | |
| import seaborn as sns | |
| logger = logging.getLogger(__name__) | |
| class EvaluationResults: | |
| """Container for evaluation results.""" | |
| accuracy: float = 0.0 | |
| precision: float = 0.0 | |
| recall: float = 0.0 | |
| f1: float = 0.0 | |
| support: int = 0 | |
| confusion_matrix: Optional[np.ndarray] = None | |
| classification_report: str = "" | |
| predictions: Optional[List[int]] = None | |
| probabilities: Optional[List[float]] = None | |
| true_labels: Optional[List[int]] = None | |
| def to_dict(self) -> dict: | |
| return { | |
| "accuracy": self.accuracy, | |
| "precision": self.precision, | |
| "recall": self.recall, | |
| "f1": self.f1, | |
| "support": self.support, | |
| "classification_report": self.classification_report | |
| } | |
| class ModelEvaluator: | |
| """ | |
| Comprehensive model evaluation with visualization support. | |
| """ | |
| def __init__(self, model=None, tokenizer=None, label_names: List[str] = None): | |
| """ | |
| Initialize evaluator. | |
| Args: | |
| model: Trained model (optional, can be loaded later) | |
| tokenizer: Tokenizer (optional, can be loaded later) | |
| label_names: List of label names for display | |
| """ | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.label_names = label_names or ["Class 0", "Class 1"] | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(self, model_path: str) -> bool: | |
| """ | |
| Load model and tokenizer from path. | |
| Args: | |
| model_path: Path to saved model directory | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"Model loaded from {model_path}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| return False | |
| def predict(self, texts: List[str], batch_size: int = 16, | |
| max_length: int = 256) -> Tuple[List[int], List[float]]: | |
| """ | |
| Make predictions on a list of texts. | |
| Args: | |
| texts: List of texts to predict | |
| batch_size: Batch size for inference | |
| max_length: Maximum sequence length | |
| Returns: | |
| Tuple of (predictions, probabilities) | |
| """ | |
| if self.model is None or self.tokenizer is None: | |
| raise ValueError("Model and tokenizer must be loaded first") | |
| self.model.eval() | |
| all_predictions = [] | |
| all_probabilities = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| encodings = self.tokenizer( | |
| batch_texts, | |
| truncation=True, | |
| padding=True, | |
| max_length=max_length, | |
| return_tensors="pt" | |
| ) | |
| encodings = {k: v.to(self.device) for k, v in encodings.items()} | |
| outputs = self.model(**encodings) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| preds = torch.argmax(probs, dim=-1) | |
| all_predictions.extend(preds.cpu().numpy().tolist()) | |
| # Get probability of positive class (class 1) | |
| all_probabilities.extend(probs[:, 1].cpu().numpy().tolist()) | |
| return all_predictions, all_probabilities | |
| def evaluate(self, texts: List[str], true_labels: List[int], | |
| batch_size: int = 16) -> EvaluationResults: | |
| """ | |
| Evaluate model on a dataset. | |
| Args: | |
| texts: List of texts | |
| true_labels: True labels | |
| batch_size: Batch size for inference | |
| Returns: | |
| EvaluationResults object | |
| """ | |
| predictions, probabilities = self.predict(texts, batch_size) | |
| # Calculate metrics | |
| accuracy = accuracy_score(true_labels, predictions) | |
| precision, recall, f1, support = precision_recall_fscore_support( | |
| true_labels, predictions, average='weighted', zero_division=0 | |
| ) | |
| cm = confusion_matrix(true_labels, predictions) | |
| report = classification_report( | |
| true_labels, predictions, | |
| target_names=self.label_names, | |
| zero_division=0 | |
| ) | |
| results = EvaluationResults( | |
| accuracy=accuracy, | |
| precision=precision, | |
| recall=recall, | |
| f1=f1, | |
| support=len(true_labels), | |
| confusion_matrix=cm, | |
| classification_report=report, | |
| predictions=predictions, | |
| probabilities=probabilities, | |
| true_labels=true_labels | |
| ) | |
| return results | |
| def plot_confusion_matrix(self, results: EvaluationResults, | |
| figsize: Tuple[int, int] = (8, 6), | |
| cmap: str = "Blues") -> plt.Figure: | |
| """ | |
| Plot confusion matrix. | |
| Args: | |
| results: EvaluationResults object | |
| figsize: Figure size | |
| cmap: Color map | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize) | |
| sns.heatmap( | |
| results.confusion_matrix, | |
| annot=True, | |
| fmt='d', | |
| cmap=cmap, | |
| xticklabels=self.label_names, | |
| yticklabels=self.label_names, | |
| ax=ax | |
| ) | |
| ax.set_xlabel('Predicted Label', fontsize=12) | |
| ax.set_ylabel('True Label', fontsize=12) | |
| ax.set_title('Confusion Matrix', fontsize=14) | |
| plt.tight_layout() | |
| return fig | |
| def plot_roc_curve(self, results: EvaluationResults, | |
| figsize: Tuple[int, int] = (8, 6)) -> plt.Figure: | |
| """ | |
| Plot ROC curve for binary classification. | |
| Args: | |
| results: EvaluationResults object | |
| figsize: Figure size | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| if results.probabilities is None or results.true_labels is None: | |
| raise ValueError("Probabilities and true labels required for ROC curve") | |
| fpr, tpr, thresholds = roc_curve(results.true_labels, results.probabilities) | |
| roc_auc = auc(fpr, tpr) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.plot(fpr, tpr, color='darkorange', lw=2, | |
| label=f'ROC curve (AUC = {roc_auc:.3f})') | |
| ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', | |
| label='Random classifier') | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('False Positive Rate', fontsize=12) | |
| ax.set_ylabel('True Positive Rate', fontsize=12) | |
| ax.set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=14) | |
| ax.legend(loc='lower right') | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def plot_precision_recall_curve(self, results: EvaluationResults, | |
| figsize: Tuple[int, int] = (8, 6)) -> plt.Figure: | |
| """ | |
| Plot precision-recall curve. | |
| Args: | |
| results: EvaluationResults object | |
| figsize: Figure size | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| if results.probabilities is None or results.true_labels is None: | |
| raise ValueError("Probabilities and true labels required") | |
| precision, recall, thresholds = precision_recall_curve( | |
| results.true_labels, results.probabilities | |
| ) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.plot(recall, precision, color='blue', lw=2) | |
| ax.fill_between(recall, precision, alpha=0.2, color='blue') | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('Recall', fontsize=12) | |
| ax.set_ylabel('Precision', fontsize=12) | |
| ax.set_title('Precision-Recall Curve', fontsize=14) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def plot_training_history(self, metrics_history: List[Dict], | |
| figsize: Tuple[int, int] = (12, 4)) -> plt.Figure: | |
| """ | |
| Plot training history (loss and metrics over epochs). | |
| Args: | |
| metrics_history: List of metric dictionaries | |
| figsize: Figure size | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| if not metrics_history: | |
| raise ValueError("No metrics history to plot") | |
| fig, axes = plt.subplots(1, 3, figsize=figsize) | |
| # Extract data | |
| epochs = [m.get('epoch', i) for i, m in enumerate(metrics_history)] | |
| train_loss = [m.get('train_loss', 0) for m in metrics_history] | |
| eval_loss = [m.get('eval_loss', 0) for m in metrics_history] | |
| accuracy = [m.get('accuracy', 0) for m in metrics_history] | |
| f1 = [m.get('f1', 0) for m in metrics_history] | |
| # Loss plot | |
| if any(train_loss): | |
| axes[0].plot(epochs, train_loss, 'b-', label='Train Loss', marker='o') | |
| if any(eval_loss): | |
| axes[0].plot(epochs, eval_loss, 'r-', label='Eval Loss', marker='s') | |
| axes[0].set_xlabel('Epoch') | |
| axes[0].set_ylabel('Loss') | |
| axes[0].set_title('Training & Validation Loss') | |
| axes[0].legend() | |
| axes[0].grid(True, alpha=0.3) | |
| # Accuracy plot | |
| if any(accuracy): | |
| axes[1].plot(epochs, accuracy, 'g-', marker='o') | |
| axes[1].set_xlabel('Epoch') | |
| axes[1].set_ylabel('Accuracy') | |
| axes[1].set_title('Accuracy over Training') | |
| axes[1].grid(True, alpha=0.3) | |
| # F1 score plot | |
| if any(f1): | |
| axes[2].plot(epochs, f1, 'm-', marker='o') | |
| axes[2].set_xlabel('Epoch') | |
| axes[2].set_ylabel('F1 Score') | |
| axes[2].set_title('F1 Score over Training') | |
| axes[2].grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def plot_class_distribution(self, labels: List[int], | |
| figsize: Tuple[int, int] = (8, 5)) -> plt.Figure: | |
| """ | |
| Plot class distribution in dataset. | |
| Args: | |
| labels: List of labels | |
| figsize: Figure size | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| unique, counts = np.unique(labels, return_counts=True) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| colors = plt.cm.Set3(np.linspace(0, 1, len(unique))) | |
| bars = ax.bar( | |
| [self.label_names[i] if i < len(self.label_names) else f"Class {i}" | |
| for i in unique], | |
| counts, | |
| color=colors | |
| ) | |
| # Add value labels on bars | |
| for bar, count in zip(bars, counts): | |
| height = bar.get_height() | |
| ax.annotate(f'{count}', | |
| xy=(bar.get_x() + bar.get_width() / 2, height), | |
| xytext=(0, 3), | |
| textcoords="offset points", | |
| ha='center', va='bottom', | |
| fontsize=12, fontweight='bold') | |
| ax.set_xlabel('Class', fontsize=12) | |
| ax.set_ylabel('Count', fontsize=12) | |
| ax.set_title('Class Distribution', fontsize=14) | |
| ax.grid(True, alpha=0.3, axis='y') | |
| plt.tight_layout() | |
| return fig | |
| def generate_report(self, results: EvaluationResults, | |
| output_path: Optional[str] = None) -> str: | |
| """ | |
| Generate a text report of evaluation results. | |
| Args: | |
| results: EvaluationResults object | |
| output_path: Optional path to save the report | |
| Returns: | |
| Report string | |
| """ | |
| report = [] | |
| report.append("=" * 60) | |
| report.append("MODEL EVALUATION REPORT") | |
| report.append("=" * 60) | |
| report.append("") | |
| report.append("OVERALL METRICS:") | |
| report.append(f" Accuracy: {results.accuracy:.4f} ({results.accuracy*100:.2f}%)") | |
| report.append(f" Precision: {results.precision:.4f}") | |
| report.append(f" Recall: {results.recall:.4f}") | |
| report.append(f" F1 Score: {results.f1:.4f}") | |
| report.append(f" Samples: {results.support}") | |
| report.append("") | |
| report.append("CLASSIFICATION REPORT:") | |
| report.append(results.classification_report) | |
| report.append("") | |
| report.append("CONFUSION MATRIX:") | |
| if results.confusion_matrix is not None: | |
| report.append(str(results.confusion_matrix)) | |
| report.append("") | |
| report.append("=" * 60) | |
| report_str = "\n".join(report) | |
| if output_path: | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| f.write(report_str) | |
| logger.info(f"Report saved to {output_path}") | |
| return report_str | |
| def save_results(self, results: EvaluationResults, output_dir: str): | |
| """ | |
| Save evaluation results to files. | |
| Args: | |
| results: EvaluationResults object | |
| output_dir: Output directory | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Save metrics as JSON | |
| metrics_path = os.path.join(output_dir, "evaluation_metrics.json") | |
| with open(metrics_path, 'w', encoding='utf-8') as f: | |
| json.dump(results.to_dict(), f, indent=2, ensure_ascii=False) | |
| # Save confusion matrix as image | |
| try: | |
| fig = self.plot_confusion_matrix(results) | |
| fig.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=150) | |
| plt.close(fig) | |
| except Exception as e: | |
| logger.warning(f"Could not save confusion matrix: {e}") | |
| # Save text report | |
| report_path = os.path.join(output_dir, "evaluation_report.txt") | |
| self.generate_report(results, report_path) | |
| logger.info(f"Results saved to {output_dir}") | |
| def create_evaluator(label_names: List[str] = None) -> ModelEvaluator: | |
| """Factory function to create a ModelEvaluator instance.""" | |
| return ModelEvaluator(label_names=label_names) | |