Dusit-P commited on
Commit
5dbe10a
·
verified ·
1 Parent(s): c1fbd91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -52
app.py CHANGED
@@ -1,26 +1,29 @@
1
- import os, json, importlib.util, tempfile, traceback, torch, re, math
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
  import gradio as gr
5
- import pandas as pd
 
6
  import plotly.graph_objects as go
7
  from huggingface_hub import hf_hub_download
8
  from safetensors.torch import load_file
9
  from transformers import AutoTokenizer, AutoModel
10
 
11
- # ===== Settings =====
12
- REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment")
13
- DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "WCB") # default model
14
  HF_TOKEN = os.getenv("HF_TOKEN", None)
15
 
16
- # ---- theme colors ----
17
- NEG_COLOR = "#F87171" # red-400
18
- POS_COLOR = "#34D399" # emerald-400
 
 
 
19
  TEMPLATE = "plotly_white"
20
 
21
  CACHE = {}
22
 
23
- # ---------- load models from common/models.py ----------
24
  def _import_models():
25
  if "models_module" in CACHE:
26
  return CACHE["models_module"]
@@ -44,10 +47,11 @@ def load_model(model_name: str):
44
 
45
  base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
46
  arch_name = cfg.get("architecture", model_name)
47
- tok = AutoTokenizer.from_pretrained(base_model)
48
 
 
49
  models = _import_models()
50
- model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)), cfg.get("pooling_after_lstm","masked_mean"))
 
51
 
52
  state = load_file(w_path)
53
  model.load_state_dict(state, strict=False)
@@ -56,17 +60,14 @@ def load_model(model_name: str):
56
  CACHE[key] = (model, tok, cfg)
57
  return CACHE[key]
58
 
59
- # ---------- helpers ----------
60
- def _format_pct(x: float) -> str:
61
- return f"{x*100:.2f}%"
62
-
63
  _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""}
64
  _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
65
 
66
  def _norm_text(v) -> str:
67
  if v is None: return ""
68
  if isinstance(v, float) and math.isnan(v): return ""
69
- return str(v).strip()
70
 
71
  def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
72
  if not s: return False
@@ -75,11 +76,8 @@ def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
75
  if len(s.replace(" ", "")) < min_chars: return False
76
  return True
77
 
78
- def _clean_texts(texts):
79
- all_norm = [_norm_text(t) for t in texts]
80
- cleaned = [t for t in all_norm if _is_substantive_text(t)]
81
- skipped = len(all_norm) - len(cleaned)
82
- return cleaned, skipped
83
 
84
  def _make_figures(df: pd.DataFrame):
85
  total = len(df)
@@ -110,10 +108,9 @@ def _make_figures(df: pd.DataFrame):
110
  marker=dict(colors=[NEG_COLOR, POS_COLOR])
111
  ))
112
  fig_pie.update_layout(title="Label share", template=TEMPLATE)
113
-
114
  return fig_bar, fig_pie, info
115
 
116
- # ---------- core prediction ----------
117
  def _predict_batch(texts, model_name, batch_size=32):
118
  model, tok, cfg = load_model(model_name)
119
  results = []
@@ -135,37 +132,92 @@ def _predict_batch(texts, model_name, batch_size=32):
135
  })
136
  return results
137
 
 
138
  def predict_one(text: str, model_choice: str):
139
- s = _norm_text(text)
140
- if not _is_substantive_text(s):
141
- return {"negative": 0.0, "positive": 0.0}, "invalid"
142
- out = _predict_batch([s], model_choice)[0]
143
- probs = {
144
- "negative": float(out["negative(%)"].rstrip("%"))/100.0,
145
- "positive": float(out["positive(%)"].rstrip("%"))/100.0,
146
- }
147
- return probs, out["label"]
148
-
 
 
 
 
149
  def predict_many(text_block: str, model_choice: str):
150
- raw_lines = (text_block or "").splitlines()
151
- cleaned, skipped = _clean_texts(raw_lines)
152
- if len(cleaned) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
153
  empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
154
- return empty, go.Figure(), go.Figure(), "No valid text"
155
- results = _predict_batch(cleaned, model_choice)
156
- df = pd.DataFrame(results)
157
- fig_bar, fig_pie, info_md = _make_figures(df)
158
- info_md = f"{info_md} \n- Skipped: {skipped}"
159
- return df, fig_bar, fig_pie, info_md
160
-
161
- # ---------- Gradio UI ----------
162
- AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM", "WCB_CNN_BiLSTM", "WCB_4Layer_BiLSTM"]
163
- if DEFAULT_MODEL not in AVAILABLE_CHOICES:
164
- DEFAULT_MODEL = "WCB"
165
-
166
- with gr.Blocks(title="Thai Sentiment GUI") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  gr.Markdown("### Thai Sentiment (WangchanBERTa Variants)")
168
-
169
  model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
170
 
171
  with gr.Tab("Single"):
@@ -182,5 +234,15 @@ with gr.Blocks(title="Thai Sentiment GUI") as demo:
182
  sum2 = gr.Markdown()
183
  gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
184
 
 
 
 
 
 
 
 
 
 
 
185
  if __name__ == "__main__":
186
  demo.launch()
 
1
+ # app.py Thai Sentiment (WangchanBERTa Variants) GUI
2
+ import os, json, importlib.util, traceback, sys, re, math, tempfile
 
3
  import gradio as gr
4
+ import torch, pandas as pd
5
+ import torch.nn.functional as F
6
  import plotly.graph_objects as go
7
  from huggingface_hub import hf_hub_download
8
  from safetensors.torch import load_file
9
  from transformers import AutoTokenizer, AutoModel
10
 
11
+ # ================= Settings =================
12
+ REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment") # <<< ใช้รีโปใหม่
13
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "WCB")
14
  HF_TOKEN = os.getenv("HF_TOKEN", None)
15
 
16
+ AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM", "WCB_CNN_BiLSTM", "WCB_4Layer_BiLSTM"]
17
+ if DEFAULT_MODEL not in AVAILABLE_CHOICES:
18
+ DEFAULT_MODEL = "WCB"
19
+
20
+ NEG_COLOR = "#F87171"
21
+ POS_COLOR = "#34D399"
22
  TEMPLATE = "plotly_white"
23
 
24
  CACHE = {}
25
 
26
+ # ================= Loader =================
27
  def _import_models():
28
  if "models_module" in CACHE:
29
  return CACHE["models_module"]
 
47
 
48
  base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
49
  arch_name = cfg.get("architecture", model_name)
 
50
 
51
+ tok = AutoTokenizer.from_pretrained(base_model)
52
  models = _import_models()
53
+ model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)),
54
+ cfg.get("pooling_after_lstm", "masked_mean"))
55
 
56
  state = load_file(w_path)
57
  model.load_state_dict(state, strict=False)
 
60
  CACHE[key] = (model, tok, cfg)
61
  return CACHE[key]
62
 
63
+ # ================= Utils =================
 
 
 
64
  _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""}
65
  _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
66
 
67
  def _norm_text(v) -> str:
68
  if v is None: return ""
69
  if isinstance(v, float) and math.isnan(v): return ""
70
+ return str(v).strip().strip('"').strip("'").strip(",")
71
 
72
  def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
73
  if not s: return False
 
76
  if len(s.replace(" ", "")) < min_chars: return False
77
  return True
78
 
79
+ def _format_pct(x: float) -> str:
80
+ return f"{x*100:.2f}%"
 
 
 
81
 
82
  def _make_figures(df: pd.DataFrame):
83
  total = len(df)
 
108
  marker=dict(colors=[NEG_COLOR, POS_COLOR])
109
  ))
110
  fig_pie.update_layout(title="Label share", template=TEMPLATE)
 
111
  return fig_bar, fig_pie, info
112
 
113
+ # ================= Core Predict =================
114
  def _predict_batch(texts, model_name, batch_size=32):
115
  model, tok, cfg = load_model(model_name)
116
  results = []
 
132
  })
133
  return results
134
 
135
+ # ----- single -----
136
  def predict_one(text: str, model_choice: str):
137
+ try:
138
+ s = _norm_text(text)
139
+ if not _is_substantive_text(s):
140
+ return {"negative": 0.0, "positive": 0.0}, "invalid"
141
+ out = _predict_batch([s], model_choice)[0]
142
+ probs = {
143
+ "negative": float(out["negative(%)"].rstrip("%"))/100.0,
144
+ "positive": float(out["positive(%)"].rstrip("%"))/100.0,
145
+ }
146
+ return probs, out["label"]
147
+ except Exception as e:
148
+ return {"error": str(e)}, "error"
149
+
150
+ # ----- textarea batch -----
151
  def predict_many(text_block: str, model_choice: str):
152
+ try:
153
+ raw_lines = (text_block or "").splitlines()
154
+ all_norm = [_norm_text(t) for t in raw_lines]
155
+ cleaned = [t for t in all_norm if _is_substantive_text(t)]
156
+ skipped = len(all_norm) - len(cleaned)
157
+ if len(cleaned) == 0:
158
+ empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
159
+ return empty, go.Figure(), go.Figure(), "No valid text"
160
+ results = _predict_batch(cleaned, model_choice)
161
+ df = pd.DataFrame(results)
162
+ fig_bar, fig_pie, info_md = _make_figures(df)
163
+ info_md = f"{info_md} \n- Skipped: {skipped}"
164
+ return df, fig_bar, fig_pie, info_md
165
+ except Exception:
166
+ tb = traceback.format_exc()
167
  empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
168
+ return empty, go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```"
169
+
170
+ # ----- CSV upload -----
171
+ LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body"]
172
+
173
+ def predict_csv(file_obj, model_choice: str, text_col_name: str):
174
+ """
175
+ file_obj: gr.File (temp file), text_col_name: optional override
176
+ """
177
+ try:
178
+ if file_obj is None:
179
+ return pd.DataFrame(), go.Figure(), go.Figure(), "Please upload a CSV.", None
180
+
181
+ df = pd.read_csv(file_obj.name)
182
+ cols = [c for c in df.columns]
183
+ # autodetect column if not provided
184
+ col = text_col_name or ""
185
+ if not col or col not in df.columns:
186
+ # pick first matching likely name; else first object dtype
187
+ found = None
188
+ low = {c.lower(): c for c in cols}
189
+ for k in LIKELY_TEXT_COLS:
190
+ if k in low:
191
+ found = low[k]; break
192
+ if found is None:
193
+ cand = [c for c in cols if df[c].dtype == object]
194
+ found = cand[0] if cand else cols[0]
195
+ col = found
196
+
197
+ # clean & predict
198
+ texts = [_norm_text(v) for v in df[col].tolist()]
199
+ texts = [t for t in texts if _is_substantive_text(t)]
200
+ if len(texts) == 0:
201
+ return pd.DataFrame(), go.Figure(), go.Figure(), "No valid texts in selected column.", None
202
+
203
+ results = _predict_batch(texts, model_choice)
204
+ out_df = pd.DataFrame(results)
205
+ fig_bar, fig_pie, info_md = _make_figures(out_df)
206
+
207
+ # write downloadable csv
208
+ fd, out_path = tempfile.mkstemp(prefix="pred_", suffix=".csv")
209
+ os.close(fd)
210
+ out_df.to_csv(out_path, index=False, encoding="utf-8-sig")
211
+
212
+ info_md = f"{info_md} \n- Column used: **{col}**"
213
+ return out_df, fig_bar, fig_pie, info_md, out_path
214
+ except Exception:
215
+ tb = traceback.format_exc()
216
+ return pd.DataFrame(), go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```", None
217
+
218
+ # ================= Gradio UI =================
219
+ with gr.Blocks(title="Thai Sentiment (WangchanBERTa Variants)") as demo:
220
  gr.Markdown("### Thai Sentiment (WangchanBERTa Variants)")
 
221
  model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
222
 
223
  with gr.Tab("Single"):
 
234
  sum2 = gr.Markdown()
235
  gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
236
 
237
+ with gr.Tab("CSV Upload"):
238
+ file_in = gr.File(label="อัปโหลดไฟล์ .csv", file_types=[".csv"])
239
+ col_in = gr.Textbox(label="ชื่อคอลัมน์ข้อความ (เว้นว่างให้เลือกอัตโนมัติได้)", value="")
240
+ df3 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
241
+ bar3 = gr.Plot(label="Label counts (bar)")
242
+ pie3 = gr.Plot(label="Label share (pie)")
243
+ sum3 = gr.Markdown()
244
+ dl3 = gr.File(label="ดาวน์โหลดผลเป็น CSV", interactive=False)
245
+ gr.Button("Predict CSV").click(predict_csv, [file_in, model_radio, col_in], [df3, bar3, pie3, sum3, dl3])
246
+
247
  if __name__ == "__main__":
248
  demo.launch()