Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from functools import cached_property | |
| from app.schemas import CategoryName, CategoryScore, ConversationTurn | |
| class TransformerModerationModel: | |
| artifact_path: str | |
| max_length: int = 256 | |
| min_score: float = 0.35 | |
| backend_name: str = "transformer" | |
| def tokenizer(self): | |
| from transformers import AutoTokenizer | |
| return AutoTokenizer.from_pretrained(self.artifact_path) | |
| def model(self): | |
| from transformers import AutoModelForSequenceClassification | |
| return AutoModelForSequenceClassification.from_pretrained(self.artifact_path) | |
| def is_loaded(self) -> bool: | |
| return "model" in self.__dict__ and "tokenizer" in self.__dict__ | |
| def warmup(self) -> None: | |
| _ = self.tokenizer | |
| _ = self.model | |
| def predict( | |
| self, | |
| text: str, | |
| context: list[ConversationTurn], | |
| ) -> list[CategoryScore]: | |
| import torch | |
| packed_text = self._pack_input(text=text, context=context) | |
| encoded = self.tokenizer( | |
| packed_text, | |
| truncation=True, | |
| padding=False, | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| logits = self.model(**encoded).logits[0] | |
| probabilities = torch.sigmoid(logits).tolist() | |
| results: list[CategoryScore] = [] | |
| id_to_label = getattr(self.model.config, "id2label", {}) | |
| for index, score in enumerate(probabilities): | |
| label = id_to_label.get(index, str(index)) | |
| normalized_label = normalize_label(label) | |
| if normalized_label is None or score < self.min_score: | |
| continue | |
| results.append( | |
| CategoryScore( | |
| name=normalized_label, | |
| score=round(float(score), 4), | |
| source="transformer", | |
| rationale=f"Model logits crossed threshold for {normalized_label.value}.", | |
| ) | |
| ) | |
| return sorted(results, key=lambda category: category.score, reverse=True) | |
| def _pack_input( | |
| self, | |
| text: str, | |
| context: list[ConversationTurn], | |
| ) -> str: | |
| turns = [] | |
| for offset, turn in enumerate(context[-5:], start=1): | |
| turns.append(f"[TURN-{offset}] {turn.role.value}: {turn.text}") | |
| turns.append(f"[CURRENT] user: {text}") | |
| return "\n".join(turns) | |
| def normalize_label(label: str) -> CategoryName | None: | |
| normalized = ( | |
| label.strip() | |
| .lower() | |
| .replace("-", "_") | |
| .replace("/", "_") | |
| .replace(" ", "_") | |
| ) | |
| mapping = { | |
| "harassment_or_insult": CategoryName.HARASSMENT_OR_INSULT, | |
| "insult": CategoryName.HARASSMENT_OR_INSULT, | |
| "harassment": CategoryName.HARASSMENT_OR_INSULT, | |
| "toxic": CategoryName.HARASSMENT_OR_INSULT, | |
| "severe_toxic": CategoryName.HARASSMENT_OR_INSULT, | |
| "threat_or_violence": CategoryName.THREAT_OR_VIOLENCE, | |
| "threat": CategoryName.THREAT_OR_VIOLENCE, | |
| "violence": CategoryName.THREAT_OR_VIOLENCE, | |
| "hate": CategoryName.HATE, | |
| "identity_hate": CategoryName.HATE, | |
| "self_harm": CategoryName.SELF_HARM, | |
| "sexual_explicit": CategoryName.SEXUAL_EXPLICIT, | |
| "sexual": CategoryName.SEXUAL_EXPLICIT, | |
| "obscene": CategoryName.PROFANITY, | |
| "profanity": CategoryName.PROFANITY, | |
| "spam": CategoryName.SPAM, | |
| } | |
| return mapping.get(normalized) | |