|
|
|
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel |
|
|
|
|
|
|
|
|
BASE_MODEL_NAME = "airesearch/wangchanberta-base-att-spm-uncased" |
|
|
POOLING_AFTER_LSTM = "masked_mean" |
|
|
|
|
|
class BaseHead(nn.Module): |
|
|
def __init__(self, hidden_in, hidden_lstm=128, num_classes=2, dropout=0.3, pooling='masked_mean'): |
|
|
super().__init__() |
|
|
self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.fc = nn.Linear(hidden_lstm*2, num_classes) |
|
|
assert pooling in ['cls','masked_mean','masked_max'] |
|
|
self.pooling = pooling |
|
|
def pool(self, x, mask): |
|
|
if self.pooling=='cls': return x[:,0,:] |
|
|
mask = mask.unsqueeze(-1) |
|
|
if self.pooling=='masked_mean': |
|
|
s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d |
|
|
x=x.masked_fill(mask==0,-1e9); return x.max(1).values |
|
|
def forward_after_bert(self, seq, mask): |
|
|
x, _ = self.lstm(seq) |
|
|
x = self.pool(x, mask) |
|
|
return self.fc(self.dropout(x)) |
|
|
|
|
|
class Model1Baseline(nn.Module): |
|
|
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM): |
|
|
super().__init__() |
|
|
self.bert = AutoModel.from_pretrained(name) |
|
|
self.head = BaseHead(self.bert.config.hidden_size, hidden, classes, dropout, pooling) |
|
|
def forward(self, ids, mask): |
|
|
out = self.bert(input_ids=ids, attention_mask=mask) |
|
|
return self.head.forward_after_bert(out.last_hidden_state, mask) |
|
|
|
|
|
class Model2CNNBiLSTM(nn.Module): |
|
|
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM): |
|
|
super().__init__() |
|
|
self.bert = AutoModel.from_pretrained(name) |
|
|
H = self.bert.config.hidden_size |
|
|
self.c1 = nn.Conv1d(H,128,3,padding=1) |
|
|
self.c2 = nn.Conv1d(128,128,5,padding=2) |
|
|
self.head = BaseHead(128, hidden, classes, dropout, pooling) |
|
|
def forward(self, ids, mask): |
|
|
out = self.bert(input_ids=ids, attention_mask=mask).last_hidden_state |
|
|
x = F.relu(self.c1(out.transpose(1,2))) |
|
|
x = F.relu(self.c2(x)).transpose(1,2) |
|
|
return self.head.forward_after_bert(x, mask) |
|
|
|
|
|
class Model3PureLast4(nn.Module): |
|
|
def __init__(self, name=BASE_MODEL_NAME, hidden=128, dropout=0.3, classes=2, pooling=POOLING_AFTER_LSTM): |
|
|
super().__init__() |
|
|
from transformers import AutoModel |
|
|
import torch.nn.functional as F |
|
|
self.bert = AutoModel.from_pretrained(name) |
|
|
self.w = nn.Parameter(torch.ones(4)) |
|
|
H = self.bert.config.hidden_size |
|
|
self.head = BaseHead(H, hidden, classes, dropout, pooling) |
|
|
def forward(self, ids, mask): |
|
|
out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True) |
|
|
last4 = out.hidden_states[-4:]; w = F.softmax(self.w, dim=0) |
|
|
seq = sum(w[i]*last4[i] for i in range(4)) |
|
|
return self.head.forward_after_bert(seq, mask) |
|
|
|
|
|
def create_model_by_name(model_name): |
|
|
if model_name == "Model1_Baseline": return Model1Baseline() |
|
|
elif model_name == "Model2_CNN_BiLSTM": return Model2CNNBiLSTM() |
|
|
elif model_name == "Model3_Pure_Last4Weighted": |
|
|
return Model3PureLast4() |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown model name: {model_name}") |