|
|
from typing import List, Optional, Dict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, sos_id, eos_id, pad_id, odim, |
|
|
n_layers, n_head, d_model, |
|
|
residual_dropout=0.1, pe_maxlen=5000): |
|
|
super().__init__() |
|
|
self.INF = 1e10 |
|
|
|
|
|
self.pad_id = pad_id |
|
|
self.sos_id = sos_id |
|
|
self.eos_id = eos_id |
|
|
self.n_layers = n_layers |
|
|
|
|
|
|
|
|
self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id) |
|
|
self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) |
|
|
self.dropout = nn.Dropout(residual_dropout) |
|
|
|
|
|
self.layer_stack = nn.ModuleList() |
|
|
for l in range(n_layers): |
|
|
block = DecoderLayer(d_model, n_head, residual_dropout) |
|
|
self.layer_stack.append(block) |
|
|
|
|
|
self.tgt_word_prj = nn.Linear(d_model, odim, bias=False) |
|
|
self.layer_norm_out = nn.LayerNorm(d_model) |
|
|
|
|
|
self.tgt_word_prj.weight = self.tgt_word_emb.weight |
|
|
self.scale = (d_model ** 0.5) |
|
|
|
|
|
def batch_beam_search(self, encoder_outputs, src_masks, |
|
|
beam_size=1, nbest=1, decode_max_len=0, |
|
|
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): |
|
|
B = beam_size |
|
|
N, Ti, H = encoder_outputs.size() |
|
|
device = encoder_outputs.device |
|
|
maxlen = decode_max_len if decode_max_len > 0 else Ti |
|
|
assert eos_penalty > 0.0 and eos_penalty <= 1.0 |
|
|
|
|
|
|
|
|
encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H) |
|
|
src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti) |
|
|
ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device) |
|
|
caches: List[Optional[Tensor]] = [] |
|
|
for _ in range(self.n_layers): |
|
|
caches.append(None) |
|
|
scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device) |
|
|
scores = scores.repeat(N).view(N*B, 1) |
|
|
is_finished = torch.zeros_like(scores) |
|
|
|
|
|
|
|
|
for t in range(maxlen): |
|
|
tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) |
|
|
|
|
|
dec_output = self.dropout( |
|
|
self.tgt_word_emb(ys) * self.scale + |
|
|
self.positional_encoding(ys)) |
|
|
|
|
|
i = 0 |
|
|
for dec_layer in self.layer_stack: |
|
|
dec_output = dec_layer.forward( |
|
|
dec_output, encoder_outputs, |
|
|
tgt_mask, src_mask, |
|
|
cache=caches[i]) |
|
|
caches[i] = dec_output |
|
|
i += 1 |
|
|
|
|
|
dec_output = self.layer_norm_out(dec_output) |
|
|
|
|
|
t_logit = self.tgt_word_prj(dec_output[:, -1]) |
|
|
t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1) |
|
|
|
|
|
if eos_penalty != 1.0: |
|
|
t_scores[:, self.eos_id] *= eos_penalty |
|
|
|
|
|
t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1) |
|
|
t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished) |
|
|
t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished) |
|
|
|
|
|
|
|
|
scores = scores + t_topB_scores |
|
|
|
|
|
|
|
|
scores = scores.view(N, B*B) |
|
|
scores, topB_score_ids = torch.topk(scores, k=B, dim=1) |
|
|
scores = scores.view(-1, 1) |
|
|
|
|
|
topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B) |
|
|
stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device) |
|
|
topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long() |
|
|
|
|
|
|
|
|
ys = ys[topB_row_number_in_ys] |
|
|
t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1) |
|
|
ys = torch.cat((ys, t_ys), dim=1) |
|
|
|
|
|
|
|
|
new_caches: List[Optional[Tensor]] = [] |
|
|
for cache in caches: |
|
|
if cache is not None: |
|
|
new_caches.append(cache[topB_row_number_in_ys]) |
|
|
caches = new_caches |
|
|
|
|
|
|
|
|
is_finished = t_ys.eq(self.eos_id) |
|
|
if is_finished.sum().item() == N*B: |
|
|
break |
|
|
|
|
|
|
|
|
scores = scores.view(N, B) |
|
|
ys = ys.view(N, B, -1) |
|
|
ys_lengths = self.get_ys_lengths(ys) |
|
|
if length_penalty > 0.0: |
|
|
penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty) |
|
|
scores /= penalty |
|
|
nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1) |
|
|
nbest_scores = -1.0 * nbest_scores |
|
|
index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long() |
|
|
nbest_ys = ys.view(N*B, -1)[index.view(-1)] |
|
|
nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1) |
|
|
nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1) |
|
|
|
|
|
|
|
|
nbest_hyps: List[List[Dict[str, Tensor]]] = [] |
|
|
for n in range(N): |
|
|
n_nbest_hyps: List[Dict[str, Tensor]] = [] |
|
|
for i, score in enumerate(nbest_scores[n]): |
|
|
new_hyp = { |
|
|
"yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]] |
|
|
} |
|
|
n_nbest_hyps.append(new_hyp) |
|
|
nbest_hyps.append(n_nbest_hyps) |
|
|
return nbest_hyps |
|
|
|
|
|
def ignored_target_position_is_0(self, padded_targets, ignore_id): |
|
|
mask = torch.ne(padded_targets, ignore_id) |
|
|
mask = mask.unsqueeze(dim=1) |
|
|
T = padded_targets.size(-1) |
|
|
upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype) |
|
|
upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device) |
|
|
return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8) |
|
|
|
|
|
def upper_triangular_is_0(self, size): |
|
|
ones = torch.ones(size, size) |
|
|
tri_left_ones = torch.tril(ones) |
|
|
return tri_left_ones.to(torch.uint8) |
|
|
|
|
|
def set_finished_beam_score_to_zero(self, scores, is_finished): |
|
|
NB, B = scores.size() |
|
|
is_finished = is_finished.float() |
|
|
mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device) |
|
|
mask_score = mask_score.view(1, B).repeat(NB, 1) |
|
|
return scores * (1 - is_finished) + mask_score * is_finished |
|
|
|
|
|
def set_finished_beam_y_to_eos(self, ys, is_finished): |
|
|
is_finished = is_finished.long() |
|
|
return ys * (1 - is_finished) + self.eos_id * is_finished |
|
|
|
|
|
def get_ys_lengths(self, ys): |
|
|
N, B, Tmax = ys.size() |
|
|
ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1) |
|
|
return ys_lengths.int() |
|
|
|
|
|
|
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
|
def __init__(self, d_model, n_head, dropout): |
|
|
super().__init__() |
|
|
self.self_attn_norm = nn.LayerNorm(d_model) |
|
|
self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) |
|
|
|
|
|
self.cross_attn_norm = nn.LayerNorm(d_model) |
|
|
self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) |
|
|
|
|
|
self.mlp_norm = nn.LayerNorm(d_model) |
|
|
self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) |
|
|
|
|
|
def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, |
|
|
cache=None): |
|
|
x = dec_input |
|
|
residual = x |
|
|
x = self.self_attn_norm(x) |
|
|
if cache is not None: |
|
|
xq = x[:, -1:, :] |
|
|
residual = residual[:, -1:, :] |
|
|
self_attn_mask = self_attn_mask[:, -1:, :] |
|
|
else: |
|
|
xq = x |
|
|
x = self.self_attn(xq, x, x, mask=self_attn_mask) |
|
|
x = residual + x |
|
|
|
|
|
residual = x |
|
|
x = self.cross_attn_norm(x) |
|
|
x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask) |
|
|
x = residual + x |
|
|
|
|
|
residual = x |
|
|
x = self.mlp_norm(x) |
|
|
x = residual + self.mlp(x) |
|
|
|
|
|
if cache is not None: |
|
|
x = torch.cat([cache, x], dim=1) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class DecoderMultiHeadAttention(nn.Module): |
|
|
def __init__(self, d_model, n_head, dropout=0.1): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_head = n_head |
|
|
self.d_k = d_model // n_head |
|
|
|
|
|
self.w_qs = nn.Linear(d_model, n_head * self.d_k) |
|
|
self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) |
|
|
self.w_vs = nn.Linear(d_model, n_head * self.d_k) |
|
|
|
|
|
self.attention = DecoderScaledDotProductAttention( |
|
|
temperature=self.d_k ** 0.5) |
|
|
self.fc = nn.Linear(n_head * self.d_k, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, q, k, v, mask=None): |
|
|
bs = q.size(0) |
|
|
|
|
|
q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) |
|
|
k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) |
|
|
v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) |
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(1) |
|
|
|
|
|
output = self.attention(q, k, v, mask=mask) |
|
|
|
|
|
output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) |
|
|
output = self.fc(output) |
|
|
output = self.dropout(output) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class DecoderScaledDotProductAttention(nn.Module): |
|
|
def __init__(self, temperature): |
|
|
super().__init__() |
|
|
self.temperature = temperature |
|
|
self.INF = float("inf") |
|
|
|
|
|
def forward(self, q, k, v, mask=None): |
|
|
attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature |
|
|
print(f"q.shape: {q.shape}") |
|
|
print(f"k.shape: {k.shape}") |
|
|
print(f"attn.shape: {attn.shape}") |
|
|
if mask is not None: |
|
|
mask = mask.eq(0) |
|
|
attn = attn.masked_fill(mask, -self.INF) |
|
|
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) |
|
|
else: |
|
|
attn = torch.softmax(attn, dim=-1) |
|
|
output = torch.matmul(attn, v) |
|
|
return output |
|
|
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
|
def __init__(self, d_model, d_ff, dropout=0.1): |
|
|
super().__init__() |
|
|
self.w_1 = nn.Linear(d_model, d_ff) |
|
|
self.act = nn.GELU() |
|
|
self.w_2 = nn.Linear(d_ff, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self.w_2(self.act(self.w_1(x))) |
|
|
output = self.dropout(output) |
|
|
return output |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
def __init__(self, d_model, max_len=5000): |
|
|
super().__init__() |
|
|
assert d_model % 2 == 0 |
|
|
pe = torch.zeros(max_len, d_model, requires_grad=False) |
|
|
position = torch.arange(0, max_len).unsqueeze(1).float() |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * |
|
|
-(torch.log(torch.tensor(10000.0)).item()/d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0) |
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x): |
|
|
length = x.size(1) |
|
|
return self.pe[:, :length].clone().detach() |
|
|
|