# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Dict, Union import torch import torch.nn as nn from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from ..utils import logging from .utils import check_fqn_match, get_module_from_path, set_module_from_path logger = logging.get_logger(__name__) @dataclass class SpecInfo: ep_fsdp_mesh: DeviceMesh placement: Union[Shard, Replicate] fqn: str @property def ep_mesh(self): if self.ep_fsdp_mesh is not None: return self.ep_fsdp_mesh["ep"] else: return None class ParallelPlan: def __init__(self, ep_plan: Dict[str, Shard]): self.ep_plan = ep_plan self.ep_param_suffix = {k.split(".")[-1] for k in ep_plan.keys()} self.fsdp_no_shard_module = {".".join(list(ep_plan.keys())[0].split(".")[:-1])} def apply(self, model: nn.Module, ep_fsdp_mesh: DeviceMesh): """ ep_fsdp_mesh: [replicate, replicate, ... , shard] """ ep_mesh = ep_fsdp_mesh["ep"] # ep_plan fqn2spec_info = {} if self.ep_plan: ep_size = ep_mesh.size(-1) ep_replicate = [Replicate() for _ in range(ep_mesh.ndim)] for fqn, param in model.named_parameters(): for fqn_pattern, shard in self.ep_plan.items(): if check_fqn_match(fqn_pattern, fqn): assert param.size(shard.dim) % ep_size == 0 ep_placement = ep_replicate[:-1] + [shard] dtensor = DTensor.from_local( local_tensor=param.data, device_mesh=ep_mesh, placements=ep_replicate ) dtensor = dtensor.redistribute(device_mesh=ep_mesh, placements=ep_placement) local_chunk = torch.nn.Parameter(dtensor.to_local(), requires_grad=param.requires_grad) local_chunk.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn) set_module_from_path(model, fqn, local_chunk) fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn) break if fqn not in fqn2spec_info: # not sharded param.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn) fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn) for param in model.parameters(): assert hasattr(param, "spec_info"), f"Internal Error: {param} is omitted" return fqn2spec_info def get_fsdp_no_shard_info(self, model: nn.Module): if self.fsdp_no_shard_module is None: return None fsdp_no_shard_states_fqn_to_module = {} for fqn, param in model.named_modules(): for no_shard_pattern in self.fsdp_no_shard_module: if check_fqn_match(no_shard_pattern, fqn): fsdp_no_shard_states_fqn_to_module[fqn] = get_module_from_path(model, fqn) assert len(fsdp_no_shard_states_fqn_to_module) > 0, "no module in model match `fsdp_no_shard_module`" return fsdp_no_shard_states_fqn_to_module def update_prefix(self, prefix: str): """ Update ep_plan when model is wrappered. """ self.ep_plan = {prefix + "." + k: v for k, v in self.ep_plan.items()} self.ep_param_suffix = {k.split(".")[-1] for k in self.ep_plan.keys()} self.fsdp_no_shard_module = {".".join(list(self.ep_plan.keys())[0].split(".")[:-1])}