narutoSiskovich commited on
Commit
47caaf3
·
verified ·
1 Parent(s): 9380e7e

Create classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +46 -0
classifier.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ import torch
5
+ from transformers import AutoTokenizer, XLMRobertaForSequenceClassification
6
+
7
+ # === Конфигурация ===
8
+ MODEL_NAME = "xlm-roberta-large"
9
+
10
+ CATEGORIES = [
11
+ "politique", "woke", "racism", "crime",
12
+ "police_abuse", "corruption", "hate_speech", "activism"
13
+ ]
14
+
15
+ # === Загрузка модели ===
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = XLMRobertaForSequenceClassification.from_pretrained(
18
+ MODEL_NAME,
19
+ num_labels=len(CATEGORIES)
20
+ )
21
+ model.eval()
22
+
23
+ # === FastAPI приложение ===
24
+ app = FastAPI(title="Multilabel Text Classifier API")
25
+
26
+ # === Схема запроса ===
27
+ class TextRequest(BaseModel):
28
+ text: str
29
+
30
+ # === Логика классификации ===
31
+ def classify_message(message: str) -> List[str]:
32
+ inputs = tokenizer(message, return_tensors="pt", truncation=True)
33
+ with torch.no_grad():
34
+ logits = model(**inputs).logits
35
+
36
+ probs = torch.sigmoid(logits)[0]
37
+ selected = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
38
+ return selected or ["neutral"]
39
+
40
+ # === Эндпоинт ===
41
+ @app.post("/classify")
42
+ def classify(request: TextRequest):
43
+ categories = classify_message(request.text)
44
+ return {
45
+ "categories": categories
46
+ }