classifier / app.py
narutoSiskovich's picture
Update app.py
f0cf89f verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
# =====================
# DEVICE
# =====================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# =====================
# Helpers
# =====================
def clamp(x: float, lo: float = -5.0, hi: float = 5.0) -> float:
return max(lo, min(hi, x))
def score01_to_minus5_plus5(p: float) -> float:
"""
0.0 -> -5
0.5 -> 0
1.0 -> +5
"""
return clamp((float(p) - 0.5) * 10)
# =====================
# 1) Agreement (MNLI) -> [-5..+5]
# =====================
MNLI_MODEL = "facebook/bart-large-mnli"
mnli_tokenizer = None
mnli_model = None
def load_mnli():
global mnli_tokenizer, mnli_model
if mnli_model is None:
mnli_tokenizer = AutoTokenizer.from_pretrained(MNLI_MODEL)
mnli_model = AutoModelForSequenceClassification.from_pretrained(MNLI_MODEL)
mnli_model.to(DEVICE)
mnli_model.eval()
def agreement_score_minus5_plus5(msg1: str, msg2: str) -> float:
"""
-5 = contradiction
+5 = entailment
"""
load_mnli()
inputs = mnli_tokenizer(msg1, msg2, return_tensors="pt", truncation=True).to(DEVICE)
with torch.no_grad():
logits = mnli_model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
# entailment - contradiction => [-1..+1]
raw = (probs[2] - probs[0]).item()
return round(clamp(raw * 5), 2)
# =====================
# 2) Sentiment -> [-5..+5]
# =====================
SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
sent_tokenizer = None
sent_model = None
def load_sentiment():
global sent_tokenizer, sent_model
if sent_model is None:
sent_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL)
sent_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL)
sent_model.to(DEVICE)
sent_model.eval()
def analyze_sentiment(text: str) -> float:
"""
1..5 stars -> [-5..+5]
"""
load_sentiment()
inputs = sent_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
with torch.no_grad():
logits = sent_model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
stars = torch.argmax(probs, dim=-1).item() + 1
score = (stars - 3) * 2.5
return round(clamp(score), 2)
# =====================
# 3) Sarcasm / Irony -> [-5..+5]
# =====================
SARCASM_MODEL = "cardiffnlp/twitter-roberta-base-irony"
sarcasm_pipe = None
def load_sarcasm():
global sarcasm_pipe
if sarcasm_pipe is None:
sarcasm_pipe = pipeline(
"text-classification",
model=SARCASM_MODEL,
device=0 if torch.cuda.is_available() else -1,
truncation=True,
)
def sarcasm_score(text: str) -> float:
"""
+5 = irony
-5 = non-irony
"""
load_sarcasm()
res = sarcasm_pipe(text)[0]
label = res["label"].lower()
conf = float(res["score"])
if "irony" in label:
return round(clamp(conf * 5), 2)
return round(clamp(-conf * 5), 2)
# =====================
# 4) Neutrality -> [-5..+5]
# =====================
def neutrality_score(text: str) -> float:
"""
+5 = максимально нейтрально
-5 = максимально эмоционально/заряжено
"""
sent = abs(analyze_sentiment(text)) # 0..5
sarc = max(0.0, sarcasm_score(text)) # 0..5 (только если irony)
neutrality = 5.0 - (sent + sarc) / 2.0
return round(clamp(neutrality), 2)
# =====================
# 5) Agreement with irony adjustment
# =====================
def agreement_with_irony(msg1: str, msg2: str) -> float:
base = agreement_score_minus5_plus5(msg1, msg2)
s2 = max(0.0, sarcasm_score(msg2)) # 0..5
sarcasm_strength = s2 / 5.0 # 0..1
# чем больше сарказм, тем меньше доверяем agreement
multiplier = 1.0 - 0.65 * sarcasm_strength
final_score = base * multiplier
return round(clamp(final_score), 2)
# =====================
# 6) Multilabel Zero-Shot -> [-5..+5]
# =====================
ZS_MODEL = "facebook/bart-large-mnli"
zs_classifier = None
CATEGORIES = [
# базовые
"politique",
"woke",
"racism",
"crime",
"police_abuse",
"corruption",
"hate_speech",
"activism",
# типичные твиттер-дискуссии
"outrage / moral outrage",
"cancel culture",
"culture war",
"polarization / us vs them",
"misinformation / fake news",
"conspiracy / deep state",
"propaganda / spin",
"whataboutism",
"virtue signaling",
"dogwhistle / coded language",
"trolling / bait",
"ragebait",
"harassment / bullying",
"callout / public shaming",
"ratio / pile-on",
"stan / fandom war",
"hot take",
"doomposting",
"memes / shitposting",
"political satire",
"debunking / fact-checking",
"support / solidarity",
]
def load_zero_shot():
global zs_classifier
if zs_classifier is None:
zs_classifier = pipeline(
"zero-shot-classification",
model=ZS_MODEL,
device=0 if torch.cuda.is_available() else -1,
)
def classify_message(text: str) -> dict:
load_zero_shot()
result = zs_classifier(text, candidate_labels=CATEGORIES, multi_label=True)
labels = result["labels"]
scores = result["scores"]
out = {}
for label, score in zip(labels, scores):
out[label] = round(score01_to_minus5_plus5(score), 2)
return out
# =====================
# Gradio UI
# =====================
with gr.Blocks(title="Unified NLP API (-5..+5)") as demo:
gr.Markdown("## 📈 Unified NLP API (all scores: -5 .. +5)")
gr.Markdown(
"""
**Шкалы:**
- **Agreement**: -5 = сильное противоречие, +5 = сильное согласие
- **Sentiment**: -5 = негатив, +5 = позитив
- **Sarcasm**: -5 = уверенно НЕ сарказм, +5 = уверенно сарказм/ирония
- **Neutrality**: +5 = максимально нейтрально, -5 = максимально “заряжено”
- **Multilabel**: уверенность метки в шкале -5..+5 (0.5 → 0)
"""
)
with gr.Tab("Agreement"):
msg1 = gr.Textbox(label="Message 1")
msg2 = gr.Textbox(label="Message 2")
btn_agree = gr.Button("Check Agreement")
out_agree = gr.Number(label="Agreement Score (-5..+5)")
btn_agree.click(fn=agreement_score_minus5_plus5, inputs=[msg1, msg2], outputs=out_agree)
gr.Markdown("### Agreement (irony-aware)")
btn_agree_irony = gr.Button("Check Agreement (with irony)")
out_agree_irony = gr.Number(label="Agreement Score (irony-aware) (-5..+5)")
btn_agree_irony.click(fn=agreement_with_irony, inputs=[msg1, msg2], outputs=out_agree_irony)
with gr.Tab("Sentiment"):
text_sent = gr.Textbox(label="Text")
btn_sent = gr.Button("Analyze Sentiment")
out_sent = gr.Number(label="Sentiment Score (-5..+5)")
btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
with gr.Tab("Sarcasm / Irony"):
text_sarc = gr.Textbox(label="Text")
btn_sarc = gr.Button("Analyze Sarcasm")
out_sarc = gr.Number(label="Sarcasm Score (-5..+5)")
btn_sarc.click(fn=sarcasm_score, inputs=text_sarc, outputs=out_sarc)
with gr.Tab("Neutrality"):
text_neu = gr.Textbox(label="Text")
btn_neu = gr.Button("Analyze Neutrality")
out_neu = gr.Number(label="Neutrality Score (-5..+5)")
btn_neu.click(fn=neutrality_score, inputs=text_neu, outputs=out_neu)
with gr.Tab("Multilabel Classification"):
text_clf = gr.Textbox(label="Text")
btn_clf = gr.Button("Classify")
out_clf = gr.Label(label="Categories & Scores (-5..+5)")
btn_clf.click(fn=classify_message, inputs=text_clf, outputs=out_clf)
demo.launch()