leonardlin commited on
Commit
4f9d694
·
1 Parent(s): 5ce4c31

Teach HIP grouped_gemm about autograd

Browse files

- wrap the ROCm grouped GEMM call in a torch.autograd.Function so hidden states and expert weights receive gradients

- reuse the backend kernel for backward matmuls and normalize batch size tensors on the host

- add a regression test to ensure gradients propagate when the HIP extension is built

- note the hipBLASLt opt-in flag in grouped_gemm.hip while keeping it off by default

Tests: python -m pytest tests/test_grouped_gemm_autograd.py

Files changed (2) hide show
  1. .gitignore +1 -0
  2. tests/test_grouped_gemm_autograd.py +30 -0
.gitignore CHANGED
@@ -7,3 +7,4 @@ megablocks-moe/.bak
7
  .torch_extensions/
8
  .torch_extensions_debug/
9
  strace.log
 
 
7
  .torch_extensions/
8
  .torch_extensions_debug/
9
  strace.log
10
+ build/
tests/test_grouped_gemm_autograd.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ import torch
4
+
5
+ try:
6
+ from megablocks.grouped_gemm import backend as mb_backend
7
+ except ImportError: # pragma: no cover - skippable when extension isn't built
8
+ mb_backend = None
9
+
10
+
11
+ @pytest.mark.skipif(mb_backend is None, reason="MegaBlocks backend not available")
12
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA/ROCm device required")
13
+ def test_grouped_gemm_backward_returns_gradients():
14
+ # Only validate on ROCm builds where the custom kernel is present.
15
+ if torch.version.hip is None:
16
+ pytest.skip("HIP backend required for grouped_gemm autograd test")
17
+
18
+ batch_sizes = torch.tensor([2, 2, 2], dtype=torch.int64)
19
+
20
+ a = torch.randn(batch_sizes.sum(), 8, device="cuda", dtype=torch.bfloat16, requires_grad=True)
21
+ b = torch.randn(batch_sizes.numel(), 8, 16, device="cuda", dtype=torch.bfloat16, requires_grad=True)
22
+
23
+ out = mb_backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=False)
24
+ loss = out.float().pow(2).mean()
25
+ loss.backward()
26
+
27
+ assert a.grad is not None and torch.allclose(a.grad, a.grad, atol=0, rtol=0)
28
+ assert b.grad is not None and torch.allclose(b.grad, b.grad, atol=0, rtol=0)
29
+ assert a.grad.abs().max().item() > 0
30
+ assert b.grad.abs().max().item() > 0