FireRedASR-AED / fireredasr /models /module /transformer_decoder.py
inoryQwQ's picture
Shorten kv cache
90f0b29
from typing import List, Optional, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class TransformerDecoder(nn.Module):
def __init__(
self, sos_id, eos_id, pad_id, odim,
n_layers, n_head, d_model,
residual_dropout=0.1, pe_maxlen=5000):
super().__init__()
self.INF = 1e10
# parameters
self.pad_id = pad_id
self.sos_id = sos_id
self.eos_id = eos_id
self.n_layers = n_layers
# Components
self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id)
self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
self.dropout = nn.Dropout(residual_dropout)
self.layer_stack = nn.ModuleList()
for l in range(n_layers):
block = DecoderLayer(d_model, n_head, residual_dropout)
self.layer_stack.append(block)
self.tgt_word_prj = nn.Linear(d_model, odim, bias=False)
self.layer_norm_out = nn.LayerNorm(d_model)
self.tgt_word_prj.weight = self.tgt_word_emb.weight
self.scale = (d_model ** 0.5)
def batch_beam_search(self, encoder_outputs, src_masks,
beam_size=1, nbest=1, decode_max_len=0,
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
B = beam_size
N, Ti, H = encoder_outputs.size()
device = encoder_outputs.device
maxlen = decode_max_len if decode_max_len > 0 else Ti
assert eos_penalty > 0.0 and eos_penalty <= 1.0
# Init
encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H)
src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti)
ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device)
caches: List[Optional[Tensor]] = []
for _ in range(self.n_layers):
caches.append(None)
scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device)
scores = scores.repeat(N).view(N*B, 1)
is_finished = torch.zeros_like(scores)
# Autoregressive Prediction
for t in range(maxlen):
tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id)
dec_output = self.dropout(
self.tgt_word_emb(ys) * self.scale +
self.positional_encoding(ys))
i = 0
for dec_layer in self.layer_stack:
dec_output = dec_layer.forward(
dec_output, encoder_outputs,
tgt_mask, src_mask,
cache=caches[i])
caches[i] = dec_output
i += 1
dec_output = self.layer_norm_out(dec_output)
t_logit = self.tgt_word_prj(dec_output[:, -1])
t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1)
if eos_penalty != 1.0:
t_scores[:, self.eos_id] *= eos_penalty
t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1)
t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished)
t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished)
# Accumulated
scores = scores + t_topB_scores
# Pruning
scores = scores.view(N, B*B)
scores, topB_score_ids = torch.topk(scores, k=B, dim=1)
scores = scores.view(-1, 1)
topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B)
stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device)
topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
# Update ys
ys = ys[topB_row_number_in_ys]
t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
ys = torch.cat((ys, t_ys), dim=1)
# Update caches
new_caches: List[Optional[Tensor]] = []
for cache in caches:
if cache is not None:
new_caches.append(cache[topB_row_number_in_ys])
caches = new_caches
# Update finished state
is_finished = t_ys.eq(self.eos_id)
if is_finished.sum().item() == N*B:
break
# Length penalty (follow GNMT)
scores = scores.view(N, B)
ys = ys.view(N, B, -1)
ys_lengths = self.get_ys_lengths(ys)
if length_penalty > 0.0:
penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty)
scores /= penalty
nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1)
nbest_scores = -1.0 * nbest_scores
index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long()
nbest_ys = ys.view(N*B, -1)[index.view(-1)]
nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1)
nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1)
# result
nbest_hyps: List[List[Dict[str, Tensor]]] = []
for n in range(N):
n_nbest_hyps: List[Dict[str, Tensor]] = []
for i, score in enumerate(nbest_scores[n]):
new_hyp = {
"yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]]
}
n_nbest_hyps.append(new_hyp)
nbest_hyps.append(n_nbest_hyps)
return nbest_hyps
def ignored_target_position_is_0(self, padded_targets, ignore_id):
mask = torch.ne(padded_targets, ignore_id)
mask = mask.unsqueeze(dim=1)
T = padded_targets.size(-1)
upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype)
upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device)
return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8)
def upper_triangular_is_0(self, size):
ones = torch.ones(size, size)
tri_left_ones = torch.tril(ones)
return tri_left_ones.to(torch.uint8)
def set_finished_beam_score_to_zero(self, scores, is_finished):
NB, B = scores.size()
is_finished = is_finished.float()
mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device)
mask_score = mask_score.view(1, B).repeat(NB, 1)
return scores * (1 - is_finished) + mask_score * is_finished
def set_finished_beam_y_to_eos(self, ys, is_finished):
is_finished = is_finished.long()
return ys * (1 - is_finished) + self.eos_id * is_finished
def get_ys_lengths(self, ys):
N, B, Tmax = ys.size()
ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1)
return ys_lengths.int()
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_head, dropout):
super().__init__()
self.self_attn_norm = nn.LayerNorm(d_model)
self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
self.cross_attn_norm = nn.LayerNorm(d_model)
self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
self.mlp_norm = nn.LayerNorm(d_model)
self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)
def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
cache=None):
x = dec_input
residual = x
x = self.self_attn_norm(x)
if cache is not None:
xq = x[:, -1:, :]
residual = residual[:, -1:, :]
self_attn_mask = self_attn_mask[:, -1:, :]
else:
xq = x
x = self.self_attn(xq, x, x, mask=self_attn_mask)
x = residual + x
residual = x
x = self.cross_attn_norm(x)
x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
x = residual + x
residual = x
x = self.mlp_norm(x)
x = residual + self.mlp(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x
class DecoderMultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head, dropout=0.1):
super().__init__()
self.d_model = d_model
self.n_head = n_head
self.d_k = d_model // n_head
self.w_qs = nn.Linear(d_model, n_head * self.d_k)
self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * self.d_k)
self.attention = DecoderScaledDotProductAttention(
temperature=self.d_k ** 0.5)
self.fc = nn.Linear(n_head * self.d_k, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k)
v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1)
output = self.attention(q, k, v, mask=mask)
output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
output = self.fc(output)
output = self.dropout(output)
return output
class DecoderScaledDotProductAttention(nn.Module):
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
self.INF = float("inf")
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
print(f"q.shape: {q.shape}")
print(f"k.shape: {k.shape}")
print(f"attn.shape: {attn.shape}")
if mask is not None:
mask = mask.eq(0)
attn = attn.masked_fill(mask, -self.INF)
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
else:
attn = torch.softmax(attn, dim=-1)
output = torch.matmul(attn, v)
return output
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.act = nn.GELU()
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
output = self.w_2(self.act(self.w_1(x)))
output = self.dropout(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
assert d_model % 2 == 0
pe = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(torch.log(torch.tensor(10000.0)).item()/d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
length = x.size(1)
return self.pe[:, :length].clone().detach()