File size: 3,509 Bytes
cf3fcd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# common/models.py
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

# ตั้งค่าพื้นฐานให้ตรงกับตอนเทรน
BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased"
POOLING_AFTER_LSTM = "masked_mean"

class BaseHead(nn.Module):
    def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'):
        super().__init__()
        self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_lstm*2, num_classes)
        assert pooling in ['cls','masked_mean','masked_max']
        self.pooling = pooling
    def pool(self, x, mask):
        if self.pooling=='cls': return x[:,0,:]
        mask = mask.unsqueeze(-1)
        if self.pooling=='masked_mean':
            s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
        x=x.masked_fill(mask==0,-1e9); return x.max(1).values
    def forward_after_bert(self, seq, mask):
        x, _ = self.lstm(seq)
        x = self.pool(x, mask)
        return self.fc(self.dropout(x))

class Model1Baseline(nn.Module):
    def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
        super().__init__()
        self.bert = AutoModel.from_pretrained(name)
        self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling)
    def forward(self, ids, mask):
        out = self.bert(input_ids=ids, attention_mask=mask)
        return self.head.forward_after_bert(out.last_hidden_state, mask)

class Model2CNNBiLSTM(nn.Module):
    def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
        super().__init__()
        self.bert = AutoModel.from_pretrained(name)
        H = self.bert.config.hidden_size
        self.c1 = nn.Conv1d(H,128,3,padding=1)
        self.c2 = nn.Conv1d(128,128,5,padding=2)
        self.head = BaseHead(128, hidden, classes, dropout, pooling)
    def forward(self, ids, mask):
        out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state
        x = F.relu(self.c1(out.transpose(1,2)))
        x = F.relu(self.c2(x)).transpose(1,2)
        return self.head.forward_after_bert(x, mask)
        
class Model3PureLast4(nn.Module):
    def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM):
        super().__init__()
        from transformers import AutoModel
        import torch.nn.functional as F
        self.bert = AutoModel.from_pretrained(name)
        self.w = nn.Parameter(torch.ones(4))
        H = self.bert.config.hidden_size
        self.head = BaseHead(H, hidden, classes, dropout, pooling)
    def forward(self, ids, mask):
        out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
        last4 = out.hidden_states[-4:]; w = F.softmax(self.w, dim=0)
        seq = sum(w[i]*last4[i] for i in range(4))
        return self.head.forward_after_bert(seq, mask)

def create_model_by_name(model_name):
    if model_name == "Model1_Baseline": return Model1Baseline()
    elif model_name == "Model2_CNN_BiLSTM": return Model2CNNBiLSTM()
    elif model_name == "Model3_Pure_Last4Weighted": #in ["Model3_Pure_Last4Weighted","last4weighted_pure","last4_pure"]:
        return Model3PureLast4()

    else:
        raise ValueError(f"Unknown model name: {model_name}")