File size: 5,430 Bytes
fb11af9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from ..utils import logging
from .parallel_state import get_parallel_state
if TYPE_CHECKING:
from torch import nn
from vescale import DeviceMesh
logger = logging.get_logger(__name__)
def build_parallelize_model(
model: "nn.Module",
dp_mode: str,
hf_weight_path: Optional[str] = None,
enable_full_shard: bool = True,
enable_fsdp_offload: bool = False,
enable_mixed_precision: bool = True,
enable_gradient_checkpointing: bool = True,
basic_modules: Optional[List[str]] = None,
enable_reentrant: bool = True,
use_pin_mem_for_offload: bool = True,
) -> Tuple["nn.Module", "DeviceMesh"]:
"""
Build a parallelized model with Vescale.
"""
logger.info_rank0("Apply vescale parallel to the model.")
parallel_state = get_parallel_state()
assert dp_mode in ["fsdp2", "fsdp2-vescale"]
params_stored_in_dtensor = dp_mode == "fsdp2"
mesh = parallel_state.fsdp_mesh
if enable_mixed_precision:
model.float()
module_init_fn = lambda sub_mod, *_: sub_mod # noqa: E731
if hf_weight_path is not None:
from vescale.initialize.hf_utils import parallel_init_module_fn, parallel_load_safetensors
shard_states = parallel_load_safetensors(hf_weight_path)
module_init_fn = parallel_init_module_fn(model, shard_states)
from vescale import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy, fully_shard
if enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
logger.info_rank0("Enable gradient checkpointing.")
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": enable_reentrant})
# mp policy
mp_policy = MixedPrecisionPolicy()
if enable_mixed_precision:
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
)
# cpu off load policy
cpu_offload_policy = OffloadPolicy()
if enable_fsdp_offload:
cpu_offload_policy = CPUOffloadPolicy(pin_memory=use_pin_mem_for_offload)
last_fsdp_module = None
for module in model.modules():
sub_mod_cls_name = module.__class__.__name__
if (sub_mod_cls_name in basic_modules) or (sub_mod_cls_name in model._no_split_modules):
module_init_fn(module)
if enable_fsdp_offload:
module.cpu()
gc.collect()
torch.cuda.empty_cache()
else:
model.cuda()
fully_shard(
module,
mesh=mesh,
reshard_after_forward=enable_full_shard,
mp_policy=mp_policy,
params_stored_in_dtensor=params_stored_in_dtensor,
offload_policy=cpu_offload_policy,
)
# explicit prefetch
if last_fsdp_module is not None:
last_fsdp_module.set_modules_to_forward_prefetch([module])
module.set_modules_to_backward_prefetch([last_fsdp_module])
last_fsdp_module = module
module_init_fn(model)
model = fully_shard(
model,
mesh=mesh,
reshard_after_forward=enable_full_shard,
mp_policy=mp_policy,
params_stored_in_dtensor=params_stored_in_dtensor,
offload_policy=cpu_offload_policy,
)
gc.collect()
torch.cuda.empty_cache()
# NOTE: uncomment below for saving memory fragmentation
model._set_unshard_async_op(True)
# for root module, we don't need to reshard after backward since forward will imediately use it
# model.set_reshard_after_backward(False, recurse=False)
# NOTE: the above line is WRONG in torch-native fsdp2's senmantic, as resulting logic follows:
# -) after backward, it is gradient clip to normalize model.parameters()'s grad
# -) at this time, model.parameters is unsharded param, which has already moved .grad to shard_param.grad, so unshard param.grad is always None
# -) then None grad disable gradient clip, which is WRONG!
# -) Even if we have no clip gradient, the optimizer step gives updated weight, which is never used in the next forward; as optimizer step only updates sharded_param, not unshard param
# -) but next forward of root is already in unsharded state, so never allgather from updated sharded param, which is WRONG again!
if not hasattr(mesh, "ndevice"):
# bytecheckpoint vescale ckpt use vescale device mesh, but here we have torch-native devicemesh, which does not have ndevice attribute
ndevice_func = lambda self: torch.numel(self.mesh) # noqa: E731
mesh.__class__.ndevice = property(ndevice_func)
return model, mesh
|