Dusit-P's picture
Update app.py
bf786d6 verified
# app.py — Thai Sentiment Analysis (ยืดหยุ่น + ง่าย)
import os, json, importlib.util, traceback, re, math, tempfile
import gradio as gr
import torch, pandas as pd
import torch.nn.functional as F
import plotly.graph_objects as go
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer
# ================= Settings =================
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment")
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "WCB_BiLSTM")
HF_TOKEN = os.getenv("HF_TOKEN", None)
# เลือกเฉพาะโมเดลที่ให้ผลดีที่สุด
AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM"]
NEG_COLOR = "#F87171"
POS_COLOR = "#34D399"
TEMPLATE = "plotly_white"
CACHE = {}
# ================= Loader =================
def _import_models():
if "models_module" in CACHE:
return CACHE["models_module"]
models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
spec = importlib.util.spec_from_file_location("models", models_py)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
CACHE["models_module"] = mod
return mod
def load_model(model_name: str):
key = f"model:{model_name}"
if key in CACHE:
return CACHE[key]
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
arch_name = cfg.get("architecture", model_name)
tok = AutoTokenizer.from_pretrained(base_model)
models = _import_models()
model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)),
cfg.get("pooling_after_lstm", "masked_mean"))
state = load_file(w_path)
model.load_state_dict(state, strict=False)
model.eval()
CACHE[key] = (model, tok, cfg)
return CACHE[key]
# ================= Utils =================
_INVALID_STRINGS = {"-", "--","—","n/a","na","null","none","nan",".","…",""}
_RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
def _norm_text(v):
if v is None: return ""
if isinstance(v, float) and math.isnan(v): return ""
return str(v).strip().strip('"').strip("'").strip(",")
def _is_substantive_text(s, min_chars=2):
if not s: return False
if s.lower() in _INVALID_STRINGS: return False
if not _RE_HAS_LETTER.search(s): return False
if len(s.replace(" ","")) < min_chars: return False
return True
def _format_pct(x): return f"{x*100:.2f}%"
# คำที่น่าจะเป็นคอลัมน์ข้อความ
LIKELY_TEXT_COLS = ["text","review","message","comment","content","sentence","body",
"ข้อความ","รีวิว","ความคิดเห็น"]
# คำที่น่าจะเป็นคอลัมน์หมวดหมู่ (ร้าน/product/category)
LIKELY_GROUP_COLS = ["shop","store","branch","category","product","brand","type","group",
"ร้าน","สาขา","ชื่อร้าน","หมวดหมู่","ประเภท","แบรนด์"]
def detect_columns(df):
"""ตรวจหา text และ group columns อัตโนมัติ"""
cols = list(df.columns)
low = {c.lower(): c for c in cols}
# Text column
text_col = None
for k in LIKELY_TEXT_COLS:
if k in low:
text_col = low[k]
break
if text_col is None:
cand = [c for c in cols if df[c].dtype == object]
text_col = cand[0] if cand else cols[0]
# Group candidates (ร้าน/หมวดหมู่)
group_candidates = []
for c in cols:
if c == text_col: # ข้ามคอลัมน์ที่เป็น text
continue
if c.lower() in LIKELY_GROUP_COLS:
group_candidates.append(c)
continue
# ตรวจว่ามีค่าซ้ำพอสมควร (categorical)
if df[c].dtype == object:
unique_ratio = df[c].nunique() / len(df)
if 0.01 <= unique_ratio <= 0.5: # 1-50% ของข้อมูลเป็นค่าซ้ำ
group_candidates.append(c)
group_candidates = list(dict.fromkeys(group_candidates))
group_col = group_candidates[0] if len(group_candidates) > 0 else None
return text_col, group_candidates, group_col
# ================= Core Predict =================
def _predict_batch(texts, model_name, batch_size=32):
model, tok, cfg = load_model(model_name)
results = []
for i in range(0, len(texts), batch_size):
chunk = texts[i:i+batch_size]
enc = tok(chunk, padding=True, truncation=True,
max_length=cfg.get("max_length",128), return_tensors="pt")
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
probs = F.softmax(logits, dim=1).cpu().numpy()
for txt, p in zip(chunk, probs):
neg, pos = float(p[0]), float(p[1])
label = "positive" if pos >= neg else "negative"
results.append({
"review": txt,
"negative(%)": _format_pct(neg),
"positive(%)": _format_pct(pos),
"label": label
})
return results
# ================= Charts =================
def make_summary_chart(df):
"""สร้างกราฟสรุปแบบ Pie"""
total = len(df)
neg_count = len(df[df["label"]=="negative"])
pos_count = len(df[df["label"]=="positive"])
neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
# Pie chart
fig = go.Figure(go.Pie(
labels=["😞 เชิงลบ", "😊 เชิงบวก"],
values=[neg_count, pos_count],
hole=0.4,
marker=dict(colors=[NEG_COLOR, POS_COLOR]),
textinfo='label+percent',
textfont_size=14
))
fig.update_layout(
title="สัดส่วนรีวิว",
template=TEMPLATE,
height=400
)
# Summary text
info = (f"**📊 สรุปผล**\n\n"
f"- ทั้งหมด: **{total:,}** รีวิว\n"
f"- เชิงลบ: **{neg_count:,}** ({neg_count/total*100:.1f}%)\n"
f"- เชิงบวก: **{pos_count:,}** ({pos_count/total*100:.1f}%)")
return fig, info
def make_group_chart(df, group_col):
"""กราฟแสดงรีวิวแยกตามกลุ่ม (ร้าน/หมวดหมู่/etc)"""
# สรุปแต่ละกลุ่ม
group_data = []
for group in df[group_col].unique():
if pd.isna(group):
continue
group_df = df[df[group_col] == group]
neg = len(group_df[group_df["label"]=="negative"])
pos = len(group_df[group_df["label"]=="positive"])
total = len(group_df)
group_data.append({
"group": str(group),
"negative": neg,
"positive": pos,
"total": total,
"pos_pct": pos/total*100 if total > 0 else 0
})
group_df = pd.DataFrame(group_data).sort_values("total", ascending=False)
# กราฟแท่ง Stacked
fig = go.Figure()
fig.add_bar(
name="😞 เชิงลบ",
x=group_df["group"],
y=group_df["negative"],
marker_color=NEG_COLOR,
text=group_df["negative"],
textposition='inside'
)
fig.add_bar(
name="😊 เชิงบวก",
x=group_df["group"],
y=group_df["positive"],
marker_color=POS_COLOR,
text=group_df["positive"],
textposition='inside'
)
fig.update_layout(
title=f"📊 รีวิวแยกตามกลุ่ม",
barmode='stack',
template=TEMPLATE,
xaxis_title="",
yaxis_title="จำนวนรีวิว",
height=450,
showlegend=True
)
# ตารางสรุป
summary_df = pd.DataFrame({
"กลุ่ม": group_df["group"],
"รีวิวทั้งหมด": group_df["total"],
"😞 เชิงลบ": group_df["negative"],
"😊 เชิงบวก": group_df["positive"],
"% เชิงบวก": group_df["pos_pct"].apply(lambda x: f"{x:.1f}%")
})
return fig, summary_df
# ================= Tab 1: วิเคราะห์ข้อความ =================
def predict_many(text_block, model_choice):
try:
raw = (text_block or "").splitlines()
norm = [_norm_text(t) for t in raw]
clean = [t for t in norm if _is_substantive_text(t)]
if not clean:
return pd.DataFrame(), go.Figure(), "❌ ไม่พบข้อความที่วิเคราะห์ได้"
results = _predict_batch(clean, model_choice)
df = pd.DataFrame(results)
fig, info = make_summary_chart(df)
return df, fig, info
except Exception as e:
return pd.DataFrame(), go.Figure(), f"❌ เกิดข้อผิดพลาด:\n{traceback.format_exc()}"
# ================= Tab 2: อัปโหลด CSV =================
def on_file_change(file_obj):
"""ตรวจหา columns เมื่ออัปโหลดไฟล์"""
if file_obj is None:
return (gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(visible=False),
gr.update(visible=False),
"⚠️ กรุณาอัปโหลดไฟล์ CSV")
try:
df = pd.read_csv(file_obj.name)
text_col, group_candidates, group_col = detect_columns(df)
has_group = group_col is not None
note = f"✅ **ตรวจพบคอลัมน์**\n\n"
note += f"📝 **ข้อความรีวิว:** {text_col}\n\n"
if has_group:
note += f"📊 **กลุ่ม/หมวดหมู่:** {group_col} ({df[group_col].nunique()} กลุ่ม)\n\n"
else:
note += f"📊 **กลุ่ม/หมวดหมู่:** _ไม่พบ_\n\n"
note += "_หากต้องการเปลี่ยน สามารถเลือกคอลัมน์ใหม่ได้ด้านบน_"
return (gr.update(choices=list(df.columns), value=text_col),
gr.update(choices=group_candidates if group_candidates else ["ไม่มี"],
value=group_col if group_col else "ไม่มี"),
gr.update(visible=has_group),
gr.update(visible=has_group),
note)
except Exception as e:
return (gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(visible=False),
gr.update(visible=False),
f"❌ ไม่สามารถอ่านไฟล์ได้: {str(e)}")
def predict_csv(file_obj, model_choice, text_col, group_col):
"""วิเคราะห์ CSV"""
if file_obj is None:
return (pd.DataFrame(), go.Figure(),
gr.update(visible=False),
gr.update(visible=False),
"❌ กรุณาอัปโหลดไฟล์", None)
try:
df_raw = pd.read_csv(file_obj.name)
total_rows = len(df_raw)
cols = list(df_raw.columns)
# ตรวจสอบ text column
if text_col not in cols:
text_col, _, _ = detect_columns(df_raw)
# ดึงข้อความ
texts = [_norm_text(v) for v in df_raw[text_col].tolist()]
texts_clean = [t for t in texts if _is_substantive_text(t)]
skipped = total_rows - len(texts_clean)
if not texts_clean:
return (pd.DataFrame(), go.Figure(),
gr.update(visible=False),
gr.update(visible=False),
"❌ ไม่พบข้อความที่วิเคราะห์ได้", None)
# ทำนาย
results = _predict_batch(texts_clean, model_choice)
df_out = pd.DataFrame(results)
# กราฟสรุป
fig_main, info = make_summary_chart(df_out)
if skipped > 0:
info += f"\n\n⚠️ ข้ามแถวว่าง: {skipped} แถว (ใช้ {len(texts_clean)}/{total_rows} แถว)"
# วิเคราะห์ตามกลุ่ม (ถ้ามี)
fig_group = go.Figure()
group_summary = pd.DataFrame()
show_group = False
if group_col and group_col in cols and group_col != "ไม่มี":
# เตรียมข้อมูล
df_group = df_out.copy()
df_group[group_col] = df_raw[group_col].iloc[:len(df_out)]
# ลบแถวที่ไม่มีข้อมูลกลุ่ม
df_group = df_group.dropna(subset=[group_col])
if len(df_group) > 0:
fig_group, group_summary = make_group_chart(df_group, group_col)
show_group = True
info += f"\n\n📊 **วิเคราะห์เพิ่มเติม:** แยกตาม '{group_col}'"
# บันทึกไฟล์
fd, path = tempfile.mkstemp(suffix=".csv")
os.close(fd)
df_out.to_csv(path, index=False, encoding="utf-8-sig")
return (df_out, fig_main,
gr.update(visible=show_group, value=fig_group),
gr.update(visible=show_group, value=group_summary),
info, path)
except Exception as e:
return (pd.DataFrame(), go.Figure(),
gr.update(visible=False),
gr.update(visible=False),
f"❌ เกิดข้อผิดพลาด:\n{traceback.format_exc()}", None)
# ================= Gradio UI =================
with gr.Blocks(title="Thai Sentiment Analysis", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🇹🇭 Thai Sentiment Analysis
### วิเคราะห์ความรู้สึกรีวิวภาษาไทย (เชิงบวก/เชิงลบ)
""")
model_radio = gr.Radio(
choices=AVAILABLE_CHOICES,
value=DEFAULT_MODEL,
label="🤖 เลือกโมเดล",
info="WCB = เร็ว | WCB_BiLSTM = แม่นยำสูงสุด (แนะนำ)"
)
# =================== Tab 1: วิเคราะห์ข้อความ ===================
with gr.Tab("📝 วิเคราะห์ข้อความ"):
gr.Markdown("""
**วิธีใช้:** ป้อนรีวิวหลายรายการ (แต่ละบรรทัด = 1 รีวิว)
**ตัวอย่าง:**
```
อาหารอร่อยมาก บริการดี
ของแพง รสชาติธรรมดา
บรรยากาศดี แนะนำเลย
```
""")
text_input = gr.Textbox(
lines=8,
label="📄 ข้อความรีวิว",
placeholder="ป้อนรีวิว แต่ละบรรทัด = 1 รีวิว..."
)
predict_btn_1 = gr.Button("🚀 เริ่มวิเคราะห์", variant="primary", size="lg")
result_df_1 = gr.Dataframe(label="📋 ผลการวิเคราะห์")
with gr.Row():
with gr.Column(scale=1):
result_chart_1 = gr.Plot(label="📊 กราฟสรุป")
with gr.Column(scale=1):
result_info_1 = gr.Markdown()
predict_btn_1.click(
predict_many,
[text_input, model_radio],
[result_df_1, result_chart_1, result_info_1]
)
# =================== Tab 2: อัปโหลด CSV ===================
with gr.Tab("📤 วิเคราะห์ไฟล์ CSV"):
gr.Markdown("""
**วิธีใช้:** อัปโหลดไฟล์ CSV ที่มีคอลัมน์รีวิว
**ระบบจะตรวจหาอัตโนมัติ:**
- 📝 คอลัมน์ข้อความรีวิว
- 📊 คอลัมน์กลุ่ม/หมวดหมู่ (เช่น ร้าน, สาขา, ประเภทสินค้า, แบรนด์)
**ใช้ได้กับหลายสถานการณ์:**
- เปรียบเทียบร้านค้า/สาขา
- วิเคราะห์ตาม product category
- แยกตามแบรนด์/ประเภทสินค้า
- หรือข้อมูล categorical อื่นๆ
""")
file_input = gr.File(file_types=[".csv"], label="📁 อัปโหลดไฟล์ CSV")
detect_note = gr.Markdown("⬆️ อัปโหลดไฟล์เพื่อเริ่มต้น")
with gr.Row():
text_col_dd = gr.Dropdown(
label="📝 คอลัมน์ข้อความรีวิว",
info="เลือกคอลัมน์ที่มีเนื้อหารีวิว"
)
group_col_dd = gr.Dropdown(
label="📊 คอลัมน์กลุ่ม/หมวดหมู่ (ถ้ามี)",
info="เช่น ร้าน, สาขา, ประเภทสินค้า, แบรนด์"
)
predict_btn_2 = gr.Button("🚀 เริ่มวิเคราะห์", variant="primary", size="lg")
gr.Markdown("### 📊 ผลการวิเคราะห์")
result_df_2 = gr.Dataframe(label="📋 รายละเอียดทุกรีวิว")
with gr.Row():
with gr.Column(scale=1):
result_chart_2 = gr.Plot(label="📊 สรุปภาพรวม")
with gr.Column(scale=1):
result_info_2 = gr.Markdown()
result_group = gr.Plot(label="📊 เปรียบเทียบแต่ละกลุ่ม", visible=False)
group_summary = gr.Dataframe(label="📋 สรุปแต่ละกลุ่ม", visible=False)
download_file = gr.File(label="💾 ดาวน์โหลดผลลัพธ์ (CSV)")
# Events
file_input.change(
on_file_change,
[file_input],
[text_col_dd, group_col_dd, result_group, group_summary, detect_note]
)
predict_btn_2.click(
predict_csv,
[file_input, model_radio, text_col_dd, group_col_dd],
[result_df_2, result_chart_2, result_group, group_summary, result_info_2, download_file]
)
gr.Markdown("""
---
### 💡 เกี่ยวกับโมเดล
- **WCB**: เร็ว เหมาะงานทั่วไป
- **WCB_BiLSTM**: แม่นยำสูงสุด ⭐ (แนะนำ)
📌 วิเคราะห์เฉพาะ **เชิงบวก/เชิงลบ** เท่านั้น
""")
if __name__ == "__main__":
demo.launch()