narutoSiskovich commited on
Commit
1c8f881
·
verified ·
1 Parent(s): 6f3e861

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -26
app.py CHANGED
@@ -12,7 +12,7 @@ from transformers import (
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # =====================
15
- # Agreement (MNLI)
16
  # =====================
17
  MNLI_MODEL = "facebook/bart-large-mnli"
18
  mnli_tokenizer = None
@@ -26,17 +26,28 @@ def load_mnli():
26
  mnli_model.to(DEVICE)
27
  mnli_model.eval()
28
 
29
- def check_agreement(msg1: str, msg2: str) -> float:
 
 
 
 
30
  load_mnli()
31
  inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True).to(DEVICE)
32
  with torch.no_grad():
33
  logits = mnli_model(**inputs).logits
34
  probs = torch.softmax(logits, dim=-1)[0]
35
- # Считаем: entailment - contradiction
36
- return round((probs[2] - probs[0]).item(), 2)
 
 
 
 
 
 
 
37
 
38
  # =====================
39
- # Sentiment
40
  # =====================
41
  SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
42
  sent_tokenizer = None
@@ -51,17 +62,81 @@ def load_sentiment():
51
  sent_model.eval()
52
 
53
  def analyze_sentiment(text: str) -> float:
 
 
 
54
  load_sentiment()
55
  inputs = sent_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
56
  with torch.no_grad():
57
  logits = sent_model(**inputs).logits
58
  probs = torch.softmax(logits, dim=-1)
59
- stars = torch.argmax(probs, dim=-1).item() + 1
60
- # Приводим шкалу 1–5 к -5..+5
61
- return round((stars - 3) * 2.5, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # =====================
64
- # ZeroShot Classification
65
  # =====================
66
  ZS_MODEL = "facebook/bart-large-mnli"
67
  zs_classifier = None
@@ -81,40 +156,70 @@ def load_zero_shot():
81
  )
82
 
83
  def classify_message(text: str) -> dict:
 
 
 
 
84
  load_zero_shot()
85
- # Zero‑shot принимает список меток:
86
  result = zs_classifier(text, candidate_labels=CATEGORIES)
87
- scores = result["scores"]
88
  labels = result["labels"]
89
- # Возвращаем словарь {label: score}
90
- return {label: round(score, 3) for label, score in zip(labels, scores)}
 
 
 
 
 
 
91
 
92
  # =====================
93
  # Gradio UI
94
  # =====================
95
- with gr.Blocks(title="Unified NLP API") as demo:
96
- gr.Markdown("## 📈 Unified NLP API")
97
-
 
 
 
 
 
 
 
 
 
98
  # ----- Agreement Tab -----
99
- with gr.Tab("Agreement"):
100
  msg1 = gr.Textbox(label="Message 1")
101
  msg2 = gr.Textbox(label="Message 2")
 
102
  btn_agree = gr.Button("Check Agreement")
103
- out_agree = gr.Number(label="Agreement Score")
104
- btn_agree.click(fn=check_agreement, inputs=[msg1, msg2], outputs=out_agree)
105
-
 
 
 
 
 
106
  # ----- Sentiment Tab -----
107
- with gr.Tab("Sentiment"):
108
  text_sent = gr.Textbox(label="Text")
109
  btn_sent = gr.Button("Analyze Sentiment")
110
- out_sent = gr.Number(label="Sentiment Score (-5 to +5)")
111
  btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
112
-
113
- # ----- Multilabel (Zero‑Shot) Classification Tab -----
114
- with gr.Tab("Multilabel Classification"):
 
 
 
 
 
 
 
115
  text_clf = gr.Textbox(label="Text")
116
  btn_clf = gr.Button("Classify")
117
- out_clf = gr.Label(label="Categories & Scores")
118
  btn_clf.click(fn=classify_message, inputs=text_clf, outputs=out_clf)
119
 
120
  demo.launch()
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # =====================
15
+ # 1) Agreement (MNLI)
16
  # =====================
17
  MNLI_MODEL = "facebook/bart-large-mnli"
18
  mnli_tokenizer = None
 
26
  mnli_model.to(DEVICE)
27
  mnli_model.eval()
28
 
29
+ def agreement_raw_score(msg1: str, msg2: str) -> float:
30
+ """
31
+ Возвращает "сырое" согласие в диапазоне [-1..+1]
32
+ по формуле entailment - contradiction.
33
+ """
34
  load_mnli()
35
  inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True).to(DEVICE)
36
  with torch.no_grad():
37
  logits = mnli_model(**inputs).logits
38
  probs = torch.softmax(logits, dim=-1)[0]
39
+ raw = (probs[2] - probs[0]).item() # [-1..+1]
40
+ return raw
41
+
42
+ def agreement_score_minus5_plus5(msg1: str, msg2: str) -> float:
43
+ """
44
+ Agreement в шкале [-5..+5]
45
+ """
46
+ raw = agreement_raw_score(msg1, msg2)
47
+ return round(raw * 5, 2)
48
 
49
  # =====================
50
+ # 2) Sentiment (-5..+5)
51
  # =====================
52
  SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
53
  sent_tokenizer = None
 
62
  sent_model.eval()
63
 
64
  def analyze_sentiment(text: str) -> float:
65
+ """
66
+ Модель даёт 1..5 звёзд -> переводим в [-5..+5]
67
+ """
68
  load_sentiment()
69
  inputs = sent_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
70
  with torch.no_grad():
71
  logits = sent_model(**inputs).logits
72
  probs = torch.softmax(logits, dim=-1)
73
+ stars = torch.argmax(probs, dim=-1).item() + 1 # 1..5
74
+ score = (stars - 3) * 2.5 # -5..+5
75
+ return round(score, 2)
76
+
77
+ # =====================
78
+ # 3) Sarcasm / Irony (-5..+5)
79
+ # =====================
80
+ # Можно заменить модель на другую, если хочешь.
81
+ # Эта модель популярна для сарказма.
82
+ SARCASM_MODEL = "cardiffnlp/twitter-roberta-base-irony"
83
+ sarcasm_pipe = None
84
+
85
+ def load_sarcasm():
86
+ global sarcasm_pipe
87
+ if sarcasm_pipe is None:
88
+ sarcasm_pipe = pipeline(
89
+ "text-classification",
90
+ model=SARCASM_MODEL,
91
+ device=0 if torch.cuda.is_available() else -1,
92
+ truncation=True
93
+ )
94
+
95
+ def sarcasm_score(text: str) -> float:
96
+ """
97
+ Возвращает рейтинг сарказма в [-5..+5]
98
+ (чем выше, тем больше сарказма/иронии)
99
+ """
100
+ load_sarcasm()
101
+ res = sarcasm_pipe(text)[0]
102
+ # Обычно метки: "irony" / "non_irony"
103
+ label = res["label"].lower()
104
+ conf = float(res["score"]) # 0..1
105
+
106
+ if "irony" in label:
107
+ # 0..1 -> 0..+5
108
+ return round(conf * 5, 2)
109
+ else:
110
+ # 0..1 -> 0..-5
111
+ return round(-conf * 5, 2)
112
+
113
+ # =====================
114
+ # 4) Agreement + Sarcasm
115
+ # =====================
116
+ def agreement_with_irony(msg1: str, msg2: str) -> float:
117
+ """
118
+ Идея:
119
+ - считаем agreement [-5..+5]
120
+ - считаем сарказм msg2 (обычно сарказм в ответе важнее)
121
+ - если сарказм высокий, уменьшаем "уверенность" agreement
122
+
123
+ Это НЕ идеальная логика, но работает лучше, чем игнорировать иронию.
124
+ """
125
+ base = agreement_score_minus5_plus5(msg1, msg2)
126
+
127
+ s2 = sarcasm_score(msg2) # [-5..+5]
128
+ sarcasm_strength = abs(s2) / 5.0 # 0..1
129
+
130
+ # Чем больше сарказм, тем сильнее "сжимаем" agreement к нулю
131
+ # 0 сарказма -> множитель 1
132
+ # сильный сарказм -> множитель ~0.35
133
+ multiplier = 1.0 - 0.65 * sarcasm_strength
134
+
135
+ final_score = base * multiplier
136
+ return round(final_score, 2)
137
 
138
  # =====================
139
+ # 5) Zero-Shot Multilabel -> [-5..+5]
140
  # =====================
141
  ZS_MODEL = "facebook/bart-large-mnli"
142
  zs_classifier = None
 
156
  )
157
 
158
  def classify_message(text: str) -> dict:
159
+ """
160
+ Возвращает рейтинг категорий в [-5..+5]
161
+ (0.5 = нейтрально, >0.5 = ближе к +5, <0.5 = ближе к -5)
162
+ """
163
  load_zero_shot()
 
164
  result = zs_classifier(text, candidate_labels=CATEGORIES)
 
165
  labels = result["labels"]
166
+ scores = result["scores"]
167
+
168
+ # score 0..1 -> [-5..+5]
169
+ out = {}
170
+ for label, score in zip(labels, scores):
171
+ rating = (float(score) - 0.5) * 10
172
+ out[label] = round(rating, 2)
173
+ return out
174
 
175
  # =====================
176
  # Gradio UI
177
  # =====================
178
+ with gr.Blocks(title="Unified NLP API (-5..+5)") as demo:
179
+ gr.Markdown("## 📈 Unified NLP API (all scores: -5 .. +5)")
180
+ gr.Markdown(
181
+ """
182
+ **Что есть что:**
183
+ - **Agreement**: -5 = сильное противоречие, +5 = сильное согласие
184
+ - **Sentiment**: -5 = негатив, +5 = позитив
185
+ - **Sarcasm**: -5 = уверенно *не сарказм*, +5 = уверенно *сарказм/ирония*
186
+ - **Categories**: рейтинг уверенности (0.5 → 0, 1.0 → +5, 0.0 → -5)
187
+ """
188
+ )
189
+
190
  # ----- Agreement Tab -----
191
+ with gr.Tab("Agreement (-5..+5)"):
192
  msg1 = gr.Textbox(label="Message 1")
193
  msg2 = gr.Textbox(label="Message 2")
194
+
195
  btn_agree = gr.Button("Check Agreement")
196
+ out_agree = gr.Number(label="Agreement Score (-5..+5)")
197
+ btn_agree.click(fn=agreement_score_minus5_plus5, inputs=[msg1, msg2], outputs=out_agree)
198
+
199
+ gr.Markdown("### Agreement with Irony adjustment")
200
+ btn_agree_irony = gr.Button("Check Agreement (with irony)")
201
+ out_agree_irony = gr.Number(label="Agreement Score (irony-aware) (-5..+5)")
202
+ btn_agree_irony.click(fn=agreement_with_irony, inputs=[msg1, msg2], outputs=out_agree_irony)
203
+
204
  # ----- Sentiment Tab -----
205
+ with gr.Tab("Sentiment (-5..+5)"):
206
  text_sent = gr.Textbox(label="Text")
207
  btn_sent = gr.Button("Analyze Sentiment")
208
+ out_sent = gr.Number(label="Sentiment Score (-5..+5)")
209
  btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
210
+
211
+ # ----- Sarcasm Tab -----
212
+ with gr.Tab("Sarcasm / Irony (-5..+5)"):
213
+ text_sarc = gr.Textbox(label="Text")
214
+ btn_sarc = gr.Button("Analyze Sarcasm")
215
+ out_sarc = gr.Number(label="Sarcasm Score (-5..+5)")
216
+ btn_sarc.click(fn=sarcasm_score, inputs=text_sarc, outputs=out_sarc)
217
+
218
+ # ----- Multilabel (Zero-Shot) Classification Tab -----
219
+ with gr.Tab("Multilabel Classification (-5..+5)"):
220
  text_clf = gr.Textbox(label="Text")
221
  btn_clf = gr.Button("Classify")
222
+ out_clf = gr.Label(label="Categories & Scores (-5..+5)")
223
  btn_clf.click(fn=classify_message, inputs=text_clf, outputs=out_clf)
224
 
225
  demo.launch()