DaCrow13
Deploy to HF Spaces (Clean)
225af6a
"""
Threshold Optimization for Multi-Label Classification
This module provides functions to optimize decision thresholds for multi-label
classification tasks to maximize F1-score (or other metrics).
In multi-label classification, the default threshold of 0.5 for converting
probabilities to binary predictions is often suboptimal, especially for
imbalanced classes. This module finds optimal thresholds per-class or globally.
Designed to work with Random Forest (baseline and improved models).
Usage:
from threshold_optimization import optimize_thresholds, apply_thresholds
from sklearn.ensemble import RandomForestClassifier
# Train Random Forest
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
# Get probability predictions
y_proba = model.predict_proba(X_val)
# Find optimal thresholds on validation set
thresholds = optimize_thresholds(y_val, y_proba, method='per_class')
# Apply thresholds to test set
y_pred = apply_thresholds(model.predict_proba(X_test), thresholds)
"""
from typing import Dict, Tuple, Union
import warnings
import numpy as np
from sklearn.metrics import f1_score
def optimize_thresholds(
y_true: np.ndarray,
y_proba: np.ndarray,
method: str = "per_class",
metric: str = "f1_weighted",
search_range: Tuple[float, float] = (0.1, 0.9),
n_steps: int = 50,
) -> Union[float, np.ndarray]:
"""
Optimize decision thresholds to maximize a given metric.
This function searches for optimal thresholds that convert probability
predictions to binary predictions (0/1) in a way that maximizes the
specified metric (default: weighted F1-score).
Args:
y_true: True binary labels, shape (n_samples, n_labels)
y_proba: Predicted probabilities, shape (n_samples, n_labels)
method: Threshold optimization method:
- 'global': Single threshold for all classes
- 'per_class': One threshold per class (default, recommended)
metric: Metric to optimize ('f1_weighted', 'f1_macro', 'f1_micro')
search_range: Range of thresholds to search (min, max)
n_steps: Number of threshold values to try
Returns:
- If method='global': Single float threshold
- If method='per_class': Array of thresholds, one per class
Example:
>>> y_true = np.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
>>> y_proba = np.array([[0.9, 0.3, 0.7], [0.2, 0.8, 0.4], [0.85, 0.6, 0.3]])
>>> thresholds = optimize_thresholds(y_true, y_proba, method='per_class')
>>> print(thresholds) # Array of 3 thresholds, one per class
"""
if y_true.shape != y_proba.shape:
raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_proba {y_proba.shape}")
if method == "global":
return _optimize_global_threshold(y_true, y_proba, metric, search_range, n_steps)
elif method == "per_class":
return _optimize_per_class_thresholds(y_true, y_proba, metric, search_range, n_steps)
else:
raise ValueError(f"Invalid method: {method}. Must be 'global' or 'per_class'")
def _optimize_global_threshold(
y_true: np.ndarray,
y_proba: np.ndarray,
metric: str,
search_range: Tuple[float, float],
n_steps: int,
) -> float:
"""
Find single optimal threshold for all classes.
This approach is faster but less flexible than per-class optimization.
Useful when classes have similar distributions.
"""
thresholds_to_try = np.linspace(search_range[0], search_range[1], n_steps)
best_threshold = 0.5
best_score = -np.inf
for threshold in thresholds_to_try:
y_pred = (y_proba >= threshold).astype(int)
score = _compute_score(y_true, y_pred, metric)
if score > best_score:
best_score = score
best_threshold = threshold
print(f"Optimal global threshold: {best_threshold:.3f} (score: {best_score:.4f})")
return best_threshold
def _optimize_per_class_thresholds(
y_true: np.ndarray,
y_proba: np.ndarray,
metric: str,
search_range: Tuple[float, float],
n_steps: int,
) -> np.ndarray:
"""
Find optimal threshold for each class independently.
This approach is more flexible and typically yields better results
for imbalanced multi-label problems, but is slower.
"""
n_classes = y_true.shape[1]
optimal_thresholds = np.zeros(n_classes)
thresholds_to_try = np.linspace(search_range[0], search_range[1], n_steps)
print(f"Optimizing thresholds for {n_classes} classes...")
for class_idx in range(n_classes):
y_true_class = y_true[:, class_idx]
y_proba_class = y_proba[:, class_idx]
# Skip classes with no positive samples
if y_true_class.sum() == 0:
optimal_thresholds[class_idx] = 0.5
warnings.warn(
f"Class {class_idx} has no positive samples, using default threshold 0.5"
)
continue
best_threshold = 0.5
best_score = -np.inf
for threshold in thresholds_to_try:
y_pred_class = (y_proba_class >= threshold).astype(int)
# Compute binary F1 for this class
try:
score = f1_score(y_true_class, y_pred_class, average="binary", zero_division=0)
except Exception:
continue
if score > best_score:
best_score = score
best_threshold = threshold
optimal_thresholds[class_idx] = best_threshold
print(
f"Threshold statistics: min={optimal_thresholds.min():.3f}, "
f"max={optimal_thresholds.max():.3f}, mean={optimal_thresholds.mean():.3f}"
)
return optimal_thresholds
def _compute_score(y_true: np.ndarray, y_pred: np.ndarray, metric: str) -> float:
"""Compute the specified metric."""
if metric == "f1_weighted":
return f1_score(y_true, y_pred, average="weighted", zero_division=0)
elif metric == "f1_macro":
return f1_score(y_true, y_pred, average="macro", zero_division=0)
elif metric == "f1_micro":
return f1_score(y_true, y_pred, average="micro", zero_division=0)
else:
raise ValueError(f"Unsupported metric: {metric}")
def apply_thresholds(y_proba: np.ndarray, thresholds: Union[float, np.ndarray]) -> np.ndarray:
"""
Apply thresholds to probability predictions to get binary predictions.
Args:
y_proba: Predicted probabilities, shape (n_samples, n_labels)
thresholds: Threshold(s) to apply:
- Single float: same threshold for all classes
- Array: one threshold per class
Returns:
Binary predictions, shape (n_samples, n_labels)
Example:
>>> y_proba = np.array([[0.9, 0.3, 0.7], [0.2, 0.8, 0.4]])
>>> thresholds = np.array([0.5, 0.4, 0.6])
>>> y_pred = apply_thresholds(y_proba, thresholds)
>>> print(y_pred)
[[1 0 1]
[0 1 0]]
"""
if isinstance(thresholds, float):
# Global threshold
return (y_proba >= thresholds).astype(int)
else:
# Per-class thresholds
if len(thresholds) != y_proba.shape[1]:
raise ValueError(
f"Number of thresholds ({len(thresholds)}) must match "
f"number of classes ({y_proba.shape[1]})"
)
# Broadcasting: compare each column with its threshold
return (y_proba >= thresholds[np.newaxis, :]).astype(int)
def evaluate_with_thresholds(
model,
X_val: np.ndarray,
y_val: np.ndarray,
X_test: np.ndarray,
y_test: np.ndarray,
method: str = "per_class",
) -> Dict:
"""
Complete workflow: optimize thresholds on validation set and evaluate on test set.
This function encapsulates the entire threshold optimization pipeline:
1. Get probability predictions on validation set
2. Optimize thresholds using validation data
3. Apply optimized thresholds to test set
4. Compare with default threshold (0.5)
Args:
model: Trained model with predict_proba method
X_val: Validation features
y_val: Validation labels (binary)
X_test: Test features
y_test: Test labels (binary)
method: 'global' or 'per_class'
Returns:
Dictionary with results:
- 'thresholds': Optimized thresholds
- 'f1_default': F1-score with default threshold (0.5)
- 'f1_optimized': F1-score with optimized thresholds
- 'improvement': Absolute improvement in F1-score
Example:
>>> results = evaluate_with_thresholds(model, X_val, y_val, X_test, y_test)
>>> print(f"F1 improvement: {results['improvement']:.4f}")
"""
# Get probability predictions
print("Getting probability predictions on validation set...")
y_val_proba = model.predict_proba(X_val)
# Handle MultiOutputClassifier (returns list of arrays)
if isinstance(y_val_proba, list):
y_val_proba = np.column_stack([proba[:, 1] for proba in y_val_proba])
# Optimize thresholds
print(f"Optimizing thresholds ({method})...")
thresholds = optimize_thresholds(y_val, y_val_proba, method=method)
# Evaluate on test set
print("Evaluating on test set...")
y_test_proba = model.predict_proba(X_test)
# Handle MultiOutputClassifier
if isinstance(y_test_proba, list):
y_test_proba = np.column_stack([proba[:, 1] for proba in y_test_proba])
# Default predictions (threshold=0.5)
y_test_pred_default = (y_test_proba >= 0.5).astype(int)
f1_default = f1_score(y_test, y_test_pred_default, average="weighted", zero_division=0)
# Optimized predictions
y_test_pred_optimized = apply_thresholds(y_test_proba, thresholds)
f1_optimized = f1_score(y_test, y_test_pred_optimized, average="weighted", zero_division=0)
improvement = f1_optimized - f1_default
print("\nResults:")
print(f" F1-score (default threshold=0.5): {f1_default:.4f}")
print(f" F1-score (optimized thresholds): {f1_optimized:.4f}")
print(f" Improvement: {improvement:+.4f} ({improvement / f1_default * 100:+.2f}%)")
return {
"thresholds": thresholds,
"f1_default": f1_default,
"f1_optimized": f1_optimized,
"improvement": improvement,
"y_pred_optimized": y_test_pred_optimized,
}