Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- encoding: utf-8 -*- | |
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import torch | |
| from typing import Dict, Optional, Tuple | |
| from funasr_detach.models.transformer.layer_norm import LayerNorm | |
| from funasr_detach.models.rwkv_bat.rwkv_feed_forward import FeedForward | |
| from funasr_detach.models.rwkv_bat.rwkv_attention import ( | |
| EncoderSelfAttention, | |
| DecoderSelfAttention, | |
| ) | |
| class RWKV(torch.nn.Module): | |
| """RWKV module. | |
| Args: | |
| size: Input/Output size. | |
| linear_size: Feed-forward hidden size. | |
| attention_size: SelfAttention hidden size. | |
| context_size: Context size for WKV computation. | |
| block_id: Block index. | |
| num_blocks: Number of blocks in the architecture. | |
| normalization_class: Normalization layer class. | |
| normalization_args: Normalization layer arguments. | |
| att_dropout_rate: Dropout rate for the attention module. | |
| ffn_dropout_rate: Dropout rate for the feed-forward module. | |
| """ | |
| def __init__( | |
| self, | |
| size: int, | |
| linear_size: int, | |
| attention_size: int, | |
| context_size: int, | |
| block_id: int, | |
| num_blocks: int, | |
| att_dropout_rate: float = 0.0, | |
| ffn_dropout_rate: float = 0.0, | |
| dropout_rate: float = 0.0, | |
| ) -> None: | |
| """Construct a RWKV object.""" | |
| super().__init__() | |
| self.layer_norm_att = LayerNorm(size) | |
| self.layer_norm_ffn = LayerNorm(size) | |
| self.att = EncoderSelfAttention( | |
| size, attention_size, context_size, block_id, att_dropout_rate, num_blocks | |
| ) | |
| self.dropout_att = torch.nn.Dropout(p=dropout_rate) | |
| self.ffn = FeedForward( | |
| size, linear_size, block_id, ffn_dropout_rate, num_blocks | |
| ) | |
| self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| state: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Compute receptance weighted key value. | |
| Args: | |
| x: RWKV input sequences. (B, L, size) | |
| state: Decoder hidden states. [5 x (B, D_att/size, N)] | |
| Returns: | |
| x: RWKV output sequences. (B, L, size) | |
| x: Decoder hidden states. [5 x (B, D_att/size, N)] | |
| """ | |
| att, state = self.att(self.layer_norm_att(x), state=state) | |
| x = x + self.dropout_att(att) | |
| ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) | |
| x = x + self.dropout_ffn(ffn) | |
| return x, state | |
| class RWKVDecoderLayer(torch.nn.Module): | |
| """RWKV module. | |
| Args: | |
| size: Input/Output size. | |
| linear_size: Feed-forward hidden size. | |
| attention_size: SelfAttention hidden size. | |
| context_size: Context size for WKV computation. | |
| block_id: Block index. | |
| num_blocks: Number of blocks in the architecture. | |
| normalization_class: Normalization layer class. | |
| normalization_args: Normalization layer arguments. | |
| att_dropout_rate: Dropout rate for the attention module. | |
| ffn_dropout_rate: Dropout rate for the feed-forward module. | |
| """ | |
| def __init__( | |
| self, | |
| size: int, | |
| linear_size: int, | |
| attention_size: int, | |
| context_size: int, | |
| block_id: int, | |
| num_blocks: int, | |
| att_dropout_rate: float = 0.0, | |
| ffn_dropout_rate: float = 0.0, | |
| dropout_rate: float = 0.0, | |
| ) -> None: | |
| """Construct a RWKV object.""" | |
| super().__init__() | |
| self.layer_norm_att = LayerNorm(size) | |
| self.layer_norm_ffn = LayerNorm(size) | |
| self.att = DecoderSelfAttention( | |
| size, attention_size, context_size, block_id, att_dropout_rate, num_blocks | |
| ) | |
| self.dropout_att = torch.nn.Dropout(p=dropout_rate) | |
| self.ffn = FeedForward( | |
| size, linear_size, block_id, ffn_dropout_rate, num_blocks | |
| ) | |
| self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| state: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Compute receptance weighted key value. | |
| Args: | |
| x: RWKV input sequences. (B, L, size) | |
| state: Decoder hidden states. [5 x (B, D_att/size, N)] | |
| Returns: | |
| x: RWKV output sequences. (B, L, size) | |
| x: Decoder hidden states. [5 x (B, D_att/size, N)] | |
| """ | |
| att, state = self.att(self.layer_norm_att(x), state=state) | |
| x = x + self.dropout_att(att) | |
| ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) | |
| x = x + self.dropout_ffn(ffn) | |
| return x, state | |