File size: 1,323 Bytes
47caaf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, XLMRobertaForSequenceClassification

# === Конфигурация ===
MODEL_NAME = "xlm-roberta-large"

CATEGORIES = [
    "politique", "woke", "racism", "crime",
    "police_abuse", "corruption", "hate_speech", "activism"
]

# === Загрузка модели ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = XLMRobertaForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(CATEGORIES)
)
model.eval()

# === FastAPI приложение ===
app = FastAPI(title="Multilabel Text Classifier API")

# === Схема запроса ===
class TextRequest(BaseModel):
    text: str

# === Логика классификации ===
def classify_message(message: str) -> List[str]:
    inputs = tokenizer(message, return_tensors="pt", truncation=True)
    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.sigmoid(logits)[0]
    selected = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
    return selected or ["neutral"]

# === Эндпоинт ===
@app.post("/classify")
def classify(request: TextRequest):
    categories = classify_message(request.text)
    return {
        "categories": categories
    }