| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
| class ColumnParallelLinearDemo(nn.Module):
|
| """
|
| Simplified Column Parallel Linear for single-process demo.
|
| In real distributed training, each rank holds a shard of the weight matrix.
|
| """
|
| def __init__(self, in_features, out_features, num_shards=2, bias=True):
|
| super().__init__()
|
| self.num_shards = num_shards
|
| self.in_features = in_features
|
| self.out_per_shard = out_features // num_shards
|
|
|
|
|
| self.weight_shards = nn.ParameterList([
|
| nn.Parameter(torch.empty(in_features, self.out_per_shard))
|
| for _ in range(num_shards)
|
| ])
|
|
|
| self.bias_shards = nn.ParameterList([
|
| nn.Parameter(torch.empty(self.out_per_shard))
|
| for _ in range(num_shards)
|
| ]) if bias else None
|
|
|
| self.reset_parameters()
|
|
|
| def reset_parameters(self):
|
| for w in self.weight_shards:
|
| nn.init.kaiming_uniform_(w, a=5**0.5)
|
| if self.bias_shards is not None:
|
| for b in self.bias_shards:
|
| fan_in = b.size(0)
|
| bound = 1.0 / (fan_in**0.5)
|
| nn.init.uniform_(b, -bound, bound)
|
|
|
| def forward(self, x):
|
|
|
| partial_outputs = []
|
| for i in range(self.num_shards):
|
| y_local = x @ self.weight_shards[i]
|
| if self.bias_shards is not None:
|
| y_local = y_local + self.bias_shards[i]
|
| partial_outputs.append(y_local)
|
|
|
|
|
| y = torch.cat(partial_outputs, dim=-1)
|
| return y
|
|
|
|
|
| def main():
|
| print("=== mini-trainer Simple Demo ===\n")
|
| print("This demonstrates Column Parallel Linear (Tensor Parallelism)\n")
|
|
|
| B, Din, Dout = 4, 8, 12
|
| num_shards = 2
|
|
|
| print(f"Batch size: {B}, Input dim: {Din}, Output dim: {Dout}")
|
| print(f"Number of shards (simulating GPUs): {num_shards}")
|
| print(f"Each shard handles: {Dout // num_shards} output features\n")
|
|
|
| layer = ColumnParallelLinearDemo(Din, Dout, num_shards=num_shards, bias=True)
|
|
|
|
|
| x = torch.randn(B, Din, requires_grad=True)
|
| t = torch.randint(0, Dout, (B,))
|
|
|
| print(f"Input shape: {x.shape}")
|
|
|
|
|
| y = layer(x)
|
| print(f"Output shape: {y.shape}")
|
|
|
|
|
| loss = F.cross_entropy(y, t)
|
| print(f"Loss: {loss.item():.4f}")
|
|
|
| loss.backward()
|
|
|
| print(f"\nGradients computed:")
|
| print(f" Input grad: {x.grad is not None}")
|
| print(f" Weight grads: {all(w.grad is not None for w in layer.weight_shards)}")
|
| print(f" Bias grads: {all(b.grad is not None for b in layer.bias_shards)}")
|
|
|
| print("\n=== Demo Complete ===")
|
| print("\nIn real distributed training:")
|
| print(" - Each GPU would hold one weight shard")
|
| print(" - Communication happens via all_gather/all_reduce")
|
| print(" - This enables training larger models across multiple GPUs")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|