File size: 17,655 Bytes
fb11af9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from functools import partial
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state_if_fully_sharded_module
from torch.distributed.fsdp._runtime_utils import _lazy_init
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.checkpoint import create_selective_checkpoint_contexts, noop_context_fn
from ..models import load_model_weights
from ..utils import logging
from ..utils.import_utils import is_torch_version_greater_than
from .checkpoint import CheckpointFunction
from .fsdp import (
clip_grad_norm_,
init_fsdp_fn,
parallel_init_fsdp_fn,
parallel_load_safetensors,
register_checkpoint_extension,
)
from .parallel_state import get_parallel_state
from .utils import get_module_from_path, set_module_from_path
if is_torch_version_greater_than("2.4"):
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.tensor.parallel import parallelize_module
logger = logging.get_logger(__name__)
def verbose_fsdp_grouping(model, prefix="", depth=0):
indent = " " * depth
for name, child in model.named_children():
if isinstance(child, FullyShardedDataParallel):
module_names = [m_name for m_name, _ in child.named_modules()][1:] # [1:] ζι€θͺθΊ«
strategy = child.sharding_strategy
logger.debug_rank0(f"{indent}βββ [FSDP Group] {prefix}{name}")
logger.debug_rank0(
f"{indent}β βββ Sharding Strategy: {strategy}, Mixed Precision: {child.mixed_precision}"
)
logger.debug_rank0(f"{indent}β βββ Contains Modules: {module_names}")
verbose_fsdp_grouping(child, prefix=f"{prefix}{name}.", depth=depth + 1)
else:
verbose_fsdp_grouping(child, prefix=f"{prefix}{name}.", depth=depth)
def build_parallelize_model(
model: "nn.Module",
weights_path: Optional[str] = None,
sharding_plan: Optional[Dict[str, Any]] = None,
enable_full_shard: bool = True,
enable_mixed_precision: bool = True,
enable_fp32: bool = False,
enable_gradient_checkpointing: bool = True,
basic_modules: Optional[List[str]] = None,
fsdp_llm_blocks: bool = True,
ignore_norm: bool = False,
use_depth_align: bool = False,
ignore_depth: bool = False,
**kwargs,
) -> "nn.Module":
"""
Applies parallel strategies to the model.
"""
parallel_state = get_parallel_state()
fsdp_no_shard_states = None
if not parallel_state.fsdp_enabled:
if kwargs.get("init_device") != "cuda":
raise ValueError("Only FSDP training supports `init_device=cpu` or `init_device=meta`.")
if kwargs.pop("enable_fsdp_offload", False):
raise ValueError("Only FSDP training supports `enable_fsdp_offload`.")
if enable_mixed_precision: # upcast to float32 before feed it to optimizer
model = model.float()
if enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
logger.info_rank0("Enable gradient checkpointing.")
use_reentrant = kwargs.pop("enable_reentrant", False)
if use_reentrant:
torch.utils.checkpoint.CheckpointFunction = CheckpointFunction
ops_to_save = kwargs.pop("ops_to_save", None)
context_fn = (
partial(create_selective_checkpoint_contexts, ops_to_save) if ops_to_save is not None else noop_context_fn
)
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": use_reentrant, "context_fn": context_fn}
)
if parallel_state.tp_enabled:
logger.info_rank0("Apply tensor parallel to the model.")
model = parallelize_module(
model,
device_mesh=parallel_state.tp_mesh,
)
if parallel_state.ep_enabled:
parallel_plan = model.get_parallel_plan()
ep_param_suffix = parallel_plan.ep_param_suffix
fqn2spec_info = parallel_plan.apply(model, parallel_state.ep_fsdp_device_mesh)
fsdp_no_shard_states_fqn_to_module = parallel_plan.get_fsdp_no_shard_info(model)
fsdp_no_shard_states = list(fsdp_no_shard_states_fqn_to_module.values())
fsdp_no_shard_states_fqn = list(fsdp_no_shard_states_fqn_to_module.keys())
logger.info_rank0(f"Apply expert parallel to the model successfully.\nEP modules: {fsdp_no_shard_states_fqn}.")
else:
fqn2spec_info = None
ep_param_suffix = None
fsdp_no_shard_states = None
fsdp_no_shard_states_fqn = None
if parallel_state.fsdp_enabled:
logger.info_rank0(f"Apply data parallel to the model: {parallel_state.dp_mode}.")
if parallel_state.dp_mode == "fsdp2":
fsdp_kwargs = {
"mesh": parallel_state.fsdp_mesh,
"reshard_after_forward": enable_full_shard,
**kwargs.pop("fsdp_kwargs", {}),
}
if enable_mixed_precision and not enable_fp32:
logger.info_rank0("Enable mixed precision training.")
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
)
fsdp_kwargs["mp_policy"] = mp_policy
elif enable_fp32:
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
output_dtype=torch.float32,
)
fsdp_kwargs["mp_policy"] = mp_policy
if ignore_norm:
ignored_modules = set()
for layer in model.model.qwenvl_with_expert.qwenvl.language_model.model.layers:
ignored_modules.add(layer.input_layernorm.weight)
ignored_modules.add(layer.post_attention_layernorm.weight)
for expert_layers in model.model.qwenvl_with_expert.qwen_expert.model.layers:
ignored_modules.add(expert_layers.input_layernorm.weight)
ignored_modules.add(expert_layers.post_attention_layernorm.weight)
fsdp_kwargs["ignored_params"] = ignored_modules
mp_fsdp_kwargs = {
"mesh": parallel_state.fsdp_mesh,
"reshard_after_forward": enable_full_shard,
**kwargs.pop("fsdp_kwargs", {}),
}
if use_depth_align and ignore_depth:
model.model.dav2_backbone.to(torch.bfloat16)
model.model.dav2_head.to(torch.bfloat16)
model.model.dav2_backbone.eval()
model.model.dav2_head.eval()
ignored_modules = set()
for param in model.model.dav2_backbone.parameters():
param.requires_grad = False
ignored_modules.add(param)
for param in model.model.dav2_head.parameters():
param.requires_grad = False
ignored_modules.add(param)
mp_fsdp_kwargs["ignored_params"] = ignored_modules
mp_fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
)
ignore_modules_in_mixed_precision = tuple()
if hasattr(model, "get_ignore_modules_in_mixed_precision"):
ignore_modules_in_mixed_precision = model.get_ignore_modules_in_mixed_precision()
def apply_fsdp_to_decoder_blocks(module: "nn.Module") -> None:
if module.__class__.__name__ in basic_modules or module.__class__ in ignore_modules_in_mixed_precision:
logger.debug(f"Apply FSDP2 to {module.__class__.__name__}.")
if module.__class__ in ignore_modules_in_mixed_precision:
fully_shard(module, **{k: v for k, v in fsdp_kwargs.items() if k != "mp_policy"})
else:
fully_shard(module, **fsdp_kwargs)
if basic_modules:
model.apply(apply_fsdp_to_decoder_blocks)
elif fsdp_llm_blocks:
layers = model.model.qwenvl_with_expert.qwenvl.language_model.model.layers
expert_layers = model.model.qwenvl_with_expert.qwen_expert.model.layers
if not hasattr(layers, '__iter__') or not hasattr(expert_layers, '__iter__'):
raise TypeError("Expected 'layers' to be a module list or container.")
logger.info_rank0(f"Applying FSDP to {len(layers)} transformer layers in Paligemma and Gemma decoder.")
for i, layer in enumerate(layers):
logger.debug(f"Sharding layer {i} ({layer.__class__.__name__})")
fully_shard(layer, **fsdp_kwargs)
for i, layer in enumerate(expert_layers):
logger.debug(f"Sharding layer {i} ({layer.__class__.__name__})")
fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **mp_fsdp_kwargs)
if kwargs.get("init_device") == "meta":
if weights_path is None:
# shard init empty model with fsdp2
model.to_empty(device="cuda")
model.init_weights()
else:
from torch.distributed.tensor import distribute_tensor
load_model_weights(model, weights_path, "cuda", dtensor_factory=distribute_tensor)
elif parallel_state.dp_mode == "fsdp1":
wrap_policy = partial(
lambda_auto_wrap_policy, lambda_fn=lambda module: module.__class__.__name__ in basic_modules
)
# set fsdp/hsdp sharding strategy
if parallel_state.fsdp_mesh.ndim > 1 and parallel_state.fsdp_mesh.size() > 1:
strategy = ShardingStrategy.HYBRID_SHARD
else:
strategy = ShardingStrategy.FULL_SHARD
fsdp_kwargs = {
"auto_wrap_policy": wrap_policy,
"ignored_states": fsdp_no_shard_states,
"device_id": torch.cuda.current_device(),
"sharding_strategy": strategy if enable_full_shard else ShardingStrategy.NO_SHARD,
"use_orig_params": True,
}
fsdp_kwargs["device_mesh"] = parallel_state.fsdp_mesh
fsdp_kwargs.update(kwargs.pop("fsdp_kwargs", {}))
if enable_mixed_precision:
logger.info_rank0("Enable mixed precision training.")
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
if hasattr(model, "get_ignore_modules_in_mixed_precision"):
mixed_precision._module_classes_to_ignore += model.get_ignore_modules_in_mixed_precision()
fsdp_kwargs["mixed_precision"] = mixed_precision
if kwargs.get("init_device") == "cpu":
logger.info_rank0("Enable rank0-only initialization.")
fsdp_kwargs["sync_module_states"] = True
if parallel_state.global_rank != 0:
fsdp_kwargs["param_init_fn"] = init_fsdp_fn(model, device="cuda")
elif kwargs.get("init_device") == "meta":
# assert weights_path is not None, "`weights_path` must be provided when `init_device=meta` for fsdp1."
logger.info_rank0("Enable meta initialization.")
if weights_path is None:
logger.info_rank0("weights_path is None during meta initialization.")
ignore_param_names = (
[".".join([fqn, k]) for fqn in fsdp_no_shard_states_fqn for k in ep_param_suffix]
if fsdp_no_shard_states_fqn is not None
else None
)
shard_states = (
parallel_load_safetensors(weights_path, ignore_param_name=ignore_param_names)
if weights_path
else kwargs.get("state_dict", {})
)
fsdp_kwargs["param_init_fn"] = parallel_init_fsdp_fn(
model, shard_states, ignore_param_name=ignore_param_names
)
if kwargs.pop("enable_fsdp_offload", False):
logger.info_rank0("Enable offloading for parameters & gradients & optimizer states.")
fsdp_kwargs["cpu_offload"] = CPUOffload(offload_params=True)
if kwargs.pop("enable_forward_prefetch", False):
fsdp_kwargs["forward_prefetch"] = True
else:
fsdp_kwargs["forward_prefetch"] = False
fsdp_kwargs["backward_prefetch"] = None
# FULLY_SHARD first
model = FullyShardedDataParallel(model, **fsdp_kwargs)
if fsdp_no_shard_states is not None:
# apply NO_SHARD the ignored_states, but wrap into DDP
if parallel_state.ep_fsdp_mesh["ep_fsdp"].size() == 1:
moe_sharding_strategy = ShardingStrategy.NO_SHARD
ep_fsdp_device_mesh = parallel_state.fsdp_mesh
else:
moe_sharding_strategy = ShardingStrategy.FULL_SHARD
ep_fsdp_device_mesh = parallel_state.ep_fsdp_mesh["ep_fsdp"]
logger.info_rank0(f"Apply {moe_sharding_strategy} states on '{fsdp_no_shard_states_fqn}'.")
fsdp_kwargs.pop("ignored_states", None)
fsdp_kwargs.pop("auto_wrap_policy", None)
fsdp_kwargs["sharding_strategy"] = moe_sharding_strategy
fsdp_kwargs["device_mesh"] = ep_fsdp_device_mesh
logger.info_rank0(f"{ep_fsdp_device_mesh=}")
for fqn in fsdp_no_shard_states_fqn:
no_shard_module = get_module_from_path(model, fqn)
if kwargs.get("init_device") == "meta":
specific_param_name = [".".join([fqn, k]) for k in ep_param_suffix]
shard_states = (
parallel_load_safetensors(weights_path, specific_param_name=specific_param_name)
if weights_path
else {}
)
if weights_path:
for suffix in ep_param_suffix:
shard_states[suffix] = shard_states.pop(".".join([fqn, suffix]))
fsdp_kwargs["param_init_fn"] = parallel_init_fsdp_fn(
no_shard_module, shard_states, specific_param_name=ep_param_suffix
)
fsdp_module = FullyShardedDataParallel(no_shard_module, **fsdp_kwargs)
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(fsdp_module)
fsdp_state._gradient_postdivide_factor *= parallel_state.ep_size
set_module_from_path(model, fqn, fsdp_module)
_lazy_init(model, model)
# Apply fsdp extension to FSDP model
save_hook_mesh = parallel_state.ep_fsdp_device_mesh if parallel_state.ep_enabled else None
logger.info_rank0("Register Checkpoints Extension hook to the model")
register_checkpoint_extension(
fsdp_model=model,
save_hook_mesh=save_hook_mesh,
fqn2spec_info=fqn2spec_info,
)
if parallel_state.ep_enabled:
model.clip_grad_norm_ = types.MethodType(clip_grad_norm_, model)
verbose_fsdp_grouping(model)
else:
ddp_kwargs = {"device_ids": [parallel_state.local_rank]}
if enable_mixed_precision:
logger.info_rank0("Enable mixed precision training.")
if enable_fp32:
mixed_precision = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
else:
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
ddp_kwargs["mixed_precision"] = mixed_precision
model = DDP(model, **ddp_kwargs)
return model
|