Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- encoding: utf-8 -*- | |
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import torch | |
| from typing import Dict, List, Optional, Tuple | |
| from funasr_detach.register import tables | |
| from funasr_detach.models.rwkv_bat.rwkv import RWKV | |
| from funasr_detach.models.transformer.layer_norm import LayerNorm | |
| from funasr_detach.models.transformer.utils.nets_utils import make_source_mask | |
| from funasr_detach.models.rwkv_bat.rwkv_subsampling import RWKVConvInput | |
| class RWKVEncoder(torch.nn.Module): | |
| """RWKV encoder module. | |
| Based on https://arxiv.org/pdf/2305.13048.pdf. | |
| Args: | |
| vocab_size: Vocabulary size. | |
| output_size: Input/Output size. | |
| context_size: Context size for WKV computation. | |
| linear_size: FeedForward hidden size. | |
| attention_size: SelfAttention hidden size. | |
| normalization_type: Normalization layer type. | |
| normalization_args: Normalization layer arguments. | |
| num_blocks: Number of RWKV blocks. | |
| embed_dropout_rate: Dropout rate for embedding layer. | |
| att_dropout_rate: Dropout rate for the attention module. | |
| ffn_dropout_rate: Dropout rate for the feed-forward module. | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int = 512, | |
| context_size: int = 1024, | |
| linear_size: Optional[int] = None, | |
| attention_size: Optional[int] = None, | |
| num_blocks: int = 4, | |
| att_dropout_rate: float = 0.0, | |
| ffn_dropout_rate: float = 0.0, | |
| dropout_rate: float = 0.0, | |
| subsampling_factor: int = 4, | |
| time_reduction_factor: int = 1, | |
| kernel: int = 3, | |
| **kwargs, | |
| ) -> None: | |
| """Construct a RWKVEncoder object.""" | |
| super().__init__() | |
| self.embed = RWKVConvInput( | |
| input_size, | |
| [output_size // 4, output_size // 2, output_size], | |
| subsampling_factor, | |
| conv_kernel_size=kernel, | |
| output_size=output_size, | |
| ) | |
| self.subsampling_factor = subsampling_factor | |
| linear_size = output_size * 4 if linear_size is None else linear_size | |
| attention_size = output_size if attention_size is None else attention_size | |
| self.rwkv_blocks = torch.nn.ModuleList( | |
| [ | |
| RWKV( | |
| output_size, | |
| linear_size, | |
| attention_size, | |
| context_size, | |
| block_id, | |
| num_blocks, | |
| att_dropout_rate=att_dropout_rate, | |
| ffn_dropout_rate=ffn_dropout_rate, | |
| dropout_rate=dropout_rate, | |
| ) | |
| for block_id in range(num_blocks) | |
| ] | |
| ) | |
| self.embed_norm = LayerNorm(output_size) | |
| self.final_norm = LayerNorm(output_size) | |
| self._output_size = output_size | |
| self.context_size = context_size | |
| self.num_blocks = num_blocks | |
| self.time_reduction_factor = time_reduction_factor | |
| def output_size(self) -> int: | |
| return self._output_size | |
| def forward(self, x: torch.Tensor, x_len) -> torch.Tensor: | |
| """Encode source label sequences. | |
| Args: | |
| x: Encoder input sequences. (B, L) | |
| Returns: | |
| out: Encoder output sequences. (B, U, D) | |
| """ | |
| _, length, _ = x.size() | |
| assert ( | |
| length <= self.context_size * self.subsampling_factor | |
| ), "Context size is too short for current length: %d versus %d" % ( | |
| length, | |
| self.context_size * self.subsampling_factor, | |
| ) | |
| mask = make_source_mask(x_len).to(x.device) | |
| x, mask = self.embed(x, mask, None) | |
| x = self.embed_norm(x) | |
| olens = mask.eq(0).sum(1) | |
| if self.training: | |
| for block in self.rwkv_blocks: | |
| x, _ = block(x) | |
| else: | |
| x = self.rwkv_infer(x) | |
| x = self.final_norm(x) | |
| if self.time_reduction_factor > 1: | |
| x = x[:, :: self.time_reduction_factor, :] | |
| olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1 | |
| return x, olens, None | |
| def rwkv_infer(self, xs_pad): | |
| batch_size = xs_pad.shape[0] | |
| hidden_sizes = [self._output_size for i in range(5)] | |
| state = [ | |
| torch.zeros( | |
| (batch_size, 1, hidden_sizes[i], self.num_blocks), | |
| dtype=torch.float32, | |
| device=xs_pad.device, | |
| ) | |
| for i in range(5) | |
| ] | |
| state[4] -= 1e-30 | |
| xs_out = [] | |
| for t in range(xs_pad.shape[1]): | |
| x_t = xs_pad[:, t, :] | |
| for idx, block in enumerate(self.rwkv_blocks): | |
| x_t, state = block(x_t, state=state) | |
| xs_out.append(x_t) | |
| xs_out = torch.cat(xs_out, dim=1) | |
| return xs_out | |