Dusit-P commited on
Commit
f6f7109
·
verified ·
1 Parent(s): cc03334

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -282
app.py CHANGED
@@ -1,8 +1,9 @@
1
  # app.py — Thai Sentiment (WangchanBERTa Variants)
2
- # - No Single tab
3
- # - No aspect analysis (focus on POS/NEG)
4
- # - CSV tab: date pickers appear ONLY if a date column exists (use DatePicker)
5
- # - Predict buttons right below inputs
 
6
  import os, json, importlib.util, traceback, re, math, tempfile, datetime
7
  import gradio as gr
8
  import torch, pandas as pd
@@ -21,12 +22,31 @@ AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM", "WCB_CNN_BiLSTM", "WCB_4Layer_BiLSTM"]
21
  if DEFAULT_MODEL not in AVAILABLE_CHOICES:
22
  DEFAULT_MODEL = "WCB"
23
 
24
- NEG_COLOR = "#F87171" # red
25
- POS_COLOR = "#34D399" # green
26
  TEMPLATE = "plotly_white"
27
-
28
  CACHE = {}
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ================= Loader =================
31
  def _import_models():
32
  if "models_module" in CACHE:
@@ -42,84 +62,56 @@ def load_model(model_name: str):
42
  key = f"model:{model_name}"
43
  if key in CACHE:
44
  return CACHE[key]
45
-
46
  cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
47
  w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
48
-
49
  with open(cfg_path, "r", encoding="utf-8") as f:
50
  cfg = json.load(f)
51
-
52
  base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
53
  arch_name = cfg.get("architecture", model_name)
54
-
55
  tok = AutoTokenizer.from_pretrained(base_model)
56
  models = _import_models()
57
  model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)),
58
  cfg.get("pooling_after_lstm", "masked_mean"))
59
-
60
  state = load_file(w_path)
61
  model.load_state_dict(state, strict=False)
62
  model.eval()
63
-
64
  CACHE[key] = (model, tok, cfg)
65
  return CACHE[key]
66
 
67
  # ================= Utils =================
68
- _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""}
69
  _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
70
 
71
- def _norm_text(v) -> str:
72
  if v is None: return ""
73
  if isinstance(v, float) and math.isnan(v): return ""
74
  return str(v).strip().strip('"').strip("'").strip(",")
75
 
76
- def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
77
  if not s: return False
78
  if s.lower() in _INVALID_STRINGS: return False
79
  if not _RE_HAS_LETTER.search(s): return False
80
- if len(s.replace(" ", "")) < min_chars: return False
81
  return True
82
 
83
- def _format_pct(x: float) -> str:
84
- return f"{x*100:.2f}%"
85
-
86
- def _to_datetime_safe(s):
87
- return pd.to_datetime(s, errors="coerce", infer_datetime_format=True, utc=False)
88
-
89
- def _normalize_datepicker_value(v):
90
- """รับค่าจาก gr.DatePicker (datetime.date หรือ str หรือ None) → pandas.Timestamp หรือ None"""
91
- if v is None or (isinstance(v, float) and math.isnan(v)):
92
- return None
93
- if isinstance(v, datetime.date):
94
- return pd.Timestamp(v)
95
- # เผื่อบางเวอร์ชันส่ง str 'YYYY-MM-DD'
96
- try:
97
- ts = pd.to_datetime(v, errors="coerce")
98
- return ts if pd.notna(ts) else None
99
- except Exception:
100
- return None
101
 
102
  LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body","ข้อความ","รีวิว"]
103
  LIKELY_DATE_COLS = ["date","created_at","time","timestamp","datetime","วันที่","วันเวลา","เวลา"]
104
 
105
- def detect_text_and_date_cols(df: pd.DataFrame):
106
  cols = list(df.columns)
107
- # text col
108
  low = {c.lower(): c for c in cols}
109
  text_col = None
110
  for k in LIKELY_TEXT_COLS:
111
- if k in low:
112
- text_col = low[k]; break
113
  if text_col is None:
114
  cand = [c for c in cols if df[c].dtype == object]
115
  text_col = cand[0] if cand else cols[0]
116
-
117
- # date candidates
118
  date_candidates = []
119
  for c in cols:
120
- if c.lower() in LIKELY_DATE_COLS:
121
- date_candidates.append(c)
122
- continue
123
  sample = df[c].head(50)
124
  if _to_datetime_safe(sample).notna().sum() >= max(3, int(len(sample)*0.2)):
125
  date_candidates.append(c)
@@ -128,277 +120,142 @@ def detect_text_and_date_cols(df: pd.DataFrame):
128
  return text_col, date_candidates, date_col
129
 
130
  # ================= Charts =================
131
- def make_basic_charts(df: pd.DataFrame):
132
  total = len(df)
133
- neg_df = df[df["label"] == "negative"].copy()
134
- pos_df = df[df["label"] == "positive"].copy()
135
-
136
- # bar counts
137
  fig_bar = go.Figure()
138
  fig_bar.add_bar(name="negative", x=["negative"], y=[len(neg_df)], marker_color=NEG_COLOR)
139
  fig_bar.add_bar(name="positive", x=["positive"], y=[len(pos_df)], marker_color=POS_COLOR)
140
  fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
141
-
142
- # pie pos/neg
143
- labels = ["negative", "positive"]
144
- values = [len(neg_df), len(pos_df)]
145
- fig_pie = go.Figure(go.Pie(labels=labels, values=values, hole=0.35, sort=False,
146
  marker=dict(colors=[NEG_COLOR, POS_COLOR])))
147
  fig_pie.update_layout(title="Positive vs Negative", template=TEMPLATE)
148
-
149
  neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
150
  pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
151
- info = (
152
- f"**Summary** \n"
153
- f"- Total: {total} \n"
154
- f"- Negative: {len(neg_df)} \n"
155
- f"- Positive: {len(pos_df)} \n"
156
- f"- Avg negative: {neg_avg:.2f}% \n"
157
- f"- Avg positive: {pos_avg:.2f}%"
158
- )
159
  return fig_bar, fig_pie, info
160
 
161
  def _resample_counts(df, date_col, freq):
162
- g = df.groupby([pd.Grouper(key=date_col, freq=freq), "label"]).size().unstack(fill_value=0)
163
- for col in ["negative","positive"]:
164
- if col not in g.columns:
165
- g[col] = 0
166
  return g[["negative","positive"]].sort_index()
167
 
168
- def _rolling_window(freq):
169
- return 7 if freq == "D" else (4 if freq == "W" else 3)
170
-
171
- def make_time_chart(df: pd.DataFrame, date_col: str, freq: str, use_ma: bool):
172
- ts = _resample_counts(df, date_col, freq)
173
- if use_ma:
174
- win = _rolling_window(freq)
175
- ts = ts.rolling(win, min_periods=1).mean()
176
 
177
- fig_line = go.Figure()
178
- fig_line.add_scatter(x=ts.index, y=ts["negative"], mode="lines",
179
- name="negative", line=dict(color=NEG_COLOR))
180
- fig_line.add_scatter(x=ts.index, y=ts["positive"], mode="lines",
181
- name="positive", line=dict(color=POS_COLOR))
182
- fig_line.update_layout(title="Reviews over time (POS/NEG)", template=TEMPLATE,
183
- xaxis_title="Date", yaxis_title="Count")
184
- return fig_line
 
185
 
186
  # ================= Core Predict =================
187
  def _predict_batch(texts, model_name, batch_size=32):
188
- model, tok, cfg = load_model(model_name)
189
- results = []
190
- for i in range(0, len(texts), batch_size):
191
- chunk = texts[i:i+batch_size]
192
- enc = tok(chunk, padding=True, truncation=True,
193
- max_length=cfg.get("max_length",128), return_tensors="pt")
194
  with torch.no_grad():
195
- logits = model(enc["input_ids"], enc["attention_mask"])
196
- probs = F.softmax(logits, dim=1).cpu().numpy()
197
- for txt, p in zip(chunk, probs):
198
- neg, pos = float(p[0]), float(p[1])
199
- label = "positive" if pos >= neg else "negative"
200
- results.append({
201
- "review": txt,
202
- "negative(%)": _format_pct(neg),
203
- "positive(%)": _format_pct(pos),
204
- "label": label,
205
- })
206
  return results
207
 
208
- # ================= Batch (Textarea) =================
209
- def predict_many(text_block: str, model_choice: str):
210
  try:
211
- raw_lines = (text_block or "").splitlines()
212
- all_norm = [_norm_text(t) for t in raw_lines]
213
- cleaned = [t for t in all_norm if _is_substantive_text(t)]
214
- skipped = len(all_norm) - len(cleaned)
215
- if len(cleaned) == 0:
216
- empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
217
- return empty, go.Figure(), go.Figure(), "No valid text"
218
- results = _predict_batch(cleaned, model_choice)
219
- df = pd.DataFrame(results)
220
- fig_bar, fig_pie, info_md = make_basic_charts(df)
221
- info_md = f"{info_md} \n- Skipped: {skipped}"
222
- return df, fig_bar, fig_pie, info_md
223
- except Exception:
224
- tb = traceback.format_exc()
225
- empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
226
- return empty, go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```"
227
-
228
- # ================= CSV Inspect (auto-detect & toggle UI) =================
229
  def on_file_change(file_obj):
230
- """
231
- เมื่ออัปโหลดไฟล์:
232
- - คืน options ของ text/date dropdown
233
- - ชื่อ default ที่เลือก
234
- - toggle visibility ของ date controls + line chart placeholder
235
- """
236
  if file_obj is None:
237
- return (
238
- gr.update(choices=[], value=None), # text_dd
239
- gr.update(choices=[], value=None), # date_dd
240
- gr.update(visible=False), # date_from
241
- gr.update(visible=False), # date_to
242
- gr.update(visible=False), # freq
243
- gr.update(visible=False), # use_ma
244
- gr.update(visible=False), # line chart
245
- "Please upload a CSV file."
246
- )
247
-
248
  try:
249
- df_raw = pd.read_csv(file_obj.name)
250
- cols = list(df_raw.columns)
251
- text_col, date_candidates, date_col = detect_text_and_date_cols(df_raw)
252
-
253
- has_date = date_col is not None
254
- note = "Detected text column: **{}**".format(text_col)
255
- if has_date:
256
- note += "; detected date column: **{}**".format(date_col)
257
- else:
258
- note += "; _no date/timestamp column detected_"
259
-
260
- return (
261
- gr.update(choices=cols, value=text_col),
262
- gr.update(choices=date_candidates, value=date_col),
263
- gr.update(visible=has_date),
264
- gr.update(visible=has_date),
265
- gr.update(visible=has_date),
266
- gr.update(visible=has_date),
267
- gr.update(visible=has_date),
268
- note
269
- )
270
- except Exception:
271
- tb = traceback.format_exc()
272
- return (
273
- gr.update(choices=[], value=None),
274
- gr.update(choices=[], value=None),
275
- gr.update(visible=False),
276
- gr.update(visible=False),
277
- gr.update(visible=False),
278
- gr.update(visible=False),
279
- gr.update(visible=False),
280
- f"**Error reading CSV**\n```\n{tb}\n```"
281
- )
282
 
283
  # ================= CSV Predict =================
284
- def predict_csv(file_obj, model_choice: str, text_col_name: str,
285
- date_col_name: str, date_from, date_to,
286
- freq_choice: str, use_ma: bool):
287
-
288
  try:
289
- if file_obj is None:
290
- return pd.DataFrame(), go.Figure(), go.Figure(), gr.update(visible=False, value=go.Figure()), "Please upload a CSV.", None
291
-
292
- df_raw = pd.read_csv(file_obj.name)
293
- cols = list(df_raw.columns)
294
-
295
- col_text = text_col_name if text_col_name in cols else detect_text_and_date_cols(df_raw)[0]
296
-
297
- texts = [_norm_text(v) for v in df_raw[col_text].tolist()]
298
- texts = [t for t in texts if _is_substantive_text(t)]
299
- if len(texts) == 0:
300
- return pd.DataFrame(), go.Figure(), go.Figure(), gr.update(visible=False, value=go.Figure()), "No valid texts in selected column.", None
301
-
302
- # predict
303
- results = _predict_batch(texts, model_choice)
304
- out_df = pd.DataFrame(results)
305
-
306
- # basic charts
307
- fig_bar, fig_pie, info_basic = make_basic_charts(out_df)
308
-
309
- # time charts (optional)
310
- show_time = False
311
- fig_line = go.Figure()
312
- if date_col_name and (date_col_name in cols):
313
- dts = _to_datetime_safe(df_raw[date_col_name])
314
  if dts.notna().any():
315
- df_time = out_df.copy()
316
- df_time["__dt__"] = dts
317
- df_time = df_time.dropna(subset=["__dt__"])
318
-
319
- # normalize datepicker values
320
- start_ts = _normalize_datepicker_value(date_from)
321
- end_ts = _normalize_datepicker_value(date_to)
322
-
323
- if start_ts is not None:
324
- df_time = df_time[df_time["__dt__"] >= start_ts]
325
- if end_ts is not None:
326
- df_time = df_time[df_time["__dt__"] <= end_ts]
327
-
328
- if len(df_time) > 0:
329
- fig_line = make_time_chart(df_time, "__dt__", freq_choice, use_ma)
330
- show_time = True
331
-
332
- # downloadable CSV
333
- fd, out_path = tempfile.mkstemp(prefix="pred_", suffix=".csv")
334
- os.close(fd)
335
- out_df.to_csv(out_path, index=False, encoding="utf-8-sig")
336
-
337
- info_time = ""
338
- if date_col_name:
339
- if show_time:
340
- info_time = f"\n\nTime chart based on date column: **{date_col_name}**, Freq: **{freq_choice}**, MA: **{use_ma}**"
341
- else:
342
- info_time = "\n\n_Selected date range has no data OR unable to parse dates._"
343
- else:
344
- info_time = "\n\n_No date/timestamp column selected — time chart hidden._"
345
-
346
- info_md = info_basic + info_time
347
- return out_df, fig_bar, fig_pie, gr.update(visible=show_time, value=fig_line), info_md, out_path
348
-
349
- except Exception:
350
- tb = traceback.format_exc()
351
- return pd.DataFrame(), go.Figure(), go.Figure(), gr.update(visible=False, value=go.Figure()), f"**Error**\n```\n{tb}\n```", None
352
 
353
  # ================= Gradio UI =================
354
- with gr.Blocks(title="Thai Sentiment (WangchanBERTa Variants)") as demo:
355
- gr.Markdown("### Thai Sentiment (WangchanBERTa Variants) Focus on POS/NEG")
 
356
 
357
- model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
 
 
 
 
358
 
359
- # ---- Batch (Textarea) ----
360
- with gr.Tab("Batch (หลายข้อความ)"):
361
- t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
362
- btn_batch = gr.Button("Predict", variant="primary")
363
- df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
364
- bar2 = gr.Plot(label="Label counts (bar)")
365
- pie2 = gr.Plot(label="Positive vs Negative (pie)")
366
- sum2 = gr.Markdown()
367
- btn_batch.click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
368
-
369
- # ---- CSV Upload ----
370
  with gr.Tab("CSV Upload"):
371
  with gr.Row():
372
- file_in = gr.File(label="อัปโหลดไฟล์ .csv", file_types=[".csv"])
373
- text_dd = gr.Dropdown(label="คอลัมน์ข้อความ", choices=[], value=None)
374
- date_dd = gr.Dropdown(label="คอลัมน์วันเวลา (ถ้ามี)", choices=[], value=None)
375
  with gr.Row():
376
- # ใช้ DatePicker แทน Date (รองรับ gradio วอ์ชันที่ไม่เคยมี gr.Date)
377
- date_from = gr.DatePicker(label="เริ่มวันที่", visible=False)
378
- date_to = gr.DatePicker(label="ถึงวันที่", visible=False)
379
- freq = gr.Radio(choices=["D","W","M"], value="D", label="ความถี่ (Day/Week/Month)", visible=False)
380
- use_ma = gr.Checkbox(value=True, label="Moving average (7/4/3)", visible=False)
381
-
382
- btn_csv = gr.Button("Predict CSV", variant="primary")
383
- note_detect = gr.Markdown()
384
-
385
- df3 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
386
- bar3 = gr.Plot(label="Label counts (bar)")
387
- pie3 = gr.Plot(label="Positive vs Negative (pie)")
388
- line = gr.Plot(label="Reviews over time (POS/NEG)", visible=False)
389
- sum3 = gr.Markdown()
390
- dl3 = gr.File(label="ดาวน์โหลดผลเป็น CSV", interactive=False)
391
-
392
- file_in.change(
393
- on_file_change, [file_in],
394
- [text_dd, date_dd, date_from, date_to, freq, use_ma, line, note_detect]
395
- )
396
-
397
- btn_csv.click(
398
- predict_csv,
399
- [file_in, model_radio, text_dd, date_dd, date_from, date_to, freq, use_ma],
400
- [df3, bar3, pie3, line, sum3, dl3]
401
- )
402
-
403
- if __name__ == "__main__":
404
- demo.launch()
 
1
  # app.py — Thai Sentiment (WangchanBERTa Variants)
2
+ # - Focus on POS/NEG only
3
+ # - Batch + CSV tabs
4
+ # - CSV: auto-detect text/date cols, hide date widgets if no date col
5
+ # - DatePicker fallback to Textbox if component missing
6
+
7
  import os, json, importlib.util, traceback, re, math, tempfile, datetime
8
  import gradio as gr
9
  import torch, pandas as pd
 
22
  if DEFAULT_MODEL not in AVAILABLE_CHOICES:
23
  DEFAULT_MODEL = "WCB"
24
 
25
+ NEG_COLOR = "#F87171"
26
+ POS_COLOR = "#34D399"
27
  TEMPLATE = "plotly_white"
 
28
  CACHE = {}
29
 
30
+ # ================= Date Component Fallback =================
31
+ try:
32
+ DateInput = getattr(gr, "Date", None) or getattr(gr, "DatePicker", None)
33
+ except Exception:
34
+ DateInput = None
35
+ DATE_FALLBACK_TO_TEXT = False
36
+ if DateInput is None:
37
+ DateInput = gr.Textbox
38
+ DATE_FALLBACK_TO_TEXT = True
39
+
40
+ def _normalize_date_input(v):
41
+ if v is None: return None
42
+ if isinstance(v, float) and math.isnan(v): return None
43
+ if isinstance(v, datetime.date): return pd.Timestamp(v)
44
+ try:
45
+ ts = pd.to_datetime(v, errors="coerce")
46
+ return ts if pd.notna(ts) else None
47
+ except Exception:
48
+ return None
49
+
50
  # ================= Loader =================
51
  def _import_models():
52
  if "models_module" in CACHE:
 
62
  key = f"model:{model_name}"
63
  if key in CACHE:
64
  return CACHE[key]
 
65
  cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
66
  w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
 
67
  with open(cfg_path, "r", encoding="utf-8") as f:
68
  cfg = json.load(f)
 
69
  base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
70
  arch_name = cfg.get("architecture", model_name)
 
71
  tok = AutoTokenizer.from_pretrained(base_model)
72
  models = _import_models()
73
  model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)),
74
  cfg.get("pooling_after_lstm", "masked_mean"))
 
75
  state = load_file(w_path)
76
  model.load_state_dict(state, strict=False)
77
  model.eval()
 
78
  CACHE[key] = (model, tok, cfg)
79
  return CACHE[key]
80
 
81
  # ================= Utils =================
82
+ _INVALID_STRINGS = {"-", "--","—","n/a","na","null","none","nan",".","…",""}
83
  _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
84
 
85
+ def _norm_text(v):
86
  if v is None: return ""
87
  if isinstance(v, float) and math.isnan(v): return ""
88
  return str(v).strip().strip('"').strip("'").strip(",")
89
 
90
+ def _is_substantive_text(s, min_chars=2):
91
  if not s: return False
92
  if s.lower() in _INVALID_STRINGS: return False
93
  if not _RE_HAS_LETTER.search(s): return False
94
+ if len(s.replace(" ","")) < min_chars: return False
95
  return True
96
 
97
+ def _format_pct(x): return f"{x*100:.2f}%"
98
+ def _to_datetime_safe(s): return pd.to_datetime(s, errors="coerce", infer_datetime_format=True, utc=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body","ข้อความ","รีวิว"]
101
  LIKELY_DATE_COLS = ["date","created_at","time","timestamp","datetime","วันที่","วันเวลา","เวลา"]
102
 
103
+ def detect_text_and_date_cols(df):
104
  cols = list(df.columns)
 
105
  low = {c.lower(): c for c in cols}
106
  text_col = None
107
  for k in LIKELY_TEXT_COLS:
108
+ if k in low: text_col = low[k]; break
 
109
  if text_col is None:
110
  cand = [c for c in cols if df[c].dtype == object]
111
  text_col = cand[0] if cand else cols[0]
 
 
112
  date_candidates = []
113
  for c in cols:
114
+ if c.lower() in LIKELY_DATE_COLS: date_candidates.append(c); continue
 
 
115
  sample = df[c].head(50)
116
  if _to_datetime_safe(sample).notna().sum() >= max(3, int(len(sample)*0.2)):
117
  date_candidates.append(c)
 
120
  return text_col, date_candidates, date_col
121
 
122
  # ================= Charts =================
123
+ def make_basic_charts(df):
124
  total = len(df)
125
+ neg_df = df[df["label"]=="negative"]; pos_df = df[df["label"]=="positive"]
 
 
 
126
  fig_bar = go.Figure()
127
  fig_bar.add_bar(name="negative", x=["negative"], y=[len(neg_df)], marker_color=NEG_COLOR)
128
  fig_bar.add_bar(name="positive", x=["positive"], y=[len(pos_df)], marker_color=POS_COLOR)
129
  fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
130
+ labels=["negative","positive"]; values=[len(neg_df), len(pos_df)]
131
+ fig_pie = go.Figure(go.Pie(labels=labels, values=values, hole=0.35,
 
 
 
132
  marker=dict(colors=[NEG_COLOR, POS_COLOR])))
133
  fig_pie.update_layout(title="Positive vs Negative", template=TEMPLATE)
 
134
  neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
135
  pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
136
+ info=(f"**Summary**\n- Total: {total}\n- Negative: {len(neg_df)}\n- Positive: {len(pos_df)}\n"
137
+ f"- Avg negative: {neg_avg:.2f}%\n- Avg positive: {pos_avg:.2f}%")
 
 
 
 
 
 
138
  return fig_bar, fig_pie, info
139
 
140
  def _resample_counts(df, date_col, freq):
141
+ g = df.groupby([pd.Grouper(key=date_col, freq=freq),"label"]).size().unstack(fill_value=0)
142
+ for c in ["negative","positive"]:
143
+ if c not in g.columns: g[c]=0
 
144
  return g[["negative","positive"]].sort_index()
145
 
146
+ def _rolling_window(freq): return 7 if freq=="D" else (4 if freq=="W" else 3)
 
 
 
 
 
 
 
147
 
148
+ def make_time_chart(df, date_col, freq, use_ma):
149
+ ts=_resample_counts(df,date_col,freq)
150
+ if use_ma: ts=ts.rolling(_rolling_window(freq), min_periods=1).mean()
151
+ fig=go.Figure()
152
+ fig.add_scatter(x=ts.index,y=ts["negative"],mode="lines",name="negative",line=dict(color=NEG_COLOR))
153
+ fig.add_scatter(x=ts.index,y=ts["positive"],mode="lines",name="positive",line=dict(color=POS_COLOR))
154
+ fig.update_layout(title="Reviews over time (POS/NEG)",template=TEMPLATE,
155
+ xaxis_title="Date",yaxis_title="Count")
156
+ return fig
157
 
158
  # ================= Core Predict =================
159
  def _predict_batch(texts, model_name, batch_size=32):
160
+ model,tok,cfg=load_model(model_name); results=[]
161
+ for i in range(0,len(texts),batch_size):
162
+ chunk=texts[i:i+batch_size]
163
+ enc=tok(chunk,padding=True,truncation=True,
164
+ max_length=cfg.get("max_length",128),return_tensors="pt")
 
165
  with torch.no_grad():
166
+ logits=model(enc["input_ids"],enc["attention_mask"])
167
+ probs=F.softmax(logits,dim=1).cpu().numpy()
168
+ for txt,p in zip(chunk,probs):
169
+ neg,pos=float(p[0]),float(p[1])
170
+ label="positive" if pos>=neg else "negative"
171
+ results.append({"review":txt,"negative(%)":_format_pct(neg),
172
+ "positive(%)":_format_pct(pos),"label":label})
 
 
 
 
173
  return results
174
 
175
+ # ================= Batch =================
176
+ def predict_many(text_block, model_choice):
177
  try:
178
+ raw=(text_block or "").splitlines()
179
+ norm=[_norm_text(t) for t in raw]; clean=[t for t in norm if _is_substantive_text(t)]
180
+ if not clean: return pd.DataFrame(),go.Figure(),go.Figure(),"No valid text"
181
+ results=_predict_batch(clean,model_choice); df=pd.DataFrame(results)
182
+ bar,pie,info=make_basic_charts(df)
183
+ return df,bar,pie,info
184
+ except: return pd.DataFrame(),go.Figure(),go.Figure(),traceback.format_exc()
185
+
186
+ # ================= CSV Inspect =================
 
 
 
 
 
 
 
 
 
187
  def on_file_change(file_obj):
 
 
 
 
 
 
188
  if file_obj is None:
189
+ return gr.update(choices=[],value=None),gr.update(choices=[],value=None),\
190
+ gr.update(visible=False),gr.update(visible=False),\
191
+ gr.update(visible=False),gr.update(visible=False),\
192
+ gr.update(visible=False),"Please upload a CSV"
 
 
 
 
 
 
 
193
  try:
194
+ df=pd.read_csv(file_obj.name)
195
+ text_col,date_candidates,date_col=detect_text_and_date_cols(df)
196
+ has_date=date_col is not None
197
+ note=f"Detected text col: **{text_col}**; "+("date col: **{}**".format(date_col) if has_date else "_no date col_")
198
+ return gr.update(choices=list(df.columns),value=text_col),\
199
+ gr.update(choices=date_candidates,value=date_col),\
200
+ gr.update(visible=has_date),gr.update(visible=has_date),\
201
+ gr.update(visible=has_date),gr.update(visible=has_date),\
202
+ gr.update(visible=has_date),note
203
+ except: return gr.update(choices=[],value=None),gr.update(choices=[],value=None),\
204
+ gr.update(visible=False),gr.update(visible=False),\
205
+ gr.update(visible=False),gr.update(visible=False),\
206
+ gr.update(visible=False),"Error reading CSV"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # ================= CSV Predict =================
209
+ def predict_csv(file_obj,model_choice,text_col,date_col,date_from,date_to,freq,use_ma):
210
+ if file_obj is None: return pd.DataFrame(),go.Figure(),go.Figure(),gr.update(visible=False), "No file",None
 
 
211
  try:
212
+ df_raw=pd.read_csv(file_obj.name); cols=list(df_raw.columns)
213
+ if text_col not in cols: text_col,_d,_dc=detect_text_and_date_cols(df_raw);
214
+ texts=[_norm_text(v) for v in df_raw[text_col].tolist()]
215
+ texts=[t for t in texts if _is_substantive_text(t)]
216
+ if not texts: return pd.DataFrame(),go.Figure(),go.Figure(),gr.update(visible=False),"No valid texts",None
217
+ results=_predict_batch(texts,model_choice); out=pd.DataFrame(results)
218
+ bar,pie,info=make_basic_charts(out)
219
+ fig_line=go.Figure(); show_time=False
220
+ if date_col and date_col in cols:
221
+ dts=_to_datetime_safe(df_raw[date_col])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  if dts.notna().any():
223
+ df_time=out.copy(); df_time["__dt__"]=dts; df_time=df_time.dropna(subset=["__dt__"])
224
+ start_ts=_normalize_date_input(date_from); end_ts=_normalize_date_input(date_to)
225
+ if start_ts is not None: df_time=df_time[df_time["__dt__"]>=start_ts]
226
+ if end_ts is not None: df_time=df_time[df_time["__dt__"]<=end_ts]
227
+ if len(df_time)>0: fig_line=make_time_chart(df_time,"__dt__",freq,use_ma); show_time=True
228
+ fd,path=tempfile.mkstemp(suffix=".csv"); os.close(fd)
229
+ out.to_csv(path,index=False,encoding="utf-8-sig")
230
+ return out,bar,pie,gr.update(visible=show_time,value=fig_line),info,path
231
+ except: return pd.DataFrame(),go.Figure(),go.Figure(),gr.update(visible=False),"Error\n"+traceback.format_exc(),None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  # ================= Gradio UI =================
234
+ with gr.Blocks(title="Thai Sentiment") as demo:
235
+ gr.Markdown("### Thai Sentiment — WangchanBERTa Variants")
236
+ model_radio=gr.Radio(choices=AVAILABLE_CHOICES,value=DEFAULT_MODEL,label="เลือกโมเดล")
237
 
238
+ with gr.Tab("Batch"):
239
+ t2=gr.Textbox(lines=8,label="รีวิว (บรรทัดละ 1)")
240
+ btn2=gr.Button("Predict",variant="primary")
241
+ df2=gr.Dataframe(); bar2=gr.Plot(); pie2=gr.Plot(); sum2=gr.Markdown()
242
+ btn2.click(predict_many,[t2,model_radio],[df2,bar2,pie2,sum2])
243
 
 
 
 
 
 
 
 
 
 
 
 
244
  with gr.Tab("CSV Upload"):
245
  with gr.Row():
246
+ file_in=gr.File(file_types=[".csv"]); text_dd=gr.Dropdown(label="Text col")
247
+ date_dd=gr.Dropdown(label="Date col (opt)")
 
248
  with gr.Row():
249
+ date_from=DateInput(label="เริ่มวันที่"+(" (YYYY-MM-DD)" if DATE_FALLBACK_TO_TEXT else ""),visible=False)
250
+ date_to=DateInput(label="ถึงวันที่"+(" (YYYY-MM-DD)" if DATE_FALLBACK_TO_TEXT else ""),visible=False)
251
+ freq=gr.Radio(choices=["D","W","M"],value="D",label="Freq",visible=False)
252
+ use_ma=gr.Checkbox(value=True,label="MA",visible=False)
253
+ btn3=gr.Button("Predict CSV",variant="primary")
254
+ note=gr.Markdown()
255
+ df3=gr.Dataframe(); bar3=gr.Plot(); pie3=gr.Plot()
256
+ line=gr.Plot(visible=False); sum3=gr.Markdown(); dl=gr.File()
257
+
258
+ file_in.change(on_file_change,[file_in],[text_dd,date_dd,date_from,date_to,freq,use_ma,line,note])
259
+ btn3.click(predict_csv,[file_in,model_radio,text_dd,date_dd,date_from,date_to,freq,use_ma],[df3,bar3,pie3,line,sum3,dl])
260
+
261
+ if __name__=="__main__": demo.launch()