| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.distributed as dist
|
| import torch.distributed.nn.functional as dnnF
|
| from torch.nn import functional as F
|
|
|
|
|
| def init_dist():
|
| """Initialize distributed backend and return (rank, world, device)."""
|
| if not dist.is_initialized():
|
| backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| dist.init_process_group(backend=backend)
|
|
|
| rank = dist.get_rank()
|
| world = dist.get_world_size()
|
|
|
|
|
| if torch.cuda.is_available():
|
| local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
| torch.cuda.set_device(local_rank)
|
| device = torch.device(f"cuda:{local_rank}")
|
| else:
|
| device = torch.device("cpu")
|
|
|
| return rank, world, device
|
|
|
|
|
| class ColumnParallelLinear(nn.Module):
|
| """
|
| Column-sharded Linear:
|
| - Full weight: [in_features, out_features]
|
| - Each rank holds shard: [in_features, out_features/world_size]
|
| - Forward:
|
| y_local = x @ W_local (+ b_local)
|
| if gather_output → concat(all_gather(y_local)) along last dim
|
| - Backward:
|
| Autograd handles local param grads.
|
| Register hook to all_reduce input grads for DDP-style correctness.
|
| """
|
| def __init__(self, in_features, out_features, bias=True, gather_output=True):
|
| super().__init__()
|
| world = dist.get_world_size()
|
| rank = dist.get_rank()
|
| assert out_features % world == 0, "out_features must be divisible by world size"
|
|
|
| self.world = world
|
| self.rank = rank
|
| self.in_features = in_features
|
| self.out_per_rank = out_features // world
|
| self.gather_output = gather_output
|
|
|
|
|
| self.weight = nn.Parameter(torch.empty(in_features, self.out_per_rank))
|
| self.bias = nn.Parameter(torch.empty(self.out_per_rank)) if bias else None
|
|
|
| self.reset_parameters()
|
|
|
| def reset_parameters(self):
|
| nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
| if self.bias is not None:
|
| fan_in = self.weight.size(0)
|
| bound = 1.0 / (fan_in**0.5)
|
| nn.init.uniform_(self.bias, -bound, bound)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| def _allreduce_input_grad(grad):
|
| if dist.is_initialized() and self.world > 1:
|
| dist.all_reduce(grad, op=dist.ReduceOp.SUM)
|
| return grad
|
|
|
| if x.requires_grad:
|
| x.register_hook(_allreduce_input_grad)
|
|
|
|
|
| y_local = x @ self.weight
|
|
|
| if self.bias is not None:
|
| y_local = y_local + self.bias
|
|
|
| if self.gather_output and self.world > 1:
|
|
|
| parts = dnnF.all_gather(y_local)
|
| y = torch.cat(parts, dim=-1)
|
| return y
|
| else:
|
| return y_local
|
|
|
|
|
| def main():
|
| rank, world, device = init_dist()
|
| torch.manual_seed(0 + rank)
|
|
|
| B, Din, Dout = 4, 8, 12
|
| assert Dout % world == 0
|
|
|
| layer = ColumnParallelLinear(Din, Dout, bias=True, gather_output=True).to(device)
|
|
|
|
|
| x = torch.randn(B, Din, device=device, requires_grad=True)
|
| t = torch.randint(0, Dout, (B,), device=device)
|
|
|
|
|
| y = layer(x)
|
| loss = F.cross_entropy(y, t)
|
| loss.backward()
|
|
|
| if rank == 0:
|
| print("ColumnParallelLinear OK:",
|
| dict(device=str(device),
|
| y=y.shape,
|
| x_grad=(x.grad is not None),
|
| w_grad=(layer.weight.grad is not None),
|
| b_grad=(layer.bias.grad is not None if layer.bias is not None else None)))
|
|
|
| dist.destroy_process_group()
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |