diff --git a/.gitattributes b/.gitattributes index daedcf899931fe4e542ddfe9fb151b56070362a9..849cfabdd4167a72817582cce9896dffce13ba44 100644 --- a/.gitattributes +++ b/.gitattributes @@ -127,3 +127,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py b/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d7d77dcd4134a8011937b544a3624dd13fe3c9a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py @@ -0,0 +1,1661 @@ +# mypy: allow-untyped-defs +r""" +This package adds support for CUDA tensor types. + +It implements the same function as CPU tensors, but they utilize +GPUs for computation. + +It is lazily initialized, so you can always import it, and use +:func:`is_available()` to determine if your system supports CUDA. + +:ref:`cuda-semantics` has more details about working with CUDA. +""" + +import importlib +import os +import threading +import traceback +import warnings +from functools import lru_cache +from typing import Any, Callable, cast, List, Optional, Tuple, Union + +import torch +import torch._C +from torch import device as _device +from torch._utils import _dummy_type, _LazySeedTracker, classproperty +from torch.types import Device + +from . import gds +from ._utils import _get_device_index +from .graphs import ( + CUDAGraph, + graph, + graph_pool_handle, + is_current_stream_capturing, + make_graphed_callables, +) +from .streams import Event, ExternalStream, Stream + + +try: + from torch._C import _cudart # type: ignore[attr-defined] +except ImportError: + _cudart = None + +_initialized = False +_tls = threading.local() +_initialization_lock = threading.Lock() +_queued_calls: List[ + Tuple[Callable[[], None], List[str]] +] = [] # don't invoke these until initialization occurs +_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) +_device_t = Union[_device, str, int, None] + +_HAS_PYNVML = False +_PYNVML_ERR = None +try: + from torch import version as _version + + try: + if not _version.hip: + import pynvml # type: ignore[import] + else: + import amdsmi # type: ignore[import] + + _HAS_PYNVML = True + except ModuleNotFoundError: + pass + finally: + del _version +except ImportError as err: + _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later + +_lazy_seed_tracker = _LazySeedTracker() + +# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA +if hasattr(torch._C, "_CudaDeviceProperties"): + _CudaDeviceProperties = torch._C._CudaDeviceProperties +else: + _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc] + +if hasattr(torch._C, "_cuda_exchangeDevice"): + _exchange_device = torch._C._cuda_exchangeDevice +else: + + def _exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without CUDA support") + + +if hasattr(torch._C, "_cuda_maybeExchangeDevice"): + _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice +else: + + def _maybe_exchange_device(device: int) -> int: + if device < 0: + return -1 + raise RuntimeError("PyTorch was compiled without CUDA support") + + +has_half: bool = True +has_magma: bool = torch._C._has_magma + +default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] + + +def _is_compiled() -> bool: + r"""Return true if compile with CUDA support.""" + return hasattr(torch._C, "_cuda_getDeviceCount") + + +def _nvml_based_avail() -> bool: + return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1" + + +def is_available() -> bool: + r"""Return a bool indicating if CUDA is currently available.""" + if not _is_compiled(): + return False + if _nvml_based_avail(): + # The user has set an env variable to request this availability check that attempts to avoid fork poisoning by + # using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization + # fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`) + return device_count() > 0 + else: + # The default availability inspection never throws and returns 0 if the driver is missing or can't + # be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver + # API via `cuInit` + return torch._C._cuda_getDeviceCount() > 0 + + +def is_bf16_supported(including_emulation: bool = True): + r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16.""" + # Check for ROCm, if true return true, no ROCM_VERSION check required, + # since it is supported on AMD GPU archs. + if torch.version.hip: + return True + + # If CUDA is not available, than it does not support bf16 either + if not is_available(): + return False + + device = torch.cuda.current_device() + + # Check for CUDA version and device compute capability. + # This is a fast way to check for it. + cuda_version = torch.version.cuda + if ( + cuda_version is not None + and int(cuda_version.split(".")[0]) >= 11 + and torch.cuda.get_device_properties(device).major >= 8 + ): + return True + + if not including_emulation: + return False + + # Finally try to create a bfloat16 device. + return _check_bf16_tensor_supported(device) + + +@lru_cache(maxsize=16) +def _check_bf16_tensor_supported(device: _device_t): + try: + torch.tensor([1.0], dtype=torch.bfloat16, device=device) + return True + except Exception: + return False + + +def _sleep(cycles): + torch._C._cuda_sleep(cycles) + + +def _extract_arch_version(arch_string: str): + """Extracts the architecture string from a CUDA version""" + base = arch_string.split("_")[1] + if base.endswith("a"): + base = base[:-1] + return int(base) + + +def _check_capability(): + incorrect_binary_warn = """ + Found GPU%d %s which requires CUDA_VERSION >= %d to + work properly, but your PyTorch was compiled + with CUDA_VERSION %d. Please install the correct PyTorch binary + using instructions from https://pytorch.org + """ + + old_gpu_warn = """ + Found GPU%d %s which is of cuda capability %d.%d. + PyTorch no longer supports this GPU because it is too old. + The minimum cuda capability supported by this library is %d.%d. + """ + + if torch.version.cuda is not None: # on ROCm we don't want this check + CUDA_VERSION = torch._C._cuda_getCompiledVersion() + for d in range(device_count()): + capability = get_device_capability(d) + major = capability[0] + minor = capability[1] + name = get_device_name(d) + current_arch = major * 10 + minor + min_arch = min( + (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), + default=35, + ) + if current_arch < min_arch: + warnings.warn( + old_gpu_warn + % (d, name, major, minor, min_arch // 10, min_arch % 10) + ) + + +def _check_cubins(): + incompatible_device_warn = """ +{} with CUDA capability sm_{} is not compatible with the current PyTorch installation. +The current PyTorch install supports CUDA capabilities {}. +If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/ +""" + if torch.version.cuda is None: # on ROCm we don't want this check + return + arch_list = get_arch_list() + if len(arch_list) == 0: + return + supported_sm = [_extract_arch_version(arch) for arch in arch_list if "sm_" in arch] + for idx in range(device_count()): + cap_major, cap_minor = get_device_capability(idx) + # NVIDIA GPU compute architectures are backward compatible within major version + supported = any(sm // 10 == cap_major for sm in supported_sm) + if not supported: + device_name = get_device_name(idx) + capability = cap_major * 10 + cap_minor + warnings.warn( + incompatible_device_warn.format( + device_name, capability, " ".join(arch_list), device_name + ) + ) + + +def is_initialized(): + r"""Return whether PyTorch's CUDA state has been initialized.""" + return _initialized and not _is_in_bad_fork() + + +def _lazy_call(callable, **kwargs): + if is_initialized(): + callable() + else: + # TODO(torch_deploy): this accesses linecache, which attempts to read the + # file system to get traceback info. Patch linecache or do something + # else here if this ends up being important. + global _lazy_seed_tracker + if kwargs.get("seed_all", False): + _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) + elif kwargs.get("seed", False): + _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) + else: + # Don't store the actual traceback to avoid memory cycle + _queued_calls.append((callable, traceback.format_stack())) + + +_lazy_call(_check_capability) +_lazy_call(_check_cubins) + + +class DeferredCudaCallError(Exception): + pass + + +OutOfMemoryError = torch._C.OutOfMemoryError + + +def init(): + r"""Initialize PyTorch's CUDA state. + + You may need to call this explicitly if you are interacting with + PyTorch via its C API, as Python bindings for CUDA functionality + will not be available until this initialization takes place. + Ordinary users should not need this, as all of PyTorch's CUDA methods + automatically initialize CUDA state on-demand. + + Does nothing if the CUDA state is already initialized. + """ + _lazy_init() + + +def _lazy_init(): + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # We be double-checked locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if is_initialized(): + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize CUDA in forked subprocess. To use CUDA with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not hasattr(torch._C, "_cuda_getDeviceCount"): + raise AssertionError("Torch not compiled with CUDA enabled") + if _cudart is None: + raise AssertionError( + "libcudart functions unavailable. It looks like you have a broken build?" + ) + # This function throws if there's a driver initialization error, no GPUs + # are found or any other error occurs + if "CUDA_MODULE_LOADING" not in os.environ: + os.environ["CUDA_MODULE_LOADING"] = "LAZY" + torch._C._cuda_init() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True + + for calls in _lazy_seed_tracker.get_calls(): + if calls: + _queued_calls.append(calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"CUDA call failed lazily at initialization with error: {str(e)}\n\n" + f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise DeferredCudaCallError(msg) from e + finally: + delattr(_tls, "is_initializing") + _initialized = True + + +def cudart(): + r"""Retrieves the CUDA runtime API module. + + + This function initializes the CUDA runtime environment if it is not already + initialized and returns the CUDA runtime API module (_cudart). The CUDA + runtime API module provides access to various CUDA runtime functions. + + Args: + ``None`` + + Returns: + module: The CUDA runtime API module (_cudart). + + Raises: + RuntimeError: If CUDA cannot be re-initialized in a forked subprocess. + AssertionError: If PyTorch is not compiled with CUDA support or if libcudart functions are unavailable. + + Example of CUDA operations with profiling: + >>> import torch + >>> from torch.cuda import cudart, check_error + >>> import os + >>> + >>> os.environ['CUDA_PROFILE'] = '1' + >>> + >>> def perform_cuda_operations_with_streams(): + >>> stream = torch.cuda.Stream() + >>> with torch.cuda.stream(stream): + >>> x = torch.randn(100, 100, device='cuda') + >>> y = torch.randn(100, 100, device='cuda') + >>> z = torch.mul(x, y) + >>> return z + >>> + >>> torch.cuda.synchronize() + >>> print("====== Start nsys profiling ======") + >>> check_error(cudart().cudaProfilerStart()) + >>> with torch.autograd.profiler.emit_nvtx(): + >>> result = perform_cuda_operations_with_streams() + >>> print("CUDA operations completed.") + >>> check_error(torch.cuda.cudart().cudaProfilerStop()) + >>> print("====== End nsys profiling ======") + + To run this example and save the profiling information, execute: + >>> $ nvprof --profile-from-start off --csv --print-summary -o trace_name.prof -f -- python cudart_test.py + + This command profiles the CUDA operations in the provided script and saves + the profiling information to a file named `trace_name.prof`. + The `--profile-from-start off` option ensures that profiling starts only + after the `cudaProfilerStart` call in the script. + The `--csv` and `--print-summary` options format the profiling output as a + CSV file and print a summary, respectively. + The `-o` option specifies the output file name, and the `-f` option forces the + overwrite of the output file if it already exists. + """ + _lazy_init() + return _cudart + + +class cudaStatus: + SUCCESS: int = 0 + ERROR_NOT_READY: int = 34 + + +class CudaError(RuntimeError): + def __init__(self, code: int) -> None: + msg = _cudart.cudaGetErrorString(_cudart.cudaError(code)) + super().__init__(f"{msg} ({code})") + + +def check_error(res: int) -> None: + if res != _cudart.cudaError.success: + raise CudaError(res) + + +class _DeviceGuard: + def __init__(self, index: int): + self.idx = index + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.cuda._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + return False + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.cuda._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + return False + + +class device_of(device): + r"""Context-manager that changes the current device to that of given object. + + You can use both tensors and storages as arguments. If a given object is + not allocated on a GPU, this is a no-op. + + Args: + obj (Tensor or Storage): object allocated on the selected device. + """ + + def __init__(self, obj): + idx = obj.get_device() if obj.is_cuda else -1 + super().__init__(idx) + + +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Usage of this function is discouraged in favor of :any:`device`. In most + cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._cuda_setDevice(device) + + +def get_device_name(device: Optional[_device_t] = None) -> str: + r"""Get the name of a device. + + Args: + device (torch.device or int or str, optional): device for which to return the + name. This function is a no-op if this argument is a negative + integer. It uses the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Returns: + str: the name of the device + """ + return get_device_properties(device).name + + +def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]: + r"""Get the cuda capability of a device. + + Args: + device (torch.device or int or str, optional): device for which to return the + device capability. This function is a no-op if this argument is + a negative integer. It uses the current device, given by + :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + (default). + + Returns: + tuple(int, int): the major and minor cuda capability of the device + """ + prop = get_device_properties(device) + return prop.major, prop.minor + + +def get_device_properties(device: _device_t) -> _CudaDeviceProperties: + r"""Get the properties of a device. + + Args: + device (torch.device or int or str): device for which to return the + properties of the device. + + Returns: + _CudaDeviceProperties: the properties of the device + """ + _lazy_init() # will define _get_device_properties + device = _get_device_index(device, optional=True) + if device < 0 or device >= device_count(): + raise AssertionError("Invalid device id") + return _get_device_properties(device) # type: ignore[name-defined] + + +def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool: + r"""Check if peer access between two devices is possible.""" + _lazy_init() + device = _get_device_index(device, optional=True) + peer_device = _get_device_index(peer_device) + if device < 0 or device >= device_count(): + raise AssertionError("Invalid device id") + if peer_device < 0 or peer_device >= device_count(): + raise AssertionError("Invalid peer device id") + return torch._C._cuda_canDeviceAccessPeer(device, peer_device) + + +class StreamContext: + r"""Context-manager that selects a given stream. + + All CUDA kernels queued within its context will be enqueued on a selected + stream. + + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + cur_stream: Optional["torch.cuda.Stream"] + + def __init__(self, stream: Optional["torch.cuda.Stream"]): + self.stream = stream + self.idx = _get_device_index(None, True) + if not torch.jit.is_scripting(): + if self.idx is None: + self.idx = -1 + + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) + ) + + def __enter__(self): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None or CUDA device not available + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.cuda.current_stream(None) + + # If the stream is not on the current device, then + # set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device) + torch.cuda.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no CUDA device available, return + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device + # and destination device + if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] + torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type] + torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type] + + +def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + ..Note:: In eager mode stream is of type Stream class while in JIT it is + an object of the custom class ``torch.classes.cuda.Stream``. + """ + return StreamContext(stream) + + +def _set_stream_by_id(stream_id, device_index, device_type): + r"""set stream specified by the stream id, device index and + device type + + Args: stream_id (int): stream id in stream pool + device_index (int): device index in topo + device_type (int): enum device type + """ + torch._C._cuda_setStream( + stream_id=stream_id, + device_index=device_index, + device_type=device_type, + ) + + +def set_stream(stream: Stream): + r"""Set the current stream.This is a wrapper API to set the stream. + Usage of this function is discouraged in favor of the ``stream`` + context manager. + + Args: + stream (Stream): selected stream. This function is a no-op + if this argument is ``None``. + """ + if stream is None: + return + _set_stream_by_id( + stream_id=stream.stream_id, + device_index=stream.device_index, + device_type=stream.device_type, + ) + + +def _parse_visible_devices() -> Union[List[int], List[str]]: + r"""Parse CUDA_VISIBLE_DEVICES environment variable.""" + var = os.getenv("CUDA_VISIBLE_DEVICES") + + if torch.version.hip: + hip_devices = os.getenv("HIP_VISIBLE_DEVICES") + if hip_devices is not None: + var = hip_devices + + if var is None: + return list(range(64)) + + def _strtoul(s: str) -> int: + """Return -1 or positive integer sequence string starts with.""" + if not s: + return -1 + for idx, c in enumerate(s): + if not (c.isdigit() or (idx == 0 and c in "+-")): + break + if idx + 1 == len(s): + idx += 1 + return int(s[:idx]) if idx > 0 else -1 + + def parse_list_with_prefix(lst: str, prefix: str) -> List[str]: + rcs: List[str] = [] + for elem in lst.split(","): + # Repeated id results in empty set + if elem in rcs: + return cast(List[str], []) + # Anything other but prefix is ignored + if not elem.startswith(prefix): + break + rcs.append(elem) + return rcs + + if var.startswith("GPU-"): + return parse_list_with_prefix(var, "GPU-") + if var.startswith("MIG-"): + return parse_list_with_prefix(var, "MIG-") + # CUDA_VISIBLE_DEVICES uses something like strtoul + # which makes `1gpu2,2ampere` is equivalent to `1,2` + rc: List[int] = [] + for elem in var.split(","): + x = _strtoul(elem.strip()) + # Repeated ordinal results in empty set + if x in rc: + return cast(List[int], []) + # Negative value aborts the sequence + if x < 0: + break + rc.append(x) + return rc + + +def _raw_device_count_amdsmi() -> int: + if not _HAS_PYNVML: # If amdsmi is not available + return -1 + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") + return -1 + socket_handles = amdsmi.amdsmi_get_processor_handles() + return len(socket_handles) + + +def _raw_device_count_nvml() -> int: + r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed.""" + from ctypes import byref, c_int, CDLL + + nvml_h = CDLL("libnvidia-ml.so.1") + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML") + return -1 + dev_count = c_int(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) + if rc != 0: + warnings.warn("Can't get nvml device count") + return -1 + del nvml_h + return dev_count.value + + +def _raw_device_uuid_amdsmi() -> Optional[List[str]]: + from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + + if not _HAS_PYNVML: # If amdsmi is not available + return None + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException: + warnings.warn("Can't initialize amdsmi") + return None + try: + socket_handles = amdsmi.amdsmi_get_processor_handles() + dev_count = len(socket_handles) + except amdsmi.AmdSmiException: + warnings.warn("Can't get amdsmi device count") + return None + uuids: List[str] = [] + for idx in range(dev_count): + try: + handler = amdsmi.amdsmi_get_processor_handles()[idx] + except amdsmi.AmdSmiException: + warnings.warn("Cannot get amd device handler") + return None + try: + uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler) + except amdsmi.AmdSmiException: + warnings.warn("Cannot get uuid for amd device") + return None + uuids.append(str(uuid)) + return uuids + + +def _raw_device_uuid_nvml() -> Optional[List[str]]: + r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed.""" + from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + + nvml_h = CDLL("libnvidia-ml.so.1") + rc = nvml_h.nvmlInit() + if rc != 0: + warnings.warn("Can't initialize NVML") + return None + dev_count = c_int(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) + if rc != 0: + warnings.warn("Can't get nvml device count") + return None + uuids: List[str] = [] + for idx in range(dev_count.value): + dev_id = c_void_p() + rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) + if rc != 0: + warnings.warn("Can't get device handle") + return None + buf_len = 96 + buf = create_string_buffer(buf_len) + rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) + if rc != 0: + warnings.warn("Can't get device UUID") + return None + uuids.append(buf.raw.decode("ascii").strip("\0")) + del nvml_h + return uuids + + +def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]: + r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs.""" + + def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: + best_match = -1 + for idx, uuid in enumerate(uuids): + if not uuid.startswith(candidate): + continue + # Ambiguous candidate + if best_match != -1: + return -1 + best_match = idx + return best_match + + rc: List[int] = [] + for candidate in candidates: + idx = uuid_to_orinal(candidate, uuids) + # First invalid ordinal stops parsing + if idx < 0: + break + # Duplicates result in empty set + if idx in rc: + return cast(List[int], []) + rc.append(idx) + return rc + + +def _device_count_amdsmi() -> int: + visible_devices = _parse_visible_devices() + if not visible_devices: + return 0 + try: + if type(visible_devices[0]) is str: + return -1 + else: + raw_cnt = _raw_device_count_amdsmi() + if raw_cnt <= 0: + return raw_cnt + # Trim the list up to a maximum available device + for idx, val in enumerate(visible_devices): + if cast(int, val) >= raw_cnt: + return idx + except OSError: + return -1 + except AttributeError: + return -1 + return len(visible_devices) + + +def _device_count_nvml() -> int: + r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. + + Negative value is returned if NVML discovery or initialization has failed. + """ + visible_devices = _parse_visible_devices() + if not visible_devices: + return 0 + try: + if type(visible_devices[0]) is str: + # Skip MIG parsing + if visible_devices[0].startswith("MIG-"): + return -1 + uuids = _raw_device_uuid_nvml() + if uuids is None: + return -1 + visible_devices = _transform_uuid_to_ordinals( + cast(List[str], visible_devices), uuids + ) + else: + raw_cnt = _raw_device_count_nvml() + if raw_cnt <= 0: + return raw_cnt + # Trim the list up to a maximum available device + for idx, val in enumerate(visible_devices): + if cast(int, val) >= raw_cnt: + return idx + except OSError: + return -1 + except AttributeError: + return -1 + return len(visible_devices) + + +def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: + r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account.""" + idx = _get_device_index(device, optional=True) + visible_devices = _parse_visible_devices() + if type(visible_devices[0]) is str: + uuids = _raw_device_uuid_nvml() + if uuids is None: + raise RuntimeError("Can't get device UUIDs") + visible_devices = _transform_uuid_to_ordinals( + cast(List[str], visible_devices), uuids + ) + visible_devices = cast(List[int], visible_devices) + if idx < 0 or idx >= len(visible_devices): + raise RuntimeError( + f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})" + ) + return visible_devices[idx] + + +_cached_device_count: Optional[int] = None + + +def device_count() -> int: + r"""Return the number of GPUs available.""" + global _cached_device_count + if not _is_compiled(): + return 0 + if _cached_device_count is not None: + return _cached_device_count + # bypass _device_count_nvml() if rocm (not supported) + nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml() + r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count + # NB: Do not cache the device count prior to CUDA initialization, because + # the number of devices can change due to changes to CUDA_VISIBLE_DEVICES + # setting prior to CUDA initialization. + if _initialized: + _cached_device_count = r + return r + + +def get_arch_list() -> List[str]: + r"""Return list CUDA architectures this library was compiled for.""" + if not is_available(): + return [] + arch_flags = torch._C._cuda_getArchFlags() + if arch_flags is None: + return [] + return arch_flags.split() + + +def get_gencode_flags() -> str: + r"""Return NVCC gencode flags this library was compiled with.""" + arch_list = get_arch_list() + if len(arch_list) == 0: + return "" + arch_list_ = [arch.split("_") for arch in arch_list] + return " ".join( + [ + f"-gencode compute=compute_{arch},code={kind}_{arch}" + for (kind, arch) in arch_list_ + ] + ) + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + _lazy_init() + return torch._C._cuda_getDevice() + + +def synchronize(device: _device_t = None) -> None: + r"""Wait for all kernels in all streams on a CUDA device to complete. + + Args: + device (torch.device or int, optional): device for which to synchronize. + It uses the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + """ + _lazy_init() + with torch.cuda.device(device): + return torch._C._cuda_synchronize() + + +def ipc_collect(): + r"""Force collects GPU memory after it has been released by CUDA IPC. + + .. note:: + Checks if any sent CUDA tensors could be cleaned from the memory. Force + closes shared memory file used for reference counting if there is no + active counters. Useful when the producer process stopped actively sending + tensors and want to release unused memory. + """ + _lazy_init() + return torch._C._cuda_ipc_collect() + + +def current_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + (default). + """ + _lazy_init() + streamdata = torch._C._cuda_getCurrentStream( + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def default_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the default :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the default :class:`Stream` for the current device, given by + :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` + (default). + """ + _lazy_init() + streamdata = torch._C._cuda_getDefaultStream( + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def current_blas_handle(): + r"""Return cublasHandle_t pointer to current cuBLAS handle""" + _lazy_init() + return torch._C._cuda_getCurrentBlasHandle() + + +def set_sync_debug_mode(debug_mode: Union[int, str]) -> None: + r"""Set the debug mode for cuda synchronizing operations. + + Args: + debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations, + if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations. + + Warning: + This is an experimental feature, and not all synchronizing operations will trigger warning or error. In + particular, operations in torch.distributed and torch.sparse namespaces are not covered yet. + """ + _lazy_init() + if isinstance(debug_mode, str): + if debug_mode == "default": + debug_mode = 0 + elif debug_mode == "warn": + debug_mode = 1 + elif debug_mode == "error": + debug_mode = 2 + else: + raise RuntimeError( + "invalid value of debug_mode, expected one of `default`, `warn`, `error`" + ) + + torch._C._cuda_set_sync_debug_mode(debug_mode) + + +def get_sync_debug_mode() -> int: + r"""Return current value of debug mode for cuda synchronizing operations.""" + _lazy_init() + return torch._C._cuda_get_sync_debug_mode() + + +def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): + if not _HAS_PYNVML: + raise ModuleNotFoundError( + "pynvml does not seem to be installed or it can't be imported." + ) from _PYNVML_ERR + from pynvml import NVMLError_DriverNotLoaded + + try: + pynvml.nvmlInit() + except NVMLError_DriverNotLoaded as e: + raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e + + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return handle + + +def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None): + if not _HAS_PYNVML: + raise ModuleNotFoundError( + "amdsmi does not seem to be installed or it can't be imported." + ) from _PYNVML_ERR + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + raise RuntimeError( + "amdsmi driver can't be loaded, requires >=ROCm5.6 installation" + ) from e + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] + return handle + + +def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: + r"""Return the amdsmi index of the device, taking visible_devices into account.""" + idx = _get_device_index(device, optional=True) + visible_devices = _parse_visible_devices() + if type(visible_devices[0]) is str: + raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings") + idx_map = dict(enumerate(cast(List[int], visible_devices))) + if idx not in idx_map: + raise RuntimeError( + f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})" + ) + return idx_map[idx] + + +def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler() + device = _get_amdsmi_device_index(device) + return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] + + +def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler() + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] + return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] + + +def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_temp_metric( + handle, + amdsmi.AmdSmiTemperatureType.JUNCTION, + amdsmi.AmdSmiTemperatureMetric.CURRENT, + ) + + +def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] + if socket_power != "N/A": + return socket_power + else: + return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"] + + +def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX) + if "cur_clk" in clock_info: # ROCm 6.2 deprecation + return clock_info["cur_clk"] + else: + return clock_info["clk"] + + +def memory_usage(device: Optional[Union[Device, int]] = None) -> int: + r"""Return the percent of time over the past sample period during which global (device) + memory was being read or written as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler() + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).memory + else: + return _get_amdsmi_memory_usage(device) + + +def utilization(device: Optional[Union[Device, int]] = None) -> int: + r"""Return the percent of time over the past sample period during which one or + more kernels was executing on the GPU as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + else: + return _get_amdsmi_utilization(device) + + +def temperature(device: Optional[Union[Device, int]] = None) -> int: + r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades). + + The average temperature is computed based on past sample period as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + # 0 refers to the temperature sensor for the GPU die. + return pynvml.nvmlDeviceGetTemperature(handle, 0) + else: + return _get_amdsmi_temperature(device) + + +def power_draw(device: Optional[Union[Device, int]] = None) -> int: + r"""Return the average power draw of the GPU sensor in mW (MilliWatts) + over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetPowerUsage(handle) + else: + return _get_amdsmi_power_draw(device) + + +def clock_rate(device: Optional[Union[Device, int]] = None) -> int: + r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + Warning: Each sample period may be between 1 second and 1/6 second, + depending on the product being queried. + """ + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetClockInfo(handle, 1) + else: + return _get_amdsmi_clock_rate(device) + + +def _get_device(device: Union[int, str, torch.device]) -> torch.device: + r"""Return the torch.device type object from the passed in device. + + Args: + device (torch.device or int): selected device. + """ + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + return device + + +def _get_generator(device: torch.device) -> torch._C.Generator: + r"""Return the CUDA Generator object for the given device. + + Args: + device (torch.device): selected device. + """ + idx = device.index + if idx is None: + idx = current_device() + return torch.cuda.default_generators[idx] + + +def _set_rng_state_offset( + offset: int, device: Union[int, str, torch.device] = "cuda" +) -> None: + r"""Set the random number generator state offset of the specified GPU. + + Args: + offset (int): The desired offset + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + """ + final_device = _get_device(device) + + def cb(): + default_generator = _get_generator(final_device) + default_generator.set_offset(offset) + + _lazy_call(cb) + + +def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int: + r"""Return the random number generator state offset of the specified GPU. + + Args: + device (torch.device or int, optional): The device to return the RNG state offset of. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + + .. warning:: + This function eagerly initializes CUDA. + """ + _lazy_init() + final_device = _get_device(device) + default_generator = _get_generator(final_device) + return default_generator.get_offset() + + +from .memory import * # noqa: F403 +from .random import * # noqa: F403 + + +################################################################################ +# Define Storage and Tensor classes +################################################################################ + + +@staticmethod # type: ignore[misc] +def _lazy_new(cls, *args, **kwargs): + _lazy_init() + # We may need to call lazy init again if we are a forked child + # del _CudaBase.__new__ + return super(_CudaBase, cls).__new__(cls, *args, **kwargs) + + +class _CudaBase: + is_cuda = True + is_sparse = False + + def type(self, *args, **kwargs): + # We could use a Protocol here to tell mypy that self has `get_device` method + # but it is only available in the typing module on Python >= 3.8 + # or on typing_extensions module on Python >= 3.6 + with device(self.get_device()): # type: ignore[attr-defined] + return super().type(*args, **kwargs) # type: ignore[misc] + + __new__ = _lazy_new + + +from torch.storage import _LegacyStorage, _warn_typed_storage_removal + + +class _CudaLegacyStorage(_LegacyStorage): + @classmethod + def from_buffer(cls, *args, **kwargs): + _warn_typed_storage_removal() + raise RuntimeError("from_buffer: Not available for CUDA storage") + + @classmethod + def _new_with_weak_ptr(cls, *args, **kwargs): + raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage") + + @classmethod + def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None): + raise RuntimeError("_new_shared_filename: Not available for CUDA storage") + + +class ByteStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.uint8 + + +class DoubleStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.double + + +class FloatStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.float + + +class HalfStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.half + + +class LongStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.long + + +class IntStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.int + + +class ShortStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.short + + +class CharStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.int8 + + +class BoolStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.bool + + +class BFloat16Storage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.bfloat16 + + +class ComplexDoubleStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.cdouble + + +class ComplexFloatStorage(_CudaLegacyStorage): + @classproperty + def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): + return torch.cfloat + + +del _LegacyStorage +del _CudaLegacyStorage + +torch._storage_classes.add(DoubleStorage) +torch._storage_classes.add(FloatStorage) +torch._storage_classes.add(LongStorage) +torch._storage_classes.add(IntStorage) +torch._storage_classes.add(ShortStorage) +torch._storage_classes.add(CharStorage) +torch._storage_classes.add(ByteStorage) +torch._storage_classes.add(HalfStorage) +torch._storage_classes.add(BoolStorage) +torch._storage_classes.add(BFloat16Storage) +torch._storage_classes.add(ComplexDoubleStorage) +torch._storage_classes.add(ComplexFloatStorage) + + +class _WrappedTritonKernel: + """Just a simple wrapper to store some metadata for testing purposes.""" + + def __init__(self, kernel): + self.kernel = kernel + self.kernel_invoked = False + + def __call__(self, *args, **kwargs): + res = self.kernel(*args, **kwargs) + self.kernel_invoked = True + return res + + +def _register_triton_kernels(): + if torch._running_with_deploy(): + return + + @_WrappedTritonKernel + def kernel_impl(*args, **kwargs): + from torch.sparse._triton_ops import bsr_dense_mm + + return bsr_dense_mm(*args, skip_checks=True, **kwargs) + + @_WrappedTritonKernel + def addmm_kernel_impl(*args, **kwargs): + from torch.sparse._triton_ops import bsr_dense_addmm + + return bsr_dense_addmm(*args, skip_checks=True, **kwargs) + + has_triton = importlib.util.find_spec("triton") is not None + if has_triton: + torch._TritonLibrary.registerOp( + "_triton_bsr_dense_mm_out", + "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)", + kernel_impl, + "SparseCsrCUDA", + ) + + torch._TritonLibrary.registerOp( + "_triton_bsr_dense_addmm_out", + ( + "_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense," + " *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)" + ), + addmm_kernel_impl, + "SparseCsrCUDA", + ) + + +_lazy_call(_register_triton_kernels) + + +from . import amp, jiterator, nvtx, profiler, sparse, tunable + + +__all__ = [ + # Typed storage and tensors + "BFloat16Storage", + "BFloat16Tensor", + "BoolStorage", + "BoolTensor", + "ByteStorage", + "ByteTensor", + "CharStorage", + "CharTensor", + "ComplexDoubleStorage", + "ComplexFloatStorage", + "DoubleStorage", + "DoubleTensor", + "FloatStorage", + "FloatTensor", + "HalfStorage", + "HalfTensor", + "IntStorage", + "IntTensor", + "LongStorage", + "LongTensor", + "ShortStorage", + "ShortTensor", + "CUDAGraph", + "CudaError", + "DeferredCudaCallError", + "Event", + "ExternalStream", + "Stream", + "StreamContext", + "amp", + "caching_allocator_alloc", + "caching_allocator_delete", + "can_device_access_peer", + "check_error", + "cudaStatus", + "cudart", + "current_blas_handle", + "current_device", + "current_stream", + "default_generators", + "default_stream", + "device", + "device_count", + "device_of", + "empty_cache", + "get_allocator_backend", + "CUDAPluggableAllocator", + "change_current_allocator", + "get_arch_list", + "get_device_capability", + "get_device_name", + "get_device_properties", + "get_gencode_flags", + "get_rng_state", + "get_rng_state_all", + "get_sync_debug_mode", + "graph", + "graph_pool_handle", + "graphs", + "has_half", + "has_magma", + "init", + "initial_seed", + "ipc_collect", + "is_available", + "is_bf16_supported", + "is_current_stream_capturing", + "is_initialized", + "jiterator", + "list_gpu_processes", + "make_graphed_callables", + "manual_seed", + "manual_seed_all", + "max_memory_allocated", + "max_memory_cached", + "max_memory_reserved", + "mem_get_info", + "memory", + "memory_allocated", + "memory_cached", + "memory_reserved", + "memory_snapshot", + "memory_stats", + "memory_stats_as_nested_dict", + "memory_summary", + "memory_usage", + "MemPool", + "MemPoolContext", + "use_mem_pool", + "temperature", + "power_draw", + "clock_rate", + "nccl", + "nvtx", + "profiler", + "random", + "reset_accumulated_memory_stats", + "reset_max_memory_allocated", + "reset_max_memory_cached", + "reset_peak_memory_stats", + "seed", + "seed_all", + "set_device", + "set_per_process_memory_fraction", + "set_rng_state", + "set_rng_state_all", + "set_stream", + "set_sync_debug_mode", + "sparse", + "stream", + "streams", + "synchronize", + "tunable", + "utilization", +] diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d6aa7f5abf6ebf8e43a33510ba4748c8cea09ca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/_gpu_trace.py b/.venv/lib/python3.11/site-packages/torch/cuda/_gpu_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..2a738b002d773118eb649e0ce34e608df3227655 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/_gpu_trace.py @@ -0,0 +1,75 @@ +from typing import Callable + +from torch._utils import CallbackRegistry + + +EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA event creation" +) +EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA event deletion" +) +EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( + "CUDA event record" +) +EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( + "CUDA event wait" +) +MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA memory allocation" +) +MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA memory deallocation" +) +StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA stream creation" +) +DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry( + "CUDA device synchronization" +) +StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA stream synchronization" +) +EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "CUDA event synchronization" +) + + +def register_callback_for_event_creation(cb: Callable[[int], None]) -> None: + EventCreationCallbacks.add_callback(cb) + + +def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None: + EventDeletionCallbacks.add_callback(cb) + + +def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None: + EventRecordCallbacks.add_callback(cb) + + +def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None: + EventWaitCallbacks.add_callback(cb) + + +def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None: + MemoryAllocationCallbacks.add_callback(cb) + + +def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None: + MemoryDeallocationCallbacks.add_callback(cb) + + +def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None: + StreamCreationCallbacks.add_callback(cb) + + +def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None: + DeviceSynchronizationCallbacks.add_callback(cb) + + +def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None: + StreamSynchronizationCallbacks.add_callback(cb) + + +def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None: + EventSynchronizationCallbacks.add_callback(cb) diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py b/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..2047ec4efb28fa31c768b96d65208c694865a587 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py @@ -0,0 +1,632 @@ +# mypy: allow-untyped-defs +import pickle +import sys +import os +import io +import subprocess +import json +from functools import lru_cache +from typing import Any +from itertools import groupby +import base64 +import warnings +import operator + +cache = lru_cache(None) + +__all__ = ["format_flamegraph", "segments", "memory", "compare"] + +def _frame_fmt(f, full_filename=False): + i = f['line'] + fname = f['filename'] + if not full_filename: + fname = fname.split('/')[-1] + func = f['name'] + return f'{fname}:{i}:{func}' + +@cache +def _frame_filter(name, filename): + omit_functions = [ + "unwind::unwind", + "CapturedTraceback::gather", + "gather_with_cpp", + "_start", + "__libc_start_main", + "PyEval_", + "PyObject_", + "PyFunction_", + ] + omit_filenames = [ + "core/boxing", + "/Register", + "/Redispatch", + "pythonrun.c", + "Modules/main.c", + "Objects/call.c", + "Objects/methodobject.c", + "pycore_ceval.h", + "ceval.c", + "cpython/abstract.h", + ] + for of in omit_functions: + if of in name: + return False + for of in omit_filenames: + if of in filename: + return False + return True + +def _frames_fmt(frames, full_filename=False, reverse=False): + if reverse: + frames = reversed(frames) + return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])] + +def _block_extra_legacy(b): + if 'history' in b: + frames = b['history'][0].get('frames', []) + real_size = b['history'][0]['real_size'] + else: + real_size = b.get('requested_size', b['size']) + frames = [] + return frames, real_size + +def _block_extra(b): + if 'frames' not in b: + # old snapshot format made it more complicated to get frames/allocated size + return _block_extra_legacy(b) + return b['frames'], b['requested_size'] + +def format_flamegraph(flamegraph_lines, flamegraph_script=None): + if flamegraph_script is None: + flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl' + if not os.path.exists(flamegraph_script): + import urllib.request + print(f"Downloading flamegraph.pl to: {flamegraph_script}") + urllib.request.urlretrieve( + 'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script) + subprocess.check_call(['chmod', '+x', flamegraph_script]) + args = [flamegraph_script, '--countname', 'bytes'] + p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8') + assert p.stdin is not None + assert p.stdout is not None + p.stdin.write(flamegraph_lines) + p.stdin.close() + result = p.stdout.read() + p.stdout.close() + p.wait() + assert p.wait() == 0 + return result + +def _write_blocks(f, prefix, blocks): + def frames_fragment(frames): + if not frames: + return "" + return ';'.join(_frames_fmt(frames, reverse=True)) + for b in blocks: + if 'history' not in b: + frames, accounted_for_size = _block_extra(b) + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n') + else: + accounted_for_size = 0 + for h in b['history']: + sz = h['real_size'] + accounted_for_size += sz + if 'frames' in h: + frames = h['frames'] + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b['size'] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') + +def segments(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot['segments']: + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + _write_blocks(f, prefix, seg['blocks']) + return format_flamegraph(f.getvalue()) + +def memory(snapshot, format_flamegraph=format_flamegraph): + f = io.StringIO() + for seg in snapshot['segments']: + prefix = f'stream_{seg["stream"]}' + _write_blocks(f, prefix, seg['blocks']) + return format_flamegraph(f.getvalue()) + +def compare(before, after, format_flamegraph=format_flamegraph): + def _seg_key(seg): + return (seg['address'], seg['total_size']) + + def _seg_info(seg): + return f'stream_{seg["stream"]};seg_{seg["address"]}' + + f = io.StringIO() + + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} + + print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}') + print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}') + + for seg in before: + if _seg_key(seg) not in after_segs: + _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks']) + + for seg in after: + if _seg_key(seg) not in before_segs: + _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks']) + + return format_flamegraph(f.getvalue()) + +def _format_size(num): + # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}B" + num /= 1024.0 + return f"{num:.1f}YiB" + +class Bytes: + def __init__(self, value): + self.value = value + + def __add__(self, rhs): + return Bytes(self.value + rhs) + + def __repr__(self): + return _format_size(self.value) + +def calc_active(seg): + return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated') + +def _report_free(free_external, free_internal): + total = free_external + free_internal + suffix = '' + if total != 0: + pct = (free_internal / total) * 100 + suffix = f' ({pct:.1f}% internal)' + return f'{Bytes(total)}{suffix}' + +PAGE_SIZE = 1024 * 1024 * 20 +legend = f"""\ + +Legend: + [a ] - a segment in the allocator + ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment + a-z: pages filled with a single block's content + ' ': page is completely free + *: page if completely full with multiple blocks + 0-9: page is partially full with tensors of multiple blocks (9 == 90% full) + (X% internal) - of the free memory, X% is free because we rounded the size of the allocation. +""" + +def segsum(data): + r"""Visually reports how the allocator has filled its segments. + + This printout can help debug fragmentation issues since free fragments + will appear as gaps in this printout. The amount of free space is reported + for each segment. + We distinguish between internal free memory which occurs because the + allocator rounds the allocation size, and external free memory, which are + the gaps between allocations in a segment. + Args: + data: snapshot dictionary created from _snapshot() + """ + segments = [] + out = io.StringIO() + out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") + total_reserved = 0 + total_allocated = 0 + free_external = 0 + free_internal = 0 + for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))): + total_reserved += seg['total_size'] + + seg_free_external = 0 + seg_free_internal = 0 + seg_allocated = 0 + all_ranges = [] + boffset = 0 + for b in seg['blocks']: + active = b['state'] == 'active_allocated' + if active: + _, allocated_size = _block_extra(b) + all_ranges.append((boffset, allocated_size, True)) + seg_allocated += allocated_size + seg_free_internal += b['size'] - allocated_size + else: + seg_free_external += b['size'] + + boffset += b['size'] + + total_allocated += seg_allocated + free_external += seg_free_external + free_internal += seg_free_internal + + nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1 + occupied = [' ' for _ in range(nseg)] + frac = [0.0 for _ in range(nseg)] + active_size = 0 + for i, (start_, size, active) in enumerate(all_ranges): + active_size += size + finish_ = (start_ + size) + start = start_ // PAGE_SIZE + finish = (finish_ - 1) // PAGE_SIZE + 1 + m = chr(ord('a' if active else 'A') + (i % 26)) + for j in range(start, finish): + s = max(start_, j * PAGE_SIZE) + e = min(finish_, (j + 1) * PAGE_SIZE) + frac[j] += (e - s) / PAGE_SIZE + if occupied[j] != ' ': + occupied[j] = '0123456789*'[int(frac[j] * 10)] + else: + occupied[j] = m + stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}' + body = ''.join(occupied) + assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size'] + stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else '' + if seg['total_size'] >= PAGE_SIZE: + out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n') + out.write(f'segments: {len(data["segments"])}\n') + out.write(f'total_reserved: {Bytes(total_reserved)}\n') + out.write(f'total_allocated: {Bytes(total_allocated)}\n') + internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else '' + out.write(f'total_free: {_report_free(free_external, free_internal)}\n') + out.write(legend) + assert free_internal + free_external + total_allocated == total_reserved + return out.getvalue() + +def trace(data): + out = io.StringIO() + + def format(entries): + segment_intervals : list = [] + segment_addr_to_name = {} + allocation_addr_to_name = {} + + free_names : list = [] + next_name = 0 + + def _name(): + nonlocal next_name + if free_names: + return free_names.pop() + r, m = next_name // 26, next_name % 26 + next_name += 1 + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + + def find_segment(addr): + for name, saddr, size in segment_intervals: + if addr >= saddr and addr < saddr + size: + return name, saddr + for i, seg in enumerate(data['segments']): + saddr = seg['address'] + size = seg['allocated_size'] + if addr >= saddr and addr < saddr + size: + return f'seg_{i}', saddr + return None, None + count = 0 + out.write(f'{len(entries)} entries\n') + + + total_reserved = 0 + for seg in data['segments']: + total_reserved += seg['total_size'] + + for count, e in enumerate(entries): + if e['action'] == 'alloc': + addr, size = e['addr'], e['size'] + n = _name() + seg_name, seg_addr = find_segment(addr) + if seg_name is None: + seg_name = "MEM" + offset = addr + else: + offset = addr - seg_addr + out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n') + allocation_addr_to_name[addr] = (n, size, count) + count += size + elif e['action'] == 'free_requested': + addr, size = e['addr'], e['size'] + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f'del {name} # {Bytes(size)}\n') + elif e['action'] == 'free_completed': + addr, size = e['addr'], e['size'] + count -= size + name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None)) + out.write(f'# free completed for {name} {Bytes(size)}\n') + if name in allocation_addr_to_name: + free_names.append(name) + del allocation_addr_to_name[name] + elif e['action'] == 'segment_alloc': + addr, size = e['addr'], e['size'] + name = _name() + out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n') + segment_intervals.append((name, addr, size)) + segment_addr_to_name[addr] = name + elif e['action'] == 'segment_free': + addr, size = e['addr'], e['size'] + name = segment_addr_to_name.get(addr, addr) + out.write(f'cudaFree({name}) # {Bytes(size)}\n') + if name in segment_addr_to_name: + free_names.append(name) + del segment_addr_to_name[name] + elif e['action'] == 'oom': + size = e['size'] + free = e['device_free'] + out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n') + else: + out.write(f'{e}\n') + out.write(f"TOTAL MEM: {Bytes(count)}") + for i, d in enumerate(data['device_traces']): + if d: + out.write(f'Device {i} ----------------\n') + format(d) + return out.getvalue() + + +_memory_viz_template = r""" + + + + + + + +""" + +def _format_viz(data, viz_kind, device): + if device is not None: + warnings.warn( + 'device argument is deprecated, plots now contain all device', + FutureWarning, + stacklevel=3, + ) + buffer = pickle.dumps(data) + buffer += b'\x00' * (3 - len(buffer) % 3) + # Encode the buffer with base64 + encoded_buffer = base64.b64encode(buffer).decode('utf-8') + + json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}]) + return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \ + .replace('$SNAPSHOT', json_format) + +def trace_plot(data, device=None, plot_segments=False): + """Generate a visualization over time of the memory usage recorded by the trace as an html file. + + Args: + data: Memory snapshot as generated from torch.cuda.memory._snapshot() + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations. + Defaults to False. + + Returns: + str: HTML of visualization + """ + return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device) + + +def _profile_to_snapshot(profile): + import torch + from torch.profiler._memory_profiler import Action, TensorKey + from torch._C._profiler import _EventType + memory_profile = profile._memory_profile() + + allocation_stacks = {} + for event in memory_profile._op_tree.sorted_nodes: + if event.tag == _EventType.Allocation: + parent = event.parent + python_parents = [] + while parent: + if parent.tag in (_EventType.PyCall, _EventType.PyCCall): + python_parents.append(parent) + parent = parent.parent + key = TensorKey.from_allocation(event.extra_fields) + + # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor) + # key will be None. I should add some way to identify these, I just haven't yet. + if key and event.extra_fields.alloc_size > 0: + allocation_stacks[key] = python_parents + + + device_count = torch.cuda.device_count() + snapshot = { + 'device_traces': [[] for _ in range(device_count + 1)], + 'segments': [{'device': device, + 'address': None, + 'total_size': 0, + 'stream': 0, + 'blocks': []} for device in range(device_count + 1)] + } + + def to_device(device): + if device.type == 'cuda': + return device.index + else: + return device_count + + def allocate(size, tensor_key, version, during_trace=True): + device = to_device(tensor_key.device) + addr = tensor_key.storage.ptr + + seg = snapshot['segments'][device] # type: ignore[index] + if seg['address'] is None or seg['address'] > addr: + seg['address'] = addr + seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later + category = memory_profile._categories.get(tensor_key, version) + category = category.name.lower() if category is not None else "unknown" + stack = allocation_stacks.get(tensor_key, ()) + stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack] + r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category} + if during_trace: + snapshot['device_traces'][device].append(r) # type: ignore[index] + return r + + def free(alloc, device): + for e in ('free_requested', 'free_completed'): + snapshot['device_traces'][device].append({'action': e, # type: ignore[index] + 'addr': alloc['addr'], + 'size': alloc['size'], + 'stream': 0, + 'frames': alloc['frames']}) + + kv_to_elem = {} + + + + # create the device trace + for time, action, (tensor_key, version), size in memory_profile.timeline: + if not isinstance(tensor_key, TensorKey): + continue + if action == Action.CREATE: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version) + elif action == Action.DESTROY: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + elif action == Action.INCREMENT_VERSION: + free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device)) + kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1) + elif action == Action.PREEXISTING: + kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False) + + + # create the final snapshot state + blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames']) + for (tensor_key, version), event in kv_to_elem.items()] + for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)): + seg = snapshot['segments'][device] # type: ignore[index] + last_addr = seg['address'] + for _, addr, size, frames in blocks: + if last_addr < addr: + seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'}) + seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames}) + last_addr = addr + size + if last_addr < seg['total_size']: + seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'}) + + snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined] + for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef] + seg['total_size'] -= seg['address'] + if not seg['blocks']: + seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'}) + + return snapshot + +def profile_plot(profile, device=None): + """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file. + + Args: + profile: profile as generated by `torch.profiler.profile(profile_memory=True)` + device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations. + + Returns: + str: HTML of visualization + """ + snapshot = _profile_to_snapshot(profile) + return _format_viz(snapshot, 'Active Memory Timeline', device) + + +def segment_plot(data: Any, device=None): + return _format_viz(data, 'Allocator State History', device) + +if __name__ == "__main__": + import os.path + thedir = os.path.realpath(os.path.dirname(__file__)) + if thedir in sys.path: + # otherwise we find cuda/random.py as random... + sys.path.remove(thedir) + import argparse + + fn_name = 'torch.cuda.memory._snapshot()' + pickled = f'pickled memory statistics from {fn_name}' + parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}') + + subparsers = parser.add_subparsers(dest='action') + + def _output(p): + p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)') + + description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.' + stats_a = subparsers.add_parser('stats', description=description) + stats_a.add_argument('input', help=pickled) + + description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.' + trace_a = subparsers.add_parser('trace', description=description) + trace_a.add_argument('input', help=pickled) + + description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)' + segments_a = subparsers.add_parser('segments', description=description) + segments_a.add_argument('input', help=pickled) + _output(segments_a) + + description = "Generate a flamegraph the program locations contributing to CUDA memory usage." + memory_a = subparsers.add_parser('memory', description=description) + memory_a.add_argument('input', help=pickled) + _output(memory_a) + + description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \ + 'or removed between two different memorys snapshots.' + compare_a = subparsers.add_parser('compare', description=description) + compare_a.add_argument('before', help=pickled) + compare_a.add_argument('after', help=pickled) + _output(compare_a) + + plots = ( + ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."), + ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.") + ) + for cmd, description in plots: + trace_plot_a = subparsers.add_parser(cmd, description=description) + trace_plot_a.add_argument('input', help=pickled) + help = 'visualize trace from this device (default: chooses the only device with trace info or errors)' + trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help) + help = 'path to save the visualization(default: output.html)' + trace_plot_a.add_argument('-o', '--output', default='output.html', help=help) + if cmd == "trace_plot": + help = 'visualize change to segments rather than individual allocations' + trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help) + + + args = parser.parse_args() + + def _read(name): + if name == '-': + f = sys.stdin.buffer + else: + f = open(name, 'rb') + data = pickle.load(f) + if isinstance(data, list): # segments only... + data = {'segments': data, 'traces': []} + return data + + def _write(name, data): + with open(name, 'w') as f: + f.write(data) + + if args.action == 'segments': + data = _read(args.input) + _write(args.output, segments(data)) + elif args.action == 'memory': + data = _read(args.input) + _write(args.output, memory(data)) + elif args.action == 'stats': + data = _read(args.input) + print(segsum(data)) + elif args.action == 'trace': + data = _read(args.input) + print(trace(data)) + elif args.action == 'compare': + before = _read(args.before) + after = _read(args.after) + _write(args.output, compare(before, after)) + elif args.action == 'trace_plot': + data = _read(args.input) + _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments)) + elif args.action == 'segment_plot': + data = _read(args.input) + _write(args.output, segment_plot(data, device=args.device)) diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py b/.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab034850858782ae0547c3d5308b24fd376c5fc9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py @@ -0,0 +1,621 @@ +# mypy: allow-untyped-defs +r""" +This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams. + +It stores information on accesses to tensors to determine if they are synchronized +or not. When enabled in a python program and a possible data race is detected, a +detailed warning will be printed and the program will exit. + +It can be enabled either by importing this module and calling +:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER`` +environment variable. +""" + +import enum +import functools +import inspect +import io +import logging +import sys +import textwrap +import traceback +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar + +import torch +import torch.cuda._gpu_trace as gpu_trace +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import TorchDispatchMode + + +DEFAULT_STREAM_ID = 0 + +TK = TypeVar("TK") +TVa = TypeVar("TVa") +TVb = TypeVar("TVb") + +DataPtr = int +StreamId = int +EventId = int +SeqNum = int + +logger = logging.getLogger(__name__) + + +class AccessType(enum.Enum): + READ = enum.auto() + WRITE = enum.auto() + + def __str__(self): + return "reading from" if self is AccessType.READ else "writing to" + + +@dataclass +class Access: + r"""Stores information about a single access to a tensor by a kernel. + + Args: + type: either AccessType.READ or AccessType.Write. + seq_num: the sequential number of the kernel performing the access. + stream: the stream id of the stream executing the kernel. + operator: the schema of the launched kernel, which lists the + arguments and return type. + aliases: the arguments in the schema this access corresponds to. + is_output: Whether the tensor was an output of the kernel. + stack_trace: the stack summary object captured during access. + """ + + type: AccessType + seq_num: SeqNum + stream: StreamId + operator: str + aliases: List[str] + is_output: bool + stack_trace: traceback.StackSummary + + +class SynchronizationError(Exception): + """Base class for errors detected by CUDA Sanitizer.""" + + +class UnsynchronizedAccessError(SynchronizationError): + """Stores information about two unsynchronized accesses to one data pointer.""" + + def __init__( + self, + data_ptr: DataPtr, + allocation_stack_trace: Optional[traceback.StackSummary], + current_access: Access, + previous_access: Access, + ): + self.data_ptr = data_ptr + self.allocation_stack_trace = allocation_stack_trace + self.current_access = current_access + self.previous_access = previous_access + + def __str__(self): + def format_access(access: Access): + message.write(f"{access.operator}\n{access.type}") + if access.aliases: + message.write(" argument(s) " + ", ".join(access.aliases)) + if access.is_output: + message.write(", and to") + if access.is_output: + message.write(" the output") + message.write( + f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n" + ) + + with io.StringIO() as message: + message.write( + textwrap.dedent( + f"""\ + ============================ + CSAN detected a possible data race on tensor with data pointer {self.data_ptr} + Access by stream {self.current_access.stream} during kernel: + """ + ) + ) + format_access(self.current_access) + + message.write( + f"Previous access by stream {self.previous_access.stream} during kernel:\n" + ) + format_access(self.previous_access) + + if self.allocation_stack_trace: + message.write( + "Tensor was allocated with stack trace:\n" + f"{''.join(self.allocation_stack_trace.format())}" + ) + else: + message.write("Trace for tensor allocation not found.") + return message.getvalue() + + +class CUDASanitizerErrors(Exception): + """Wrapper class for errors reported by CUDA Sanitizer.""" + + def __init__(self, errors: List[SynchronizationError]): + self.errors = errors + + def __str__(self): + return f"detected {len(self.errors)} errors" + + +@dataclass +class TensorInfo: + r"""Stores information about a single tensor and recent accesses to it. + + Args: + allocation_stack_trace: the stack summary object captured during tensor + allocation. Can be ``None`` if the allocation wasn't caught by CSAN. + reads: list of read accesses to the tensor that were performed since + the last write. + write: the last write access to the tensor. + """ + + allocation_stack_trace: Optional[traceback.StackSummary] + reads: List[Access] = field(default_factory=list) + write: Optional[Access] = None + + +class _TensorsAccessed: + def __init__(self) -> None: + self.accesses: Dict[DataPtr, TensorInfo] = {} + + def ensure_tensor_exists(self, data_ptr: DataPtr) -> None: + if data_ptr not in self.accesses: + logger.info( + "Found tensor with pointer: %s, but no matching tensor " + "allocation in the trace. Backfilling the trace now. " + "Perhaps the sanitizer was enabled after some torch operations?", + data_ptr, + ) + self.create_tensor(data_ptr, None) + + def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None: + if data_ptr in self.accesses: + logger.info( + "Found duplicate tensor allocation in the trace for tensor with " + "pointer: %s. Assuming the trace for tensor deallocation " + "wasn't caught and backfilling it now. " + "Perhaps the sanitizer was enabled after some torch operations?", + data_ptr, + ) + self.delete_tensor(data_ptr) + + def create_tensor( + self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary] + ) -> None: + self.accesses[data_ptr] = TensorInfo(stack_trace) + + def delete_tensor(self, data_ptr: DataPtr) -> None: + del self.accesses[data_ptr] + + def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool: + return True if self.accesses[data_ptr].reads else False + + def get_allocation_stack_trace( + self, data_ptr: DataPtr + ) -> Optional[traceback.StackSummary]: + return self.accesses[data_ptr].allocation_stack_trace + + def get_write(self, data_ptr: DataPtr) -> Optional[Access]: + return self.accesses[data_ptr].write + + def get_reads(self, data_ptr: DataPtr) -> List[Access]: + return self.accesses[data_ptr].reads + + def add_read(self, data_ptr: DataPtr, access: Access) -> None: + self.accesses[data_ptr].reads.append(access) + + def set_write(self, data_ptr: DataPtr, access: Access) -> None: + self.accesses[data_ptr].write = access + self.accesses[data_ptr].reads = [] + + +class StreamSynchronizations: + def __init__(self) -> None: + self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {} + self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {} + self.host_sync_state: Dict[StreamId, SeqNum] = {} + self.create_stream(DEFAULT_STREAM_ID) + + def _ensure_stream_exists(self, stream: StreamId) -> None: + if stream not in self.current_sync_states: + logger.info( + "Found Stream with id: %s, but no matching stream " + "creation in the trace. Backfilling the trace now. " + "Perhaps the sanitizer was enabled after some torch operations?", + stream, + ) + self.create_stream(stream) + + def _ensure_event_exists(self, event: EventId) -> None: + if event not in self.recorded_sync_states: + logger.info( + "Found Event with id: %s, but no matching event " + "creation in the trace. Backfilling the trace now. " + "Perhaps the sanitizer was enabled after some torch operations?", + event, + ) + self.create_event(event) + + def _ensure_event_does_not_exist(self, event: EventId) -> None: + if event in self.recorded_sync_states: + logger.info( + "Found duplicate event creation in the trace for event with " + "id: %s. Assuming the trace for event deletion wasn't caught " + "and backfilling it now. " + "Perhaps the sanitizer was enabled after some torch operations?", + event, + ) + self.delete_event(event) + + def create_stream(self, stream: StreamId) -> None: + if stream in self.current_sync_states: + logger.info( + "Found duplicate Stream creation in the trace for Stream with " + "id: %s. PyTorch Streams are only created once, so this " + "trace entry is ignored.", + stream, + ) + else: + self.host_sync_state[stream] = 0 + self.current_sync_states[stream] = self.host_sync_state.copy() + + def create_event(self, event: EventId) -> None: + self._ensure_event_does_not_exist(event) + self.recorded_sync_states[event] = {} + + def delete_event(self, event: EventId) -> None: + self._ensure_event_exists(event) + del self.recorded_sync_states[event] + + def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None: + self._ensure_stream_exists(stream) + self.current_sync_states[stream][stream] = seq_num + + def record_state(self, event: EventId, stream: StreamId) -> None: + self._ensure_event_exists(event) + self._ensure_stream_exists(stream) + self.recorded_sync_states[event] = self.current_sync_states[stream].copy() + + def _state_wait_for_other( + self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum] + ) -> None: + for stream, seq_num in other.items(): + state[stream] = max(state.get(stream, -1), seq_num) + + def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None: + self._ensure_stream_exists(stream) + self._ensure_event_exists(event) + self._state_wait_for_other( + self.current_sync_states[stream], self.recorded_sync_states[event] + ) + + def all_streams_wait_for_event(self, event: EventId) -> None: + self._ensure_event_exists(event) + for stream in self.current_sync_states.keys(): + self.stream_wait_for_event(stream, event) + + self._state_wait_for_other( + self.host_sync_state, self.recorded_sync_states[event] + ) + + def all_streams_wait_for_stream(self, stream: StreamId) -> None: + self._ensure_stream_exists(stream) + for state in self.current_sync_states.values(): + self._state_wait_for_other(state, self.current_sync_states[stream]) + + self._state_wait_for_other( + self.host_sync_state, self.current_sync_states[stream] + ) + + def sync_all_streams(self) -> None: + for stream, state in self.current_sync_states.items(): + self.host_sync_state[stream] = state[stream] + + for state in self.current_sync_states.values(): + self._state_wait_for_other(state, self.host_sync_state) + + def is_ordered_after( + self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId + ) -> bool: + self._ensure_stream_exists(current_stream) + self._ensure_stream_exists(other_stream) + return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1) + + +class EventHandler: + """Analyzes CSAN trace for synchronization errors. + + Stores information on each stream's synchronizations with other streams as well + as tensor accesses to determine whether a given kernel launch might cause a + data race. + """ + + def __init__(self) -> None: + self.tensors_accessed = _TensorsAccessed() + self.syncs = StreamSynchronizations() + self.seq_num: SeqNum = 0 + + def _handle_kernel_launch( + self, + stream: StreamId, + read_only: Set[DataPtr], + read_write: Set[DataPtr], + outputs: Set[DataPtr], + operator: str, + tensor_aliases: Dict[int, List[str]], + ) -> List[SynchronizationError]: + def check_conflict( + data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access] + ) -> None: + if previous_access is None: + return + if not self.syncs.is_ordered_after( + current_access.stream, previous_access.seq_num, previous_access.stream + ): + error_list.append( + UnsynchronizedAccessError( + data_ptr, + self.tensors_accessed.get_allocation_stack_trace(data_ptr), + current_access, + previous_access, + ) + ) + + error_list: List[SynchronizationError] = [] + self.seq_num += 1 + self.syncs.update_seq_num(stream, self.seq_num) + stack_trace = traceback.StackSummary.extract( + traceback.walk_stack(inspect.currentframe()), lookup_lines=False + ) + # The stack trace generated in this way is in the inverse order, so it must be + # reversed. + stack_trace.reverse() + + for data_ptr in read_only: + self.tensors_accessed.ensure_tensor_exists(data_ptr) + current_access = Access( + AccessType.READ, + self.seq_num, + stream, + operator, + tensor_aliases[data_ptr], + data_ptr in outputs, + stack_trace, + ) + check_conflict( + data_ptr, current_access, self.tensors_accessed.get_write(data_ptr) + ) + self.tensors_accessed.add_read(data_ptr, current_access) + + for data_ptr in read_write: + self.tensors_accessed.ensure_tensor_exists(data_ptr) + current_access = Access( + AccessType.WRITE, + self.seq_num, + stream, + operator, + tensor_aliases[data_ptr], + data_ptr in outputs, + stack_trace, + ) + if self.tensors_accessed.were_there_reads_since_last_write(data_ptr): + for previous_access in self.tensors_accessed.get_reads(data_ptr): + check_conflict(data_ptr, current_access, previous_access) + else: + check_conflict( + data_ptr, current_access, self.tensors_accessed.get_write(data_ptr) + ) + self.tensors_accessed.set_write(data_ptr, current_access) + + return error_list + + def _handle_event_creation(self, event: EventId) -> None: + self.syncs.create_event(event) + + def _handle_event_deletion(self, event: EventId) -> None: + self.syncs.delete_event(event) + + def _handle_event_record(self, event: EventId, stream: StreamId) -> None: + self.syncs.record_state(event, stream) + + def _handle_event_wait(self, event: EventId, stream: StreamId) -> None: + self.syncs.stream_wait_for_event(stream, event) + + def _handle_memory_allocation(self, data_ptr: DataPtr) -> None: + self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr) + stack_trace = traceback.StackSummary.extract( + traceback.walk_stack(inspect.currentframe()), lookup_lines=False + ) + # The stack trace generated in this way is in the inverse order, so it must be + # reversed. + stack_trace.reverse() + self.tensors_accessed.create_tensor( + data_ptr, + stack_trace, + ) + + def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None: + self.tensors_accessed.ensure_tensor_exists(data_ptr) + self.tensors_accessed.delete_tensor(data_ptr) + + def _handle_stream_creation(self, stream: StreamId) -> None: + self.syncs.create_stream(stream) + + def _handle_device_synchronization(self) -> None: + self.syncs.sync_all_streams() + + def _handle_stream_synchronization(self, stream: StreamId) -> None: + self.syncs.all_streams_wait_for_stream(stream) + + def _handle_event_synchronization(self, event: EventId) -> None: + self.syncs.all_streams_wait_for_event(event) + + +def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]: + for arg, value in a.items(): + if arg in b: + yield arg, value, b[arg] + + +def zip_arguments( + schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Iterator[Tuple[torch.Argument, Any]]: + schema_args = schema.arguments[: len(args)] + schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]} + + yield from zip(schema_args, args) + + for _, argument, value in zip_by_key(schema_kwargs, kwargs): + yield (argument, value) + + +class ArgumentHandler: + def __init__(self) -> None: + self.dataptrs_read: Set[DataPtr] = set() + self.dataptrs_written: Set[DataPtr] = set() + self.tensor_aliases: Dict[DataPtr, List[str]] = {} + self.outputs: Set[DataPtr] = set() + + def _handle_argument( + self, + value: Any, + is_write: bool, + name: Optional[str] = None, + is_output: bool = False, + ) -> None: + if isinstance(value, torch.Tensor) and value.is_cuda: + data_ptr = value.data_ptr() + if is_write: + self.dataptrs_written.add(data_ptr) + else: + self.dataptrs_read.add(data_ptr) + + self.tensor_aliases.setdefault(data_ptr, []) + if name is not None: + self.tensor_aliases[data_ptr].append(name) + if is_output: + self.outputs.add(data_ptr) + + def parse_inputs( + self, + schema: torch.FunctionSchema, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> None: + for argument, value in zip_arguments(schema, args, kwargs): + is_write = argument.alias_info is not None and argument.alias_info.is_write + pytree.tree_map_( + functools.partial( + self._handle_argument, is_write=is_write, name=argument.name + ), + value, + ) + + def parse_outputs(self, outputs: Any) -> None: + pytree.tree_map_( + functools.partial(self._handle_argument, is_write=True, is_output=True), + outputs, + ) + + +class CUDASanitizerDispatchMode(TorchDispatchMode): + def __init__(self) -> None: + self.event_handler = EventHandler() + torch._C._activate_gpu_trace() + gpu_trace.register_callback_for_event_creation( + self.event_handler._handle_event_creation + ) + gpu_trace.register_callback_for_event_deletion( + self.event_handler._handle_event_deletion + ) + gpu_trace.register_callback_for_event_record( + self.event_handler._handle_event_record + ) + gpu_trace.register_callback_for_event_wait( + self.event_handler._handle_event_wait + ) + gpu_trace.register_callback_for_memory_allocation( + self.event_handler._handle_memory_allocation + ) + gpu_trace.register_callback_for_memory_deallocation( + self.event_handler._handle_memory_deallocation + ) + gpu_trace.register_callback_for_stream_creation( + self.event_handler._handle_stream_creation + ) + gpu_trace.register_callback_for_device_synchronization( + self.event_handler._handle_device_synchronization + ) + gpu_trace.register_callback_for_stream_synchronization( + self.event_handler._handle_stream_synchronization + ) + gpu_trace.register_callback_for_event_synchronization( + self.event_handler._handle_event_synchronization + ) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + argument_handler = ArgumentHandler() + argument_handler.parse_inputs(func._schema, args, kwargs) + + outputs = func(*args, **kwargs) + + argument_handler.parse_outputs(outputs) + errors = self.event_handler._handle_kernel_launch( + torch.cuda.current_stream().cuda_stream, + argument_handler.dataptrs_read - argument_handler.dataptrs_written, + argument_handler.dataptrs_written, + argument_handler.outputs, + func._schema, + argument_handler.tensor_aliases, + ) + if errors: + for error in errors: + print(error, file=sys.stderr) + raise CUDASanitizerErrors(errors) + + return outputs + + +class CUDASanitizer: + """Manages the lifetime of a CUDASanitizer dispatch mode object. + + The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode + context manager in the enable function/destructor, respectively. This is to + explicitly set the lifetime of the dispatch mode object to that of the application. + This approach was deemed more elegant than using the atexit module. + """ + + def __init__(self) -> None: + self.dispatch = CUDASanitizerDispatchMode() + self.enabled = False + + def enable(self): + self.dispatch.__enter__() + self.enabled = True + + def __del__(self): + if self.enabled: + self.dispatch.__exit__(None, None, None) + + +def enable_cuda_sanitizer(): + """Enable CUDA Sanitizer. + + The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions + for synchronization errors. All data races found will be printed to the standard + error output along with stack traces of suspected causes. For best results, the + sanitizer should be enabled at the very beginning of the program. + """ + cuda_sanitizer.enable() + + +cuda_sanitizer = CUDASanitizer() diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/_utils.py b/.venv/lib/python3.11/site-packages/torch/cuda/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0ee8830bd68c23115e7788bb7e1a0c220b1882 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/_utils.py @@ -0,0 +1,38 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a CUDA device. Note that for a CUDA device without a specified index, + i.e., ``torch.device('cuda')``, this will return the current default CUDA + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default CUDA + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["cuda", "cpu"]: + raise ValueError(f"Expected a cuda or cpu device, but got: {device}") + elif device.type != "cuda": + raise ValueError(f"Expected a cuda device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.cuda.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/comm.py b/.venv/lib/python3.11/site-packages/torch/cuda/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..2915de5d090fd18c82540beedb9971a0b7b6cc3e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/comm.py @@ -0,0 +1,19 @@ +# The functions here have been moved to torch.nn.parallel.comm +from torch.nn.parallel.comm import ( + broadcast, + broadcast_coalesced, + gather, + reduce_add, + reduce_add_coalesced, + scatter, +) + + +__all__ = [ + "broadcast", + "broadcast_coalesced", + "reduce_add", + "reduce_add_coalesced", + "scatter", + "gather", +] diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/error.py b/.venv/lib/python3.11/site-packages/torch/cuda/error.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/gds.py b/.venv/lib/python3.11/site-packages/torch/cuda/gds.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd5b8824103d2ebf4e217793f7d42923bb3c2a0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/gds.py @@ -0,0 +1,129 @@ +import os +import sys +from typing import Callable, List, Optional + +import torch +from torch.types import Storage + + +__all__: List[str] = [] + + +def _dummy_fn(name: str) -> Callable: + def fn(*args, **kwargs): # type: ignore[no-untyped-def] + raise RuntimeError(f"torch._C.{name} is not supported on this platform") + + return fn + + +if not hasattr(torch._C, "_gds_register_buffer"): + assert not hasattr(torch._C, "_gds_deregister_buffer") + assert not hasattr(torch._C, "_gds_register_handle") + assert not hasattr(torch._C, "_gds_deregister_handle") + assert not hasattr(torch._C, "_gds_load_storage") + assert not hasattr(torch._C, "_gds_save_storage") + # Define functions + torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer") + torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer") + torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle") + torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle") + torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage") + torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage") + + +def _gds_register_buffer(s: Storage) -> None: + """Registers a buffer. + + Args: + s (Storage): Buffer to register. + """ + torch._C._gds_register_buffer(s) + + +def _gds_deregister_buffer(s: Storage) -> None: + """Registers a buffer. + + Args: + s (Storage): Buffer to register. + """ + torch._C._gds_deregister_buffer(s) + + +class _GdsFile: + r"""Wrapper around cuFile. + + cuFile is a file-like interface to the GPUDirect Storage (GDS) API. + + Args: + filename (str): Name of the file to open. + flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will + be added automatically. + + .. _CUDA GPUDirect Storage Documentation: + https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api + """ + + def __init__(self, filename: str, flags: int): + if sys.platform == "win32": + raise RuntimeError("GdsFile is not supported on this platform.") + self.filename = filename + self.flags = flags + self.fd = os.open(filename, flags | os.O_DIRECT) + self.handle: Optional[int] = None + self.register_handle() + + def __del__(self) -> None: + if self.handle is not None: + self.deregister_handle() + os.close(self.fd) + + def register_handle(self) -> None: + """Registers file descriptor to cuFile Driver. + + This is a wrapper around ``cuFileHandleRegister``. + """ + assert ( + self.handle is None + ), "Cannot register a handle that is already registered." + self.handle = torch._C._gds_register_handle(self.fd) + + def deregister_handle(self) -> None: + """Deregisters file descriptor from cuFile Driver. + + This is a wrapper around ``cuFileHandleDeregister``. + """ + assert ( + self.handle is not None + ), "Cannot deregister a handle that is not registered." + torch._C._gds_deregister_handle(self.handle) + self.handle = None + + def load_storage(self, storage: Storage, offset: int = 0) -> None: + """Loads data from the file into the storage. + + This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data + will be loaded from the file at ``offset`` into the storage. + + Args: + storage (Storage): Storage to load data into. + offset (int, optional): Offset into the file to start loading from. (Default: 0) + """ + assert ( + self.handle is not None + ), "Cannot load data from a file that is not registered." + torch._C._gds_load_storage(self.handle, storage, offset) + + def save_storage(self, storage: Storage, offset: int = 0) -> None: + """Saves data from the storage into the file. + + This is a wrapper around ``cuFileWrite``. All bytes of the storage + will be written to the file at ``offset``. + + Args: + storage (Storage): Storage to save data from. + offset (int, optional): Offset into the file to start saving to. (Default: 0) + """ + assert ( + self.handle is not None + ), "Cannot save data to a file that is not registered." + torch._C._gds_save_storage(self.handle, storage, offset) diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py b/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py new file mode 100644 index 0000000000000000000000000000000000000000..b5de9f73df726cc2d6e0fd5bee2bd178dcf6aa89 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py @@ -0,0 +1,491 @@ +# mypy: allow-untyped-defs +import gc +import typing + +import torch + +from .._utils import _dummy_type + + +if not hasattr(torch._C, "_CudaStreamBase"): + # Define dummy base classes + torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") + torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle") + torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type( + "_cuda_isCurrentStreamCapturing" + ) + +from torch._C import ( # noqa: F401 + _cuda_isCurrentStreamCapturing, + _CUDAGraph, + _graph_pool_handle, +) + + +def is_current_stream_capturing(): + r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. + + If a CUDA context does not exist on the current device, returns False without initializing the context. + """ + return _cuda_isCurrentStreamCapturing() + + +# Python shim helps Sphinx process docstrings more reliably. +def graph_pool_handle(): + r"""Return an opaque token representing the id of a graph memory pool. + + See :ref:`Graph memory management`. + + .. warning:: + This API is in beta and may change in future releases. + """ + return _graph_pool_handle() + + +# Python shim helps Sphinx process docstrings more reliably. +class CUDAGraph(torch._C._CUDAGraph): + r"""Wrapper around a CUDA graph. + + .. warning:: + This API is in beta and may change in future releases. + """ + + def __new__(cls): + return super().__new__(cls) + + def capture_begin(self, pool=None, capture_error_mode="global"): + r"""Begin capturing CUDA work on the current stream. + + Typically, you shouldn't call ``capture_begin`` yourself. + Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, + which call ``capture_begin`` internally. + + Arguments: + pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or + :meth:`other_Graph_instance.pool()`) that hints this graph may share memory + with the indicated pool. See :ref:`Graph memory management`. + capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting + unless you're familiar with `cudaStreamCaptureMode `_ + """ # noqa: B950 + super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) + + def capture_end(self): + r"""End CUDA graph capture on the current stream. + + After ``capture_end``, ``replay`` may be called on this instance. + + Typically, you shouldn't call ``capture_end`` yourself. + Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, + which call ``capture_end`` internally. + """ + super().capture_end() + + def replay(self): + r"""Replay the CUDA work captured by this graph.""" + super().replay() + + def reset(self): + r"""Delete the graph currently held by this instance.""" + super().reset() + + def pool(self): + r"""Return an opaque token representing the id of this graph's memory pool. + + This id can optionally be passed to another graph's ``capture_begin``, + which hints the other graph may share the same memory pool. + """ + return super().pool() + + def enable_debug_mode(self): + r"""Enable debugging mode for CUDAGraph.debug_dump.""" + return super().enable_debug_mode() + + def debug_dump(self, debug_path): + r""" + Arguments: + debug_path (required): Path to dump the graph to. + + Calls a debugging function to dump the graph if the debugging is + enabled via CUDAGraph.enable_debug_mode() + """ + return super().debug_dump(debug_path) + + +class graph: + r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay. + + See :ref:`CUDA Graphs ` for a general introduction, + detailed use, and constraints. + + Arguments: + cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture. + pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or + :meth:`other_Graph_instance.pool()`) hinting this graph's capture + may share memory from the specified pool. See :ref:`Graph memory management`. + stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context. + If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. + capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. + Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, + may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for + actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting + unless you're familiar with `cudaStreamCaptureMode `_ + + .. note:: + For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture + used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture. + + .. warning:: + This API is in beta and may change in future releases. + + .. _cudaStreamCaptureMode: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 + """ # noqa: B950 + + default_capture_stream: typing.Optional["torch.cuda.Stream"] = None + + def __init__( + self, + cuda_graph, + pool=None, + stream=None, + capture_error_mode: str = "global", + ): + # Lazy-init of default_capture_stream helps avoid circular-import errors. + # Not thread safe, but graphs already have the general (explicitly documented) + # restriction that only one capture may be underway at a time in the process. + if self.__class__.default_capture_stream is None: + self.__class__.default_capture_stream = torch.cuda.Stream() + + self.pool = () if pool is None else (pool,) + self.capture_stream = ( + stream if stream is not None else self.__class__.default_capture_stream + ) + assert self.capture_stream is not None + self.stream_ctx = torch.cuda.stream(self.capture_stream) + self.cuda_graph = cuda_graph + self.capture_error_mode = capture_error_mode + + def __enter__(self): + # Free as much memory as we can for the graph + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + # Stackoverflow seems comfortable with this pattern + # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 + self.stream_ctx.__enter__() + + self.cuda_graph.capture_begin( + *self.pool, capture_error_mode=self.capture_error_mode + ) + + def __exit__(self, exc_type, exc_value, traceback): + self.cuda_graph.capture_end() + self.stream_ctx.__exit__(exc_type, exc_value, traceback) + # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() + + +def make_graphed_callables( + callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None +): + r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions. + + Each graphed callable's forward pass runs its source callable's + forward CUDA work as a CUDA graph inside a single autograd node. + + The graphed callable's forward pass also appends + a backward node to the autograd graph. During backward, this node runs the + callable's backward work as a CUDA graph. + + Therefore, each graphed callable should be a drop-in replacement for its source callable + in an autograd-enabled training loop. + + See :ref:`Partial-network capture` for detailed use and constraints. + + If you pass a tuple of several callables, their captures will use the same memory pool. + See :ref:`Graph memory management` for when this is appropriate. + + Arguments: + callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph. + See :ref:`Graph memory management` for when passing a tuple of callables + is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order + they'll run in the live workload. + sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable. + If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors. + If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors. + num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs + 11 iterations for warm up. Default: ``3``. + allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs + (and therefore their grad is always zero) is an error. Defaults to False. + pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or + :meth:`other_Graph_instance.pool()`) that hints this graph may share memory + with the indicated pool. See :ref:`Graph memory management`. + .. note:: + The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state + that's expected for the corresponding real input in the training loop. + + .. warning:: + This API is in beta and may change in future releases. + + .. warning:: + ``sample_args`` for each callable must contain only Tensors. Other types are not allowed. + + .. warning:: + Returned callables do not support higher order differentiation (e.g., double backward). + + .. warning:: + In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters + may be trainable. Buffers must have ``requires_grad=False``. + + .. warning:: + After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`, + you may not add or remove any of that Module's parameters or buffers. + + .. warning:: + :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks + registered on them at the time they are passed. However, registering hooks on modules *after* passing them + through :func:`~torch.cuda.make_graphed_callables` is allowed. + + .. warning:: + When running a graphed callable, you must pass its arguments in the same order and format + they appeared in that callable's ``sample_args``. + + .. warning:: + The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled + caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`. + """ + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." + ) + + just_one_callable = False + + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + flatten_sample_args = [] + + for c, args in zip(callables, sample_args): + if isinstance(c, torch.nn.Module): + assert ( + len(c._backward_hooks) == 0 + and len(c._forward_hooks) == 0 + and len(c._forward_pre_hooks) == 0 + ), ( + "Modules must not have hooks registered at the time they are passed. However, registering hooks " + + "on modules after passing them through make_graphed_callables is allowed." + ) + assert all(b.requires_grad is False for b in c.buffers()), ( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + + "``requires_grad=False``." + ) + flatten_arg = torch.utils._pytree.arg_tree_leaves(*args) + flatten_sample_args.append(tuple(flatten_arg)) + assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) + + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly + # passes to forward (ie, its sample_args) AND the module's parameter attributes. + per_callable_len_user_args = [len(args) for args in flatten_sample_args] + per_callable_module_params = [ + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables + ] + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(callables)) + ] + + fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + + mempool = graph_pool_handle() if pool is None else pool + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + torch.cuda.synchronize() + with torch.cuda.stream(torch.cuda.Stream()): + for func, args, static_input_surface in zip( + callables, sample_args, per_callable_static_input_surfaces + ): + grad_inputs, outputs, outputs_grad = None, None, None + for _ in range(num_warmup_iters): + outputs = torch.utils._pytree.tree_leaves(func(*args)) + outputs_grad = tuple(o for o in outputs if o.requires_grad) + if len(outputs_grad) > 0: + grad_inputs = torch.autograd.grad( + outputs=outputs_grad, + inputs=tuple( + i for i in static_input_surface if i.requires_grad + ), + grad_outputs=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + only_inputs=True, + allow_unused=allow_unused_input, + ) + for v in [outputs, outputs_grad, grad_inputs]: + del v + + torch.cuda.synchronize() + + # All captures here share a mempool. To avoid replays corrupting each other's memory, + # the safest approach is to capture all passes in the same order they'll run: + # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + + # Capture forward graphs + per_callable_static_outputs = [] + per_callable_output_unflatten_spec = [] + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + + flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) + per_callable_static_outputs.append(tuple(flatten_outputs)) + per_callable_output_unflatten_spec.append(spec) + + # Capture backward graphs in reverse order + per_callable_static_grad_outputs = [] + per_callable_static_grad_inputs = [] + for static_input_surface, static_outputs, bwd_graph, module_params in zip( + reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + reversed(per_callable_module_params), + ): + # For now, assumes all static_outputs require grad + # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + + outputs_grad = tuple(o for o in static_outputs if o.requires_grad) + grad_inputs = None + if len(outputs_grad) > 0: + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=outputs_grad, + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad. + # I couldn't think of a slick one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad and grad_inputs is not None: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs.append(static_grad_outputs) + per_callable_static_grad_inputs.append(static_grad_inputs) + + # Reverses the most recent two lists + per_callable_static_grad_outputs.reverse() + per_callable_static_grad_inputs.reverse() + # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + + def make_graphed_autograd_function( + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): + class Graphed(torch.autograd.Function): + @staticmethod + def forward(ctx, *inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + fwd_graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grads): + assert len(grads) == len(static_grad_outputs) + for g, grad in zip(static_grad_outputs, grads): + if g is not None: + # don't copy if autograd gods have been kind and the + # incoming grad is already in the right place + if g.data_ptr() != grad.data_ptr(): + g.copy_(grad) + bwd_graph.replay() + + # Input args that didn't require grad expect a None gradient. + assert isinstance(static_grad_inputs, tuple) + return tuple( + b.detach() if b is not None else b for b in static_grad_inputs + ) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args) + out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) + return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callables + ret = [] + for i, func in enumerate(callables): + graphed = make_graphed_autograd_function( + fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_unflatten_spec[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i], + ) + + if isinstance(func, torch.nn.Module): + + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + return graphed(*user_args) + else: + return orig_fwd(*user_args) + + return new_fwd + + func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] + ret.append(func) + else: + ret.append(graphed) + + if just_one_callable: + return ret[0] + + return tuple(ret) diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py b/.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py new file mode 100644 index 0000000000000000000000000000000000000000..4f1e6393bac3d8424d501b975cd0b0bd958b6828 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py @@ -0,0 +1,187 @@ +# mypy: allow-untyped-defs +import re +from typing import Callable, List + +import torch +from torch import Tensor + + +__all__: List[str] = [] + + +class _CodeParser: + def __init__(self, code_string: str): + optional_ws = r"\s*" + required_ws = r"\s+" + template_params = r"(?P\<.+\>)" + return_type = r"(?P\w+)" + function_name = r"(?P\w+)" + function_params = r"(?P\(.+\))" + function_body = r"(?P\{.+\})" + + pattern = ( + optional_ws + + "template" + + optional_ws + + template_params + + optional_ws + + return_type + + required_ws + + function_name + + optional_ws + + function_params + + optional_ws + + function_body + + optional_ws + ) + + result = re.match( + pattern, code_string, re.DOTALL + ) # DOTALL for matching multiline + + if result is None: + raise Exception( # noqa: TRY002 + f"Couldn't parse code, please check correctness:\n {code_string}" + ) + + self.template_params = result["template_params"] + self.return_type = result["return_type"] + self.function_name = result["function_name"] + self.function_params = result["function_params"] + self.function_body = result["function_body"] + + +class _JittedFunction: + def __init__( + self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs + ): + self.code_string = code_string + + assert ( + return_by_ref or num_outputs == 1 + ), "Return by value only works for single output. " + self.return_by_ref = return_by_ref + self.num_outputs = num_outputs + + parsed_code = _CodeParser(code_string) + self.kernel_name = parsed_code.function_name + + self.kwargs_dict = kwargs + self.is_cuda_available = torch.cuda.is_available() + + def __call__(self, *tensors: Tensor, **kwargs): + # Jiterator follow torch.cuda's lazy initialization behavior + # Defer checking cuda's availability at the function invocation time + assert ( + self.is_cuda_available + ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available." + + assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." + + expanded_kwargs = self.kwargs_dict.copy() + for key, value in kwargs.items(): + if key in self.kwargs_dict: + expanded_kwargs[key] = value + else: + raise KeyError(f"{key} is not declared in function definition") + + return torch._C._cuda_jiterator_compile_and_launch_kernel( + self.code_string, + self.kernel_name, + self.return_by_ref, + self.num_outputs, + tensors, + expanded_kwargs, + ) + + +def _create_jit_fn(code_string: str, **kwargs) -> Callable: + """ + Create a jiterator-generated cuda kernel for an elementwise op. + + The code string has to be a valid CUDA function that describes the computation for a single element. The code + string has to follow the c++ template pattern, as shown in the example below. This function will be inlined + into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as + local temp dir. + + Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion. + + Args: + code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value. + kwargs (Dict, optional): Keyword arguments for generated function + + Example:: + + code_string = "template T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" + jitted_fn = create_jit_fn(code_string, alpha=1.0) + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') + # invoke jitted function like a regular python function + result = jitted_fn(a, b, alpha=3.14) + + code_string also allows multiple function definitions, and the last function will be treated as the entry function. + + Example:: + + code_string = "template T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" + code_string += "template T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" + jitted_fn = create_jit_fn(code_string, val=0.0) + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') + # invoke jitted function like a regular python function + result = jitted_fn(a, b) # using default val=0.0 + + Jiterator can be used together with python registration to override an operator's cuda kernel. + Following example is overriding gelu's cuda kernel with relu. + + Example:: + + code_string = "template T my_gelu(T a) { return a > 0 ? a : 0; }" + my_gelu = create_jit_fn(code_string) + my_lib = torch.library.Library("aten", "IMPL") + my_lib.impl('aten::gelu', my_gelu, "CUDA") + # torch.nn.GELU and torch.nn.function.gelu are now overridden + a = torch.rand(3, device='cuda') + torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a)) + + .. warning:: + This API is in beta and may change in future releases. + + .. warning:: + This API only supports up to 8 inputs and 1 output + + .. warning:: + All input tensors must live in CUDA device + """ + return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs) + + +def _create_multi_output_jit_fn( + code_string: str, num_outputs: int, **kwargs +) -> Callable: + """ + Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. + + Args: + code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference. + num_outputs(int): number of outputs return by the kernel + kwargs (Dict, optional): Keyword arguments for generated function + + Example:: + + code_string = "template void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" + jitted_fn = create_jit_fn(code_string, alpha=1.0) + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') + # invoke jitted function like a regular python function + result = jitted_fn(a, b, alpha=3.14) + + .. warning:: + This API is in beta and may change in future releases. + + .. warning:: + This API only supports up to 8 inputs and 8 outputs + """ + return _JittedFunction( + code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs + ) diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/memory.py b/.venv/lib/python3.11/site-packages/torch/cuda/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..af2c8d480c8345430ddc082a19b1a1f5b771364f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/memory.py @@ -0,0 +1,1041 @@ +# mypy: allow-untyped-defs +r"""This package adds support for device memory management implemented in CUDA.""" + +import collections +import contextlib +import ctypes +import pickle +import sys +import warnings +from inspect import signature +from typing import Any, Dict, Optional, Tuple, Union +from typing_extensions import deprecated + +import torch +from torch import _C +from torch._utils import _dummy_type +from torch.types import Device + +from . import ( + _get_amdsmi_device_index, + _get_device_index, + _get_nvml_device_index, + _lazy_init, + is_initialized, +) +from ._memory_viz import memory as _memory, segments as _segments + + +__all__ = [ + "caching_allocator_alloc", + "caching_allocator_delete", + "set_per_process_memory_fraction", + "empty_cache", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "reset_max_memory_allocated", + "reset_max_memory_cached", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "memory_cached", + "max_memory_cached", + "memory_snapshot", + "memory_summary", + "list_gpu_processes", + "mem_get_info", + "get_allocator_backend", + "CUDAPluggableAllocator", + "change_current_allocator", + "MemPool", + "MemPoolContext", + "use_mem_pool", +] + + +if not hasattr(torch._C, "_cuda_CUDAAllocator"): + # Define dummy base classes + torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator") + + +if not hasattr(torch._C, "_MemPool"): + # Define dummy base classes + torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") + torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext") + torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( + "_cuda_beginAllocateToPool" + ) + torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type( + "_cuda_endAllocateCurrentStreamToPool" + ) + +from torch._C import ( # noqa: F401 + _cuda_beginAllocateToPool, + _cuda_CUDAAllocator, + _cuda_endAllocateCurrentStreamToPool, + _MemPool, + _MemPoolContext, +) + + +def _host_allocator(): + _lazy_init() + return torch._C._cuda_cudaHostAllocator() + + +@contextlib.contextmanager +def _free_mutex(): + torch._C._cuda_lock_mutex() + try: + yield + finally: + torch._C._cuda_unlock_mutex() + + +def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None): + r"""Perform a memory allocation using the CUDA memory allocator. + + Memory is allocated for a given device and a stream, this + function is intended to be used for interoperability with other + frameworks. Allocated memory is released through + :func:`~torch.cuda.caching_allocator_delete`. + + Args: + size (int): number of bytes to be allocated. + device (torch.device or int, optional): selected device. If it is + ``None`` the default CUDA device is used. + stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then + the default stream for the selected device is used. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + if stream is None: + stream = torch.cuda.current_stream(device) + if isinstance(stream, torch.cuda.streams.Stream): + stream = stream.cuda_stream + if not isinstance(stream, int): + raise TypeError( + "Invalid type for stream argument, must be " + "`torch.cuda.Stream` or `int` representing a pointer " + "to a existing stream" + ) + with torch.cuda.device(device): + return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream) + + +def caching_allocator_delete(mem_ptr): + r"""Delete memory allocated using the CUDA memory allocator. + + Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`. + is freed here. The associated device and stream are tracked inside + the allocator. + + Args: + mem_ptr (int): memory address to be freed by the allocator. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr) + + +def set_per_process_memory_fraction( + fraction, device: Union[Device, int] = None +) -> None: + r"""Set memory fraction for a process. + + The fraction is used to limit an caching allocator to allocated memory on a CUDA device. + The allowed value equals the total visible memory multiplied fraction. + If trying to allocate more than the allowed value in a process, will raise an out of + memory error in allocator. + + Args: + fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction. + device (torch.device or int, optional): selected device. If it is + ``None`` the default CUDA device is used. + .. note:: + In general, the total available free memory is less than the total capacity. + """ + _lazy_init() + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + if not isinstance(fraction, float): + raise TypeError("Invalid type for fraction argument, must be `float`") + if fraction < 0 or fraction > 1: + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1") + + torch._C._cuda_setMemoryFraction(fraction, device) + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other GPU application and visible in + `nvidia-smi`. + + .. note:: + :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU + memory available for PyTorch. However, it may help reduce fragmentation + of GPU memory in certain cases. See :ref:`cuda-memory-management` for + more details about GPU memory management. + """ + if is_initialized(): + torch._C._cuda_emptyCache() + + +def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]: + r"""Return a dictionary of CUDA memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from ``cudaMalloc()``. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of October 2019, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of October 2019, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of CUDA allocation calls. This includes both + cuMemMap and cudaMalloc. + - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap + and cudaFree. + + The caching allocator can be configured via ENV to not split blocks larger than a + defined size (see Memory Management section of the Cuda Semantics documentation). + This helps avoid memory fragmentation but may have a performance + penalty. Additional outputs to assist with tuning and evaluating impact: + + - ``"max_split_size"``: blocks above this size will not be split. + - ``"oversize_allocations.{current,peak,allocated,freed}"``: + number of over-size allocation requests received by the memory allocator. + - ``"oversize_segments.{current,peak,allocated,freed}"``: + number of over-size reserved segments from ``cudaMalloc()``. + + The caching allocator can be configured via ENV to round memory allocations in order + to reduce fragmentation. Sometimes the overhead from rounding can be higher than + the fragmentation it helps reduce. The following stat can be used to check if + rounding adds too much overhead: + + - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + memory requested by client code, compare this with allocated_bytes to check if + allocation rounding adds too much overhead. + + Args: + device (torch.device or int, optional): selected device. Returns + statistics for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + + .. note:: + With :ref:`backend:cudaMallocAsync`, some stats are not + meaningful, and are always reported as zero. + """ + result = [] + + def _recurse_add_to_result(prefix, obj): + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = memory_stats_as_nested_dict(device=device) + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) + + +def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]: + r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + device = _get_device_index(device, optional=True) + return torch._C._cuda_memoryStats(device) + + +def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator. + + See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to + the `"allocated"` and `"freed"` keys in each individual stat dict, as well as + `"num_alloc_retries"` and `"num_ooms"`. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._cuda_resetAccumulatedMemoryStats(device) + + +def reset_peak_memory_stats(device: Union[Device, int] = None) -> None: + r"""Reset the "peak" stats tracked by the CUDA memory allocator. + + See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._cuda_resetPeakMemoryStats(device) + + +def reset_max_memory_allocated(device: Union[Device, int] = None) -> None: + r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device. + + See :func:`~torch.cuda.max_memory_allocated` for details. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. warning:: + This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets + /all/ peak memory stats. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + warnings.warn( + "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, " + "which resets /all/ peak memory stats.", + FutureWarning, + ) + return reset_peak_memory_stats(device=device) + + +def reset_max_memory_cached(device: Union[Device, int] = None) -> None: + r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + + See :func:`~torch.cuda.max_memory_cached` for details. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. warning:: + This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets + /all/ peak memory stats. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + warnings.warn( + "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, " + "which resets /all/ peak memory stats.", + FutureWarning, + ) + return reset_peak_memory_stats(device=device) + + +def memory_allocated(device: Union[Device, int] = None) -> int: + r"""Return the current GPU memory occupied by tensors in bytes for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + This is likely less than the amount shown in `nvidia-smi` since some + unused memory can be held by the caching allocator and some context + needs to be created on GPU. See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + return memory_stats(device=device).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device: Union[Device, int] = None) -> int: + r"""Return the maximum GPU memory occupied by tensors in bytes for a given device. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. For example, these two + functions can measure the peak allocated memory usage of each iteration in a + training loop. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device: Union[Device, int] = None) -> int: + r"""Return the current GPU memory managed by the caching allocator in bytes for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device: Union[Device, int] = None) -> int: + r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. For example, these two functions + can measure the peak cached memory amount of each iteration in a training + loop. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return memory_stats(device=device).get("reserved_bytes.all.peak", 0) + + +@deprecated( + "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`", + category=FutureWarning, +) +def memory_cached(device: Union[Device, int] = None) -> int: + r"""Deprecated; see :func:`~torch.cuda.memory_reserved`.""" + return memory_reserved(device=device) + + +@deprecated( + "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`", + category=FutureWarning, +) +def max_memory_cached(device: Union[Device, int] = None) -> int: + r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`.""" + return max_memory_reserved(device=device) + + +def memory_snapshot(): + r"""Return a snapshot of the CUDA memory allocator state across all devices. + + Interpreting the output of this function requires familiarity with the + memory allocator internals. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + return torch._C._cuda_memorySnapshot()["segments"] + + +def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str: + r"""Return a human-readable printout of the current memory allocator statistics for a given device. + + This can be useful to display periodically during training, or when + handling out-of-memory exceptions. + + Args: + device (torch.device or int, optional): selected device. Returns + printout for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + abbreviated (bool, optional): whether to return an abbreviated summary + (default: False). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + stats = memory_stats(device=device) + + def _format_size(sz, pref_sz): + prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"] + prefix = prefixes[0] + for new_prefix in prefixes[1:]: + if pref_sz < 768 * 1024: + break + prefix = new_prefix + sz //= 1024 + pref_sz /= 1024 + return f"{sz:6d} {prefix}" + + def _format_count(cnt, pref_cnt): + prefixes = [" ", "K", "M"] + prefix = prefixes[0] + for new_prefix in prefixes[1:]: + if pref_cnt < 750 * 1000: + break + prefix = new_prefix + cnt //= 1000 + pref_cnt /= 1000 + return f"{cnt:7d} {prefix} " + + metrics_to_display = [ + ("allocated_bytes", "Allocated memory", _format_size), + ("active_bytes", "Active memory", _format_size), + ("requested_bytes", "Requested memory", _format_size), + ("reserved_bytes", "GPU reserved memory", _format_size), + ("inactive_split_bytes", "Non-releasable memory", _format_size), + ("allocation", "Allocations", _format_count), + ("active", "Active allocs", _format_count), + ("segment", "GPU reserved segments", _format_count), + ("inactive_split", "Non-releasable allocs", _format_count), + ] + + lines = [] + lines.append("=" * 75) + lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ") + lines.append("-" * 75) + lines.append( + " {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} " + ) + lines.append("=" * 75) + lines.append( + " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed " + ) + + for metric_key, metric_name, formatter in metrics_to_display: + lines.append("-" * 75) + submetrics = [("all", metric_name)] + if not abbreviated: + submetrics.append(("large_pool", " from large pool")) + submetrics.append(("small_pool", " from small pool")) + + current_prefval, peak_prefval, allocated_prefval, freed_prefval = ( + None, + None, + None, + None, + ) + + for submetric_key, submetric_name in submetrics: + prefix = metric_key + "." + submetric_key + "." + + current = stats[prefix + "current"] + peak = stats[prefix + "peak"] + allocated = stats[prefix + "allocated"] + freed = stats[prefix + "freed"] + + if current_prefval is None: + current_prefval = current + peak_prefval = peak + allocated_prefval = allocated + freed_prefval = freed + + lines.append( + f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | " + f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ", + ) + + metrics_to_display = [ + ("oversize_allocations", "Oversize allocations", _format_count), + ("oversize_segments", "Oversize GPU segments", _format_count), + ] + + for metric_key, metric_name, formatter in metrics_to_display: + lines.append("-" * 75) + + prefix = metric_key + "." + + current = stats[prefix + "current"] + peak = stats[prefix + "peak"] + allocated = stats[prefix + "allocated"] + freed = stats[prefix + "freed"] + + lines.append( + f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | " + f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ", + ) + + lines.append("=" * 75) + + fmt_dict = {"_": "", "device": device} + for k, v in stats.items(): + fmt_dict[k.replace(".", "-")] = v + return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n" + + +def list_gpu_processes(device: Union[Device, int] = None) -> str: + r"""Return a human-readable printout of the running processes and their GPU memory use for a given device. + + This can be useful to display periodically during training, or when + handling out-of-memory exceptions. + + Args: + device (torch.device or int, optional): selected device. Returns + printout for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + """ + if not torch.version.hip: + try: + import pynvml # type: ignore[import] + except ModuleNotFoundError: + return "pynvml module not found, please install pynvml" + from pynvml import NVMLError_DriverNotLoaded + + try: + pynvml.nvmlInit() + except NVMLError_DriverNotLoaded: + return "cuda driver can't be loaded, is cuda enabled?" + + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + else: + try: + import amdsmi # type: ignore[import] + except ModuleNotFoundError: + return "amdsmi module not found, please install amdsmi" + try: + amdsmi.amdsmi_init() # type: ignore[attr-defined] + except amdsmi.AmdSmiException: # type: ignore[attr-defined] + return "amdsmi driver can't be loaded, is ROCm installed?" + + device = _get_amdsmi_device_index(device) + + try: + handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined] + procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined] + except amdsmi.AmdSmiException: # type: ignore[attr-defined] + return "amdsmi cannot list processes from other users" + + lines = [] + lines.append(f"GPU:{device}") + if len(procs) == 0: + lines.append("no processes are running") + for p in procs: + if not torch.version.hip: + mem = p.usedGpuMemory / (1024 * 1024) + pid = p.pid + else: + try: + proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined] + except AttributeError: + # https://github.com/ROCm/amdsmi/commit/c551c3caedbd903ba828e7fdffa5b56d475a15e7 + # is a BC-breaking change that removes amdsmi_get_gpu_process_info API from amdsmi + proc_info = p + mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024) + pid = proc_info["pid"] + lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory") + return "\n".join(lines) + + +def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]: + r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default) or if the device index is not specified. + + .. note:: + See :ref:`cuda-memory-management` for more + details about GPU memory management. + """ + if device is None: + device = torch.cuda.current_device() + # optional=True allows `device = torch.device('cuda')` for which device.index is None + device = _get_device_index(device, optional=True) + return torch.cuda.cudart().cudaMemGetInfo(device) + + +def _record_memory_history_legacy( + enabled: bool, + record_context=True, + trace_alloc_max_entries=1, + trace_alloc_record_context=False, + device: Union[Device, int] = None, + record_context_cpp=False, +): + _C._cuda_record_memory_history_legacy( + enabled, + record_context, + trace_alloc_max_entries, + trace_alloc_record_context, + record_context_cpp, + ) + + +def _record_memory_history(enabled="all", *args, **kwargs): + """Enable recording of stack traces associated with memory + allocations, so you can tell what allocated any piece of memory in + :func:`torch.cuda.memory._snapshot()`. + + In addition too keeping stack traces with each current allocation and free, + this will also enable recording of a history of all alloc/free events. + + Use :func:`torch.cuda.memory._snapshot()` to retrieve this information, + and the tools in `_memory_viz.py` to visualize snapshots. + + The Python trace collection is fast (2us per trace), so you may consider + enabling this on production jobs if you anticipate ever having to debug + memory issues. + + C++ trace collection is also fast (~50ns/frame), which for many typical programs + works out to ~2us per trace, but can vary depending on stack depth. + + Args: + enabled (Literal[None, "state", "all"], optional): + `None`, disable recording memory history. + `"state"`, keep information for currenly allocated memory. + `"all"`, additionally keep a history of all alloc/free calls. + Defaults to "all". + context (Literal[None, "state", "alloc", "all"], optional): + `None`, Do not record any tracebacks. + `"state"`, Record tracebacks for currently allocated memory. + `"alloc"`, additionally keep tracebacks for alloc calls. + `"all"`, additionally keep tracebacks for free calls. + Defaults to "all". + stacks (Literal["python", "all"], optional): + `"python"`, include Python, TorchScript, and inductor frames in tracebacks + `"all"`, additionally include C++ frames + Defaults to "all". + max_entries (int, optional): Keep a maximum of `max_entries` + alloc/free events in the recorded history recorded. + """ + if isinstance(enabled, bool): + return _record_memory_history_legacy(enabled, *args, **kwargs) + else: + return _record_memory_history_impl(enabled, *args, **kwargs) + + +def _record_memory_history_impl( + enabled: Optional[str] = "all", + context: Optional[str] = "all", + stacks: str = "all", + max_entries: int = sys.maxsize, + device: Union[Device, int] = None, +): + _C._cuda_record_memory_history(enabled, context, stacks, max_entries) + + +_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined] + + +def _snapshot(device: Union[Device, int] = None): + """Save a snapshot of CUDA memory state at the time it was called. + + The state is represented as a dictionary with the following structure. + + .. code-block:: python + + class Snapshot(TypedDict): + segments : List[Segment] + device_traces: List[List[TraceEntry]] + + class Segment(TypedDict): + # Segments are memory returned from a cudaMalloc call. + # The size of reserved memory is the sum of all Segments. + # Segments are cached and reused for future allocations. + # If the reuse is smaller than the segment, the segment + # is split into more then one Block. + # empty_cache() frees Segments that are entirely inactive. + address: int + total_size: int # cudaMalloc'd size of segment + stream: int + segment_type: Literal['small', 'large'] # 'large' (>1MB) + allocated_size: int # size of memory in use + active_size: int # size of memory in use or in active_awaiting_free state + blocks : List[Block] + + class Block(TypedDict): + # A piece of memory returned from the allocator, or + # current cached but inactive. + size: int + requested_size: int # size requested during malloc, may be smaller than + # size due to rounding + address: int + state: Literal['active_allocated', # used by a tensor + 'active_awaiting_free', # waiting for another stream to finish using + # this, then it will become free + 'inactive',] # free for reuse + frames: List[Frame] # stack trace from where the allocation occurred + + class Frame(TypedDict): + filename: str + line: int + name: str + + class TraceEntry(TypedDict): + # When `torch.cuda.memory._record_memory_history()` is enabled, + # the snapshot will contain TraceEntry objects that record each + # action the allocator took. + action: Literal[ + 'alloc' # memory allocated + 'free_requested', # the allocated received a call to free memory + 'free_completed', # the memory that was requested to be freed is now + # able to be used in future allocation calls + 'segment_alloc', # the caching allocator ask cudaMalloc for more memory + # and added it as a segment in its cache + 'segment_free', # the caching allocator called cudaFree to return memory + # to cuda possibly trying free up memory to + # allocate more segments or because empty_caches was called + 'oom', # the allocator threw an OOM exception. 'size' is + # the requested number of bytes that did not succeed + 'snapshot' # the allocator generated a memory snapshot + # useful to coorelate a previously taken + # snapshot with this trace + ] + addr: int # not present for OOM + frames: List[Frame] + size: int + stream: int + device_free: int # only present for OOM, the amount of + # memory cuda still reports to be free + + Returns: + The Snapshot dictionary object + """ + return _C._cuda_memorySnapshot() + + +def _dump_snapshot(filename="dump_snapshot.pickle"): + """ + Save a pickled version of the `torch.memory._snapshot()` dictionary to a file. + + This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz + + Args: + filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". + """ + s = _snapshot() + with open(filename, "wb") as f: + pickle.dump(s, f) + + +def _save_segment_usage(filename="output.svg", snapshot=None): + if snapshot is None: + snapshot = _snapshot() + with open(filename, "w") as f: + f.write(_segments(snapshot)) + + +def _save_memory_usage(filename="output.svg", snapshot=None): + if snapshot is None: + snapshot = _snapshot() + with open(filename, "w") as f: + f.write(_memory(snapshot)) + + +def _set_allocator_settings(env: str): + return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env) + + +def get_allocator_backend() -> str: + r"""Return a string describing the active allocator backend as set by + ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are + ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync`` + (CUDA's built-in asynchronous allocator). + + .. note:: + See :ref:`cuda-memory-management` for details on choosing the allocator backend. + """ + return torch._C._cuda_getAllocatorBackend() + + +class _CUDAAllocator: + r"""Wrapper over internal CUDA memory allocators.""" + + def __init__(self, allocator: torch._C._cuda_CUDAAllocator): + self._allocator = allocator + + def allocator(self): + return self._allocator + + +class CUDAPluggableAllocator(_CUDAAllocator): + r"""CUDA memory allocator loaded from a so file.""" + + def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str): + r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes. + + To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function. + + Args: + path_to_so_file(str): Path in the filesystem to the `.so` file containing + the allocator functions + alloc_fn_name(str): Name of the function to perform the memory allocation + in the so file. The signature must be: + void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream); + free_fn_name(str): Name of the function to perform the memory release + in the so file. The signature must be: + void free_fn_name(void* ptr, size_t size, cudaStream_t stream); + + .. warning:: + This is currently supported only in unix OSs + + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + allocator = ctypes.CDLL(path_to_so_file) + alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value + free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value + assert alloc_fn is not None + assert free_fn is not None + self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn) + + +def change_current_allocator(allocator: _CUDAAllocator) -> None: + r"""Change the currently used memory allocator to be the one provided. + + If the current allocator has already been used/initialized, this function will error. + + + Args: + allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one. + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + torch._C._cuda_changeCurrentAllocator(allocator.allocator()) + + +def _get_current_allocator() -> _CUDAAllocator: + r"""Return the allocator being currently used. + + .. note:: + See :ref:`cuda-memory-management` for details on creating and using a custom allocator + """ + return _CUDAAllocator(torch._C._cuda_getAllocator()) + + +class MemPool(_MemPool): + r"""MemPool represents a pool of memory in a caching allocator. Currently, + it's just the ID of the pool object maintained in the CUDACachingAllocator. + + Args: + allocator(torch._C._cuda_CUDAAllocator, optional): a + torch._C._cuda_CUDAAllocator object that can be used to + define how memory gets allocated in the pool. If :attr:`allocator` + is ``None`` (default), memory allocation follows the default/ + current configuration of the CUDACachingAllocator. + + """ + + def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None): + super().__init__(allocator, True) + + @property + def id(self) -> Tuple[int, int]: + r"""Returns the ID of this pool as a tuple of two ints.""" + return super().id + + @property + def allocator(self) -> Optional[_cuda_CUDAAllocator]: + r"""Returns the allocator this MemPool routes allocations to""" + return super().allocator + + +class MemPoolContext(_MemPoolContext): + r"""MemPoolContext holds the currently active pool and stashes the previous + pool. On deletion it makes the previous pool active. + + Args: + pool(torch.cuda.MemPool): a MemPool object to be made active so that + allocations route to this pool. + + """ + + def __init__(self, pool: MemPool): + super().__init__(pool) + + @staticmethod + def active_pool() -> Optional[_MemPool]: + r"""Returns the active MemPool""" + return _MemPoolContext.active_pool() + + +@contextlib.contextmanager +def use_mem_pool(pool: MemPool, device: Union[Device, int] = None): + r"""A context manager that routes allocations to a given pool. + + Args: + pool(torch.cuda.MemPool): a MemPool object to be made active so that + allocations route to this pool. + device (torch.device or int, optional): selected device. Uses MemPool on + the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + """ + ctx = MemPoolContext(pool) + device_index = ( + torch.cuda.current_device() if device is None else _get_device_index(device) + ) + _cuda_beginAllocateToPool(device_index, pool.id) + try: + yield + finally: + _cuda_endAllocateCurrentStreamToPool(device_index, pool.id) + del ctx diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/nccl.py b/.venv/lib/python3.11/site-packages/torch/cuda/nccl.py new file mode 100644 index 0000000000000000000000000000000000000000..4c28443c9e29fba7edd34924006573497feb2d6e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/nccl.py @@ -0,0 +1,151 @@ +# mypy: allow-untyped-defs +import collections +import warnings +from typing import Optional, Sequence, Union + +import torch.cuda + + +__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"] + +SUM = 0 # ncclRedOp_t + + +def is_available(tensors): + if not hasattr(torch._C, "_nccl_all_reduce"): + warnings.warn("PyTorch is not compiled with NCCL support") + return False + + devices = set() + for tensor in tensors: + if tensor.is_sparse: + return False + if not tensor.is_contiguous(): + return False + if not tensor.is_cuda: + return False + device = tensor.get_device() + if device in devices: + return False + devices.add(device) + + return True + + +def version(): + """ + Returns the version of the NCCL. + + + This function returns a tuple containing the major, minor, and patch version numbers of the NCCL. + The suffix is also included in the tuple if a version suffix exists. + Returns: + tuple: The version information of the NCCL. + """ + ver = torch._C._nccl_version() + major = ver >> 32 + minor = (ver >> 16) & 65535 + patch = ver & 65535 + suffix = torch._C._nccl_version_suffix().decode("utf-8") + if suffix == "": + return (major, minor, patch) + else: + return (major, minor, patch, suffix) + + +def unique_id(): + return torch._C._nccl_unique_id() + + +def init_rank(num_ranks, uid, rank): + return torch._C._nccl_init_rank(num_ranks, uid, rank) + + +def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None: + if not isinstance(inputs, collections.abc.Container) or isinstance( + inputs, torch.Tensor + ): + raise TypeError("Inputs should be a collection of tensors") + + +def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None): + _check_sequence_type(inputs) + if outputs is None: + outputs = inputs + _check_sequence_type(outputs) + torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms) + + +# `output` used to be `outputs`, taking in a list of tensors. So we have two +# arguments for BC reasons. +def reduce( + inputs: Sequence[torch.Tensor], + output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, + root: int = 0, + op: int = SUM, + streams: Optional[Sequence[torch.cuda.Stream]] = None, + comms=None, + *, + outputs: Optional[Sequence[torch.Tensor]] = None, +) -> None: + _check_sequence_type(inputs) + _output: torch.Tensor + if outputs is not None: + if output is not None: + raise ValueError( + "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in " + "favor of 'output', taking in a single output tensor. The signature of reduce is: " + "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)." + ) + else: + warnings.warn( + "`nccl.reduce` with an output tensor list is deprecated. " + "Please specify a single output tensor with argument 'output' instead instead.", + FutureWarning, + stacklevel=2, + ) + _output = outputs[root] + elif not isinstance(output, torch.Tensor) and isinstance( + output, collections.abc.Sequence + ): + # User called old API with positional arguments of list of output tensors. + warnings.warn( + "nccl.reduce with an output tensor list is deprecated. " + "Please specify a single output tensor.", + FutureWarning, + stacklevel=2, + ) + _output = output[root] + else: + _output = inputs[root] if output is None else output + torch._C._nccl_reduce(inputs, _output, root, op, streams, comms) + + +def broadcast( + inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None +) -> None: + _check_sequence_type(inputs) + torch._C._nccl_broadcast(inputs, root, streams, comms) + + +def all_gather( + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + streams=None, + comms=None, +) -> None: + _check_sequence_type(inputs) + _check_sequence_type(outputs) + torch._C._nccl_all_gather(inputs, outputs, streams, comms) + + +def reduce_scatter( + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + op: int = SUM, + streams=None, + comms=None, +) -> None: + _check_sequence_type(inputs) + _check_sequence_type(outputs) + torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms) diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py b/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py new file mode 100644 index 0000000000000000000000000000000000000000..a74633da032280c3437504efa7b538d0c1e20312 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py @@ -0,0 +1,93 @@ +# mypy: allow-untyped-defs +r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling.""" + +from contextlib import contextmanager + + +try: + from torch._C import _nvtx +except ImportError: + + class _NVTXStub: + @staticmethod + def _fail(*args, **kwargs): + raise RuntimeError( + "NVTX functions not installed. Are you sure you have a CUDA build?" + ) + + rangePushA = _fail + rangePop = _fail + markA = _fail + + _nvtx = _NVTXStub() # type: ignore[assignment] + +__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"] + + +def range_push(msg): + """ + Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started. + + Args: + msg (str): ASCII message to associate with range + """ + return _nvtx.rangePushA(msg) + + +def range_pop(): + """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended.""" + return _nvtx.rangePop() + + +def range_start(msg) -> int: + """ + Mark the start of a range with string message. It returns an unique handle + for this range to pass to the corresponding call to rangeEnd(). + + A key difference between this and range_push/range_pop is that the + range_start/range_end version supports range across threads (start on one + thread and end on another thread). + + Returns: A range handle (uint64_t) that can be passed to range_end(). + + Args: + msg (str): ASCII message to associate with the range. + """ + return _nvtx.rangeStartA(msg) + + +def range_end(range_id) -> None: + """ + Mark the end of a range for a given range_id. + + Args: + range_id (int): an unique handle for the start range. + """ + _nvtx.rangeEnd(range_id) + + +def mark(msg): + """ + Describe an instantaneous event that occurred at some point. + + Args: + msg (str): ASCII message to associate with the event. + """ + return _nvtx.markA(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (str): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + range_pop() diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/profiler.py b/.venv/lib/python3.11/site-packages/torch/cuda/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0674f4a4a1d9d3f20a759301606c8236b09828 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/profiler.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +import contextlib +import tempfile + +import torch + +from . import check_error, cudart + + +__all__ = ["init", "start", "stop", "profile"] + +DEFAULT_FLAGS = [ + "gpustarttimestamp", + "gpuendtimestamp", + "gridsize3d", + "threadblocksize", + "streamid", + "enableonstart 0", + "conckerneltrace", +] + + +def init(output_file, flags=None, output_mode="key_value"): + rt = cudart() + if not hasattr(rt, "cudaOutputMode"): + raise AssertionError("HIP does not support profiler initialization!") + if ( + hasattr(torch.version, "cuda") + and torch.version.cuda is not None + and int(torch.version.cuda.split(".")[0]) >= 12 + ): + # Check https://github.com/pytorch/pytorch/pull/91118 + # cudaProfilerInitialize is no longer needed after CUDA 12 + raise AssertionError("CUDA12+ does not need profiler initialization!") + flags = DEFAULT_FLAGS if flags is None else flags + if output_mode == "key_value": + output_mode_enum = rt.cudaOutputMode.KeyValuePair + elif output_mode == "csv": + output_mode_enum = rt.cudaOutputMode.CSV + else: + raise RuntimeError( + "supported CUDA profiler output modes are: key_value and csv" + ) + with tempfile.NamedTemporaryFile(delete=True) as f: + f.write(b"\n".join(f.encode("ascii") for f in flags)) + f.flush() + check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum)) + + +def start(): + r"""Starts cuda profiler data collection. + + .. warning:: + Raises CudaError in case of it is unable to start the profiler. + """ + check_error(cudart().cudaProfilerStart()) + + +def stop(): + r"""Stops cuda profiler data collection. + + .. warning:: + Raises CudaError in case of it is unable to stop the profiler. + """ + check_error(cudart().cudaProfilerStop()) + + +@contextlib.contextmanager +def profile(): + """ + Enable profiling. + + Context Manager to enabling profile collection by the active profiling tool from CUDA backend. + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> model = torch.nn.Linear(20, 30).cuda() + >>> inputs = torch.randn(128, 20).cuda() + >>> with torch.cuda.profiler.profile() as prof: + ... model(inputs) + """ + try: + start() + yield + finally: + stop() diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/random.py b/.venv/lib/python3.11/site-packages/torch/cuda/random.py new file mode 100644 index 0000000000000000000000000000000000000000..088e8398f64266e1473edba35f5cf0226787ac08 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/random.py @@ -0,0 +1,182 @@ +# mypy: allow-untyped-defs +from typing import Iterable, List, Union + +import torch +from torch import Tensor + +from . import _lazy_call, _lazy_init, current_device, device_count + + +__all__ = [ + "get_rng_state", + "get_rng_state_all", + "set_rng_state", + "set_rng_state_all", + "manual_seed", + "manual_seed_all", + "seed", + "seed_all", + "initial_seed", +] + + +def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor: + r"""Return the random number generator state of the specified GPU as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + + .. warning:: + This function eagerly initializes CUDA. + """ + _lazy_init() + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + return default_generator.get_state() + + +def get_rng_state_all() -> List[Tensor]: + r"""Return a list of ByteTensor representing the random number states of all devices.""" + results = [] + for i in range(device_count()): + results.append(get_rng_state(i)) + return results + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "cuda" +) -> None: + r"""Set the random number generator state of the specified GPU. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). + """ + with torch._C._DisableFuncTorch(): + new_state_copy = new_state.clone(memory_format=torch.contiguous_format) + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + + def cb(): + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state_copy) + + _lazy_call(cb) + + +def set_rng_state_all(new_states: Iterable[Tensor]) -> None: + r"""Set the random number generator state of all devices. + + Args: + new_states (Iterable of torch.ByteTensor): The desired state for each device. + """ + for i, state in enumerate(new_states): + set_rng_state(state, i) + + +def manual_seed(seed: int) -> None: + r"""Set the seed for generating random numbers for the current GPU. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + + .. warning:: + If you are working with a multi-GPU model, this function is insufficient + to get determinism. To seed all GPUs, use :func:`manual_seed_all`. + """ + seed = int(seed) + + def cb(): + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed=True) + + +def manual_seed_all(seed: int) -> None: + r"""Set the seed for generating random numbers on all GPUs. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + """ + seed = int(seed) + + def cb(): + for i in range(device_count()): + default_generator = torch.cuda.default_generators[i] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed_all=True) + + +def seed() -> None: + r"""Set the seed for generating random numbers to a random number for the current GPU. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + .. warning:: + If you are working with a multi-GPU model, this function will only initialize + the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. + """ + + def cb(): + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.seed() + + _lazy_call(cb) + + +def seed_all() -> None: + r"""Set the seed for generating random numbers to a random number on all GPUs. + + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + """ + + def cb(): + random_seed = 0 + seeded = False + for i in range(device_count()): + default_generator = torch.cuda.default_generators[i] + if not seeded: + default_generator.seed() + random_seed = default_generator.initial_seed() + seeded = True + else: + default_generator.manual_seed(random_seed) + + _lazy_call(cb) + + +def initial_seed() -> int: + r"""Return the current random seed of the current GPU. + + .. warning:: + This function eagerly initializes CUDA. + """ + _lazy_init() + idx = current_device() + default_generator = torch.cuda.default_generators[idx] + return default_generator.initial_seed() diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/sparse.py b/.venv/lib/python3.11/site-packages/torch/cuda/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..f37a34118d2d8f73437dee54337a666df1b99a09 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/sparse.py @@ -0,0 +1 @@ +# The Tensor classes are added to this module by python_tensor.cpp diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/streams.py b/.venv/lib/python3.11/site-packages/torch/cuda/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ee6eb68d6892bc749847826c3c0873d900bbb9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/streams.py @@ -0,0 +1,242 @@ +# mypy: allow-untyped-defs +import ctypes + +import torch +from torch._streambase import _EventBase, _StreamBase +from torch._utils import _dummy_type + + +if not hasattr(torch._C, "_CudaStreamBase"): + # Define dummy base classes + torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase") + torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") + + +class Stream(torch._C._CudaStreamBase, _StreamBase): + r"""Wrapper around a CUDA stream. + + A CUDA stream is a linear sequence of execution that belongs to a specific + device, independent from other streams. See :ref:`cuda-semantics` for + details. + + Args: + device(torch.device or int, optional): a device on which to allocate + the stream. If :attr:`device` is ``None`` (default) or a negative + integer, this will use the current device. + priority(int, optional): priority of the stream, should be 0 or + negative, where negative numbers indicate higher priority. By default, + streams have priority 0. + + """ + + def __new__(cls, device=None, priority=0, **kwargs): + # setting device manager is expensive, so we avoid it unless necessary + if device is None or ("stream_id" in kwargs and "device_index" in kwargs): + return super().__new__(cls, priority=priority, **kwargs) + else: + with torch.cuda.device(device): + return super().__new__(cls, priority=priority, **kwargs) + + def wait_event(self, event) -> None: + r"""Make all future work submitted to the stream wait for an event. + + Args: + event (torch.cuda.Event): an event to wait for. + + .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see + `CUDA Stream documentation`_ for more info. + + This function returns without waiting for :attr:`event`: only future + operations are affected. + + .. _CUDA Stream documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html + """ + event.wait(self) + + def wait_stream(self, stream) -> None: + r"""Synchronize with another stream. + + All future work submitted to this stream will wait until all kernels + submitted to a given stream at the time of call complete. + + Args: + stream (Stream): a stream to synchronize. + + .. note:: This function returns without waiting for currently enqueued + kernels in :attr:`stream`: only future operations are affected. + """ + self.wait_event(stream.record_event()) + + def record_event(self, event=None): + r"""Record an event. + + Args: + event (torch.cuda.Event, optional): event to record. If not given, a new one + will be allocated. + + Returns: + Recorded event. + """ + if event is None: + event = Event() + event.record(self) + return event + + def query(self) -> bool: + r"""Check if all the work submitted has been completed. + + Returns: + A boolean indicating if all kernels in this stream are completed. + """ + return super().query() + + def synchronize(self) -> None: + r"""Wait for all the kernels in this stream to complete. + + .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see + `CUDA Stream documentation`_ for more info. + """ + super().synchronize() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.cuda_stream) + + def __eq__(self, o) -> bool: + if isinstance(o, Stream): + return super().__eq__(o) + return False + + def __hash__(self): + return hash((self.cuda_stream, self.device)) + + def __repr__(self): + return f"" + + +class ExternalStream(Stream): + r"""Wrapper around an externally allocated CUDA stream. + + This class is used to wrap streams allocated in other libraries in order + to facilitate data exchange and multi-library interactions. + + .. note:: This class doesn't manage the stream life-cycle, it is the user + responsibility to keep the referenced stream alive while this class is + being used. + + Args: + stream_ptr(int): Integer representation of the `cudaStream_t` value. + allocated externally. + device(torch.device or int, optional): the device where the stream + was originally allocated. If device is specified incorrectly, + subsequent launches using this stream may fail. + """ + + def __new__(cls, stream_ptr, device=None, **kwargs): + with torch.cuda.device(device): + return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) + + +class Event(torch._C._CudaEventBase, _EventBase): + r"""Wrapper around a CUDA event. + + CUDA events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize CUDA + streams. + + The underlying CUDA events are lazily initialized when the event is first + recorded or exported to another process. After creation, only streams on the + same device may record the event. However, streams on any device can wait on + the event. + + Args: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) + interprocess (bool): if ``True``, the event can be shared between processes + (default: ``False``) + + .. _CUDA Event Documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html + """ + + def __new__(cls, enable_timing=False, blocking=False, interprocess=False): + return super().__new__( + cls, + enable_timing=enable_timing, + blocking=blocking, + interprocess=interprocess, + ) + + @classmethod + def from_ipc_handle(cls, device, handle): + r"""Reconstruct an event from an IPC handle on the given device.""" + return super().from_ipc_handle(device, handle) + + def record(self, stream=None): + r"""Record the event in a given stream. + + Uses ``torch.cuda.current_stream()`` if no stream is specified. The + stream's device must match the event's device. + """ + if stream is None: + stream = torch.cuda.current_stream() + super().record(stream) + + def wait(self, stream=None) -> None: + r"""Make all future work submitted to the given stream wait for this event. + + Use ``torch.cuda.current_stream()`` if no stream is specified. + + .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see + `CUDA Event documentation`_ for more info. + """ + if stream is None: + stream = torch.cuda.current_stream() + super().wait(stream) + + def query(self): + r"""Check if all work currently captured by event has completed. + + Returns: + A boolean indicating if all work currently captured by event has + completed. + """ + return super().query() + + def elapsed_time(self, end_event): + r"""Return the time elapsed. + + Time reported in milliseconds after the event was recorded and + before the end_event was recorded. + """ + return super().elapsed_time(end_event) + + def synchronize(self) -> None: + r"""Wait for the event to complete. + + Waits until the completion of all work currently captured in this event. + This prevents the CPU thread from proceeding until the event completes. + + .. note:: This is a wrapper around ``cudaEventSynchronize()``: see + `CUDA Event documentation`_ for more info. + """ + super().synchronize() + + def ipc_handle(self): + r"""Return an IPC handle of this event. + + If not recorded yet, the event will use the current device. + """ + return super().ipc_handle() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.cuda_event) + + def __repr__(self) -> str: + if self.cuda_event: + return f"" + else: + return "" diff --git a/.venv/lib/python3.11/site-packages/torch/cuda/tunable.py b/.venv/lib/python3.11/site-packages/torch/cuda/tunable.py new file mode 100644 index 0000000000000000000000000000000000000000..8b387102b43dcf3abd1a99c1d63f43cdddc36a4a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/cuda/tunable.py @@ -0,0 +1,242 @@ +r""" +This module exposes a TunableOp interface. + +Some operations, such as GEMMs, could be implemented using more than one library +or more than one technique. For example, a GEMM could be implemented for CUDA or +ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and +hipblaslt libraries allow the user to query for all possible algorithms and then +choose one. How does one know which implementation is the fastest and should be +chosen? That's what TunableOp provides. + +Enabling TunableOp and Tuning Separately +======================================== + +The TunableOp feature is enabled separately from enabling the tuning phase +itself. Enabling TunableOp means that PyTorch will replace any standard +operators with their Tunable implementations. Any call to a TunableOp first +checks whether it has already been tuned for the given operator inputs. If so, +it will immediately call the tuned operation; no further tuning will take place +even when the tuning setting is enabled. Instead if no tuning result is found, +and tuning is enabled, the TunableOp will benchmark every registered +implementation of that operator for the given set of inputs and select the +fastest. + +File Input and Output +===================== + +The first time any TunableOp is invoked, the internal database of tuned +operations will be prepared by attempting to read the results from the given +file. The default filename is 'tunableop_results.csv'. To support tuning when +multiple GPUs are used across multiple processes, the GPU device ordinal is +automatically inserted into the filename to avoid multiple processes overwriting +the same file. + +If tuning is enabled and new tunings are discovered during the course of your +workload, it will also write out to this same filename with all tunings, both +the ones it read in at startup as well as the new ones found at runtime. This +can be used, for example, to build up a tunings file across many workloads by +reusing the same file. The output file is automatically created when the +application terminates. This behavior can be controlled by the C++ and Python +APIs but not the environment variables. + +Assuming you specified a filename, you'll end up with a CSV file with contents +like so:: + + Validator,PT_VERSION,2.2.0 + Validator,ROCM_VERSION,6.0.0.0-12969-1544e39 + Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7 + Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty + GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 + GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 + +Note the "Validator" lines. If you change a library verison, or ROCm version, or +PyTorch version, TunableOp will detect this and reject the tunings file because +the prior tunings are likely affected by other software changes. + +The remaining lines are the tuned solutions for each TunableOp encountered +during your execution. Each line consists of 4 comma-separated fields: operator +name, operator parameters, solution name, and average execution time. The +execution time is an optional field. The CSV file can be edited, but with +caution. For example, the solution name (field 3) can be changed to "Default" +and it will fall back to the original PyTorch untuned implementation. Or, in the +case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution +index you can override the solution that TunableOp selected by replacing the +value. The operator name and parameters (fields 1 and 2) are internally named +and should not be modified. In the case of GemmTunableOp, field 1 indicates the +datatype and whether the inputs are transposed (T) or not (N) and field 2 +indicates the M, N, K input shapes. + +There is an option to enable verbose output but it is only recommended for +debugging purposes. This will produce a lot of diagnostic messages but may be +useful to see if TunableOp is being used at all. Otherwise, TunableOp is +completely silent, besides file output, unless there is a warning or error +during its use. The verbose option is only available by setting the environment +variable PYTORCH_TUNABLEOP_VEROBSE=1. + +A Note on Tuning Behavior +========================= + +Tuning an operator consists of iterating through the list or registered +implementations and profiling each one. The profile is established by running a +single implementation in a loop multiple times and taking the average execution +time. + +By default, each possible solution for a given operator will be run for either +100 iterations or as many iterations that can be run within 30ms, whichever is +smaller, and its average execution will be calculated. The fastest solution +among all that were successfully profiled will be chosen. A profile might fail +if the given solution doesn't achieve the same accuracy as the default +implementation or if the solution returns an error code. + +Current Tunable Operators +========================= + +TunableGemm for ROCm +-------------------- + +Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of +PyTorch will function correctly when using TunableOp but the only solution +available to CUDA builds is the 'Default' implementation i.e. the original +cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm() +or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a +given set of input arguments (transa, transb, m, n, k) will attempt to use the +fastest available implementation across both rocblas and hipblaslt. + +Tuning Context +============== + +The behavior of TunableOp is currently manipulated through environment +variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the +torch.cuda.tunable python interfaces that wrap the C++ TuningContext. The +environment variables take precedence over any setting you manipulate using the +C++ or Python APIs. + +""" +from typing import Optional, Tuple + +import torch + + +__all__ = [ + "enable", + "is_enabled", + "tuning_enable", + "tuning_is_enabled", + "set_max_tuning_duration", + "get_max_tuning_duration", + "set_max_tuning_iterations", + "get_max_tuning_iterations", + "set_filename", + "get_filename", + "get_results", + "get_validators", + "write_file_on_exit", + "write_file", + "read_file", +] + + +def enable(val: bool = True) -> None: + r"""This is the big on/off switch for all TunableOp implementations.""" + torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined] + + +def is_enabled() -> bool: + r"""Returns whether the TunableOp feature is enabled.""" + return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined] + + +def tuning_enable(val: bool = True) -> None: + r"""Enable tuning of TunableOp implementations. + + When enabled, if a tuned entry isn't found, run the tuning step and record + the entry. + """ + torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined] + + +def tuning_is_enabled() -> bool: + r"""Returns whether TunableOp implementations can be tuned.""" + return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined] + + +def set_max_tuning_duration(duration: int) -> None: + r"""Set max time in milliseconds to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined] + + +def get_max_tuning_duration() -> int: + r"""Get max time to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined] + + +def set_max_tuning_iterations(iterations: int) -> None: + r"""Set max number of iterations to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined] + + +def get_max_tuning_iterations() -> int: + r"""Get max iterations to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined] + + +def set_filename(filename: str, insert_device_ordinal: bool = False) -> None: + r"""Set the filename to use for input/output of tuning results. + + If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal + will be added to the given filename automatically. This can be used in a + 1-process-per-gpu cenario to ensure all processes write to a separate file. + """ + torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined] + + +def get_filename() -> str: + r"""Get the results filename.""" + return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined] + + +def get_results() -> Tuple[str, str, str, float]: + r"""Return all TunableOp results.""" + return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined] + + +def get_validators() -> Tuple[str, str]: + r"""Return the TunableOp validators.""" + return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined] + + +def write_file_on_exit(val: bool) -> None: + r"""During Tuning Context destruction, write file to disk. + + This is useful as a final flush of your results to disk if your application + terminates as result of normal operation or an error. Manual flushing of + your results can be achieved by manually calling ``write_file()``.""" + torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined] + + +def write_file(filename: Optional[str] = None) -> bool: + r"""Write results to a CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined] + + +def read_file(filename: Optional[str] = None) -> bool: + r"""Read results from a TunableOp CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined] diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1da564fecc14040035789556842a3925dfc3564 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8410e1ea90ca4d9c36d3bb866157e4d8cfe7a967 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bf6e6105f81bdee31808f68694ee57779f45efe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47232c6f3acc81d7b3065e86b219b1fbd0110194 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcc011e8a2b6989412fcc23ca14cb8af735abf5d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04c76d67c180ebec7dbd01729bbf4fc305b94661 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d18b9cfd5b2510c58030c9315ca3631191d536e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb5359ae4969d74cc41f7e716d550a2abc59aba8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2135705aa08f28347fc86e17e634acd3f6fb9fdb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bcd9ffceb9b38d020674f145d16bc5cafd7d495 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1ba3d8779780b8cd553473c0432e7137e30ab44 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba2876540567a614dfacc067b69402a7e336ed8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19e6f9a603c82e10fbd7a611431e7a29d61ec4c4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13271e2c1eb3c89715fbc577d704c7a6922c9875 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55e6a3042992bf11994f8fd327bc9195c7168e77 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..235460c4b17fb3012eb15f3d6acd1327ec28bfa3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a27d783aad174d87725d9f5fbf8525a91d866afa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d684c1ed6dc92239ca390c2c3f65ca8fe9ec22a5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d5cabb742d640024f83520b0cae322b1cbee1a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..561d47823803456af4d5d723841ad605567295c6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9b6a792e8a957de6ede955fa22ff5c10e8567f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1535675ae68adae8919ecf69dc77ce979a1cfb3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45b02c3b65377fe1a09967350dd48353b74bcb5e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c4c7f31d38720da0dcd62ff87decc747f653a7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..391831b4e063db75e98226a66a75daa791dd24d0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..393dee04ffd4a74c2fd2cf1ecc0b364f03c1b511 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10509537a6acd00c84bffe70832bc12b0b2c3d18 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f14f76cc8404e6fc8ed38ab7b62739057d98f072 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b48be660d208346c648b744fdb63e7f01b9a173b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a73a3ec2e2aa56a39080974213f79416518a0aca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8189872b7efef71ec3b320317a50d8572b6755c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f0e16a86a09f98f537cfdb75833c74c9633e2d2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8565d362545e3bd9e452785fd9b219a16a0f8630 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ea31359d3daac6b01157ce9eac47abe25598277 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f6400e172c145fd74207cfc59a28e48a8db1ed2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d10ccde338566cb3509c2a1c10e8130e9103988 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11240fd546e55ba3a7189e8018f686524a92d2c7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/cudagraphs.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..0f48165b7dab41d9c8dd30ef49a2752868740989 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/cudagraphs.py @@ -0,0 +1,57 @@ +# mypy: allow-untyped-defs +import torch +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.utils import _pytree as pytree + +import operator + +class CudaGraphsSupport(OperatorSupport): + # TODO: why is submodules passed here + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op not in CALLABLE_NODE_OPS: + return False + + if node.target in [torch.ops.aten.embedding_dense_backward.default]: + return False + + if node.target in [operator.getitem]: + return True + + found_not_cuda = False + + def meta_fk(meta): + return meta["val"] if "val" in meta else meta["fake_result"] + + def find_not_cuda(t): + nonlocal found_not_cuda + if isinstance(t, torch.Tensor) and t.device.type != 'cuda': + found_not_cuda = True + + for n in node.all_input_nodes: + pytree.tree_map_(find_not_cuda, meta_fk(n.meta)) + + pytree.tree_map_(find_not_cuda, meta_fk(node.meta)) + + # NB: factory function is accounted for because the result would be + # cpu or cuda + + return not found_not_cuda + +def partition_cudagraphs(gm, inputs): + """ + Partition an FX graph into sub-GraphModules that can be validly run under + CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations + must involve CUDA tensors only/ + """ + + FakeTensorProp(gm).propagate(*inputs) + supported_ops = CudaGraphsSupport() + # TODO: single node partition may be wrong due to the pessimization + # from copying in and out the data. Check in benchmarks, perhaps + partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) + partitions = partitioner.propose_partitions() + fused_graph = partitioner.fuse_partitions(partitions) + return fused_graph diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c5c82bc1d997814f2ffd01a043d4b0491bcb2dc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c10d9e18100064e9a9db9fe7a1c20904b8644f9c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85c3564e374c9380db7f59655bc24acdf0627f62 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/test_pass_manager.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/test_pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..60ed6671179b2c20fa0be176631d1415009ee87a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/test_pass_manager.py @@ -0,0 +1,58 @@ +import unittest + +from ..pass_manager import ( + inplace_wrapper, + PassManager, + these_before_those_pass_constraint, + this_before_that_pass_constraint, +) + + +class TestPassManager(unittest.TestCase): + def test_pass_manager_builder(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + pm.validate() + + def test_this_before_that_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + + # add unfulfillable constraint + pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) + + self.assertRaises(RuntimeError, pm.validate) + + def test_these_before_those_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + constraint = these_before_those_pass_constraint(passes[-1], passes[0]) + pm = PassManager( + [inplace_wrapper(p) for p in passes] + ) + + # add unfulfillable constraint + pm.add_constraint(constraint) + + self.assertRaises(RuntimeError, pm.validate) + + def test_two_pass_managers(self) -> None: + """Make sure we can construct the PassManager twice and not share any + state between them""" + + passes = [lambda x: 2 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm1 = PassManager() + for p in passes: + pm1.add_pass(p) + pm1.add_constraint(constraint) + output1 = pm1(1) + self.assertEqual(output1, 2 ** 3) + + passes = [lambda x: 3 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm2 = PassManager() + for p in passes: + pm2.add_pass(p) + pm2.add_constraint(constraint) + output2 = pm2(1) + self.assertEqual(output2, 3 ** 3) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7970ba4c283e851430ed0025e1ed5c772eb7b1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__init__.py @@ -0,0 +1 @@ +from .common import lift_subgraph_as_module, HolderModule, compare_graphs diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ba9ac9e864cee87f99c2bc06b6cf5a5aa651810 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dfe4d97f144688480c92938a1f387a4081da75a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8df167542f3afb4872c2b4fcb48f7d488bafe04e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43f39e6997b4ed90b949307131453c2efad03f64 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fffeffb4c99435c1616f517af33287f4dd9690c7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..242b175067219e9dfa206a88050eb1e5e8d543be Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2ae45aabf5dcd26a0baee6b5acc11229cf1b88 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +from typing import Dict, Tuple + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph + +from torch.fx.graph_module import GraphModule +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.nn import Module + + +__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] + + +@compatibility(is_backward_compatible=False) +class HolderModule(Module): + """ + HolderModule is used to copy all the attributes from original module to submodules + that uses the attributes + """ + + def __init__(self, d): + super().__init__() + for k, v in d.items(): + self.add_module(k, v) + + +@compatibility(is_backward_compatible=False) +def lift_subgraph_as_module( + gm: GraphModule, + subgraph: Graph, + comp_name: str = "", + class_name: str = "GraphModule", +) -> Tuple[GraphModule, Dict[str, str]]: + """ + Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. + + Args: + gm (GraphModule): parent graph module + + subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph + + comp_name (str): name for the new component + + class_name (str): name for the submodule + + """ + + # Loop through all module calls (call_module) and param fetches (get_attr) + # in this component, creating HolderModules as necessary to match the path. + # e.g. if in the original module there's a get_attr node fetches "conv.weight". + # We create a HolderModule as root -> add a HolderModule named "conv" -> + # make "weight" a attribute of "conv" HolderModule and point to conv.weight in + # the original module. + submodule = HolderModule({}) + orig_to_split_fqn_mapping: Dict[str, str] = {} + for n in subgraph.nodes: + if n.op not in ("call_module", "get_attr"): + continue + + target = n.target + assert isinstance(target, str) + target_name_parts = target.split(".") + curr = submodule + orig_gm = gm + + for name in target_name_parts[:-1]: + if not hasattr(curr, name): + curr.add_module(name, HolderModule({})) + + curr = getattr(curr, name) + orig_gm = getattr(orig_gm, name) + + leaf_node_name = target_name_parts[-1] + leaf_node = getattr(orig_gm, leaf_node_name) + + orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" + # Relies on custom __setattr__ magic. + setattr(curr, leaf_node_name, leaf_node) + + return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping + + +@compatibility(is_backward_compatible=False) +def compare_graphs(left: Graph, right: Graph) -> bool: + """ + Return True if two graphs are identical, i.e they + - have the same number of outputs in the same order + - have the same number of inputs in the same order + - have the same set of nodes, and identical connectivity + """ + + matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) + matches = matcher.match(right) + + return len(matches) > 0 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/fuser_utils.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/fuser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..324e8a67801564cf69fac7fef318951fc34223b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/fuser_utils.py @@ -0,0 +1,236 @@ +# mypy: allow-untyped-defs +import copy +from queue import SimpleQueue +from typing import List, Dict, Tuple + +import torch.fx +from torch.fx.graph_module import GraphModule +from torch.fx.graph import Graph +from torch.fx.node import Node +from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph +from torch.fx.passes.utils import lift_subgraph_as_module +from torch.fx._compatibility import compatibility + +@compatibility(is_backward_compatible=False) +def topo_sort(nodes: NodeList) -> NodeList: + # sort nodes according to the topological order + indegree_map = dict.fromkeys(nodes, 0) + candidates: SimpleQueue = SimpleQueue() + + for node in nodes: + for n in node.all_input_nodes: + if n in indegree_map: + indegree_map[node] += 1 + if indegree_map[node] == 0: + candidates.put(node) + + sorted_nodes: NodeList = [] + while not candidates.empty(): + node = candidates.get() + sorted_nodes.append(node) + + for n in node.users: + if n in indegree_map: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes" + + return sorted_nodes + + +@compatibility(is_backward_compatible=False) +def validate_partition(partition: NodeList) -> bool: + # verify the partition does't form a dependency cycle in the original graph + # returns True for valid partition, False for invalid + + partition_set = set(partition) + + outputs: NodeList = [] + for node in partition_set: + for user_node in node.users: + if user_node not in partition_set: + # external user node, need to expose as an output + outputs.append(user_node) + + # Perform BFS on the partition outputs. + # If it reaches a node within the partition, then it found a cycle. + # This function takes the ownership of `root_nodes` and may modify it. + def bfs_find_cycle(root_nodes: NodeList) -> bool: + # Set used to exclude nodes that have already been visited. + # If a node has been visited, that node and all its children have + # been checked for cycles. + visited: NodeSet = set() + + # Start with `root_nodes` and traverse through (toward child nodes) + # their connected sub-graph. Nodes in `visited` won't be added + # to `queue` again. + queue: NodeList = root_nodes + while queue: + current = queue.pop() + visited.add(current) + if current in partition_set: + # Started from partition's `output` nodes, and reached + # another node in partition. Cycle! + return True + for user_node in current.users: + if user_node in visited: + continue + queue.append(user_node) + # `root_nodes` don't cause cycle. + return False + + # Use all output nodes as roots to traverse + # the graph to check cycles. + if bfs_find_cycle(outputs): + return False + + return True + + +@compatibility(is_backward_compatible=False) +def fuse_as_graphmodule(gm: GraphModule, + nodes: NodeList, + module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: + + """ + Fuse nodes in graph_module into a GraphModule. + + Args: + gm (GraphModule): target graph_module + + nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted + + module_name: class name for the fused GraphModule + + Returns: + fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` + + original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` + + original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` + + """ + + # assumption: nodes are already sorted in topo order + + for node in nodes: + assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" + assert not node._erased, f"{node} has been removed from owning graph" + assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" + + # validates partition doesn't introduce dependency circles in the graph + assert validate_partition(nodes), "Invalid partition, found dependency cycles" + + subgraph = Graph() + + node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph + node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph + + # handles inputs through graph.node_copy's arg_transform functions + def remap_inputs(x): + if x.op == "get_attr": + # TODO: do we really need copy the get_attr node into the graph? + # do something here + pass + + if x in nodes: + # x is inside subgraph, return the copied node + # the node should have been copied aleady, as we are copying graph in the topological order + return node_map[x] + + if x not in node_to_placeholder: + # x is not in subgraph, create a new placeholder for subgraph + placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelvant for the placeholder node + placeholder_node.meta = copy.copy(x.meta) + node_to_placeholder[x] = placeholder_node + + return node_to_placeholder[x] + + # copy nodes in topological order + for node in nodes: + new_node = subgraph.node_copy(node, remap_inputs) + node_map[node] = new_node + + # handles outputs + output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs + + for node in nodes: + for user_node in node.users: + if user_node not in nodes: + # external user node, need to expose as an output + output_mapping[node] = node_map[node] + + # outs contain nodes in the new subgraph + outs = tuple(output_mapping.values()) + + # Take care of the args of FX output node. If there's a single + # output then the output node args is like (output_single), else + # if there're multiple outputs then the output node args is like + # ((output_0, output_1, ...)). + subgraph.output(outs[0] if len(outs) == 1 else outs) + + # lint to ensure correctness + subgraph.lint() + fused_gm: GraphModule + fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name) + + # sub_gm's input nodes in the original module + original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) + + # sub_gm's outputs node in the original module + original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys()) + + return fused_gm, original_inputs, original_outputs + + +@compatibility(is_backward_compatible=False) +def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): + # add sub_gm into gm + submodule_name = sub_gm.__class__.__name__ + gm.add_submodule(submodule_name, sub_gm) + + # Create a call_module node in main graph. + module_node = gm.graph.call_module( + submodule_name, + args=orig_inputs, + kwargs=None) + + if len(orig_outputs) == 1: + # main_remapping[comp.orig_outputs[0]] = module_node + orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True) + else: + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + + module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs) + return gm + +@compatibility(is_backward_compatible=False) +def erase_nodes(gm: GraphModule, nodes: NodeList): + + # erase original nodes in inversed topological order + for node in reversed(nodes): + gm.graph.erase_node(node) + + +@compatibility(is_backward_compatible=False) +def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule: + for partition_id, nodes in enumerate(partitions): + sorted_nodes = topo_sort(nodes) + + submodule_name = prefix + str(partition_id) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) + + insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) + + erase_nodes(gm, sorted_nodes) + + # topological sort original gm with newly created sub_gm + legalize_graph(gm) + + return gm diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56b9d96348e8de5619f2a78682090d5e4448c5ce --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py @@ -0,0 +1,401 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass, field +from collections import defaultdict +import copy +import torch +from torch.fx import ( + Node, + Graph, +) +from torch.fx._compatibility import compatibility +from typing import Dict, List, Set, Any, Union, Tuple +import logging +import os + +__all__ = ['SubgraphMatcher', 'InternalMatch'] + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger(): + logger = logging.getLogger(__name__) + + level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + +logger = _init_logger() + +@compatibility(is_backward_compatible=False) +@dataclass +class InternalMatch: + # Nodes from which the match was found + anchors: List[Node] + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] = field(default_factory=dict) + + # nodes in target graph that are matched placeholder in pattern + placeholder_nodes: List[Node] = field(default_factory=list) + + # nodes in matched subgraph returned by output + returning_nodes: List[Node] = field(default_factory=list) + + # map from a string name to a node in the target graph + # only available if the matcher is `SubgraphMatcherWithNameNodesMap` + name_node_map: Dict[str, Node] = field(default_factory=dict) + + def __copy__(self): + return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(), + placeholder_nodes=self.placeholder_nodes.copy(), + returning_nodes=self.returning_nodes.copy()) + +@compatibility(is_backward_compatible=False) +class SubgraphMatcher: + def __init__(self, pattern: Graph, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False) -> None: + """ + Args: + pattern: the targeted matching pattern, represented in fx.Graph. + match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. + If False, output node is ignored during match. + match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of + the targeted pattern. If False, placeholder nodes will be used a wildcard. + remove_overlapping_matches: If True, in the case of overlapping matches, only the first match + will be returned. + ignore_literals: If True, will not check if literals are equal and + will instead treat them as wildcards. + """ + + self.pattern = pattern + self.match_output = match_output + self.match_placeholder = match_placeholder + self.remove_overlapping_matches = remove_overlapping_matches + self.ignore_literals = ignore_literals + + if len(pattern.nodes) == 0: + raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern") + + for node in pattern.nodes: + if node.op != "output": + assert len(node.users) > 0, \ + "SubgraphMatcher cannot be initialized with an pattern with dead code" + + # TODO: assert pattern is a connected graph + + self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"] + output_node = next(iter(reversed(pattern.nodes))) + # nodes returned by outputs + self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes + + self.pattern_anchors: List[Node] = [] + if match_output: + self.pattern_anchors = [output_node] + else: + # If a node has output_node as the ONLY user, then this node is a graph sink, + # and should be matched against as an anchor + self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1] + + def _match_attributes(self, pn: Node, gn: Node) -> bool: + # Attributes matching is complicated. Right now we only support matching constant tensor + assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string." + assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string." + + # TODO(tmanlaibaatar) should probably make this actual API + def _getattr(model: torch.fx.GraphModule, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + pn_value = _getattr(pn.graph.owning_module, pn.target) + gn_value = _getattr(gn.graph.owning_module, gn.target) + + if type(pn_value) != type(gn_value): + return False + + # Don't require exact match on tensor values. + if isinstance(pn_value, torch.Tensor): + return isinstance(gn_value, torch.Tensor) + else: + raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") + return False + + def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: + # if exact match for placeholder is not required, then use placeholder as a wildcard + if not self.match_placeholder and pn.op == "placeholder": + return True + + if pn.op == gn.op: + if pn.op == "placeholder" or pn.op == "output": + return True + elif pn.op == "get_attr": + return self._match_attributes(pn, gn) + return pn.target == gn.target + return False + + def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + + # Placeholders can be used by other nodes in the graphs + lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"} + + for gn, pn in lookup.items(): + # nodes returned by output are allowed to be used in other areas of the graph + if pn in self.pattern_returning_nodes: + continue + + for user in gn.users: + # If this node has users that were not in `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True + + def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]: + non_overlapping_matches: List[InternalMatch] = [] + nodes_matched: Set[Node] = set() + + for match in matches: + found_overlap = False + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"} and gn in nodes_matched: + found_overlap = True + break + + if not found_overlap: + non_overlapping_matches.append(match) + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"}: + nodes_matched.add(gn) + return non_overlapping_matches + + def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: + assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node" + + if isinstance(pn, Node) and not isinstance(gn, Node): + if pn.op == "placeholder": + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + match.nodes_map[pn] = gn + return True + else: + return False + elif not isinstance(pn, Node) and isinstance(gn, Node): + return False + else: + return type(gn) == type(pn) and gn == pn + + def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: + logger.info(" matching %s to %s", pn, gn) + + assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}") + + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + # TODO: use a more efficient way to check if gn is matched before: two-way dict + if gn in match.nodes_map.values(): + return False + + if not self._nodes_are_equal(pn, gn): + return False + + # Optimistically mark `pn` as a match for `gn`, and save a local copy of match + saved_match = copy.copy(match) + match.nodes_map[pn] = gn + + # Placeholder is a wildcard and can be matched with any python object + # (including list/tuple) + if pn.op == "placeholder": + return True + + # Recursively traverse upwards to check if `pn` is a true + # match for `gn` + match_found = True + + def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool: + if len(args1) != len(args2): + return False + + for a1, a2 in zip(args1, args2): + if isinstance(a1, Node) and isinstance(a2, Node): + matched = self._match_nodes(a1, a2, match) + elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)): + matched = _match_args(a1, a2) + else: + matched = self._match_literals(a1, a2, match) or self.ignore_literals + + if not matched: + return False + + return True + + # Flatten all args/kwargs into 1 list of args + pn_args, gn_args = None, None + if ( + (len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and + pn.op == "call_function" and + isinstance(pn.target, torch._ops.OpOverload) + ): + args_schema = pn.target._schema.arguments + + def get_all_arguments(orig_args, orig_kwargs): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + pn_args = get_all_arguments(pn.args, pn.kwargs) + gn_args = get_all_arguments(gn.args, gn.kwargs) + + elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()): + pn_args = list(pn.args) + gn_args = list(gn.args) + pn_args.extend(list(pn.kwargs.values())) + gn_args.extend(list(gn.kwargs.values())) + else: + match_found = False + + match_found = ( + match_found and + pn_args is not None and + gn_args is not None and + _match_args(pn_args, gn_args) + ) + + if not match_found: + # revert to saved_match before matching with current node + match = copy.copy(saved_match) + return False + + return True + + def match(self, graph: Graph) -> List[InternalMatch]: + """ + Returns: + The matched subgraphs. + Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder + and nodes returned by output) can only be consumed by nodes within the matched subgraph. + + Subgraph pattern matcher is implemented with the backtracking style in the following steps: + + 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes + are the "sinks" (nodes with no user other than the output node) of the pattern graph. + One pattern graph could have multiple anchors if it has multiple return values. + + 2. In the target graph, we identify the potential candidate nodes that can be matched + with each anchor. These anchor-candidate pairs are the starting points for + pairwise per-node matching. + + 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both + pattern and target graphs. For every pattern nodes along traversal path, we compare it + against the target nodes. In case any comparison failed, the match for this anchor-candidate + pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` + for more details. + + 4. In the case of multiple anchors, every anchor will need to find a match using step 3. + In addition, the matches found between anchors need to have a common intersection node + in order for the match to be valid. This is implemented with backtracking. See `backtracking` + for more details. + + Notice: graph traversal must be done in the reverser order because a tensor can have multiple + consumers, but can only have a single producer. Only with reverser order, we can we jointly + traverse the pattern and target graph in a deterministic path. + + Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, + in practice, it's unlikely to blow up. + + """ + from torch.fx.passes.utils.fuser_utils import validate_partition + + # find candidate nodes to match with pattern anchors + match_candidates: Dict[Node, List[Node]] = defaultdict(list) + for pattern_anchor in self.pattern_anchors: + for node in graph.nodes: + if self._nodes_are_equal(pattern_anchor, node): + match_candidates[pattern_anchor].append(node) + match_candidates_list = list(match_candidates.items()) + + logger.info("Initial match_candidates_list: %s\n", match_candidates_list) + + matches: List[InternalMatch] = [] + + def backtracking(anchor_index, match): + if anchor_index == len(match_candidates_list): + match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes] + match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes] + matches.append(match) + + logger.info("Found a match: %s\n", match) + return + + pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] + saved_match = copy.copy(match) + + for node in candidate_nodes: + logger.info("Trying to match anchor %s to %s", pattern_anchor, node) + + match_found = self._match_nodes(pattern_anchor, node, match) + if match_found: + # match next anchor + backtracking(anchor_index + 1, match) + else: + logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node) + + # revert to saved_match before matching with current anchor + match = copy.copy(saved_match) + + match = InternalMatch(anchors=self.pattern_anchors) + if match_candidates_list: + backtracking(0, match) + + # filter out the matches where the subgraph is not fully_contained + before = len(matches) + matches = [match for match in matches if self._is_contained(match.nodes_map)] + after = len(matches) + if before != after: + logger.info("Filtered out %s matches because they are not fully contained", before - after) + + # filter out the matches that form a cycle if the subgraph is fused + valid_matches = [] + for match in matches: + matched_compute_nodes = \ + [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}] + if validate_partition(matched_compute_nodes): + valid_matches.append(match) + if len(valid_matches) != len(matches): + logger.info("Filtered out %s matches because \ + matched subgraph would form a cycle if fused", len(matches) - len(valid_matches)) + + if self.remove_overlapping_matches: + before = len(valid_matches) + matches = self._remove_overlapping_matches(valid_matches) + after = len(matches) + if before != after: + logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after) + + logger.info("Matches returned: %s", matches) + + return matches diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8482dca74b1002ad5f4e4a66b8dd30b800fc2765 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -0,0 +1,114 @@ +from typing import Dict, List, Tuple + +from torch.fx import Graph, GraphModule, Node +from torch.fx._compatibility import compatibility + +from .matcher_utils import InternalMatch, SubgraphMatcher + + +__all__ = ["SubgraphMatcherWithNameNodeMap"] + + +def _split_to_graph_and_name_node_map( + gm: GraphModule, +) -> Tuple[GraphModule, Dict[str, Node]]: + from torch.fx.graph import _PyTreeInfo + from torch.utils._pytree import tree_flatten, tree_unflatten + + name_node_map = {} + for n in gm.graph.nodes: + if n.op == "output": + assert gm._out_spec is not None + output = tree_unflatten(n.args[0], gm._out_spec) + assert isinstance( + output, tuple + ), "Expecting the pattern graph to return a tuple" + assert ( + len(output) >= 2 + ), "Expecting the pattern graph to have at least two outputs" + *out, name_node_map = output + flattened, out_spec = tree_flatten(out) + assert isinstance( + name_node_map, Dict + ), "Expecting the input graph to have a dict output as the last element" + n.args = (flattened,) + orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] + gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] + orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec + ) + gm.recompile() + return gm, name_node_map + + +@compatibility(is_backward_compatible=False) +class SubgraphMatcherWithNameNodeMap(SubgraphMatcher): + """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name, + this requires pattern to have specific format (returning and additional dictionary at the output, + that has node name as key, and the node in the pattern graph as value, see Example for more details) + + Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during + initialization since we need to modify the graph (which requires `recompile` the GraphModule) + + Example:: + def pattern(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + return relu, {"conv": conv, "relu": relu} + + def target_graph(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + relu *= 2 + return relu + + pattern_gm = capture_pre_autograd_graph(pattern, example_inputs) + target_gm = capture_pre_autograd_graph(target_graph, example_inputs) + matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) + matches = matcher.match(target_gm) + for match in matches: + match.name_node_map["conv"].meta["annotation"] = ... + + """ + + def __init__( + self, + pattern_gm: GraphModule, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False, + ) -> None: + pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) + self.name_node_map = name_node_map + super().__init__( + pattern_gm.graph, + match_output, + match_placeholder, + remove_overlapping_matches, + ignore_literals, + ) + + def match(self, graph: Graph) -> List[InternalMatch]: + """The returned InternalMatch will have name_node_map populated with a map + from node name (str) to the target node, e.g. + {"conv": target_conv_ndoe, "relu": target_relu_node} + + this requires the pattern graph returns an additional + output of node name to node, e.g. instead of: + ``` + def pattern(...): + ... + return relu + ``` + we should do: + ``` + def pattern(...): + ... + return relu, {"conv": conv, "relu": relu} + ``` instead + """ + internal_matches = super().match(graph) + for internal_match in internal_matches: + for k, n in self.name_node_map.items(): + internal_match.name_node_map[k] = internal_match.nodes_map[n] + return internal_matches diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4f072644cdbb01edab042ad42a8a617cce8314 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass, field +from torch.fx.graph import Graph +from torch.fx.node import Node +from torch.fx._compatibility import compatibility +from typing import Dict, List, Any, Type, Optional, Callable +import logging +import os + + +__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition'] + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + + level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class SourcePartition: + # Nodes in a particular partition + nodes: List[Node] + + # The source these nodes decomposed from + source: Any + + # Nodes in the graph that are needed as inputs to the partition + input_nodes: List[Node] = field(default_factory=list) + + # Nodes in the partition that are being used by nodes outside of the + # partition + output_nodes: List[Node] = field(default_factory=list) + + # Parameters that are being used + params: List[Node] = field(default_factory=list) + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def get_source_partitions( + graph: Graph, + wanted_sources: List[Any], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Dict[Any, List[SourcePartition]]: + """ + Args: + graph: The graph we want to partition + wanted_sources: List of sources of nodes that were decomposed from this + source. This can be a function (ex. torch.nn.functional.linear) or a + leaf module type (ex. torch.nn.Linear). + + Returns: + Dictionary mapping sources that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + source. + """ + modules: Dict[Type, Dict[str, List[Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + + # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can + # be different from "source_fn_stack", for example for the add_ node + # decomposed from batch norm. We should remove the check on "source_fn_stack" + # after we fix "torch_fn". T199561090 + if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and + (torch_fn := node.meta.get("torch_fn", None)) is not None): + node_fqn, source_fn = torch_fn + source_fn_name = source_fn.split(".")[1] + if source_fn_name in wanted_sources: + diff_modules = modules.setdefault(source_fn_name, {}) + partition = diff_modules.setdefault(node_fqn, []) + partition.append(node) + + + if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None: + source_fn = source_fn_st[-1] + if source_fn[1] in wanted_sources: + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(source_fn[0], []) + partition.append(node) + + def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, Node) and arg not in nodes: + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + + for user in node.users.keys(): + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: Dict[Type[Any], List[SourcePartition]] = {} + + if filter_fn: + # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the + # filter condition + filtered_modules = {} + for tp, name_to_partition in modules.items(): + filtered_name_to_partition = { + name: partition + for name, partition in name_to_partition.items() + if all(map(filter_fn, partition)) + } + filtered_modules[tp] = filtered_name_to_partition + modules = filtered_modules + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool: + """ + Given two subgraphs A and B (in the form of a list of nodes), checks if + A has nodes connecting to at least one node in B -- aka there exists a node + in B that uses a node in A (not the other way around). + """ + + for node in reversed(subgraph1.nodes): + for user in node.users.keys(): + if user in subgraph2.nodes: + return True + return False diff --git a/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e25cbb22d64798997b52bf718116d657352d0f4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d40d0a8858cf7ba7ccc4732f37b4ea2520cbcd4f937045965175306e3f23d3b5 +size 249046