SignalMod / src /service /meta_stack_predictor.py
Mirae Kang
feat: implement new models and improve UI, #23
46cc63a
raw
history blame
3.74 kB
"""Inference for Notebook 14 meta-feature stacking (frozen CLS + metadata + LR)."""
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from src.features.metadata_features import extract_metadata_features
MODEL_ID = "unitary/toxic-bert"
_EMOJI_PAT = re.compile(
"["
"\U0001f300-\U0001f9ff"
"\U0001f600-\U0001f64f"
"]+",
flags=re.UNICODE,
)
def _extended_meta_frame(text: str) -> pd.DataFrame:
df = pd.DataFrame({"Text": [text]})
base = extract_metadata_features(df, text_column="Text")
length = max(len(text), 1)
base = base.copy()
base["emoji_count"] = len(_EMOJI_PAT.findall(text))
base["punctuation_density"] = len(re.findall(r"[^\w\s]", text)) / length
return base.astype(float)
class MetaStackPredictor:
"""Load production bundle and score a single comment."""
def __init__(
self,
bundle_path: Path,
*,
manifest_path: Path | None = None,
frozen_model_id: str = MODEL_ID,
) -> None:
self.bundle_path = bundle_path
self.frozen_model_id = frozen_model_id
self.manifest: dict[str, Any] = {}
if manifest_path and manifest_path.is_file():
self.manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
bundle = joblib.load(bundle_path)
self.scaler = bundle["scaler"]
self.clf = bundle["clf"]
self.meta_columns: list[str] = list(bundle.get("meta_columns", []))
self.default_threshold = float(self.manifest.get("threshold", 0.381))
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._tokenizer = AutoTokenizer.from_pretrained(frozen_model_id)
self._bert = AutoModelForSequenceClassification.from_pretrained(frozen_model_id)
for p in self._bert.parameters():
p.requires_grad = False
self._bert.eval()
self._bert.to(self._device)
def _cls_vector(self, text: str) -> np.ndarray:
with torch.no_grad():
enc = self._tokenizer(
[text],
truncation=True,
max_length=128,
padding=True,
return_tensors="pt",
)
enc = {k: v.to(self._device) for k, v in enc.items()}
cls = self._bert.bert(**enc).last_hidden_state[:, 0, :].cpu().numpy()
return cls
def _feature_row(self, text: str) -> np.ndarray:
meta = _extended_meta_frame(text)
if self.meta_columns:
meta = meta.reindex(columns=self.meta_columns, fill_value=0.0)
cls = self._cls_vector(text)
return np.hstack([cls, meta.values.astype(float)])
def predict_proba(self, text: str) -> float:
row = self._feature_row(text)
scaled = self.scaler.transform(row)
return float(self.clf.predict_proba(scaled)[0][1])
def predict(self, text: str, *, threshold: float | None = None) -> dict[str, Any]:
if not text or not text.strip():
return {
"is_toxic": False,
"probability": 0.0,
"labels": [],
"recommended_threshold": self.default_threshold,
}
proba = self.predict_proba(text)
thresh = self.default_threshold if threshold is None else threshold
tox = proba >= thresh
return {
"is_toxic": tox,
"probability": proba,
"labels": ["Offensive content"] if tox else [],
"recommended_threshold": self.default_threshold,
}