Upload 13 files
Browse files- LICENSE +1 -0
- README.md +5 -0
- WCB/config.json +18 -0
- WCB/model.safetensors +3 -0
- WCB_4Layer_BiLSTM/config.json +18 -0
- WCB_4Layer_BiLSTM/model.safetensors +3 -0
- WCB_BiLSTM/config.json +18 -0
- WCB_BiLSTM/model.safetensors +3 -0
- WCB_CNN_BiLSTM/config.json +18 -0
- WCB_CNN_BiLSTM/model.safetensors +3 -0
- common/models.py +58 -0
- infer.py +44 -0
- requirements.txt +3 -0
LICENSE
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Apache-2.0
|
README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Thai Sentiment (WangchanBERTa + LSTM Heads)
|
| 2 |
+
|
| 3 |
+
## Install
|
| 4 |
+
```bash
|
| 5 |
+
pip install -r requirements.txt
|
WCB/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_wcb_sentiment",
|
| 3 |
+
"base_model": "airesearch/wangchanberta-base-att-spm-uncased",
|
| 4 |
+
"architecture": "WCB",
|
| 5 |
+
"num_labels": 2,
|
| 6 |
+
"id2label": {
|
| 7 |
+
"0": "NEG",
|
| 8 |
+
"1": "POS"
|
| 9 |
+
},
|
| 10 |
+
"label2id": {
|
| 11 |
+
"NEG": 0,
|
| 12 |
+
"POS": 1
|
| 13 |
+
},
|
| 14 |
+
"max_length": 128,
|
| 15 |
+
"pooling_after_lstm": "masked_mean",
|
| 16 |
+
"export_source_checkpoint": "best_m1_wcb_5models_wcb_comparison.pth",
|
| 17 |
+
"export_experiment": "5models_wcb_comparison"
|
| 18 |
+
}
|
WCB/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21f68a987efcb981173b90aa55d2827622265a108842e97dd27975f9ca99bfd5
|
| 3 |
+
size 421007280
|
WCB_4Layer_BiLSTM/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_wcb_sentiment",
|
| 3 |
+
"base_model": "airesearch/wangchanberta-base-att-spm-uncased",
|
| 4 |
+
"architecture": "WCB_4Layer_BiLSTM",
|
| 5 |
+
"num_labels": 2,
|
| 6 |
+
"id2label": {
|
| 7 |
+
"0": "NEG",
|
| 8 |
+
"1": "POS"
|
| 9 |
+
},
|
| 10 |
+
"label2id": {
|
| 11 |
+
"NEG": 0,
|
| 12 |
+
"POS": 1
|
| 13 |
+
},
|
| 14 |
+
"max_length": 128,
|
| 15 |
+
"pooling_after_lstm": "masked_mean",
|
| 16 |
+
"export_source_checkpoint": "best_m4_wcb_4layer_bilstm_5models_wcb_comparison.pth",
|
| 17 |
+
"export_experiment": "5models_wcb_comparison"
|
| 18 |
+
}
|
WCB_4Layer_BiLSTM/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0450384ee6ee1dea3352ef6805f92efcace76c6b80fe2b08d1ce7b1d4e340254
|
| 3 |
+
size 424682216
|
WCB_BiLSTM/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_wcb_sentiment",
|
| 3 |
+
"base_model": "airesearch/wangchanberta-base-att-spm-uncased",
|
| 4 |
+
"architecture": "WCB_BiLSTM",
|
| 5 |
+
"num_labels": 2,
|
| 6 |
+
"id2label": {
|
| 7 |
+
"0": "NEG",
|
| 8 |
+
"1": "POS"
|
| 9 |
+
},
|
| 10 |
+
"label2id": {
|
| 11 |
+
"NEG": 0,
|
| 12 |
+
"POS": 1
|
| 13 |
+
},
|
| 14 |
+
"max_length": 128,
|
| 15 |
+
"pooling_after_lstm": "masked_mean",
|
| 16 |
+
"export_source_checkpoint": "best_m2_wcb_bilstm_5models_wcb_comparison.pth",
|
| 17 |
+
"export_experiment": "5models_wcb_comparison"
|
| 18 |
+
}
|
WCB_BiLSTM/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc6742020cc8966603c74b03bd8d091c9d19c451b307f0e2103e05200d07c090
|
| 3 |
+
size 424682128
|
WCB_CNN_BiLSTM/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_wcb_sentiment",
|
| 3 |
+
"base_model": "airesearch/wangchanberta-base-att-spm-uncased",
|
| 4 |
+
"architecture": "WCB_CNN_BiLSTM",
|
| 5 |
+
"num_labels": 2,
|
| 6 |
+
"id2label": {
|
| 7 |
+
"0": "NEG",
|
| 8 |
+
"1": "POS"
|
| 9 |
+
},
|
| 10 |
+
"label2id": {
|
| 11 |
+
"NEG": 0,
|
| 12 |
+
"POS": 1
|
| 13 |
+
},
|
| 14 |
+
"max_length": 128,
|
| 15 |
+
"pooling_after_lstm": "masked_mean",
|
| 16 |
+
"export_source_checkpoint": "best_m3_wcb_cnn_bilstm_5models_wcb_comparison.pth",
|
| 17 |
+
"export_experiment": "5models_wcb_comparison"
|
| 18 |
+
}
|
WCB_CNN_BiLSTM/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc09bb197646b9ea0f534396ebac19e5eb6e5162616bcf6a62ed6941409316e2
|
| 3 |
+
size 423569368
|
common/models.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# common/models.py
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import AutoModel
|
| 5 |
+
|
| 6 |
+
# ตั้งค่าพื้นฐานให้ตรงกับตอนเทรน
|
| 7 |
+
BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
|
| 8 |
+
POOLING_AFTER_LSTM = "masked_mean"
|
| 9 |
+
|
| 10 |
+
class BaseHead(nn.Module):
|
| 11 |
+
def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
|
| 14 |
+
self.dropout = nn.Dropout(dropout)
|
| 15 |
+
self.fc = nn.Linear(hidden_lstm*2, num_classes)
|
| 16 |
+
assert pooling in ['cls','masked_mean','masked_max']
|
| 17 |
+
self.pooling = pooling
|
| 18 |
+
def pool(self, x, mask):
|
| 19 |
+
if self.pooling=='cls': return x[:,0,:]
|
| 20 |
+
mask = mask.unsqueeze(-1)
|
| 21 |
+
if self.pooling=='masked_mean':
|
| 22 |
+
s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
|
| 23 |
+
x=x.masked_fill(mask==0,-1e9); return x.max(1).values
|
| 24 |
+
def forward_after_bert(self, seq, mask):
|
| 25 |
+
x, _ = self.lstm(seq)
|
| 26 |
+
x = self.pool(x, mask)
|
| 27 |
+
return self.fc(self.dropout(x))
|
| 28 |
+
|
| 29 |
+
class Model1Baseline(nn.Module):
|
| 30 |
+
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.bert = AutoModel.from_pretrained(name)
|
| 33 |
+
self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling)
|
| 34 |
+
def forward(self, ids, mask):
|
| 35 |
+
out = self.bert(input_ids=ids, attention_mask=mask)
|
| 36 |
+
return self.head.forward_after_bert(out.last_hidden_state, mask)
|
| 37 |
+
|
| 38 |
+
class Model2CNNBiLSTM(nn.Module):
|
| 39 |
+
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.bert = AutoModel.from_pretrained(name)
|
| 42 |
+
H = self.bert.config.hidden_size
|
| 43 |
+
self.c1 = nn.Conv1d(H,128,3,padding=1)
|
| 44 |
+
self.c2 = nn.Conv1d(128,128,5,padding=2)
|
| 45 |
+
self.head = BaseHead(128, hidden, classes, dropout, pooling)
|
| 46 |
+
def forward(self, ids, mask):
|
| 47 |
+
out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 48 |
+
x = F.relu(self.c1(out.transpose(1,2)))
|
| 49 |
+
x = F.relu(self.c2(x)).transpose(1,2)
|
| 50 |
+
return self.head.forward_after_bert(x, mask)
|
| 51 |
+
|
| 52 |
+
def create_model_by_name(model_name):
|
| 53 |
+
if model_name == "Model1_Baseline":
|
| 54 |
+
return Model1Baseline()
|
| 55 |
+
elif model_name == "Model2_CNN_BiLSTM":
|
| 56 |
+
return Model2CNNBiLSTM()
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unknown model name: {model_name}")
|
infer.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# infer.py
|
| 2 |
+
import os, sys, json, torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
+
|
| 7 |
+
# ใช้สถาปัตยกรรมร่วม
|
| 8 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "common"))
|
| 9 |
+
from models import create_model_by_name
|
| 10 |
+
|
| 11 |
+
def load_model(model_dir: str):
|
| 12 |
+
cfg_path = os.path.join(model_dir, "config.json")
|
| 13 |
+
w_path = os.path.join(model_dir, "model.safetensors")
|
| 14 |
+
if not (os.path.exists(cfg_path) and os.path.exists(w_path)):
|
| 15 |
+
raise FileNotFoundError("config.json หรือ model.safetensors ไม่ครบ")
|
| 16 |
+
|
| 17 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 18 |
+
cfg = json.load(f)
|
| 19 |
+
|
| 20 |
+
tok = AutoTokenizer.from_pretrained(cfg["base_model"])
|
| 21 |
+
model = create_model_by_name(cfg["arch"])
|
| 22 |
+
state = load_file(w_path)
|
| 23 |
+
model.load_state_dict(state)
|
| 24 |
+
model.eval()
|
| 25 |
+
return model, tok, cfg
|
| 26 |
+
|
| 27 |
+
def predict(texts, model, tok, cfg):
|
| 28 |
+
enc = tok(texts, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
logits = model(enc["input_ids"], enc["attention_mask"])
|
| 31 |
+
prob = F.softmax(logits, dim=1).cpu().numpy()
|
| 32 |
+
pred = prob.argmax(1)
|
| 33 |
+
return pred, prob
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
# เลือกโฟลเดอร์โมเดล: "baseline" หรือ "cnn_bilstm"
|
| 37 |
+
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else "cnn_bilstm"
|
| 38 |
+
|
| 39 |
+
model, tok, cfg = load_model(MODEL_DIR)
|
| 40 |
+
xs = ["อาหารอร่อยมาก บริการดี", "ไม่ประทับใจเลย ช้ามาก"]
|
| 41 |
+
y, p = predict(xs, model, tok, cfg)
|
| 42 |
+
labels = ["negative", "positive"]
|
| 43 |
+
for t, yy, pp in zip(xs, y, p):
|
| 44 |
+
print(f"{t} => {labels[yy]} | prob={pp}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
safetensors
|