import torch import numpy as np from sklearn.metrics import ( r2_score, mean_squared_error, mean_absolute_error, f1_score, precision_score, recall_score, roc_auc_score, precision_recall_curve, auc, matthews_corrcoef, confusion_matrix, hamming_loss, accuracy_score, make_scorer, ) from scipy.stats import pearsonr, spearmanr from transformers import EvalPrediction def softmax(x: np.ndarray) -> np.ndarray: return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True) def regression_scorer(): def dual_score(y_true, y_pred): return spearmanr(y_true, y_pred).correlation * r2_score(y_true, y_pred) return dual_score def classification_scorer(): def mcc_scorer(y_true, y_pred): return matthews_corrcoef(y_true, y_pred) return mcc_scorer def get_classification_scorer(): return make_scorer(classification_scorer(), greater_is_better=True) def get_regression_scorer(): return make_scorer(regression_scorer(), greater_is_better=True) def calculate_max_metrics(ss: torch.Tensor, labels: torch.Tensor, cutoff: float) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculate precision, recall and F1 metrics for binary classification at a specific cutoff threshold. Args: ss: Prediction scores tensor, typically between -1 and 1 labels: Ground truth binary labels tensor (0 or 1) cutoff: Classification threshold value Returns: Tuple containing: - F1 score (torch.Tensor) - Precision score (torch.Tensor) - Recall score (torch.Tensor) Note: - Input tensors are converted to float type - Handles division by zero cases by returning 0 - Uses standard binary classification metrics formulas: - Precision = TP / (TP + FP) - Recall = TP / (TP + FN) - F1 = 2 * (Precision * Recall) / (Precision + Recall) """ ss, labels = ss.float(), labels.float() tp = torch.sum((ss >= cutoff) & (labels == 1.0)) fp = torch.sum((ss >= cutoff) & (labels == 0.0)) fn = torch.sum((ss < cutoff) & (labels == 1.0)) precision_denominator = tp + fp precision = torch.where(precision_denominator != 0, tp / precision_denominator, torch.tensor(0.0)) recall_denominator = tp + fn recall = torch.where(recall_denominator != 0, tp / recall_denominator, torch.tensor(0.0)) f1 = torch.where((precision + recall) != 0, (2 * precision * recall) / (precision + recall), torch.tensor(0.0)) return f1, precision, recall def max_metrics(ss: torch.Tensor, labels: torch.Tensor, increment: float = 0.01) -> tuple[float, float, float, float]: """ Find optimal classification metrics by scanning different cutoff thresholds. Optimized version that vectorizes calculations across all cutoffs. Args: ss: Prediction scores tensor, typically between -1 and 1 labels: Ground truth binary labels tensor (0 or 1) increment: Step size for scanning cutoff values, defaults to 0.01 Returns: Tuple containing: - Maximum F1 score (float) - Maximum precision score (float) - Maximum recall score (float) - Optimal cutoff threshold (float) Note: - Input scores are clamped to [-1, 1] range - Handles edge case where all scores are >= 1 - Scans cutoff values from min score to 1 in increments - Handles NaN F1 scores by replacing with -1 before finding max - Returns metrics at the threshold that maximizes F1 score - Optimized to compute metrics for all cutoffs in parallel using vectorization """ # Handle NaNs by replacing with 0.0 ss = torch.nan_to_num(ss, nan=0.0) ss = torch.clamp(ss, -1.0, 1.0) min_val = ss.min().item() max_val = 1 if min_val >= max_val: min_val = 0 # Convert to float and ensure labels are binary ss = ss.float() labels = labels.float() # Create cutoff tensor cutoffs = torch.arange(min_val, max_val, increment, device=ss.device, dtype=ss.dtype) n_cutoffs = len(cutoffs) if n_cutoffs == 0: # Edge case: no cutoffs to test return 0.0, 0.0, 0.0, min_val # Vectorize across all cutoffs: shape (n_cutoffs, n_samples) # Expand cutoffs to (n_cutoffs, 1) and ss to (1, n_samples) for broadcasting ss_expanded = ss.unsqueeze(0) # (1, n_samples) cutoffs_expanded = cutoffs.unsqueeze(1) # (n_cutoffs, 1) labels_expanded = labels.unsqueeze(0) # (1, n_samples) # Compute predictions for all cutoffs at once: (n_cutoffs, n_samples) predictions = (ss_expanded >= cutoffs_expanded).float() # Compute TP, FP, FN for all cutoffs simultaneously # TP: predicted positive and label positive tp = torch.sum(predictions * labels_expanded, dim=1) # (n_cutoffs,) # FP: predicted positive but label negative fp = torch.sum(predictions * (1.0 - labels_expanded), dim=1) # (n_cutoffs,) # FN: predicted negative but label positive fn = torch.sum((1.0 - predictions) * labels_expanded, dim=1) # (n_cutoffs,) # Compute precision, recall, F1 for all cutoffs precision_denominator = tp + fp precision = torch.where(precision_denominator != 0, tp / precision_denominator, torch.tensor(0.0, device=ss.device)) recall_denominator = tp + fn recall = torch.where(recall_denominator != 0, tp / recall_denominator, torch.tensor(0.0, device=ss.device)) # Compute F1 scores f1_denominator = precision + recall f1s = torch.where(f1_denominator != 0, (2 * precision * recall) / f1_denominator, torch.tensor(0.0, device=ss.device)) # Handle NaN values by replacing with -1 valid_f1s = torch.where(torch.isnan(f1s), torch.tensor(-1.0, device=ss.device), f1s) max_index = torch.argmax(valid_f1s) return f1s[max_index].item(), precision[max_index].item(), recall[max_index].item(), cutoffs[max_index].item() def calculate_robust_roc_auc_multiclass(y_true: np.ndarray, probs: np.ndarray) -> float: """ Robust ROC AUC for multi-class (single-label) tasks. Handles missing classes in y_true by ignoring them in the weighted average. """ # Check for NaNs in probs if np.isnan(probs).any(): probs = np.nan_to_num(probs, nan=0.0) n_classes = probs.shape[1] try: if n_classes == 2: if len(np.unique(y_true)) == 2: return roc_auc_score(y_true, probs[:, 1]) return -100.0 y_true_onehot = np.eye(n_classes)[y_true] scores = [] weights = [] for i in range(n_classes): # Only calculate if both positive and negative samples exist if len(np.unique(y_true_onehot[:, i])) == 2: scores.append(roc_auc_score(y_true_onehot[:, i], probs[:, i])) weights.append(np.sum(y_true_onehot[:, i])) if not scores: return -100.0 return float(np.average(scores, weights=weights)) except Exception: return -100.0 def calculate_robust_pr_auc_multiclass(y_true: np.ndarray, probs: np.ndarray) -> float: """ Robust PR AUC for multi-class (single-label) tasks. """ # Check for NaNs in probs if np.isnan(probs).any(): probs = np.nan_to_num(probs, nan=0.0) n_classes = probs.shape[1] try: if n_classes == 2: if len(np.unique(y_true)) == 2: precision, recall, _ = precision_recall_curve(y_true, probs[:, 1]) return auc(recall, precision) return -100.0 y_true_onehot = np.eye(n_classes)[y_true] scores = [] weights = [] for i in range(n_classes): if len(np.unique(y_true_onehot[:, i])) == 2: precision, recall, _ = precision_recall_curve(y_true_onehot[:, i], probs[:, i]) scores.append(auc(recall, precision)) weights.append(np.sum(y_true_onehot[:, i])) if not scores: return -100.0 return float(np.average(scores, weights=weights)) except Exception: return -100.0 def calculate_robust_roc_auc_multilabel(y_true: np.ndarray, probs: np.ndarray) -> float: """ Robust ROC AUC for multi-label tasks (macro average). """ if np.isnan(probs).any(): probs = np.nan_to_num(probs, nan=0.0) scores = [] try: for i in range(y_true.shape[1]): if len(np.unique(y_true[:, i])) == 2: scores.append(roc_auc_score(y_true[:, i], probs[:, i])) if not scores: return -100.0 return float(np.mean(scores)) except Exception: return -100.0 def calculate_robust_pr_auc_multilabel(y_true: np.ndarray, probs: np.ndarray) -> float: """ Robust PR AUC for multi-label tasks (macro average). """ if np.isnan(probs).any(): probs = np.nan_to_num(probs, nan=0.0) scores = [] try: for i in range(y_true.shape[1]): if len(np.unique(y_true[:, i])) == 2: precision, recall, _ = precision_recall_curve(y_true[:, i], probs[:, i]) scores.append(auc(recall, precision)) if not scores: return -100.0 return float(np.mean(scores)) except Exception: return -100.0 def compute_single_label_classification_metrics(p: EvalPrediction) -> dict[str, float]: """ Compute comprehensive metrics for single-label classification tasks. Args: p: EvalPrediction object containing model predictions and ground truth labels Returns: Dictionary with the following metrics (all rounded to 5 decimal places): - f1: F1 score (weighted average) - precision: Precision score (weighted average) - recall: Recall score (weighted average) - accuracy: Overall accuracy - mcc: Matthews Correlation Coefficient - roc_auc: Area Under ROC Curve (weighted average) - pr_auc: Area Under Precision-Recall Curve (weighted average) Note: - Handles both binary and multi-class cases - For binary case: uses 0.5 threshold on probabilities - For multi-class: uses argmax for class prediction - Prints confusion matrix for detailed error analysis - Uses weighted averaging for multi-class metrics - Handles AUC calculation for both binary and multi-class cases """ logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions labels = p.label_ids[1] if isinstance(p.label_ids, tuple) else p.label_ids y_pred = logits.argmax(axis=-1).flatten() y_true = labels.flatten().astype(int) probs = softmax(logits) # Calculate ROC AUC roc_auc = calculate_robust_roc_auc_multiclass(y_true, probs) # Calculate PR AUC (true AUC of Precision-Recall curve) pr_auc = calculate_robust_pr_auc_multiclass(y_true, probs) cm = confusion_matrix(y_true, y_pred) print("\nConfusion Matrix:") print(cm) f1 = f1_score(y_true, y_pred, average='macro', zero_division=0) precision = precision_score(y_true, y_pred, average='macro', zero_division=0) recall = recall_score(y_true, y_pred, average='macro', zero_division=0) accuracy = accuracy_score(y_true, y_pred) mcc = matthews_corrcoef(y_true, y_pred) return { 'f1': round(f1, 5), 'precision': round(precision, 5), 'recall': round(recall, 5), 'accuracy': round(accuracy, 5), 'mcc': round(mcc, 5), 'roc_auc': round(roc_auc, 5), 'pr_auc': round(pr_auc, 5) } def compute_tokenwise_classification_metrics(p: EvalPrediction) -> dict[str, float]: """ Compute metrics for token-level classification tasks. Args: p: EvalPrediction object containing model predictions and ground truth labels Returns: Dictionary containing the following metrics (all rounded to 5 decimal places): - accuracy: Overall accuracy - f1: F1 score (macro average) - precision: Precision score (macro average) - recall: Recall score (macro average) - mcc: Matthews Correlation Coefficient - roc_auc: Area Under ROC Curve (weighted average) - pr_auc: Area Under Precision-Recall Curve (weighted average) Note: - Handles special token padding (-100) by filtering before metric calculation - Uses macro averaging for multi-class metrics - Converts predictions to class labels using argmax - Handles AUC calculation for both binary and multi-class cases """ logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions labels = p.label_ids # Compute f1 score y_pred = logits.argmax(axis=-1).flatten() y_true = labels.flatten() valid_indices = y_true != -100 y_pred = y_pred[valid_indices] y_true = y_true[valid_indices] cm = confusion_matrix(y_true, y_pred) print("\nConfusion Matrix:") print(cm) f1 = f1_score(y_true, y_pred, average='macro', zero_division=0) precision = precision_score(y_true, y_pred, average='macro', zero_division=0) recall = recall_score(y_true, y_pred, average='macro', zero_division=0) accuracy = accuracy_score(y_true, y_pred) mcc = matthews_corrcoef(y_true, y_pred) # Calculate probabilities for AUC metrics probs = softmax(logits) probs = probs.reshape(-1, probs.shape[-1]) # Flatten to (n_samples, n_classes) probs = probs[valid_indices] # Filter by valid indices # Calculate ROC AUC roc_auc = calculate_robust_roc_auc_multiclass(y_true, probs) # Calculate PR AUC (true AUC of Precision-Recall curve) pr_auc = calculate_robust_pr_auc_multiclass(y_true, probs) return { 'accuracy': round(accuracy, 5), 'f1': round(f1, 5), 'precision': round(precision, 5), 'recall': round(recall, 5), 'mcc': round(mcc, 5), 'roc_auc': round(roc_auc, 5), 'pr_auc': round(pr_auc, 5) } def compute_multi_label_classification_metrics(p: EvalPrediction) -> dict[str, float]: """ Compute comprehensive metrics for multi-label classification tasks. Args: p: EvalPrediction object containing model predictions and ground truth labels Returns: Dictionary containing the following metrics (all rounded to 5 decimal places): - accuracy: Overall accuracy - f1: F1 score (optimized across thresholds) - precision: Precision score (at optimal threshold) - recall: Recall score (at optimal threshold) - hamming_loss: Proportion of wrong labels - threshold: Optimal classification threshold - mcc: Matthews Correlation Coefficient - roc_auc: Area Under ROC Curve (macro average) - pr_auc: Area Under Precision-Recall Curve (macro average) Note: - Converts inputs to PyTorch tensors - Applies softmax to raw predictions - Uses threshold optimization for best F1 score - Handles multi-class ROC AUC using one-vs-rest - All metrics are computed on flattened predictions """ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions labels = p.label_ids[1] if isinstance(p.label_ids, tuple) else p.label_ids # Convert to tensors efficiently, avoiding unnecessary numpy round-trip if not isinstance(preds, torch.Tensor): preds = torch.tensor(preds) if not isinstance(labels, torch.Tensor): y_true = torch.tensor(labels, dtype=torch.int) else: y_true = labels.int() probs = preds.sigmoid() y_pred = (probs > 0.5).int() # Flatten before max_metrics for efficiency - max_metrics expects flattened tensors probs_flat = probs.flatten() y_true_flat = y_true.flatten() f1, prec, recall, thres = max_metrics(probs_flat, y_true_flat) y_pred_flat, y_true_flat = y_pred.flatten().numpy(), y_true.flatten().numpy() accuracy = accuracy_score(y_pred_flat, y_true_flat) hamming = hamming_loss(y_pred_flat, y_true_flat) mcc = matthews_corrcoef(y_true_flat, y_pred_flat) # Calculate ROC AUC for multilabel case # Use unflattened arrays for macro averaging roc_auc = calculate_robust_roc_auc_multilabel(y_true.numpy(), probs.numpy()) # Calculate PR AUC for multilabel case (true AUC of Precision-Recall curve) pr_auc = calculate_robust_pr_auc_multilabel(y_true.numpy(), probs.numpy()) return { 'accuracy': round(accuracy, 5), 'f1': round(f1, 5), 'precision': round(prec, 5), 'recall': round(recall, 5), 'hamming_loss': round(hamming, 5), 'threshold': round(thres, 5), 'mcc': round(mcc, 5), 'roc_auc': round(roc_auc, 5), 'pr_auc': round(pr_auc, 5) } def compute_regression_metrics(p: EvalPrediction) -> dict[str, float]: """ Compute comprehensive metrics for regression tasks. Args: p: EvalPrediction object containing model predictions and ground truth values Returns: Dictionary containing the following metrics (all rounded to 5 decimal places): - r_squared: Coefficient of determination (R²) - spearman_rho: Spearman rank correlation coefficient - spear_pval: P-value for Spearman correlation - pearson_rho: Pearson correlation coefficient - pear_pval: P-value for Pearson correlation - mse: Mean Squared Error - mae: Mean Absolute Error - rmse: Root Mean Squared Error Note: - Handles both raw predictions and tuple predictions - Flattens inputs to 1D arrays - Includes both correlation and error metrics - P-values indicate statistical significance of correlations - RMSE is calculated as square root of MSE """ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions labels = p.label_ids[1] if isinstance(p.label_ids, tuple) else p.label_ids y_pred = np.array(preds).flatten() y_true = np.array(labels).flatten() if np.isnan(y_true).any(): print("y_true Nans were cast to 0") y_true = np.where(np.isnan(y_true), 0, y_true) if np.isnan(y_pred).any(): print("y_pred Nans were cast to 0") y_pred = np.where(np.isnan(y_pred), 0, y_pred) try: spearman_rho, spear_pval = spearmanr(y_pred, y_true) pearson_rho, pear_pval = pearsonr(y_pred, y_true) except: spearman_rho = -100.0 spear_pval = -100.0 pearson_rho = -100.0 pear_pval = -100.0 r2 = r2_score(y_true, y_pred) mse = mean_squared_error(y_true, y_pred) mae = mean_absolute_error(y_true, y_pred) rmse = np.sqrt(mse) return { 'r_squared': round(r2, 5), 'spearman_rho': round(spearman_rho, 5), 'spear_pval': round(spear_pval, 5), 'pearson_rho': round(pearson_rho, 5), 'pear_pval': round(pear_pval, 5), 'mse': round(mse, 5), 'mae': round(mae, 5), 'rmse': round(rmse, 5), } def compute_tokenwise_regression_metrics(p: EvalPrediction) -> dict[str, float]: """ Compute regression metrics tokenwise, ignoring label positions equal to -100. Compatible with HF Trainer `compute_metrics` API. """ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions labels = p.label_ids[1] if isinstance(p.label_ids, tuple) else p.label_ids y_pred = np.array(preds) y_true = np.array(labels) # If predictions have an extra trailing dim of size 1, squeeze it if y_pred.ndim == y_true.ndim + 1 and y_pred.shape[-1] == 1: y_pred = np.squeeze(y_pred, axis=-1) # Flatten to align and filter by valid positions (labels != -100) valid_mask = (y_true != -100) y_true = y_true[valid_mask].astype(float) y_pred = y_pred[valid_mask].astype(float) if y_true.size == 0: return { 'r_squared': -100.0, 'spearman_rho': -100.0, 'spear_pval': -100.0, 'pearson_rho': -100.0, 'pear_pval': -100.0, 'mse': -100.0, 'mae': -100.0, 'rmse': -100.0, } if np.isnan(y_true).any(): print("y_true Nans were cast to 0") y_true = np.where(np.isnan(y_true), 0, y_true) if np.isnan(y_pred).any(): print("y_pred Nans were cast to 0") y_pred = np.where(np.isnan(y_pred), 0, y_pred) try: spearman_rho, spear_pval = spearmanr(y_pred, y_true) pearson_rho, pear_pval = pearsonr(y_pred, y_true) except Exception: spearman_rho = -100.0 spear_pval = -100.0 pearson_rho = -100.0 pear_pval = -100.0 r2 = r2_score(y_true, y_pred) mse = mean_squared_error(y_true, y_pred) mae = mean_absolute_error(y_true, y_pred) rmse = np.sqrt(mse) return { 'r_squared': round(float(r2), 5), 'spearman_rho': round(float(spearman_rho), 5), 'spear_pval': round(float(spear_pval), 5), 'pearson_rho': round(float(pearson_rho), 5), 'pear_pval': round(float(pear_pval), 5), 'mse': round(float(mse), 5), 'mae': round(float(mae), 5), 'rmse': round(float(rmse), 5), } def get_compute_metrics(task_type: str, tokenwise: bool = False): if task_type == 'singlelabel': compute_metrics = compute_single_label_classification_metrics elif task_type == 'multilabel': compute_metrics = compute_multi_label_classification_metrics elif task_type == 'sigmoid_regression': # Treat sigmoid_regression like regression for metrics compute_metrics = compute_tokenwise_regression_metrics if tokenwise else compute_regression_metrics elif not task_type == 'regression' and tokenwise: compute_metrics = compute_tokenwise_classification_metrics elif task_type == 'regression' and not tokenwise: compute_metrics = compute_regression_metrics elif task_type == 'regression' and tokenwise: compute_metrics = compute_tokenwise_regression_metrics else: raise ValueError(f'Task type {task_type} not supported') return compute_metrics if __name__ == "__main__": # py -m metrics print("Running tests for metrics functions...") # Test compute_single_label_classification_metrics print("\n--- compute_single_label_classification_metrics (Binary) ---") # 2 samples, 2 classes. # Logits: Sample 0 -> class 0 (high, low), Sample 1 -> class 1 (low, high) predictions = np.array([[2.0, -1.0], [-1.0, 2.0]]) label_ids = np.array([0, 1]) p = EvalPrediction(predictions=predictions, label_ids=label_ids) metrics = compute_single_label_classification_metrics(p) print(metrics) print("\n--- compute_single_label_classification_metrics (Multi-class) ---") # 3 samples, 3 classes. predictions = np.array([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]]) label_ids = np.array([0, 1, 2]) p = EvalPrediction(predictions=predictions, label_ids=label_ids) metrics = compute_single_label_classification_metrics(p) print(metrics) # Test compute_tokenwise_classification_metrics print("\n--- compute_tokenwise_classification_metrics ---") # 1 sample, 3 tokens, 2 classes. # Token 0: pred 0, label 0 # Token 1: pred 1, label 1 # Token 2: pred 0, label -100 (ignored) predictions = np.array([[[2.0, -1.0], [-1.0, 2.0], [2.0, -1.0]]]) label_ids = np.array([[0, 1, -100]]) p = EvalPrediction(predictions=predictions, label_ids=label_ids) metrics = compute_tokenwise_classification_metrics(p) print(metrics) # Test compute_multi_label_classification_metrics print("\n--- compute_multi_label_classification_metrics ---") # 2 samples, 3 classes # Sample 0: pred [1, 0, 1], label [1, 0, 1] # Sample 1: pred [0, 1, 0], label [0, 1, 0] # Logits need to be high for 1, low for 0. predictions = np.array([[5.0, -5.0, 5.0], [-5.0, 5.0, -5.0]]) label_ids = np.array([[1, 0, 1], [0, 1, 0]]) p = EvalPrediction(predictions=predictions, label_ids=label_ids) metrics = compute_multi_label_classification_metrics(p) print(metrics) # Test compute_regression_metrics print("\n--- compute_regression_metrics ---") predictions = np.array([1.0, 2.0, 3.0]) label_ids = np.array([1.1, 1.9, 3.2]) p = EvalPrediction(predictions=predictions, label_ids=label_ids) metrics = compute_regression_metrics(p) print(metrics) # Test compute_tokenwise_regression_metrics print("\n--- compute_tokenwise_regression_metrics ---") # 1 sample, 3 tokens # Token 2 is ignored (-100) predictions = np.array([[1.0, 2.0, 5.0]]) label_ids = np.array([[1.1, 1.9, -100.0]]) p = EvalPrediction(predictions=predictions, label_ids=label_ids) metrics = compute_tokenwise_regression_metrics(p) print(metrics)