""" """ # pyright: reportPrivateImportUsage=false import gc import multiprocessing import os import shutil from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from contextvars import copy_context from pathlib import Path from typing import Any from typing import Callable import torch from torch.overrides import TorchFunctionMode from torch.overrides import resolve_name from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._python_dispatch import transform_subclass from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map_only from torch.utils.weak import WeakTensorKeyDictionary from ...config import Config from ..mmap import unmap_capture from ..tqdm import tqdm from ..utils import malloc_trim from . import cudart from .packing import ZeroGPUTensorPack from .packing import pack_tensors from .packing import pack_to_cuda from .static import * from .utils import empty_like_raw_alloc from .types import AliasId PINNED_MEMORY_RATIO_LIMIT = 0.1 OPS_INPUTS_CHECK_NO_RETURN = ( torch.Tensor.equal, ) OPS_INPUT_CHECK_SELF_RETURN = ( torch.Tensor.set_, # probably never dispatched torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue] ) OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}" _tensor_make_subclass = torch.Tensor._make_subclass _asarray = torch.asarray _device = torch.device _cuda_init = torch._C._cuda_init _cuda_exchange_device = torch.cuda._exchange_device _cuda_available = torch.cuda.is_available _cuda_device_count = torch.cuda.device_count _cuda_current_device = torch.cuda.current_device _cuda_synchronize = torch.cuda.synchronize _cuda_get_device_capability = torch.cuda.get_device_capability _cuda_get_device_properties = torch.cuda.get_device_properties _cuda_get_device_name = torch.cuda.get_device_name _cuda_memory_stats_as_nested_dict = torch.cuda.memory.memory_stats_as_nested_dict _cuda_cudart = torch.cuda.cudart # PyTorch 2.3 _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None) cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType] tensor_packs: list[ZeroGPUTensorPack] = [] class ZeroGPUTensor(torch.Tensor): pass def empty_fake(tensor: torch.Tensor): fake = empty_like_raw_alloc(tensor, requires_grad=tensor.requires_grad) if fake.__class__ != tensor.__class__: fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType] return fake # Torch 2.5: https://github.com/pytorch/pytorch/issues/144152 def no_int_device(*args, **kwargs): if len(args) and isinstance(index := args[0], int): args = (f'cuda:{index}', *args[1:]) if isinstance(index := kwargs.get('device'), int): kwargs['device'] = f'cuda:{index}' return args, kwargs class ZeroGPUFunctionMode(TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): kwargs = {} if kwargs is None else kwargs if func == torch._C._nn._parse_to: # pyright: ignore [reportAttributeAccessIssue] args, kwargs = no_int_device(*args, **kwargs) return func(*args, **kwargs) # Redispatch: tensor.cuda() -> tensor.to(device='cuda') if func == torch.Tensor.cuda or func == torch.Tensor.cpu: memory_format = kwargs.get('memory_format') return self.__torch_function__(torch.Tensor.to, types, (args[0],), { 'device': 'cuda' if func == torch.Tensor.cuda else 'cpu', **({'memory_format': memory_format} if memory_format is not None else {}), }) # Redispatch: tensor.to('cuda') -> tensor.to(device='cuda') if func == torch.Tensor.to and len(args) > 1: parse_to_args, parse_to_kwargs = no_int_device(*args[1:], **kwargs) # We are using nn._parse_to utility to parse generic Tensor.to but nn does not accept copy kwarg copy_kwarg = {'copy': parse_to_kwargs.pop('copy')} if 'copy' in parse_to_kwargs else {} device, dtype, _, memory_format = torch._C._nn._parse_to(*parse_to_args, **parse_to_kwargs) # pyright: ignore [reportAttributeAccessIssue] return self.__torch_function__(torch.Tensor.to, types, (args[0],), { 'device': device, 'dtype': dtype, 'memory_format': memory_format, } | copy_kwarg) if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue] self, target = args if target in cuda_aliases: if (target_original := cuda_aliases[target]) is None: raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target)) original = empty_fake(self) original.data = target_original cuda_aliases[self] = original elif self in cuda_aliases: del cuda_aliases[self] self.data = target return if func == torch.Tensor.device.__get__: tensor, = args if tensor in cuda_aliases: return torch.device('cuda', index=0) elif func == torch.Tensor.__repr__: tensor, = args if tensor in cuda_aliases: if (original := cuda_aliases[tensor]) is None: original = tensor.to('meta') original_class = original.__class__ original.__class__ = ZeroGPUTensor try: return func(original, **kwargs) finally: original.__class__ = original_class elif func == torch.Tensor.untyped_storage: tensor, = args if tensor in cuda_aliases: if (original := cuda_aliases[tensor]) is None: raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor)) res = func(original, **kwargs) res._zerogpu = True return res cuda: bool | None = None # Handle device kwarg if (device := kwargs.get('device')) is not None: device = torch.device(device) if device.type == 'cuda': kwargs['device'] = torch.device('cpu') cuda = True else: cuda = False # Swap fake inputs with original data swapped = {} inputs_are_cuda = set() def swap(tensor: torch.Tensor): nonlocal inputs_are_cuda if tensor not in cuda_aliases: inputs_are_cuda |= {False} return tensor if (original := cuda_aliases[tensor]) is None: raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor)) swapped[original] = tensor inputs_are_cuda |= {True} return original args_ = tree_map_only(torch.Tensor, swap, args) kwargs_ = tree_map_only(torch.Tensor, swap, kwargs) if inputs_are_cuda == {True}: if cuda is not False: cuda = True # Wrapper tensors special case (torchao quickix) if len(args) == 1 and is_traceable_wrapper_subclass(wrapper_tensor := args[0]): if func in { torch.Tensor.detach, torch.ops.aten.alias.default, # pyright: ignore [reportAttributeAccessIssue] torch.ops.aten.clone.default, # pyright: ignore [reportAttributeAccessIssue] }: with self: return transform_subclass(wrapper_tensor, lambda _, t: func(t)) res = func(*args_, **kwargs_) # Re-generate swapped fakes in case of mutation for original, fake in swapped.items(): fake.data = empty_fake(original) # Special case for Tensor indexing where only 'self' matters if func in { torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue] torch.Tensor.__getitem__, # PyTorch 2.4+ }: self = args[0] cuda = self in cuda_aliases inputs_are_cuda = {cuda} # Emulate device check if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN: self = None if len(args_) >= 1 and isinstance(args_[0], torch.Tensor): self = args_[0] # Only raise if func does not return its first input (Tensor.copy_) if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN: if inputs_are_cuda == {True, False}: raise RuntimeError( "Expected all tensors to be on the same device, " "but found at least two devices, cuda:0 (ZeroGPU) and cpu!" ) # Register output def register(tensor: torch.Tensor): if tensor in swapped and cuda is not False: return swapped[tensor] if cuda is not True: return tensor fake = empty_fake(tensor) cuda_aliases[fake] = tensor return fake return tree_map_only(torch.Tensor, register, res) # When enabling DispatchMode, some aten ops are dispatched to FunctionMode # We are using it for aten.alias.default and aten.set_.source_Tensor class DefaultDispatchMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): return func(*args, **(kwargs or {})) function_mode = ZeroGPUFunctionMode() dispatch_mode = DefaultDispatchMode() def _untyped_storage_new_register(*args, **kwargs): cuda = False if (device := kwargs.get('device')) is not None: device = torch.device(device) if device.type == 'cuda': cuda = True del kwargs['device'] storage = torch._C.StorageBase.__new__(*args, **kwargs) if cuda: storage._zerogpu = True return storage @property def _untyped_storage_device(self): if hasattr(self, '_zerogpu'): return torch.device('cuda', index=0) return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue] # Force dispatch def _tensor_make_subclass_function_mode(*args, **kwargs): with torch._C.DisableTorchFunction(): return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs) def _asarray_function_mode(*args, **kwargs): with torch._C.DisableTorchFunction(): return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs) class _DeviceStringOnlyMeta(type): def __instancecheck__(cls, instance): return isinstance(instance, _device) class _DeviceStringOnly(metaclass=_DeviceStringOnlyMeta): def __new__(cls, *args, **kwargs): args, kwargs = no_int_device(*args, **kwargs) return _device(*args, **kwargs) def _cuda_init_raise(): raise RuntimeError( "Low-level CUDA init (`torch._C._cuda_init`) reached. " "This means ZeroGPU's PyTorch CUDA emulation mode " "did not intercept a CUDA operation in your code.\n" "Check this stacktrace to locate the trigger." ) def _cuda_dummy_exchange_device(device): assert device in {-1, 0} return device def patch(): function_mode.__enter__() dispatch_mode.__enter__() # TODO: only patch bellow methods on current Thread to be consistent with TorchModes # (or hijack threading.Thread.__init__ to force Modes on all threads) torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue] torch.UntypedStorage.__new__ = _untyped_storage_new_register torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue] torch.asarray = _asarray_function_mode torch.device = _DeviceStringOnly torch._C._cuda_init = _cuda_init_raise torch.cuda._exchange_device = _cuda_dummy_exchange_device torch.cuda.is_available = lambda: True torch.cuda.device_count = lambda: 1 torch.cuda.current_device = lambda: 0 torch.cuda.synchronize = lambda *args: None torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME torch.cuda.memory.memory_stats_as_nested_dict = lambda *args, **kwargs: CUDA_MEMORY_STATS_AS_NESTED_DICT torch.cuda.cudart = lambda: cudart # PyTorch 2.3 if _cuda_maybe_exchange_device is not None: # pragma: no cover setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device) bitsandbytes().patch() def unpatch(): try: dispatch_mode.__exit__(None, None, None) function_mode.__exit__(None, None, None) except RuntimeError: pass # patch() and unpatch() called from != threads torch.Tensor._make_subclass = _tensor_make_subclass torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__ torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue] torch.asarray = _asarray torch.device = _device torch._C._cuda_init = _cuda_init torch.cuda._exchange_device = _cuda_exchange_device torch.cuda.is_available = _cuda_available torch.cuda.device_count = _cuda_device_count torch.cuda.current_device = _cuda_current_device torch.cuda.synchronize = _cuda_synchronize torch.cuda.get_device_capability = _cuda_get_device_capability torch.cuda.get_device_properties = _cuda_get_device_properties torch.cuda.get_device_name = _cuda_get_device_name torch.cuda.memory.memory_stats_as_nested_dict = _cuda_memory_stats_as_nested_dict torch.cuda.cudart = _cuda_cudart # PyTorch 2.3 if _cuda_maybe_exchange_device is not None: # pragma: no cover setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device) bitsandbytes().unpatch() def _total_unpacked_size(): tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None] deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors} return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()]) def _pack(offload_dir: str): # Pack to disk originals: set[torch.Tensor] = set() originals_dedup: dict[AliasId, torch.Tensor] = {} fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list) for fake, original in cuda_aliases.items(): # TODO filter-out sparse Tensors if original is not None: original_id = AliasId.from_tensor(original) if original_id not in originals_dedup: originals_dedup[original_id] = original originals |= {original} fakes[originals_dedup[original_id]] += [fake] total_size = _total_unpacked_size() progress = tqdm( total=total_size, unit='B', unit_scale=True, desc="ZeroGPU tensors packing", ) if tqdm is not None else nullcontext() with progress as progress: update = progress.update if progress is not None else lambda _: None pack = pack_tensors(originals, fakes, offload_dir, callback=update) tensor_packs.append(pack) # Free memory for fake_list in fakes.values(): for fake in fake_list: cuda_aliases[fake] = None return total_size def pack(): if len(cuda_aliases) == 0: return 0 shutil.rmtree(Config.zerogpu_offload_dir, ignore_errors=True) Path(Config.zerogpu_offload_dir).mkdir(parents=True) mmap_addrs = [tensor.data_ptr() for tensor in cuda_aliases.values() if tensor is not None] with unmap_capture(mmap_addrs) as unmapped_paths: total_size = _pack(Config.zerogpu_offload_dir) total_cleaned_size = 0 for path in unmapped_paths: total_cleaned_size += path.lstat().st_size path.unlink(missing_ok=True) if total_cleaned_size > 0: print( f"Cleaned {total_cleaned_size / 2**30:.2f}GB of tensor files " f"from {Config.zerogpu_mmap_autoprune_pattern} after packing" ) gc.collect() malloc_trim() return total_size def init(nvidia_uuid: str): os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid torch.Tensor([0]).cuda() def size(): return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs]) def _move(callback: Callable[[int], Any] | None = None): callback = callback if callback is not None else lambda _: None # CPU -> CUDA pinned_limit = _total_unpacked_size() * PINNED_MEMORY_RATIO_LIMIT moved: dict[AliasId, torch.Tensor] = {} for fake, original in cuda_aliases.items(): if original is not None: original = torch.Tensor(original) # unwrap subclass original_id = AliasId.from_tensor(original) if original_id not in moved: if original.numel() * original.element_size() < pinned_limit: original_cuda = original.pin_memory().cuda(non_blocking=True) else: original_cuda = original.cuda() moved[original_id] = original_cuda callback(fake.numel() * fake.element_size()) torch.cuda.synchronize() for fake, original in cuda_aliases.items(): if original is not None: fake.data = moved[AliasId.from_tensor(original)] # Disk -> CUDA for tensor_pack in tensor_packs: pack_to_cuda(tensor_pack, callback=callback) bitsandbytes().move() def move(callback: Callable[[int], Any] | None = None): callback = callback if callback is not None else lambda _: None with ThreadPoolExecutor(1) as e: e.submit(copy_context().run, _move, callback=callback).result() torch.cuda.synchronize() def is_in_bad_fork(): with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e: f = e.submit(torch.cuda._is_in_bad_fork) return f.result() def bitsandbytes(): # Lazy import from . import bitsandbytes return bitsandbytes