Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
-
from transformers import
|
| 4 |
-
AutoTokenizer,
|
| 5 |
-
AutoModelForSequenceClassification,
|
| 6 |
-
pipeline,
|
| 7 |
-
)
|
| 8 |
|
| 9 |
# =====================
|
| 10 |
# DEVICE
|
|
@@ -19,7 +15,6 @@ def clamp(x: float, lo: float = -5.0, hi: float = 5.0) -> float:
|
|
| 19 |
|
| 20 |
def score01_to_minus5_plus5(p: float) -> float:
|
| 21 |
"""
|
| 22 |
-
Перевод вероятности 0..1 в шкалу -5..+5:
|
| 23 |
0.0 -> -5
|
| 24 |
0.5 -> 0
|
| 25 |
1.0 -> +5
|
|
@@ -43,8 +38,8 @@ def load_mnli():
|
|
| 43 |
|
| 44 |
def agreement_score_minus5_plus5(msg1: str, msg2: str) -> float:
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
-
|
| 48 |
"""
|
| 49 |
load_mnli()
|
| 50 |
inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True).to(DEVICE)
|
|
@@ -52,7 +47,8 @@ def agreement_score_minus5_plus5(msg1: str, msg2: str) -> float:
|
|
| 52 |
logits = mnli_model(**inputs).logits
|
| 53 |
probs = torch.softmax(logits, dim=-1)[0]
|
| 54 |
|
| 55 |
-
|
|
|
|
| 56 |
return round(clamp(raw * 5), 2)
|
| 57 |
|
| 58 |
# =====================
|
|
@@ -79,7 +75,7 @@ def analyze_sentiment(text: str) -> float:
|
|
| 79 |
with torch.no_grad():
|
| 80 |
logits = sent_model(**inputs).logits
|
| 81 |
probs = torch.softmax(logits, dim=-1)
|
| 82 |
-
stars = torch.argmax(probs, dim=-1).item() + 1
|
| 83 |
score = (stars - 3) * 2.5
|
| 84 |
return round(clamp(score), 2)
|
| 85 |
|
|
@@ -96,18 +92,18 @@ def load_sarcasm():
|
|
| 96 |
"text-classification",
|
| 97 |
model=SARCASM_MODEL,
|
| 98 |
device=0 if torch.cuda.is_available() else -1,
|
| 99 |
-
truncation=True
|
| 100 |
)
|
| 101 |
|
| 102 |
def sarcasm_score(text: str) -> float:
|
| 103 |
"""
|
| 104 |
-
+5 =
|
| 105 |
-
-5 =
|
| 106 |
"""
|
| 107 |
load_sarcasm()
|
| 108 |
res = sarcasm_pipe(text)[0]
|
| 109 |
label = res["label"].lower()
|
| 110 |
-
conf = float(res["score"])
|
| 111 |
|
| 112 |
if "irony" in label:
|
| 113 |
return round(clamp(conf * 5), 2)
|
|
@@ -116,15 +112,14 @@ def sarcasm_score(text: str) -> float:
|
|
| 116 |
# =====================
|
| 117 |
# 4) Neutrality -> [-5..+5]
|
| 118 |
# =====================
|
| 119 |
-
# +5 = максимально нейтрально
|
| 120 |
-
# -5 = максимально заряжено/эмоционально/полярно
|
| 121 |
-
#
|
| 122 |
-
# Простая логика:
|
| 123 |
-
# neutrality = 5 - (|sentiment| + max(0, sarcasm))/2
|
| 124 |
-
# (сарказм делает текст менее нейтральным)
|
| 125 |
def neutrality_score(text: str) -> float:
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
neutrality = 5.0 - (sent + sarc) / 2.0
|
| 129 |
return round(clamp(neutrality), 2)
|
| 130 |
|
|
@@ -132,24 +127,22 @@ def neutrality_score(text: str) -> float:
|
|
| 132 |
# 5) Agreement with irony adjustment
|
| 133 |
# =====================
|
| 134 |
def agreement_with_irony(msg1: str, msg2: str) -> float:
|
| 135 |
-
"""
|
| 136 |
-
Ирония снижает "уверенность" agreement.
|
| 137 |
-
"""
|
| 138 |
base = agreement_score_minus5_plus5(msg1, msg2)
|
|
|
|
| 139 |
s2 = max(0.0, sarcasm_score(msg2)) # 0..5
|
| 140 |
sarcasm_strength = s2 / 5.0 # 0..1
|
| 141 |
|
|
|
|
| 142 |
multiplier = 1.0 - 0.65 * sarcasm_strength
|
| 143 |
final_score = base * multiplier
|
| 144 |
return round(clamp(final_score), 2)
|
| 145 |
|
| 146 |
# =====================
|
| 147 |
-
# 6) Zero-Shot
|
| 148 |
# =====================
|
| 149 |
ZS_MODEL = "facebook/bart-large-mnli"
|
| 150 |
zs_classifier = None
|
| 151 |
|
| 152 |
-
# Твои категории + расширение под Twitter/X дискуссии
|
| 153 |
CATEGORIES = [
|
| 154 |
# базовые
|
| 155 |
"politique",
|
|
@@ -192,14 +185,10 @@ def load_zero_shot():
|
|
| 192 |
zs_classifier = pipeline(
|
| 193 |
"zero-shot-classification",
|
| 194 |
model=ZS_MODEL,
|
| 195 |
-
device=0 if torch.cuda.is_available() else -1
|
| 196 |
)
|
| 197 |
|
| 198 |
def classify_message(text: str) -> dict:
|
| 199 |
-
"""
|
| 200 |
-
Возвращает словарь {label: rating} где rating в [-5..+5]
|
| 201 |
-
Важно: это не "истина", а "уверенность модели" относительно метки.
|
| 202 |
-
"""
|
| 203 |
load_zero_shot()
|
| 204 |
result = zs_classifier(text, candidate_labels=CATEGORIES, multi_label=True)
|
| 205 |
|
|
@@ -223,11 +212,10 @@ with gr.Blocks(title="Unified NLP API (-5..+5)") as demo:
|
|
| 223 |
- **Sentiment**: -5 = негатив, +5 = позитив
|
| 224 |
- **Sarcasm**: -5 = уверенно НЕ сарказм, +5 = уверенно сарказм/ирония
|
| 225 |
- **Neutrality**: +5 = максимально нейтрально, -5 = максимально “заряжено”
|
| 226 |
-
- **Multilabel**:
|
| 227 |
"""
|
| 228 |
)
|
| 229 |
|
| 230 |
-
# Agreement
|
| 231 |
with gr.Tab("Agreement"):
|
| 232 |
msg1 = gr.Textbox(label="Message 1")
|
| 233 |
msg2 = gr.Textbox(label="Message 2")
|
|
@@ -241,30 +229,28 @@ with gr.Blocks(title="Unified NLP API (-5..+5)") as demo:
|
|
| 241 |
out_agree_irony = gr.Number(label="Agreement Score (irony-aware) (-5..+5)")
|
| 242 |
btn_agree_irony.click(fn=agreement_with_irony, inputs=[msg1, msg2], outputs=out_agree_irony)
|
| 243 |
|
| 244 |
-
# Sentiment
|
| 245 |
with gr.Tab("Sentiment"):
|
| 246 |
text_sent = gr.Textbox(label="Text")
|
| 247 |
btn_sent = gr.Button("Analyze Sentiment")
|
| 248 |
out_sent = gr.Number(label="Sentiment Score (-5..+5)")
|
| 249 |
btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
|
| 250 |
|
| 251 |
-
# Sarcasm
|
| 252 |
with gr.Tab("Sarcasm / Irony"):
|
| 253 |
text_sarc = gr.Textbox(label="Text")
|
| 254 |
btn_sarc = gr.Button("Analyze Sarcasm")
|
| 255 |
out_sarc = gr.Number(label="Sarcasm Score (-5..+5)")
|
| 256 |
btn_sarc.click(fn=sarcasm_score, inputs=text_sarc, outputs=out_sarc)
|
| 257 |
|
| 258 |
-
# Neutrality
|
| 259 |
with gr.Tab("Neutrality"):
|
| 260 |
text_neu = gr.Textbox(label="Text")
|
| 261 |
btn_neu = gr.Button("Analyze Neutrality")
|
| 262 |
out_neu = gr.Number(label="Neutrality Score (-5..+5)")
|
| 263 |
btn_neu.click(fn=neutrality_score, inputs=text_neu, outputs=out_neu)
|
| 264 |
|
| 265 |
-
# Multilabel
|
| 266 |
with gr.Tab("Multilabel Classification"):
|
| 267 |
text_clf = gr.Textbox(label="Text")
|
| 268 |
btn_clf = gr.Button("Classify")
|
| 269 |
out_clf = gr.Label(label="Categories & Scores (-5..+5)")
|
| 270 |
-
btn_clf.click(fn=classify_message, inputs=text_clf, outputs=
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# =====================
|
| 6 |
# DEVICE
|
|
|
|
| 15 |
|
| 16 |
def score01_to_minus5_plus5(p: float) -> float:
|
| 17 |
"""
|
|
|
|
| 18 |
0.0 -> -5
|
| 19 |
0.5 -> 0
|
| 20 |
1.0 -> +5
|
|
|
|
| 38 |
|
| 39 |
def agreement_score_minus5_plus5(msg1: str, msg2: str) -> float:
|
| 40 |
"""
|
| 41 |
+
-5 = contradiction
|
| 42 |
+
+5 = entailment
|
| 43 |
"""
|
| 44 |
load_mnli()
|
| 45 |
inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True).to(DEVICE)
|
|
|
|
| 47 |
logits = mnli_model(**inputs).logits
|
| 48 |
probs = torch.softmax(logits, dim=-1)[0]
|
| 49 |
|
| 50 |
+
# entailment - contradiction => [-1..+1]
|
| 51 |
+
raw = (probs[2] - probs[0]).item()
|
| 52 |
return round(clamp(raw * 5), 2)
|
| 53 |
|
| 54 |
# =====================
|
|
|
|
| 75 |
with torch.no_grad():
|
| 76 |
logits = sent_model(**inputs).logits
|
| 77 |
probs = torch.softmax(logits, dim=-1)
|
| 78 |
+
stars = torch.argmax(probs, dim=-1).item() + 1
|
| 79 |
score = (stars - 3) * 2.5
|
| 80 |
return round(clamp(score), 2)
|
| 81 |
|
|
|
|
| 92 |
"text-classification",
|
| 93 |
model=SARCASM_MODEL,
|
| 94 |
device=0 if torch.cuda.is_available() else -1,
|
| 95 |
+
truncation=True,
|
| 96 |
)
|
| 97 |
|
| 98 |
def sarcasm_score(text: str) -> float:
|
| 99 |
"""
|
| 100 |
+
+5 = irony
|
| 101 |
+
-5 = non-irony
|
| 102 |
"""
|
| 103 |
load_sarcasm()
|
| 104 |
res = sarcasm_pipe(text)[0]
|
| 105 |
label = res["label"].lower()
|
| 106 |
+
conf = float(res["score"])
|
| 107 |
|
| 108 |
if "irony" in label:
|
| 109 |
return round(clamp(conf * 5), 2)
|
|
|
|
| 112 |
# =====================
|
| 113 |
# 4) Neutrality -> [-5..+5]
|
| 114 |
# =====================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def neutrality_score(text: str) -> float:
|
| 116 |
+
"""
|
| 117 |
+
+5 = максимально нейтрально
|
| 118 |
+
-5 = максимально эмоционально/заряжено
|
| 119 |
+
"""
|
| 120 |
+
sent = abs(analyze_sentiment(text)) # 0..5
|
| 121 |
+
sarc = max(0.0, sarcasm_score(text)) # 0..5 (только если irony)
|
| 122 |
+
|
| 123 |
neutrality = 5.0 - (sent + sarc) / 2.0
|
| 124 |
return round(clamp(neutrality), 2)
|
| 125 |
|
|
|
|
| 127 |
# 5) Agreement with irony adjustment
|
| 128 |
# =====================
|
| 129 |
def agreement_with_irony(msg1: str, msg2: str) -> float:
|
|
|
|
|
|
|
|
|
|
| 130 |
base = agreement_score_minus5_plus5(msg1, msg2)
|
| 131 |
+
|
| 132 |
s2 = max(0.0, sarcasm_score(msg2)) # 0..5
|
| 133 |
sarcasm_strength = s2 / 5.0 # 0..1
|
| 134 |
|
| 135 |
+
# чем больше сарказм, тем меньше доверяем agreement
|
| 136 |
multiplier = 1.0 - 0.65 * sarcasm_strength
|
| 137 |
final_score = base * multiplier
|
| 138 |
return round(clamp(final_score), 2)
|
| 139 |
|
| 140 |
# =====================
|
| 141 |
+
# 6) Multilabel Zero-Shot -> [-5..+5]
|
| 142 |
# =====================
|
| 143 |
ZS_MODEL = "facebook/bart-large-mnli"
|
| 144 |
zs_classifier = None
|
| 145 |
|
|
|
|
| 146 |
CATEGORIES = [
|
| 147 |
# базовые
|
| 148 |
"politique",
|
|
|
|
| 185 |
zs_classifier = pipeline(
|
| 186 |
"zero-shot-classification",
|
| 187 |
model=ZS_MODEL,
|
| 188 |
+
device=0 if torch.cuda.is_available() else -1,
|
| 189 |
)
|
| 190 |
|
| 191 |
def classify_message(text: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
load_zero_shot()
|
| 193 |
result = zs_classifier(text, candidate_labels=CATEGORIES, multi_label=True)
|
| 194 |
|
|
|
|
| 212 |
- **Sentiment**: -5 = негатив, +5 = позитив
|
| 213 |
- **Sarcasm**: -5 = уверенно НЕ сарказм, +5 = уверенно сарказм/ирония
|
| 214 |
- **Neutrality**: +5 = максимально нейтрально, -5 = максимально “заряжено”
|
| 215 |
+
- **Multilabel**: уверенность метки в шкале -5..+5 (0.5 → 0)
|
| 216 |
"""
|
| 217 |
)
|
| 218 |
|
|
|
|
| 219 |
with gr.Tab("Agreement"):
|
| 220 |
msg1 = gr.Textbox(label="Message 1")
|
| 221 |
msg2 = gr.Textbox(label="Message 2")
|
|
|
|
| 229 |
out_agree_irony = gr.Number(label="Agreement Score (irony-aware) (-5..+5)")
|
| 230 |
btn_agree_irony.click(fn=agreement_with_irony, inputs=[msg1, msg2], outputs=out_agree_irony)
|
| 231 |
|
|
|
|
| 232 |
with gr.Tab("Sentiment"):
|
| 233 |
text_sent = gr.Textbox(label="Text")
|
| 234 |
btn_sent = gr.Button("Analyze Sentiment")
|
| 235 |
out_sent = gr.Number(label="Sentiment Score (-5..+5)")
|
| 236 |
btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
|
| 237 |
|
|
|
|
| 238 |
with gr.Tab("Sarcasm / Irony"):
|
| 239 |
text_sarc = gr.Textbox(label="Text")
|
| 240 |
btn_sarc = gr.Button("Analyze Sarcasm")
|
| 241 |
out_sarc = gr.Number(label="Sarcasm Score (-5..+5)")
|
| 242 |
btn_sarc.click(fn=sarcasm_score, inputs=text_sarc, outputs=out_sarc)
|
| 243 |
|
|
|
|
| 244 |
with gr.Tab("Neutrality"):
|
| 245 |
text_neu = gr.Textbox(label="Text")
|
| 246 |
btn_neu = gr.Button("Analyze Neutrality")
|
| 247 |
out_neu = gr.Number(label="Neutrality Score (-5..+5)")
|
| 248 |
btn_neu.click(fn=neutrality_score, inputs=text_neu, outputs=out_neu)
|
| 249 |
|
|
|
|
| 250 |
with gr.Tab("Multilabel Classification"):
|
| 251 |
text_clf = gr.Textbox(label="Text")
|
| 252 |
btn_clf = gr.Button("Classify")
|
| 253 |
out_clf = gr.Label(label="Categories & Scores (-5..+5)")
|
| 254 |
+
btn_clf.click(fn=classify_message, inputs=text_clf, outputs=out_clf)
|
| 255 |
+
|
| 256 |
+
demo.launch()
|