CLTMPSE / model.py
KairongLiu's picture
Upload 6 files
5ad246d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import numpy as np
from utils.attn import MultiHeadedAttention
class SentencesPairExtract(nn.Module):
def __init__(self, IPA_embed_dim, max_seq_length,batch_size, IPA_vocab_size=None):
super(SentencesPairExtract, self).__init__()
self.IPA_embed_dim = IPA_embed_dim
self.bsz = batch_size
'''前向训练模块'''
# 序号0是PADDING
# 发音序列嵌入
self.embedding_IPA = nn.Embedding(IPA_vocab_size, IPA_embed_dim,padding_idx=0)
# 发音特征维度变换层
self.scaling_IPA = nn.Sequential(nn.Linear(IPA_embed_dim * max_seq_length, 64),
nn.ELU())
self.rescaling_IPA = nn.Sequential(nn.Linear(768, 32),
nn.ELU())
# IPA 融合层
self.fus_layer = MultiHeadedAttention(head_count=1, model_dim=768)
# 编码层
self.enc_src = nn.Sequential(nn.Linear(768,768),
nn.ELU(),
)
self.src_relu = nn.ELU()
self.enc_tgt = nn.Sequential(nn.Linear(800,768),
nn.ELU(),
)
self.tgt_relu = nn.ELU()
# 多层感知机
self.mlp = nn.Sequential(nn.Linear(768 * 4, 768),
nn.ELU(),
nn.Linear(768, 1),
nn.Sigmoid())
# 损失函数
self.loss_func = nn.BCELoss()
'''注意力表征对齐模块'''
# 注意力层
self.layer_attns = MultiHeadedAttention(head_count=1, model_dim=768)
def forward(self, src_vec, tgt_vec, labels, IPA_inputs=None, return_vec=False,
MODE=None, anchor_vec=False, src_IPA=False, tgt_IPA=False):
if not MODE:
'''src_vec ,tgt_vec 都是E5输出的句向量, IPA_inputs 是 tgt 的发音碎片序列,如 [pʰaj@1] '''
if IPA_inputs is not None:
tgt = self.fus_IPA(tgt_vec, IPA_inputs)
else:
tgt = tgt_vec
src_vec = self.ResidualEnc(src_vec,1)
tgt = self.ResidualEnc(tgt,0)
# 推理模块,判断是否平行
logits = self.mlp(torch.cat((tgt, src_vec, tgt - src_vec, tgt * src_vec), 1))
# 根据labels计算loss
loss, pred = self.cal_loss_and_pred(logits, labels)
return loss, pred
else:
'''src_vec, tgt_vec, anchor_vec都是E5输出的句向量, src_IPA, tgt_IPA是IPA发音'''
src = self.fus_IPA(src_vec, src_IPA)
tgt = self.fus_IPA(tgt_vec, tgt_IPA)
# labels = [1 for x in range(self.bsz)]
src = self.ResidualEnc(src, 0)
tgt = self.ResidualEnc(tgt, 0)
anchor_vec = self.ResidualEnc(anchor_vec, 1)
src_fus_anc = self.layer_attns(src,anchor_vec,anchor_vec)
tgt_fus_anc = self.layer_attns(tgt,anchor_vec,anchor_vec)
# batch_output = torch.cat((tgt_fus_anc, src_fus_anc, tgt_fus_anc - src_fus_anc, tgt_fus_anc * src_fus_anc), 1)
# logits = self.mlp(batch_output.resize(self.bsz, 4 * 768))
# loss, pred = self.cal_loss_and_pred(logits, labels)
cosine_loss = nn.CosineEmbeddingLoss(margin=0).to('cuda')
loss, pred = cosine_loss(src_fus_anc.squeeze(1), tgt_fus_anc.squeeze(1), torch.ones(self.bsz).to('cuda')), '_'
return loss, pred
def ResidualEnc(self,vec, ifS):
'''0:lao/th+IPA src,tgt 1:zh,src'''
if ifS == 0:
return self.tgt_relu(self.enc_tgt(vec))
elif ifS == 1:
return self.src_relu(self.enc_src(vec) + vec)
def cal_loss_and_pred(self, logits, labels):
matrix_labels = torch.tensor(labels).float() # (Batch, Batch)
poss = logits[matrix_labels == 1] + 1e-4
negs = logits[matrix_labels == 0] + 1e-4
p_ls = torch.log(poss).mean()
n_ls = torch.log(1 - negs).mean()
loss = - (torch.where(torch.isnan(p_ls), torch.full_like(p_ls, 0), p_ls) +
torch.where(torch.isnan(n_ls), torch.full_like(n_ls, 0), n_ls))
# loss = - (torch.log(1 - negs).mean() + torch.log(poss).mean())
predictions = (logits > 0.5).int() # (Batch, )
return loss, predictions
def fus_IPA(self, vec, IPA_inputs):
IPA_embed = self.embedding_IPA(IPA_inputs)
list_ = []
for x in torch.chunk(IPA_embed, self.bsz, dim=0):
list_.append(torch.nn.functional.pad(self.scaling_IPA(x.reshape(-1)), [0,704]))
IPA_vec = torch.stack(list_)
# 融合vec和IPA_vec
a = 0.8
fus_vec = torch.cat([vec, (1-a) * self.rescaling_IPA(self.fus_layer(query=vec, key=IPA_vec, value=IPA_vec).squeeze(1))],dim=1)
return fus_vec