|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.autograd import Variable |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
import numpy as np, itertools, random, copy, math |
|
|
|
|
|
|
|
|
class DiffLoss(nn.Module): |
|
|
|
|
|
def __init__(self, args): |
|
|
super(DiffLoss, self).__init__() |
|
|
|
|
|
def forward(self, input1, input2): |
|
|
|
|
|
|
|
|
|
|
|
batch_size = input1.size(0) |
|
|
N = input1.size(1) |
|
|
input1 = input1.view(batch_size, -1) |
|
|
input2 = input2.view(batch_size, -1) |
|
|
|
|
|
|
|
|
|
|
|
input1_mean = torch.mean(input1, dim=0, keepdim=True) |
|
|
input2_mean = torch.mean(input2, dim=0, keepdim=True) |
|
|
input1 = input1 - input1_mean |
|
|
input2 = input2 - input2_mean |
|
|
|
|
|
input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True) |
|
|
input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6) |
|
|
|
|
|
|
|
|
input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True) |
|
|
input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_diff = torch.mean(torch.norm(input1_l2 - input2_l2, p=2, dim=1)) |
|
|
if norm_diff.item() == 0: |
|
|
return torch.tensor(float('inf'), device=input1.device) |
|
|
diff_loss = 1.0 / norm_diff |
|
|
|
|
|
|
|
|
|
|
|
return diff_loss |
|
|
|
|
|
|
|
|
class MaskedNLLLoss(nn.Module): |
|
|
|
|
|
def __init__(self, weight=None): |
|
|
super(MaskedNLLLoss, self).__init__() |
|
|
self.weight = weight |
|
|
self.loss = nn.NLLLoss(weight=weight, |
|
|
reduction='sum') |
|
|
|
|
|
def forward(self, pred, target, mask): |
|
|
""" |
|
|
pred -> batch*seq_len, n_classes |
|
|
target -> batch*seq_len |
|
|
mask -> batch, seq_len |
|
|
""" |
|
|
mask_ = mask.view(-1, 1) |
|
|
if type(self.weight) == type(None): |
|
|
loss = self.loss(pred * mask_, target) / torch.sum(mask) |
|
|
else: |
|
|
loss = self.loss(pred * mask_, target) \ |
|
|
/ torch.sum(self.weight[target] * mask_.squeeze()) |
|
|
return loss |
|
|
|
|
|
|
|
|
class MaskedMSELoss(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
super(MaskedMSELoss, self).__init__() |
|
|
self.loss = nn.MSELoss(reduction='sum') |
|
|
|
|
|
def forward(self, pred, target, mask): |
|
|
""" |
|
|
pred -> batch*seq_len |
|
|
target -> batch*seq_len |
|
|
mask -> batch*seq_len |
|
|
""" |
|
|
loss = self.loss(pred * mask, target) / torch.sum(mask) |
|
|
return loss |
|
|
|
|
|
|
|
|
class UnMaskedWeightedNLLLoss(nn.Module): |
|
|
|
|
|
def __init__(self, weight=None): |
|
|
super(UnMaskedWeightedNLLLoss, self).__init__() |
|
|
self.weight = weight |
|
|
self.loss = nn.NLLLoss(weight=weight, |
|
|
reduction='sum') |
|
|
|
|
|
def forward(self, pred, target): |
|
|
""" |
|
|
pred -> batch*seq_len, n_classes |
|
|
target -> batch*seq_len |
|
|
""" |
|
|
if type(self.weight) == type(None): |
|
|
loss = self.loss(pred, target) |
|
|
else: |
|
|
loss = self.loss(pred, target) \ |
|
|
/ torch.sum(self.weight[target]) |
|
|
return loss |
|
|
|
|
|
class GatedSelection(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.context_trans = nn.Linear(hidden_size, hidden_size) |
|
|
self.linear1 = nn.Linear(hidden_size, hidden_size) |
|
|
self.linear2 = nn.Linear(hidden_size, hidden_size) |
|
|
self.fc = nn.Linear(hidden_size, hidden_size) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, x1, x2): |
|
|
x2 = self.context_trans(x2) |
|
|
s = self.sigmoid(self.linear1(x1)+self.linear2(x2)) |
|
|
h = s * x1 + (1 - s) * x2 |
|
|
return self.relu(self.fc(h)) |
|
|
|
|
|
def mask_logic(alpha, adj): |
|
|
''' |
|
|
performing mask logic with adj |
|
|
:param alpha: |
|
|
:param adj: |
|
|
:return: |
|
|
''' |
|
|
return alpha - (1 - adj) * 1e30 |
|
|
|
|
|
class GatLinear(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(hidden_size * 2, 1) |
|
|
|
|
|
|
|
|
def forward(self, Q, K, V, adj): |
|
|
''' |
|
|
imformation gatherer with linear attention |
|
|
:param Q: (B, D) # query utterance |
|
|
:param K: (B, N, D) # context |
|
|
:param V: (B, N, D) # context |
|
|
:param adj: (B, N) # the adj matrix of the i th node |
|
|
:return: |
|
|
''' |
|
|
N = K.size()[1] |
|
|
|
|
|
Q = Q.unsqueeze(1).expand(-1, N, -1) |
|
|
|
|
|
X = torch.cat((Q,K), dim = 2) |
|
|
|
|
|
alpha = self.linear(X).permute(0,2,1) |
|
|
|
|
|
|
|
|
adj = adj.unsqueeze(1) |
|
|
alpha = mask_logic(alpha, adj) |
|
|
|
|
|
|
|
|
|
|
|
attn_weight = F.softmax(alpha, dim = 2) |
|
|
|
|
|
|
|
|
|
|
|
attn_sum = torch.bmm(attn_weight, V).squeeze(1) |
|
|
|
|
|
|
|
|
return attn_weight, attn_sum |
|
|
|
|
|
class GatDot(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(hidden_size, hidden_size) |
|
|
self.linear2 = nn.Linear(hidden_size, hidden_size) |
|
|
|
|
|
def forward(self, Q, K, V, adj): |
|
|
''' |
|
|
imformation gatherer with dot product attention |
|
|
:param Q: (B, D) # query utterance |
|
|
:param K: (B, N, D) # context |
|
|
:param V: (B, N, D) # context |
|
|
:param adj: (B, N) # the adj matrix of the i th node |
|
|
:return: |
|
|
''' |
|
|
N = K.size()[1] |
|
|
|
|
|
|
|
|
Q = self.linear1(Q).unsqueeze(2) |
|
|
|
|
|
K = self.linear2(K) |
|
|
|
|
|
alpha = torch.bmm(K, Q).permute(0, 2, 1) |
|
|
|
|
|
adj = adj.unsqueeze(1) |
|
|
alpha = mask_logic(alpha, adj) |
|
|
|
|
|
attn_weight = F.softmax(alpha, dim=2) |
|
|
|
|
|
attn_sum = torch.bmm(attn_weight, V).squeeze(1) |
|
|
|
|
|
return attn_weight, attn_sum |
|
|
|
|
|
class GatLinear_rel(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(hidden_size * 3, 1) |
|
|
self.rel_emb = nn.Embedding(2, hidden_size) |
|
|
|
|
|
|
|
|
def forward(self, Q, K, V, adj, s_mask): |
|
|
''' |
|
|
imformation gatherer with linear attention |
|
|
:param Q: (B, D) # query utterance |
|
|
:param K: (B, N, D) # context |
|
|
:param V: (B, N, D) # context |
|
|
:param adj: (B, N) # the adj matrix of the i th node |
|
|
:param s_mask: (B, N) # |
|
|
:return: |
|
|
''' |
|
|
rel_emb = self.rel_emb(s_mask) |
|
|
N = K.size()[1] |
|
|
|
|
|
Q = Q.unsqueeze(1).expand(-1, N, -1) |
|
|
|
|
|
|
|
|
X = torch.cat((Q,K, rel_emb), dim = 2) |
|
|
|
|
|
alpha = self.linear(X).permute(0,2,1) |
|
|
|
|
|
|
|
|
adj = adj.unsqueeze(1) |
|
|
alpha = mask_logic(alpha, adj) |
|
|
|
|
|
|
|
|
|
|
|
attn_weight = F.softmax(alpha, dim = 2) |
|
|
|
|
|
|
|
|
|
|
|
attn_sum = torch.bmm(attn_weight, V).squeeze(1) |
|
|
|
|
|
|
|
|
return attn_weight, attn_sum |
|
|
|
|
|
|
|
|
class GatDot_rel(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(hidden_size, hidden_size) |
|
|
self.linear2 = nn.Linear(hidden_size, hidden_size) |
|
|
self.linear3 = nn.Linear(hidden_size, 1) |
|
|
self.rel_emb = nn.Embedding(2, hidden_size) |
|
|
|
|
|
def forward(self, Q, K, V, adj, s_mask): |
|
|
''' |
|
|
imformation gatherer with dot product attention |
|
|
:param Q: (B, D) # query utterance |
|
|
:param K: (B, N, D) # context |
|
|
:param V: (B, N, D) # context |
|
|
:param adj: (B, N) # the adj matrix of the i th node |
|
|
:param s_mask: (B, N) # relation mask |
|
|
:return: |
|
|
''' |
|
|
N = K.size()[1] |
|
|
|
|
|
rel_emb = self.rel_emb(s_mask) |
|
|
Q = self.linear1(Q).unsqueeze(2) |
|
|
K = self.linear2(K) |
|
|
y = self.linear3(rel_emb) |
|
|
|
|
|
alpha = (torch.bmm(K, Q) + y).permute(0, 2, 1) |
|
|
|
|
|
adj = adj.unsqueeze(1) |
|
|
alpha = mask_logic(alpha, adj) |
|
|
|
|
|
attn_weight = F.softmax(alpha, dim=2) |
|
|
|
|
|
attn_sum = torch.bmm(attn_weight, V).squeeze(1) |
|
|
|
|
|
return attn_weight, attn_sum |
|
|
|
|
|
|
|
|
class GAT_dialoggcn(nn.Module): |
|
|
''' |
|
|
H_i = alpha_ij(W_rH_j) |
|
|
alpha_ij = attention(H_i, H_j) |
|
|
''' |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.linear = nn.Linear(hidden_size * 2, 1) |
|
|
self.rel_emb = nn.Parameter(torch.randn(2, hidden_size, hidden_size)) |
|
|
|
|
|
def forward(self, Q, K, V, adj, s_mask_onehot): |
|
|
''' |
|
|
imformation gatherer with linear attention |
|
|
:param Q: (B, D) # query utterance |
|
|
:param K: (B, N, D) # context |
|
|
:param V: (B, N, D) # context |
|
|
:param adj: (B, N) # the adj matrix of the i th node |
|
|
:param s_mask: (B, N, 2) # |
|
|
:return: |
|
|
''' |
|
|
B = K.size()[0] |
|
|
N = K.size()[1] |
|
|
|
|
|
Q = Q.unsqueeze(1).expand(-1, N, -1) |
|
|
|
|
|
X = torch.cat((Q,K), dim = 2) |
|
|
|
|
|
alpha = self.linear(X).permute(0,2,1) |
|
|
|
|
|
|
|
|
adj = adj.unsqueeze(1) |
|
|
alpha = mask_logic(alpha, adj) |
|
|
|
|
|
|
|
|
|
|
|
attn_weight = F.softmax(alpha, dim = 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
D = self.rel_emb.size()[2] |
|
|
|
|
|
rel_emb = self.rel_emb.unsqueeze(0).expand(B,-1,-1,-1) |
|
|
|
|
|
|
|
|
|
|
|
rel_emb = rel_emb.reshape((B, 2, D*D)) |
|
|
|
|
|
Wr = torch.bmm(s_mask_onehot, rel_emb).reshape((B, N, D, D)) |
|
|
|
|
|
|
|
|
Wr = Wr.reshape((B*N, D, D)) |
|
|
|
|
|
|
|
|
V = V.unsqueeze(2).reshape((B*N, 1, -1)) |
|
|
|
|
|
V = torch.bmm(V, Wr).unsqueeze(1) |
|
|
|
|
|
V = V.reshape((B,N,-1)) |
|
|
|
|
|
|
|
|
attn_sum = torch.bmm(attn_weight, V).squeeze(1) |
|
|
|
|
|
|
|
|
return attn_weight, attn_sum |
|
|
|
|
|
|
|
|
class GAT_dialoggcn_v1(nn.Module): |
|
|
''' |
|
|
use linear to avoid OOM |
|
|
H_i = alpha_ij(W_rH_j) |
|
|
alpha_ij = attention(H_i, H_j) |
|
|
''' |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.linear = nn.Linear(hidden_size * 2, 1) |
|
|
self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False) |
|
|
self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False) |
|
|
|
|
|
def forward(self, Q, K, V, adj, s_mask): |
|
|
''' |
|
|
imformation gatherer with linear attention |
|
|
:param Q: (B, D) # query utterance |
|
|
:param K: (B, N, D) # context |
|
|
:param V: (B, N, D) # context |
|
|
:param adj: (B, N) # the adj matrix of the i th node |
|
|
:param s_mask: (B, N) # |
|
|
:return: |
|
|
''' |
|
|
B = K.size()[0] |
|
|
N = K.size()[1] |
|
|
|
|
|
Q = Q.unsqueeze(1).expand(-1, N, -1) |
|
|
|
|
|
X = torch.cat((Q,K), dim = 2) |
|
|
|
|
|
alpha = self.linear(X).permute(0,2,1) |
|
|
|
|
|
|
|
|
|
|
|
adj = adj.unsqueeze(1) |
|
|
alpha = mask_logic(alpha, adj) |
|
|
|
|
|
|
|
|
|
|
|
attn_weight = F.softmax(alpha, dim = 2) |
|
|
|
|
|
|
|
|
|
|
|
V0 = self.Wr0(V) |
|
|
V1 = self.Wr1(V) |
|
|
|
|
|
s_mask = s_mask.unsqueeze(2).float() |
|
|
V = V0 * s_mask + V1 * (1 - s_mask) |
|
|
|
|
|
attn_sum = torch.bmm(attn_weight, V).squeeze(1) |
|
|
|
|
|
|
|
|
return attn_weight, attn_sum |
|
|
|
|
|
|
|
|
class GAT_dialoggcn_v2(nn.Module): |
|
|
''' |
|
|
use linear to avoid OOM |
|
|
H_i = alpha_ij(W_rH_j) |
|
|
alpha_ij = attention(H_i, H_j, rel) |
|
|
''' |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.linear = nn.Linear(hidden_size * 3, 1) |
|
|
self.Wr0 = nn.Linear(hidden_size, hidden_size, bias = False) |
|
|
self.Wr1 = nn.Linear(hidden_size, hidden_size, bias = False) |
|
|
self.rel_emb = nn.Embedding(2, hidden_size) |
|
|
|
|
|
def forward(self, Q, K, V, adj, s_mask): |
|
|
''' |
|
|
imformation gatherer with linear attention |
|
|
:param Q: (B, D) # query utterance |
|
|
:param K: (B, N, D) # context |
|
|
:param V: (B, N, D) # context |
|
|
:param adj: (B, N) # the adj matrix of the i th node |
|
|
:param s_mask: (B, N) # |
|
|
:return: |
|
|
''' |
|
|
rel_emb = self.rel_emb(s_mask) |
|
|
B = K.size()[0] |
|
|
N = K.size()[1] |
|
|
|
|
|
Q = Q.unsqueeze(1).expand(-1, N, -1) |
|
|
|
|
|
X = torch.cat((Q,K,rel_emb), dim = 2) |
|
|
|
|
|
alpha = self.linear(X).permute(0,2,1) |
|
|
|
|
|
|
|
|
adj = adj.unsqueeze(1) |
|
|
alpha = mask_logic(alpha, adj) |
|
|
|
|
|
|
|
|
|
|
|
attn_weight = F.softmax(alpha, dim = 2) |
|
|
|
|
|
|
|
|
|
|
|
V0 = self.Wr0(V) |
|
|
V1 = self.Wr1(V) |
|
|
|
|
|
s_mask = s_mask.unsqueeze(2).float() |
|
|
V = V0 * s_mask + V1 * (1 - s_mask) |
|
|
|
|
|
attn_sum = torch.bmm(attn_weight, V).squeeze(1) |
|
|
|
|
|
|
|
|
return attn_weight, attn_sum |
|
|
|
|
|
|
|
|
class attentive_node_features(nn.Module): |
|
|
''' |
|
|
Method to obtain attentive node features over the graph convoluted features |
|
|
''' |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.transform = nn.Linear(hidden_size, hidden_size) |
|
|
|
|
|
def forward(self,features, lengths, nodal_att_type): |
|
|
''' |
|
|
features : (B, N, V) |
|
|
lengths : (B, ) |
|
|
nodal_att_type : type of the final nodal attention |
|
|
''' |
|
|
|
|
|
if nodal_att_type==None: |
|
|
return features |
|
|
|
|
|
batch_size = features.size(0) |
|
|
max_seq_len = features.size(1) |
|
|
padding_mask = [l*[1]+(max_seq_len-l)*[0] for l in lengths] |
|
|
padding_mask = torch.tensor(padding_mask).to(features) |
|
|
causal_mask = torch.ones(max_seq_len, max_seq_len).to(features) |
|
|
causal_mask = torch.tril(causal_mask).unsqueeze(0) |
|
|
|
|
|
if nodal_att_type=='global': |
|
|
mask = padding_mask.unsqueeze(1) |
|
|
elif nodal_att_type=='past': |
|
|
mask = padding_mask.unsqueeze(1)*causal_mask |
|
|
|
|
|
x = self.transform(features) |
|
|
temp = torch.bmm(x, features.permute(0,2,1)) |
|
|
|
|
|
alpha = F.softmax(torch.tanh(temp), dim=2) |
|
|
alpha_masked = alpha*mask |
|
|
|
|
|
alpha_sum = torch.sum(alpha_masked, dim=2, keepdim=True) |
|
|
|
|
|
alpha = alpha_masked / alpha_sum |
|
|
attn_pool = torch.bmm(alpha, features) |
|
|
|
|
|
return attn_pool |
|
|
|
|
|
|
|
|
|