File size: 3,330 Bytes
9823a7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import inspect
import pytest
from bitblas.base import DefaultPolicy, TensorCorePolicy
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
from tvm.testing.utils import *
# pytest.main() wrapper to allow running single test file
def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))
def debug_with_schedule(func, arch, sch_rule):
policy = DefaultPolicy(func=func, arch=arch)
try:
tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)
except Exception:
tags = None
if tags:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)
configs = policy.emit_config(1)
return sch_rule.apply_config(func, configs[0])
def torch_assert_close(tensor_a,
tensor_b,
rtol=1e-2,
atol=1e-3,
max_mismatched_ratio=0.001,
verbose=False):
"""
Custom function to assert that two tensors are "close enough," allowing a specified
percentage of mismatched elements.
Parameters:
----------
tensor_a : torch.Tensor
The first tensor to compare.
tensor_b : torch.Tensor
The second tensor to compare.
rtol : float, optional
Relative tolerance for comparison. Default is 1e-2.
atol : float, optional
Absolute tolerance for comparison. Default is 1e-3.
max_mismatched_ratio : float, optional
Maximum ratio of mismatched elements allowed (relative to the total number of elements).
Default is 0.001 (0.1% of total elements).
Raises:
-------
AssertionError:
If the ratio of mismatched elements exceeds `max_mismatched_ratio`.
"""
import torch
# Compute the absolute difference between the two tensors
diff = torch.abs(tensor_a - tensor_b)
# Compute the maximum allowable difference for each element
max_diff = atol + rtol * torch.abs(tensor_b)
# Identify elements where the difference exceeds the maximum allowable difference
mismatched = diff > max_diff
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Calculate the total number of elements in the tensor
total_elements = tensor_a.numel()
# Compute the allowed mismatched elements based on the ratio
max_allowed_mismatched = int(total_elements * max_mismatched_ratio)
# Print debug information about the mismatch
if verbose:
print(f"Number of mismatched elements: {num_mismatched} / {total_elements} "
f"(allowed: {max_allowed_mismatched})")
# Check if the number of mismatched elements exceeds the allowed threshold
if num_mismatched > max_allowed_mismatched:
raise AssertionError(
f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} "
f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%). "
f"Greatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
else:
return True
|