import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as opt import numpy as np from utils.attn import MultiHeadedAttention class SentencesPairExtract(nn.Module): def __init__(self, IPA_embed_dim, max_seq_length,batch_size, IPA_vocab_size=None): super(SentencesPairExtract, self).__init__() self.IPA_embed_dim = IPA_embed_dim self.bsz = batch_size '''前向训练模块''' # 序号0是PADDING # 发音序列嵌入 self.embedding_IPA = nn.Embedding(IPA_vocab_size, IPA_embed_dim,padding_idx=0) # 发音特征维度变换层 self.scaling_IPA = nn.Sequential(nn.Linear(IPA_embed_dim * max_seq_length, 64), nn.ELU()) self.rescaling_IPA = nn.Sequential(nn.Linear(768, 32), nn.ELU()) # IPA 融合层 self.fus_layer = MultiHeadedAttention(head_count=1, model_dim=768) # 编码层 self.enc_src = nn.Sequential(nn.Linear(768,768), nn.ELU(), ) self.src_relu = nn.ELU() self.enc_tgt = nn.Sequential(nn.Linear(800,768), nn.ELU(), ) self.tgt_relu = nn.ELU() # 多层感知机 self.mlp = nn.Sequential(nn.Linear(768 * 4, 768), nn.ELU(), nn.Linear(768, 1), nn.Sigmoid()) # 损失函数 self.loss_func = nn.BCELoss() '''注意力表征对齐模块''' # 注意力层 self.layer_attns = MultiHeadedAttention(head_count=1, model_dim=768) def forward(self, src_vec, tgt_vec, labels, IPA_inputs=None, return_vec=False, MODE=None, anchor_vec=False, src_IPA=False, tgt_IPA=False): if not MODE: '''src_vec ,tgt_vec 都是E5输出的句向量, IPA_inputs 是 tgt 的发音碎片序列,如 [pʰaj@1] ''' if IPA_inputs is not None: tgt = self.fus_IPA(tgt_vec, IPA_inputs) else: tgt = tgt_vec src_vec = self.ResidualEnc(src_vec,1) tgt = self.ResidualEnc(tgt,0) # 推理模块,判断是否平行 logits = self.mlp(torch.cat((tgt, src_vec, tgt - src_vec, tgt * src_vec), 1)) # 根据labels计算loss loss, pred = self.cal_loss_and_pred(logits, labels) return loss, pred else: '''src_vec, tgt_vec, anchor_vec都是E5输出的句向量, src_IPA, tgt_IPA是IPA发音''' src = self.fus_IPA(src_vec, src_IPA) tgt = self.fus_IPA(tgt_vec, tgt_IPA) # labels = [1 for x in range(self.bsz)] src = self.ResidualEnc(src, 0) tgt = self.ResidualEnc(tgt, 0) anchor_vec = self.ResidualEnc(anchor_vec, 1) src_fus_anc = self.layer_attns(src,anchor_vec,anchor_vec) tgt_fus_anc = self.layer_attns(tgt,anchor_vec,anchor_vec) # batch_output = torch.cat((tgt_fus_anc, src_fus_anc, tgt_fus_anc - src_fus_anc, tgt_fus_anc * src_fus_anc), 1) # logits = self.mlp(batch_output.resize(self.bsz, 4 * 768)) # loss, pred = self.cal_loss_and_pred(logits, labels) cosine_loss = nn.CosineEmbeddingLoss(margin=0).to('cuda') loss, pred = cosine_loss(src_fus_anc.squeeze(1), tgt_fus_anc.squeeze(1), torch.ones(self.bsz).to('cuda')), '_' return loss, pred def ResidualEnc(self,vec, ifS): '''0:lao/th+IPA src,tgt 1:zh,src''' if ifS == 0: return self.tgt_relu(self.enc_tgt(vec)) elif ifS == 1: return self.src_relu(self.enc_src(vec) + vec) def cal_loss_and_pred(self, logits, labels): matrix_labels = torch.tensor(labels).float() # (Batch, Batch) poss = logits[matrix_labels == 1] + 1e-4 negs = logits[matrix_labels == 0] + 1e-4 p_ls = torch.log(poss).mean() n_ls = torch.log(1 - negs).mean() loss = - (torch.where(torch.isnan(p_ls), torch.full_like(p_ls, 0), p_ls) + torch.where(torch.isnan(n_ls), torch.full_like(n_ls, 0), n_ls)) # loss = - (torch.log(1 - negs).mean() + torch.log(poss).mean()) predictions = (logits > 0.5).int() # (Batch, ) return loss, predictions def fus_IPA(self, vec, IPA_inputs): IPA_embed = self.embedding_IPA(IPA_inputs) list_ = [] for x in torch.chunk(IPA_embed, self.bsz, dim=0): list_.append(torch.nn.functional.pad(self.scaling_IPA(x.reshape(-1)), [0,704])) IPA_vec = torch.stack(list_) # 融合vec和IPA_vec a = 0.8 fus_vec = torch.cat([vec, (1-a) * self.rescaling_IPA(self.fus_layer(query=vec, key=IPA_vec, value=IPA_vec).squeeze(1))],dim=1) return fus_vec