# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import sys import inspect import pytest from bitblas.base import DefaultPolicy, TensorCorePolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from tvm.testing.utils import * # pytest.main() wrapper to allow running single test file def main(): test_file = inspect.getsourcefile(sys._getframe(1)) sys.exit(pytest.main([test_file] + sys.argv[1:])) def debug_with_schedule(func, arch, sch_rule): policy = DefaultPolicy(func=func, arch=arch) try: tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) except Exception: tags = None if tags: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) configs = policy.emit_config(1) return sch_rule.apply_config(func, configs[0]) def torch_assert_close(tensor_a, tensor_b, rtol=1e-2, atol=1e-3, max_mismatched_ratio=0.001, verbose=False): """ Custom function to assert that two tensors are "close enough," allowing a specified percentage of mismatched elements. Parameters: ---------- tensor_a : torch.Tensor The first tensor to compare. tensor_b : torch.Tensor The second tensor to compare. rtol : float, optional Relative tolerance for comparison. Default is 1e-2. atol : float, optional Absolute tolerance for comparison. Default is 1e-3. max_mismatched_ratio : float, optional Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements). Raises: ------- AssertionError: If the ratio of mismatched elements exceeds `max_mismatched_ratio`. """ import torch # Compute the absolute difference between the two tensors diff = torch.abs(tensor_a - tensor_b) # Compute the maximum allowable difference for each element max_diff = atol + rtol * torch.abs(tensor_b) # Identify elements where the difference exceeds the maximum allowable difference mismatched = diff > max_diff # Count the number of mismatched elements num_mismatched = mismatched.sum().item() # Calculate the total number of elements in the tensor total_elements = tensor_a.numel() # Compute the allowed mismatched elements based on the ratio max_allowed_mismatched = int(total_elements * max_mismatched_ratio) # Print debug information about the mismatch if verbose: print(f"Number of mismatched elements: {num_mismatched} / {total_elements} " f"(allowed: {max_allowed_mismatched})") # Check if the number of mismatched elements exceeds the allowed threshold if num_mismatched > max_allowed_mismatched: raise AssertionError( f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} " f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%). " f"Greatest absolute difference: {diff.max().item()}, " f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.") else: return True