File size: 3,846 Bytes
52510e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.nn.functional as dnnF
from torch.nn import functional as F


def init_dist():
    """Initialize distributed backend and return (rank, world, device)."""
    if not dist.is_initialized():
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        dist.init_process_group(backend=backend)

    rank = dist.get_rank()
    world = dist.get_world_size()

    # Assign each rank its device
    if torch.cuda.is_available():
        local_rank = int(os.environ.get("LOCAL_RANK", rank))  # fallback to rank
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cpu")

    return rank, world, device


class RowParallelLinear(nn.Module):
    """

    Row-sharded Linear:

      - Full weight: [in_features, out_features]

      - Each rank holds shard: [in_features/world_size, out_features]

      - Forward:

          x_local = x[:, in_slice]

          y_local = x_local @ W_local

          y = SUM_r(y_local) via all_reduce

      - Backward:

          Autograd computes local grads for W_local and x_local.

          No input hook needed since x is already partitioned.

    """
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        world = dist.get_world_size()
        rank = dist.get_rank()
        assert in_features % world == 0, "in_features must be divisible by world size"

        self.world = world
        self.rank = rank
        self.in_per_rank = in_features // world
        self.out_features = out_features

        # Local row shard of the weight
        self.weight = nn.Parameter(torch.empty(self.in_per_rank, out_features))
        # Full bias (identical across ranks)
        self.bias = nn.Parameter(torch.empty(out_features)) if bias else None

        # Slice for local input chunk
        start = rank * self.in_per_rank
        end = start + self.in_per_rank
        self._in_slice = slice(start, end)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=5**0.5)
        if self.bias is not None:
            fan_in_full = self.weight.size(0) * self.world
            bound = 1.0 / (fan_in_full ** 0.5)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Local input chunk
        x_local = x[:, self._in_slice]

        # Local matmul → partial output [B, Dout]
        y_local = x_local @ self.weight

        # Sum partial outputs across ranks → full y
        if self.world > 1:
            y_local = dnnF.all_reduce(y_local)  # autograd-safe

        # Bias added identically on all ranks
        if self.bias is not None:
            y_local = y_local + self.bias

        return y_local


def main():
    rank, world, device = init_dist()
    torch.manual_seed(0 + rank)

    B, Din, Dout = 4, 8, 12
    assert Din % world == 0

    layer = RowParallelLinear(Din, Dout, bias=True).to(device)

    # Toy input/target
    x = torch.randn(B, Din, device=device, requires_grad=True)
    t = torch.randint(0, Dout, (B,), device=device)

    # Forward + Backward
    y = layer(x)
    loss = F.cross_entropy(y, t)
    loss.backward()

    if rank == 0:
        print("RowParallelLinear OK:",
              dict(device=str(device),
                   y=y.shape,
                   x_grad=(x.grad is not None),
                   w_grad=(layer.weight.grad is not None),
                   b_grad=(layer.bias.grad is not None if layer.bias is not None else None)))

    dist.destroy_process_group()


if __name__ == "__main__":
    main()