File size: 3,846 Bytes
52510e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.nn.functional as dnnF
from torch.nn import functional as F
def init_dist():
"""Initialize distributed backend and return (rank, world, device)."""
if not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend)
rank = dist.get_rank()
world = dist.get_world_size()
# Assign each rank its device
if torch.cuda.is_available():
local_rank = int(os.environ.get("LOCAL_RANK", rank)) # fallback to rank
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cpu")
return rank, world, device
class RowParallelLinear(nn.Module):
"""
Row-sharded Linear:
- Full weight: [in_features, out_features]
- Each rank holds shard: [in_features/world_size, out_features]
- Forward:
x_local = x[:, in_slice]
y_local = x_local @ W_local
y = SUM_r(y_local) via all_reduce
- Backward:
Autograd computes local grads for W_local and x_local.
No input hook needed since x is already partitioned.
"""
def __init__(self, in_features, out_features, bias=True):
super().__init__()
world = dist.get_world_size()
rank = dist.get_rank()
assert in_features % world == 0, "in_features must be divisible by world size"
self.world = world
self.rank = rank
self.in_per_rank = in_features // world
self.out_features = out_features
# Local row shard of the weight
self.weight = nn.Parameter(torch.empty(self.in_per_rank, out_features))
# Full bias (identical across ranks)
self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
# Slice for local input chunk
start = rank * self.in_per_rank
end = start + self.in_per_rank
self._in_slice = slice(start, end)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in_full = self.weight.size(0) * self.world
bound = 1.0 / (fan_in_full ** 0.5)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Local input chunk
x_local = x[:, self._in_slice]
# Local matmul → partial output [B, Dout]
y_local = x_local @ self.weight
# Sum partial outputs across ranks → full y
if self.world > 1:
y_local = dnnF.all_reduce(y_local) # autograd-safe
# Bias added identically on all ranks
if self.bias is not None:
y_local = y_local + self.bias
return y_local
def main():
rank, world, device = init_dist()
torch.manual_seed(0 + rank)
B, Din, Dout = 4, 8, 12
assert Din % world == 0
layer = RowParallelLinear(Din, Dout, bias=True).to(device)
# Toy input/target
x = torch.randn(B, Din, device=device, requires_grad=True)
t = torch.randint(0, Dout, (B,), device=device)
# Forward + Backward
y = layer(x)
loss = F.cross_entropy(y, t)
loss.backward()
if rank == 0:
print("RowParallelLinear OK:",
dict(device=str(device),
y=y.shape,
x_grad=(x.grad is not None),
w_grad=(layer.weight.grad is not None),
b_grad=(layer.bias.grad is not None if layer.bias is not None else None)))
dist.destroy_process_group()
if __name__ == "__main__":
main() |