WeNet / wenet /LLM /decoder.py
inoryQwQ's picture
First commit
3c50954
Raw
History Blame Contribute Delete
6.14 kB
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint as ckpt
from wenet.transformer.attention import T_CACHE
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_EMB_CLASSES, WENET_MLP_CLASSES,
WENET_NORM_CLASSES)
from wenet.utils.common import mask_to_bias
class DecoderOnly(torch.nn.Module):
def __init__(
self,
n_kv_head: int,
head_dim: int,
hidden_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
normalize_before: bool = True,
query_bias: bool = False,
key_bias: bool = False,
value_bias: bool = False,
mlp_bias: bool = False,
activation_type: str = "gelu",
gelu_approximate: Union[str, None] = None,
max_position_embeding: int = 8192,
mlp_type: str = 'gated',
layer_norm_type: str = 'rms_norm',
norm_eps: float = 1e-5,
rms_norm_offset: bool = True,
selfattention_layer_type: str = "rope_abs_selfattn",
use_sdpa: bool = False,
gradient_checkpointing: bool = False,
rope_theta: float = 10000.0,
rope_style: str = 'google',
scale_embed: bool = True,
) -> None:
super().__init__()
assert selfattention_layer_type in ['rope_abs_selfattn']
self.pos_enc = WENET_EMB_CLASSES["rope_pos"](
hidden_size,
head_dim,
max_len=max_position_embeding,
dropout_rate=positional_dropout_rate,
rope_theta=rope_theta,
scale=scale_embed)
if activation_type == "gelu" and gelu_approximate is not None:
activation = WENET_ACTIVATION_CLASSES['gelu'](
approximate=gelu_approximate)
else:
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.num_blocks = num_blocks
# TODO: support lora & refactor lora
self.decoders = torch.nn.ModuleList([
TransformerEncoderLayer(
hidden_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
attention_heads,
hidden_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
style=rope_style),
mlp_class(hidden_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
rms_norm_offset=rms_norm_offset,
) for _ in range(self.num_blocks)
])
self.pre_norm = normalize_before
self.final_norm: Optional[torch.nn.Module] = None
if self.pre_norm:
norm_class = WENET_NORM_CLASSES[layer_norm_type]
if layer_norm_type == "rms_norm":
norm_class = partial(
norm_class,
add_unit_offset=rms_norm_offset,
)
self.final_norm = norm_class(hidden_size, eps=norm_eps)
self.n_kv_head = n_kv_head
self.head_dim = head_dim
self._hidden_size = hidden_size
self.use_sdpa = use_sdpa
self.gradient_checkpointing = gradient_checkpointing
def forward(
self,
input: torch.Tensor,
att_mask: torch.Tensor,
input_position: Union[int, torch.Tensor] = 0,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
xs, pos_emb = self.pos_enc(input, offset=input_position)
if self.use_sdpa:
att_mask = mask_to_bias(att_mask, xs.dtype)
if self.gradient_checkpointing and self.training:
xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb)
else:
xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb,
kv_caches)
if self.pre_norm and self.final_norm is not None:
xs = self.final_norm(xs)
return xs, kv_caches
def forward_layers(
self,
xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
if self.training:
for (i, layer) in enumerate(self.decoders):
xs, _, _, _ = layer(xs, att_mask, pos_emb)
new_kv_caches = kv_caches
else:
assert kv_caches is not None
new_kv_caches = []
for (i, layer) in enumerate(self.decoders):
xs, _, new_kv_cache, _ = layer(xs,
att_mask,
pos_emb,
att_cache=(kv_caches[i][0],
kv_caches[i][1]))
new_kv_caches.append(new_kv_cache)
return xs, new_kv_caches
@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask,
pos_emb)
return xs
@property
def hidden_size(self):
return self._hidden_size