import subprocess from packaging import version import torch try: import triton.language as tl # noqa: F401 import triton # noqa: F401 triton_available = True except ImportError: triton_available = False _NF4_QUANT_TABLE = torch.tensor( [ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0, ], dtype=torch.float32, device="xpu" if hasattr(torch, "xpu") and torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. ) _FP4_QUANT_TABLE = torch.tensor( [ 0.0000, 0.0052, 0.6667, 1.0000, 0.3333, 0.5000, 0.1667, 0.2500, 0.0000, -0.0052, -0.6667, -1.0000, -0.3333, -0.5000, -0.1667, -0.2500, ], dtype=torch.float32, device="xpu" if hasattr(torch, "xpu") and torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. ) CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} def get_gaudi_sw_version(): """ Returns the installed version of Gaudi SW. """ output = subprocess.run( "pip list | grep habana-torch-plugin", shell=True, text=True, capture_output=True, ) # If grep return nothing if not output.stdout.strip(): return None return version.parse(output.stdout.split("\n")[0].split()[-1]) GAUDI_SW_VER = get_gaudi_sw_version()