|
|
""" |
|
|
Pytorch implementation of basic sequence to Sequence modules. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
import numpy as np |
|
|
|
|
|
import stanza.models.common.seq2seq_constant as constant |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
class BasicAttention(nn.Module): |
|
|
""" |
|
|
A basic MLP attention layer. |
|
|
""" |
|
|
def __init__(self, dim): |
|
|
super(BasicAttention, self).__init__() |
|
|
self.linear_in = nn.Linear(dim, dim, bias=False) |
|
|
self.linear_c = nn.Linear(dim, dim) |
|
|
self.linear_v = nn.Linear(dim, 1, bias=False) |
|
|
self.linear_out = nn.Linear(dim * 2, dim, bias=False) |
|
|
self.tanh = nn.Tanh() |
|
|
self.sm = nn.Softmax(dim=1) |
|
|
|
|
|
def forward(self, input, context, mask=None, attn_only=False): |
|
|
""" |
|
|
input: batch x dim |
|
|
context: batch x sourceL x dim |
|
|
""" |
|
|
batch_size = context.size(0) |
|
|
source_len = context.size(1) |
|
|
dim = context.size(2) |
|
|
target = self.linear_in(input) |
|
|
source = self.linear_c(context.contiguous().view(-1, dim)).view(batch_size, source_len, dim) |
|
|
attn = target.unsqueeze(1).expand_as(context) + source |
|
|
attn = self.tanh(attn) |
|
|
attn = self.linear_v(attn.view(-1, dim)).view(batch_size, source_len) |
|
|
|
|
|
if mask is not None: |
|
|
attn.masked_fill_(mask, -constant.INFINITY_NUMBER) |
|
|
|
|
|
attn = self.sm(attn) |
|
|
if attn_only: |
|
|
return attn |
|
|
|
|
|
weighted_context = torch.bmm(attn.unsqueeze(1), context).squeeze(1) |
|
|
h_tilde = torch.cat((weighted_context, input), 1) |
|
|
h_tilde = self.tanh(self.linear_out(h_tilde)) |
|
|
|
|
|
return h_tilde, attn |
|
|
|
|
|
class SoftDotAttention(nn.Module): |
|
|
"""Soft Dot Attention. |
|
|
|
|
|
Ref: http://www.aclweb.org/anthology/D15-1166 |
|
|
Adapted from PyTorch OPEN NMT. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim): |
|
|
"""Initialize layer.""" |
|
|
super(SoftDotAttention, self).__init__() |
|
|
self.linear_in = nn.Linear(dim, dim, bias=False) |
|
|
self.sm = nn.Softmax(dim=1) |
|
|
self.linear_out = nn.Linear(dim * 2, dim, bias=False) |
|
|
self.tanh = nn.Tanh() |
|
|
self.mask = None |
|
|
|
|
|
def forward(self, input, context, mask=None, attn_only=False, return_logattn=False): |
|
|
"""Propagate input through the network. |
|
|
|
|
|
input: batch x dim |
|
|
context: batch x sourceL x dim |
|
|
""" |
|
|
target = self.linear_in(input).unsqueeze(2) |
|
|
|
|
|
|
|
|
attn = torch.bmm(context, target).squeeze(2) |
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
assert mask.size() == attn.size(), "Mask size must match the attention size!" |
|
|
attn.masked_fill_(mask, -constant.INFINITY_NUMBER) |
|
|
|
|
|
if return_logattn: |
|
|
attn = torch.log_softmax(attn, 1) |
|
|
attn_w = torch.exp(attn) |
|
|
else: |
|
|
attn = self.sm(attn) |
|
|
attn_w = attn |
|
|
if attn_only: |
|
|
return attn |
|
|
|
|
|
attn3 = attn_w.view(attn_w.size(0), 1, attn_w.size(1)) |
|
|
|
|
|
weighted_context = torch.bmm(attn3, context).squeeze(1) |
|
|
h_tilde = torch.cat((weighted_context, input), 1) |
|
|
|
|
|
h_tilde = self.tanh(self.linear_out(h_tilde)) |
|
|
|
|
|
return h_tilde, attn |
|
|
|
|
|
|
|
|
class LinearAttention(nn.Module): |
|
|
""" A linear attention form, inspired by BiDAF: |
|
|
a = W (u; v; u o v) |
|
|
""" |
|
|
|
|
|
def __init__(self, dim): |
|
|
super(LinearAttention, self).__init__() |
|
|
self.linear = nn.Linear(dim*3, 1, bias=False) |
|
|
self.linear_out = nn.Linear(dim * 2, dim, bias=False) |
|
|
self.sm = nn.Softmax(dim=1) |
|
|
self.tanh = nn.Tanh() |
|
|
self.mask = None |
|
|
|
|
|
def forward(self, input, context, mask=None, attn_only=False): |
|
|
""" |
|
|
input: batch x dim |
|
|
context: batch x sourceL x dim |
|
|
""" |
|
|
batch_size = context.size(0) |
|
|
source_len = context.size(1) |
|
|
dim = context.size(2) |
|
|
u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) |
|
|
v = context.contiguous().view(-1, dim) |
|
|
attn_in = torch.cat((u, v, u.mul(v)), 1) |
|
|
attn = self.linear(attn_in).view(batch_size, source_len) |
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
assert mask.size() == attn.size(), "Mask size must match the attention size!" |
|
|
attn.masked_fill_(mask, -constant.INFINITY_NUMBER) |
|
|
|
|
|
attn = self.sm(attn) |
|
|
if attn_only: |
|
|
return attn |
|
|
|
|
|
attn3 = attn.view(batch_size, 1, source_len) |
|
|
|
|
|
weighted_context = torch.bmm(attn3, context).squeeze(1) |
|
|
h_tilde = torch.cat((weighted_context, input), 1) |
|
|
h_tilde = self.tanh(self.linear_out(h_tilde)) |
|
|
return h_tilde, attn |
|
|
|
|
|
class DeepAttention(nn.Module): |
|
|
""" A deep attention form, invented by Robert: |
|
|
u = ReLU(Wx) |
|
|
v = ReLU(Wy) |
|
|
a = V.(u o v) |
|
|
""" |
|
|
|
|
|
def __init__(self, dim): |
|
|
super(DeepAttention, self).__init__() |
|
|
self.linear_in = nn.Linear(dim, dim, bias=False) |
|
|
self.linear_v = nn.Linear(dim, 1, bias=False) |
|
|
self.linear_out = nn.Linear(dim * 2, dim, bias=False) |
|
|
self.relu = nn.ReLU() |
|
|
self.sm = nn.Softmax(dim=1) |
|
|
self.tanh = nn.Tanh() |
|
|
self.mask = None |
|
|
|
|
|
def forward(self, input, context, mask=None, attn_only=False): |
|
|
""" |
|
|
input: batch x dim |
|
|
context: batch x sourceL x dim |
|
|
""" |
|
|
batch_size = context.size(0) |
|
|
source_len = context.size(1) |
|
|
dim = context.size(2) |
|
|
u = input.unsqueeze(1).expand_as(context).contiguous().view(-1, dim) |
|
|
u = self.relu(self.linear_in(u)) |
|
|
v = self.relu(self.linear_in(context.contiguous().view(-1, dim))) |
|
|
attn = self.linear_v(u.mul(v)).view(batch_size, source_len) |
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
assert mask.size() == attn.size(), "Mask size must match the attention size!" |
|
|
attn.masked_fill_(mask, -constant.INFINITY_NUMBER) |
|
|
|
|
|
attn = self.sm(attn) |
|
|
if attn_only: |
|
|
return attn |
|
|
|
|
|
attn3 = attn.view(batch_size, 1, source_len) |
|
|
|
|
|
weighted_context = torch.bmm(attn3, context).squeeze(1) |
|
|
h_tilde = torch.cat((weighted_context, input), 1) |
|
|
h_tilde = self.tanh(self.linear_out(h_tilde)) |
|
|
return h_tilde, attn |
|
|
|
|
|
class LSTMAttention(nn.Module): |
|
|
r"""A long short-term memory (LSTM) cell with attention.""" |
|
|
|
|
|
def __init__(self, input_size, hidden_size, batch_first=True, attn_type='soft'): |
|
|
"""Initialize params.""" |
|
|
super(LSTMAttention, self).__init__() |
|
|
self.input_size = input_size |
|
|
self.hidden_size = hidden_size |
|
|
self.batch_first = batch_first |
|
|
|
|
|
self.lstm_cell = nn.LSTMCell(input_size, hidden_size) |
|
|
|
|
|
if attn_type == 'soft': |
|
|
self.attention_layer = SoftDotAttention(hidden_size) |
|
|
elif attn_type == 'mlp': |
|
|
self.attention_layer = BasicAttention(hidden_size) |
|
|
elif attn_type == 'linear': |
|
|
self.attention_layer = LinearAttention(hidden_size) |
|
|
elif attn_type == 'deep': |
|
|
self.attention_layer = DeepAttention(hidden_size) |
|
|
else: |
|
|
raise Exception("Unsupported LSTM attention type: {}".format(attn_type)) |
|
|
logger.debug("Using {} attention for LSTM.".format(attn_type)) |
|
|
|
|
|
def forward(self, input, hidden, ctx, ctx_mask=None, return_logattn=False): |
|
|
"""Propagate input through the network.""" |
|
|
if self.batch_first: |
|
|
input = input.transpose(0,1) |
|
|
|
|
|
output = [] |
|
|
attn = [] |
|
|
steps = range(input.size(0)) |
|
|
for i in steps: |
|
|
hidden = self.lstm_cell(input[i], hidden) |
|
|
hy, cy = hidden |
|
|
h_tilde, alpha = self.attention_layer(hy, ctx, mask=ctx_mask, return_logattn=return_logattn) |
|
|
output.append(h_tilde) |
|
|
attn.append(alpha) |
|
|
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) |
|
|
|
|
|
if self.batch_first: |
|
|
output = output.transpose(0,1) |
|
|
|
|
|
if return_logattn: |
|
|
attn = torch.stack(attn, 0) |
|
|
if self.batch_first: |
|
|
attn = attn.transpose(0, 1) |
|
|
return output, hidden, attn |
|
|
|
|
|
return output, hidden |
|
|
|
|
|
|