narutoSiskovich commited on
Commit
c5b8dc3
·
verified ·
1 Parent(s): 76a8c52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import gradio as gr
3
  import torch
4
  from transformers import (
@@ -15,7 +14,7 @@ DEVICE = "cpu" # HF Spaces обычно CPU
15
  # =====================
16
  # Agreement (MNLI)
17
  # =====================
18
- MNLI_MODEL = "facebook/bart-base-mnli"
19
  mnli_tokenizer = None
20
  mnli_model = None
21
 
@@ -33,6 +32,7 @@ def check_agreement(msg1: str, msg2: str) -> float:
33
  with torch.no_grad():
34
  logits = mnli_model(**inputs).logits
35
  probs = torch.softmax(logits, dim=-1)[0]
 
36
  return round((probs[2] - probs[0]).item(), 2)
37
 
38
  # =====================
@@ -57,6 +57,7 @@ def analyze_sentiment(text: str) -> float:
57
  logits = sent_model(**inputs).logits
58
  probs = torch.softmax(logits, dim=-1)
59
  stars = torch.argmax(probs, dim=-1).item() + 1
 
60
  return round((stars - 3) * 2.5, 2)
61
 
62
  # =====================
@@ -82,14 +83,14 @@ def load_classifier():
82
  clf_model.to(DEVICE)
83
  clf_model.eval()
84
 
85
- def classify_message(text: str) -> list:
86
  load_classifier()
87
  inputs = clf_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
88
  with torch.no_grad():
89
  logits = clf_model(**inputs).logits
90
  probs = torch.sigmoid(logits)[0]
91
- labels = [CATEGORIES[i] for i, p in enumerate(probs) if p > 0.5]
92
- return labels or ["neutral"]
93
 
94
  # =====================
95
  # Gradio UI
@@ -97,6 +98,7 @@ def classify_message(text: str) -> list:
97
  with gr.Blocks(title="Unified NLP API") as demo:
98
  gr.Markdown("## 📈 Unified NLP API")
99
 
 
100
  with gr.Tab("Agreement"):
101
  msg1 = gr.Textbox(label="Message 1")
102
  msg2 = gr.Textbox(label="Message 2")
@@ -104,12 +106,14 @@ with gr.Blocks(title="Unified NLP API") as demo:
104
  out_agree = gr.Number(label="Agreement Score")
105
  btn_agree.click(fn=check_agreement, inputs=[msg1, msg2], outputs=out_agree)
106
 
 
107
  with gr.Tab("Sentiment"):
108
  text_sent = gr.Textbox(label="Text")
109
  btn_sent = gr.Button("Analyze Sentiment")
110
  out_sent = gr.Number(label="Sentiment Score (-5 to +5)")
111
  btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
112
 
 
113
  with gr.Tab("Multilabel Classification"):
114
  text_clf = gr.Textbox(label="Text")
115
  btn_clf = gr.Button("Classify")
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import (
 
14
  # =====================
15
  # Agreement (MNLI)
16
  # =====================
17
+ MNLI_MODEL = "facebook/bart-large-mnli" # Fixed: valid public model
18
  mnli_tokenizer = None
19
  mnli_model = None
20
 
 
32
  with torch.no_grad():
33
  logits = mnli_model(**inputs).logits
34
  probs = torch.softmax(logits, dim=-1)[0]
35
+ # Agreement score: entailment - contradiction
36
  return round((probs[2] - probs[0]).item(), 2)
37
 
38
  # =====================
 
57
  logits = sent_model(**inputs).logits
58
  probs = torch.softmax(logits, dim=-1)
59
  stars = torch.argmax(probs, dim=-1).item() + 1
60
+ # Convert 1-5 stars into -5 to +5 scale
61
  return round((stars - 3) * 2.5, 2)
62
 
63
  # =====================
 
83
  clf_model.to(DEVICE)
84
  clf_model.eval()
85
 
86
+ def classify_message(text: str) -> dict:
87
  load_classifier()
88
  inputs = clf_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
89
  with torch.no_grad():
90
  logits = clf_model(**inputs).logits
91
  probs = torch.sigmoid(logits)[0]
92
+ # Return a dict for Gradio Label output
93
+ return {CATEGORIES[i]: float(probs[i]) for i in range(len(CATEGORIES))}
94
 
95
  # =====================
96
  # Gradio UI
 
98
  with gr.Blocks(title="Unified NLP API") as demo:
99
  gr.Markdown("## 📈 Unified NLP API")
100
 
101
+ # ----- Agreement Tab -----
102
  with gr.Tab("Agreement"):
103
  msg1 = gr.Textbox(label="Message 1")
104
  msg2 = gr.Textbox(label="Message 2")
 
106
  out_agree = gr.Number(label="Agreement Score")
107
  btn_agree.click(fn=check_agreement, inputs=[msg1, msg2], outputs=out_agree)
108
 
109
+ # ----- Sentiment Tab -----
110
  with gr.Tab("Sentiment"):
111
  text_sent = gr.Textbox(label="Text")
112
  btn_sent = gr.Button("Analyze Sentiment")
113
  out_sent = gr.Number(label="Sentiment Score (-5 to +5)")
114
  btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
115
 
116
+ # ----- Multilabel Classification Tab -----
117
  with gr.Tab("Multilabel Classification"):
118
  text_clf = gr.Textbox(label="Text")
119
  btn_clf = gr.Button("Classify")