|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
|
|
|
from fireredasr.models.module.conformer_encoder import ConformerEncoder |
|
|
from fireredasr.models.module.transformer_decoder import ( |
|
|
TransformerDecoder, |
|
|
DecoderLayer, |
|
|
DecoderMultiHeadAttention, |
|
|
DecoderScaledDotProductAttention, |
|
|
PositionalEncoding |
|
|
) |
|
|
|
|
|
|
|
|
def DecoderScaledDotProductAttentionForward( |
|
|
self: DecoderScaledDotProductAttention, |
|
|
q: Tensor, |
|
|
k: Tensor, |
|
|
v: Tensor, |
|
|
mask: Tensor |
|
|
): |
|
|
attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature |
|
|
if mask is not None: |
|
|
|
|
|
attn = attn + mask |
|
|
attn = torch.softmax(attn, dim=-1) |
|
|
else: |
|
|
attn = torch.softmax(attn, dim=-1) |
|
|
output = torch.matmul(attn, v) |
|
|
return output |
|
|
|
|
|
DecoderScaledDotProductAttention.forward = DecoderScaledDotProductAttentionForward |
|
|
|
|
|
|
|
|
""" |
|
|
The purpose of this is to allow the exported onnx model |
|
|
to only need to pass in the token id of the decoding result |
|
|
of the previous time step when performing decoding inference at each time step, |
|
|
rather than the token id of all previous time steps. |
|
|
""" |
|
|
def PositionalEncodingForward( |
|
|
self: PositionalEncoding, |
|
|
offset: Tensor |
|
|
): |
|
|
return self.pe[:, :offset].clone().detach()[:, -1] |
|
|
|
|
|
PositionalEncoding.forward = PositionalEncodingForward |
|
|
|
|
|
|
|
|
""" |
|
|
NOTE(Lianghu): Why do that? |
|
|
|
|
|
When exporting the onnx model using original padding_position_is_0 funciton, |
|
|
the dynamic batch does not work properly for the exported onnx model. |
|
|
|
|
|
The code in the original padding_position_is_0 function is as follows: |
|
|
```py |
|
|
def padding_position_is_0(...): |
|
|
N, T = padded_input.size()[:2] |
|
|
mask = torch.ones((N, T)).to(padded_input.device) |
|
|
... |
|
|
``` |
|
|
|
|
|
Because when exporting onnx, N and T are considered constants. |
|
|
Should be N = padded_input.size(0) and T = padded_input.size(1). |
|
|
""" |
|
|
def padding_position_is_0(self: ConformerEncoder, |
|
|
padded_input: Tensor, |
|
|
input_lengths: Tensor): |
|
|
N = padded_input.size(0) |
|
|
T = padded_input.size(1) |
|
|
seq_range = torch.arange(T, device=padded_input.device).unsqueeze(0) |
|
|
input_lengths_exp = input_lengths.unsqueeze(1) |
|
|
mask = seq_range < input_lengths_exp |
|
|
mask = mask.unsqueeze(dim=1) |
|
|
return mask.to(torch.uint8) |
|
|
|
|
|
|
|
|
ConformerEncoder.padding_position_is_0 = padding_position_is_0 |
|
|
|
|
|
class AudioEncoderTensorCache(nn.Module): |
|
|
def __init__(self, |
|
|
encoder: ConformerEncoder, |
|
|
decoder: TransformerDecoder): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
|
|
|
def forward(self, input: Tensor, input_length: Tensor): |
|
|
encoder_output, _, encoder_mask = self.encoder(input, input_length) |
|
|
|
|
|
n_layer_cross_k_list = [] |
|
|
n_layer_cross_v_list = [] |
|
|
|
|
|
for layer in self.decoder.layer_stack: |
|
|
|
|
|
n_layer_cross_k_list.append(layer.cross_attn.w_ks(encoder_output)) |
|
|
n_layer_cross_v_list.append(layer.cross_attn.w_vs(encoder_output)) |
|
|
|
|
|
encoder_mask = encoder_mask.to(torch.float32) |
|
|
encoder_mask[encoder_mask == 0] = -torch.inf |
|
|
encoder_mask[encoder_mask == 1] = 0.0 |
|
|
|
|
|
return (torch.stack(n_layer_cross_k_list), |
|
|
torch.stack(n_layer_cross_v_list), |
|
|
encoder_mask) |
|
|
|
|
|
|
|
|
class DecoderMultiHeadSelfAttention(nn.Module): |
|
|
def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False): |
|
|
super().__init__() |
|
|
self.multiHeadSelfAttention = multiHeadSelfAttention |
|
|
self.loop = loop |
|
|
|
|
|
def forward(self, |
|
|
x: Tensor, |
|
|
k_cache: Tensor, |
|
|
v_cache: Tensor, |
|
|
mask: Tensor): |
|
|
bs = x.size(0) |
|
|
|
|
|
|
|
|
|
|
|
q = self.multiHeadSelfAttention.w_qs(x) |
|
|
k = self.multiHeadSelfAttention.w_ks(x) |
|
|
v = self.multiHeadSelfAttention.w_vs(x) |
|
|
|
|
|
k_cache[:, -k.shape[1] :, :] = k |
|
|
v_cache[:, -v.shape[1] :, :] = v |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(1) |
|
|
|
|
|
output = self.multiHeadSelfAttention.attention(q, k, v, mask) |
|
|
output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model) |
|
|
output = self.multiHeadSelfAttention.fc(output) |
|
|
output = self.multiHeadSelfAttention.dropout(output) |
|
|
|
|
|
return output, k_cache, v_cache |
|
|
|
|
|
|
|
|
class DecoderMultiHeadSelfAttentionV2(nn.Module): |
|
|
def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False): |
|
|
super().__init__() |
|
|
self.multiHeadSelfAttention = multiHeadSelfAttention |
|
|
self.loop = loop |
|
|
|
|
|
def forward(self, |
|
|
x: Tensor, |
|
|
k_cache: Tensor, |
|
|
v_cache: Tensor, |
|
|
mask: Tensor): |
|
|
bs = x.size(0) |
|
|
|
|
|
|
|
|
|
|
|
q = self.multiHeadSelfAttention.w_qs(x) |
|
|
k = self.multiHeadSelfAttention.w_ks(x) |
|
|
v = self.multiHeadSelfAttention.w_vs(x) |
|
|
|
|
|
|
|
|
|
|
|
if self.loop: |
|
|
k_cache = torch.cat([k_cache[:, 1:, :], k], 1) |
|
|
v_cache = torch.cat([v_cache[:, 1:, :], v], 1) |
|
|
else: |
|
|
k_cache = k |
|
|
v_cache = v |
|
|
|
|
|
q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k) |
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(1) |
|
|
|
|
|
output = self.multiHeadSelfAttention.attention(q, k, v, mask) |
|
|
output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model) |
|
|
output = self.multiHeadSelfAttention.fc(output) |
|
|
output = self.multiHeadSelfAttention.dropout(output) |
|
|
|
|
|
return output, k_cache, v_cache |
|
|
|
|
|
|
|
|
class DecoderMultiHeadCrossAttention(nn.Module): |
|
|
def __init__(self, multiHeadCrossAttention: DecoderMultiHeadAttention): |
|
|
super().__init__() |
|
|
self.multiHeadCrossAttention = multiHeadCrossAttention |
|
|
|
|
|
def forward(self, |
|
|
x: Tensor, |
|
|
k: Tensor, |
|
|
v: Tensor, |
|
|
mask: Tensor): |
|
|
bs = x.size(0) |
|
|
x = self.multiHeadCrossAttention.w_qs(x) |
|
|
x = x.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k) |
|
|
k = k.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k) |
|
|
v = v.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k) |
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(1) |
|
|
|
|
|
output = self.multiHeadCrossAttention.attention(x, k, v, mask) |
|
|
output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadCrossAttention.d_model) |
|
|
output = self.multiHeadCrossAttention.fc(output) |
|
|
output = self.multiHeadCrossAttention.dropout(output) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class ResidualAttentionBlockTensorCache(nn.Module): |
|
|
def __init__(self, decoder_layer: DecoderLayer, loop: bool = False): |
|
|
super().__init__() |
|
|
self.original_decoder_layer = decoder_layer |
|
|
self.self_attn = DecoderMultiHeadSelfAttention(decoder_layer.self_attn, loop) |
|
|
self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn) |
|
|
|
|
|
def forward(self, |
|
|
x: Tensor, |
|
|
self_k_cache: Tensor, |
|
|
self_v_cache: Tensor, |
|
|
cross_k: Tensor, |
|
|
cross_v: Tensor, |
|
|
self_attn_mask: Tensor, |
|
|
cross_attn_mask: Tensor): |
|
|
|
|
|
x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x) |
|
|
self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn( |
|
|
x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask) |
|
|
|
|
|
x = x + self_attn_x |
|
|
|
|
|
residual = x |
|
|
x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x) |
|
|
x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask) |
|
|
x = residual + x_cross_attn |
|
|
|
|
|
x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x)) |
|
|
|
|
|
return x, self_k_cache_updated, self_v_cache_updated |
|
|
|
|
|
|
|
|
class ResidualAttentionBlockTensorCacheV2(nn.Module): |
|
|
def __init__(self, decoder_layer: DecoderLayer, loop: bool = False): |
|
|
super().__init__() |
|
|
self.original_decoder_layer = decoder_layer |
|
|
self.self_attn = DecoderMultiHeadSelfAttentionV2(decoder_layer.self_attn, loop) |
|
|
self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn) |
|
|
|
|
|
def forward(self, |
|
|
x: Tensor, |
|
|
self_k_cache: Tensor, |
|
|
self_v_cache: Tensor, |
|
|
cross_k: Tensor, |
|
|
cross_v: Tensor, |
|
|
self_attn_mask: Tensor, |
|
|
cross_attn_mask: Tensor): |
|
|
|
|
|
x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x) |
|
|
self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn( |
|
|
x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask) |
|
|
|
|
|
x = x + self_attn_x |
|
|
|
|
|
residual = x |
|
|
x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x) |
|
|
x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask) |
|
|
x = residual + x_cross_attn |
|
|
|
|
|
x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x)) |
|
|
|
|
|
return x, self_k_cache_updated, self_v_cache_updated |
|
|
|
|
|
|
|
|
class TextDecoderTensorCache(nn.Module): |
|
|
def __init__(self, decoder: TransformerDecoder): |
|
|
super().__init__() |
|
|
self.decoder = decoder |
|
|
|
|
|
self.blocks = [] |
|
|
for original_layer in self.decoder.layer_stack: |
|
|
self.blocks.append( |
|
|
ResidualAttentionBlockTensorCache(original_layer)) |
|
|
|
|
|
def forward(self, |
|
|
tokens: Tensor, |
|
|
n_layer_self_k_cache: Tensor, |
|
|
n_layer_self_v_cache: Tensor, |
|
|
n_layer_cross_k: Tensor, |
|
|
n_layer_cross_v: Tensor, |
|
|
offset: Tensor, |
|
|
self_attn_mask: Tensor, |
|
|
cross_attn_mask: Tensor): |
|
|
""" |
|
|
TODO(Lianghu): Integrate self_attn_mask into the model inference process |
|
|
instead of passing it in through an external interface. |
|
|
""" |
|
|
x = self.decoder.dropout( |
|
|
self.decoder.tgt_word_emb(tokens) * self.decoder.scale + |
|
|
self.decoder.positional_encoding(offset + 1) |
|
|
) |
|
|
|
|
|
i = 0 |
|
|
for block in self.blocks: |
|
|
self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] |
|
|
self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] |
|
|
x, self_k_cache, self_v_cache = block( |
|
|
x, |
|
|
self_k_cache, |
|
|
self_v_cache, |
|
|
n_layer_cross_k[i], |
|
|
n_layer_cross_v[i], |
|
|
self_attn_mask, |
|
|
cross_attn_mask |
|
|
) |
|
|
n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache |
|
|
n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache |
|
|
i += 1 |
|
|
|
|
|
output = self.decoder.layer_norm_out(x) |
|
|
logits = self.decoder.tgt_word_prj(output) |
|
|
|
|
|
return logits, n_layer_self_k_cache, n_layer_self_v_cache |
|
|
|
|
|
|
|
|
class TextDecoderTensorCacheV2(nn.Module): |
|
|
def __init__(self, decoder: TransformerDecoder, loop: bool = False): |
|
|
super().__init__() |
|
|
self.decoder = decoder |
|
|
self.loop = loop |
|
|
|
|
|
self.blocks = [] |
|
|
for original_layer in self.decoder.layer_stack: |
|
|
self.blocks.append( |
|
|
ResidualAttentionBlockTensorCacheV2(original_layer, loop)) |
|
|
|
|
|
def forward(self, |
|
|
tokens: Tensor, |
|
|
n_layer_self_k_cache: Tensor, |
|
|
n_layer_self_v_cache: Tensor, |
|
|
n_layer_cross_k: Tensor, |
|
|
n_layer_cross_v: Tensor, |
|
|
positional_embedding: Tensor, |
|
|
self_attn_mask: Tensor, |
|
|
cross_attn_mask: Tensor): |
|
|
""" |
|
|
TODO(Lianghu): Integrate self_attn_mask into the model inference process |
|
|
instead of passing it in through an external interface. |
|
|
""" |
|
|
x = self.decoder.dropout( |
|
|
self.decoder.tgt_word_emb(tokens) * self.decoder.scale + |
|
|
positional_embedding |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
i = 0 |
|
|
self_k_cache_out = [] |
|
|
self_v_cache_out = [] |
|
|
for block in self.blocks: |
|
|
self_k_cache = n_layer_self_k_cache[i, :, :, :] |
|
|
self_v_cache = n_layer_self_v_cache[i, :, :, :] |
|
|
if self.loop: |
|
|
x, self_k_cache, self_v_cache = block( |
|
|
x, |
|
|
self_k_cache, |
|
|
self_v_cache, |
|
|
n_layer_cross_k[i], |
|
|
n_layer_cross_v[i], |
|
|
self_attn_mask, |
|
|
cross_attn_mask |
|
|
) |
|
|
self_k_cache_out.append(self_k_cache.unsqueeze(0)) |
|
|
self_v_cache_out.append(self_v_cache.unsqueeze(0)) |
|
|
else: |
|
|
n_audio, n_text_ctx, ntext_state = self_k_cache.shape |
|
|
|
|
|
x, self_k_cache, self_v_cache = block( |
|
|
x, |
|
|
self_k_cache, |
|
|
self_v_cache, |
|
|
n_layer_cross_k[i], |
|
|
n_layer_cross_v[i], |
|
|
self_attn_mask, |
|
|
cross_attn_mask |
|
|
) |
|
|
self_k_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_k_cache.shape[1], ntext_state]).to(self_k_cache.device), self_k_cache), 1).unsqueeze(0)) |
|
|
self_v_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_v_cache.shape[1], ntext_state]).to(self_v_cache.device), self_v_cache), 1).unsqueeze(0)) |
|
|
|
|
|
i += 1 |
|
|
|
|
|
n_layer_self_k_cache = torch.cat(self_k_cache_out, 0) |
|
|
n_layer_self_v_cache = torch.cat(self_v_cache_out, 0) |
|
|
|
|
|
output = self.decoder.layer_norm_out(x) |
|
|
logits = self.decoder.tgt_word_prj(output) |
|
|
|
|
|
return logits, n_layer_self_k_cache, n_layer_self_v_cache |