File size: 6,024 Bytes
0584798
 
 
 
 
 
32b3543
09a2d95
32b3543
 
 
09a2d95
32b3543
 
0584798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a2d95
 
 
0584798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import annotations

from functools import lru_cache

import torch

try:
    from .config import IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, _looks_like_local_hf_model_dir  # type: ignore
    from .iab_taxonomy import get_iab_taxonomy, parse_path_label, path_to_label  # type: ignore
    from .model_runtime import get_head  # type: ignore
except ImportError:
    from config import IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, _looks_like_local_hf_model_dir
    from iab_taxonomy import get_iab_taxonomy, parse_path_label, path_to_label
    from model_runtime import get_head


def round_score(value: float) -> float:
    return round(float(value), 4)


@lru_cache(maxsize=1)
def _prefix_label_ids() -> dict[tuple[str, ...], list[int]]:
    head = get_head("iab_content")
    prefix_map: dict[tuple[str, ...], list[int]] = {}
    for label, label_id in head.config.label2id.items():
        path = parse_path_label(label)
        for depth in range(1, len(path) + 1):
            prefix_map.setdefault(path[:depth], []).append(label_id)
    return prefix_map


def _effective_exact_threshold(confidence_threshold: float | None) -> float:
    head = get_head("iab_content")
    if confidence_threshold is None:
        return float(head.calibration.confidence_threshold)
    return min(max(float(confidence_threshold), 0.0), 1.0)


def _effective_parent_threshold(exact_threshold: float) -> float:
    return min(max(IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, exact_threshold), 1.0)


def _build_prediction(
    accepted_path: tuple[str, ...],
    *,
    exact_label: str,
    confidence: float,
    raw_confidence: float,
    exact_threshold: float,
    calibrated: bool,
    meets_confidence_threshold: bool,
    mapping_mode: str,
    stopped_reason: str,
) -> dict:
    taxonomy = get_iab_taxonomy()
    return {
        "label": path_to_label(accepted_path),
        "exact_label": exact_label,
        "path": list(accepted_path),
        "confidence": round_score(confidence),
        "raw_confidence": round_score(raw_confidence),
        "confidence_threshold": round_score(exact_threshold),
        "calibrated": calibrated,
        "meets_confidence_threshold": meets_confidence_threshold,
        "content": taxonomy.build_content_object(
            accepted_path,
            mapping_mode=mapping_mode,
            mapping_confidence=confidence,
        ),
        "mapping_mode": mapping_mode,
        "mapping_confidence": round_score(confidence),
        "source": "supervised_classifier",
        "stopped_reason": stopped_reason,
    }


def predict_iab_content_classifier_batch(
    texts: list[str],
    confidence_threshold: float | None = None,
) -> list[dict | None]:
    if not texts:
        return []

    head = get_head("iab_content")
    # `SequenceClassifierHead` will raise if the folder exists but is incomplete
    # (missing `model.safetensors` / `pytorch_model.bin`). Treat that as "no model".
    if not _looks_like_local_hf_model_dir(head.config.model_dir):
        return [None for _ in texts]

    raw_probs, calibrated_probs = head.predict_probs_batch(texts)
    prefix_map = _prefix_label_ids()
    exact_threshold = _effective_exact_threshold(confidence_threshold)
    parent_threshold = _effective_parent_threshold(exact_threshold)
    predictions: list[dict | None] = []

    for raw_row, calibrated_row in zip(raw_probs, calibrated_probs):
        pred_id = int(torch.argmax(calibrated_row).item())
        exact_label = head.model.config.id2label[pred_id]
        exact_path = parse_path_label(exact_label)
        exact_confidence = float(calibrated_row[pred_id].item())
        exact_raw_confidence = float(raw_row[pred_id].item())

        if exact_confidence >= exact_threshold:
            predictions.append(
                _build_prediction(
                    exact_path,
                    exact_label=exact_label,
                    confidence=exact_confidence,
                    raw_confidence=exact_raw_confidence,
                    exact_threshold=exact_threshold,
                    calibrated=head.calibration.calibrated,
                    meets_confidence_threshold=True,
                    mapping_mode="exact",
                    stopped_reason="exact_threshold_met",
                )
            )
            continue

        accepted_path = exact_path[:1]
        accepted_confidence = float(calibrated_row[prefix_map[accepted_path]].sum().item())
        accepted_raw_confidence = float(raw_row[prefix_map[accepted_path]].sum().item())
        meets_confidence_threshold = False
        stopped_reason = "top_level_safe_fallback"

        for depth in range(len(exact_path) - 1, 0, -1):
            prefix = exact_path[:depth]
            prefix_ids = prefix_map[prefix]
            prefix_confidence = float(calibrated_row[prefix_ids].sum().item())
            prefix_raw_confidence = float(raw_row[prefix_ids].sum().item())
            if prefix_confidence >= parent_threshold:
                accepted_path = prefix
                accepted_confidence = prefix_confidence
                accepted_raw_confidence = prefix_raw_confidence
                meets_confidence_threshold = True
                stopped_reason = "parent_fallback_threshold_met"
                break

        predictions.append(
            _build_prediction(
                accepted_path,
                exact_label=exact_label,
                confidence=accepted_confidence,
                raw_confidence=accepted_raw_confidence,
                exact_threshold=exact_threshold,
                calibrated=head.calibration.calibrated,
                meets_confidence_threshold=meets_confidence_threshold,
                mapping_mode="nearest_equivalent",
                stopped_reason=stopped_reason,
            )
        )

    return predictions


def predict_iab_content_classifier(text: str, confidence_threshold: float | None = None) -> dict | None:
    return predict_iab_content_classifier_batch([text], confidence_threshold=confidence_threshold)[0]