narutoSiskovich commited on
Commit
6f3e861
·
verified ·
1 Parent(s): c5b8dc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -28
app.py CHANGED
@@ -3,18 +3,18 @@ import torch
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForSequenceClassification,
6
- XLMRobertaForSequenceClassification,
7
  )
8
 
9
  # =====================
10
  # DEVICE
11
  # =====================
12
- DEVICE = "cpu" # HF Spaces обычно CPU
13
 
14
  # =====================
15
  # Agreement (MNLI)
16
  # =====================
17
- MNLI_MODEL = "facebook/bart-large-mnli" # Fixed: valid public model
18
  mnli_tokenizer = None
19
  mnli_model = None
20
 
@@ -32,7 +32,7 @@ def check_agreement(msg1: str, msg2: str) -> float:
32
  with torch.no_grad():
33
  logits = mnli_model(**inputs).logits
34
  probs = torch.softmax(logits, dim=-1)[0]
35
- # Agreement score: entailment - contradiction
36
  return round((probs[2] - probs[0]).item(), 2)
37
 
38
  # =====================
@@ -57,40 +57,37 @@ def analyze_sentiment(text: str) -> float:
57
  logits = sent_model(**inputs).logits
58
  probs = torch.softmax(logits, dim=-1)
59
  stars = torch.argmax(probs, dim=-1).item() + 1
60
- # Convert 1-5 stars into -5 to +5 scale
61
  return round((stars - 3) * 2.5, 2)
62
 
63
  # =====================
64
- # Multilabel classifier
65
  # =====================
66
- CLASSIFIER_MODEL = "xlm-roberta-base"
 
 
67
  CATEGORIES = [
68
  "politique", "woke", "racism", "crime",
69
  "police_abuse", "corruption", "hate_speech", "activism"
70
  ]
71
- clf_tokenizer = None
72
- clf_model = None
73
 
74
- def load_classifier():
75
- global clf_tokenizer, clf_model
76
- if clf_model is None:
77
- clf_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL)
78
- clf_model = XLMRobertaForSequenceClassification.from_pretrained(
79
- CLASSIFIER_MODEL,
80
- num_labels=len(CATEGORIES),
81
- problem_type="multi_label_classification"
82
  )
83
- clf_model.to(DEVICE)
84
- clf_model.eval()
85
 
86
  def classify_message(text: str) -> dict:
87
- load_classifier()
88
- inputs = clf_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
89
- with torch.no_grad():
90
- logits = clf_model(**inputs).logits
91
- probs = torch.sigmoid(logits)[0]
92
- # Return a dict for Gradio Label output
93
- return {CATEGORIES[i]: float(probs[i]) for i in range(len(CATEGORIES))}
94
 
95
  # =====================
96
  # Gradio UI
@@ -113,11 +110,11 @@ with gr.Blocks(title="Unified NLP API") as demo:
113
  out_sent = gr.Number(label="Sentiment Score (-5 to +5)")
114
  btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
115
 
116
- # ----- Multilabel Classification Tab -----
117
  with gr.Tab("Multilabel Classification"):
118
  text_clf = gr.Textbox(label="Text")
119
  btn_clf = gr.Button("Classify")
120
- out_clf = gr.Label(label="Categories")
121
  btn_clf.click(fn=classify_message, inputs=text_clf, outputs=out_clf)
122
 
123
  demo.launch()
 
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForSequenceClassification,
6
+ pipeline,
7
  )
8
 
9
  # =====================
10
  # DEVICE
11
  # =====================
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # =====================
15
  # Agreement (MNLI)
16
  # =====================
17
+ MNLI_MODEL = "facebook/bart-large-mnli"
18
  mnli_tokenizer = None
19
  mnli_model = None
20
 
 
32
  with torch.no_grad():
33
  logits = mnli_model(**inputs).logits
34
  probs = torch.softmax(logits, dim=-1)[0]
35
+ # Считаем: entailment - contradiction
36
  return round((probs[2] - probs[0]).item(), 2)
37
 
38
  # =====================
 
57
  logits = sent_model(**inputs).logits
58
  probs = torch.softmax(logits, dim=-1)
59
  stars = torch.argmax(probs, dim=-1).item() + 1
60
+ # Приводим шкалу 15 к -5..+5
61
  return round((stars - 3) * 2.5, 2)
62
 
63
  # =====================
64
+ # Zero‑Shot Classification
65
  # =====================
66
+ ZS_MODEL = "facebook/bart-large-mnli"
67
+ zs_classifier = None
68
+
69
  CATEGORIES = [
70
  "politique", "woke", "racism", "crime",
71
  "police_abuse", "corruption", "hate_speech", "activism"
72
  ]
 
 
73
 
74
+ def load_zero_shot():
75
+ global zs_classifier
76
+ if zs_classifier is None:
77
+ zs_classifier = pipeline(
78
+ "zero-shot-classification",
79
+ model=ZS_MODEL,
80
+ device=0 if torch.cuda.is_available() else -1
 
81
  )
 
 
82
 
83
  def classify_message(text: str) -> dict:
84
+ load_zero_shot()
85
+ # Zero‑shot принимает список меток:
86
+ result = zs_classifier(text, candidate_labels=CATEGORIES)
87
+ scores = result["scores"]
88
+ labels = result["labels"]
89
+ # Возвращаем словарь {label: score}
90
+ return {label: round(score, 3) for label, score in zip(labels, scores)}
91
 
92
  # =====================
93
  # Gradio UI
 
110
  out_sent = gr.Number(label="Sentiment Score (-5 to +5)")
111
  btn_sent.click(fn=analyze_sentiment, inputs=text_sent, outputs=out_sent)
112
 
113
+ # ----- Multilabel (Zero‑Shot) Classification Tab -----
114
  with gr.Tab("Multilabel Classification"):
115
  text_clf = gr.Textbox(label="Text")
116
  btn_clf = gr.Button("Classify")
117
+ out_clf = gr.Label(label="Categories & Scores")
118
  btn_clf.click(fn=classify_message, inputs=text_clf, outputs=out_clf)
119
 
120
  demo.launch()