from fastapi import FastAPI from pydantic import BaseModel from typing import List import torch from transformers import AutoTokenizer, XLMRobertaForSequenceClassification # === Конфигурация === MODEL_NAME = "xlm-roberta-large" CATEGORIES = [ "politique", "woke", "racism", "crime", "police_abuse", "corruption", "hate_speech", "activism" ] # === Загрузка модели === tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = XLMRobertaForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=len(CATEGORIES) ) model.eval() # === FastAPI приложение === app = FastAPI(title="Multilabel Text Classifier API") # === Схема запроса === class TextRequest(BaseModel): text: str # === Логика классификации === def classify_message(message: str) -> List[str]: inputs = tokenizer(message, return_tensors="pt", truncation=True) with torch.no_grad(): logits = model(**inputs).logits probs = torch.sigmoid(logits)[0] selected = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5] return selected or ["neutral"] # === Эндпоинт === @app.post("/classify") def classify(request: TextRequest): categories = classify_message(request.text) return { "categories": categories }