Dusit-P commited on
Commit
48e0979
·
verified ·
1 Parent(s): 257b029

Upload 13 files

Browse files
LICENSE ADDED
@@ -0,0 +1 @@
 
 
1
+ Apache-2.0
README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Thai Sentiment (WangchanBERTa + LSTM Heads)
2
+
3
+ ## Install
4
+ ```bash
5
+ pip install -r requirements.txt
WCB/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_wcb_sentiment",
3
+ "base_model": "airesearch/wangchanberta-base-att-spm-uncased",
4
+ "architecture": "WCB",
5
+ "num_labels": 2,
6
+ "id2label": {
7
+ "0": "NEG",
8
+ "1": "POS"
9
+ },
10
+ "label2id": {
11
+ "NEG": 0,
12
+ "POS": 1
13
+ },
14
+ "max_length": 128,
15
+ "pooling_after_lstm": "masked_mean",
16
+ "export_source_checkpoint": "best_m1_wcb_5models_wcb_comparison.pth",
17
+ "export_experiment": "5models_wcb_comparison"
18
+ }
WCB/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21f68a987efcb981173b90aa55d2827622265a108842e97dd27975f9ca99bfd5
3
+ size 421007280
WCB_4Layer_BiLSTM/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_wcb_sentiment",
3
+ "base_model": "airesearch/wangchanberta-base-att-spm-uncased",
4
+ "architecture": "WCB_4Layer_BiLSTM",
5
+ "num_labels": 2,
6
+ "id2label": {
7
+ "0": "NEG",
8
+ "1": "POS"
9
+ },
10
+ "label2id": {
11
+ "NEG": 0,
12
+ "POS": 1
13
+ },
14
+ "max_length": 128,
15
+ "pooling_after_lstm": "masked_mean",
16
+ "export_source_checkpoint": "best_m4_wcb_4layer_bilstm_5models_wcb_comparison.pth",
17
+ "export_experiment": "5models_wcb_comparison"
18
+ }
WCB_4Layer_BiLSTM/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0450384ee6ee1dea3352ef6805f92efcace76c6b80fe2b08d1ce7b1d4e340254
3
+ size 424682216
WCB_BiLSTM/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_wcb_sentiment",
3
+ "base_model": "airesearch/wangchanberta-base-att-spm-uncased",
4
+ "architecture": "WCB_BiLSTM",
5
+ "num_labels": 2,
6
+ "id2label": {
7
+ "0": "NEG",
8
+ "1": "POS"
9
+ },
10
+ "label2id": {
11
+ "NEG": 0,
12
+ "POS": 1
13
+ },
14
+ "max_length": 128,
15
+ "pooling_after_lstm": "masked_mean",
16
+ "export_source_checkpoint": "best_m2_wcb_bilstm_5models_wcb_comparison.pth",
17
+ "export_experiment": "5models_wcb_comparison"
18
+ }
WCB_BiLSTM/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc6742020cc8966603c74b03bd8d091c9d19c451b307f0e2103e05200d07c090
3
+ size 424682128
WCB_CNN_BiLSTM/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_wcb_sentiment",
3
+ "base_model": "airesearch/wangchanberta-base-att-spm-uncased",
4
+ "architecture": "WCB_CNN_BiLSTM",
5
+ "num_labels": 2,
6
+ "id2label": {
7
+ "0": "NEG",
8
+ "1": "POS"
9
+ },
10
+ "label2id": {
11
+ "NEG": 0,
12
+ "POS": 1
13
+ },
14
+ "max_length": 128,
15
+ "pooling_after_lstm": "masked_mean",
16
+ "export_source_checkpoint": "best_m3_wcb_cnn_bilstm_5models_wcb_comparison.pth",
17
+ "export_experiment": "5models_wcb_comparison"
18
+ }
WCB_CNN_BiLSTM/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc09bb197646b9ea0f534396ebac19e5eb6e5162616bcf6a62ed6941409316e2
3
+ size 423569368
common/models.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # common/models.py
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel
5
+
6
+ # ตั้งค่าพื้นฐานให้ตรงกับตอนเทรน
7
+ BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
8
+ POOLING_AFTER_LSTM = "masked_mean"
9
+
10
+ class BaseHead(nn.Module):
11
+ def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'):
12
+ super().__init__()
13
+ self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
14
+ self.dropout = nn.Dropout(dropout)
15
+ self.fc = nn.Linear(hidden_lstm*2, num_classes)
16
+ assert pooling in ['cls','masked_mean','masked_max']
17
+ self.pooling = pooling
18
+ def pool(self, x, mask):
19
+ if self.pooling=='cls': return x[:,0,:]
20
+ mask = mask.unsqueeze(-1)
21
+ if self.pooling=='masked_mean':
22
+ s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
23
+ x=x.masked_fill(mask==0,-1e9); return x.max(1).values
24
+ def forward_after_bert(self, seq, mask):
25
+ x, _ = self.lstm(seq)
26
+ x = self.pool(x, mask)
27
+ return self.fc(self.dropout(x))
28
+
29
+ class Model1Baseline(nn.Module):
30
+ def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
31
+ super().__init__()
32
+ self.bert = AutoModel.from_pretrained(name)
33
+ self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling)
34
+ def forward(self, ids, mask):
35
+ out = self.bert(input_ids=ids, attention_mask=mask)
36
+ return self.head.forward_after_bert(out.last_hidden_state, mask)
37
+
38
+ class Model2CNNBiLSTM(nn.Module):
39
+ def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
40
+ super().__init__()
41
+ self.bert = AutoModel.from_pretrained(name)
42
+ H = self.bert.config.hidden_size
43
+ self.c1 = nn.Conv1d(H,128,3,padding=1)
44
+ self.c2 = nn.Conv1d(128,128,5,padding=2)
45
+ self.head = BaseHead(128, hidden, classes, dropout, pooling)
46
+ def forward(self, ids, mask):
47
+ out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
48
+ x = F.relu(self.c1(out.transpose(1,2)))
49
+ x = F.relu(self.c2(x)).transpose(1,2)
50
+ return self.head.forward_after_bert(x, mask)
51
+
52
+ def create_model_by_name(model_name):
53
+ if model_name == "Model1_Baseline":
54
+ return Model1Baseline()
55
+ elif model_name == "Model2_CNN_BiLSTM":
56
+ return Model2CNNBiLSTM()
57
+ else:
58
+ raise ValueError(f"Unknown model name: {model_name}")
infer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # infer.py
2
+ import os, sys, json, torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ from safetensors.torch import load_file
6
+
7
+ # ใช้สถาปัตยกรรมร่วม
8
+ sys.path.append(os.path.join(os.path.dirname(__file__), "common"))
9
+ from models import create_model_by_name
10
+
11
+ def load_model(model_dir: str):
12
+ cfg_path = os.path.join(model_dir, "config.json")
13
+ w_path = os.path.join(model_dir, "model.safetensors")
14
+ if not (os.path.exists(cfg_path) and os.path.exists(w_path)):
15
+ raise FileNotFoundError("config.json หรือ model.safetensors ไม่ครบ")
16
+
17
+ with open(cfg_path, "r", encoding="utf-8") as f:
18
+ cfg = json.load(f)
19
+
20
+ tok = AutoTokenizer.from_pretrained(cfg["base_model"])
21
+ model = create_model_by_name(cfg["arch"])
22
+ state = load_file(w_path)
23
+ model.load_state_dict(state)
24
+ model.eval()
25
+ return model, tok, cfg
26
+
27
+ def predict(texts, model, tok, cfg):
28
+ enc = tok(texts, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
29
+ with torch.no_grad():
30
+ logits = model(enc["input_ids"], enc["attention_mask"])
31
+ prob = F.softmax(logits, dim=1).cpu().numpy()
32
+ pred = prob.argmax(1)
33
+ return pred, prob
34
+
35
+ if __name__ == "__main__":
36
+ # เลือกโฟลเดอร์โมเดล: "baseline" หรือ "cnn_bilstm"
37
+ MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else "cnn_bilstm"
38
+
39
+ model, tok, cfg = load_model(MODEL_DIR)
40
+ xs = ["อาหารอร่อยมาก บริการดี", "ไม่ประทับใจเลย ช้ามาก"]
41
+ y, p = predict(xs, model, tok, cfg)
42
+ labels = ["negative", "positive"]
43
+ for t, yy, pp in zip(xs, y, p):
44
+ print(f"{t} => {labels[yy]} | prob={pp}")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ safetensors