File size: 3,743 Bytes
46cc63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Inference for Notebook 14 meta-feature stacking (frozen CLS + metadata + LR)."""

from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Any

import joblib
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from src.features.metadata_features import extract_metadata_features

MODEL_ID = "unitary/toxic-bert"
_EMOJI_PAT = re.compile(
    "["
    "\U0001f300-\U0001f9ff"
    "\U0001f600-\U0001f64f"
    "]+",
    flags=re.UNICODE,
)


def _extended_meta_frame(text: str) -> pd.DataFrame:
    df = pd.DataFrame({"Text": [text]})
    base = extract_metadata_features(df, text_column="Text")
    length = max(len(text), 1)
    base = base.copy()
    base["emoji_count"] = len(_EMOJI_PAT.findall(text))
    base["punctuation_density"] = len(re.findall(r"[^\w\s]", text)) / length
    return base.astype(float)


class MetaStackPredictor:
    """Load production bundle and score a single comment."""

    def __init__(
        self,
        bundle_path: Path,
        *,
        manifest_path: Path | None = None,
        frozen_model_id: str = MODEL_ID,
    ) -> None:
        self.bundle_path = bundle_path
        self.frozen_model_id = frozen_model_id
        self.manifest: dict[str, Any] = {}
        if manifest_path and manifest_path.is_file():
            self.manifest = json.loads(manifest_path.read_text(encoding="utf-8"))

        bundle = joblib.load(bundle_path)
        self.scaler = bundle["scaler"]
        self.clf = bundle["clf"]
        self.meta_columns: list[str] = list(bundle.get("meta_columns", []))
        self.default_threshold = float(self.manifest.get("threshold", 0.381))

        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._tokenizer = AutoTokenizer.from_pretrained(frozen_model_id)
        self._bert = AutoModelForSequenceClassification.from_pretrained(frozen_model_id)
        for p in self._bert.parameters():
            p.requires_grad = False
        self._bert.eval()
        self._bert.to(self._device)

    def _cls_vector(self, text: str) -> np.ndarray:
        with torch.no_grad():
            enc = self._tokenizer(
                [text],
                truncation=True,
                max_length=128,
                padding=True,
                return_tensors="pt",
            )
            enc = {k: v.to(self._device) for k, v in enc.items()}
            cls = self._bert.bert(**enc).last_hidden_state[:, 0, :].cpu().numpy()
        return cls

    def _feature_row(self, text: str) -> np.ndarray:
        meta = _extended_meta_frame(text)
        if self.meta_columns:
            meta = meta.reindex(columns=self.meta_columns, fill_value=0.0)
        cls = self._cls_vector(text)
        return np.hstack([cls, meta.values.astype(float)])

    def predict_proba(self, text: str) -> float:
        row = self._feature_row(text)
        scaled = self.scaler.transform(row)
        return float(self.clf.predict_proba(scaled)[0][1])

    def predict(self, text: str, *, threshold: float | None = None) -> dict[str, Any]:
        if not text or not text.strip():
            return {
                "is_toxic": False,
                "probability": 0.0,
                "labels": [],
                "recommended_threshold": self.default_threshold,
            }
        proba = self.predict_proba(text)
        thresh = self.default_threshold if threshold is None else threshold
        tox = proba >= thresh
        return {
            "is_toxic": tox,
            "probability": proba,
            "labels": ["Offensive content"] if tox else [],
            "recommended_threshold": self.default_threshold,
        }