Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
-
# app.py — Thai Sentiment (WangchanBERTa Variants) GUI
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import torch, pandas as pd
|
| 5 |
import torch.nn.functional as F
|
|
@@ -9,7 +11,7 @@ from safetensors.torch import load_file
|
|
| 9 |
from transformers import AutoTokenizer
|
| 10 |
|
| 11 |
# ================= Settings =================
|
| 12 |
-
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment")
|
| 13 |
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "WCB")
|
| 14 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
| 15 |
|
|
@@ -17,15 +19,13 @@ AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM", "WCB_CNN_BiLSTM", "WCB_4Layer_BiLSTM"]
|
|
| 17 |
if DEFAULT_MODEL not in AVAILABLE_CHOICES:
|
| 18 |
DEFAULT_MODEL = "WCB"
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
POS_COLOR = "#34D399"
|
| 23 |
TEMPLATE = "plotly_white"
|
| 24 |
|
| 25 |
CACHE = {}
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
# ปรับแก้/เพิ่มคีย์เวิร์ดได้ตามโดเมนงานของคุณ
|
| 29 |
ASPECT_KEYWORDS = {
|
| 30 |
"บริการ": [
|
| 31 |
"บริการ", "เซอร์วิส", "service", "ดูแล", "เอาใจใส่", "ใส่ใจ",
|
|
@@ -162,7 +162,6 @@ _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
|
|
| 162 |
def _norm_text(v) -> str:
|
| 163 |
if v is None: return ""
|
| 164 |
if isinstance(v, float) and math.isnan(v): return ""
|
| 165 |
-
# ล้างเครื่องหมายที่พบบ่อยจากการพิมพ์/CSV
|
| 166 |
return str(v).strip().strip('"').strip("'").strip(",")
|
| 167 |
|
| 168 |
def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
|
|
@@ -183,7 +182,6 @@ def _compile_aspect_regex():
|
|
| 183 |
return _aspect_regex_cache
|
| 184 |
cache = {}
|
| 185 |
for aspect, kws in ASPECT_KEYWORDS.items():
|
| 186 |
-
# สร้าง regex OR จากคีย์เวิร์ด (หนีอักขระพิเศษ)
|
| 187 |
parts = [re.escape(k) for k in kws if k.strip()]
|
| 188 |
if not parts:
|
| 189 |
continue
|
|
@@ -193,10 +191,6 @@ def _compile_aspect_regex():
|
|
| 193 |
return cache
|
| 194 |
|
| 195 |
def detect_aspect_for_negative(text: str) -> str:
|
| 196 |
-
"""
|
| 197 |
-
คืนชื่อหมวดแรกที่พบจากคีย์เวิร์ด ถ้าไม่พบเลย → "อื่นๆ"
|
| 198 |
-
(เราจงใจให้เป็น single-label ต่อ 1 ข้อความเพื่ออ่าน pie ง่าย)
|
| 199 |
-
"""
|
| 200 |
text = text or ""
|
| 201 |
regs = _compile_aspect_regex()
|
| 202 |
for aspect, rx in regs.items():
|
|
@@ -204,39 +198,40 @@ def detect_aspect_for_negative(text: str) -> str:
|
|
| 204 |
return aspect
|
| 205 |
return ASPECT_OTHER_LABEL
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
df: columns = review, negative(%), positive(%), label
|
| 210 |
-
return: (bar_counts, pie_negative_aspects, info_md)
|
| 211 |
-
"""
|
| 212 |
total = len(df)
|
| 213 |
neg_df = df[df["label"] == "negative"].copy()
|
| 214 |
pos_df = df[df["label"] == "positive"].copy()
|
| 215 |
|
| 216 |
-
#
|
| 217 |
fig_bar = go.Figure()
|
| 218 |
fig_bar.add_bar(name="negative", x=["negative"], y=[len(neg_df)], marker_color=NEG_COLOR)
|
| 219 |
fig_bar.add_bar(name="positive", x=["positive"], y=[len(pos_df)], marker_color=POS_COLOR)
|
| 220 |
fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
|
| 221 |
|
| 222 |
-
#
|
| 223 |
if len(neg_df) > 0:
|
| 224 |
neg_df["aspect"] = neg_df["review"].apply(detect_aspect_for_negative)
|
| 225 |
counts = neg_df["aspect"].value_counts()
|
| 226 |
labels = list(counts.index)
|
| 227 |
values = list(counts.values)
|
| 228 |
else:
|
| 229 |
-
labels, values = [ASPECT_OTHER_LABEL], [1] # placeholder
|
| 230 |
-
|
| 231 |
-
fig_pie = go.Figure(go.Pie(
|
| 232 |
-
labels=labels,
|
| 233 |
-
values=values,
|
| 234 |
-
hole=0.35,
|
| 235 |
-
sort=False
|
| 236 |
-
))
|
| 237 |
fig_pie.update_layout(title="Negative aspect distribution", template=TEMPLATE)
|
| 238 |
|
| 239 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 241 |
pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 242 |
info = (
|
|
@@ -245,10 +240,62 @@ def make_charts_with_negative_aspects(df: pd.DataFrame):
|
|
| 245 |
f"- Negative: {len(neg_df)} \n"
|
| 246 |
f"- Positive: {len(pos_df)} \n"
|
| 247 |
f"- Avg negative: {neg_avg:.2f}% \n"
|
| 248 |
-
f"- Avg positive: {pos_avg:.2f}
|
| 249 |
-
f"- *Pie
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
)
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
# ================= Core Predict =================
|
| 254 |
def _predict_batch(texts, model_name, batch_size=32):
|
|
@@ -272,22 +319,7 @@ def _predict_batch(texts, model_name, batch_size=32):
|
|
| 272 |
})
|
| 273 |
return results
|
| 274 |
|
| 275 |
-
# -----
|
| 276 |
-
def predict_one(text: str, model_choice: str):
|
| 277 |
-
try:
|
| 278 |
-
s = _norm_text(text)
|
| 279 |
-
if not _is_substantive_text(s):
|
| 280 |
-
return {"negative": 0.0, "positive": 0.0}, "invalid"
|
| 281 |
-
out = _predict_batch([s], model_choice)[0]
|
| 282 |
-
probs = {
|
| 283 |
-
"negative": float(out["negative(%)"].rstrip("%"))/100.0,
|
| 284 |
-
"positive": float(out["positive(%)"].rstrip("%"))/100.0,
|
| 285 |
-
}
|
| 286 |
-
return probs, out["label"]
|
| 287 |
-
except Exception as e:
|
| 288 |
-
return {"error": str(e)}, "error"
|
| 289 |
-
|
| 290 |
-
# ----- textarea batch -----
|
| 291 |
def predict_many(text_block: str, model_choice: str):
|
| 292 |
try:
|
| 293 |
raw_lines = (text_block or "").splitlines()
|
|
@@ -296,88 +328,158 @@ def predict_many(text_block: str, model_choice: str):
|
|
| 296 |
skipped = len(all_norm) - len(cleaned)
|
| 297 |
if len(cleaned) == 0:
|
| 298 |
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
|
| 299 |
-
return empty, go.Figure(), go.Figure(), "No valid text"
|
| 300 |
results = _predict_batch(cleaned, model_choice)
|
| 301 |
df = pd.DataFrame(results)
|
| 302 |
-
fig_bar, fig_pie, info_md =
|
| 303 |
info_md = f"{info_md} \n- Skipped: {skipped}"
|
| 304 |
-
return df, fig_bar, fig_pie, info_md
|
| 305 |
except Exception:
|
| 306 |
tb = traceback.format_exc()
|
| 307 |
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
|
| 308 |
-
return empty, go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```"
|
| 309 |
|
| 310 |
# ----- CSV upload -----
|
| 311 |
LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body","ข้อความ","รีวิว"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
-
def predict_csv(file_obj, model_choice: str, text_col_name: str):
|
| 314 |
try:
|
| 315 |
if file_obj is None:
|
| 316 |
-
return pd.DataFrame(), go.Figure(), go.Figure(), "Please upload a CSV.", None
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
#
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
low = {c.lower(): c for c in cols}
|
|
|
|
| 325 |
for k in LIKELY_TEXT_COLS:
|
| 326 |
if k in low:
|
| 327 |
found = low[k]; break
|
| 328 |
if found is None:
|
| 329 |
-
cand = [c for c in cols if
|
| 330 |
found = cand[0] if cand else cols[0]
|
| 331 |
-
|
| 332 |
|
| 333 |
-
texts = [_norm_text(v) for v in
|
| 334 |
texts = [t for t in texts if _is_substantive_text(t)]
|
| 335 |
if len(texts) == 0:
|
| 336 |
-
return pd.DataFrame(), go.Figure(), go.Figure(),
|
|
|
|
| 337 |
|
|
|
|
| 338 |
results = _predict_batch(texts, model_choice)
|
| 339 |
out_df = pd.DataFrame(results)
|
| 340 |
-
fig_bar, fig_pie, info_md = make_charts_with_negative_aspects(out_df)
|
| 341 |
|
| 342 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
fd, out_path = tempfile.mkstemp(prefix="pred_", suffix=".csv")
|
| 344 |
os.close(fd)
|
| 345 |
-
out_df.to_csv(out_path, index=False, encoding="utf-8-sig")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
info_md = f"{info_md} \n- Column used: **{col}**"
|
| 348 |
-
return out_df, fig_bar, fig_pie, info_md, out_path
|
| 349 |
except Exception:
|
| 350 |
tb = traceback.format_exc()
|
| 351 |
-
return pd.DataFrame(), go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```", None
|
| 352 |
|
| 353 |
# ================= Gradio UI =================
|
| 354 |
with gr.Blocks(title="Thai Sentiment (WangchanBERTa Variants)") as demo:
|
| 355 |
gr.Markdown("### Thai Sentiment (WangchanBERTa Variants)")
|
| 356 |
model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
|
| 357 |
|
| 358 |
-
|
| 359 |
-
t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
|
| 360 |
-
probs = gr.Label(label="Probabilities")
|
| 361 |
-
pred = gr.Textbox(label="Prediction", interactive=False)
|
| 362 |
-
gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred])
|
| 363 |
-
|
| 364 |
with gr.Tab("Batch (หลายข้อความ)"):
|
| 365 |
t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
|
| 366 |
df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
|
| 367 |
bar2 = gr.Plot(label="Label counts (bar)")
|
| 368 |
-
pie2 = gr.Plot(label="Negative aspects (pie)")
|
|
|
|
| 369 |
sum2 = gr.Markdown()
|
| 370 |
-
gr.Button("Run Batch").click(
|
|
|
|
|
|
|
| 371 |
|
|
|
|
| 372 |
with gr.Tab("CSV Upload"):
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
df3 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
|
| 376 |
bar3 = gr.Plot(label="Label counts (bar)")
|
| 377 |
-
pie3 = gr.Plot(label="Negative aspects (pie)")
|
|
|
|
|
|
|
| 378 |
sum3 = gr.Markdown()
|
| 379 |
dl3 = gr.File(label="ดาวน์โหลดผลเป็น CSV", interactive=False)
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
if __name__ == "__main__":
|
| 383 |
demo.launch()
|
|
|
|
| 1 |
+
# app.py — Thai Sentiment (WangchanBERTa Variants) GUI
|
| 2 |
+
# No Single tab. Tabs: Batch (textarea) + CSV Upload
|
| 3 |
+
# Extended B: Negative-aspect pie + Top-N bar + Time series (POS/NEG & aspects)
|
| 4 |
+
import os, json, importlib.util, traceback, re, math, tempfile
|
| 5 |
import gradio as gr
|
| 6 |
import torch, pandas as pd
|
| 7 |
import torch.nn.functional as F
|
|
|
|
| 11 |
from transformers import AutoTokenizer
|
| 12 |
|
| 13 |
# ================= Settings =================
|
| 14 |
+
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment")
|
| 15 |
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "WCB")
|
| 16 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
| 17 |
|
|
|
|
| 19 |
if DEFAULT_MODEL not in AVAILABLE_CHOICES:
|
| 20 |
DEFAULT_MODEL = "WCB"
|
| 21 |
|
| 22 |
+
NEG_COLOR = "#F87171" # red
|
| 23 |
+
POS_COLOR = "#34D399" # green
|
|
|
|
| 24 |
TEMPLATE = "plotly_white"
|
| 25 |
|
| 26 |
CACHE = {}
|
| 27 |
|
| 28 |
+
# ================ Aspect Keywords ================
|
|
|
|
| 29 |
ASPECT_KEYWORDS = {
|
| 30 |
"บริการ": [
|
| 31 |
"บริการ", "เซอร์วิส", "service", "ดูแล", "เอาใจใส่", "ใส่ใจ",
|
|
|
|
| 162 |
def _norm_text(v) -> str:
|
| 163 |
if v is None: return ""
|
| 164 |
if isinstance(v, float) and math.isnan(v): return ""
|
|
|
|
| 165 |
return str(v).strip().strip('"').strip("'").strip(",")
|
| 166 |
|
| 167 |
def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
|
|
|
|
| 182 |
return _aspect_regex_cache
|
| 183 |
cache = {}
|
| 184 |
for aspect, kws in ASPECT_KEYWORDS.items():
|
|
|
|
| 185 |
parts = [re.escape(k) for k in kws if k.strip()]
|
| 186 |
if not parts:
|
| 187 |
continue
|
|
|
|
| 191 |
return cache
|
| 192 |
|
| 193 |
def detect_aspect_for_negative(text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
text = text or ""
|
| 195 |
regs = _compile_aspect_regex()
|
| 196 |
for aspect, rx in regs.items():
|
|
|
|
| 198 |
return aspect
|
| 199 |
return ASPECT_OTHER_LABEL
|
| 200 |
|
| 201 |
+
# ================= Charts (no time) =================
|
| 202 |
+
def make_basic_charts(df: pd.DataFrame):
|
|
|
|
|
|
|
|
|
|
| 203 |
total = len(df)
|
| 204 |
neg_df = df[df["label"] == "negative"].copy()
|
| 205 |
pos_df = df[df["label"] == "positive"].copy()
|
| 206 |
|
| 207 |
+
# bar counts
|
| 208 |
fig_bar = go.Figure()
|
| 209 |
fig_bar.add_bar(name="negative", x=["negative"], y=[len(neg_df)], marker_color=NEG_COLOR)
|
| 210 |
fig_bar.add_bar(name="positive", x=["positive"], y=[len(pos_df)], marker_color=POS_COLOR)
|
| 211 |
fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
|
| 212 |
|
| 213 |
+
# pie negative aspects
|
| 214 |
if len(neg_df) > 0:
|
| 215 |
neg_df["aspect"] = neg_df["review"].apply(detect_aspect_for_negative)
|
| 216 |
counts = neg_df["aspect"].value_counts()
|
| 217 |
labels = list(counts.index)
|
| 218 |
values = list(counts.values)
|
| 219 |
else:
|
| 220 |
+
labels, values = [ASPECT_OTHER_LABEL], [1] # placeholder
|
| 221 |
+
|
| 222 |
+
fig_pie = go.Figure(go.Pie(labels=labels, values=values, hole=0.35, sort=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
fig_pie.update_layout(title="Negative aspect distribution", template=TEMPLATE)
|
| 224 |
|
| 225 |
+
# top-N aspects bar (negative only, overall)
|
| 226 |
+
fig_topn = go.Figure()
|
| 227 |
+
if len(neg_df) > 0:
|
| 228 |
+
top_counts = neg_df["aspect"].value_counts().sort_values(ascending=True)
|
| 229 |
+
fig_topn.add_bar(y=top_counts.index.tolist(), x=top_counts.values.tolist(),
|
| 230 |
+
orientation="h", marker_color=NEG_COLOR)
|
| 231 |
+
fig_topn.update_layout(title="Top negative aspects (overall)", template=TEMPLATE)
|
| 232 |
+
else:
|
| 233 |
+
fig_topn.update_layout(title="Top negative aspects (no negative rows)", template=TEMPLATE)
|
| 234 |
+
|
| 235 |
neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 236 |
pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
|
| 237 |
info = (
|
|
|
|
| 240 |
f"- Negative: {len(neg_df)} \n"
|
| 241 |
f"- Positive: {len(pos_df)} \n"
|
| 242 |
f"- Avg negative: {neg_avg:.2f}% \n"
|
| 243 |
+
f"- Avg positive: {pos_avg:.2f}%\n"
|
| 244 |
+
f"- *Pie & Top-N = เฉพาะข้อความเชิงลบ*"
|
| 245 |
+
)
|
| 246 |
+
return fig_bar, fig_pie, fig_topn, info
|
| 247 |
+
|
| 248 |
+
# ================= Charts (with time) =================
|
| 249 |
+
def _to_datetime_safe(s):
|
| 250 |
+
return pd.to_datetime(s, errors="coerce", infer_datetime_format=True, utc=False)
|
| 251 |
+
|
| 252 |
+
def _resample_counts(df, date_col, freq):
|
| 253 |
+
g = df.groupby([pd.Grouper(key=date_col, freq=freq), "label"]).size().unstack(fill_value=0)
|
| 254 |
+
for col in ["negative","positive"]:
|
| 255 |
+
if col not in g.columns:
|
| 256 |
+
g[col] = 0
|
| 257 |
+
return g[["negative","positive"]].sort_index()
|
| 258 |
+
|
| 259 |
+
def _rolling_window(freq):
|
| 260 |
+
return 7 if freq == "D" else (4 if freq == "W" else 3)
|
| 261 |
+
|
| 262 |
+
def make_time_charts(df: pd.DataFrame, date_col: str, freq: str, use_ma: bool, topn_aspects: int):
|
| 263 |
+
# 1) line pos/neg over time
|
| 264 |
+
ts = _resample_counts(df, date_col, freq)
|
| 265 |
+
if use_ma:
|
| 266 |
+
win = _rolling_window(freq)
|
| 267 |
+
ts = ts.rolling(win, min_periods=1).mean()
|
| 268 |
+
|
| 269 |
+
fig_line = go.Figure()
|
| 270 |
+
fig_line.add_scatter(x=ts.index, y=ts["negative"], mode="lines",
|
| 271 |
+
name="negative", line=dict(color=NEG_COLOR))
|
| 272 |
+
fig_line.add_scatter(x=ts.index, y=ts["positive"], mode="lines",
|
| 273 |
+
name="positive", line=dict(color=POS_COLOR))
|
| 274 |
+
fig_line.update_layout(title="Reviews over time", template=TEMPLATE,
|
| 275 |
+
xaxis_title="Date", yaxis_title="Count")
|
| 276 |
+
|
| 277 |
+
# 2) stacked area of top-N negative aspects over time
|
| 278 |
+
neg_df = df[df["label"] == "negative"].copy()
|
| 279 |
+
if len(neg_df) == 0:
|
| 280 |
+
return fig_line # ไม่มีลบก็ไม่มี area
|
| 281 |
+
|
| 282 |
+
neg_df["aspect"] = neg_df["review"].apply(detect_aspect_for_negative)
|
| 283 |
+
top_aspects = neg_df["aspect"].value_counts().head(max(1, int(topn_aspects))).index.tolist()
|
| 284 |
+
|
| 285 |
+
area_df = (
|
| 286 |
+
neg_df.assign(aspect=lambda d: d["aspect"].where(d["aspect"].isin(top_aspects), ASPECT_OTHER_LABEL))
|
| 287 |
+
.groupby([pd.Grouper(key=date_col, freq=freq), "aspect"]).size()
|
| 288 |
+
.unstack(fill_value=0).sort_index()
|
| 289 |
)
|
| 290 |
+
if use_ma:
|
| 291 |
+
win = _rolling_window(freq)
|
| 292 |
+
area_df = area_df.rolling(win, min_periods=1).mean()
|
| 293 |
+
|
| 294 |
+
# ผสาน area traces เข้า fig_line (รวมเป็นกราฟเดียว)
|
| 295 |
+
for col in area_df.columns:
|
| 296 |
+
fig_line.add_scatter(x=area_df.index, y=area_df[col], stackgroup="one",
|
| 297 |
+
mode="lines", name=f"neg:{col}", opacity=0.4)
|
| 298 |
+
return fig_line
|
| 299 |
|
| 300 |
# ================= Core Predict =================
|
| 301 |
def _predict_batch(texts, model_name, batch_size=32):
|
|
|
|
| 319 |
})
|
| 320 |
return results
|
| 321 |
|
| 322 |
+
# ----- Batch (textarea) -----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
def predict_many(text_block: str, model_choice: str):
|
| 324 |
try:
|
| 325 |
raw_lines = (text_block or "").splitlines()
|
|
|
|
| 328 |
skipped = len(all_norm) - len(cleaned)
|
| 329 |
if len(cleaned) == 0:
|
| 330 |
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
|
| 331 |
+
return empty, go.Figure(), go.Figure(), go.Figure(), "No valid text"
|
| 332 |
results = _predict_batch(cleaned, model_choice)
|
| 333 |
df = pd.DataFrame(results)
|
| 334 |
+
fig_bar, fig_pie, fig_topn, info_md = make_basic_charts(df)
|
| 335 |
info_md = f"{info_md} \n- Skipped: {skipped}"
|
| 336 |
+
return df, fig_bar, fig_pie, fig_topn, info_md
|
| 337 |
except Exception:
|
| 338 |
tb = traceback.format_exc()
|
| 339 |
empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
|
| 340 |
+
return empty, go.Figure(), go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```"
|
| 341 |
|
| 342 |
# ----- CSV upload -----
|
| 343 |
LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body","ข้อความ","รีวิว"]
|
| 344 |
+
LIKELY_DATE_COLS = ["date","created_at","time","timestamp","datetime","วันที่","วันเวลา","เวลา"]
|
| 345 |
+
|
| 346 |
+
def predict_csv(file_obj, model_choice: str, text_col_name: str,
|
| 347 |
+
date_col_name: str, date_from: str, date_to: str,
|
| 348 |
+
freq_choice: str, use_ma: bool, topn_aspects: int):
|
| 349 |
|
|
|
|
| 350 |
try:
|
| 351 |
if file_obj is None:
|
| 352 |
+
return pd.DataFrame(), go.Figure(), go.Figure(), go.Figure(), go.Figure(), "Please upload a CSV.", None
|
| 353 |
|
| 354 |
+
df_raw = pd.read_csv(file_obj.name)
|
| 355 |
+
|
| 356 |
+
# ----- pick text column -----
|
| 357 |
+
cols = list(df_raw.columns)
|
| 358 |
+
col_text = (text_col_name or "").strip()
|
| 359 |
+
if not col_text or col_text not in cols:
|
| 360 |
low = {c.lower(): c for c in cols}
|
| 361 |
+
found = None
|
| 362 |
for k in LIKELY_TEXT_COLS:
|
| 363 |
if k in low:
|
| 364 |
found = low[k]; break
|
| 365 |
if found is None:
|
| 366 |
+
cand = [c for c in cols if df_raw[c].dtype == object]
|
| 367 |
found = cand[0] if cand else cols[0]
|
| 368 |
+
col_text = found
|
| 369 |
|
| 370 |
+
texts = [_norm_text(v) for v in df_raw[col_text].tolist()]
|
| 371 |
texts = [t for t in texts if _is_substantive_text(t)]
|
| 372 |
if len(texts) == 0:
|
| 373 |
+
return pd.DataFrame(), go.Figure(), go.Figure(), go.Figure(), go.Figure(), \
|
| 374 |
+
"No valid texts in selected column.", None
|
| 375 |
|
| 376 |
+
# ----- predict -----
|
| 377 |
results = _predict_batch(texts, model_choice)
|
| 378 |
out_df = pd.DataFrame(results)
|
|
|
|
| 379 |
|
| 380 |
+
# ----- attach date if available -----
|
| 381 |
+
col_date = (date_col_name or "").strip()
|
| 382 |
+
if col_date and col_date in cols:
|
| 383 |
+
date_series = df_raw[col_date]
|
| 384 |
+
else:
|
| 385 |
+
found_d = None
|
| 386 |
+
low = {c.lower(): c for c in cols}
|
| 387 |
+
for k in LIKELY_DATE_COLS:
|
| 388 |
+
if k in low:
|
| 389 |
+
found_d = low[k]; break
|
| 390 |
+
date_series = df_raw[found_d] if found_d else None
|
| 391 |
+
|
| 392 |
+
have_time = False
|
| 393 |
+
if date_series is not None:
|
| 394 |
+
dts = _to_datetime_safe(date_series)
|
| 395 |
+
if dts.notna().any():
|
| 396 |
+
out_df["__dt__"] = dts # บางแถว parse ได้
|
| 397 |
+
have_time = True
|
| 398 |
+
|
| 399 |
+
# ----- basic charts (always) -----
|
| 400 |
+
neg_mask = out_df["label"] == "negative"
|
| 401 |
+
if neg_mask.any():
|
| 402 |
+
out_df.loc[neg_mask, "aspect"] = out_df.loc[neg_mask, "review"].apply(detect_aspect_for_negative)
|
| 403 |
+
else:
|
| 404 |
+
out_df["aspect"] = None
|
| 405 |
+
|
| 406 |
+
fig_bar, fig_pie, fig_topn, info_basic = make_basic_charts(out_df)
|
| 407 |
+
|
| 408 |
+
# ----- time charts (optional) -----
|
| 409 |
+
fig_time = go.Figure()
|
| 410 |
+
info_time = ""
|
| 411 |
+
if have_time:
|
| 412 |
+
df_time = out_df.dropna(subset=["__dt__"]).copy()
|
| 413 |
+
if date_from:
|
| 414 |
+
df_time = df_time[df_time["__dt__"] >= pd.to_datetime(date_from, errors="coerce")]
|
| 415 |
+
if date_to:
|
| 416 |
+
df_time = df_time[df_time["__dt__"] <= pd.to_datetime(date_to, errors="coerce")]
|
| 417 |
+
if len(df_time) == 0:
|
| 418 |
+
info_time = "**Note:** Selected date range has no data."
|
| 419 |
+
else:
|
| 420 |
+
fig_time = make_time_charts(df_time, "__dt__", freq_choice, use_ma, topn_aspects)
|
| 421 |
+
info_time = f"Time charts based on column: **{col_date or 'auto-detected'}**, Freq: **{freq_choice}**, MA: **{use_ma}**, TopN: **{topn_aspects}**"
|
| 422 |
+
else:
|
| 423 |
+
info_time = "_No date/timestamp column detected — time charts hidden._"
|
| 424 |
+
|
| 425 |
+
# ----- downloadable CSV -----
|
| 426 |
fd, out_path = tempfile.mkstemp(prefix="pred_", suffix=".csv")
|
| 427 |
os.close(fd)
|
| 428 |
+
out_df.drop(columns=["__dt__"], errors="ignore").to_csv(out_path, index=False, encoding="utf-8-sig")
|
| 429 |
+
|
| 430 |
+
info_md = info_basic + "\n\n" + info_time
|
| 431 |
+
return (
|
| 432 |
+
out_df.drop(columns=["__dt__"], errors="ignore"),
|
| 433 |
+
fig_bar, fig_pie, fig_topn, fig_time, info_md, out_path
|
| 434 |
+
)
|
| 435 |
|
|
|
|
|
|
|
| 436 |
except Exception:
|
| 437 |
tb = traceback.format_exc()
|
| 438 |
+
return pd.DataFrame(), go.Figure(), go.Figure(), go.Figure(), go.Figure(), f"**Error**\n```\n{tb}\n```", None
|
| 439 |
|
| 440 |
# ================= Gradio UI =================
|
| 441 |
with gr.Blocks(title="Thai Sentiment (WangchanBERTa Variants)") as demo:
|
| 442 |
gr.Markdown("### Thai Sentiment (WangchanBERTa Variants)")
|
| 443 |
model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
|
| 444 |
|
| 445 |
+
# ---- Batch (Textarea) ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
with gr.Tab("Batch (หลายข้อความ)"):
|
| 447 |
t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
|
| 448 |
df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
|
| 449 |
bar2 = gr.Plot(label="Label counts (bar)")
|
| 450 |
+
pie2 = gr.Plot(label="Negative aspects (pie)")
|
| 451 |
+
top2 = gr.Plot(label="Top-N negative aspects (overall)")
|
| 452 |
sum2 = gr.Markdown()
|
| 453 |
+
gr.Button("Run Batch").click(
|
| 454 |
+
predict_many, [t2, model_radio], [df2, bar2, pie2, top2, sum2]
|
| 455 |
+
)
|
| 456 |
|
| 457 |
+
# ---- CSV Upload ----
|
| 458 |
with gr.Tab("CSV Upload"):
|
| 459 |
+
with gr.Row():
|
| 460 |
+
file_in = gr.File(label="อัปโหลดไฟล์ .csv", file_types=[".csv"])
|
| 461 |
+
col_in = gr.Textbox(label="ชื่อคอลัมน์ข้อความ (เว้นว่างให้เลือกอัตโนมัติ)", value="")
|
| 462 |
+
date_in = gr.Textbox(label="ชื่อคอลัมน์วันเวลา (เว้นว่างให้ auto-detect)", value="")
|
| 463 |
+
with gr.Row():
|
| 464 |
+
date_from = gr.Textbox(label="เริ่มวันที่ (YYYY-MM-DD)", value="")
|
| 465 |
+
date_to = gr.Textbox(label="ถึงวันที่ (YYYY-MM-DD)", value="")
|
| 466 |
+
freq = gr.Radio(choices=["D","W","M"], value="D", label="ความถี่ (Day/Week/Month)")
|
| 467 |
+
use_ma = gr.Checkbox(value=True, label="Moving average (7/4/3)")
|
| 468 |
+
topn = gr.Slider(3, 10, value=5, step=1, label="Top-N aspects (ลบ)")
|
| 469 |
+
|
| 470 |
df3 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
|
| 471 |
bar3 = gr.Plot(label="Label counts (bar)")
|
| 472 |
+
pie3 = gr.Plot(label="Negative aspects (pie)")
|
| 473 |
+
top3 = gr.Plot(label="Top-N negative aspects (overall)")
|
| 474 |
+
line = gr.Plot(label="Reviews over time (POS/NEG + stacked negative aspects)")
|
| 475 |
sum3 = gr.Markdown()
|
| 476 |
dl3 = gr.File(label="ดาวน์โหลดผลเป็น CSV", interactive=False)
|
| 477 |
+
|
| 478 |
+
gr.Button("Predict CSV").click(
|
| 479 |
+
predict_csv,
|
| 480 |
+
[file_in, model_radio, col_in, date_in, date_from, date_to, freq, use_ma, topn],
|
| 481 |
+
[df3, bar3, pie3, top3, line, sum3, dl3]
|
| 482 |
+
)
|
| 483 |
|
| 484 |
if __name__ == "__main__":
|
| 485 |
demo.launch()
|