Dusit-P commited on
Commit
c1fbd91
·
verified ·
1 Parent(s): c6bf28b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -186
app.py CHANGED
@@ -1,186 +1,186 @@
1
- import os, json, importlib.util, tempfile, traceback, torch, re, math
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import gradio as gr
5
- import pandas as pd
6
- import plotly.graph_objects as go
7
- from huggingface_hub import hf_hub_download
8
- from safetensors.torch import load_file
9
- from transformers import AutoTokenizer, AutoModel
10
-
11
- # ===== Settings =====
12
- REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
13
- DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "WCB") # default model
14
- HF_TOKEN = os.getenv("HF_TOKEN", None)
15
-
16
- # ---- theme colors ----
17
- NEG_COLOR = "#F87171" # red-400
18
- POS_COLOR = "#34D399" # emerald-400
19
- TEMPLATE = "plotly_white"
20
-
21
- CACHE = {}
22
-
23
- # ---------- load models from common/models.py ----------
24
- def _import_models():
25
- if "models_module" in CACHE:
26
- return CACHE["models_module"]
27
- models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
28
- spec = importlib.util.spec_from_file_location("models", models_py)
29
- mod = importlib.util.module_from_spec(spec)
30
- spec.loader.exec_module(mod)
31
- CACHE["models_module"] = mod
32
- return mod
33
-
34
- def load_model(model_name: str):
35
- key = f"model:{model_name}"
36
- if key in CACHE:
37
- return CACHE[key]
38
-
39
- cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
40
- w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
41
-
42
- with open(cfg_path, "r", encoding="utf-8") as f:
43
- cfg = json.load(f)
44
-
45
- base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
46
- arch_name = cfg.get("architecture", model_name)
47
- tok = AutoTokenizer.from_pretrained(base_model)
48
-
49
- models = _import_models()
50
- model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)), cfg.get("pooling_after_lstm","masked_mean"))
51
-
52
- state = load_file(w_path)
53
- model.load_state_dict(state, strict=False)
54
- model.eval()
55
-
56
- CACHE[key] = (model, tok, cfg)
57
- return CACHE[key]
58
-
59
- # ---------- helpers ----------
60
- def _format_pct(x: float) -> str:
61
- return f"{x*100:.2f}%"
62
-
63
- _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""}
64
- _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
65
-
66
- def _norm_text(v) -> str:
67
- if v is None: return ""
68
- if isinstance(v, float) and math.isnan(v): return ""
69
- return str(v).strip()
70
-
71
- def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
72
- if not s: return False
73
- if s.lower() in _INVALID_STRINGS: return False
74
- if not _RE_HAS_LETTER.search(s): return False
75
- if len(s.replace(" ", "")) < min_chars: return False
76
- return True
77
-
78
- def _clean_texts(texts):
79
- all_norm = [_norm_text(t) for t in texts]
80
- cleaned = [t for t in all_norm if _is_substantive_text(t)]
81
- skipped = len(all_norm) - len(cleaned)
82
- return cleaned, skipped
83
-
84
- def _make_figures(df: pd.DataFrame):
85
- total = len(df)
86
- neg = int((df["label"] == "negative").sum())
87
- pos = int((df["label"] == "positive").sum())
88
- neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
89
- pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
90
-
91
- info = (
92
- f"**Summary** \n"
93
- f"- Total: {total} \n"
94
- f"- Negative: {neg} \n"
95
- f"- Positive: {pos} \n"
96
- f"- Avg negative: {neg_avg:.2f}% \n"
97
- f"- Avg positive: {pos_avg:.2f}%"
98
- )
99
-
100
- fig_bar = go.Figure()
101
- fig_bar.add_bar(name="negative", x=["negative"], y=[neg], marker_color=NEG_COLOR)
102
- fig_bar.add_bar(name="positive", x=["positive"], y=[pos], marker_color=POS_COLOR)
103
- fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
104
-
105
- fig_pie = go.Figure(go.Pie(
106
- labels=["negative", "positive"],
107
- values=[neg, pos],
108
- hole=0.35,
109
- sort=False,
110
- marker=dict(colors=[NEG_COLOR, POS_COLOR])
111
- ))
112
- fig_pie.update_layout(title="Label share", template=TEMPLATE)
113
-
114
- return fig_bar, fig_pie, info
115
-
116
- # ---------- core prediction ----------
117
- def _predict_batch(texts, model_name, batch_size=32):
118
- model, tok, cfg = load_model(model_name)
119
- results = []
120
- for i in range(0, len(texts), batch_size):
121
- chunk = texts[i:i+batch_size]
122
- enc = tok(chunk, padding=True, truncation=True,
123
- max_length=cfg.get("max_length",128), return_tensors="pt")
124
- with torch.no_grad():
125
- logits = model(enc["input_ids"], enc["attention_mask"])
126
- probs = F.softmax(logits, dim=1).cpu().numpy()
127
- for txt, p in zip(chunk, probs):
128
- neg, pos = float(p[0]), float(p[1])
129
- label = "positive" if pos >= neg else "negative"
130
- results.append({
131
- "review": txt,
132
- "negative(%)": _format_pct(neg),
133
- "positive(%)": _format_pct(pos),
134
- "label": label,
135
- })
136
- return results
137
-
138
- def predict_one(text: str, model_choice: str):
139
- s = _norm_text(text)
140
- if not _is_substantive_text(s):
141
- return {"negative": 0.0, "positive": 0.0}, "invalid"
142
- out = _predict_batch([s], model_choice)[0]
143
- probs = {
144
- "negative": float(out["negative(%)"].rstrip("%"))/100.0,
145
- "positive": float(out["positive(%)"].rstrip("%"))/100.0,
146
- }
147
- return probs, out["label"]
148
-
149
- def predict_many(text_block: str, model_choice: str):
150
- raw_lines = (text_block or "").splitlines()
151
- cleaned, skipped = _clean_texts(raw_lines)
152
- if len(cleaned) == 0:
153
- empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
154
- return empty, go.Figure(), go.Figure(), "No valid text"
155
- results = _predict_batch(cleaned, model_choice)
156
- df = pd.DataFrame(results)
157
- fig_bar, fig_pie, info_md = _make_figures(df)
158
- info_md = f"{info_md} \n- Skipped: {skipped}"
159
- return df, fig_bar, fig_pie, info_md
160
-
161
- # ---------- Gradio UI ----------
162
- AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM", "WCB_CNN_BiLSTM", "WCB_4Layer_BiLSTM"]
163
- if DEFAULT_MODEL not in AVAILABLE_CHOICES:
164
- DEFAULT_MODEL = "WCB"
165
-
166
- with gr.Blocks(title="Thai Sentiment GUI") as demo:
167
- gr.Markdown("### Thai Sentiment (WangchanBERTa Variants)")
168
-
169
- model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
170
-
171
- with gr.Tab("Single"):
172
- t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
173
- probs = gr.Label(label="Probabilities")
174
- pred = gr.Textbox(label="Prediction", interactive=False)
175
- gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred])
176
-
177
- with gr.Tab("Batch (หลายข้อความ)"):
178
- t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
179
- df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
180
- bar2 = gr.Plot(label="Label counts (bar)")
181
- pie2 = gr.Plot(label="Label share (pie)")
182
- sum2 = gr.Markdown()
183
- gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
184
-
185
- if __name__ == "__main__":
186
- demo.launch()
 
1
+ import os, json, importlib.util, tempfile, traceback, torch, re, math
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ from huggingface_hub import hf_hub_download
8
+ from safetensors.torch import load_file
9
+ from transformers import AutoTokenizer, AutoModel
10
+
11
+ # ===== Settings =====
12
+ REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment")
13
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "WCB") # default model
14
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
15
+
16
+ # ---- theme colors ----
17
+ NEG_COLOR = "#F87171" # red-400
18
+ POS_COLOR = "#34D399" # emerald-400
19
+ TEMPLATE = "plotly_white"
20
+
21
+ CACHE = {}
22
+
23
+ # ---------- load models from common/models.py ----------
24
+ def _import_models():
25
+ if "models_module" in CACHE:
26
+ return CACHE["models_module"]
27
+ models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
28
+ spec = importlib.util.spec_from_file_location("models", models_py)
29
+ mod = importlib.util.module_from_spec(spec)
30
+ spec.loader.exec_module(mod)
31
+ CACHE["models_module"] = mod
32
+ return mod
33
+
34
+ def load_model(model_name: str):
35
+ key = f"model:{model_name}"
36
+ if key in CACHE:
37
+ return CACHE[key]
38
+
39
+ cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
40
+ w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
41
+
42
+ with open(cfg_path, "r", encoding="utf-8") as f:
43
+ cfg = json.load(f)
44
+
45
+ base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
46
+ arch_name = cfg.get("architecture", model_name)
47
+ tok = AutoTokenizer.from_pretrained(base_model)
48
+
49
+ models = _import_models()
50
+ model = models._build(arch_name, base_model, int(cfg.get("num_labels",2)), cfg.get("pooling_after_lstm","masked_mean"))
51
+
52
+ state = load_file(w_path)
53
+ model.load_state_dict(state, strict=False)
54
+ model.eval()
55
+
56
+ CACHE[key] = (model, tok, cfg)
57
+ return CACHE[key]
58
+
59
+ # ---------- helpers ----------
60
+ def _format_pct(x: float) -> str:
61
+ return f"{x*100:.2f}%"
62
+
63
+ _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""}
64
+ _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]")
65
+
66
+ def _norm_text(v) -> str:
67
+ if v is None: return ""
68
+ if isinstance(v, float) and math.isnan(v): return ""
69
+ return str(v).strip()
70
+
71
+ def _is_substantive_text(s: str, min_chars: int = 2) -> bool:
72
+ if not s: return False
73
+ if s.lower() in _INVALID_STRINGS: return False
74
+ if not _RE_HAS_LETTER.search(s): return False
75
+ if len(s.replace(" ", "")) < min_chars: return False
76
+ return True
77
+
78
+ def _clean_texts(texts):
79
+ all_norm = [_norm_text(t) for t in texts]
80
+ cleaned = [t for t in all_norm if _is_substantive_text(t)]
81
+ skipped = len(all_norm) - len(cleaned)
82
+ return cleaned, skipped
83
+
84
+ def _make_figures(df: pd.DataFrame):
85
+ total = len(df)
86
+ neg = int((df["label"] == "negative").sum())
87
+ pos = int((df["label"] == "positive").sum())
88
+ neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean()
89
+ pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean()
90
+
91
+ info = (
92
+ f"**Summary** \n"
93
+ f"- Total: {total} \n"
94
+ f"- Negative: {neg} \n"
95
+ f"- Positive: {pos} \n"
96
+ f"- Avg negative: {neg_avg:.2f}% \n"
97
+ f"- Avg positive: {pos_avg:.2f}%"
98
+ )
99
+
100
+ fig_bar = go.Figure()
101
+ fig_bar.add_bar(name="negative", x=["negative"], y=[neg], marker_color=NEG_COLOR)
102
+ fig_bar.add_bar(name="positive", x=["positive"], y=[pos], marker_color=POS_COLOR)
103
+ fig_bar.update_layout(barmode="group", title="Label counts", template=TEMPLATE)
104
+
105
+ fig_pie = go.Figure(go.Pie(
106
+ labels=["negative", "positive"],
107
+ values=[neg, pos],
108
+ hole=0.35,
109
+ sort=False,
110
+ marker=dict(colors=[NEG_COLOR, POS_COLOR])
111
+ ))
112
+ fig_pie.update_layout(title="Label share", template=TEMPLATE)
113
+
114
+ return fig_bar, fig_pie, info
115
+
116
+ # ---------- core prediction ----------
117
+ def _predict_batch(texts, model_name, batch_size=32):
118
+ model, tok, cfg = load_model(model_name)
119
+ results = []
120
+ for i in range(0, len(texts), batch_size):
121
+ chunk = texts[i:i+batch_size]
122
+ enc = tok(chunk, padding=True, truncation=True,
123
+ max_length=cfg.get("max_length",128), return_tensors="pt")
124
+ with torch.no_grad():
125
+ logits = model(enc["input_ids"], enc["attention_mask"])
126
+ probs = F.softmax(logits, dim=1).cpu().numpy()
127
+ for txt, p in zip(chunk, probs):
128
+ neg, pos = float(p[0]), float(p[1])
129
+ label = "positive" if pos >= neg else "negative"
130
+ results.append({
131
+ "review": txt,
132
+ "negative(%)": _format_pct(neg),
133
+ "positive(%)": _format_pct(pos),
134
+ "label": label,
135
+ })
136
+ return results
137
+
138
+ def predict_one(text: str, model_choice: str):
139
+ s = _norm_text(text)
140
+ if not _is_substantive_text(s):
141
+ return {"negative": 0.0, "positive": 0.0}, "invalid"
142
+ out = _predict_batch([s], model_choice)[0]
143
+ probs = {
144
+ "negative": float(out["negative(%)"].rstrip("%"))/100.0,
145
+ "positive": float(out["positive(%)"].rstrip("%"))/100.0,
146
+ }
147
+ return probs, out["label"]
148
+
149
+ def predict_many(text_block: str, model_choice: str):
150
+ raw_lines = (text_block or "").splitlines()
151
+ cleaned, skipped = _clean_texts(raw_lines)
152
+ if len(cleaned) == 0:
153
+ empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"])
154
+ return empty, go.Figure(), go.Figure(), "No valid text"
155
+ results = _predict_batch(cleaned, model_choice)
156
+ df = pd.DataFrame(results)
157
+ fig_bar, fig_pie, info_md = _make_figures(df)
158
+ info_md = f"{info_md} \n- Skipped: {skipped}"
159
+ return df, fig_bar, fig_pie, info_md
160
+
161
+ # ---------- Gradio UI ----------
162
+ AVAILABLE_CHOICES = ["WCB", "WCB_BiLSTM", "WCB_CNN_BiLSTM", "WCB_4Layer_BiLSTM"]
163
+ if DEFAULT_MODEL not in AVAILABLE_CHOICES:
164
+ DEFAULT_MODEL = "WCB"
165
+
166
+ with gr.Blocks(title="Thai Sentiment GUI") as demo:
167
+ gr.Markdown("### Thai Sentiment (WangchanBERTa Variants)")
168
+
169
+ model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
170
+
171
+ with gr.Tab("Single"):
172
+ t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
173
+ probs = gr.Label(label="Probabilities")
174
+ pred = gr.Textbox(label="Prediction", interactive=False)
175
+ gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred])
176
+
177
+ with gr.Tab("Batch (หลายข้อความ)"):
178
+ t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)")
179
+ df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False)
180
+ bar2 = gr.Plot(label="Label counts (bar)")
181
+ pie2 = gr.Plot(label="Label share (pie)")
182
+ sum2 = gr.Markdown()
183
+ gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2])
184
+
185
+ if __name__ == "__main__":
186
+ demo.launch()