| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.distributed as dist
|
| import torch.distributed.nn.functional as dnnF
|
| from torch.nn import functional as F
|
|
|
|
|
| def init_dist():
|
| """Initialize distributed backend and return (rank, world, device)."""
|
| if not dist.is_initialized():
|
| backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| dist.init_process_group(backend=backend)
|
|
|
| rank = dist.get_rank()
|
| world = dist.get_world_size()
|
|
|
|
|
| if torch.cuda.is_available():
|
| local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
| torch.cuda.set_device(local_rank)
|
| device = torch.device(f"cuda:{local_rank}")
|
| else:
|
| device = torch.device("cpu")
|
|
|
| return rank, world, device
|
|
|
|
|
| class RowParallelLinear(nn.Module):
|
| """
|
| Row-sharded Linear:
|
| - Full weight: [in_features, out_features]
|
| - Each rank holds shard: [in_features/world_size, out_features]
|
| - Forward:
|
| x_local = x[:, in_slice]
|
| y_local = x_local @ W_local
|
| y = SUM_r(y_local) via all_reduce
|
| - Backward:
|
| Autograd computes local grads for W_local and x_local.
|
| No input hook needed since x is already partitioned.
|
| """
|
| def __init__(self, in_features, out_features, bias=True):
|
| super().__init__()
|
| world = dist.get_world_size()
|
| rank = dist.get_rank()
|
| assert in_features % world == 0, "in_features must be divisible by world size"
|
|
|
| self.world = world
|
| self.rank = rank
|
| self.in_per_rank = in_features // world
|
| self.out_features = out_features
|
|
|
|
|
| self.weight = nn.Parameter(torch.empty(self.in_per_rank, out_features))
|
|
|
| self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
|
|
|
|
|
| start = rank * self.in_per_rank
|
| end = start + self.in_per_rank
|
| self._in_slice = slice(start, end)
|
|
|
| self.reset_parameters()
|
|
|
| def reset_parameters(self):
|
| nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
| if self.bias is not None:
|
| fan_in_full = self.weight.size(0) * self.world
|
| bound = 1.0 / (fan_in_full ** 0.5)
|
| nn.init.uniform_(self.bias, -bound, bound)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| x_local = x[:, self._in_slice]
|
|
|
|
|
| y_local = x_local @ self.weight
|
|
|
|
|
| if self.world > 1:
|
| y_local = dnnF.all_reduce(y_local)
|
|
|
|
|
| if self.bias is not None:
|
| y_local = y_local + self.bias
|
|
|
| return y_local
|
|
|
|
|
| def main():
|
| rank, world, device = init_dist()
|
| torch.manual_seed(0 + rank)
|
|
|
| B, Din, Dout = 4, 8, 12
|
| assert Din % world == 0
|
|
|
| layer = RowParallelLinear(Din, Dout, bias=True).to(device)
|
|
|
|
|
| x = torch.randn(B, Din, device=device, requires_grad=True)
|
| t = torch.randint(0, Dout, (B,), device=device)
|
|
|
|
|
| y = layer(x)
|
| loss = F.cross_entropy(y, t)
|
| loss.backward()
|
|
|
| if rank == 0:
|
| print("RowParallelLinear OK:",
|
| dict(device=str(device),
|
| y=y.shape,
|
| x_grad=(x.grad is not None),
|
| w_grad=(layer.weight.grad is not None),
|
| b_grad=(layer.bias.grad is not None if layer.bias is not None else None)))
|
|
|
| dist.destroy_process_group()
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |