""" Threshold Optimization for Multi-Label Classification This module provides functions to optimize decision thresholds for multi-label classification tasks to maximize F1-score (or other metrics). In multi-label classification, the default threshold of 0.5 for converting probabilities to binary predictions is often suboptimal, especially for imbalanced classes. This module finds optimal thresholds per-class or globally. Designed to work with Random Forest (baseline and improved models). Usage: from threshold_optimization import optimize_thresholds, apply_thresholds from sklearn.ensemble import RandomForestClassifier # Train Random Forest model = RandomForestClassifier(n_estimators=100) model.fit(X_train, y_train) # Get probability predictions y_proba = model.predict_proba(X_val) # Find optimal thresholds on validation set thresholds = optimize_thresholds(y_val, y_proba, method='per_class') # Apply thresholds to test set y_pred = apply_thresholds(model.predict_proba(X_test), thresholds) """ from typing import Dict, Tuple, Union import warnings import numpy as np from sklearn.metrics import f1_score def optimize_thresholds( y_true: np.ndarray, y_proba: np.ndarray, method: str = "per_class", metric: str = "f1_weighted", search_range: Tuple[float, float] = (0.1, 0.9), n_steps: int = 50, ) -> Union[float, np.ndarray]: """ Optimize decision thresholds to maximize a given metric. This function searches for optimal thresholds that convert probability predictions to binary predictions (0/1) in a way that maximizes the specified metric (default: weighted F1-score). Args: y_true: True binary labels, shape (n_samples, n_labels) y_proba: Predicted probabilities, shape (n_samples, n_labels) method: Threshold optimization method: - 'global': Single threshold for all classes - 'per_class': One threshold per class (default, recommended) metric: Metric to optimize ('f1_weighted', 'f1_macro', 'f1_micro') search_range: Range of thresholds to search (min, max) n_steps: Number of threshold values to try Returns: - If method='global': Single float threshold - If method='per_class': Array of thresholds, one per class Example: >>> y_true = np.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]]) >>> y_proba = np.array([[0.9, 0.3, 0.7], [0.2, 0.8, 0.4], [0.85, 0.6, 0.3]]) >>> thresholds = optimize_thresholds(y_true, y_proba, method='per_class') >>> print(thresholds) # Array of 3 thresholds, one per class """ if y_true.shape != y_proba.shape: raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_proba {y_proba.shape}") if method == "global": return _optimize_global_threshold(y_true, y_proba, metric, search_range, n_steps) elif method == "per_class": return _optimize_per_class_thresholds(y_true, y_proba, metric, search_range, n_steps) else: raise ValueError(f"Invalid method: {method}. Must be 'global' or 'per_class'") def _optimize_global_threshold( y_true: np.ndarray, y_proba: np.ndarray, metric: str, search_range: Tuple[float, float], n_steps: int, ) -> float: """ Find single optimal threshold for all classes. This approach is faster but less flexible than per-class optimization. Useful when classes have similar distributions. """ thresholds_to_try = np.linspace(search_range[0], search_range[1], n_steps) best_threshold = 0.5 best_score = -np.inf for threshold in thresholds_to_try: y_pred = (y_proba >= threshold).astype(int) score = _compute_score(y_true, y_pred, metric) if score > best_score: best_score = score best_threshold = threshold print(f"Optimal global threshold: {best_threshold:.3f} (score: {best_score:.4f})") return best_threshold def _optimize_per_class_thresholds( y_true: np.ndarray, y_proba: np.ndarray, metric: str, search_range: Tuple[float, float], n_steps: int, ) -> np.ndarray: """ Find optimal threshold for each class independently. This approach is more flexible and typically yields better results for imbalanced multi-label problems, but is slower. """ n_classes = y_true.shape[1] optimal_thresholds = np.zeros(n_classes) thresholds_to_try = np.linspace(search_range[0], search_range[1], n_steps) print(f"Optimizing thresholds for {n_classes} classes...") for class_idx in range(n_classes): y_true_class = y_true[:, class_idx] y_proba_class = y_proba[:, class_idx] # Skip classes with no positive samples if y_true_class.sum() == 0: optimal_thresholds[class_idx] = 0.5 warnings.warn( f"Class {class_idx} has no positive samples, using default threshold 0.5" ) continue best_threshold = 0.5 best_score = -np.inf for threshold in thresholds_to_try: y_pred_class = (y_proba_class >= threshold).astype(int) # Compute binary F1 for this class try: score = f1_score(y_true_class, y_pred_class, average="binary", zero_division=0) except Exception: continue if score > best_score: best_score = score best_threshold = threshold optimal_thresholds[class_idx] = best_threshold print( f"Threshold statistics: min={optimal_thresholds.min():.3f}, " f"max={optimal_thresholds.max():.3f}, mean={optimal_thresholds.mean():.3f}" ) return optimal_thresholds def _compute_score(y_true: np.ndarray, y_pred: np.ndarray, metric: str) -> float: """Compute the specified metric.""" if metric == "f1_weighted": return f1_score(y_true, y_pred, average="weighted", zero_division=0) elif metric == "f1_macro": return f1_score(y_true, y_pred, average="macro", zero_division=0) elif metric == "f1_micro": return f1_score(y_true, y_pred, average="micro", zero_division=0) else: raise ValueError(f"Unsupported metric: {metric}") def apply_thresholds(y_proba: np.ndarray, thresholds: Union[float, np.ndarray]) -> np.ndarray: """ Apply thresholds to probability predictions to get binary predictions. Args: y_proba: Predicted probabilities, shape (n_samples, n_labels) thresholds: Threshold(s) to apply: - Single float: same threshold for all classes - Array: one threshold per class Returns: Binary predictions, shape (n_samples, n_labels) Example: >>> y_proba = np.array([[0.9, 0.3, 0.7], [0.2, 0.8, 0.4]]) >>> thresholds = np.array([0.5, 0.4, 0.6]) >>> y_pred = apply_thresholds(y_proba, thresholds) >>> print(y_pred) [[1 0 1] [0 1 0]] """ if isinstance(thresholds, float): # Global threshold return (y_proba >= thresholds).astype(int) else: # Per-class thresholds if len(thresholds) != y_proba.shape[1]: raise ValueError( f"Number of thresholds ({len(thresholds)}) must match " f"number of classes ({y_proba.shape[1]})" ) # Broadcasting: compare each column with its threshold return (y_proba >= thresholds[np.newaxis, :]).astype(int) def evaluate_with_thresholds( model, X_val: np.ndarray, y_val: np.ndarray, X_test: np.ndarray, y_test: np.ndarray, method: str = "per_class", ) -> Dict: """ Complete workflow: optimize thresholds on validation set and evaluate on test set. This function encapsulates the entire threshold optimization pipeline: 1. Get probability predictions on validation set 2. Optimize thresholds using validation data 3. Apply optimized thresholds to test set 4. Compare with default threshold (0.5) Args: model: Trained model with predict_proba method X_val: Validation features y_val: Validation labels (binary) X_test: Test features y_test: Test labels (binary) method: 'global' or 'per_class' Returns: Dictionary with results: - 'thresholds': Optimized thresholds - 'f1_default': F1-score with default threshold (0.5) - 'f1_optimized': F1-score with optimized thresholds - 'improvement': Absolute improvement in F1-score Example: >>> results = evaluate_with_thresholds(model, X_val, y_val, X_test, y_test) >>> print(f"F1 improvement: {results['improvement']:.4f}") """ # Get probability predictions print("Getting probability predictions on validation set...") y_val_proba = model.predict_proba(X_val) # Handle MultiOutputClassifier (returns list of arrays) if isinstance(y_val_proba, list): y_val_proba = np.column_stack([proba[:, 1] for proba in y_val_proba]) # Optimize thresholds print(f"Optimizing thresholds ({method})...") thresholds = optimize_thresholds(y_val, y_val_proba, method=method) # Evaluate on test set print("Evaluating on test set...") y_test_proba = model.predict_proba(X_test) # Handle MultiOutputClassifier if isinstance(y_test_proba, list): y_test_proba = np.column_stack([proba[:, 1] for proba in y_test_proba]) # Default predictions (threshold=0.5) y_test_pred_default = (y_test_proba >= 0.5).astype(int) f1_default = f1_score(y_test, y_test_pred_default, average="weighted", zero_division=0) # Optimized predictions y_test_pred_optimized = apply_thresholds(y_test_proba, thresholds) f1_optimized = f1_score(y_test, y_test_pred_optimized, average="weighted", zero_division=0) improvement = f1_optimized - f1_default print("\nResults:") print(f" F1-score (default threshold=0.5): {f1_default:.4f}") print(f" F1-score (optimized thresholds): {f1_optimized:.4f}") print(f" Improvement: {improvement:+.4f} ({improvement / f1_default * 100:+.2f}%)") return { "thresholds": thresholds, "f1_default": f1_default, "f1_optimized": f1_optimized, "improvement": improvement, "y_pred_optimized": y_test_pred_optimized, }