File size: 1,398 Bytes
46cc63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Validation-set threshold search to maximize F1 (default: toxic class).
"""

from __future__ import annotations

import numpy as np
from sklearn.metrics import f1_score


def search_best_threshold(
    y_true: np.ndarray,
    probs: np.ndarray,
    *,
    metric: str = "f1_toxic",
    min_threshold: float = 0.05,
    max_threshold: float = 0.95,
    step: float = 0.01,
) -> tuple[float, float]:
    """
    Grid-search classification threshold on validation probabilities.

    Returns (best_threshold, best_score).
    """
    y = np.asarray(y_true).astype(int)
    p = np.asarray(probs, dtype=float)
    thresholds = np.arange(min_threshold, max_threshold + step / 2, step)

    best_t = 0.5
    best_score = -1.0
    average = "weighted" if metric == "f1_weighted" else "binary"
    pos_label = 1

    for t in thresholds:
        preds = (p >= t).astype(int)
        if metric in ("f1_toxic", "f1_binary"):
            score = float(f1_score(y, preds, pos_label=pos_label, zero_division=0))
        else:
            score = float(
                f1_score(y, preds, average=average, zero_division=0)
            )
        if score > best_score:
            best_score = score
            best_t = float(t)

    return best_t, best_score


def predict_with_threshold(probs: np.ndarray, threshold: float) -> np.ndarray:
    return (np.asarray(probs, dtype=float) >= threshold).astype(int)