Distributed-Transformer-Framework / examples /column_parallel_linear_demo.py
Aravindhan11's picture
Deploy Intelligent Distributed LLaMA Framework
52510e8 verified
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.nn.functional as dnnF
from torch.nn import functional as F
def init_dist():
"""Initialize distributed backend and return (rank, world, device)."""
if not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend)
rank = dist.get_rank()
world = dist.get_world_size()
# Assign each rank its device
if torch.cuda.is_available():
local_rank = int(os.environ.get("LOCAL_RANK", rank)) # fallback to rank
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cpu")
return rank, world, device
class ColumnParallelLinear(nn.Module):
"""
Column-sharded Linear:
- Full weight: [in_features, out_features]
- Each rank holds shard: [in_features, out_features/world_size]
- Forward:
y_local = x @ W_local (+ b_local)
if gather_output → concat(all_gather(y_local)) along last dim
- Backward:
Autograd handles local param grads.
Register hook to all_reduce input grads for DDP-style correctness.
"""
def __init__(self, in_features, out_features, bias=True, gather_output=True):
super().__init__()
world = dist.get_world_size()
rank = dist.get_rank()
assert out_features % world == 0, "out_features must be divisible by world size"
self.world = world
self.rank = rank
self.in_features = in_features
self.out_per_rank = out_features // world
self.gather_output = gather_output
# Local column shard
self.weight = nn.Parameter(torch.empty(in_features, self.out_per_rank))
self.bias = nn.Parameter(torch.empty(self.out_per_rank)) if bias else None
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in = self.weight.size(0)
bound = 1.0 / (fan_in**0.5)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Hook to all-reduce input grad → ensures identical gradients across ranks
def _allreduce_input_grad(grad):
if dist.is_initialized() and self.world > 1:
dist.all_reduce(grad, op=dist.ReduceOp.SUM)
return grad
if x.requires_grad:
x.register_hook(_allreduce_input_grad)
# Local matmul: partial output [B, out_per_rank]
y_local = x @ self.weight
if self.bias is not None:
y_local = y_local + self.bias
if self.gather_output and self.world > 1:
# Autograd-aware gather (differentiable)
parts = dnnF.all_gather(y_local)
y = torch.cat(parts, dim=-1) # [B, out_features]
return y
else:
return y_local
def main():
rank, world, device = init_dist()
torch.manual_seed(0 + rank)
B, Din, Dout = 4, 8, 12
assert Dout % world == 0
layer = ColumnParallelLinear(Din, Dout, bias=True, gather_output=True).to(device)
# Toy input/target
x = torch.randn(B, Din, device=device, requires_grad=True)
t = torch.randint(0, Dout, (B,), device=device)
# Forward + Backward
y = layer(x)
loss = F.cross_entropy(y, t)
loss.backward()
if rank == 0:
print("ColumnParallelLinear OK:",
dict(device=str(device),
y=y.shape,
x_grad=(x.grad is not None),
w_grad=(layer.weight.grad is not None),
b_grad=(layer.bias.grad is not None if layer.bias is not None else None)))
dist.destroy_process_group()
if __name__ == "__main__":
main()