narutoSiskovich commited on
Commit
2d1205b
·
verified ·
1 Parent(s): a72dc55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -12
app.py CHANGED
@@ -1,12 +1,98 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
-
4
- from agreement_score import check_agreement
5
- from sentimental import analyze_sentiment
6
- from classifier import classify_message
 
 
 
7
 
8
  app = FastAPI(title="Unified NLP API")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class AgreementRequest(BaseModel):
11
  msg1: str
12
  msg2: str
@@ -15,19 +101,17 @@ class TextRequest(BaseModel):
15
  text: str
16
 
17
 
 
 
 
18
  @app.post("/agreement")
19
  def agreement(req: AgreementRequest):
20
- score = check_agreement(req.msg1, req.msg2)
21
- return {"agreement_score": score}
22
-
23
 
24
  @app.post("/sentiment")
25
  def sentiment(req: TextRequest):
26
- score = analyze_sentiment(req.text)
27
- return {"sentiment_score": score}
28
-
29
 
30
  @app.post("/classify")
31
  def classify(req: TextRequest):
32
- categories = classify_message(req.text)
33
- return {"categories": categories}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from typing import List
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForSequenceClassification,
8
+ XLMRobertaForSequenceClassification,
9
+ )
10
 
11
  app = FastAPI(title="Unified NLP API")
12
 
13
+ # =====================
14
+ # Agreement (MNLI)
15
+ # =====================
16
+ MNLI_MODEL = "facebook/bart-base-mnli"
17
+ mnli_tokenizer = None
18
+ mnli_model = None
19
+
20
+ def load_mnli():
21
+ global mnli_tokenizer, mnli_model
22
+ if mnli_model is None:
23
+ mnli_tokenizer = AutoTokenizer.from_pretrained(MNLI_MODEL)
24
+ mnli_model = AutoModelForSequenceClassification.from_pretrained(MNLI_MODEL)
25
+ mnli_model.eval()
26
+
27
+ def check_agreement(msg1: str, msg2: str) -> float:
28
+ load_mnli()
29
+ inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True)
30
+ with torch.no_grad():
31
+ logits = mnli_model(**inputs).logits
32
+ probs = torch.softmax(logits, dim=-1)[0]
33
+ return round((probs[2] - probs[0]).item(), 2) # entailment - contradiction
34
+
35
+
36
+ # =====================
37
+ # Sentiment
38
+ # =====================
39
+ SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
40
+ sent_tokenizer = None
41
+ sent_model = None
42
+
43
+ def load_sentiment():
44
+ global sent_tokenizer, sent_model
45
+ if sent_model is None:
46
+ sent_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL)
47
+ sent_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL)
48
+ sent_model.eval()
49
+
50
+ def analyze_sentiment(text: str) -> float:
51
+ load_sentiment()
52
+ inputs = sent_tokenizer(text, return_tensors="pt", truncation=True)
53
+ with torch.no_grad():
54
+ logits = sent_model(**inputs).logits
55
+ probs = torch.softmax(logits, dim=-1)
56
+ stars = torch.argmax(probs, dim=-1).item() + 1
57
+ return round((stars - 3) * 2.5, 2) # -5 .. +5
58
+
59
+
60
+ # =====================
61
+ # Multilabel classifier
62
+ # =====================
63
+ CLASSIFIER_MODEL = "xlm-roberta-base"
64
+
65
+ CATEGORIES = [
66
+ "politique", "woke", "racism", "crime",
67
+ "police_abuse", "corruption", "hate_speech", "activism"
68
+ ]
69
+
70
+ clf_tokenizer = None
71
+ clf_model = None
72
+
73
+ def load_classifier():
74
+ global clf_tokenizer, clf_model
75
+ if clf_model is None:
76
+ clf_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL)
77
+ clf_model = XLMRobertaForSequenceClassification.from_pretrained(
78
+ CLASSIFIER_MODEL,
79
+ num_labels=len(CATEGORIES)
80
+ )
81
+ clf_model.eval()
82
+
83
+ def classify_message(text: str) -> List[str]:
84
+ load_classifier()
85
+ inputs = clf_tokenizer(text, return_tensors="pt", truncation=True)
86
+ with torch.no_grad():
87
+ logits = clf_model(**inputs).logits
88
+ probs = torch.sigmoid(logits)[0]
89
+ labels = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
90
+ return labels or ["neutral"]
91
+
92
+
93
+ # =====================
94
+ # API schemas
95
+ # =====================
96
  class AgreementRequest(BaseModel):
97
  msg1: str
98
  msg2: str
 
101
  text: str
102
 
103
 
104
+ # =====================
105
+ # Endpoints
106
+ # =====================
107
  @app.post("/agreement")
108
  def agreement(req: AgreementRequest):
109
+ return {"agreement_score": check_agreement(req.msg1, req.msg2)}
 
 
110
 
111
  @app.post("/sentiment")
112
  def sentiment(req: TextRequest):
113
+ return {"sentiment_score": analyze_sentiment(req.text)}
 
 
114
 
115
  @app.post("/classify")
116
  def classify(req: TextRequest):
117
+ return {"categories": classify_message(req.text)}