File size: 5,095 Bytes
ec1e0cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

import json, torch, torch.nn as nn, torch.nn.functional as F
from pathlib import Path
from transformers import AutoModel, AutoTokenizer
from safetensors.torch import load_file as safe_load

# -------- Base pooling --------
class PoolingLayer(nn.Module):
    def __init__(self, pooling="masked_mean"):
        super().__init__()
        assert pooling in ["cls","masked_mean","masked_max"]
        self.pooling = pooling
    def forward(self, x, mask):
        if self.pooling == "cls":
            return x[:,0,:]
        mask = mask.unsqueeze(-1)
        if self.pooling == "masked_mean":
            s = (x * mask).sum(1)
            d = mask.sum(1).clamp(min=1e-6)
            return s / d
        x = x.masked_fill(mask == 0, -1e9)
        return x.max(1).values

# -------- Model defs --------
def _base_model(name):
    return AutoModel.from_pretrained(name)

class Model1_WCB(nn.Module):
    def __init__(self, name, num_labels=2, dropout=0.3):
        super().__init__()
        self.bert = _base_model(name)
        H = self.bert.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(H, num_labels)
    def forward(self, ids, mask):
        out = self.bert(input_ids=ids, attention_mask=mask)
        cls = out.last_hidden_state[:,0,:]
        return self.fc(self.dropout(cls))

class Model2_WCB_BiLSTM(nn.Module):
    def __init__(self, name, num_labels=2, hidden=128, dropout=0.3, pooling="masked_mean"):
        super().__init__()
        self.bert = _base_model(name)
        H = self.bert.config.hidden_size
        self.lstm = nn.LSTM(H, hidden, bidirectional=True, batch_first=True)
        self.pool = PoolingLayer(pooling)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden*2, num_labels)
    def forward(self, ids, mask):
        seq = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
        x,_ = self.lstm(seq)
        x = self.pool(x, mask)
        return self.fc(self.dropout(x))

class Model3_WCB_CNN_BiLSTM(nn.Module):
    def __init__(self, name, num_labels=2, hidden=128, dropout=0.3, pooling="masked_mean"):
        super().__init__()
        self.bert = _base_model(name)
        H = self.bert.config.hidden_size
        self.c1 = nn.Conv1d(H,128,3,padding=1)
        self.c2 = nn.Conv1d(128,128,5,padding=2)
        self.lstm = nn.LSTM(128, hidden, bidirectional=True, batch_first=True)
        self.pool = PoolingLayer(pooling)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden*2, num_labels)
    def forward(self, ids, mask):
        out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
        x = F.relu(self.c1(out.transpose(1,2)))
        x = F.relu(self.c2(x)).transpose(1,2)
        x,_ = self.lstm(x)
        x = self.pool(x, mask)
        return self.fc(self.dropout(x))

class Model4_WCB_4Layer_BiLSTM(nn.Module):
    def __init__(self, name, num_labels=2, hidden=128, dropout=0.3, pooling="masked_mean"):
        super().__init__()
        self.bert = _base_model(name)
        H = self.bert.config.hidden_size
        self.w = nn.Parameter(torch.ones(4))
        self.lstm = nn.LSTM(H, hidden, bidirectional=True, batch_first=True)
        self.pool = PoolingLayer(pooling)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden*2, num_labels)
    def _pool_layers(self, hs):
        last4 = hs[-4:]
        w = F.softmax(self.w, 0)
        return sum(w[i]*last4[i] for i in range(4))
    def forward(self, ids, mask):
        out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
        seq = self._pool_layers(out.hidden_states)
        x,_ = self.lstm(seq)
        x = self.pool(x, mask)
        return self.fc(self.dropout(x))

# -------- Factory & Loader --------
def _build(arch, base_model, num_labels, pooling):
    if arch == "WCB":
        return Model1_WCB(base_model, num_labels)
    if arch == "WCB_BiLSTM":
        return Model2_WCB_BiLSTM(base_model, num_labels, pooling=pooling)
    if arch == "WCB_CNN_BiLSTM":
        return Model3_WCB_CNN_BiLSTM(base_model, num_labels, pooling=pooling)
    if arch == "WCB_4Layer_BiLSTM":
        return Model4_WCB_4Layer_BiLSTM(base_model, num_labels, pooling=pooling)
    raise ValueError(f"Unknown architecture: {arch}")

def load_model(model_dir: str):
    """
    โหลดโมเดลจากโฟลเดอร์โมเดล (ที่มี config.json + model.safetensors)
    Return: tokenizer, model (eval mode), config(dict)
    """
    d = Path(model_dir)
    cfg = json.loads((d/"config.json").read_text(encoding="utf-8"))
    arch    = cfg.get("architecture","WCB")
    base    = cfg.get("base_model","airesearch/wangchanberta-base-att-spm-uncased")
    nlabel  = int(cfg.get("num_labels",2))
    pooling = cfg.get("pooling_after_lstm","masked_mean")

    model = _build(arch, base, nlabel, pooling)
    sd = safe_load(str(d/"model.safetensors"))
    model.load_state_dict(sd, strict=False)
    model.eval()

    tok = AutoTokenizer.from_pretrained(base, use_fast=True)
    return tok, model, cfg