| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.utils.checkpoint import checkpoint | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, heads, hidden_dim): | |
| super(MultiHeadAttention, self).__init__() | |
| assert hidden_dim % heads == 0 | |
| self.heads = heads | |
| head_dim = hidden_dim // heads | |
| self.alpha = 1 / math.sqrt(head_dim) | |
| self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) | |
| self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) | |
| for param in self.parameters(): | |
| stdv = 1. / math.sqrt(param.size(-1)) | |
| param.data.uniform_(-stdv, stdv) | |
| def forward(self, q, K, V, mask): | |
| batch_size, query_num, hidden_dim = q.size() | |
| size = (self.heads, batch_size, query_num, -1) | |
| q = q.reshape(-1, hidden_dim) | |
| Q = torch.matmul(q, self.nn_Q).view(size) | |
| value_num = V.size(2) | |
| heads_batch = self.heads * batch_size | |
| Q = Q.view(heads_batch, query_num, -1) | |
| K = K.view(heads_batch, value_num, -1).transpose(1, 2) | |
| S = masked_tensor(mask, self.heads) | |
| S = S.view(heads_batch, query_num, value_num) | |
| S.baddbmm_(Q, K, alpha=self.alpha) | |
| S = S.view(self.heads, batch_size, query_num, value_num) | |
| S = F.softmax(S, dim=-1) | |
| x = torch.matmul(S, V).permute(1, 2, 0, 3) | |
| x = x.reshape(batch_size, query_num, -1) | |
| x = torch.matmul(x, self.nn_O) | |
| return x | |
| class Decode(nn.Module): | |
| def __init__(self, nn_args): | |
| super(Decode, self).__init__() | |
| self.nn_args = nn_args | |
| heads = nn_args['decode_atten_heads'] | |
| hidden_dim = nn_args['decode_hidden_dim'] | |
| self.heads = heads | |
| self.alpha = 1 / math.sqrt(hidden_dim) | |
| if heads > 0: | |
| assert hidden_dim % heads == 0 | |
| head_dim = hidden_dim // heads | |
| self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) | |
| self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim)) | |
| self.nn_mha = MultiHeadAttention(heads, hidden_dim) | |
| decode_rnn = nn_args.setdefault('decode_rnn', 'LSTM') | |
| assert decode_rnn in ('GRU', 'LSTM', 'NONE') | |
| if decode_rnn == 'GRU': | |
| self.nn_rnn_cell = nn.GRUCell(hidden_dim, hidden_dim) | |
| elif decode_rnn == 'LSTM': | |
| self.nn_rnn_cell = nn.LSTMCell(hidden_dim, hidden_dim) | |
| else: | |
| self.nn_rnn_cell = None | |
| self.vars_dim = sum(nn_args['variable_dim'].values()) | |
| if self.vars_dim > 0: | |
| atten_type = nn_args.setdefault('decode_atten_type', 'add') | |
| assert atten_type == 'add', "must be addition attention when vars_dim > 0, {}".format(atten_type) | |
| self.nn_A = nn.Parameter(torch.Tensor(self.vars_dim, hidden_dim)) | |
| self.nn_B = nn.Parameter(torch.Tensor(hidden_dim)) | |
| else: | |
| atten_type = nn_args.setdefault('decode_atten_type', 'prod') | |
| if atten_type == 'add': | |
| self.nn_W = nn.Parameter(torch.Tensor(hidden_dim)) | |
| else: | |
| self.nn_W = None | |
| for param in self.parameters(): | |
| stdv = 1 / math.sqrt(param.size(-1)) | |
| param.data.uniform_(-stdv, stdv) | |
| def forward(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt=0): | |
| if self.training and memopt > 2: | |
| state1, state2 = checkpoint(self.rnn_step, query, state1, state2) | |
| else: | |
| state1, state2 = self.rnn_step(query, state1, state2) | |
| query = state1 | |
| NP = X.size(0) | |
| NR = query.size(0) // NP | |
| batch_size = query.size(0) | |
| if self.heads > 0: | |
| query = query.view(NP, NR, -1) | |
| if self.training and memopt > 1: | |
| query = checkpoint(self.nn_mha, query, K, V, mask) | |
| else: | |
| query = self.nn_mha(query, K, V, mask) | |
| query = query.view(batch_size, -1) | |
| if self.nn_W is None: | |
| query = query.view(NP, NR, -1) | |
| logit = masked_tensor(mask, 1) | |
| logit = logit.view(NP, NR, -1) | |
| X = X.permute(0, 2, 1) | |
| logit.baddbmm_(query, X, alpha=self.alpha) | |
| logit = logit.view(batch_size, -1) | |
| else: | |
| if self.training and self.vars_dim > 0 and memopt > 0: | |
| logit = checkpoint(self.atten, query, X, varfeat, mask) | |
| else: | |
| logit = self.atten(query, X, varfeat, mask) | |
| chosen_p = choose(logit, chosen, sample_p, clip, mode) | |
| return state1, state2, chosen_p | |
| def rnn_step(self, query, state1, state2): | |
| if isinstance(self.nn_rnn_cell, nn.GRUCell): | |
| state1 = self.nn_rnn_cell(query, state1) | |
| elif isinstance(self.nn_rnn_cell, nn.LSTMCell): | |
| state1, state2 = self.nn_rnn_cell(query, (state1, state2)) | |
| return state1, state2 | |
| def atten(self, query, keyvalue, varfeat, mask): | |
| if self.vars_dim > 0: | |
| varfeat = vfaddmm(varfeat, mask, self.nn_A, self.nn_B) | |
| return atten(query, keyvalue, varfeat, mask, self.nn_W) | |
| def choose(logit, chosen, sample_p, clip, mode): | |
| mask = logit == -math.inf | |
| logit = torch.tanh(logit) * clip | |
| logit[mask] = -math.inf | |
| if mode == 0: | |
| pass | |
| elif mode == 1: | |
| chosen[:] = logit.argmax(1) | |
| elif mode == 2: | |
| p = logit.exp() | |
| chosen[:] = torch.multinomial(p, 1).squeeze(1) | |
| else: | |
| raise Exception() | |
| logp = logit.log_softmax(1) | |
| logp = logp.gather(1, chosen[:, None]) | |
| logp = logp.squeeze(1) | |
| return logp | |
| def atten(query, keyvalue, varfeat, mask, weight): | |
| batch_size = query.size(0) | |
| NP, NK, ND = keyvalue.size() | |
| query = query.view(NP, -1, 1, ND) | |
| varfeat = varfeat.view(NP, -1, NK, ND) | |
| keyvalue = keyvalue[:, None, :, :] | |
| keyvalue = keyvalue + varfeat + query | |
| keyvalue = torch.tanh(keyvalue) | |
| keyvalue = keyvalue.view(-1, ND) | |
| logit = masked_tensor(mask, 1).view(-1) | |
| logit.addmv_(keyvalue, weight) | |
| return logit.view(batch_size, -1) | |
| def masked_tensor(mask, heads): | |
| size = list(mask.size()) | |
| size.insert(0, heads) | |
| mask = mask[None].expand(size) | |
| result = mask.new_zeros(size, dtype=torch.float32) | |
| result[mask] = -math.inf | |
| return result | |
| def vfaddmm(varfeat, mask, A, B): | |
| varfeat = varfeat.permute(0, 2, 1) | |
| return F.linear(varfeat, A.permute(1, 0), B) | |