sherdd commited on
Commit
016dbc6
·
verified ·
1 Parent(s): e1377d6
Files changed (1) hide show
  1. app.py +193 -15
app.py CHANGED
@@ -1,33 +1,211 @@
1
- import os
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- MODEL_ID = os.getenv("MODEL_ID", "cardiffnlp/twitter-xlm-roberta-base-sentiment")
6
- LABEL_MAP = {0: "negative", 1: "neutral", 2: "positive"} # modelin etiket sirasi
 
 
 
 
 
 
 
 
7
 
8
- # modeli ve tokenizer'i bir kez yukle
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
11
- pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True, framework="pt", device=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def analyze(text: str):
14
  text = (text or "").strip()
15
  if not text:
16
  return {"label": "neutral", "score": 1.0}
17
- scores = pipe(text)[0] # [{label: "...", score: ...}, ...]
18
- max_idx = max(range(len(scores)), key=lambda i: scores[i]["score"])
19
- label = LABEL_MAP.get(max_idx, scores[max_idx]["label"]).lower()
20
- score = float(scores[max_idx]["score"])
21
- return {"label": label, "score": round(score, 4)}
22
 
23
- demo = gr.Interface(
 
 
 
 
 
 
 
 
 
24
  fn=analyze,
25
  inputs=gr.Textbox(lines=3, placeholder="type a message..."),
26
  outputs=gr.JSON(),
27
  title="chat sentiment api",
28
  description="returns json: {label: positive|neutral|negative, score: 0..1}",
29
  )
30
- demo.api_name = "analyze" # endpoint: /api/predict/analyze
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  if __name__ == "__main__":
33
  demo.launch()
 
1
+ import os, re, time
2
  import gradio as gr
3
+ from typing import List, Dict, Tuple
4
+ from transformers import (
5
+ AutoTokenizer, AutoModelForSequenceClassification,
6
+ TextClassificationPipeline, AutoConfig
7
+ )
8
+
9
+ # ----------------------------
10
+ # MODELs
11
+ # ----------------------------
12
+ MODELS: Dict[str, Dict] = {
13
+ "xlmr": {
14
+ "name": "XLM-R (3-class)",
15
+ "id": "cardiffnlp/twitter-xlm-roberta-base-sentiment",
16
+ "kind": "3class",
17
+ "default": True, # default
18
+ },
19
+ "distilmulti": {
20
+ "name": "DistilBERT (5-star)",
21
+ "id": "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
22
+ "kind": "5star",
23
+ "default": True,
24
+ },
25
+ "mbert5": {
26
+ "name": "mBERT (5-star)",
27
+ "id": "nlptown/bert-base-multilingual-uncased-sentiment",
28
+ "kind": "5star",
29
+ "default": False,
30
+ },
31
+ "turkish2": {
32
+ "name": "Turkish BERT (2-class)",
33
+ "id": "savasy/bert-base-turkish-sentiment-cased",
34
+ "kind": "2class",
35
+ "default": False,
36
+ },
37
+ }
38
+
39
+ # Tek model API'si için
40
+ MODEL_ID = os.getenv("MODEL_ID", MODELS["xlmr"]["id"])
41
+ LABEL_MAP_3CLS = {0: "negative", 1: "neutral", 2: "positive"}
42
+
43
+
44
+ _PIPE_CACHE: Dict[str, TextClassificationPipeline] = {}
45
+ _CFG_CACHE: Dict[str, AutoConfig] = {}
46
+
47
+ def get_pipe_and_cfg(model_id: str) -> Tuple[TextClassificationPipeline, AutoConfig]:
48
+ if model_id not in _PIPE_CACHE:
49
+ tok = AutoTokenizer.from_pretrained(model_id)
50
+ mdl = AutoModelForSequenceClassification.from_pretrained(model_id)
51
+ _PIPE_CACHE[model_id] = TextClassificationPipeline(
52
+ model=mdl, tokenizer=tok, return_all_scores=True, framework="pt", device=-1
53
+ )
54
+ _CFG_CACHE[model_id] = AutoConfig.from_pretrained(model_id)
55
+ return _PIPE_CACHE[model_id], _CFG_CACHE[model_id]
56
+
57
+ # ----------------------------
58
+ # LABEL NORMALIZATION
59
+ # ----------------------------
60
+ def normalize_label(raw_label: str, cfg: AutoConfig, kind: str) -> str:
61
+ """Ham etiketleri positive/neutral/negative üçlüsüne indirger."""
62
+ lbl = raw_label.lower()
63
+
64
+ # LABEL_0 -> id2label -> metne çevir
65
+ if lbl.startswith("label_") and hasattr(cfg, "id2label"):
66
+ try:
67
+ idx = int(lbl.split("_")[-1])
68
+ lbl = str(cfg.id2label[idx]).lower()
69
+ except Exception:
70
+ pass
71
 
72
+ # 5-yıldızlı modeller: 1..5 -> neg/neu/pos
73
+ if kind == "5star":
74
+ m = re.search(r"([1-5])", lbl)
75
+ if m:
76
+ s = int(m.group(1))
77
+ if s <= 2:
78
+ return "negative"
79
+ if s == 3:
80
+ return "neutral"
81
+ return "positive"
82
 
83
+ # metinsel eşleştirme
84
+ if "neg" in lbl:
85
+ return "negative"
86
+ if "neu" in lbl:
87
+ return "neutral"
88
+ if "pos" in lbl:
89
+ return "positive"
90
+
91
+ # 2-class modellerin bazılarında sadece pos/neg var
92
+ return "neutral"
93
+
94
+ # ----------------------------
95
+ # TEK METİN ANALİZ (API)
96
+ # ----------------------------
97
+ # endpoint: /api/predict/analyze
98
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
99
+ _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
100
+ _pipe = TextClassificationPipeline(model=_model, tokenizer=_tokenizer, return_all_scores=True, framework="pt", device=-1)
101
 
102
  def analyze(text: str):
103
  text = (text or "").strip()
104
  if not text:
105
  return {"label": "neutral", "score": 1.0}
106
+ scores = _pipe(text)[0]
107
+ top = max(scores, key=lambda s: s["score"])
 
 
 
108
 
109
+ # LABEL_0/1/2 -> okunabilir etiket
110
+ raw = top["label"]
111
+ if raw.startswith("LABEL_"):
112
+ idx = int(raw.split("_")[-1])
113
+ label = LABEL_MAP_3CLS.get(idx, raw).lower()
114
+ else:
115
+ label = raw.lower()
116
+ return {"label": label, "score": round(float(top["score"]), 4)}
117
+
118
+ api_intf = gr.Interface(
119
  fn=analyze,
120
  inputs=gr.Textbox(lines=3, placeholder="type a message..."),
121
  outputs=gr.JSON(),
122
  title="chat sentiment api",
123
  description="returns json: {label: positive|neutral|negative, score: 0..1}",
124
  )
125
+ api_intf.api_name = "analyze" # /api/predict/analyze
126
+
127
+ # ----------------------------
128
+ # ÇOKLU MODEL KARŞILAŞTIRMA (UI)
129
+ # ----------------------------
130
+ def run_benchmark(texts_blob: str, selected_keys: List[str]):
131
+ texts = [t.strip() for t in (texts_blob or "").splitlines() if t.strip()]
132
+ if not texts:
133
+ return " Metin alanı boş. Her satıra bir örnek yaz.", []
134
+
135
+ if not selected_keys:
136
+ return " En az bir model seç.", []
137
+
138
+ # tablo başlıkları
139
+ headers = ["text", "model", "label", "score", "latency_ms"]
140
+ rows = []
141
+
142
+ for t in texts:
143
+ for key in selected_keys:
144
+ spec = MODELS[key]
145
+ pipe, cfg = get_pipe_and_cfg(spec["id"])
146
+
147
+ t0 = time.perf_counter()
148
+ out = pipe(t)[0] # list of dicts
149
+ top = max(out, key=lambda s: s["score"])
150
+ latency = (time.perf_counter() - t0) * 1000.0
151
+
152
+ label = normalize_label(top["label"], cfg, spec["kind"])
153
+ score = float(top["score"])
154
+
155
+ rows.append([t, spec["name"], label, round(score, 4), round(latency, 1)])
156
+
157
+ # küçük özet
158
+ # ortalama gecikme ve label dağılımı (model bazında)
159
+ by_model: Dict[str, Dict] = {}
160
+ for r in rows:
161
+ _t, m, lab, sc, lat = r
162
+ d = by_model.setdefault(m, {"n": 0, "lat_sum": 0.0, "neg": 0, "neu": 0, "pos": 0})
163
+ d["n"] += 1
164
+ d["lat_sum"] += lat
165
+ d[lab[:3]] += 1 # neg/neu/pos sayacı
166
+
167
+ lines = ["### Summary"]
168
+ for m, d in by_model.items():
169
+ avg_lat = d["lat_sum"] / max(d["n"], 1)
170
+ lines.append(f"- **{m}** → avg latency: **{avg_lat:.1f} ms**, counts: neg={d['neg']}, neu={d['neu']}, pos={d['pos']}")
171
+
172
+ summary_md = "\n".join(lines)
173
+ return summary_md, rows, headers
174
+
175
+ with gr.Blocks(title="sentiment multi-model bench") as bench_ui:
176
+ gr.Markdown("## Compare models on the same inputs\nEnter one sentence per line. Select models and run.")
177
+ txt = gr.Textbox(lines=8, label="Sentences (one per line)", placeholder="bugün hava harika\nama içim biraz buruk\nnötr bir cümle örneği")
178
+ default_keys = [k for k, v in MODELS.items() if v["default"]]
179
+ choices = gr.CheckboxGroup(
180
+ choices=[gr.Checkbox(label=v["name"], value=False, elem_id=k) for k, v in MODELS.items()],
181
+ label="Models to test",
182
+ )
183
+ # CheckboxGroup 'choices' parametresi metin beklediği için isimleri kullanacağız:
184
+ model_names = [MODELS[k]["name"] for k in MODELS]
185
+ choices.choices = model_names
186
+ choices.value = [MODELS[k]["name"] for k in MODELS if MODELS[k]["default"]]
187
+
188
+ run_btn = gr.Button("Run benchmark")
189
+ out_md = gr.Markdown()
190
+ out_tbl = gr.Dataframe(row_count=(0, "dynamic"), col_count=(5, "fixed"), wrap=True)
191
+
192
+ def _resolve_keys(selected_names: List[str]) -> List[str]:
193
+ rev = {v["name"]: k for k, v in MODELS.items()}
194
+ return [rev[n] for n in (selected_names or []) if n in rev]
195
+
196
+ def _runner(texts_blob, selected_names):
197
+ keys = _resolve_keys(selected_names)
198
+ summary_md, rows, headers = run_benchmark(texts_blob, keys)
199
+ out_tbl_headers = headers # ["text","model","label","score","latency_ms"]
200
+ return summary_md, gr.update(value=rows, headers=out_tbl_headers)
201
+
202
+ run_btn.click(_runner, inputs=[txt, choices], outputs=[out_md, out_tbl])
203
+
204
+
205
+ demo = gr.TabbedInterface(
206
+ [api_intf, bench_ui],
207
+ tab_names=["API (single model)", "Compare models"],
208
+ )
209
 
210
  if __name__ == "__main__":
211
  demo.launch()