agentic-intent-classifier / iab_classifier.py
manikumargouni's picture
Upload iab_classifier.py with huggingface_hub
09a2d95 verified
from __future__ import annotations
from functools import lru_cache
import torch
try:
from .config import IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, _looks_like_local_hf_model_dir # type: ignore
from .iab_taxonomy import get_iab_taxonomy, parse_path_label, path_to_label # type: ignore
from .model_runtime import get_head # type: ignore
except ImportError:
from config import IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, _looks_like_local_hf_model_dir
from iab_taxonomy import get_iab_taxonomy, parse_path_label, path_to_label
from model_runtime import get_head
def round_score(value: float) -> float:
return round(float(value), 4)
@lru_cache(maxsize=1)
def _prefix_label_ids() -> dict[tuple[str, ...], list[int]]:
head = get_head("iab_content")
prefix_map: dict[tuple[str, ...], list[int]] = {}
for label, label_id in head.config.label2id.items():
path = parse_path_label(label)
for depth in range(1, len(path) + 1):
prefix_map.setdefault(path[:depth], []).append(label_id)
return prefix_map
def _effective_exact_threshold(confidence_threshold: float | None) -> float:
head = get_head("iab_content")
if confidence_threshold is None:
return float(head.calibration.confidence_threshold)
return min(max(float(confidence_threshold), 0.0), 1.0)
def _effective_parent_threshold(exact_threshold: float) -> float:
return min(max(IAB_PARENT_FALLBACK_CONFIDENCE_FLOOR, exact_threshold), 1.0)
def _build_prediction(
accepted_path: tuple[str, ...],
*,
exact_label: str,
confidence: float,
raw_confidence: float,
exact_threshold: float,
calibrated: bool,
meets_confidence_threshold: bool,
mapping_mode: str,
stopped_reason: str,
) -> dict:
taxonomy = get_iab_taxonomy()
return {
"label": path_to_label(accepted_path),
"exact_label": exact_label,
"path": list(accepted_path),
"confidence": round_score(confidence),
"raw_confidence": round_score(raw_confidence),
"confidence_threshold": round_score(exact_threshold),
"calibrated": calibrated,
"meets_confidence_threshold": meets_confidence_threshold,
"content": taxonomy.build_content_object(
accepted_path,
mapping_mode=mapping_mode,
mapping_confidence=confidence,
),
"mapping_mode": mapping_mode,
"mapping_confidence": round_score(confidence),
"source": "supervised_classifier",
"stopped_reason": stopped_reason,
}
def predict_iab_content_classifier_batch(
texts: list[str],
confidence_threshold: float | None = None,
) -> list[dict | None]:
if not texts:
return []
head = get_head("iab_content")
# `SequenceClassifierHead` will raise if the folder exists but is incomplete
# (missing `model.safetensors` / `pytorch_model.bin`). Treat that as "no model".
if not _looks_like_local_hf_model_dir(head.config.model_dir):
return [None for _ in texts]
raw_probs, calibrated_probs = head.predict_probs_batch(texts)
prefix_map = _prefix_label_ids()
exact_threshold = _effective_exact_threshold(confidence_threshold)
parent_threshold = _effective_parent_threshold(exact_threshold)
predictions: list[dict | None] = []
for raw_row, calibrated_row in zip(raw_probs, calibrated_probs):
pred_id = int(torch.argmax(calibrated_row).item())
exact_label = head.model.config.id2label[pred_id]
exact_path = parse_path_label(exact_label)
exact_confidence = float(calibrated_row[pred_id].item())
exact_raw_confidence = float(raw_row[pred_id].item())
if exact_confidence >= exact_threshold:
predictions.append(
_build_prediction(
exact_path,
exact_label=exact_label,
confidence=exact_confidence,
raw_confidence=exact_raw_confidence,
exact_threshold=exact_threshold,
calibrated=head.calibration.calibrated,
meets_confidence_threshold=True,
mapping_mode="exact",
stopped_reason="exact_threshold_met",
)
)
continue
accepted_path = exact_path[:1]
accepted_confidence = float(calibrated_row[prefix_map[accepted_path]].sum().item())
accepted_raw_confidence = float(raw_row[prefix_map[accepted_path]].sum().item())
meets_confidence_threshold = False
stopped_reason = "top_level_safe_fallback"
for depth in range(len(exact_path) - 1, 0, -1):
prefix = exact_path[:depth]
prefix_ids = prefix_map[prefix]
prefix_confidence = float(calibrated_row[prefix_ids].sum().item())
prefix_raw_confidence = float(raw_row[prefix_ids].sum().item())
if prefix_confidence >= parent_threshold:
accepted_path = prefix
accepted_confidence = prefix_confidence
accepted_raw_confidence = prefix_raw_confidence
meets_confidence_threshold = True
stopped_reason = "parent_fallback_threshold_met"
break
predictions.append(
_build_prediction(
accepted_path,
exact_label=exact_label,
confidence=accepted_confidence,
raw_confidence=accepted_raw_confidence,
exact_threshold=exact_threshold,
calibrated=head.calibration.calibrated,
meets_confidence_threshold=meets_confidence_threshold,
mapping_mode="nearest_equivalent",
stopped_reason=stopped_reason,
)
)
return predictions
def predict_iab_content_classifier(text: str, confidence_threshold: float | None = None) -> dict | None:
return predict_iab_content_classifier_batch([text], confidence_threshold=confidence_threshold)[0]