| from functools import partial |
| from typing import List, Optional, Tuple, Union |
| import torch |
| import torch.utils.checkpoint as ckpt |
| from wenet.transformer.attention import T_CACHE |
|
|
| from wenet.transformer.encoder_layer import TransformerEncoderLayer |
| from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, |
| WENET_ATTENTION_CLASSES, |
| WENET_EMB_CLASSES, WENET_MLP_CLASSES, |
| WENET_NORM_CLASSES) |
| from wenet.utils.common import mask_to_bias |
|
|
|
|
| class DecoderOnly(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| n_kv_head: int, |
| head_dim: int, |
| hidden_size: int, |
| attention_heads: int = 4, |
| linear_units: int = 2048, |
| num_blocks: int = 6, |
| dropout_rate: float = 0.1, |
| positional_dropout_rate: float = 0.1, |
| attention_dropout_rate: float = 0.0, |
| normalize_before: bool = True, |
| query_bias: bool = False, |
| key_bias: bool = False, |
| value_bias: bool = False, |
| mlp_bias: bool = False, |
| activation_type: str = "gelu", |
| gelu_approximate: Union[str, None] = None, |
| max_position_embeding: int = 8192, |
| mlp_type: str = 'gated', |
| layer_norm_type: str = 'rms_norm', |
| norm_eps: float = 1e-5, |
| rms_norm_offset: bool = True, |
| selfattention_layer_type: str = "rope_abs_selfattn", |
| use_sdpa: bool = False, |
| gradient_checkpointing: bool = False, |
| rope_theta: float = 10000.0, |
| rope_style: str = 'google', |
| scale_embed: bool = True, |
| ) -> None: |
| super().__init__() |
|
|
| assert selfattention_layer_type in ['rope_abs_selfattn'] |
| self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( |
| hidden_size, |
| head_dim, |
| max_len=max_position_embeding, |
| dropout_rate=positional_dropout_rate, |
| rope_theta=rope_theta, |
| scale=scale_embed) |
| if activation_type == "gelu" and gelu_approximate is not None: |
| activation = WENET_ACTIVATION_CLASSES['gelu']( |
| approximate=gelu_approximate) |
| else: |
| activation = WENET_ACTIVATION_CLASSES[activation_type]() |
|
|
| mlp_class = WENET_MLP_CLASSES[mlp_type] |
| self.num_blocks = num_blocks |
| |
| self.decoders = torch.nn.ModuleList([ |
| TransformerEncoderLayer( |
| hidden_size, |
| WENET_ATTENTION_CLASSES[selfattention_layer_type]( |
| attention_heads, |
| hidden_size, |
| attention_dropout_rate, |
| query_bias, |
| key_bias, |
| value_bias, |
| use_sdpa, |
| n_kv_head, |
| head_dim, |
| style=rope_style), |
| mlp_class(hidden_size, linear_units, dropout_rate, activation, |
| mlp_bias), |
| dropout_rate, |
| normalize_before, |
| layer_norm_type=layer_norm_type, |
| norm_eps=norm_eps, |
| rms_norm_offset=rms_norm_offset, |
| ) for _ in range(self.num_blocks) |
| ]) |
| self.pre_norm = normalize_before |
| self.final_norm: Optional[torch.nn.Module] = None |
| if self.pre_norm: |
| norm_class = WENET_NORM_CLASSES[layer_norm_type] |
| if layer_norm_type == "rms_norm": |
| norm_class = partial( |
| norm_class, |
| add_unit_offset=rms_norm_offset, |
| ) |
| self.final_norm = norm_class(hidden_size, eps=norm_eps) |
|
|
| self.n_kv_head = n_kv_head |
| self.head_dim = head_dim |
| self._hidden_size = hidden_size |
| self.use_sdpa = use_sdpa |
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| def forward( |
| self, |
| input: torch.Tensor, |
| att_mask: torch.Tensor, |
| input_position: Union[int, torch.Tensor] = 0, |
| kv_caches: Optional[List[T_CACHE]] = None, |
| ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: |
| xs, pos_emb = self.pos_enc(input, offset=input_position) |
| if self.use_sdpa: |
| att_mask = mask_to_bias(att_mask, xs.dtype) |
|
|
| if self.gradient_checkpointing and self.training: |
| xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) |
| else: |
| xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, |
| kv_caches) |
| if self.pre_norm and self.final_norm is not None: |
| xs = self.final_norm(xs) |
| return xs, kv_caches |
|
|
| def forward_layers( |
| self, |
| xs: torch.Tensor, |
| att_mask: torch.Tensor, |
| pos_emb: torch.Tensor, |
| kv_caches: Optional[List[T_CACHE]] = None, |
| ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: |
| if self.training: |
| for (i, layer) in enumerate(self.decoders): |
| xs, _, _, _ = layer(xs, att_mask, pos_emb) |
| new_kv_caches = kv_caches |
| else: |
| assert kv_caches is not None |
| new_kv_caches = [] |
| for (i, layer) in enumerate(self.decoders): |
| xs, _, new_kv_cache, _ = layer(xs, |
| att_mask, |
| pos_emb, |
| att_cache=(kv_caches[i][0], |
| kv_caches[i][1])) |
| new_kv_caches.append(new_kv_cache) |
|
|
| return xs, new_kv_caches |
|
|
| @torch.jit.ignore(drop=True) |
| def forward_layers_checkpointed(self, xs: torch.Tensor, |
| att_mask: torch.Tensor, |
| pos_emb: torch.Tensor) -> torch.Tensor: |
| for layer in self.decoders: |
| xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, |
| pos_emb) |
| return xs |
|
|
| @property |
| def hidden_size(self): |
| return self._hidden_size |
|
|