narutoSiskovich commited on
Commit
4d26922
·
verified ·
1 Parent(s): 99aa3de

Delete classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +0 -46
classifier.py DELETED
@@ -1,46 +0,0 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from typing import List
4
- import torch
5
- from transformers import AutoTokenizer, XLMRobertaForSequenceClassification
6
-
7
- # === Конфигурация ===
8
- MODEL_NAME = "xlm-roberta-large"
9
-
10
- CATEGORIES = [
11
- "politique", "woke", "racism", "crime",
12
- "police_abuse", "corruption", "hate_speech", "activism"
13
- ]
14
-
15
- # === Загрузка модели ===
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
- model = XLMRobertaForSequenceClassification.from_pretrained(
18
- MODEL_NAME,
19
- num_labels=len(CATEGORIES)
20
- )
21
- model.eval()
22
-
23
- # === FastAPI приложение ===
24
- app = FastAPI(title="Multilabel Text Classifier API")
25
-
26
- # === Схема запроса ===
27
- class TextRequest(BaseModel):
28
- text: str
29
-
30
- # === Логика классификации ===
31
- def classify_message(message: str) -> List[str]:
32
- inputs = tokenizer(message, return_tensors="pt", truncation=True)
33
- with torch.no_grad():
34
- logits = model(**inputs).logits
35
-
36
- probs = torch.sigmoid(logits)[0]
37
- selected = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
38
- return selected or ["neutral"]
39
-
40
- # === Эндпоинт ===
41
- @app.post("/classify")
42
- def classify(request: TextRequest):
43
- categories = classify_message(request.text)
44
- return {
45
- "categories": categories
46
- }