| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from dataclasses import dataclass |
| from typing import Dict, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard |
|
|
| from ..utils import logging |
| from .utils import check_fqn_match, get_module_from_path, set_module_from_path |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class SpecInfo: |
| ep_fsdp_mesh: DeviceMesh |
| placement: Union[Shard, Replicate] |
| fqn: str |
|
|
| @property |
| def ep_mesh(self): |
| if self.ep_fsdp_mesh is not None: |
| return self.ep_fsdp_mesh["ep"] |
| else: |
| return None |
|
|
|
|
| class ParallelPlan: |
| def __init__(self, ep_plan: Dict[str, Shard]): |
| self.ep_plan = ep_plan |
| self.ep_param_suffix = {k.split(".")[-1] for k in ep_plan.keys()} |
| self.fsdp_no_shard_module = {".".join(list(ep_plan.keys())[0].split(".")[:-1])} |
|
|
| def apply(self, model: nn.Module, ep_fsdp_mesh: DeviceMesh): |
| """ |
| ep_fsdp_mesh: [replicate, replicate, ... , shard] |
| """ |
| ep_mesh = ep_fsdp_mesh["ep"] |
| |
| fqn2spec_info = {} |
| if self.ep_plan: |
| ep_size = ep_mesh.size(-1) |
| ep_replicate = [Replicate() for _ in range(ep_mesh.ndim)] |
| for fqn, param in model.named_parameters(): |
| for fqn_pattern, shard in self.ep_plan.items(): |
| if check_fqn_match(fqn_pattern, fqn): |
| assert param.size(shard.dim) % ep_size == 0 |
| ep_placement = ep_replicate[:-1] + [shard] |
| dtensor = DTensor.from_local( |
| local_tensor=param.data, device_mesh=ep_mesh, placements=ep_replicate |
| ) |
| dtensor = dtensor.redistribute(device_mesh=ep_mesh, placements=ep_placement) |
| local_chunk = torch.nn.Parameter(dtensor.to_local(), requires_grad=param.requires_grad) |
| local_chunk.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn) |
| set_module_from_path(model, fqn, local_chunk) |
| fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn) |
| break |
| if fqn not in fqn2spec_info: |
| param.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn) |
| fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn) |
| for param in model.parameters(): |
| assert hasattr(param, "spec_info"), f"Internal Error: {param} is omitted" |
|
|
| return fqn2spec_info |
|
|
| def get_fsdp_no_shard_info(self, model: nn.Module): |
| if self.fsdp_no_shard_module is None: |
| return None |
|
|
| fsdp_no_shard_states_fqn_to_module = {} |
| for fqn, param in model.named_modules(): |
| for no_shard_pattern in self.fsdp_no_shard_module: |
| if check_fqn_match(no_shard_pattern, fqn): |
| fsdp_no_shard_states_fqn_to_module[fqn] = get_module_from_path(model, fqn) |
| assert len(fsdp_no_shard_states_fqn_to_module) > 0, "no module in model match `fsdp_no_shard_module`" |
|
|
| return fsdp_no_shard_states_fqn_to_module |
|
|
| def update_prefix(self, prefix: str): |
| """ |
| Update ep_plan when model is wrappered. |
| """ |
| self.ep_plan = {prefix + "." + k: v for k, v in self.ep_plan.items()} |
| self.ep_param_suffix = {k.split(".")[-1] for k in self.ep_plan.keys()} |
| self.fsdp_no_shard_module = {".".join(list(self.ep_plan.keys())[0].split(".")[:-1])} |
|
|