megablocks-hip / tests /test_gg.py
leonardlin's picture
Clean ROCm grouped_gemm fallback and add tests
aeb3812
import pathlib
import sys
import torch
import pytest
from torch.testing import assert_close
def _ensure_megablocks_importable() -> None:
repo_root = pathlib.Path(__file__).resolve().parent.parent
build_dir = repo_root / "build"
variant = None
utils_path = repo_root / "kernels" / "utils.py"
if utils_path.exists():
sys.path.insert(0, str(repo_root))
try:
from kernels.utils import build_variant # type: ignore
variant = build_variant()
except Exception:
variant = None
finally:
sys.path.remove(str(repo_root))
if variant is None:
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
if candidates:
variant = candidates[0].name
if variant is None:
raise RuntimeError("Could not locate staged MegaBlocks build; run build.py before pytest.")
staged_dir = build_dir / variant
for path in (staged_dir, repo_root):
if str(path) not in sys.path:
sys.path.insert(0, str(path))
_ensure_megablocks_importable()
import megablocks
def randn(bs, x, y):
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
return out.cuda().to(torch.bfloat16)
def gmm(a, b, batch_sizes, trans_b=False):
batch_sizes = batch_sizes.cpu().numpy()
out = []
start = 0
for i, size in enumerate(batch_sizes):
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
out.append(a[start : start + size, :] @ rhs)
start += size
return torch.cat(out)
@pytest.mark.parametrize(
"z,m,n,k",
[
(1, 4, 4, 4),
(2, 4, 4, 4),
(1, 16, 16, 16),
(4, 16, 16, 16),
(1, 128, 128, 128),
],
)
def test_gmm_forward_backward(z, m, n, k):
trans_b = False
torch.manual_seed(0)
a = randn(z, m, k).view(-1, k)
b = randn(z, k, n) if not trans_b else randn(z, n, k)
batch_sizes = torch.tensor([m] * z)
a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
assert_close(out, expected_out, rtol=1e-2, atol=1e-2)
out.sum().backward()
expected_out.sum().backward()
a_grad_diff = (a.grad - a_ref.grad).abs().max().item()
b_grad_diff = (b.grad - b_ref.grad).abs().max().item()
assert a_grad_diff < 0.15, f"a.grad max diff {a_grad_diff:.4f} exceeds tolerance"
assert b_grad_diff < 0.15, f"b.grad max diff {b_grad_diff:.4f} exceeds tolerance"
def test_gmm_sequence_no_state_contamination():
trans_b = False
sequences = [
(1, 4, 4, 4),
(2, 4, 4, 4),
(1, 16, 16, 16),
(4, 16, 16, 16),
]
for z, m, n, k in sequences:
torch.manual_seed(0)
a = randn(z, m, k).view(-1, k)
b = randn(z, k, n) if not trans_b else randn(z, n, k)
batch_sizes = torch.tensor([m] * z)
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
expected_out = gmm(a, b, batch_sizes, trans_b)
assert_close(out, expected_out, rtol=1e-2, atol=1e-2)