from __future__ import annotations from dataclasses import dataclass from functools import cached_property from app.schemas import CategoryName, CategoryScore, ConversationTurn @dataclass class TransformerModerationModel: artifact_path: str max_length: int = 256 min_score: float = 0.35 backend_name: str = "transformer" @cached_property def tokenizer(self): from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(self.artifact_path) @cached_property def model(self): from transformers import AutoModelForSequenceClassification return AutoModelForSequenceClassification.from_pretrained(self.artifact_path) @property def is_loaded(self) -> bool: return "model" in self.__dict__ and "tokenizer" in self.__dict__ def warmup(self) -> None: _ = self.tokenizer _ = self.model def predict( self, text: str, context: list[ConversationTurn], ) -> list[CategoryScore]: import torch packed_text = self._pack_input(text=text, context=context) encoded = self.tokenizer( packed_text, truncation=True, padding=False, max_length=self.max_length, return_tensors="pt", ) with torch.no_grad(): logits = self.model(**encoded).logits[0] probabilities = torch.sigmoid(logits).tolist() results: list[CategoryScore] = [] id_to_label = getattr(self.model.config, "id2label", {}) for index, score in enumerate(probabilities): label = id_to_label.get(index, str(index)) normalized_label = normalize_label(label) if normalized_label is None or score < self.min_score: continue results.append( CategoryScore( name=normalized_label, score=round(float(score), 4), source="transformer", rationale=f"Model logits crossed threshold for {normalized_label.value}.", ) ) return sorted(results, key=lambda category: category.score, reverse=True) def _pack_input( self, text: str, context: list[ConversationTurn], ) -> str: turns = [] for offset, turn in enumerate(context[-5:], start=1): turns.append(f"[TURN-{offset}] {turn.role.value}: {turn.text}") turns.append(f"[CURRENT] user: {text}") return "\n".join(turns) def normalize_label(label: str) -> CategoryName | None: normalized = ( label.strip() .lower() .replace("-", "_") .replace("/", "_") .replace(" ", "_") ) mapping = { "harassment_or_insult": CategoryName.HARASSMENT_OR_INSULT, "insult": CategoryName.HARASSMENT_OR_INSULT, "harassment": CategoryName.HARASSMENT_OR_INSULT, "toxic": CategoryName.HARASSMENT_OR_INSULT, "severe_toxic": CategoryName.HARASSMENT_OR_INSULT, "threat_or_violence": CategoryName.THREAT_OR_VIOLENCE, "threat": CategoryName.THREAT_OR_VIOLENCE, "violence": CategoryName.THREAT_OR_VIOLENCE, "hate": CategoryName.HATE, "identity_hate": CategoryName.HATE, "self_harm": CategoryName.SELF_HARM, "sexual_explicit": CategoryName.SEXUAL_EXPLICIT, "sexual": CategoryName.SEXUAL_EXPLICIT, "obscene": CategoryName.PROFANITY, "profanity": CategoryName.PROFANITY, "spam": CategoryName.SPAM, } return mapping.get(normalized)