| |
| |
| import sys |
| import inspect |
| import pytest |
| from bitblas.base import DefaultPolicy, TensorCorePolicy |
| from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags |
| from tvm.testing.utils import * |
|
|
|
|
| |
| def main(): |
| test_file = inspect.getsourcefile(sys._getframe(1)) |
| sys.exit(pytest.main([test_file] + sys.argv[1:])) |
|
|
|
|
| def debug_with_schedule(func, arch, sch_rule): |
| policy = DefaultPolicy(func=func, arch=arch) |
| try: |
| tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) |
| except Exception: |
| tags = None |
| if tags: |
| policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) |
| configs = policy.emit_config(1) |
| return sch_rule.apply_config(func, configs[0]) |
|
|
|
|
| def torch_assert_close(tensor_a, |
| tensor_b, |
| rtol=1e-2, |
| atol=1e-3, |
| max_mismatched_ratio=0.001, |
| verbose=False): |
| """ |
| Custom function to assert that two tensors are "close enough," allowing a specified |
| percentage of mismatched elements. |
| |
| Parameters: |
| ---------- |
| tensor_a : torch.Tensor |
| The first tensor to compare. |
| tensor_b : torch.Tensor |
| The second tensor to compare. |
| rtol : float, optional |
| Relative tolerance for comparison. Default is 1e-2. |
| atol : float, optional |
| Absolute tolerance for comparison. Default is 1e-3. |
| max_mismatched_ratio : float, optional |
| Maximum ratio of mismatched elements allowed (relative to the total number of elements). |
| Default is 0.001 (0.1% of total elements). |
| |
| Raises: |
| ------- |
| AssertionError: |
| If the ratio of mismatched elements exceeds `max_mismatched_ratio`. |
| """ |
| import torch |
|
|
| |
| diff = torch.abs(tensor_a - tensor_b) |
|
|
| |
| max_diff = atol + rtol * torch.abs(tensor_b) |
|
|
| |
| mismatched = diff > max_diff |
|
|
| |
| num_mismatched = mismatched.sum().item() |
|
|
| |
| total_elements = tensor_a.numel() |
|
|
| |
| max_allowed_mismatched = int(total_elements * max_mismatched_ratio) |
|
|
| |
| if verbose: |
| print(f"Number of mismatched elements: {num_mismatched} / {total_elements} " |
| f"(allowed: {max_allowed_mismatched})") |
|
|
| |
| if num_mismatched > max_allowed_mismatched: |
| raise AssertionError( |
| f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} " |
| f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%). " |
| f"Greatest absolute difference: {diff.max().item()}, " |
| f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.") |
| else: |
| return True |
|
|