DeLVM / InternLM /internlm /model /modeling_vit.py
jirong's picture
Upload folder using huggingface_hub
ee3e701 verified
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from typing import Optional
import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D, Embedding1DLVM
from internlm.model.linear import (
FeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, try_import_LayerNorm
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
from internlm.utils.logger import get_logger
from internlm.utils.registry import MODEL_INITIALIZER
MODEL_TYPE = "ViT"
logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()
LayerNorm = try_import_LayerNorm()
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
class PackedFlashBaseLayer1D(nn.Module):
"""
1D Packed Flash Base Layer.
Args:
hidden_size (int): The hidden size of model. 768 by default.
num_attention_heads (int): The number of attention heads. 12 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
drop_path_rate (float): The drop path rate of the input hidden state. 0.0 by default.
dtype (torch.dtype): Type of data. torch.float by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
layer_idx (int): The index of current layer. 0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
"""
def __init__(
self,
hidden_size: int = 768,
num_attention_heads: int = 12,
mlp_ratio: int = 4,
mlp_bias: bool = False,
attn_drop_rate: float = 0,
drop_path_rate: float = 0.0,
dtype: torch.dtype = torch.float,
layer_norm_epsilon: float = 1e-6,
checkpoint: bool = False,
layer_idx: int = 0,
residual_in_fp32: bool = False,
device: Optional[torch.device] = None,
norm_type: str = "rmsnorm",
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
):
super().__init__()
self.checkpoint = checkpoint
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
self.layer_idx = layer_idx
self.use_flash_attn = use_flash_attn
head_dim = hidden_size // num_attention_heads
self.mixer = MHA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
process_group=gpc.get_group(ParallelMode.TENSOR),
dropout=attn_drop_rate,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
)
self.dropout1 = DropPath(drop_path_rate)
if norm_type == "rmsnorm":
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm1 = LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.mlp = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=mlp_bias,
bias2=mlp_bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
dtype=dtype,
)
for _, param in self.mlp.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.dropout2 = DropPath(drop_path_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
self.return_residual = False
self.reset_parameters()
def reset_parameters(self):
with torch.no_grad():
for name, param in self.mixer.named_parameters():
if param.ndim == 1:
param.data.zero_()
elif "Wqkv" in name:
normal_(std=0.006)(param.data)
elif self.use_scaled_init:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
normal_(std=0.0015)(param.data)
for name, param in self.mlp.named_parameters():
if param.ndim == 1 and "bias" in name:
param.data.zero_()
elif self.use_swiglu:
if self.use_scaled_init and "w2" in name:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
else:
if self.use_scaled_init and "fc1" not in name:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
if self.checkpoint and self.training:
return activation_checkpoint(
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
)
else:
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Attn/MLP(LN(residual))
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
indexes: the length of index is same as hidden states, which stand for the current position
"""
mixer_kwargs = {
"cu_seqlens": cu_seqlens,
"max_seqlen": max_seqlen,
"indexes": indexes,
"inference_params": inference_params,
}
residual = hidden_states
hidden_states = self.norm1(residual.float())
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
hidden_states = self.dropout1(hidden_states)
residual = residual + hidden_states
hidden_states = self.norm2(residual.float())
hidden_states = self.mlp(hidden_states)
hidden_states = self.dropout2(hidden_states)
return hidden_states + residual
class PackedFlashInternLm1D(nn.Module):
"""
1D Packed Flash InternLm.
Args:
num_layers (int): The number of layer. 12 by default.
hidden_size (int): The size of hidden state. 768 by default.
num_attention_heads (int): The number of attention head. 12 by default.
vocab_size (int): The size of vocabulary. 50304 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
drop_path_rate (float): The drop path rate of input hidden state. 0.0 by default.
dtype (torch.dtype): The type of data. torch.float by default.
checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
of layers. 0.0 by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
first (bool): Whether input embedding layer or not. False by default.
last (bool): Whether output embedding layer or not. False by default.
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
True by default.
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
def __init__(
self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
mlp_ratio: int = 4.0,
mlp_bias: bool = False,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
dtype: torch.dtype = torch.float,
checkpoint: float = 0.0,
layer_norm_epsilon: float = 1e-5,
first: bool = False,
last: bool = False,
embed_split_hidden: bool = False,
embed_grad_scale: float = 0.1,
parallel_output: bool = True,
start_layer_idx: int = 0,
device: Optional[torch.device] = None,
residual_in_fp32: bool = False,
norm_type: str = "rmsnorm",
is_reward: bool = False,
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
lvm_config: dict = None,
):
super().__init__()
self.lvm_config = lvm_config
checkpoint_layer_num = int(num_layers * checkpoint)
head_cls = ScaleColumnParallelLinear
if first:
if self.lvm_config.get('enable', False):
self.embedding = Embedding1DLVM(**self.lvm_config.get('embedding_cfg'))
if self.embedding.embed_proj is not None:
for _, param in self.embedding.embed_proj.named_parameters():
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
else:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
else:
self.embedding = ParallelGPT2Embeddings(
embed_dim=hidden_size,
vocab_size=vocab_size,
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
for _, param in self.embedding.named_parameters():
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.embed_grad_scale = embed_grad_scale
self.blocks = nn.ModuleList(
[
PackedFlashBaseLayer1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
mlp_ratio=mlp_ratio,
mlp_bias=mlp_bias,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
dtype=dtype,
layer_norm_epsilon=layer_norm_epsilon,
checkpoint=lid < checkpoint_layer_num,
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
residual_in_fp32=residual_in_fp32,
device=device,
norm_type=norm_type,
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
)
for lid in range(num_layers)
]
)
if last:
if norm_type == "rmsnorm":
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm = LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.head = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
weight_scale=embed_grad_scale,
)
for _, param in self.head.named_parameters():
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
if hasattr(self, "embedding"):
hidden_states = self.embedding(input_ids)
if self.embed_grad_scale != 1:
hidden_states = (
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
)
if isinstance(cu_seqlens, list):
assert len(cu_seqlens) == 1
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.squeeze(0)
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
# the batch dimension with a size of 1 should be directly squeezed off.
if indexes is not None:
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
for _, block in enumerate(self.blocks):
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=inference_params,
max_seqlen=max_seqlen,
)
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "head"):
hidden_states = self.head(hidden_states)
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
return hidden_states
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
"""
build generic model 1d
Args:
num_layers (int): The number of layer.
num_chunks (int): The number of partitions in pipeline parallel.
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
"""
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
if gpc.is_rank_for_log():
logger.info(f"The layer sharding is {all_parts}.")
models = []
for start, end in parts:
kwargs["num_layers"] = end - start
kwargs["first"] = start == 0
# If there is no content in the final layer, assign the last layer.
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
kwargs["device"] = device
kwargs["start_layer_idx"] = start
chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
models.append(chunk)
torch.distributed.barrier()
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
return model
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
def build_vit_model_with_cfg(
num_chunks=1,
checkpoint=0.0,
dtype=torch.float,
embed_split_hidden=False,
num_layers=48,
hidden_size=2048,
vocab_size=50304,
embed_grad_scale=1,
parallel_output=True,
num_attention_heads=32,
mlp_ratio=4.0,
mlp_bias: bool = False,
residual_in_fp32=False,
norm_type="rmsnorm",
drop_path_rate=0,
attn_drop_rate=0,
apply_post_layer_norm=False, # pylint: disable=W0613
layer_norm_epsilon=1e-5,
is_reward=False,
dropout_selective_checkpoint=True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
lvm_config=None,
):
"""
Build model with config.
Args:
num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
dtype (torch.dtype): The type of data. torch.float by default.
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
False by default.
num_layers (int): The number of layer. 48 by default.
hidden_size (int): The size of hidden state. 2048 by default.
vocab_size (int): The size of vocabulary. 50304 by default.
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
num_attention_heads (int): The number of attention head. 32 by default.
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
because this parameter requires inconsistent data types to be passed between pipelines,
which requires significant modifications to internlm.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
drop_path_rate (float): The drop path rate rate of input hidden state. 0 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
is_reward (bool): Whether to use reward model. False by default.
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
cfg = dict(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
vocab_size=vocab_size,
embed_grad_scale=embed_grad_scale,
parallel_output=parallel_output,
mlp_ratio=mlp_ratio,
mlp_bias=mlp_bias,
residual_in_fp32=residual_in_fp32,
norm_type=norm_type,
drop_path_rate=drop_path_rate,
attn_drop_rate=attn_drop_rate,
layer_norm_epsilon=layer_norm_epsilon,
is_reward=is_reward,
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
lvm_config=lvm_config,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)