File size: 3,333 Bytes
52510e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F


class ColumnParallelLinearDemo(nn.Module):
    """

    Simplified Column Parallel Linear for single-process demo.

    In real distributed training, each rank holds a shard of the weight matrix.

    """
    def __init__(self, in_features, out_features, num_shards=2, bias=True):
        super().__init__()
        self.num_shards = num_shards
        self.in_features = in_features
        self.out_per_shard = out_features // num_shards
        
        # Simulate sharded weights (in real TP, each rank holds one shard)
        self.weight_shards = nn.ParameterList([
            nn.Parameter(torch.empty(in_features, self.out_per_shard))
            for _ in range(num_shards)
        ])
        
        self.bias_shards = nn.ParameterList([
            nn.Parameter(torch.empty(self.out_per_shard))
            for _ in range(num_shards)
        ]) if bias else None
        
        self.reset_parameters()
    
    def reset_parameters(self):
        for w in self.weight_shards:
            nn.init.kaiming_uniform_(w, a=5**0.5)
        if self.bias_shards is not None:
            for b in self.bias_shards:
                fan_in = b.size(0)
                bound = 1.0 / (fan_in**0.5)
                nn.init.uniform_(b, -bound, bound)
    
    def forward(self, x):
        # Simulate parallel computation: each shard computes partial output
        partial_outputs = []
        for i in range(self.num_shards):
            y_local = x @ self.weight_shards[i]
            if self.bias_shards is not None:
                y_local = y_local + self.bias_shards[i]
            partial_outputs.append(y_local)
        
        # Gather: concatenate partial outputs (all_gather in distributed)
        y = torch.cat(partial_outputs, dim=-1)
        return y


def main():
    print("=== mini-trainer Simple Demo ===\n")
    print("This demonstrates Column Parallel Linear (Tensor Parallelism)\n")
    
    B, Din, Dout = 4, 8, 12
    num_shards = 2
    
    print(f"Batch size: {B}, Input dim: {Din}, Output dim: {Dout}")
    print(f"Number of shards (simulating GPUs): {num_shards}")
    print(f"Each shard handles: {Dout // num_shards} output features\n")
    
    layer = ColumnParallelLinearDemo(Din, Dout, num_shards=num_shards, bias=True)
    
    # Input
    x = torch.randn(B, Din, requires_grad=True)
    t = torch.randint(0, Dout, (B,))
    
    print(f"Input shape: {x.shape}")
    
    # Forward
    y = layer(x)
    print(f"Output shape: {y.shape}")
    
    # Loss and backward
    loss = F.cross_entropy(y, t)
    print(f"Loss: {loss.item():.4f}")
    
    loss.backward()
    
    print(f"\nGradients computed:")
    print(f"  Input grad: {x.grad is not None}")
    print(f"  Weight grads: {all(w.grad is not None for w in layer.weight_shards)}")
    print(f"  Bias grads: {all(b.grad is not None for b in layer.bias_shards)}")
    
    print("\n=== Demo Complete ===")
    print("\nIn real distributed training:")
    print("  - Each GPU would hold one weight shard")
    print("  - Communication happens via all_gather/all_reduce")
    print("  - This enables training larger models across multiple GPUs")


if __name__ == "__main__":
    main()