| import torch |
| from torch import nn |
|
|
|
|
| class MultiHeadAttentionLayer(nn.Module): |
| def __init__(self, hid_dim, n_heads, dropout, device): |
| super().__init__() |
|
|
| assert hid_dim % n_heads == 0 |
|
|
| self.hid_dim = hid_dim |
| self.n_heads = n_heads |
| self.head_dim = hid_dim // n_heads |
|
|
| self.fc_q = nn.Linear(hid_dim, hid_dim) |
| self.fc_k = nn.Linear(hid_dim, hid_dim) |
| self.fc_v = nn.Linear(hid_dim, hid_dim) |
|
|
| self.fc_o = nn.Linear(hid_dim, hid_dim) |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) |
|
|
| def forward(self, query, key, value, mask = None): |
|
|
| batch_size = query.shape[0] |
|
|
| |
| |
| |
|
|
| Q = self.fc_q(query) |
| K = self.fc_k(key) |
| V = self.fc_v(value) |
|
|
| |
| |
| |
|
|
| Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
| K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
| V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
|
|
| |
| |
| |
|
|
| energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale |
|
|
| |
|
|
| if mask is not None: |
| energy = energy.masked_fill(mask == 0, -1e10) |
|
|
| attention = torch.softmax(energy, dim = -1) |
|
|
| |
|
|
| x = torch.matmul(self.dropout(attention), V) |
|
|
| |
|
|
| x = x.permute(0, 2, 1, 3).contiguous() |
|
|
| |
|
|
| x = x.view(batch_size, -1, self.hid_dim) |
|
|
| |
|
|
| x = self.fc_o(x) |
|
|
| |
|
|
| return x, attention |
| class PositionwiseFeedforwardLayer(nn.Module): |
| def __init__(self, hid_dim, pf_dim, dropout): |
| super().__init__() |
|
|
| self.fc_1 = nn.Linear(hid_dim, pf_dim) |
| self.fc_2 = nn.Linear(pf_dim, hid_dim) |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
|
|
| |
|
|
| x = self.dropout(torch.relu(self.fc_1(x))) |
|
|
| |
|
|
| x = self.fc_2(x) |
|
|
| |
|
|
| return x |
| class EncoderLayer(nn.Module): |
| def __init__(self, |
| hid_dim, |
| n_heads, |
| pf_dim, |
| dropout, |
| device): |
| super().__init__() |
|
|
| self.self_attn_layer_norm = nn.LayerNorm(hid_dim) |
| self.ff_layer_norm = nn.LayerNorm(hid_dim) |
| self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) |
| self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, |
| pf_dim, |
| dropout) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, src, src_mask): |
|
|
| |
| |
|
|
| |
| _src, _ = self.self_attention(src, src, src, src_mask) |
|
|
| |
| src = self.self_attn_layer_norm(src + self.dropout(_src)) |
|
|
| |
|
|
| |
| _src = self.positionwise_feedforward(src) |
|
|
| |
| src = self.ff_layer_norm(src + self.dropout(_src)) |
|
|
| |
|
|
| return src |
|
|
| class Encoder(nn.Module): |
| def __init__(self, |
| input_dim, |
| hid_dim, |
| n_layers, |
| n_heads, |
| pf_dim, |
| dropout, |
| device, |
| max_length = 1024): |
| super().__init__() |
|
|
| self.device = device |
|
|
| self.tok_embedding = nn.Embedding(input_dim, hid_dim) |
| self.pos_embedding = nn.Embedding(max_length, hid_dim) |
|
|
| self.layers = nn.ModuleList([EncoderLayer(hid_dim, |
| n_heads, |
| pf_dim, |
| dropout, |
| device) |
| for _ in range(n_layers)]) |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) |
|
|
| def forward(self, src, src_mask): |
|
|
| |
| |
|
|
| batch_size = src.shape[0] |
| src_len = src.shape[1] |
|
|
| pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) |
|
|
| |
|
|
| src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos)) |
|
|
| |
|
|
| for layer in self.layers: |
| src = layer(src, src_mask) |
|
|
| |
|
|
| return src |
|
|
| class BuSTv2(nn.Module): |
| def __init__(self, |
| encoder, |
| src_pad_idx, |
| d_model, |
| device, |
| num_classes=2, dropout=0.3): |
| super().__init__() |
|
|
| self.encoder = encoder |
| self.src_pad_idx = src_pad_idx |
| self.device = device |
| self.dropout = nn.Dropout(dropout) |
| self.classifier = nn.Linear(d_model * 2, num_classes) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def make_src_mask(self, src): |
|
|
| |
|
|
| src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) |
|
|
| |
|
|
| return src_mask |
|
|
| def forward(self, src, trg): |
|
|
| |
| |
|
|
| src_mask = self.make_src_mask(src) |
| trg_mask = self.make_src_mask(trg) |
|
|
| |
| |
|
|
| enc_src = self.encoder(src, src_mask) |
| enc_trg = self.encoder(trg, trg_mask) |
|
|
| enc_src_pooled = enc_src.mean(dim=1) |
| enc_trg_pooled = enc_trg.mean(dim=1) |
|
|
| combined = torch.cat((enc_src_pooled, enc_trg_pooled), dim=1) |
|
|
| logits = self.classifier(combined) |
|
|
| |
|
|
| return logits |
|
|