import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from torch.nn.utils.rnn import pad_sequence import numpy as np, itertools, random, copy, math class DiffLoss(nn.Module): def __init__(self, args): super(DiffLoss, self).__init__() def forward(self, input1, input2): # input1 (B,N,D) input2 (B,N,D) batch_size = input1.size(0) N = input1.size(1) input1 = input1.view(batch_size, -1) # (B,N*D) input2 = input2.view(batch_size, -1) # (B, N*D) # print('input1:', input1) # print('input2:', input2) # Zero mean input1_mean = torch.mean(input1, dim=0, keepdim=True) # (1,N*D) input2_mean = torch.mean(input2, dim=0, keepdim=True) # (1,N*D) input1 = input1 - input1_mean # (B,N*D) input2 = input2 - input2_mean # (B,N*D) input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True) # (B,1) input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6) # (B,N*D) input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True) # (B,1) input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6) # (B,N*D) # print("input1_l2_norm:", input1_l2_norm.detach().cpu().numpy()) # print("input2_l2_norm:", input2_l2_norm.detach().cpu().numpy()) # print("input1_l2:", input1_l2.detach().cpu().numpy()) # print("input2_l2:", input2_l2.detach().cpu().numpy()) norm_diff = torch.mean(torch.norm(input1_l2 - input2_l2, p=2, dim=1)) if norm_diff.item() == 0: return torch.tensor(float('inf'), device=input1.device) diff_loss = 1.0 / norm_diff # print('loss:', diff_loss) return diff_loss class MaskedNLLLoss(nn.Module): def __init__(self, weight=None): super(MaskedNLLLoss, self).__init__() self.weight = weight self.loss = nn.NLLLoss(weight=weight, reduction='sum') def forward(self, pred, target, mask): """ pred -> batch*seq_len, n_classes target -> batch*seq_len mask -> batch, seq_len """ mask_ = mask.view(-1, 1) # batch*seq_len, 1 if type(self.weight) == type(None): loss = self.loss(pred * mask_, target) / torch.sum(mask) else: loss = self.loss(pred * mask_, target) \ / torch.sum(self.weight[target] * mask_.squeeze()) return loss class MaskedMSELoss(nn.Module): def __init__(self): super(MaskedMSELoss, self).__init__() self.loss = nn.MSELoss(reduction='sum') def forward(self, pred, target, mask): """ pred -> batch*seq_len target -> batch*seq_len mask -> batch*seq_len """ loss = self.loss(pred * mask, target) / torch.sum(mask) return loss class UnMaskedWeightedNLLLoss(nn.Module): def __init__(self, weight=None): super(UnMaskedWeightedNLLLoss, self).__init__() self.weight = weight self.loss = nn.NLLLoss(weight=weight, reduction='sum') def forward(self, pred, target): """ pred -> batch*seq_len, n_classes target -> batch*seq_len """ if type(self.weight) == type(None): loss = self.loss(pred, target) else: loss = self.loss(pred, target) \ / torch.sum(self.weight[target]) return loss class GatedSelection(nn.Module): def __init__(self, hidden_size): super().__init__() self.context_trans = nn.Linear(hidden_size, hidden_size) self.linear1 = nn.Linear(hidden_size, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.fc = nn.Linear(hidden_size, hidden_size) self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU() def forward(self, x1, x2): x2 = self.context_trans(x2) s = self.sigmoid(self.linear1(x1)+self.linear2(x2)) h = s * x1 + (1 - s) * x2 return self.relu(self.fc(h)) def mask_logic(alpha, adj): ''' performing mask logic with adj :param alpha: :param adj: :return: ''' return alpha - (1 - adj) * 1e30 class GatLinear(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size * 2, 1) def forward(self, Q, K, V, adj): ''' imformation gatherer with linear attention :param Q: (B, D) # query utterance :param K: (B, N, D) # context :param V: (B, N, D) # context :param adj: (B, N) # the adj matrix of the i th node :return: ''' N = K.size()[1] # print('Q',Q.size()) Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D) # print('K',K.size()) X = torch.cat((Q,K), dim = 2) # (B, N, 2D) # print('X',X.size()) alpha = self.linear(X).permute(0,2,1) #(B, 1, N) # print('alpha',alpha.size()) # print(alpha) adj = adj.unsqueeze(1) alpha = mask_logic(alpha, adj) # (B, 1, N) # print('alpha after mask',alpha.size()) # print(alpha) attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N) # print('attn_weight',attn_weight.size()) # print(attn_weight) attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D) # print('attn_sum',attn_sum.size()) return attn_weight, attn_sum class GatDot(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear1 = nn.Linear(hidden_size, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) def forward(self, Q, K, V, adj): ''' imformation gatherer with dot product attention :param Q: (B, D) # query utterance :param K: (B, N, D) # context :param V: (B, N, D) # context :param adj: (B, N) # the adj matrix of the i th node :return: ''' N = K.size()[1] Q = self.linear1(Q).unsqueeze(2) # (B,D,1) # K = self.linear2(Q) # (B, N, D) K = self.linear2(K) # (B, N, D) alpha = torch.bmm(K, Q).permute(0, 2, 1) # (B, 1, N) adj = adj.unsqueeze(1) alpha = mask_logic(alpha, adj) # (B, 1, N) attn_weight = F.softmax(alpha, dim=2) # (B, 1, N) attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D) return attn_weight, attn_sum class GatLinear_rel(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size * 3, 1) self.rel_emb = nn.Embedding(2, hidden_size) def forward(self, Q, K, V, adj, s_mask): ''' imformation gatherer with linear attention :param Q: (B, D) # query utterance :param K: (B, N, D) # context :param V: (B, N, D) # context :param adj: (B, N) # the adj matrix of the i th node :param s_mask: (B, N) # :return: ''' rel_emb = self.rel_emb(s_mask) # (B, N, D) N = K.size()[1] # print('Q',Q.size()) Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D) # print('K',K.size()) # print('rel_emb', rel_emb.size()) X = torch.cat((Q,K, rel_emb), dim = 2) # (B, N, 2D)? (B, N, 3D) # print('X',X.size()) alpha = self.linear(X).permute(0,2,1) #(B, 1, N) # print('alpha',alpha.size()) # print(alpha) adj = adj.unsqueeze(1) alpha = mask_logic(alpha, adj) # (B, 1, N) # print('alpha after mask',alpha.size()) # print(alpha) attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N) # print('attn_weight',attn_weight.size()) # print(attn_weight) attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D) # print('attn_sum',attn_sum.size()) return attn_weight, attn_sum class GatDot_rel(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear1 = nn.Linear(hidden_size, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.linear3 = nn.Linear(hidden_size, 1) self.rel_emb = nn.Embedding(2, hidden_size) def forward(self, Q, K, V, adj, s_mask): ''' imformation gatherer with dot product attention :param Q: (B, D) # query utterance :param K: (B, N, D) # context :param V: (B, N, D) # context :param adj: (B, N) # the adj matrix of the i th node :param s_mask: (B, N) # relation mask :return: ''' N = K.size()[1] rel_emb = self.rel_emb(s_mask) Q = self.linear1(Q).unsqueeze(2) # (B,D,1) K = self.linear2(K) # (B, N, D) y = self.linear3(rel_emb) # (B, N, 1) alpha = (torch.bmm(K, Q) + y).permute(0, 2, 1) # (B, 1, N) adj = adj.unsqueeze(1) alpha = mask_logic(alpha, adj) # (B, 1, N) attn_weight = F.softmax(alpha, dim=2) # (B, 1, N) attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D) return attn_weight, attn_sum class GAT_dialoggcn(nn.Module): ''' H_i = alpha_ij(W_rH_j) alpha_ij = attention(H_i, H_j) ''' def __init__(self, hidden_size): super().__init__() self.hidden_size = hidden_size self.linear = nn.Linear(hidden_size * 2, 1) self.rel_emb = nn.Parameter(torch.randn(2, hidden_size, hidden_size)) def forward(self, Q, K, V, adj, s_mask_onehot): ''' imformation gatherer with linear attention :param Q: (B, D) # query utterance :param K: (B, N, D) # context :param V: (B, N, D) # context :param adj: (B, N) # the adj matrix of the i th node :param s_mask: (B, N, 2) # :return: ''' B = K.size()[0] N = K.size()[1] # print('Q',Q.size()) Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D); # print('K',K.size()) X = torch.cat((Q,K), dim = 2) # (B, N, 2D) # print('X',X.size()) alpha = self.linear(X).permute(0,2,1) #(B, 1, N) # print('alpha',alpha.size()) # print(alpha) adj = adj.unsqueeze(1) alpha = mask_logic(alpha, adj) # (B, 1, N) # print('alpha after mask',alpha.size()) # print(alpha) attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N) # print('attn_weight',attn_weight.size()) # print(attn_weight) # print('s_mask_onehot', s_mask_onehot.size()) D = self.rel_emb.size()[2] # print('rel_emb', self.rel_emb.size()) rel_emb = self.rel_emb.unsqueeze(0).expand(B,-1,-1,-1) # rel_emb = self.rel_emb.unsqueeze(0).repeat(B, 1, 1, 1) # print('rel_emb expand', rel_emb.size()) rel_emb = rel_emb.reshape((B, 2, D*D)) # print('rel_emb resize', rel_emb.size()) Wr = torch.bmm(s_mask_onehot, rel_emb).reshape((B, N, D, D)) # (B, N, D, D) # print('Wr', Wr.size()) # (B, N, D, D) Wr = Wr.reshape((B*N, D, D)) # print('Wr after reshape', Wr.size()) V = V.unsqueeze(2).reshape((B*N, 1, -1)) # (B*N, 1, D) # print('V after reshape', V.size()) V = torch.bmm(V, Wr).unsqueeze(1) #(B * N, D) # print('V after transform', V.size()) V = V.reshape((B,N,-1)) # print('Final V', V.size()) attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D) # print('attn_sum',attn_sum.size()) return attn_weight, attn_sum class GAT_dialoggcn_v1(nn.Module): ''' use linear to avoid OOM H_i = alpha_ij(W_rH_j) alpha_ij = attention(H_i, H_j) ''' def __init__(self, hidden_size): super().__init__() self.hidden_size = hidden_size self.linear = nn.Linear(hidden_size * 2, 1) self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False) self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False) def forward(self, Q, K, V, adj, s_mask): ''' imformation gatherer with linear attention :param Q: (B, D) # query utterance :param K: (B, N, D) # context :param V: (B, N, D) # context :param adj: (B, N) # the adj matrix of the i th node :param s_mask: (B, N) # :return: ''' B = K.size()[0] N = K.size()[1] # print('Q',Q.size()) Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D); # print('K',K.size()) X = torch.cat((Q,K), dim = 2) # (B, N, 2D) # print('X',X.size()) alpha = self.linear(X).permute(0,2,1) #(B, 1, N) #alpha = F.leaky_relu(alpha) # print('alpha',alpha.size()) # print(alpha) adj = adj.unsqueeze(1) # (B, 1, N) alpha = mask_logic(alpha, adj) # (B, 1, N) # print('alpha after mask',alpha.size()) # print(alpha) attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N) # print('attn_weight',attn_weight.size()) # print(attn_weight) V0 = self.Wr0(V) # (B, N, D) V1 = self.Wr1(V) # (B, N, D) s_mask = s_mask.unsqueeze(2).float() # (B, N, 1) V = V0 * s_mask + V1 * (1 - s_mask) attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D) # print('attn_sum',attn_sum.size()) return attn_weight, attn_sum class GAT_dialoggcn_v2(nn.Module): ''' use linear to avoid OOM H_i = alpha_ij(W_rH_j) alpha_ij = attention(H_i, H_j, rel) ''' def __init__(self, hidden_size): super().__init__() self.hidden_size = hidden_size self.linear = nn.Linear(hidden_size * 3, 1) self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False) self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False) self.rel_emb = nn.Embedding(2, hidden_size) def forward(self, Q, K, V, adj, s_mask): ''' imformation gatherer with linear attention :param Q: (B, D) # query utterance :param K: (B, N, D) # context :param V: (B, N, D) # context :param adj: (B, N) # the adj matrix of the i th node :param s_mask: (B, N) # :return: ''' rel_emb = self.rel_emb(s_mask) # (B, N, D) B = K.size()[0] N = K.size()[1] # print('Q',Q.size()) Q = Q.unsqueeze(1).expand(-1, N, -1) # (B, N, D); # print('K',K.size()) X = torch.cat((Q,K,rel_emb), dim = 2) # (B, N, 3D) # print('X',X.size()) alpha = self.linear(X).permute(0,2,1) #(B, 1, N) # print('alpha',alpha.size()) # print(alpha) adj = adj.unsqueeze(1) alpha = mask_logic(alpha, adj) # (B, 1, N) # print('alpha after mask',alpha.size()) # print(alpha) attn_weight = F.softmax(alpha, dim = 2) # (B, 1, N) # print('attn_weight',attn_weight.size()) # print(attn_weight) V0 = self.Wr0(V) # (B, N,D) V1 = self.Wr1(V) # (B, N, D) s_mask = s_mask.unsqueeze(2).float() V = V0 * s_mask + V1 * (1 - s_mask) attn_sum = torch.bmm(attn_weight, V).squeeze(1) # (B, D) # print('attn_sum',attn_sum.size()) return attn_weight, attn_sum class attentive_node_features(nn.Module): ''' Method to obtain attentive node features over the graph convoluted features ''' def __init__(self, hidden_size): super().__init__() self.transform = nn.Linear(hidden_size, hidden_size) def forward(self,features, lengths, nodal_att_type): ''' features : (B, N, V) lengths : (B, ) nodal_att_type : type of the final nodal attention ''' if nodal_att_type==None: return features batch_size = features.size(0) max_seq_len = features.size(1) padding_mask = [l*[1]+(max_seq_len-l)*[0] for l in lengths] padding_mask = torch.tensor(padding_mask).to(features) # (B, N) causal_mask = torch.ones(max_seq_len, max_seq_len).to(features) # (N, N) causal_mask = torch.tril(causal_mask).unsqueeze(0) # (1, N, N) if nodal_att_type=='global': mask = padding_mask.unsqueeze(1) elif nodal_att_type=='past': mask = padding_mask.unsqueeze(1)*causal_mask x = self.transform(features) # (B, N, V) temp = torch.bmm(x, features.permute(0,2,1)) #print(temp) alpha = F.softmax(torch.tanh(temp), dim=2) # (B, N, N) alpha_masked = alpha*mask # (B, N, N) alpha_sum = torch.sum(alpha_masked, dim=2, keepdim=True) # (B, N, 1) #print(alpha_sum) alpha = alpha_masked / alpha_sum # (B, N, N) attn_pool = torch.bmm(alpha, features) # (B, N, V) return attn_pool