| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import torch.optim as opt
|
| import numpy as np
|
| from utils.attn import MultiHeadedAttention
|
|
|
|
|
| class SentencesPairExtract(nn.Module):
|
| def __init__(self, IPA_embed_dim, max_seq_length,batch_size, IPA_vocab_size=None):
|
| super(SentencesPairExtract, self).__init__()
|
| self.IPA_embed_dim = IPA_embed_dim
|
| self.bsz = batch_size
|
| '''前向训练模块'''
|
|
|
|
|
| self.embedding_IPA = nn.Embedding(IPA_vocab_size, IPA_embed_dim,padding_idx=0)
|
|
|
| self.scaling_IPA = nn.Sequential(nn.Linear(IPA_embed_dim * max_seq_length, 64),
|
| nn.ELU())
|
|
|
| self.rescaling_IPA = nn.Sequential(nn.Linear(768, 32),
|
| nn.ELU())
|
|
|
|
|
| self.fus_layer = MultiHeadedAttention(head_count=1, model_dim=768)
|
|
|
|
|
| self.enc_src = nn.Sequential(nn.Linear(768,768),
|
| nn.ELU(),
|
| )
|
| self.src_relu = nn.ELU()
|
|
|
| self.enc_tgt = nn.Sequential(nn.Linear(800,768),
|
| nn.ELU(),
|
| )
|
| self.tgt_relu = nn.ELU()
|
|
|
|
|
| self.mlp = nn.Sequential(nn.Linear(768 * 4, 768),
|
| nn.ELU(),
|
| nn.Linear(768, 1),
|
| nn.Sigmoid())
|
|
|
|
|
| self.loss_func = nn.BCELoss()
|
|
|
| '''注意力表征对齐模块'''
|
|
|
| self.layer_attns = MultiHeadedAttention(head_count=1, model_dim=768)
|
|
|
| def forward(self, src_vec, tgt_vec, labels, IPA_inputs=None, return_vec=False,
|
| MODE=None, anchor_vec=False, src_IPA=False, tgt_IPA=False):
|
| if not MODE:
|
| '''src_vec ,tgt_vec 都是E5输出的句向量, IPA_inputs 是 tgt 的发音碎片序列,如 [pʰaj@1] '''
|
| if IPA_inputs is not None:
|
| tgt = self.fus_IPA(tgt_vec, IPA_inputs)
|
| else:
|
| tgt = tgt_vec
|
|
|
| src_vec = self.ResidualEnc(src_vec,1)
|
| tgt = self.ResidualEnc(tgt,0)
|
|
|
| logits = self.mlp(torch.cat((tgt, src_vec, tgt - src_vec, tgt * src_vec), 1))
|
|
|
|
|
| loss, pred = self.cal_loss_and_pred(logits, labels)
|
|
|
| return loss, pred
|
|
|
| else:
|
| '''src_vec, tgt_vec, anchor_vec都是E5输出的句向量, src_IPA, tgt_IPA是IPA发音'''
|
| src = self.fus_IPA(src_vec, src_IPA)
|
| tgt = self.fus_IPA(tgt_vec, tgt_IPA)
|
|
|
|
|
| src = self.ResidualEnc(src, 0)
|
| tgt = self.ResidualEnc(tgt, 0)
|
| anchor_vec = self.ResidualEnc(anchor_vec, 1)
|
|
|
| src_fus_anc = self.layer_attns(src,anchor_vec,anchor_vec)
|
| tgt_fus_anc = self.layer_attns(tgt,anchor_vec,anchor_vec)
|
|
|
|
|
|
|
|
|
|
|
| cosine_loss = nn.CosineEmbeddingLoss(margin=0).to('cuda')
|
| loss, pred = cosine_loss(src_fus_anc.squeeze(1), tgt_fus_anc.squeeze(1), torch.ones(self.bsz).to('cuda')), '_'
|
| return loss, pred
|
|
|
| def ResidualEnc(self,vec, ifS):
|
| '''0:lao/th+IPA src,tgt 1:zh,src'''
|
| if ifS == 0:
|
| return self.tgt_relu(self.enc_tgt(vec))
|
| elif ifS == 1:
|
| return self.src_relu(self.enc_src(vec) + vec)
|
|
|
| def cal_loss_and_pred(self, logits, labels):
|
| matrix_labels = torch.tensor(labels).float()
|
| poss = logits[matrix_labels == 1] + 1e-4
|
| negs = logits[matrix_labels == 0] + 1e-4
|
| p_ls = torch.log(poss).mean()
|
| n_ls = torch.log(1 - negs).mean()
|
| loss = - (torch.where(torch.isnan(p_ls), torch.full_like(p_ls, 0), p_ls) +
|
| torch.where(torch.isnan(n_ls), torch.full_like(n_ls, 0), n_ls))
|
|
|
| predictions = (logits > 0.5).int()
|
| return loss, predictions
|
|
|
| def fus_IPA(self, vec, IPA_inputs):
|
| IPA_embed = self.embedding_IPA(IPA_inputs)
|
| list_ = []
|
| for x in torch.chunk(IPA_embed, self.bsz, dim=0):
|
| list_.append(torch.nn.functional.pad(self.scaling_IPA(x.reshape(-1)), [0,704]))
|
| IPA_vec = torch.stack(list_)
|
|
|
|
|
| a = 0.8
|
| fus_vec = torch.cat([vec, (1-a) * self.rescaling_IPA(self.fus_layer(query=vec, key=IPA_vec, value=IPA_vec).squeeze(1))],dim=1)
|
| return fus_vec |