| | |
| | """Enhanced numerical diagnostic for megablocks.gg_ops.gmm on ROCm builds.""" |
| |
|
| | import pathlib |
| | import sys |
| | from typing import Optional |
| |
|
| | import torch |
| |
|
| |
|
| | def detect_variant(root: pathlib.Path) -> str: |
| | build_dir = root / "build" |
| | variant: Optional[str] = None |
| |
|
| | if (root / "kernels" / "utils.py").exists(): |
| | try: |
| | sys.path.insert(0, str(root)) |
| | from kernels.utils import build_variant as _build_variant |
| |
|
| | variant = _build_variant() |
| | except Exception: |
| | variant = None |
| | finally: |
| | sys.path.pop(0) |
| |
|
| | if variant is None: |
| | candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) |
| | if candidates: |
| | variant = candidates[0].name |
| |
|
| | if variant is None: |
| | raise SystemExit("Could not determine build variant; run build.py first.") |
| |
|
| | return variant |
| |
|
| |
|
| | def main() -> None: |
| | repo_root = pathlib.Path(__file__).resolve().parent.parent |
| | variant = detect_variant(repo_root) |
| | staged_dir = repo_root / "build" / variant |
| |
|
| | if str(staged_dir) not in sys.path: |
| | sys.path.insert(0, str(staged_dir)) |
| | if str(repo_root) not in sys.path: |
| | sys.path.insert(0, str(repo_root)) |
| |
|
| | import megablocks |
| | from tests.test_gg import gmm, randn |
| |
|
| | print(f"Using staged variant: {variant}") |
| | print(f"megablocks module: {megablocks.__file__}") |
| |
|
| | torch.manual_seed(0) |
| |
|
| | z = m = n = k = 128 |
| | trans_b = False |
| |
|
| | a = randn(z, m, k).view(-1, k) |
| | b = randn(z, k, n) if not trans_b else randn(z, n, k) |
| | batch_sizes = torch.tensor([m] * z, device="cpu") |
| |
|
| | |
| | print(f"Input a has NaNs: {torch.isnan(a).any().item()}") |
| | print(f"Input b has NaNs: {torch.isnan(b).any().item()}") |
| | print(f"Input a range: [{a.min().item():.6f}, {a.max().item():.6f}]") |
| | print(f"Input b range: [{b.min().item():.6f}, {b.max().item():.6f}]") |
| |
|
| | a.requires_grad_(True) |
| | b.requires_grad_(True) |
| |
|
| | a_ref = a.detach().clone().requires_grad_(True) |
| | b_ref = b.detach().clone().requires_grad_(True) |
| |
|
| | |
| | ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b) |
| | print(f"Reference computation completed") |
| | print(f"ref has NaNs: {torch.isnan(ref).any().item()}") |
| | print(f"ref range: [{ref.min().item():.6f}, {ref.max().item():.6f}]") |
| |
|
| | |
| | print(f"Running megablocks.gg_ops.gmm...") |
| | out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
| | print(f"megablocks computation completed") |
| |
|
| | print(f"out has NaNs: {torch.isnan(out).any().item()}") |
| | if not torch.isnan(out).all(): |
| | print(f"out range: [{out.min().item():.6f}, {out.max().item():.6f}]") |
| | else: |
| | print("out is all NaN") |
| |
|
| | |
| | print(f"Input a modified: {not torch.equal(a[:5], a_ref[:5])}") |
| | print(f"Input b modified: {not torch.equal(b[0, :5, :5], b_ref[0, :5, :5])}") |
| |
|
| | if not torch.isnan(out).any(): |
| | forward_abs = (out - ref).abs().max().item() |
| | forward_rel = ((out - ref).abs() / (ref.abs() + 1e-9)).max().item() |
| | print(f"forward max abs diff: {forward_abs:.6e}") |
| | print(f"forward max rel diff: {forward_rel:.6e}") |
| | else: |
| | print(f"forward max abs diff: nan") |
| | print(f"forward max rel diff: nan") |
| |
|
| | |
| | out.sum().backward() |
| | ref.sum().backward() |
| |
|
| | print(f"a.grad has NaNs: {torch.isnan(a.grad).any().item()}") |
| | print(f"b.grad has NaNs: {torch.isnan(b.grad).any().item()}") |
| |
|
| | if not torch.isnan(a.grad).any() and not torch.isnan(a_ref.grad).any(): |
| | a_grad_abs = (a.grad - a_ref.grad).abs().max().item() |
| | print(f"a grad max abs diff: {a_grad_abs:.6e}") |
| | else: |
| | print(f"a grad max abs diff: nan") |
| |
|
| | if not torch.isnan(b.grad).any() and not torch.isnan(b_ref.grad).any(): |
| | b_grad_abs = (b.grad - b_ref.grad).abs().max().item() |
| | print(f"b grad max abs diff: {b_grad_abs:.6e}") |
| | else: |
| | print(f"b grad max abs diff: nan") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |