narutoSiskovich commited on
Commit
f6f0b87
·
verified ·
1 Parent(s): 5e8a79a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -47
app.py CHANGED
@@ -1,6 +1,5 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from typing import List
4
  import torch
5
  from transformers import (
6
  AutoTokenizer,
@@ -8,8 +7,6 @@ from transformers import (
8
  XLMRobertaForSequenceClassification,
9
  )
10
 
11
- app = FastAPI(title="Unified NLP API")
12
-
13
  # =====================
14
  # Agreement (MNLI)
15
  # =====================
@@ -32,7 +29,6 @@ def check_agreement(msg1: str, msg2: str) -> float:
32
  probs = torch.softmax(logits, dim=-1)[0]
33
  return round((probs[2] - probs[0]).item(), 2) # entailment - contradiction
34
 
35
-
36
  # =====================
37
  # Sentiment
38
  # =====================
@@ -56,17 +52,14 @@ def analyze_sentiment(text: str) -> float:
56
  stars = torch.argmax(probs, dim=-1).item() + 1
57
  return round((stars - 3) * 2.5, 2) # -5 .. +5
58
 
59
-
60
  # =====================
61
  # Multilabel classifier
62
  # =====================
63
  CLASSIFIER_MODEL = "xlm-roberta-base"
64
-
65
  CATEGORIES = [
66
  "politique", "woke", "racism", "crime",
67
  "police_abuse", "corruption", "hate_speech", "activism"
68
  ]
69
-
70
  clf_tokenizer = None
71
  clf_model = None
72
 
@@ -76,11 +69,12 @@ def load_classifier():
76
  clf_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL)
77
  clf_model = XLMRobertaForSequenceClassification.from_pretrained(
78
  CLASSIFIER_MODEL,
79
- num_labels=len(CATEGORIES)
 
80
  )
81
  clf_model.eval()
82
 
83
- def classify_message(text: str) -> List[str]:
84
  load_classifier()
85
  inputs = clf_tokenizer(text, return_tensors="pt", truncation=True)
86
  with torch.no_grad():
@@ -89,42 +83,31 @@ def classify_message(text: str) -> List[str]:
89
  labels = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
90
  return labels or ["neutral"]
91
 
92
-
93
- # =====================
94
- # API schemas
95
- # =====================
96
- class AgreementRequest(BaseModel):
97
- msg1: str
98
- msg2: str
99
-
100
- class TextRequest(BaseModel):
101
- text: str
102
-
103
-
104
  # =====================
105
- # Endpoints
106
  # =====================
107
- @app.post("/agreement")
108
- def agreement(req: AgreementRequest):
109
- return {"agreement_score": check_agreement(req.msg1, req.msg2)}
110
-
111
- @app.post("/sentiment")
112
- def sentiment(req: TextRequest):
113
- return {"sentiment_score": analyze_sentiment(req.text)}
114
-
115
- @app.post("/classify")
116
- def classify(req: TextRequest):
117
- return {"categories": classify_message(req.text)}
118
-
119
- @app.get("/")
120
- def root():
121
- return {
122
- "status": "ok",
123
- "endpoints": {
124
- "POST /sentiment": "sentiment analysis",
125
- "POST /agreement": "text agreement",
126
- "POST /classify": "multilabel classification",
127
- "GET /docs": "swagger UI"
128
- }
129
- }
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
 
3
  import torch
4
  from transformers import (
5
  AutoTokenizer,
 
7
  XLMRobertaForSequenceClassification,
8
  )
9
 
 
 
10
  # =====================
11
  # Agreement (MNLI)
12
  # =====================
 
29
  probs = torch.softmax(logits, dim=-1)[0]
30
  return round((probs[2] - probs[0]).item(), 2) # entailment - contradiction
31
 
 
32
  # =====================
33
  # Sentiment
34
  # =====================
 
52
  stars = torch.argmax(probs, dim=-1).item() + 1
53
  return round((stars - 3) * 2.5, 2) # -5 .. +5
54
 
 
55
  # =====================
56
  # Multilabel classifier
57
  # =====================
58
  CLASSIFIER_MODEL = "xlm-roberta-base"
 
59
  CATEGORIES = [
60
  "politique", "woke", "racism", "crime",
61
  "police_abuse", "corruption", "hate_speech", "activism"
62
  ]
 
63
  clf_tokenizer = None
64
  clf_model = None
65
 
 
69
  clf_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL)
70
  clf_model = XLMRobertaForSequenceClassification.from_pretrained(
71
  CLASSIFIER_MODEL,
72
+ num_labels=len(CATEGORIES),
73
+ problem_type="multi_label_classification"
74
  )
75
  clf_model.eval()
76
 
77
+ def classify_message(text: str) -> list:
78
  load_classifier()
79
  inputs = clf_tokenizer(text, return_tensors="pt", truncation=True)
80
  with torch.no_grad():
 
83
  labels = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
84
  return labels or ["neutral"]
85
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # =====================
87
+ # Gradio interfaces
88
  # =====================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ with gr.Blocks(title="Unified NLP API") as demo:
91
+
92
+ gr.Markdown("## 📈 Unified NLP API")
93
+
94
+ with gr.Tab("Agreement"):
95
+ msg1 = gr.Textbox(label="Message 1")
96
+ msg2 = gr.Textbox(label="Message 2")
97
+ btn_agree = gr.Button("Check Agreement")
98
+ out_agree = gr.Number(label="Agreement Score")
99
+ btn_agree.click(fn=check_agreement, inputs=[msg1, msg2], outputs=out_agree)
100
+
101
+ with gr.Tab("Sentiment"):
102
+ text_sent = gr.Textbox(label="Text")
103
+ btn_sent = gr.Button("Analyze Sentiment")
104
+ out_sent = gr.Number(label="Sentiment Score (-5 to +5)")
105
+ btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
106
+
107
+ with gr.Tab("Multilabel Classification"):
108
+ text_clf = gr.Textbox(label="Text")
109
+ btn_clf = gr.Button("Classify")
110
+ out_clf = gr.Label(label="Categories")
111
+ btn_clf.click(fn=classify_message, inputs=text_clf, outputs=out_clf)
112
+
113
+ demo.launch()