| import enum |
| import functools |
| import os |
| import subprocess |
| import sys |
| import torch |
| from .numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 |
|
|
|
|
| def assert_equal(ref, tri): |
| if isinstance(ref, torch.Tensor): |
| assert torch.all(ref == tri) |
| else: |
| assert ref == tri |
|
|
|
|
| def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): |
| if tri.dtype.itemsize == 1: |
| ref_as_type = ref.to(tri.dtype) |
| if ref.dtype == tri.dtype: |
| assert torch.all(ref_as_type == tri) |
| return |
| ref = ref_as_type |
|
|
| if maxtol is None: |
| maxtol = 2e-2 |
| if rmstol is None: |
| rmstol = 4e-3 |
| """ |
| Compare reference values against obtained values. |
| """ |
|
|
| |
| ref = ref.to(torch.float32).detach() |
| tri = tri.to(torch.float32).detach() |
| assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}" |
|
|
| |
| inf_mask_ref = torch.isinf(ref) |
| inf_mask_tri = torch.isinf(tri) |
| assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements" |
| refn = torch.where(inf_mask_ref, 0, ref) |
| trin = torch.where(inf_mask_tri, 0, tri) |
|
|
| |
| eps = 1.0e-30 |
| multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) |
| refn *= multiplier |
| trin *= multiplier |
|
|
| ref_rms = torch.sqrt(torch.square(refn).mean()) + eps |
|
|
| rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) |
| max_err = torch.max(rel_err).item() |
| rms_err = torch.sqrt(torch.square(rel_err).mean()).item() |
|
|
| if verbose: |
| print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol)) |
| print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol)) |
|
|
| if max_err > maxtol: |
| bad_idxs = torch.nonzero(rel_err > maxtol) |
| num_nonzero = bad_idxs.size(0) |
| bad_idxs = bad_idxs[:1000] |
| print("%d / %d mismatched elements (shape = %s) at coords %s" % |
| (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())) |
|
|
| bad_idxs = bad_idxs.unbind(-1) |
| print("ref values: ", ref[tuple(bad_idxs)].cpu()) |
| print("tri values: ", tri[tuple(bad_idxs)].cpu()) |
|
|
| assert max_err <= maxtol |
| assert rms_err <= rmstol |
|
|
|
|
| class ComputeSanitizerTool(enum.Enum): |
| MEMCHECK = "memcheck" |
| RACECHECK = "racecheck" |
| SYNCCHECK = "synccheck" |
| INITCHECK = "initcheck" |
|
|
|
|
| def compute_sanitizer(**target_kwargs): |
| """ |
| Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled, |
| to expose potential memory access errors. |
| This decorator requires the `request` fixture to be present. |
| If `run_sanitizer` argument is present and set to False, the sanitizer is not run. |
| Running tests under compute sanitizer requires launching subprocess and is slow, |
| so use sparingly |
| """ |
|
|
| def decorator(test_fn): |
|
|
| @functools.wraps(test_fn) |
| def wrapper(*args, **kwargs): |
| if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1": |
| test_fn(*args, **kwargs) |
| return |
|
|
| import psutil |
|
|
| if target_kwargs.pop("clear_torch_cache", False): |
| |
| |
| torch.cuda.empty_cache() |
| tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK]) |
| assert isinstance(tools_to_check, list), f"{tools_to_check=}" |
| assert all(tool in ComputeSanitizerTool for tool in tools_to_check), ( |
| f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}") |
|
|
| ppid_name = psutil.Process(os.getppid()).exe() |
| run_compute_sanitizer = target_kwargs.items() <= kwargs.items() |
| if "run_sanitizer" in kwargs: |
| run_compute_sanitizer &= kwargs["run_sanitizer"] |
| if run_compute_sanitizer and "compute-sanitizer" not in ppid_name: |
| for tool in tools_to_check: |
| path = os.path.realpath(test_fn.__globals__["__file__"]) |
| |
| env = { |
| "PATH": os.environ["PATH"], |
| "PYTORCH_NO_CUDA_MEMORY_CACHING": "1", |
| "TORCH_SHOW_CPP_STACKTRACES": "1", |
| "CUDA_LAUNCH_BLOCKING": "1", |
| } |
| if "CUDA_VISIBLE_DEVICES" in os.environ: |
| env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] |
| assert "request_fixture" in kwargs, ( |
| "memcheck'ed test must have a (possibly unused) `request` fixture") |
| test_id = kwargs["request_fixture"].node.callspec.id |
| cmd = f"{path}::{test_fn.__name__}[{test_id}]" |
| cmd = [ |
| "compute-sanitizer", |
| "--target-processes=application-only", |
| "--destroy-on-device-error=context", |
| f"--tool={tool.value}", |
| sys.executable, |
| "-m", |
| "pytest", |
| "-vsx", |
| cmd, |
| ] |
| for opt in ["--update_checksum", "--ignore_checksum_error"]: |
| if opt in sys.argv: |
| cmd.append(opt) |
| out = subprocess.run( |
| cmd, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| env=env, |
| ) |
| sanitizer_ok = "ERROR SUMMARY: 0 errors" in str( |
| out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout) |
| test_output = out.stdout |
| if type(test_output) is bytes: |
| test_output = test_output.decode() |
|
|
| fail = False |
| if not sanitizer_ok: |
| print("compute-sanitizer returned an error") |
| fail = True |
| elif out.returncode != 0: |
| print( |
| "The test failed due to some other reason: consider running without compute-sanitizer to verify." |
| ) |
| print(f"{out.returncode=}") |
| fail = True |
|
|
| if fail: |
| print("*****************************************************") |
| print("******************** TEST OUTPUT ********************") |
| print("*****************************************************") |
| print(test_output) |
| print("*****************************************************") |
| print("****************** TEST OUTPUT END ******************") |
| print("*****************************************************") |
| assert None |
| else: |
| test_fn(*args, **kwargs) |
|
|
| return wrapper |
|
|
| return decorator |
|
|
|
|
| def compute_actual_scale(x, dtype): |
| max_finite = { |
| torch.float8_e5m2: MAX_FINITE_FLOAT8E5, |
| torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV, |
| torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8, |
| }[dtype] |
| return x.abs().max() / max_finite |
|
|