| import pathlib |
| import sys |
|
|
| import torch |
| import pytest |
| from torch.testing import assert_close |
|
|
|
|
| def _ensure_megablocks_importable() -> None: |
| repo_root = pathlib.Path(__file__).resolve().parent.parent |
| build_dir = repo_root / "build" |
| variant = None |
|
|
| utils_path = repo_root / "kernels" / "utils.py" |
| if utils_path.exists(): |
| sys.path.insert(0, str(repo_root)) |
| try: |
| from kernels.utils import build_variant |
|
|
| variant = build_variant() |
| except Exception: |
| variant = None |
| finally: |
| sys.path.remove(str(repo_root)) |
|
|
| if variant is None: |
| candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) |
| if candidates: |
| variant = candidates[0].name |
|
|
| if variant is None: |
| raise RuntimeError("Could not locate staged MegaBlocks build; run build.py before pytest.") |
|
|
| staged_dir = build_dir / variant |
| for path in (staged_dir, repo_root): |
| if str(path) not in sys.path: |
| sys.path.insert(0, str(path)) |
|
|
|
|
| _ensure_megablocks_importable() |
|
|
| import megablocks |
|
|
|
|
| def randn(bs, x, y): |
| out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) |
| return out.cuda().to(torch.bfloat16) |
|
|
|
|
| def gmm(a, b, batch_sizes, trans_b=False): |
| batch_sizes = batch_sizes.cpu().numpy() |
|
|
| out = [] |
| start = 0 |
| for i, size in enumerate(batch_sizes): |
| rhs = b[i, :, :].t() if trans_b else b[i, :, :] |
| out.append(a[start : start + size, :] @ rhs) |
| start += size |
| return torch.cat(out) |
|
|
|
|
| @pytest.mark.parametrize( |
| "z,m,n,k", |
| [ |
| (1, 4, 4, 4), |
| (2, 4, 4, 4), |
| (1, 16, 16, 16), |
| (4, 16, 16, 16), |
| (1, 128, 128, 128), |
| ], |
| ) |
| def test_gmm_forward_backward(z, m, n, k): |
| trans_b = False |
|
|
| torch.manual_seed(0) |
| a = randn(z, m, k).view(-1, k) |
| b = randn(z, k, n) if not trans_b else randn(z, n, k) |
| batch_sizes = torch.tensor([m] * z) |
|
|
| a.requires_grad_(True) |
| b.requires_grad_(True) |
| a_ref = a.detach().clone().requires_grad_(True) |
| b_ref = b.detach().clone().requires_grad_(True) |
|
|
| out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
| expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
|
|
| assert_close(out, expected_out, rtol=1e-2, atol=1e-2) |
|
|
| out.sum().backward() |
| expected_out.sum().backward() |
|
|
| a_grad_diff = (a.grad - a_ref.grad).abs().max().item() |
| b_grad_diff = (b.grad - b_ref.grad).abs().max().item() |
| assert a_grad_diff < 0.15, f"a.grad max diff {a_grad_diff:.4f} exceeds tolerance" |
| assert b_grad_diff < 0.15, f"b.grad max diff {b_grad_diff:.4f} exceeds tolerance" |
|
|
|
|
| def test_gmm_sequence_no_state_contamination(): |
| trans_b = False |
| sequences = [ |
| (1, 4, 4, 4), |
| (2, 4, 4, 4), |
| (1, 16, 16, 16), |
| (4, 16, 16, 16), |
| ] |
|
|
| for z, m, n, k in sequences: |
| torch.manual_seed(0) |
| a = randn(z, m, k).view(-1, k) |
| b = randn(z, k, n) if not trans_b else randn(z, n, k) |
| batch_sizes = torch.tensor([m] * z) |
|
|
| out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
| expected_out = gmm(a, b, batch_sizes, trans_b) |
|
|
| assert_close(out, expected_out, rtol=1e-2, atol=1e-2) |
|
|