File size: 4,040 Bytes
52510e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.nn.functional as dnnF
from torch.nn import functional as F


def init_dist():
    """Initialize distributed backend and return (rank, world, device)."""
    if not dist.is_initialized():
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        dist.init_process_group(backend=backend)

    rank = dist.get_rank()
    world = dist.get_world_size()

    # Assign each rank its device
    if torch.cuda.is_available():
        local_rank = int(os.environ.get("LOCAL_RANK", rank))  # fallback to rank
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cpu")

    return rank, world, device


class ColumnParallelLinear(nn.Module):
    """

    Column-sharded Linear:

      - Full weight: [in_features, out_features]

      - Each rank holds shard: [in_features, out_features/world_size]

      - Forward:

          y_local = x @ W_local (+ b_local)

          if gather_output → concat(all_gather(y_local)) along last dim

      - Backward:

          Autograd handles local param grads.

          Register hook to all_reduce input grads for DDP-style correctness.

    """
    def __init__(self, in_features, out_features, bias=True, gather_output=True):
        super().__init__()
        world = dist.get_world_size()
        rank = dist.get_rank()
        assert out_features % world == 0, "out_features must be divisible by world size"

        self.world = world
        self.rank = rank
        self.in_features = in_features
        self.out_per_rank = out_features // world
        self.gather_output = gather_output

        # Local column shard
        self.weight = nn.Parameter(torch.empty(in_features, self.out_per_rank))
        self.bias = nn.Parameter(torch.empty(self.out_per_rank)) if bias else None

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=5**0.5)
        if self.bias is not None:
            fan_in = self.weight.size(0)
            bound = 1.0 / (fan_in**0.5)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Hook to all-reduce input grad → ensures identical gradients across ranks
        def _allreduce_input_grad(grad):
            if dist.is_initialized() and self.world > 1:
                dist.all_reduce(grad, op=dist.ReduceOp.SUM)
            return grad

        if x.requires_grad:
            x.register_hook(_allreduce_input_grad)

        # Local matmul: partial output [B, out_per_rank]
        y_local = x @ self.weight

        if self.bias is not None:
            y_local = y_local + self.bias

        if self.gather_output and self.world > 1:
            # Autograd-aware gather (differentiable)
            parts = dnnF.all_gather(y_local)
            y = torch.cat(parts, dim=-1)  # [B, out_features]
            return y
        else:
            return y_local


def main():
    rank, world, device = init_dist()
    torch.manual_seed(0 + rank)

    B, Din, Dout = 4, 8, 12
    assert Dout % world == 0

    layer = ColumnParallelLinear(Din, Dout, bias=True, gather_output=True).to(device)

    # Toy input/target
    x = torch.randn(B, Din, device=device, requires_grad=True)
    t = torch.randint(0, Dout, (B,), device=device)

    # Forward + Backward
    y = layer(x)
    loss = F.cross_entropy(y, t)
    loss.backward()

    if rank == 0:
        print("ColumnParallelLinear OK:",
              dict(device=str(device),
                   y=y.shape,
                   x_grad=(x.grad is not None),
                   w_grad=(layer.weight.grad is not None),
                   b_grad=(layer.bias.grad is not None if layer.bias is not None else None)))

    dist.destroy_process_group()


if __name__ == "__main__":
    main()