import torch import torch.nn as nn from torch import Tensor from fireredasr.models.module.conformer_encoder import ConformerEncoder from fireredasr.models.module.transformer_decoder import ( TransformerDecoder, DecoderLayer, DecoderMultiHeadAttention, DecoderScaledDotProductAttention, PositionalEncoding ) def DecoderScaledDotProductAttentionForward( self: DecoderScaledDotProductAttention, q: Tensor, k: Tensor, v: Tensor, mask: Tensor ): attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature if mask is not None: # mask is such as [[[0, 0, 0, 0, ..., -inf, -inf]]] attn = attn + mask attn = torch.softmax(attn, dim=-1) else: attn = torch.softmax(attn, dim=-1) output = torch.matmul(attn, v) return output DecoderScaledDotProductAttention.forward = DecoderScaledDotProductAttentionForward """ The purpose of this is to allow the exported onnx model to only need to pass in the token id of the decoding result of the previous time step when performing decoding inference at each time step, rather than the token id of all previous time steps. """ def PositionalEncodingForward( self: PositionalEncoding, offset: Tensor ): return self.pe[:, :offset].clone().detach()[:, -1] PositionalEncoding.forward = PositionalEncodingForward """ NOTE(Lianghu): Why do that? When exporting the onnx model using original padding_position_is_0 funciton, the dynamic batch does not work properly for the exported onnx model. The code in the original padding_position_is_0 function is as follows: ```py def padding_position_is_0(...): N, T = padded_input.size()[:2] mask = torch.ones((N, T)).to(padded_input.device) ... ``` Because when exporting onnx, N and T are considered constants. Should be N = padded_input.size(0) and T = padded_input.size(1). """ def padding_position_is_0(self: ConformerEncoder, padded_input: Tensor, input_lengths: Tensor): N = padded_input.size(0) T = padded_input.size(1) seq_range = torch.arange(T, device=padded_input.device).unsqueeze(0) # shape: (1, T) input_lengths_exp = input_lengths.unsqueeze(1) # shape: (N, 1) mask = seq_range < input_lengths_exp # shape: (N, T) mask = mask.unsqueeze(dim=1) return mask.to(torch.uint8) ConformerEncoder.padding_position_is_0 = padding_position_is_0 class AudioEncoderTensorCache(nn.Module): def __init__(self, encoder: ConformerEncoder, decoder: TransformerDecoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, input: Tensor, input_length: Tensor): encoder_output, _, encoder_mask = self.encoder(input, input_length) n_layer_cross_k_list = [] n_layer_cross_v_list = [] for layer in self.decoder.layer_stack: # layer: DecoderLayer n_layer_cross_k_list.append(layer.cross_attn.w_ks(encoder_output)) n_layer_cross_v_list.append(layer.cross_attn.w_vs(encoder_output)) encoder_mask = encoder_mask.to(torch.float32) encoder_mask[encoder_mask == 0] = -torch.inf encoder_mask[encoder_mask == 1] = 0.0 return (torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list), encoder_mask) class DecoderMultiHeadSelfAttention(nn.Module): def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False): super().__init__() self.multiHeadSelfAttention = multiHeadSelfAttention self.loop = loop def forward(self, x: Tensor, k_cache: Tensor, v_cache: Tensor, mask: Tensor): bs = x.size(0) # 当前时间步为 t # k_cache 和 v_cache 是 时间步 [0: t-1] 的 self_attn_k 和 self_attn_v 的缓存 q = self.multiHeadSelfAttention.w_qs(x) k = self.multiHeadSelfAttention.w_ks(x) v = self.multiHeadSelfAttention.w_vs(x) k_cache[:, -k.shape[1] :, :] = k v_cache[:, -v.shape[1] :, :] = v # if self.loop: # k_cache = torch.cat([k_cache[:, 1:, :], k], 1) # v_cache = torch.cat([v_cache[:, 1:, :], v], 1) # else: # k_cache = k # v_cache = v q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1) output = self.multiHeadSelfAttention.attention(q, k, v, mask) output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model) output = self.multiHeadSelfAttention.fc(output) output = self.multiHeadSelfAttention.dropout(output) return output, k_cache, v_cache class DecoderMultiHeadSelfAttentionV2(nn.Module): def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False): super().__init__() self.multiHeadSelfAttention = multiHeadSelfAttention self.loop = loop def forward(self, x: Tensor, k_cache: Tensor, v_cache: Tensor, mask: Tensor): bs = x.size(0) # 当前时间步为 t # k_cache 和 v_cache 是 时间步 [0: t-1] 的 self_attn_k 和 self_attn_v 的缓存 q = self.multiHeadSelfAttention.w_qs(x) k = self.multiHeadSelfAttention.w_ks(x) v = self.multiHeadSelfAttention.w_vs(x) # k_cache[:, -k.shape[1] :, :] = k # v_cache[:, -v.shape[1] :, :] = v if self.loop: k_cache = torch.cat([k_cache[:, 1:, :], k], 1) v_cache = torch.cat([v_cache[:, 1:, :], v], 1) else: k_cache = k v_cache = v q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1) output = self.multiHeadSelfAttention.attention(q, k, v, mask) output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model) output = self.multiHeadSelfAttention.fc(output) output = self.multiHeadSelfAttention.dropout(output) return output, k_cache, v_cache class DecoderMultiHeadCrossAttention(nn.Module): def __init__(self, multiHeadCrossAttention: DecoderMultiHeadAttention): super().__init__() self.multiHeadCrossAttention = multiHeadCrossAttention def forward(self, x: Tensor, k: Tensor, v: Tensor, mask: Tensor): bs = x.size(0) x = self.multiHeadCrossAttention.w_qs(x) x = x.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k) k = k.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k) v = v.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k) x = x.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1) output = self.multiHeadCrossAttention.attention(x, k, v, mask) output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadCrossAttention.d_model) output = self.multiHeadCrossAttention.fc(output) output = self.multiHeadCrossAttention.dropout(output) return output class ResidualAttentionBlockTensorCache(nn.Module): def __init__(self, decoder_layer: DecoderLayer, loop: bool = False): super().__init__() self.original_decoder_layer = decoder_layer self.self_attn = DecoderMultiHeadSelfAttention(decoder_layer.self_attn, loop) self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn) def forward(self, x: Tensor, self_k_cache: Tensor, self_v_cache: Tensor, cross_k: Tensor, cross_v: Tensor, self_attn_mask: Tensor, cross_attn_mask: Tensor): # q.shape (B, 1, dim) x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x) self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn( x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask) x = x + self_attn_x residual = x x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x) x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask) x = residual + x_cross_attn x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x)) return x, self_k_cache_updated, self_v_cache_updated class ResidualAttentionBlockTensorCacheV2(nn.Module): def __init__(self, decoder_layer: DecoderLayer, loop: bool = False): super().__init__() self.original_decoder_layer = decoder_layer self.self_attn = DecoderMultiHeadSelfAttentionV2(decoder_layer.self_attn, loop) self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn) def forward(self, x: Tensor, self_k_cache: Tensor, self_v_cache: Tensor, cross_k: Tensor, cross_v: Tensor, self_attn_mask: Tensor, cross_attn_mask: Tensor): # q.shape (B, 1, dim) x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x) self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn( x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask) x = x + self_attn_x residual = x x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x) x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask) x = residual + x_cross_attn x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x)) return x, self_k_cache_updated, self_v_cache_updated class TextDecoderTensorCache(nn.Module): def __init__(self, decoder: TransformerDecoder): super().__init__() self.decoder = decoder self.blocks = [] for original_layer in self.decoder.layer_stack: self.blocks.append( ResidualAttentionBlockTensorCache(original_layer)) def forward(self, tokens: Tensor, n_layer_self_k_cache: Tensor, n_layer_self_v_cache: Tensor, n_layer_cross_k: Tensor, n_layer_cross_v: Tensor, offset: Tensor, self_attn_mask: Tensor, cross_attn_mask: Tensor): """ TODO(Lianghu): Integrate self_attn_mask into the model inference process instead of passing it in through an external interface. """ x = self.decoder.dropout( self.decoder.tgt_word_emb(tokens) * self.decoder.scale + self.decoder.positional_encoding(offset + 1) ) i = 0 for block in self.blocks: self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] x, self_k_cache, self_v_cache = block( x, self_k_cache, self_v_cache, n_layer_cross_k[i], n_layer_cross_v[i], self_attn_mask, cross_attn_mask ) n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache i += 1 output = self.decoder.layer_norm_out(x) logits = self.decoder.tgt_word_prj(output) return logits, n_layer_self_k_cache, n_layer_self_v_cache class TextDecoderTensorCacheV2(nn.Module): def __init__(self, decoder: TransformerDecoder, loop: bool = False): super().__init__() self.decoder = decoder self.loop = loop self.blocks = [] for original_layer in self.decoder.layer_stack: self.blocks.append( ResidualAttentionBlockTensorCacheV2(original_layer, loop)) def forward(self, tokens: Tensor, n_layer_self_k_cache: Tensor, n_layer_self_v_cache: Tensor, n_layer_cross_k: Tensor, n_layer_cross_v: Tensor, positional_embedding: Tensor, self_attn_mask: Tensor, cross_attn_mask: Tensor): """ TODO(Lianghu): Integrate self_attn_mask into the model inference process instead of passing it in through an external interface. """ x = self.decoder.dropout( self.decoder.tgt_word_emb(tokens) * self.decoder.scale + positional_embedding ) # if self.loop: # x = self.decoder.dropout( # self.decoder.tgt_word_emb(tokens) * self.decoder.scale + # positional_embedding # ) # else: # x = self.decoder.dropout( # self.decoder.tgt_word_emb(tokens) * self.decoder.scale + # self.decoder.positional_encoding.pe[:, : tokens.shape[-1]] # ) i = 0 self_k_cache_out = [] self_v_cache_out = [] for block in self.blocks: self_k_cache = n_layer_self_k_cache[i, :, :, :] self_v_cache = n_layer_self_v_cache[i, :, :, :] if self.loop: x, self_k_cache, self_v_cache = block( x, self_k_cache, self_v_cache, n_layer_cross_k[i], n_layer_cross_v[i], self_attn_mask, cross_attn_mask ) self_k_cache_out.append(self_k_cache.unsqueeze(0)) self_v_cache_out.append(self_v_cache.unsqueeze(0)) else: n_audio, n_text_ctx, ntext_state = self_k_cache.shape x, self_k_cache, self_v_cache = block( x, self_k_cache, self_v_cache, n_layer_cross_k[i], n_layer_cross_v[i], self_attn_mask, cross_attn_mask ) self_k_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_k_cache.shape[1], ntext_state]).to(self_k_cache.device), self_k_cache), 1).unsqueeze(0)) self_v_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_v_cache.shape[1], ntext_state]).to(self_v_cache.device), self_v_cache), 1).unsqueeze(0)) i += 1 n_layer_self_k_cache = torch.cat(self_k_cache_out, 0) n_layer_self_v_cache = torch.cat(self_v_cache_out, 0) output = self.decoder.layer_norm_out(x) logits = self.decoder.tgt_word_prj(output) return logits, n_layer_self_k_cache, n_layer_self_v_cache