Dusit-P commited on
Commit
3d347ae
·
verified ·
1 Parent(s): f8698d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -64
app.py CHANGED
@@ -1,64 +1,239 @@
1
- import os, json, importlib.util, torch
2
- import torch.nn.functional as F
3
- import gradio as gr
4
- from huggingface_hub import hf_hub_download
5
- from safetensors.torch import load_file
6
- from transformers import AutoTokenizer
7
-
8
- # ===== ปรับได้ผ่าน Settings > Variables (Environment) =====
9
- REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
10
- DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # หรื"baseline"
11
- HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
12
-
13
- CACHE = {}
14
-
15
- def _import_models():
16
- if "models_module" in CACHE:
17
- return CACHE["models_module"]
18
- models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
19
- spec = importlib.util.spec_from_file_location("models", models_py)
20
- mod = importlib.util.module_from_spec(spec)
21
- spec.loader.exec_module(mod)
22
- CACHE["models_module"] = mod
23
- return mod
24
-
25
- def load_model(model_name: str):
26
- key = f"model:{model_name}"
27
- if key in CACHE:
28
- return CACHE[key]
29
- cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
30
- w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
31
- with open(cfg_path, "r", encoding="utf-8") as f:
32
- cfg = json.load(f)
33
- models = _import_models()
34
- tok = AutoTokenizer.from_pretrained(cfg["base_model"])
35
- model = models.create_model_by_name(cfg["arch"])
36
- state = load_file(w_path)
37
- model.load_state_dict(state, strict=True)
38
- model.eval()
39
- CACHE[key] = (model, tok, cfg)
40
- return CACHE[key]
41
-
42
- def predict_api(text: str, model_choice: str):
43
- if not text.strip():
44
- return {"negative": 0.0, "positive": 0.0}, ""
45
- model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
46
- model, tok, cfg = load_model(model_name)
47
- enc = tok([text], padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
48
- with torch.no_grad():
49
- logits = model(enc["input_ids"], enc["attention_mask"])
50
- probs = F.softmax(logits, dim=1)[0].tolist()
51
- out = {"negative": float(probs[0]), "positive": float(probs[1])}
52
- label = "positive" if out["positive"] >= out["negative"] else "negative"
53
- return out, label
54
-
55
- with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
56
- gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM Heads)")
57
- inp_text = gr.Textbox(lines=3, label="ข้อความรีวิวภาษาไทย", placeholder="พิมพ์รีวิวที่นี่")
58
- inp_model = gr.Radio(choices=["cnn_bilstm","baseline"], value=DEFAULT_MODEL, label="เลือกโมเดล")
59
- out_probs = gr.Label(label="Probabilities")
60
- out_label = gr.Textbox(label="Prediction", interactive=False)
61
- gr.Button("Predict").click(predict_api, [inp_text, inp_model], [out_probs, out_label])
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, importlib.util, tempfile, torch
2
+ import torch.nn.functional as F
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import plotly.graph_objects as go
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file
8
+ from transformers import AutoTokenizer
9
+
10
+ # ===== ปรับได้จาก Settings > Variables & secrets ข Space =====
11
+ REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
12
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # หรือ "baseline"
13
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
14
+
15
+ CACHE = {}
16
+
17
+ # ---------- load architecture & weights from model repo ----------
18
+ def _import_models():
19
+ if "models_module" in CACHE:
20
+ return CACHE["models_module"]
21
+ models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
22
+ spec = importlib.util.spec_from_file_location("models", models_py)
23
+ mod = importlib.util.module_from_spec(spec)
24
+ spec.loader.exec_module(mod)
25
+ CACHE["models_module"] = mod
26
+ return mod
27
+
28
+ def load_model(model_name: str):
29
+ key = f"model:{model_name}"
30
+ if key in CACHE:
31
+ return CACHE[key]
32
+ cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
33
+ w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
34
+
35
+ with open(cfg_path, "r", encoding="utf-8") as f:
36
+ cfg = json.load(f)
37
+
38
+ models = _import_models()
39
+ tok = AutoTokenizer.from_pretrained(cfg["base_model"])
40
+ model = models.create_model_by_name(cfg["arch"])
41
+ state = load_file(w_path)
42
+ model.load_state_dict(state, strict=True)
43
+ model.eval()
44
+
45
+ CACHE[key] = (model, tok, cfg)
46
+ return CACHE[key]
47
+
48
+ # ---------- helpers ----------
49
+ def _format_pct(x: float) -> str:
50
+ return f"{x*100:.2f}%"
51
+
52
+ def _predict_batch(texts, model_name, batch_size=64):
53
+ """รับ list[str] → คืน list[dict] = review, negative(%), positive(%), label"""
54
+ model, tok, cfg = load_model(model_name)
55
+ results = []
56
+ rows = [str(t) for t in texts if str(t).strip()]
57
+ for i in range(0, len(rows), batch_size):
58
+ chunk = rows[i:i+batch_size]
59
+ enc = tok(chunk, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
60
+ with torch.no_grad():
61
+ logits = model(enc["input_ids"], enc["attention_mask"])
62
+ probs = F.softmax(logits, dim=1).cpu().numpy()
63
+ for txt, p in zip(chunk, probs):
64
+ neg, pos = float(p[0]), float(p[1])
65
+ label = "positive" if pos >= neg else "negative"
66
+ results.append({
67
+ "review": txt,
68
+ "negative(%)": _format_pct(neg),
69
+ "positive(%)": _format_pct(pos),
70
+ "label": label,
71
+ })
72
+ return results
73
+
74
+ def _detect_cols(df: pd.DataFrame):
75
+ """เดาชื่อคอลัมน์รีวิว/ร้านอัตโนมัติ ถ้าไม่พบรีวิว เลือกคอลัมน์ object ตัวแรก"""
76
+ rev_cands = ["review", "text", "comment", "content", "message", "ข้อความ", "รีวิว"]
77
+ shop_cands = ["shop", "shop_name", "store", "restaurant", "brand", "merchant", "ชื่อร้าน"]
78
+
79
+ review_col = next((c for c in rev_cands if c in df.columns), None)
80
+ shop_col = next((c for c in shop_cands if c in df.columns), None)
81
+
82
+ if review_col is None:
83
+ obj_cols = [c for c in df.columns if df[c].dtype == object]
84
+ if obj_cols:
85
+ review_col = obj_cols[0]
86
+
87
+ return review_col, shop_col
88
+
89
+ def _summarize_df(df: pd.DataFrame):
90
+ """สรุปภาพรวม + ตัวเลขเฉลี่ยความมั่นใจ"""
91
+ total = len(df)
92
+ neg = int((df["label"] == "negative").sum())
93
+ pos = int((df["label"] == "positive").sum())
94
+ neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
95
+ pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
96
+ info = (
97
+ f"**Summary** \n"
98
+ f"- Total: {total} \n"
99
+ f"- Negative: {neg} \n"
100
+ f"- Positive: {pos} \n"
101
+ f"- Avg negative: {neg_avg:.2f}% \n"
102
+ f"- Avg positive: {pos_avg:.2f}%"
103
+ )
104
+ return {"total": total, "neg": neg, "pos": pos, "neg_avg": neg_avg, "pos_avg": pos_avg, "md": info}
105
+
106
+ def _make_figures(df: pd.DataFrame):
107
+ s = _summarize_df(df)
108
+ # Bar รวม
109
+ fig_bar = go.Figure([go.Bar(x=["negative","positive"], y=[s["neg"], s["pos"]])])
110
+ fig_bar.update_layout(title="Label counts", xaxis_title="label", yaxis_title="count")
111
+ # Pie รวม
112
+ fig_pie = go.Figure(go.Pie(labels=["negative","positive"], values=[s["neg"], s["pos"]], hole=0.35))
113
+ fig_pie.update_layout(title="Label share")
114
+ return fig_bar, fig_pie, s["md"]
115
+
116
+ def _shop_summary(out_df: pd.DataFrame, max_shops=15):
117
+ """สรุปต่อร้าน: table + stacked bar (pos/neg counts) — ถ้ามีคอลัมน์ shop"""
118
+ if "shop" not in out_df.columns:
119
+ return go.Figure(), pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"])
120
+ g = out_df.groupby("shop")["label"].value_counts().unstack(fill_value=0)
121
+ # ให้มีทั้งสองคอลัมน์เสมอ
122
+ for col in ["positive","negative"]:
123
+ if col not in g.columns:
124
+ g[col] = 0
125
+ g["total"] = g["positive"] + g["negative"]
126
+ g = g.sort_values("total", ascending=False)
127
+
128
+ table = g[["total","positive","negative"]].copy()
129
+ table["positive_rate(%)"] = (table["positive"] / table["total"] * 100).round(2)
130
+ table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2)
131
+ table = table.reset_index().rename(columns={"index":"shop"})
132
+
133
+ # กราฟโชว์ top N ร้านตามจำนวนรีวิวรวม
134
+ top = table.head(max_shops)
135
+ fig = go.Figure()
136
+ fig.add_bar(name="positive", x=top["shop"], y=top["positive"])
137
+ fig.add_bar(name="negative", x=top["shop"], y=top["negative"])
138
+ fig.update_layout(barmode="stack", title=f"Per-shop counts (top {len(top)})",
139
+ xaxis_title="shop", yaxis_title="count", legend_title="label")
140
+ return fig, table
141
+
142
+ # ---------- API wrappers ----------
143
+ def predict_one(text: str, model_choice: str):
144
+ if not text.strip():
145
+ return {"negative": 0.0, "positive": 0.0}, ""
146
+ model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
147
+ out = _predict_batch([text], model_name)[0]
148
+ probs = {
149
+ "negative": float(out["negative(%)"].rstrip("%"))/100.0,
150
+ "positive": float(out["positive(%)"].rstrip("%"))/100.0,
151
+ }
152
+ return probs, out["label"]
153
+
154
+ def predict_many(text_block: str, model_choice: str):
155
+ model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
156
+ lines = [ln.strip() for ln in (text_block or "").splitlines() if ln.strip()]
157
+ results = _predict_batch(lines, model_name)
158
+ df = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"])
159
+ if len(df) == 0:
160
+ return df, go.Figure(), go.Figure(), "No data"
161
+
162
+ fig_bar, fig_pie, info_md = _make_figures(df)
163
+ return df, fig_bar, fig_pie, info_md
164
+
165
+ def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop_col_override: str = ""):
166
+ if file_obj is None:
167
+ return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV"
168
+
169
+ model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
170
+ df = pd.read_csv(file_obj.name)
171
+
172
+ auto_rev, auto_shop = _detect_cols(df)
173
+ rev_col = (review_col_override or "").strip() or auto_rev
174
+ shop_col = (shop_col_override or "").strip() or auto_shop
175
+
176
+ if rev_col not in df.columns:
177
+ raise ValueError(f"ไม่พบคอลัมน์รีวิว '{rev_col}' ใน CSV (columns = {list(df.columns)})")
178
+
179
+ results = _predict_batch(df[rev_col].astype(str).tolist(), model_name)
180
+ out = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"])
181
+
182
+ if shop_col and shop_col in df.columns:
183
+ out.insert(0, "shop", df[shop_col].astype(str).fillna(""))
184
+
185
+ # ไฟล์ผลลัพธ์สำหรับดาวน์โหลด
186
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
187
+ out.to_csv(tmp.name, index=False, encoding="utf-8-sig")
188
+
189
+ # กราฟ/สรุปรวม
190
+ fig_bar, fig_pie, info_md = _make_figures(out)
191
+ # กราฟ/ตารางต่อร้าน (ถ้ามี shop)
192
+ fig_shop, tbl_shop = _shop_summary(out)
193
+
194
+ # แนบข้อความบอกคอลัมน์ที่ใช้
195
+ info_md = f"{info_md} \nใช้คอลัมน์รีวิว: {rev_col}" + (f" | คอลัมน์ร้าน: {shop_col}" if ("shop" in out.columns) else " | ไม่มีคอลัมน์ร้าน")
196
+
197
+ return out, tmp.name, fig_bar, fig_pie, fig_shop, tbl_shop, info_md
198
+
199
+ # ---------- Gradio UI ----------
200
+ with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
201
+ gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM Heads)")
202
+
203
+ model_radio = gr.Radio(choices=["cnn_bilstm","baseline"], value=DEFAULT_MODEL, label="เลือกโมเดล")
204
+
205
+ with gr.Tab("Single"):
206
+ t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
207
+ probs = gr.Label(label="Probabilities")
208
+ pred = gr.Textbox(label="Prediction", interactive=False)
209
+ gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred])
210
+
211
+ with gr.Tab("Batch (หลายข้อความ)"):
212
+ t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
213
+ df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
214
+ bar2 = gr.Plot(label="Label counts (bar)")
215
+ pie2 = gr.Plot(label="Label share (pie)")
216
+ sum2 = gr.Markdown()
217
+ gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
218
+
219
+ with gr.Tab("CSV (auto-detect columns)"):
220
+ f = gr.File(label="อัปโหลด CSV", file_types=[".csv"])
221
+ review_col_inp = gr.Textbox(label="ชื่อคอลัมน์รีวิว (เว้นว่างให้เดาได้)")
222
+ shop_col_inp = gr.Textbox(label="ชื่อคอลัมน์ร้าน (เว้นว่างได้)")
223
+
224
+ df3 = gr.Dataframe(label="ผลลัพธ์ CSV", interactive=False)
225
+ download = gr.File(label="ดาวน์โหลดผลลัพธ์")
226
+ bar3 = gr.Plot(label="Label counts (bar)")
227
+ pie3 = gr.Plot(label="Label share (pie)")
228
+ shop_bar = gr.Plot(label="Per-shop stacked bar")
229
+ shop_tbl = gr.Dataframe(label="Per-shop summary", interactive=False)
230
+ info = gr.Markdown()
231
+
232
+ gr.Button("Run CSV").click(
233
+ predict_csv,
234
+ inputs=[f, model_radio, review_col_inp, shop_col_inp],
235
+ outputs=[df3, download, bar3, pie3, shop_bar, shop_tbl, info]
236
+ )
237
+
238
+ if __name__ == "__main__":
239
+ demo.launch()