vineet88's picture
Deploy standalone ML service
16f57d9 verified
from __future__ import annotations
from dataclasses import dataclass
from functools import cached_property
from app.schemas import CategoryName, CategoryScore, ConversationTurn
@dataclass
class TransformerModerationModel:
artifact_path: str
max_length: int = 256
min_score: float = 0.35
backend_name: str = "transformer"
@cached_property
def tokenizer(self):
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.artifact_path)
@cached_property
def model(self):
from transformers import AutoModelForSequenceClassification
return AutoModelForSequenceClassification.from_pretrained(self.artifact_path)
@property
def is_loaded(self) -> bool:
return "model" in self.__dict__ and "tokenizer" in self.__dict__
def warmup(self) -> None:
_ = self.tokenizer
_ = self.model
def predict(
self,
text: str,
context: list[ConversationTurn],
) -> list[CategoryScore]:
import torch
packed_text = self._pack_input(text=text, context=context)
encoded = self.tokenizer(
packed_text,
truncation=True,
padding=False,
max_length=self.max_length,
return_tensors="pt",
)
with torch.no_grad():
logits = self.model(**encoded).logits[0]
probabilities = torch.sigmoid(logits).tolist()
results: list[CategoryScore] = []
id_to_label = getattr(self.model.config, "id2label", {})
for index, score in enumerate(probabilities):
label = id_to_label.get(index, str(index))
normalized_label = normalize_label(label)
if normalized_label is None or score < self.min_score:
continue
results.append(
CategoryScore(
name=normalized_label,
score=round(float(score), 4),
source="transformer",
rationale=f"Model logits crossed threshold for {normalized_label.value}.",
)
)
return sorted(results, key=lambda category: category.score, reverse=True)
def _pack_input(
self,
text: str,
context: list[ConversationTurn],
) -> str:
turns = []
for offset, turn in enumerate(context[-5:], start=1):
turns.append(f"[TURN-{offset}] {turn.role.value}: {turn.text}")
turns.append(f"[CURRENT] user: {text}")
return "\n".join(turns)
def normalize_label(label: str) -> CategoryName | None:
normalized = (
label.strip()
.lower()
.replace("-", "_")
.replace("/", "_")
.replace(" ", "_")
)
mapping = {
"harassment_or_insult": CategoryName.HARASSMENT_OR_INSULT,
"insult": CategoryName.HARASSMENT_OR_INSULT,
"harassment": CategoryName.HARASSMENT_OR_INSULT,
"toxic": CategoryName.HARASSMENT_OR_INSULT,
"severe_toxic": CategoryName.HARASSMENT_OR_INSULT,
"threat_or_violence": CategoryName.THREAT_OR_VIOLENCE,
"threat": CategoryName.THREAT_OR_VIOLENCE,
"violence": CategoryName.THREAT_OR_VIOLENCE,
"hate": CategoryName.HATE,
"identity_hate": CategoryName.HATE,
"self_harm": CategoryName.SELF_HARM,
"sexual_explicit": CategoryName.SEXUAL_EXPLICIT,
"sexual": CategoryName.SEXUAL_EXPLICIT,
"obscene": CategoryName.PROFANITY,
"profanity": CategoryName.PROFANITY,
"spam": CategoryName.SPAM,
}
return mapping.get(normalized)