Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
# app.py — Thai Sentiment (WangchanBERTa Variants)
|
| 2 |
-
# -
|
| 3 |
-
# -
|
| 4 |
-
# - CSV
|
| 5 |
-
# -
|
|
|
|
| 6 |
import os, json, importlib.util, traceback, re, math, tempfile, datetime
|
| 7 |
import gradio as gr
|
| 8 |
import torch, pandas as pd
|
|
@@ -21,12 +22,31 @@ AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM", "WCB_CNN_BiLSTM", "WCB_4Layer_BiLSTM"]
|
|
| 21 |
if DEFAULT_MODEL not in AVAILABLE_CHOICES:
|
| 22 |
DEFAULT_MODEL = "WCB"
|
| 23 |
|
| 24 |
-
NEG_COLOR = "#F87171"
|
| 25 |
-
POS_COLOR = "#34D399"
|
| 26 |
TEMPLATE = "plotly_white"
|
| 27 |
-
|
| 28 |
CACHE = {}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# ================= Loader =================
|
| 31 |
def _import_models():
|
| 32 |
if "models_module" in CACHE:
|
|
@@ -42,84 +62,56 @@ def load_model(model_name: str):
|
|
| 42 |
key = f"model:{model_name}"
|
| 43 |
if key in CACHE:
|
| 44 |
return CACHE[key]
|
| 45 |
-
|
| 46 |
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
|
| 47 |
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
|
| 48 |
-
|
| 49 |
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 50 |
cfg = json.load(f)
|
| 51 |
-
|
| 52 |
base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
|
| 53 |
arch_name = cfg.get("architecture", model_name)
|
| 54 |
-
|
| 55 |
tok = AutoTokenizer.from_pretrained(base_model)
|
| 56 |
models = _import_models()
|
| 57 |
model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)),
|
| 58 |
cfg.get("pooling_after_lstm", "masked_mean"))
|
| 59 |
-
|
| 60 |
state = load_file(w_path)
|
| 61 |
model.load_state_dict(state, strict=False)
|
| 62 |
model.eval()
|
| 63 |
-
|
| 64 |
CACHE[key] = (model, tok, cfg)
|
| 65 |
return CACHE[key]
|
| 66 |
|
| 67 |
# ================= Utils =================
|
| 68 |
-
_INVALID_STRINGS = {"-", "--",
|
| 69 |
_RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
|
| 70 |
|
| 71 |
-
def _norm_text(v)
|
| 72 |
if v is None: return ""
|
| 73 |
if isinstance(v, float) and math.isnan(v): return ""
|
| 74 |
return str(v).strip().strip('"').strip("'").strip(",")
|
| 75 |
|
| 76 |
-
def _is_substantive_text(s
|
| 77 |
if not s: return False
|
| 78 |
if s.lower() in _INVALID_STRINGS: return False
|
| 79 |
if not _RE_HAS_LETTER.search(s): return False
|
| 80 |
-
if len(s.replace(" ",
|
| 81 |
return True
|
| 82 |
|
| 83 |
-
def _format_pct(x:
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def _to_datetime_safe(s):
|
| 87 |
-
return pd.to_datetime(s, errors="coerce", infer_datetime_format=True, utc=False)
|
| 88 |
-
|
| 89 |
-
def _normalize_datepicker_value(v):
|
| 90 |
-
"""รับค่าจาก gr.DatePicker (datetime.date หรือ str หรือ None) → pandas.Timestamp หรือ None"""
|
| 91 |
-
if v is None or (isinstance(v, float) and math.isnan(v)):
|
| 92 |
-
return None
|
| 93 |
-
if isinstance(v, datetime.date):
|
| 94 |
-
return pd.Timestamp(v)
|
| 95 |
-
# เผื่อบางเวอร์ชันส่ง str 'YYYY-MM-DD'
|
| 96 |
-
try:
|
| 97 |
-
ts = pd.to_datetime(v, errors="coerce")
|
| 98 |
-
return ts if pd.notna(ts) else None
|
| 99 |
-
except Exception:
|
| 100 |
-
return None
|
| 101 |
|
| 102 |
LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body","ข้อความ","รีวิว"]
|
| 103 |
LIKELY_DATE_COLS = ["date","created_at","time","timestamp","datetime","วันที่","วันเวลา","เวลา"]
|
| 104 |
|
| 105 |
-
def detect_text_and_date_cols(df
|
| 106 |
cols = list(df.columns)
|
| 107 |
-
# text col
|
| 108 |
low = {c.lower(): c for c in cols}
|
| 109 |
text_col = None
|
| 110 |
for k in LIKELY_TEXT_COLS:
|
| 111 |
-
if k in low:
|
| 112 |
-
text_col = low[k]; break
|
| 113 |
if text_col is None:
|
| 114 |
cand = [c for c in cols if df[c].dtype == object]
|
| 115 |
text_col = cand[0] if cand else cols[0]
|
| 116 |
-
|
| 117 |
-
# date candidates
|
| 118 |
date_candidates = []
|
| 119 |
for c in cols:
|
| 120 |
-
if c.lower() in LIKELY_DATE_COLS:
|
| 121 |
-
date_candidates.append(c)
|
| 122 |
-
continue
|
| 123 |
sample = df[c].head(50)
|
| 124 |
if _to_datetime_safe(sample).notna().sum() >= max(3, int(len(sample)*0.2)):
|
| 125 |
date_candidates.append(c)
|
|
@@ -128,277 +120,142 @@ def detect_text_and_date_cols(df: pd.DataFrame):
|
|
| 128 |
return text_col, date_candidates, date_col
|
| 129 |
|
| 130 |
# ================= Charts =================
|
| 131 |
-
def make_basic_charts(df
|
| 132 |
total = len(df)
|
| 133 |
-
neg_df = df[df["label"]
|
| 134 |
-
pos_df = df[df["label"] == "positive"].copy()
|
| 135 |
-
|
| 136 |
-
# bar counts
|
| 137 |
fig_bar = go.Figure()
|
| 138 |
fig_bar.add_bar(name="negative", x=["negative"], y=[len(neg_df)], marker_color=NEG_COLOR)
|
| 139 |
fig_bar.add_bar(name="positive", x=["positive"], y=[len(pos_df)], marker_color=POS_COLOR)
|
| 140 |
fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
labels = ["negative", "positive"]
|
| 144 |
-
values = [len(neg_df), len(pos_df)]
|
| 145 |
-
fig_pie = go.Figure(go.Pie(labels=labels, values=values, hole=0.35, sort=False,
|
| 146 |
marker=dict(colors=[NEG_COLOR, POS_COLOR])))
|
| 147 |
fig_pie.update_layout(title="Positive vs Negative", template=TEMPLATE)
|
| 148 |
-
|
| 149 |
neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 150 |
pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 151 |
-
info
|
| 152 |
-
|
| 153 |
-
f"- Total: {total} \n"
|
| 154 |
-
f"- Negative: {len(neg_df)} \n"
|
| 155 |
-
f"- Positive: {len(pos_df)} \n"
|
| 156 |
-
f"- Avg negative: {neg_avg:.2f}% \n"
|
| 157 |
-
f"- Avg positive: {pos_avg:.2f}%"
|
| 158 |
-
)
|
| 159 |
return fig_bar, fig_pie, info
|
| 160 |
|
| 161 |
def _resample_counts(df, date_col, freq):
|
| 162 |
-
g = df.groupby([pd.Grouper(key=date_col, freq=freq),
|
| 163 |
-
for
|
| 164 |
-
if
|
| 165 |
-
g[col] = 0
|
| 166 |
return g[["negative","positive"]].sort_index()
|
| 167 |
|
| 168 |
-
def _rolling_window(freq):
|
| 169 |
-
return 7 if freq == "D" else (4 if freq == "W" else 3)
|
| 170 |
-
|
| 171 |
-
def make_time_chart(df: pd.DataFrame, date_col: str, freq: str, use_ma: bool):
|
| 172 |
-
ts = _resample_counts(df, date_col, freq)
|
| 173 |
-
if use_ma:
|
| 174 |
-
win = _rolling_window(freq)
|
| 175 |
-
ts = ts.rolling(win, min_periods=1).mean()
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
# ================= Core Predict =================
|
| 187 |
def _predict_batch(texts, model_name, batch_size=32):
|
| 188 |
-
model,
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
chunk
|
| 192 |
-
|
| 193 |
-
max_length=cfg.get("max_length",128), return_tensors="pt")
|
| 194 |
with torch.no_grad():
|
| 195 |
-
logits
|
| 196 |
-
probs
|
| 197 |
-
for txt,
|
| 198 |
-
neg,
|
| 199 |
-
label
|
| 200 |
-
results.append({
|
| 201 |
-
|
| 202 |
-
"negative(%)": _format_pct(neg),
|
| 203 |
-
"positive(%)": _format_pct(pos),
|
| 204 |
-
"label": label,
|
| 205 |
-
})
|
| 206 |
return results
|
| 207 |
|
| 208 |
-
# ================= Batch
|
| 209 |
-
def predict_many(text_block
|
| 210 |
try:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
fig_bar, fig_pie, info_md = make_basic_charts(df)
|
| 221 |
-
info_md = f"{info_md} \n- Skipped: {skipped}"
|
| 222 |
-
return df, fig_bar, fig_pie, info_md
|
| 223 |
-
except Exception:
|
| 224 |
-
tb = traceback.format_exc()
|
| 225 |
-
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
|
| 226 |
-
return empty, go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```"
|
| 227 |
-
|
| 228 |
-
# ================= CSV Inspect (auto-detect & toggle UI) =================
|
| 229 |
def on_file_change(file_obj):
|
| 230 |
-
"""
|
| 231 |
-
เมื่ออัปโหลดไฟล์:
|
| 232 |
-
- คืน options ของ text/date dropdown
|
| 233 |
-
- ชื่อ default ที่เลือก
|
| 234 |
-
- toggle visibility ของ date controls + line chart placeholder
|
| 235 |
-
"""
|
| 236 |
if file_obj is None:
|
| 237 |
-
return (
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
gr.update(visible=False), # date_to
|
| 242 |
-
gr.update(visible=False), # freq
|
| 243 |
-
gr.update(visible=False), # use_ma
|
| 244 |
-
gr.update(visible=False), # line chart
|
| 245 |
-
"Please upload a CSV file."
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
try:
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
gr.update(choices=date_candidates, value=date_col),
|
| 263 |
-
gr.update(visible=has_date),
|
| 264 |
-
gr.update(visible=has_date),
|
| 265 |
-
gr.update(visible=has_date),
|
| 266 |
-
gr.update(visible=has_date),
|
| 267 |
-
gr.update(visible=has_date),
|
| 268 |
-
note
|
| 269 |
-
)
|
| 270 |
-
except Exception:
|
| 271 |
-
tb = traceback.format_exc()
|
| 272 |
-
return (
|
| 273 |
-
gr.update(choices=[], value=None),
|
| 274 |
-
gr.update(choices=[], value=None),
|
| 275 |
-
gr.update(visible=False),
|
| 276 |
-
gr.update(visible=False),
|
| 277 |
-
gr.update(visible=False),
|
| 278 |
-
gr.update(visible=False),
|
| 279 |
-
gr.update(visible=False),
|
| 280 |
-
f"**Error reading CSV**\n```\n{tb}\n```"
|
| 281 |
-
)
|
| 282 |
|
| 283 |
# ================= CSV Predict =================
|
| 284 |
-
def predict_csv(file_obj,
|
| 285 |
-
|
| 286 |
-
freq_choice: str, use_ma: bool):
|
| 287 |
-
|
| 288 |
try:
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
if len(texts) == 0:
|
| 300 |
-
return pd.DataFrame(), go.Figure(), go.Figure(), gr.update(visible=False, value=go.Figure()), "No valid texts in selected column.", None
|
| 301 |
-
|
| 302 |
-
# predict
|
| 303 |
-
results = _predict_batch(texts, model_choice)
|
| 304 |
-
out_df = pd.DataFrame(results)
|
| 305 |
-
|
| 306 |
-
# basic charts
|
| 307 |
-
fig_bar, fig_pie, info_basic = make_basic_charts(out_df)
|
| 308 |
-
|
| 309 |
-
# time charts (optional)
|
| 310 |
-
show_time = False
|
| 311 |
-
fig_line = go.Figure()
|
| 312 |
-
if date_col_name and (date_col_name in cols):
|
| 313 |
-
dts = _to_datetime_safe(df_raw[date_col_name])
|
| 314 |
if dts.notna().any():
|
| 315 |
-
df_time =
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
df_time = df_time[df_time["__dt__"] >= start_ts]
|
| 325 |
-
if end_ts is not None:
|
| 326 |
-
df_time = df_time[df_time["__dt__"] <= end_ts]
|
| 327 |
-
|
| 328 |
-
if len(df_time) > 0:
|
| 329 |
-
fig_line = make_time_chart(df_time, "__dt__", freq_choice, use_ma)
|
| 330 |
-
show_time = True
|
| 331 |
-
|
| 332 |
-
# downloadable CSV
|
| 333 |
-
fd, out_path = tempfile.mkstemp(prefix="pred_", suffix=".csv")
|
| 334 |
-
os.close(fd)
|
| 335 |
-
out_df.to_csv(out_path, index=False, encoding="utf-8-sig")
|
| 336 |
-
|
| 337 |
-
info_time = ""
|
| 338 |
-
if date_col_name:
|
| 339 |
-
if show_time:
|
| 340 |
-
info_time = f"\n\nTime chart based on date column: **{date_col_name}**, Freq: **{freq_choice}**, MA: **{use_ma}**"
|
| 341 |
-
else:
|
| 342 |
-
info_time = "\n\n_Selected date range has no data OR unable to parse dates._"
|
| 343 |
-
else:
|
| 344 |
-
info_time = "\n\n_No date/timestamp column selected — time chart hidden._"
|
| 345 |
-
|
| 346 |
-
info_md = info_basic + info_time
|
| 347 |
-
return out_df, fig_bar, fig_pie, gr.update(visible=show_time, value=fig_line), info_md, out_path
|
| 348 |
-
|
| 349 |
-
except Exception:
|
| 350 |
-
tb = traceback.format_exc()
|
| 351 |
-
return pd.DataFrame(), go.Figure(), go.Figure(), gr.update(visible=False, value=go.Figure()), f"**Error**\n```\n{tb}\n```", None
|
| 352 |
|
| 353 |
# ================= Gradio UI =================
|
| 354 |
-
with gr.Blocks(title="Thai Sentiment
|
| 355 |
-
gr.Markdown("### Thai Sentiment
|
|
|
|
| 356 |
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
-
# ---- Batch (Textarea) ----
|
| 360 |
-
with gr.Tab("Batch (หลายข้อความ)"):
|
| 361 |
-
t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
|
| 362 |
-
btn_batch = gr.Button("Predict", variant="primary")
|
| 363 |
-
df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
|
| 364 |
-
bar2 = gr.Plot(label="Label counts (bar)")
|
| 365 |
-
pie2 = gr.Plot(label="Positive vs Negative (pie)")
|
| 366 |
-
sum2 = gr.Markdown()
|
| 367 |
-
btn_batch.click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
|
| 368 |
-
|
| 369 |
-
# ---- CSV Upload ----
|
| 370 |
with gr.Tab("CSV Upload"):
|
| 371 |
with gr.Row():
|
| 372 |
-
file_in
|
| 373 |
-
|
| 374 |
-
date_dd = gr.Dropdown(label="คอลัมน์วันเวลา (ถ้ามี)", choices=[], value=None)
|
| 375 |
with gr.Row():
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
sum3 = gr.Markdown()
|
| 390 |
-
dl3 = gr.File(label="ดาวน์โหลดผลเป็น CSV", interactive=False)
|
| 391 |
-
|
| 392 |
-
file_in.change(
|
| 393 |
-
on_file_change, [file_in],
|
| 394 |
-
[text_dd, date_dd, date_from, date_to, freq, use_ma, line, note_detect]
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
btn_csv.click(
|
| 398 |
-
predict_csv,
|
| 399 |
-
[file_in, model_radio, text_dd, date_dd, date_from, date_to, freq, use_ma],
|
| 400 |
-
[df3, bar3, pie3, line, sum3, dl3]
|
| 401 |
-
)
|
| 402 |
-
|
| 403 |
-
if __name__ == "__main__":
|
| 404 |
-
demo.launch()
|
|
|
|
| 1 |
# app.py — Thai Sentiment (WangchanBERTa Variants)
|
| 2 |
+
# - Focus on POS/NEG only
|
| 3 |
+
# - Batch + CSV tabs
|
| 4 |
+
# - CSV: auto-detect text/date cols, hide date widgets if no date col
|
| 5 |
+
# - DatePicker fallback to Textbox if component missing
|
| 6 |
+
|
| 7 |
import os, json, importlib.util, traceback, re, math, tempfile, datetime
|
| 8 |
import gradio as gr
|
| 9 |
import torch, pandas as pd
|
|
|
|
| 22 |
if DEFAULT_MODEL not in AVAILABLE_CHOICES:
|
| 23 |
DEFAULT_MODEL = "WCB"
|
| 24 |
|
| 25 |
+
NEG_COLOR = "#F87171"
|
| 26 |
+
POS_COLOR = "#34D399"
|
| 27 |
TEMPLATE = "plotly_white"
|
|
|
|
| 28 |
CACHE = {}
|
| 29 |
|
| 30 |
+
# ================= Date Component Fallback =================
|
| 31 |
+
try:
|
| 32 |
+
DateInput = getattr(gr, "Date", None) or getattr(gr, "DatePicker", None)
|
| 33 |
+
except Exception:
|
| 34 |
+
DateInput = None
|
| 35 |
+
DATE_FALLBACK_TO_TEXT = False
|
| 36 |
+
if DateInput is None:
|
| 37 |
+
DateInput = gr.Textbox
|
| 38 |
+
DATE_FALLBACK_TO_TEXT = True
|
| 39 |
+
|
| 40 |
+
def _normalize_date_input(v):
|
| 41 |
+
if v is None: return None
|
| 42 |
+
if isinstance(v, float) and math.isnan(v): return None
|
| 43 |
+
if isinstance(v, datetime.date): return pd.Timestamp(v)
|
| 44 |
+
try:
|
| 45 |
+
ts = pd.to_datetime(v, errors="coerce")
|
| 46 |
+
return ts if pd.notna(ts) else None
|
| 47 |
+
except Exception:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
# ================= Loader =================
|
| 51 |
def _import_models():
|
| 52 |
if "models_module" in CACHE:
|
|
|
|
| 62 |
key = f"model:{model_name}"
|
| 63 |
if key in CACHE:
|
| 64 |
return CACHE[key]
|
|
|
|
| 65 |
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
|
| 66 |
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
|
|
|
|
| 67 |
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 68 |
cfg = json.load(f)
|
|
|
|
| 69 |
base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
|
| 70 |
arch_name = cfg.get("architecture", model_name)
|
|
|
|
| 71 |
tok = AutoTokenizer.from_pretrained(base_model)
|
| 72 |
models = _import_models()
|
| 73 |
model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)),
|
| 74 |
cfg.get("pooling_after_lstm", "masked_mean"))
|
|
|
|
| 75 |
state = load_file(w_path)
|
| 76 |
model.load_state_dict(state, strict=False)
|
| 77 |
model.eval()
|
|
|
|
| 78 |
CACHE[key] = (model, tok, cfg)
|
| 79 |
return CACHE[key]
|
| 80 |
|
| 81 |
# ================= Utils =================
|
| 82 |
+
_INVALID_STRINGS = {"-", "--","—","n/a","na","null","none","nan",".","…",""}
|
| 83 |
_RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
|
| 84 |
|
| 85 |
+
def _norm_text(v):
|
| 86 |
if v is None: return ""
|
| 87 |
if isinstance(v, float) and math.isnan(v): return ""
|
| 88 |
return str(v).strip().strip('"').strip("'").strip(",")
|
| 89 |
|
| 90 |
+
def _is_substantive_text(s, min_chars=2):
|
| 91 |
if not s: return False
|
| 92 |
if s.lower() in _INVALID_STRINGS: return False
|
| 93 |
if not _RE_HAS_LETTER.search(s): return False
|
| 94 |
+
if len(s.replace(" ","")) < min_chars: return False
|
| 95 |
return True
|
| 96 |
|
| 97 |
+
def _format_pct(x): return f"{x*100:.2f}%"
|
| 98 |
+
def _to_datetime_safe(s): return pd.to_datetime(s, errors="coerce", infer_datetime_format=True, utc=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body","ข้อความ","รีวิว"]
|
| 101 |
LIKELY_DATE_COLS = ["date","created_at","time","timestamp","datetime","วันที่","วันเวลา","เวลา"]
|
| 102 |
|
| 103 |
+
def detect_text_and_date_cols(df):
|
| 104 |
cols = list(df.columns)
|
|
|
|
| 105 |
low = {c.lower(): c for c in cols}
|
| 106 |
text_col = None
|
| 107 |
for k in LIKELY_TEXT_COLS:
|
| 108 |
+
if k in low: text_col = low[k]; break
|
|
|
|
| 109 |
if text_col is None:
|
| 110 |
cand = [c for c in cols if df[c].dtype == object]
|
| 111 |
text_col = cand[0] if cand else cols[0]
|
|
|
|
|
|
|
| 112 |
date_candidates = []
|
| 113 |
for c in cols:
|
| 114 |
+
if c.lower() in LIKELY_DATE_COLS: date_candidates.append(c); continue
|
|
|
|
|
|
|
| 115 |
sample = df[c].head(50)
|
| 116 |
if _to_datetime_safe(sample).notna().sum() >= max(3, int(len(sample)*0.2)):
|
| 117 |
date_candidates.append(c)
|
|
|
|
| 120 |
return text_col, date_candidates, date_col
|
| 121 |
|
| 122 |
# ================= Charts =================
|
| 123 |
+
def make_basic_charts(df):
|
| 124 |
total = len(df)
|
| 125 |
+
neg_df = df[df["label"]=="negative"]; pos_df = df[df["label"]=="positive"]
|
|
|
|
|
|
|
|
|
|
| 126 |
fig_bar = go.Figure()
|
| 127 |
fig_bar.add_bar(name="negative", x=["negative"], y=[len(neg_df)], marker_color=NEG_COLOR)
|
| 128 |
fig_bar.add_bar(name="positive", x=["positive"], y=[len(pos_df)], marker_color=POS_COLOR)
|
| 129 |
fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
|
| 130 |
+
labels=["negative","positive"]; values=[len(neg_df), len(pos_df)]
|
| 131 |
+
fig_pie = go.Figure(go.Pie(labels=labels, values=values, hole=0.35,
|
|
|
|
|
|
|
|
|
|
| 132 |
marker=dict(colors=[NEG_COLOR, POS_COLOR])))
|
| 133 |
fig_pie.update_layout(title="Positive vs Negative", template=TEMPLATE)
|
|
|
|
| 134 |
neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 135 |
pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 136 |
+
info=(f"**Summary**\n- Total: {total}\n- Negative: {len(neg_df)}\n- Positive: {len(pos_df)}\n"
|
| 137 |
+
f"- Avg negative: {neg_avg:.2f}%\n- Avg positive: {pos_avg:.2f}%")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
return fig_bar, fig_pie, info
|
| 139 |
|
| 140 |
def _resample_counts(df, date_col, freq):
|
| 141 |
+
g = df.groupby([pd.Grouper(key=date_col, freq=freq),"label"]).size().unstack(fill_value=0)
|
| 142 |
+
for c in ["negative","positive"]:
|
| 143 |
+
if c not in g.columns: g[c]=0
|
|
|
|
| 144 |
return g[["negative","positive"]].sort_index()
|
| 145 |
|
| 146 |
+
def _rolling_window(freq): return 7 if freq=="D" else (4 if freq=="W" else 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
def make_time_chart(df, date_col, freq, use_ma):
|
| 149 |
+
ts=_resample_counts(df,date_col,freq)
|
| 150 |
+
if use_ma: ts=ts.rolling(_rolling_window(freq), min_periods=1).mean()
|
| 151 |
+
fig=go.Figure()
|
| 152 |
+
fig.add_scatter(x=ts.index,y=ts["negative"],mode="lines",name="negative",line=dict(color=NEG_COLOR))
|
| 153 |
+
fig.add_scatter(x=ts.index,y=ts["positive"],mode="lines",name="positive",line=dict(color=POS_COLOR))
|
| 154 |
+
fig.update_layout(title="Reviews over time (POS/NEG)",template=TEMPLATE,
|
| 155 |
+
xaxis_title="Date",yaxis_title="Count")
|
| 156 |
+
return fig
|
| 157 |
|
| 158 |
# ================= Core Predict =================
|
| 159 |
def _predict_batch(texts, model_name, batch_size=32):
|
| 160 |
+
model,tok,cfg=load_model(model_name); results=[]
|
| 161 |
+
for i in range(0,len(texts),batch_size):
|
| 162 |
+
chunk=texts[i:i+batch_size]
|
| 163 |
+
enc=tok(chunk,padding=True,truncation=True,
|
| 164 |
+
max_length=cfg.get("max_length",128),return_tensors="pt")
|
|
|
|
| 165 |
with torch.no_grad():
|
| 166 |
+
logits=model(enc["input_ids"],enc["attention_mask"])
|
| 167 |
+
probs=F.softmax(logits,dim=1).cpu().numpy()
|
| 168 |
+
for txt,p in zip(chunk,probs):
|
| 169 |
+
neg,pos=float(p[0]),float(p[1])
|
| 170 |
+
label="positive" if pos>=neg else "negative"
|
| 171 |
+
results.append({"review":txt,"negative(%)":_format_pct(neg),
|
| 172 |
+
"positive(%)":_format_pct(pos),"label":label})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
return results
|
| 174 |
|
| 175 |
+
# ================= Batch =================
|
| 176 |
+
def predict_many(text_block, model_choice):
|
| 177 |
try:
|
| 178 |
+
raw=(text_block or "").splitlines()
|
| 179 |
+
norm=[_norm_text(t) for t in raw]; clean=[t for t in norm if _is_substantive_text(t)]
|
| 180 |
+
if not clean: return pd.DataFrame(),go.Figure(),go.Figure(),"No valid text"
|
| 181 |
+
results=_predict_batch(clean,model_choice); df=pd.DataFrame(results)
|
| 182 |
+
bar,pie,info=make_basic_charts(df)
|
| 183 |
+
return df,bar,pie,info
|
| 184 |
+
except: return pd.DataFrame(),go.Figure(),go.Figure(),traceback.format_exc()
|
| 185 |
+
|
| 186 |
+
# ================= CSV Inspect =================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
def on_file_change(file_obj):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
if file_obj is None:
|
| 189 |
+
return gr.update(choices=[],value=None),gr.update(choices=[],value=None),\
|
| 190 |
+
gr.update(visible=False),gr.update(visible=False),\
|
| 191 |
+
gr.update(visible=False),gr.update(visible=False),\
|
| 192 |
+
gr.update(visible=False),"Please upload a CSV"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
try:
|
| 194 |
+
df=pd.read_csv(file_obj.name)
|
| 195 |
+
text_col,date_candidates,date_col=detect_text_and_date_cols(df)
|
| 196 |
+
has_date=date_col is not None
|
| 197 |
+
note=f"Detected text col: **{text_col}**; "+("date col: **{}**".format(date_col) if has_date else "_no date col_")
|
| 198 |
+
return gr.update(choices=list(df.columns),value=text_col),\
|
| 199 |
+
gr.update(choices=date_candidates,value=date_col),\
|
| 200 |
+
gr.update(visible=has_date),gr.update(visible=has_date),\
|
| 201 |
+
gr.update(visible=has_date),gr.update(visible=has_date),\
|
| 202 |
+
gr.update(visible=has_date),note
|
| 203 |
+
except: return gr.update(choices=[],value=None),gr.update(choices=[],value=None),\
|
| 204 |
+
gr.update(visible=False),gr.update(visible=False),\
|
| 205 |
+
gr.update(visible=False),gr.update(visible=False),\
|
| 206 |
+
gr.update(visible=False),"Error reading CSV"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
# ================= CSV Predict =================
|
| 209 |
+
def predict_csv(file_obj,model_choice,text_col,date_col,date_from,date_to,freq,use_ma):
|
| 210 |
+
if file_obj is None: return pd.DataFrame(),go.Figure(),go.Figure(),gr.update(visible=False), "No file",None
|
|
|
|
|
|
|
| 211 |
try:
|
| 212 |
+
df_raw=pd.read_csv(file_obj.name); cols=list(df_raw.columns)
|
| 213 |
+
if text_col not in cols: text_col,_d,_dc=detect_text_and_date_cols(df_raw);
|
| 214 |
+
texts=[_norm_text(v) for v in df_raw[text_col].tolist()]
|
| 215 |
+
texts=[t for t in texts if _is_substantive_text(t)]
|
| 216 |
+
if not texts: return pd.DataFrame(),go.Figure(),go.Figure(),gr.update(visible=False),"No valid texts",None
|
| 217 |
+
results=_predict_batch(texts,model_choice); out=pd.DataFrame(results)
|
| 218 |
+
bar,pie,info=make_basic_charts(out)
|
| 219 |
+
fig_line=go.Figure(); show_time=False
|
| 220 |
+
if date_col and date_col in cols:
|
| 221 |
+
dts=_to_datetime_safe(df_raw[date_col])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
if dts.notna().any():
|
| 223 |
+
df_time=out.copy(); df_time["__dt__"]=dts; df_time=df_time.dropna(subset=["__dt__"])
|
| 224 |
+
start_ts=_normalize_date_input(date_from); end_ts=_normalize_date_input(date_to)
|
| 225 |
+
if start_ts is not None: df_time=df_time[df_time["__dt__"]>=start_ts]
|
| 226 |
+
if end_ts is not None: df_time=df_time[df_time["__dt__"]<=end_ts]
|
| 227 |
+
if len(df_time)>0: fig_line=make_time_chart(df_time,"__dt__",freq,use_ma); show_time=True
|
| 228 |
+
fd,path=tempfile.mkstemp(suffix=".csv"); os.close(fd)
|
| 229 |
+
out.to_csv(path,index=False,encoding="utf-8-sig")
|
| 230 |
+
return out,bar,pie,gr.update(visible=show_time,value=fig_line),info,path
|
| 231 |
+
except: return pd.DataFrame(),go.Figure(),go.Figure(),gr.update(visible=False),"Error\n"+traceback.format_exc(),None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
# ================= Gradio UI =================
|
| 234 |
+
with gr.Blocks(title="Thai Sentiment") as demo:
|
| 235 |
+
gr.Markdown("### Thai Sentiment — WangchanBERTa Variants")
|
| 236 |
+
model_radio=gr.Radio(choices=AVAILABLE_CHOICES,value=DEFAULT_MODEL,label="เลือกโมเดล")
|
| 237 |
|
| 238 |
+
with gr.Tab("Batch"):
|
| 239 |
+
t2=gr.Textbox(lines=8,label="รีวิว (บรรทัดละ 1)")
|
| 240 |
+
btn2=gr.Button("Predict",variant="primary")
|
| 241 |
+
df2=gr.Dataframe(); bar2=gr.Plot(); pie2=gr.Plot(); sum2=gr.Markdown()
|
| 242 |
+
btn2.click(predict_many,[t2,model_radio],[df2,bar2,pie2,sum2])
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
with gr.Tab("CSV Upload"):
|
| 245 |
with gr.Row():
|
| 246 |
+
file_in=gr.File(file_types=[".csv"]); text_dd=gr.Dropdown(label="Text col")
|
| 247 |
+
date_dd=gr.Dropdown(label="Date col (opt)")
|
|
|
|
| 248 |
with gr.Row():
|
| 249 |
+
date_from=DateInput(label="เริ่มวันที่"+(" (YYYY-MM-DD)" if DATE_FALLBACK_TO_TEXT else ""),visible=False)
|
| 250 |
+
date_to=DateInput(label="ถึงวันที่"+(" (YYYY-MM-DD)" if DATE_FALLBACK_TO_TEXT else ""),visible=False)
|
| 251 |
+
freq=gr.Radio(choices=["D","W","M"],value="D",label="Freq",visible=False)
|
| 252 |
+
use_ma=gr.Checkbox(value=True,label="MA",visible=False)
|
| 253 |
+
btn3=gr.Button("Predict CSV",variant="primary")
|
| 254 |
+
note=gr.Markdown()
|
| 255 |
+
df3=gr.Dataframe(); bar3=gr.Plot(); pie3=gr.Plot()
|
| 256 |
+
line=gr.Plot(visible=False); sum3=gr.Markdown(); dl=gr.File()
|
| 257 |
+
|
| 258 |
+
file_in.change(on_file_change,[file_in],[text_dd,date_dd,date_from,date_to,freq,use_ma,line,note])
|
| 259 |
+
btn3.click(predict_csv,[file_in,model_radio,text_dd,date_dd,date_from,date_to,freq,use_ma],[df3,bar3,pie3,line,sum3,dl])
|
| 260 |
+
|
| 261 |
+
if __name__=="__main__": demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|