from __future__ import annotations from functools import lru_cache import torch try: from .config import IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, _looks_like_local_hf_model_dir # type: ignore from .iab_taxonomy import get_iab_taxonomy, parse_path_label, path_to_label # type: ignore from .model_runtime import get_head # type: ignore except ImportError: from config import IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, _looks_like_local_hf_model_dir from iab_taxonomy import get_iab_taxonomy, parse_path_label, path_to_label from model_runtime import get_head def round_score(value: float) -> float: return round(float(value), 4) @lru_cache(maxsize=1) def _prefix_label_ids() -> dict[tuple[str, ...], list[int]]: head = get_head("iab_content") prefix_map: dict[tuple[str, ...], list[int]] = {} for label, label_id in head.config.label2id.items(): path = parse_path_label(label) for depth in range(1, len(path) + 1): prefix_map.setdefault(path[:depth], []).append(label_id) return prefix_map def _effective_exact_threshold(confidence_threshold: float | None) -> float: head = get_head("iab_content") if confidence_threshold is None: return float(head.calibration.confidence_threshold) return min(max(float(confidence_threshold), 0.0), 1.0) def _effective_parent_threshold(exact_threshold: float) -> float: return min(max(IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, exact_threshold), 1.0) def _build_prediction( accepted_path: tuple[str, ...], *, exact_label: str, confidence: float, raw_confidence: float, exact_threshold: float, calibrated: bool, meets_confidence_threshold: bool, mapping_mode: str, stopped_reason: str, ) -> dict: taxonomy = get_iab_taxonomy() return { "label": path_to_label(accepted_path), "exact_label": exact_label, "path": list(accepted_path), "confidence": round_score(confidence), "raw_confidence": round_score(raw_confidence), "confidence_threshold": round_score(exact_threshold), "calibrated": calibrated, "meets_confidence_threshold": meets_confidence_threshold, "content": taxonomy.build_content_object( accepted_path, mapping_mode=mapping_mode, mapping_confidence=confidence, ), "mapping_mode": mapping_mode, "mapping_confidence": round_score(confidence), "source": "supervised_classifier", "stopped_reason": stopped_reason, } def predict_iab_content_classifier_batch( texts: list[str], confidence_threshold: float | None = None, ) -> list[dict | None]: if not texts: return [] head = get_head("iab_content") # `SequenceClassifierHead` will raise if the folder exists but is incomplete # (missing `model.safetensors` / `pytorch_model.bin`). Treat that as "no model". if not _looks_like_local_hf_model_dir(head.config.model_dir): return [None for _ in texts] raw_probs, calibrated_probs = head.predict_probs_batch(texts) prefix_map = _prefix_label_ids() exact_threshold = _effective_exact_threshold(confidence_threshold) parent_threshold = _effective_parent_threshold(exact_threshold) predictions: list[dict | None] = [] for raw_row, calibrated_row in zip(raw_probs, calibrated_probs): pred_id = int(torch.argmax(calibrated_row).item()) exact_label = head.model.config.id2label[pred_id] exact_path = parse_path_label(exact_label) exact_confidence = float(calibrated_row[pred_id].item()) exact_raw_confidence = float(raw_row[pred_id].item()) if exact_confidence >= exact_threshold: predictions.append( _build_prediction( exact_path, exact_label=exact_label, confidence=exact_confidence, raw_confidence=exact_raw_confidence, exact_threshold=exact_threshold, calibrated=head.calibration.calibrated, meets_confidence_threshold=True, mapping_mode="exact", stopped_reason="exact_threshold_met", ) ) continue accepted_path = exact_path[:1] accepted_confidence = float(calibrated_row[prefix_map[accepted_path]].sum().item()) accepted_raw_confidence = float(raw_row[prefix_map[accepted_path]].sum().item()) meets_confidence_threshold = False stopped_reason = "top_level_safe_fallback" for depth in range(len(exact_path) - 1, 0, -1): prefix = exact_path[:depth] prefix_ids = prefix_map[prefix] prefix_confidence = float(calibrated_row[prefix_ids].sum().item()) prefix_raw_confidence = float(raw_row[prefix_ids].sum().item()) if prefix_confidence >= parent_threshold: accepted_path = prefix accepted_confidence = prefix_confidence accepted_raw_confidence = prefix_raw_confidence meets_confidence_threshold = True stopped_reason = "parent_fallback_threshold_met" break predictions.append( _build_prediction( accepted_path, exact_label=exact_label, confidence=accepted_confidence, raw_confidence=accepted_raw_confidence, exact_threshold=exact_threshold, calibrated=head.calibration.calibrated, meets_confidence_threshold=meets_confidence_threshold, mapping_mode="nearest_equivalent", stopped_reason=stopped_reason, ) ) return predictions def predict_iab_content_classifier(text: str, confidence_threshold: float | None = None) -> dict | None: return predict_iab_content_classifier_batch([text], confidence_threshold=confidence_threshold)[0]