|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from flash_attn.modules.embedding import ParallelGPT2Embeddings |
|
|
from flash_attn.modules.mlp import ParallelFusedMLP |
|
|
from torch import nn |
|
|
|
|
|
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode |
|
|
from internlm.core.context.parallel_context import global_context as gpc |
|
|
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal |
|
|
from internlm.model.embedding import Embedding1D, Embedding1DLVM |
|
|
from internlm.model.linear import ( |
|
|
FeedForward, |
|
|
RewardModelLinear, |
|
|
ScaleColumnParallelLinear, |
|
|
) |
|
|
from internlm.model.multi_head_attention import MHA |
|
|
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, try_import_LayerNorm |
|
|
from internlm.solver.pipeline_utils import partition_uniform |
|
|
from internlm.utils.checkpoint import activation_checkpoint |
|
|
from internlm.utils.common import filter_kwargs |
|
|
from internlm.utils.logger import get_logger |
|
|
from internlm.utils.registry import MODEL_INITIALIZER |
|
|
|
|
|
MODEL_TYPE = "ViT" |
|
|
|
|
|
logger = get_logger(__file__) |
|
|
RMSNorm = try_import_RMSNorm() |
|
|
LayerNorm = try_import_LayerNorm() |
|
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): |
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
|
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
|
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
|
|
'survival rate' as the argument. |
|
|
|
|
|
""" |
|
|
if drop_prob == 0. or not training: |
|
|
return x |
|
|
keep_prob = 1 - drop_prob |
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
|
if keep_prob > 0.0 and scale_by_keep: |
|
|
random_tensor.div_(keep_prob) |
|
|
return x * random_tensor |
|
|
|
|
|
|
|
|
class DropPath(nn.Module): |
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
""" |
|
|
def __init__(self, drop_prob=None, scale_by_keep=True): |
|
|
super(DropPath, self).__init__() |
|
|
self.drop_prob = drop_prob |
|
|
self.scale_by_keep = scale_by_keep |
|
|
|
|
|
def forward(self, x): |
|
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
|
|
|
|
|
|
class PackedFlashBaseLayer1D(nn.Module): |
|
|
""" |
|
|
1D Packed Flash Base Layer. |
|
|
|
|
|
Args: |
|
|
hidden_size (int): The hidden size of model. 768 by default. |
|
|
num_attention_heads (int): The number of attention heads. 12 by default. |
|
|
mlp_ratio (int): The ratio of MLP layers. 4 by default. |
|
|
attn_drop_rate (float): The dropout rate of attention module. 0 by default. |
|
|
drop_path_rate (float): The drop path rate of the input hidden state. 0.0 by default. |
|
|
dtype (torch.dtype): Type of data. torch.float by default. |
|
|
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. |
|
|
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. |
|
|
layer_idx (int): The index of current layer. 0 by default. |
|
|
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. |
|
|
device (Optional[Union[str, torch.device]]): The device will be used. |
|
|
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. |
|
|
use_flash_attn (bool): Whether use flash-attn. True by default. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int = 768, |
|
|
num_attention_heads: int = 12, |
|
|
mlp_ratio: int = 4, |
|
|
mlp_bias: bool = False, |
|
|
attn_drop_rate: float = 0, |
|
|
drop_path_rate: float = 0.0, |
|
|
dtype: torch.dtype = torch.float, |
|
|
layer_norm_epsilon: float = 1e-6, |
|
|
checkpoint: bool = False, |
|
|
layer_idx: int = 0, |
|
|
residual_in_fp32: bool = False, |
|
|
device: Optional[torch.device] = None, |
|
|
norm_type: str = "rmsnorm", |
|
|
dropout_selective_checkpoint: bool = True, |
|
|
use_scaled_init: bool = True, |
|
|
use_swiglu: bool = True, |
|
|
use_flash_attn: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.checkpoint = checkpoint |
|
|
|
|
|
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False |
|
|
self.layer_idx = layer_idx |
|
|
self.use_flash_attn = use_flash_attn |
|
|
|
|
|
head_dim = hidden_size // num_attention_heads |
|
|
self.mixer = MHA( |
|
|
embed_dim=hidden_size, |
|
|
num_heads=num_attention_heads, |
|
|
process_group=gpc.get_group(ParallelMode.TENSOR), |
|
|
dropout=attn_drop_rate, |
|
|
softmax_scale=1 / math.sqrt(head_dim), |
|
|
causal=True, |
|
|
layer_idx=layer_idx, |
|
|
rotary_emb_dim=head_dim, |
|
|
rotary_emb_scale_base=0, |
|
|
use_flash_attn=use_flash_attn, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
self.dropout1 = DropPath(drop_path_rate) |
|
|
if norm_type == "rmsnorm": |
|
|
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) |
|
|
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) |
|
|
else: |
|
|
self.norm1 = LayerNorm(hidden_size, eps=layer_norm_epsilon) |
|
|
self.norm2 = LayerNorm(hidden_size, eps=layer_norm_epsilon) |
|
|
|
|
|
self.mlp = ParallelFusedMLP( |
|
|
hidden_size, |
|
|
int(hidden_size * mlp_ratio), |
|
|
out_features=hidden_size, |
|
|
activation="gelu_approx", |
|
|
process_group=gpc.get_group(ParallelMode.TENSOR), |
|
|
bias1=mlp_bias, |
|
|
bias2=mlp_bias, |
|
|
sequence_parallel=gpc.config.parallel.sequence_parallel, |
|
|
checkpoint_lvl=0, |
|
|
heuristic="auto", |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
for _, param in self.mlp.named_parameters(): |
|
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1: |
|
|
setattr(param, IS_TENSOR_PARALLEL, True) |
|
|
self.dropout2 = DropPath(drop_path_rate) |
|
|
self.use_swiglu = use_swiglu |
|
|
self.use_scaled_init = use_scaled_init |
|
|
self.residual_in_fp32 = residual_in_fp32 |
|
|
self.return_residual = False |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
with torch.no_grad(): |
|
|
for name, param in self.mixer.named_parameters(): |
|
|
if param.ndim == 1: |
|
|
param.data.zero_() |
|
|
elif "Wqkv" in name: |
|
|
normal_(std=0.006)(param.data) |
|
|
elif self.use_scaled_init: |
|
|
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) |
|
|
else: |
|
|
normal_(std=0.0015)(param.data) |
|
|
|
|
|
for name, param in self.mlp.named_parameters(): |
|
|
if param.ndim == 1 and "bias" in name: |
|
|
param.data.zero_() |
|
|
elif self.use_swiglu: |
|
|
if self.use_scaled_init and "w2" in name: |
|
|
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) |
|
|
else: |
|
|
normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data) |
|
|
else: |
|
|
if self.use_scaled_init and "fc1" not in name: |
|
|
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) |
|
|
else: |
|
|
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) |
|
|
|
|
|
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): |
|
|
if self.checkpoint and self.training: |
|
|
return activation_checkpoint( |
|
|
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen |
|
|
) |
|
|
else: |
|
|
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) |
|
|
|
|
|
def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): |
|
|
r"""Pass the input through the encoder layer. |
|
|
|
|
|
Args: |
|
|
hidden_states: the sequence to the encoder layer (required). |
|
|
residual: hidden_states = Attn/MLP(LN(residual)) |
|
|
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 |
|
|
indexes: the length of index is same as hidden states, which stand for the current position |
|
|
""" |
|
|
mixer_kwargs = { |
|
|
"cu_seqlens": cu_seqlens, |
|
|
"max_seqlen": max_seqlen, |
|
|
"indexes": indexes, |
|
|
"inference_params": inference_params, |
|
|
} |
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.norm1(residual.float()) |
|
|
hidden_states = self.mixer(hidden_states, **mixer_kwargs) |
|
|
hidden_states = self.dropout1(hidden_states) |
|
|
|
|
|
residual = residual + hidden_states |
|
|
|
|
|
hidden_states = self.norm2(residual.float()) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = self.dropout2(hidden_states) |
|
|
|
|
|
return hidden_states + residual |
|
|
|
|
|
|
|
|
class PackedFlashInternLm1D(nn.Module): |
|
|
""" |
|
|
1D Packed Flash InternLm. |
|
|
|
|
|
Args: |
|
|
num_layers (int): The number of layer. 12 by default. |
|
|
hidden_size (int): The size of hidden state. 768 by default. |
|
|
num_attention_heads (int): The number of attention head. 12 by default. |
|
|
vocab_size (int): The size of vocabulary. 50304 by default. |
|
|
mlp_ratio (int): The ratio of MLP layers. 4 by default. |
|
|
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. |
|
|
drop_path_rate (float): The drop path rate of input hidden state. 0.0 by default. |
|
|
dtype (torch.dtype): The type of data. torch.float by default. |
|
|
checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number |
|
|
of layers. 0.0 by default. |
|
|
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. |
|
|
first (bool): Whether input embedding layer or not. False by default. |
|
|
last (bool): Whether output embedding layer or not. False by default. |
|
|
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. |
|
|
True by default. |
|
|
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. |
|
|
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. |
|
|
start_layer_idx (int): The index of start layer in the pipeline. 0 by default. |
|
|
device (Optional[Union[str, torch.device]]): The device will be used. None by default. |
|
|
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. |
|
|
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. |
|
|
use_flash_attn (bool): Whether to use flash-attn. True by default. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_layers: int = 12, |
|
|
hidden_size: int = 768, |
|
|
num_attention_heads: int = 12, |
|
|
vocab_size: int = 50304, |
|
|
mlp_ratio: int = 4.0, |
|
|
mlp_bias: bool = False, |
|
|
attn_drop_rate: float = 0.0, |
|
|
drop_path_rate: float = 0.0, |
|
|
dtype: torch.dtype = torch.float, |
|
|
checkpoint: float = 0.0, |
|
|
layer_norm_epsilon: float = 1e-5, |
|
|
first: bool = False, |
|
|
last: bool = False, |
|
|
embed_split_hidden: bool = False, |
|
|
embed_grad_scale: float = 0.1, |
|
|
parallel_output: bool = True, |
|
|
start_layer_idx: int = 0, |
|
|
device: Optional[torch.device] = None, |
|
|
residual_in_fp32: bool = False, |
|
|
norm_type: str = "rmsnorm", |
|
|
is_reward: bool = False, |
|
|
dropout_selective_checkpoint: bool = True, |
|
|
use_scaled_init: bool = True, |
|
|
use_swiglu: bool = True, |
|
|
use_flash_attn: bool = True, |
|
|
lvm_config: dict = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.lvm_config = lvm_config |
|
|
|
|
|
checkpoint_layer_num = int(num_layers * checkpoint) |
|
|
|
|
|
head_cls = ScaleColumnParallelLinear |
|
|
if first: |
|
|
if self.lvm_config.get('enable', False): |
|
|
self.embedding = Embedding1DLVM(**self.lvm_config.get('embedding_cfg')) |
|
|
if self.embedding.embed_proj is not None: |
|
|
for _, param in self.embedding.embed_proj.named_parameters(): |
|
|
normal_(std=0.0052)(param) |
|
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1: |
|
|
setattr(param, IS_TENSOR_PARALLEL, True) |
|
|
else: |
|
|
if embed_split_hidden: |
|
|
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) |
|
|
else: |
|
|
self.embedding = ParallelGPT2Embeddings( |
|
|
embed_dim=hidden_size, |
|
|
vocab_size=vocab_size, |
|
|
max_position_embeddings=-1, |
|
|
process_group=gpc.get_group(ParallelMode.TENSOR), |
|
|
padding_idx=None, |
|
|
sequence_parallel=gpc.config.parallel.sequence_parallel, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
for _, param in self.embedding.named_parameters(): |
|
|
normal_(std=0.0052)(param) |
|
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1: |
|
|
setattr(param, IS_TENSOR_PARALLEL, True) |
|
|
self.embed_grad_scale = embed_grad_scale |
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
PackedFlashBaseLayer1D( |
|
|
hidden_size=hidden_size, |
|
|
num_attention_heads=num_attention_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
mlp_bias=mlp_bias, |
|
|
attn_drop_rate=attn_drop_rate, |
|
|
drop_path_rate=drop_path_rate, |
|
|
dtype=dtype, |
|
|
layer_norm_epsilon=layer_norm_epsilon, |
|
|
checkpoint=lid < checkpoint_layer_num, |
|
|
layer_idx=lid + start_layer_idx, |
|
|
residual_in_fp32=residual_in_fp32, |
|
|
device=device, |
|
|
norm_type=norm_type, |
|
|
dropout_selective_checkpoint=dropout_selective_checkpoint, |
|
|
use_scaled_init=use_scaled_init, |
|
|
use_swiglu=use_swiglu, |
|
|
use_flash_attn=use_flash_attn, |
|
|
) |
|
|
for lid in range(num_layers) |
|
|
] |
|
|
) |
|
|
if last: |
|
|
if norm_type == "rmsnorm": |
|
|
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) |
|
|
else: |
|
|
self.norm = LayerNorm(hidden_size, eps=layer_norm_epsilon) |
|
|
self.head = head_cls( |
|
|
in_features=hidden_size, |
|
|
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, |
|
|
process_group=gpc.get_group(ParallelMode.TENSOR), |
|
|
bias=False, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
weight_scale=embed_grad_scale, |
|
|
) |
|
|
for _, param in self.head.named_parameters(): |
|
|
normal_(std=0.0052)(param) |
|
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1: |
|
|
setattr(param, IS_TENSOR_PARALLEL, True) |
|
|
self.parallel_output = parallel_output |
|
|
|
|
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): |
|
|
|
|
|
if hasattr(self, "embedding"): |
|
|
hidden_states = self.embedding(input_ids) |
|
|
if self.embed_grad_scale != 1: |
|
|
hidden_states = ( |
|
|
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() |
|
|
) |
|
|
if isinstance(cu_seqlens, list): |
|
|
assert len(cu_seqlens) == 1 |
|
|
cu_seqlens = cu_seqlens[0].to(hidden_states.device) |
|
|
|
|
|
if cu_seqlens is not None: |
|
|
cu_seqlens = cu_seqlens.squeeze(0) |
|
|
hidden_states = hidden_states.squeeze(0) |
|
|
|
|
|
|
|
|
if indexes is not None: |
|
|
assert len(indexes) == 1 |
|
|
|
|
|
indexes = indexes[0] |
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None |
|
|
|
|
|
for _, block in enumerate(self.blocks): |
|
|
hidden_states = block( |
|
|
hidden_states, |
|
|
cu_seqlens=cu_seqlens, |
|
|
indexes=indexes, |
|
|
inference_params=inference_params, |
|
|
max_seqlen=max_seqlen, |
|
|
) |
|
|
|
|
|
if hasattr(self, "norm"): |
|
|
hidden_states = self.norm(hidden_states.float()) |
|
|
if hasattr(self, "head"): |
|
|
hidden_states = self.head(hidden_states) |
|
|
|
|
|
if not self.parallel_output: |
|
|
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): |
|
|
""" |
|
|
build generic model 1d |
|
|
|
|
|
Args: |
|
|
num_layers (int): The number of layer. |
|
|
num_chunks (int): The number of partitions in pipeline parallel. |
|
|
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. |
|
|
|
|
|
""" |
|
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) |
|
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) |
|
|
|
|
|
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) |
|
|
parts = all_parts[pipeline_rank] |
|
|
if gpc.is_rank_for_log(): |
|
|
logger.info(f"The layer sharding is {all_parts}.") |
|
|
|
|
|
models = [] |
|
|
|
|
|
for start, end in parts: |
|
|
kwargs["num_layers"] = end - start |
|
|
kwargs["first"] = start == 0 |
|
|
|
|
|
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 |
|
|
kwargs["device"] = device |
|
|
kwargs["start_layer_idx"] = start |
|
|
chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device) |
|
|
|
|
|
models.append(chunk) |
|
|
torch.distributed.barrier() |
|
|
if len(models) == 1: |
|
|
model = models[0] |
|
|
else: |
|
|
model = nn.ModuleList(models) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) |
|
|
def build_vit_model_with_cfg( |
|
|
num_chunks=1, |
|
|
checkpoint=0.0, |
|
|
dtype=torch.float, |
|
|
embed_split_hidden=False, |
|
|
num_layers=48, |
|
|
hidden_size=2048, |
|
|
vocab_size=50304, |
|
|
embed_grad_scale=1, |
|
|
parallel_output=True, |
|
|
num_attention_heads=32, |
|
|
mlp_ratio=4.0, |
|
|
mlp_bias: bool = False, |
|
|
residual_in_fp32=False, |
|
|
norm_type="rmsnorm", |
|
|
drop_path_rate=0, |
|
|
attn_drop_rate=0, |
|
|
apply_post_layer_norm=False, |
|
|
layer_norm_epsilon=1e-5, |
|
|
is_reward=False, |
|
|
dropout_selective_checkpoint=True, |
|
|
use_scaled_init: bool = True, |
|
|
use_swiglu: bool = True, |
|
|
use_flash_attn: bool = True, |
|
|
lvm_config=None, |
|
|
): |
|
|
""" |
|
|
Build model with config. |
|
|
|
|
|
Args: |
|
|
num_chunks (int): The number of partitions in pipeline parallel. 1 by default. |
|
|
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. |
|
|
dtype (torch.dtype): The type of data. torch.float by default. |
|
|
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. |
|
|
False by default. |
|
|
num_layers (int): The number of layer. 48 by default. |
|
|
hidden_size (int): The size of hidden state. 2048 by default. |
|
|
vocab_size (int): The size of vocabulary. 50304 by default. |
|
|
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. |
|
|
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. |
|
|
num_attention_heads (int): The number of attention head. 32 by default. |
|
|
mlp_ratio (int): The ratio of MLP layers. 4.0 by default. |
|
|
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily |
|
|
because this parameter requires inconsistent data types to be passed between pipelines, |
|
|
which requires significant modifications to internlm. |
|
|
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. |
|
|
drop_path_rate (float): The drop path rate rate of input hidden state. 0 by default. |
|
|
attn_drop_rate (float): The dropout rate of attention module. 0 by default. |
|
|
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. |
|
|
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. |
|
|
is_reward (bool): Whether to use reward model. False by default. |
|
|
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. |
|
|
use_scaled_init (bool): Whether to use scaled init. True by default. |
|
|
use_swiglu (bool): Whether to use swiglu. True by default. |
|
|
use_flash_attn (bool): Whether to use flash-attn. True by default. |
|
|
|
|
|
""" |
|
|
|
|
|
cfg = dict( |
|
|
hidden_size=hidden_size, |
|
|
num_attention_heads=num_attention_heads, |
|
|
checkpoint=checkpoint, |
|
|
dtype=dtype, |
|
|
embed_split_hidden=embed_split_hidden, |
|
|
vocab_size=vocab_size, |
|
|
embed_grad_scale=embed_grad_scale, |
|
|
parallel_output=parallel_output, |
|
|
mlp_ratio=mlp_ratio, |
|
|
mlp_bias=mlp_bias, |
|
|
residual_in_fp32=residual_in_fp32, |
|
|
norm_type=norm_type, |
|
|
drop_path_rate=drop_path_rate, |
|
|
attn_drop_rate=attn_drop_rate, |
|
|
layer_norm_epsilon=layer_norm_epsilon, |
|
|
is_reward=is_reward, |
|
|
dropout_selective_checkpoint=dropout_selective_checkpoint, |
|
|
use_scaled_init=use_scaled_init, |
|
|
use_swiglu=use_swiglu, |
|
|
use_flash_attn=use_flash_attn, |
|
|
lvm_config=lvm_config, |
|
|
) |
|
|
|
|
|
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) |
|
|
|