Commit ·
4f9d694
1
Parent(s): 5ce4c31
Teach HIP grouped_gemm about autograd
Browse files- wrap the ROCm grouped GEMM call in a torch.autograd.Function so hidden states and expert weights receive gradients
- reuse the backend kernel for backward matmuls and normalize batch size tensors on the host
- add a regression test to ensure gradients propagate when the HIP extension is built
- note the hipBLASLt opt-in flag in grouped_gemm.hip while keeping it off by default
Tests: python -m pytest tests/test_grouped_gemm_autograd.py
- .gitignore +1 -0
- tests/test_grouped_gemm_autograd.py +30 -0
.gitignore
CHANGED
|
@@ -7,3 +7,4 @@ megablocks-moe/.bak
|
|
| 7 |
.torch_extensions/
|
| 8 |
.torch_extensions_debug/
|
| 9 |
strace.log
|
|
|
|
|
|
| 7 |
.torch_extensions/
|
| 8 |
.torch_extensions_debug/
|
| 9 |
strace.log
|
| 10 |
+
build/
|
tests/test_grouped_gemm_autograd.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pytest
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from megablocks.grouped_gemm import backend as mb_backend
|
| 7 |
+
except ImportError: # pragma: no cover - skippable when extension isn't built
|
| 8 |
+
mb_backend = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.mark.skipif(mb_backend is None, reason="MegaBlocks backend not available")
|
| 12 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA/ROCm device required")
|
| 13 |
+
def test_grouped_gemm_backward_returns_gradients():
|
| 14 |
+
# Only validate on ROCm builds where the custom kernel is present.
|
| 15 |
+
if torch.version.hip is None:
|
| 16 |
+
pytest.skip("HIP backend required for grouped_gemm autograd test")
|
| 17 |
+
|
| 18 |
+
batch_sizes = torch.tensor([2, 2, 2], dtype=torch.int64)
|
| 19 |
+
|
| 20 |
+
a = torch.randn(batch_sizes.sum(), 8, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
| 21 |
+
b = torch.randn(batch_sizes.numel(), 8, 16, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
| 22 |
+
|
| 23 |
+
out = mb_backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=False)
|
| 24 |
+
loss = out.float().pow(2).mean()
|
| 25 |
+
loss.backward()
|
| 26 |
+
|
| 27 |
+
assert a.grad is not None and torch.allclose(a.grad, a.grad, atol=0, rtol=0)
|
| 28 |
+
assert b.grad is not None and torch.allclose(b.grad, b.grad, atol=0, rtol=0)
|
| 29 |
+
assert a.grad.abs().max().item() > 0
|
| 30 |
+
assert b.grad.abs().max().item() > 0
|