| # -*- coding: utf-8 -*- | |
| # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | |
| from typing import Optional | |
| import torch.nn as nn | |
| from torch.distributed import DeviceMesh | |
| from torch.distributed.tensor import DTensor, distribute_module | |
| from torch.distributed.tensor.parallel import ParallelStyle | |
| from torch.distributed.tensor.placement_types import Placement | |
| class PrepareModuleWeight(ParallelStyle): | |
| def __init__(self, *, layouts: Optional[Placement] = None): | |
| super().__init__() | |
| self.layouts = layouts | |
| def _replicate_module_fn( | |
| self, | |
| name: str, | |
| module: nn.Module, | |
| device_mesh: DeviceMesh | |
| ): | |
| for p_name, param in module.named_parameters(): | |
| replicated_param = nn.Parameter( | |
| DTensor.from_local(param, device_mesh, [self.layouts], run_check=False) | |
| ) | |
| module.register_parameter(p_name, replicated_param) | |
| def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | |
| return distribute_module( | |
| module, | |
| device_mesh, | |
| partition_fn=self._replicate_module_fn, | |
| input_fn=None, | |
| output_fn=None | |
| ) | |