| | |
| |
|
| | import contextlib |
| | import functools |
| | import os |
| | from enum import Enum |
| | from functools import lru_cache |
| | from typing import Any, Callable, Dict, Literal, Optional, Tuple |
| |
|
| | import torch |
| | import triton |
| | from packaging import version |
| |
|
| |
|
| | def tensor_cache( |
| | fn: Callable[..., torch.Tensor] |
| | ) -> Callable[..., torch.Tensor]: |
| | """ |
| | A decorator that caches the most recent result of a function with tensor inputs. |
| | |
| | This decorator will store the output of the decorated function for the most recent set of input tensors. |
| | If the function is called again with the same input tensors, it will return the cached result. |
| | |
| | |
| | Args: |
| | fn (Callable[..., torch.Tensor]): |
| | The function to be decorated. It should take tensor inputs and return tensor outputs. |
| | |
| | Returns: |
| | Callable[..., torch.Tensor]: |
| | A wrapped version of the input function with single-entry caching. |
| | """ |
| | last_args: Optional[Tuple] = None |
| | last_kwargs: Optional[Dict] = None |
| | last_result: Any = None |
| |
|
| | @functools.wraps(fn) |
| | def wrapper(*args: Any, **kwargs: Any) -> Any: |
| | nonlocal last_args, last_kwargs, last_result |
| |
|
| | if last_args is not None and last_kwargs is not None: |
| | if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): |
| | if all(a is b for a, b in zip(args, last_args)) and \ |
| | all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): |
| | return last_result |
| |
|
| | result = fn(*args, **kwargs) |
| | last_args, last_kwargs, last_result = args, kwargs, result |
| | return result |
| |
|
| | return wrapper |
| |
|
| |
|
| | def input_guard( |
| | fn: Callable[..., torch.Tensor] |
| | ) -> Callable[..., torch.Tensor]: |
| | """ |
| | A decorator to make sure all input tensors are contiguous and set the device based on input tensors. |
| | """ |
| |
|
| | @functools.wraps(fn) |
| | def wrapper(*args, **kwargs): |
| | contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) |
| | contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} |
| |
|
| | tensor = None |
| | for arg in args: |
| | if isinstance(arg, torch.Tensor): |
| | tensor = arg |
| | break |
| | if tensor is None: |
| | for value in kwargs.values(): |
| | if isinstance(value, torch.Tensor): |
| | tensor = value |
| | break |
| |
|
| | if tensor is not None: |
| | ctx = custom_device_ctx(tensor.device.index) |
| | else: |
| | ctx = contextlib.nullcontext() |
| |
|
| | with ctx: |
| | return fn(*contiguous_args, **contiguous_kwargs) |
| |
|
| | return wrapper |
| |
|
| |
|
| | contiguous = input_guard |
| |
|
| |
|
| | def require_version(version, hint): |
| | """ |
| | Perform a runtime check of the dependency versions, using the exact same syntax used by pip. |
| | """ |
| | def decorator(fn): |
| | @functools.wraps(fn) |
| | def wrapper(ctx, *args, **kwargs): |
| | from transformers.utils.versions import require_version |
| | require_version(version, hint) |
| | return fn(ctx, |
| | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), |
| | **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) |
| | return wrapper |
| | return decorator |
| |
|
| |
|
| | def checkpoint(fn): |
| | def wrapper(*args, **kwargs): |
| | return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) |
| | return wrapper |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def check_pytorch_version(version_s: str = '2.4') -> bool: |
| | return version.parse(torch.__version__) >= version.parse(version_s) |
| |
|
| |
|
| | def _cpu_device_warning(): |
| | import warnings |
| | warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1) |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def get_multiprocessor_count(tensor_idx: int = 0) -> int: |
| | try: |
| | |
| | |
| | return torch.cuda.get_device_properties().multi_processor_count |
| | except BaseException: |
| | _cpu_device_warning() |
| | return -1 |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def get_available_device() -> str: |
| | try: |
| | return triton.runtime.driver.active.get_current_target().backend |
| | except BaseException: |
| | _cpu_device_warning() |
| | return 'cpu' |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: |
| | device = get_available_device() |
| | if device == 'cuda': |
| | return 'nvidia' |
| | elif device == 'hip': |
| | return 'amd' |
| | elif device == 'xpu': |
| | return 'intel' |
| | else: |
| | return device |
| |
|
| |
|
| | |
| | |
| | |
| | device = get_available_device() if get_available_device() != 'hip' else 'cuda' |
| | device_torch_lib = getattr(torch, device) |
| | device_platform = _check_platform() |
| |
|
| | is_amd = (device_platform == 'amd') |
| | is_intel = (device_platform == 'intel') |
| | is_nvidia = (device_platform == 'nvidia') |
| | is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) |
| | is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) |
| | use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') |
| |
|
| | |
| | is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8) |
| | is_gather_supported = hasattr(triton.language, 'gather') |
| |
|
| |
|
| | def get_all_max_shared_mem(): |
| | try: |
| | return [ |
| | triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem'] |
| | for i in range(device_torch_lib.device_count()) |
| | ] |
| | except BaseException: |
| | _cpu_device_warning() |
| | return [-1] |
| |
|
| |
|
| | class Backend(Enum): |
| | ADA = 101376 |
| | AMPERE = 166912 |
| | HOPPER = 232448 |
| | DEFAULT = 102400 |
| |
|
| | @classmethod |
| | def get_shared_memory(cls, arch: str) -> int: |
| | try: |
| | return cls[arch.upper()].value |
| | except KeyError: |
| | return cls.DEFAULT.value |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: |
| | try: |
| | device_shared_mem_list = get_all_max_shared_mem() |
| | max_shared_memory = device_shared_mem_list[tensor_idx] |
| | return max_shared_memory >= Backend.get_shared_memory(arch) |
| | except Exception: |
| | return False |
| |
|
| |
|
| | if check_pytorch_version('2.4'): |
| | device = 'cuda' if device == 'cpu' else device |
| | autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) |
| | autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) |
| |
|
| | def custom_device_ctx(index: int): |
| | return device_torch_lib.device(index) |
| | else: |
| | assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.' |
| | autocast_custom_fwd = device_torch_lib.amp.custom_fwd |
| | autocast_custom_bwd = device_torch_lib.amp.custom_bwd |
| |
|
| | def custom_device_ctx(index: int): |
| | return torch.cuda.device(index) |
| |
|