classifier / classifier.py
narutoSiskovich's picture
Create classifier.py
47caaf3 verified
raw
history blame
1.32 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, XLMRobertaForSequenceClassification
# === Конфигурация ===
MODEL_NAME = "xlm-roberta-large"
CATEGORIES = [
"politique", "woke", "racism", "crime",
"police_abuse", "corruption", "hate_speech", "activism"
]
# === Загрузка модели ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = XLMRobertaForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=len(CATEGORIES)
)
model.eval()
# === FastAPI приложение ===
app = FastAPI(title="Multilabel Text Classifier API")
# === Схема запроса ===
class TextRequest(BaseModel):
text: str
# === Логика классификации ===
def classify_message(message: str) -> List[str]:
inputs = tokenizer(message, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits)[0]
selected = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
return selected or ["neutral"]
# === Эндпоинт ===
@app.post("/classify")
def classify(request: TextRequest):
categories = classify_message(request.text)
return {
"categories": categories
}