| import abc | |
| import torch.nn as nn | |
| class Sharder(abc.ABC): | |
| """ | |
| This is an interface which allows user to create more advanced | |
| sharding strategies that are not easily be composed by the | |
| `ShardingSpec`. | |
| :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could | |
| take an object of the `Sharder` and call `shard` to shard the module, | |
| then replace the original module with sharded module returned. | |
| """ | |
| def shard(self, module: nn.Module) -> nn.Module: | |
| """ | |
| Shard a module base on the implementation of this method, and | |
| return the sharded version of the module. | |
| Args: | |
| module (:class:`torch.nn.Module`): | |
| The module to apply sharding to. | |
| Returns: | |
| A :class:`torch.nn.Module` object that represents a module | |
| that's already been sharded. | |
| """ | |