File size: 4,249 Bytes
fb11af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass
from typing import Dict, Union

import torch
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard

from ..utils import logging
from .utils import check_fqn_match, get_module_from_path, set_module_from_path


logger = logging.get_logger(__name__)


@dataclass
class SpecInfo:
    ep_fsdp_mesh: DeviceMesh
    placement: Union[Shard, Replicate]
    fqn: str

    @property
    def ep_mesh(self):
        if self.ep_fsdp_mesh is not None:
            return self.ep_fsdp_mesh["ep"]
        else:
            return None


class ParallelPlan:
    def __init__(self, ep_plan: Dict[str, Shard]):
        self.ep_plan = ep_plan
        self.ep_param_suffix = {k.split(".")[-1] for k in ep_plan.keys()}
        self.fsdp_no_shard_module = {".".join(list(ep_plan.keys())[0].split(".")[:-1])}

    def apply(self, model: nn.Module, ep_fsdp_mesh: DeviceMesh):
        """
        ep_fsdp_mesh: [replicate, replicate, ... , shard]
        """
        ep_mesh = ep_fsdp_mesh["ep"]
        # ep_plan
        fqn2spec_info = {}
        if self.ep_plan:
            ep_size = ep_mesh.size(-1)
            ep_replicate = [Replicate() for _ in range(ep_mesh.ndim)]
            for fqn, param in model.named_parameters():
                for fqn_pattern, shard in self.ep_plan.items():
                    if check_fqn_match(fqn_pattern, fqn):
                        assert param.size(shard.dim) % ep_size == 0
                        ep_placement = ep_replicate[:-1] + [shard]
                        dtensor = DTensor.from_local(
                            local_tensor=param.data, device_mesh=ep_mesh, placements=ep_replicate
                        )
                        dtensor = dtensor.redistribute(device_mesh=ep_mesh, placements=ep_placement)
                        local_chunk = torch.nn.Parameter(dtensor.to_local(), requires_grad=param.requires_grad)
                        local_chunk.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn)
                        set_module_from_path(model, fqn, local_chunk)
                        fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn)
                        break
                if fqn not in fqn2spec_info:  # not sharded
                    param.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn)
                    fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn)
        for param in model.parameters():
            assert hasattr(param, "spec_info"), f"Internal Error: {param} is omitted"

        return fqn2spec_info

    def get_fsdp_no_shard_info(self, model: nn.Module):
        if self.fsdp_no_shard_module is None:
            return None

        fsdp_no_shard_states_fqn_to_module = {}
        for fqn, param in model.named_modules():
            for no_shard_pattern in self.fsdp_no_shard_module:
                if check_fqn_match(no_shard_pattern, fqn):
                    fsdp_no_shard_states_fqn_to_module[fqn] = get_module_from_path(model, fqn)
        assert len(fsdp_no_shard_states_fqn_to_module) > 0, "no module in model match `fsdp_no_shard_module`"

        return fsdp_no_shard_states_fqn_to_module

    def update_prefix(self, prefix: str):
        """
        Update ep_plan when model is wrappered.
        """
        self.ep_plan = {prefix + "." + k: v for k, v in self.ep_plan.items()}
        self.ep_param_suffix = {k.split(".")[-1] for k in self.ep_plan.keys()}
        self.fsdp_no_shard_module = {".".join(list(self.ep_plan.keys())[0].split(".")[:-1])}