Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- encoding: utf-8 -*- | |
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import math | |
| import torch | |
| from pathlib import Path | |
| from importlib.util import find_spec | |
| from typing import List, Optional, Tuple, Union | |
| wkv_kernel_encoder = None | |
| wkv_kernel_decoder = None | |
| class WKVLinearAttentionEncoder(torch.autograd.Function): | |
| """WKVLinearAttention function definition.""" | |
| def forward( | |
| ctx, | |
| time_decay: torch.Tensor, | |
| time_first: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.tensor, | |
| ) -> torch.Tensor: | |
| """WKVLinearAttention function forward pass. | |
| Args: | |
| time_decay: Channel-wise time decay vector. (D_att) | |
| time_first: Channel-wise time first vector. (D_att) | |
| key: Key tensor. (B, U, D_att) | |
| value: Value tensor. (B, U, D_att) | |
| Returns: | |
| out: Weighted Key-Value tensor. (B, U, D_att) | |
| """ | |
| batch, length, dim = key.size() | |
| assert length <= wkv_kernel_encoder.context_size, ( | |
| f"Cannot process key of length {length} while context_size " | |
| f"is ({wkv_kernel_encoder.context_size}). Limit should be increased." | |
| ) | |
| assert batch * dim % min(dim, 32) == 0, ( | |
| f"batch size ({batch}) by dimension ({dim}) should be a multiple of " | |
| f"{min(dim, 32)}" | |
| ) | |
| ctx.input_dtype = key.dtype | |
| time_decay = -torch.exp(time_decay.float().contiguous()) | |
| time_first = time_first.float().contiguous() | |
| key = key.float().contiguous() | |
| value = value.float().contiguous() | |
| out = torch.empty_like(key, memory_format=torch.contiguous_format) | |
| wkv_kernel_encoder.forward(time_decay, time_first, key, value, out) | |
| ctx.save_for_backward(time_decay, time_first, key, value, out) | |
| return out | |
| def backward( | |
| ctx, grad_output: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """WKVLinearAttention function backward pass. | |
| Args: | |
| grad_output: Output gradient. (B, U, D_att) | |
| Returns: | |
| grad_time_decay: Gradient for channel-wise time decay vector. (D_att) | |
| grad_time_first: Gradient for channel-wise time first vector. (D_att) | |
| grad_key: Gradient for key tensor. (B, U, D_att) | |
| grad_value: Gradient for value tensor. (B, U, D_att) | |
| """ | |
| time_decay, time_first, key, value, output = ctx.saved_tensors | |
| grad_dtype = ctx.input_dtype | |
| batch, _, dim = key.size() | |
| grad_time_decay = torch.empty( | |
| (batch, dim), | |
| memory_format=torch.contiguous_format, | |
| dtype=time_decay.dtype, | |
| device=time_decay.device, | |
| ) | |
| grad_time_first = torch.empty( | |
| (batch, dim), | |
| memory_format=torch.contiguous_format, | |
| dtype=time_decay.dtype, | |
| device=time_decay.device, | |
| ) | |
| grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) | |
| grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) | |
| wkv_kernel_encoder.backward( | |
| time_decay, | |
| time_first, | |
| key, | |
| value, | |
| output, | |
| grad_output.contiguous(), | |
| grad_time_decay, | |
| grad_time_first, | |
| grad_key, | |
| grad_value, | |
| ) | |
| grad_time_decay = torch.sum(grad_time_decay, dim=0) | |
| grad_time_first = torch.sum(grad_time_first, dim=0) | |
| return ( | |
| grad_time_decay, | |
| grad_time_first, | |
| grad_key, | |
| grad_value, | |
| ) | |
| class WKVLinearAttentionDecoder(torch.autograd.Function): | |
| """WKVLinearAttention function definition.""" | |
| def forward( | |
| ctx, | |
| time_decay: torch.Tensor, | |
| time_first: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.tensor, | |
| ) -> torch.Tensor: | |
| """WKVLinearAttention function forward pass. | |
| Args: | |
| time_decay: Channel-wise time decay vector. (D_att) | |
| time_first: Channel-wise time first vector. (D_att) | |
| key: Key tensor. (B, U, D_att) | |
| value: Value tensor. (B, U, D_att) | |
| Returns: | |
| out: Weighted Key-Value tensor. (B, U, D_att) | |
| """ | |
| batch, length, dim = key.size() | |
| assert length <= wkv_kernel_decoder.context_size, ( | |
| f"Cannot process key of length {length} while context_size " | |
| f"is ({wkv_kernel.context_size}). Limit should be increased." | |
| ) | |
| assert batch * dim % min(dim, 32) == 0, ( | |
| f"batch size ({batch}) by dimension ({dim}) should be a multiple of " | |
| f"{min(dim, 32)}" | |
| ) | |
| ctx.input_dtype = key.dtype | |
| time_decay = -torch.exp(time_decay.float().contiguous()) | |
| time_first = time_first.float().contiguous() | |
| key = key.float().contiguous() | |
| value = value.float().contiguous() | |
| out = torch.empty_like(key, memory_format=torch.contiguous_format) | |
| wkv_kernel_decoder.forward(time_decay, time_first, key, value, out) | |
| ctx.save_for_backward(time_decay, time_first, key, value, out) | |
| return out | |
| def backward( | |
| ctx, grad_output: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """WKVLinearAttention function backward pass. | |
| Args: | |
| grad_output: Output gradient. (B, U, D_att) | |
| Returns: | |
| grad_time_decay: Gradient for channel-wise time decay vector. (D_att) | |
| grad_time_first: Gradient for channel-wise time first vector. (D_att) | |
| grad_key: Gradient for key tensor. (B, U, D_att) | |
| grad_value: Gradient for value tensor. (B, U, D_att) | |
| """ | |
| time_decay, time_first, key, value, output = ctx.saved_tensors | |
| grad_dtype = ctx.input_dtype | |
| batch, _, dim = key.size() | |
| grad_time_decay = torch.empty( | |
| (batch, dim), | |
| memory_format=torch.contiguous_format, | |
| dtype=time_decay.dtype, | |
| device=time_decay.device, | |
| ) | |
| grad_time_first = torch.empty( | |
| (batch, dim), | |
| memory_format=torch.contiguous_format, | |
| dtype=time_decay.dtype, | |
| device=time_decay.device, | |
| ) | |
| grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) | |
| grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) | |
| wkv_kernel_decoder.backward( | |
| time_decay, | |
| time_first, | |
| key, | |
| value, | |
| output, | |
| grad_output.contiguous(), | |
| grad_time_decay, | |
| grad_time_first, | |
| grad_key, | |
| grad_value, | |
| ) | |
| grad_time_decay = torch.sum(grad_time_decay, dim=0) | |
| grad_time_first = torch.sum(grad_time_first, dim=0) | |
| return ( | |
| grad_time_decay, | |
| grad_time_first, | |
| grad_key, | |
| grad_value, | |
| ) | |
| def load_encoder_wkv_kernel(context_size: int) -> None: | |
| """Load WKV CUDA kernel. | |
| Args: | |
| context_size: Context size. | |
| """ | |
| from torch.utils.cpp_extension import load | |
| global wkv_kernel_encoder | |
| if ( | |
| wkv_kernel_encoder is not None | |
| and wkv_kernel_encoder.context_size == context_size | |
| ): | |
| return | |
| if find_spec("ninja") is None: | |
| raise ImportError( | |
| "Ninja package was not found. WKV kernel module can't be loaded " | |
| "for training. Please, 'pip install ninja' in your environment." | |
| ) | |
| if not torch.cuda.is_available(): | |
| raise ImportError( | |
| "CUDA is currently a requirement for WKV kernel loading. " | |
| "Please set your devices properly and launch again." | |
| ) | |
| kernel_folder = Path(__file__).resolve().parent / "cuda_encoder" | |
| kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] | |
| kernel_cflags = [ | |
| "-res-usage", | |
| "--maxrregcount 60", | |
| "--use_fast_math", | |
| "-O3", | |
| "-Xptxas -O3", | |
| f"-DTmax={context_size}", | |
| ] | |
| wkv_kernel_encoder = load( | |
| name=f"encoder_wkv_{context_size}", | |
| sources=kernel_files, | |
| verbose=True, | |
| extra_cuda_cflags=kernel_cflags, | |
| ) | |
| wkv_kernel_encoder.context_size = context_size | |
| def load_decoder_wkv_kernel(context_size: int) -> None: | |
| """Load WKV CUDA kernel. | |
| Args: | |
| context_size: Context size. | |
| """ | |
| from torch.utils.cpp_extension import load | |
| global wkv_kernel_decoder | |
| if ( | |
| wkv_kernel_decoder is not None | |
| and wkv_kernel_decoder.context_size == context_size | |
| ): | |
| return | |
| if find_spec("ninja") is None: | |
| raise ImportError( | |
| "Ninja package was not found. WKV kernel module can't be loaded " | |
| "for training. Please, 'pip install ninja' in your environment." | |
| ) | |
| if not torch.cuda.is_available(): | |
| raise ImportError( | |
| "CUDA is currently a requirement for WKV kernel loading. " | |
| "Please set your devices properly and launch again." | |
| ) | |
| kernel_folder = Path(__file__).resolve().parent / "cuda_decoder" | |
| kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] | |
| kernel_cflags = [ | |
| "-res-usage", | |
| "--maxrregcount 60", | |
| "--use_fast_math", | |
| "-O3", | |
| "-Xptxas -O3", | |
| f"-DTmax={context_size}", | |
| ] | |
| wkv_kernel_decoder = load( | |
| name=f"decoder_wkv_{context_size}", | |
| sources=kernel_files, | |
| verbose=True, | |
| extra_cuda_cflags=kernel_cflags, | |
| ) | |
| wkv_kernel_decoder.context_size = context_size | |
| class SelfAttention(torch.nn.Module): | |
| """SelfAttention module definition. | |
| Args: | |
| size: Input/Output size. | |
| attention_size: Attention hidden size. | |
| context_size: Context size for WKV kernel. | |
| block_id: Block index. | |
| num_blocks: Number of blocks in the architecture. | |
| """ | |
| def __init__( | |
| self, | |
| size: int, | |
| attention_size: int, | |
| block_id: int, | |
| dropout_rate: float, | |
| num_blocks: int, | |
| ) -> None: | |
| """Construct a SelfAttention object.""" | |
| super().__init__() | |
| self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) | |
| self.time_decay = torch.nn.Parameter(torch.empty(attention_size)) | |
| self.time_first = torch.nn.Parameter(torch.empty(attention_size)) | |
| self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) | |
| self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size)) | |
| self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) | |
| self.proj_key = torch.nn.Linear(size, attention_size, bias=True) | |
| self.proj_value = torch.nn.Linear(size, attention_size, bias=True) | |
| self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True) | |
| self.proj_output = torch.nn.Linear(attention_size, size, bias=True) | |
| self.block_id = block_id | |
| self.reset_parameters(size, attention_size, block_id, num_blocks) | |
| self.dropout = torch.nn.Dropout(p=dropout_rate) | |
| def reset_parameters( | |
| self, size: int, attention_size: int, block_id: int, num_blocks: int | |
| ) -> None: | |
| """Reset module parameters. | |
| Args: | |
| size: Block size. | |
| attention_size: Attention hidden size. | |
| block_id: Block index. | |
| num_blocks: Number of blocks in the architecture. | |
| """ | |
| ratio_0_to_1 = block_id / (num_blocks - 1) | |
| ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) | |
| time_weight = torch.ones(1, 1, size) | |
| for i in range(size): | |
| time_weight[0, 0, i] = i / size | |
| decay_speed = [ | |
| -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) | |
| for h in range(attention_size) | |
| ] | |
| decay_speed = torch.tensor( | |
| decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device | |
| ) | |
| zigzag = ( | |
| torch.tensor( | |
| [(i + 1) % 3 - 1 for i in range(attention_size)], | |
| dtype=self.time_first.dtype, | |
| device=self.time_first.device, | |
| ) | |
| * 0.5 | |
| ) | |
| with torch.no_grad(): | |
| self.time_decay.data = decay_speed | |
| self.time_first.data = torch.ones_like( | |
| self.time_first * math.log(0.3) + zigzag | |
| ) | |
| self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) | |
| self.time_mix_value.data = ( | |
| torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 | |
| ) | |
| self.time_mix_receptance.data = torch.pow( | |
| time_weight, 0.5 * ratio_1_to_almost0 | |
| ) | |
| def wkv_linear_attention( | |
| self, | |
| time_decay: torch.Tensor, | |
| time_first: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], | |
| ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: | |
| """Compute WKV with state (i.e.: for inference). | |
| Args: | |
| time_decay: Channel-wise time decay vector. (D_att) | |
| time_first: Channel-wise time first vector. (D_att) | |
| key: Key tensor. (B, 1, D_att) | |
| value: Value tensor. (B, 1, D_att) | |
| state: Decoder hidden states. [3 x (B, D_att)] | |
| Returns: | |
| output: Weighted Key-Value. (B, 1, D_att) | |
| state: Decoder hidden states. [3 x (B, 1, D_att)] | |
| """ | |
| num_state, den_state, max_state = state | |
| time_decay = -torch.exp(time_decay) | |
| max_for_output = torch.maximum(max_state, (time_first + key)) | |
| e1 = torch.exp(max_state - max_for_output) | |
| e2 = torch.exp((time_first + key) - max_for_output) | |
| numerator = e1 * num_state + e2 * value | |
| denominator = e1 * den_state + e2 | |
| max_for_state = torch.maximum(key, (max_state + time_decay)) | |
| e1 = torch.exp((max_state + time_decay) - max_for_state) | |
| e2 = torch.exp(key - max_for_state) | |
| wkv = numerator / denominator | |
| state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state] | |
| return wkv, state | |
| class DecoderSelfAttention(SelfAttention): | |
| """SelfAttention module definition. | |
| Args: | |
| size: Input/Output size. | |
| attention_size: Attention hidden size. | |
| context_size: Context size for WKV kernel. | |
| block_id: Block index. | |
| num_blocks: Number of blocks in the architecture. | |
| """ | |
| def __init__( | |
| self, | |
| size: int, | |
| attention_size: int, | |
| context_size: int, | |
| block_id: int, | |
| dropout_rate: float, | |
| num_blocks: int, | |
| ) -> None: | |
| """Construct a SelfAttention object.""" | |
| super().__init__(size, attention_size, block_id, dropout_rate, num_blocks) | |
| # load_decoder_wkv_kernel(context_size) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| state: Optional[List[torch.Tensor]] = None, | |
| ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: | |
| """Compute time mixing. | |
| Args: | |
| x: SelfAttention input sequences. (B, U, size) | |
| state: Decoder hidden states. [5 x (B, 1, D_att, N)] | |
| Returns: | |
| x: SelfAttention output sequences. (B, U, size) | |
| """ | |
| shifted_x = ( | |
| self.time_shift(x) if state is None else state[1][..., self.block_id] | |
| ) | |
| key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) | |
| value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) | |
| receptance = x * self.time_mix_receptance + shifted_x * ( | |
| 1 - self.time_mix_receptance | |
| ) | |
| key = self.proj_key(key) | |
| value = self.proj_value(value) | |
| receptance = torch.sigmoid(self.proj_receptance(receptance)) | |
| if state is not None: | |
| state[1][..., self.block_id] = x | |
| wkv, att_state = self.wkv_linear_attention( | |
| self.time_decay, | |
| self.time_first, | |
| key, | |
| value, | |
| tuple(s[..., self.block_id] for s in state[2:]), | |
| ) | |
| state[2][..., self.block_id] = att_state[0] | |
| state[3][..., self.block_id] = att_state[1] | |
| state[4][..., self.block_id] = att_state[2] | |
| else: | |
| wkv = WKVLinearAttentionDecoder.apply( | |
| self.time_decay, self.time_first, key, value | |
| ) | |
| wkv = self.dropout(wkv) | |
| x = self.proj_output(receptance * wkv) | |
| return x, state | |
| class EncoderSelfAttention(SelfAttention): | |
| """SelfAttention module definition. | |
| Args: | |
| size: Input/Output size. | |
| attention_size: Attention hidden size. | |
| context_size: Context size for WKV kernel. | |
| block_id: Block index. | |
| num_blocks: Number of blocks in the architecture. | |
| """ | |
| def __init__( | |
| self, | |
| size: int, | |
| attention_size: int, | |
| context_size: int, | |
| block_id: int, | |
| dropout_rate: float, | |
| num_blocks: int, | |
| ) -> None: | |
| """Construct a SelfAttention object.""" | |
| super().__init__(size, attention_size, block_id, dropout_rate, num_blocks) | |
| # load_encoder_wkv_kernel(context_size) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| state: Optional[List[torch.Tensor]] = None, | |
| ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: | |
| """Compute time mixing. | |
| Args: | |
| x: SelfAttention input sequences. (B, U, size) | |
| state: Decoder hidden states. [5 x (B, 1, D_att, N)] | |
| Returns: | |
| x: SelfAttention output sequences. (B, U, size) | |
| """ | |
| shifted_x = ( | |
| self.time_shift(x) if state is None else state[1][..., self.block_id] | |
| ) | |
| key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) | |
| value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) | |
| receptance = x * self.time_mix_receptance + shifted_x * ( | |
| 1 - self.time_mix_receptance | |
| ) | |
| key = self.proj_key(key) | |
| value = self.proj_value(value) | |
| receptance = torch.sigmoid(self.proj_receptance(receptance)) | |
| if state is not None: | |
| state[1][..., self.block_id] = x | |
| wkv, att_state = self.wkv_linear_attention( | |
| self.time_decay, | |
| self.time_first, | |
| key, | |
| value, | |
| tuple(s[..., self.block_id] for s in state[2:]), | |
| ) | |
| state[2][..., self.block_id] = att_state[0] | |
| state[3][..., self.block_id] = att_state[1] | |
| state[4][..., self.block_id] = att_state[2] | |
| else: | |
| wkv = WKVLinearAttentionEncoder.apply( | |
| self.time_decay, self.time_first, key, value | |
| ) | |
| wkv = self.dropout(wkv) | |
| x = self.proj_output(receptance * wkv) | |
| return x, state | |