Spaces:
Running on Zero
Running on Zero
| """ | |
| """ | |
| # pyright: reportPrivateImportUsage=false | |
| import gc | |
| import multiprocessing | |
| import os | |
| import shutil | |
| from collections import defaultdict | |
| from concurrent.futures import ProcessPoolExecutor | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import nullcontext | |
| from contextvars import copy_context | |
| from pathlib import Path | |
| from typing import Any | |
| from typing import Callable | |
| import torch | |
| from torch.overrides import TorchFunctionMode | |
| from torch.overrides import resolve_name | |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass | |
| from torch.utils._python_dispatch import transform_subclass | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| from torch.utils._pytree import tree_map_only | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| from ...config import Config | |
| from ..mmap import unmap_capture | |
| from ..tqdm import tqdm | |
| from ..utils import malloc_trim | |
| from . import cudart | |
| from .packing import ZeroGPUTensorPack | |
| from .packing import pack_tensors | |
| from .packing import pack_to_cuda | |
| from .static import * | |
| from .utils import empty_like_raw_alloc | |
| from .types import AliasId | |
| PINNED_MEMORY_RATIO_LIMIT = 0.1 | |
| OPS_INPUTS_CHECK_NO_RETURN = ( | |
| torch.Tensor.equal, | |
| ) | |
| OPS_INPUT_CHECK_SELF_RETURN = ( | |
| torch.Tensor.set_, # probably never dispatched | |
| torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue] | |
| ) | |
| OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}" | |
| _tensor_make_subclass = torch.Tensor._make_subclass | |
| _asarray = torch.asarray | |
| _device = torch.device | |
| _cuda_init = torch._C._cuda_init | |
| _cuda_exchange_device = torch.cuda._exchange_device | |
| _cuda_available = torch.cuda.is_available | |
| _cuda_device_count = torch.cuda.device_count | |
| _cuda_current_device = torch.cuda.current_device | |
| _cuda_synchronize = torch.cuda.synchronize | |
| _cuda_get_device_capability = torch.cuda.get_device_capability | |
| _cuda_get_device_properties = torch.cuda.get_device_properties | |
| _cuda_get_device_name = torch.cuda.get_device_name | |
| _cuda_memory_stats_as_nested_dict = torch.cuda.memory.memory_stats_as_nested_dict | |
| _cuda_cudart = torch.cuda.cudart | |
| # PyTorch 2.3 | |
| _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None) | |
| cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType] | |
| tensor_packs: list[ZeroGPUTensorPack] = [] | |
| class ZeroGPUTensor(torch.Tensor): | |
| pass | |
| def empty_fake(tensor: torch.Tensor): | |
| fake = empty_like_raw_alloc(tensor, requires_grad=tensor.requires_grad) | |
| if fake.__class__ != tensor.__class__: | |
| fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType] | |
| return fake | |
| # Torch 2.5: https://github.com/pytorch/pytorch/issues/144152 | |
| def no_int_device(*args, **kwargs): | |
| if len(args) and isinstance(index := args[0], int): | |
| args = (f'cuda:{index}', *args[1:]) | |
| if isinstance(index := kwargs.get('device'), int): | |
| kwargs['device'] = f'cuda:{index}' | |
| return args, kwargs | |
| class ZeroGPUFunctionMode(TorchFunctionMode): | |
| def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): | |
| kwargs = {} if kwargs is None else kwargs | |
| if func == torch._C._nn._parse_to: # pyright: ignore [reportAttributeAccessIssue] | |
| args, kwargs = no_int_device(*args, **kwargs) | |
| return func(*args, **kwargs) | |
| # Redispatch: tensor.cuda() -> tensor.to(device='cuda') | |
| if func == torch.Tensor.cuda or func == torch.Tensor.cpu: | |
| memory_format = kwargs.get('memory_format') | |
| return self.__torch_function__(torch.Tensor.to, types, (args[0],), { | |
| 'device': 'cuda' if func == torch.Tensor.cuda else 'cpu', | |
| **({'memory_format': memory_format} if memory_format is not None else {}), | |
| }) | |
| # Redispatch: tensor.to('cuda') -> tensor.to(device='cuda') | |
| if func == torch.Tensor.to and len(args) > 1: | |
| parse_to_args, parse_to_kwargs = no_int_device(*args[1:], **kwargs) | |
| # We are using nn._parse_to utility to parse generic Tensor.to but nn does not accept copy kwarg | |
| copy_kwarg = {'copy': parse_to_kwargs.pop('copy')} if 'copy' in parse_to_kwargs else {} | |
| device, dtype, _, memory_format = torch._C._nn._parse_to(*parse_to_args, **parse_to_kwargs) # pyright: ignore [reportAttributeAccessIssue] | |
| return self.__torch_function__(torch.Tensor.to, types, (args[0],), { | |
| 'device': device, | |
| 'dtype': dtype, | |
| 'memory_format': memory_format, | |
| } | copy_kwarg) | |
| if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue] | |
| self, target = args | |
| if target in cuda_aliases: | |
| if (target_original := cuda_aliases[target]) is None: | |
| raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target)) | |
| original = empty_fake(self) | |
| original.data = target_original | |
| cuda_aliases[self] = original | |
| elif self in cuda_aliases: | |
| del cuda_aliases[self] | |
| self.data = target | |
| return | |
| if func == torch.Tensor.device.__get__: | |
| tensor, = args | |
| if tensor in cuda_aliases: | |
| return torch.device('cuda', index=0) | |
| elif func == torch.Tensor.__repr__: | |
| tensor, = args | |
| if tensor in cuda_aliases: | |
| if (original := cuda_aliases[tensor]) is None: | |
| original = tensor.to('meta') | |
| original_class = original.__class__ | |
| original.__class__ = ZeroGPUTensor | |
| try: | |
| return func(original, **kwargs) | |
| finally: | |
| original.__class__ = original_class | |
| elif func == torch.Tensor.untyped_storage: | |
| tensor, = args | |
| if tensor in cuda_aliases: | |
| if (original := cuda_aliases[tensor]) is None: | |
| raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor)) | |
| res = func(original, **kwargs) | |
| res._zerogpu = True | |
| return res | |
| cuda: bool | None = None | |
| # Handle device kwarg | |
| if (device := kwargs.get('device')) is not None: | |
| device = torch.device(device) | |
| if device.type == 'cuda': | |
| kwargs['device'] = torch.device('cpu') | |
| cuda = True | |
| else: | |
| cuda = False | |
| # Swap fake inputs with original data | |
| swapped = {} | |
| inputs_are_cuda = set() | |
| def swap(tensor: torch.Tensor): | |
| nonlocal inputs_are_cuda | |
| if tensor not in cuda_aliases: | |
| inputs_are_cuda |= {False} | |
| return tensor | |
| if (original := cuda_aliases[tensor]) is None: | |
| raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor)) | |
| swapped[original] = tensor | |
| inputs_are_cuda |= {True} | |
| return original | |
| args_ = tree_map_only(torch.Tensor, swap, args) | |
| kwargs_ = tree_map_only(torch.Tensor, swap, kwargs) | |
| if inputs_are_cuda == {True}: | |
| if cuda is not False: | |
| cuda = True | |
| # Wrapper tensors special case (torchao quickix) | |
| if len(args) == 1 and is_traceable_wrapper_subclass(wrapper_tensor := args[0]): | |
| if func in { | |
| torch.Tensor.detach, | |
| torch.ops.aten.alias.default, # pyright: ignore [reportAttributeAccessIssue] | |
| torch.ops.aten.clone.default, # pyright: ignore [reportAttributeAccessIssue] | |
| }: | |
| with self: | |
| return transform_subclass(wrapper_tensor, lambda _, t: func(t)) | |
| res = func(*args_, **kwargs_) | |
| # Re-generate swapped fakes in case of mutation | |
| for original, fake in swapped.items(): | |
| fake.data = empty_fake(original) | |
| # Special case for Tensor indexing where only 'self' matters | |
| if func in { | |
| torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue] | |
| torch.Tensor.__getitem__, # PyTorch 2.4+ | |
| }: | |
| self = args[0] | |
| cuda = self in cuda_aliases | |
| inputs_are_cuda = {cuda} | |
| # Emulate device check | |
| if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN: | |
| self = None | |
| if len(args_) >= 1 and isinstance(args_[0], torch.Tensor): | |
| self = args_[0] | |
| # Only raise if func does not return its first input (Tensor.copy_) | |
| if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN: | |
| if inputs_are_cuda == {True, False}: | |
| raise RuntimeError( | |
| "Expected all tensors to be on the same device, " | |
| "but found at least two devices, cuda:0 (ZeroGPU) and cpu!" | |
| ) | |
| # Register output | |
| def register(tensor: torch.Tensor): | |
| if tensor in swapped and cuda is not False: | |
| return swapped[tensor] | |
| if cuda is not True: | |
| return tensor | |
| fake = empty_fake(tensor) | |
| cuda_aliases[fake] = tensor | |
| return fake | |
| return tree_map_only(torch.Tensor, register, res) | |
| # When enabling DispatchMode, some aten ops are dispatched to FunctionMode | |
| # We are using it for aten.alias.default and aten.set_.source_Tensor | |
| class DefaultDispatchMode(TorchDispatchMode): | |
| def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): | |
| return func(*args, **(kwargs or {})) | |
| function_mode = ZeroGPUFunctionMode() | |
| dispatch_mode = DefaultDispatchMode() | |
| def _untyped_storage_new_register(*args, **kwargs): | |
| cuda = False | |
| if (device := kwargs.get('device')) is not None: | |
| device = torch.device(device) | |
| if device.type == 'cuda': | |
| cuda = True | |
| del kwargs['device'] | |
| storage = torch._C.StorageBase.__new__(*args, **kwargs) | |
| if cuda: | |
| storage._zerogpu = True | |
| return storage | |
| def _untyped_storage_device(self): | |
| if hasattr(self, '_zerogpu'): | |
| return torch.device('cuda', index=0) | |
| return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue] | |
| # Force dispatch | |
| def _tensor_make_subclass_function_mode(*args, **kwargs): | |
| with torch._C.DisableTorchFunction(): | |
| return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs) | |
| def _asarray_function_mode(*args, **kwargs): | |
| with torch._C.DisableTorchFunction(): | |
| return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs) | |
| class _DeviceStringOnlyMeta(type): | |
| def __instancecheck__(cls, instance): | |
| return isinstance(instance, _device) | |
| class _DeviceStringOnly(metaclass=_DeviceStringOnlyMeta): | |
| def __new__(cls, *args, **kwargs): | |
| args, kwargs = no_int_device(*args, **kwargs) | |
| return _device(*args, **kwargs) | |
| def _cuda_init_raise(): | |
| raise RuntimeError( | |
| "Low-level CUDA init (`torch._C._cuda_init`) reached. " | |
| "This means ZeroGPU's PyTorch CUDA emulation mode " | |
| "did not intercept a CUDA operation in your code.\n" | |
| "Check this stacktrace to locate the trigger." | |
| ) | |
| def _cuda_dummy_exchange_device(device): | |
| assert device in {-1, 0} | |
| return device | |
| def patch(): | |
| function_mode.__enter__() | |
| dispatch_mode.__enter__() | |
| # TODO: only patch bellow methods on current Thread to be consistent with TorchModes | |
| # (or hijack threading.Thread.__init__ to force Modes on all threads) | |
| torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue] | |
| torch.UntypedStorage.__new__ = _untyped_storage_new_register | |
| torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue] | |
| torch.asarray = _asarray_function_mode | |
| torch.device = _DeviceStringOnly | |
| torch._C._cuda_init = _cuda_init_raise | |
| torch.cuda._exchange_device = _cuda_dummy_exchange_device | |
| torch.cuda.is_available = lambda: True | |
| torch.cuda.device_count = lambda: 1 | |
| torch.cuda.current_device = lambda: 0 | |
| torch.cuda.synchronize = lambda *args: None | |
| torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY | |
| torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES | |
| torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME | |
| torch.cuda.memory.memory_stats_as_nested_dict = lambda *args, **kwargs: CUDA_MEMORY_STATS_AS_NESTED_DICT | |
| torch.cuda.cudart = lambda: cudart | |
| # PyTorch 2.3 | |
| if _cuda_maybe_exchange_device is not None: # pragma: no cover | |
| setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device) | |
| bitsandbytes().patch() | |
| def unpatch(): | |
| try: | |
| dispatch_mode.__exit__(None, None, None) | |
| function_mode.__exit__(None, None, None) | |
| except RuntimeError: | |
| pass # patch() and unpatch() called from != threads | |
| torch.Tensor._make_subclass = _tensor_make_subclass | |
| torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__ | |
| torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue] | |
| torch.asarray = _asarray | |
| torch.device = _device | |
| torch._C._cuda_init = _cuda_init | |
| torch.cuda._exchange_device = _cuda_exchange_device | |
| torch.cuda.is_available = _cuda_available | |
| torch.cuda.device_count = _cuda_device_count | |
| torch.cuda.current_device = _cuda_current_device | |
| torch.cuda.synchronize = _cuda_synchronize | |
| torch.cuda.get_device_capability = _cuda_get_device_capability | |
| torch.cuda.get_device_properties = _cuda_get_device_properties | |
| torch.cuda.get_device_name = _cuda_get_device_name | |
| torch.cuda.memory.memory_stats_as_nested_dict = _cuda_memory_stats_as_nested_dict | |
| torch.cuda.cudart = _cuda_cudart | |
| # PyTorch 2.3 | |
| if _cuda_maybe_exchange_device is not None: # pragma: no cover | |
| setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device) | |
| bitsandbytes().unpatch() | |
| def _total_unpacked_size(): | |
| tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None] | |
| deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors} | |
| return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()]) | |
| def _pack(offload_dir: str): | |
| # Pack to disk | |
| originals: set[torch.Tensor] = set() | |
| originals_dedup: dict[AliasId, torch.Tensor] = {} | |
| fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list) | |
| for fake, original in cuda_aliases.items(): | |
| # TODO filter-out sparse Tensors | |
| if original is not None: | |
| original_id = AliasId.from_tensor(original) | |
| if original_id not in originals_dedup: | |
| originals_dedup[original_id] = original | |
| originals |= {original} | |
| fakes[originals_dedup[original_id]] += [fake] | |
| total_size = _total_unpacked_size() | |
| progress = tqdm( | |
| total=total_size, | |
| unit='B', | |
| unit_scale=True, | |
| desc="ZeroGPU tensors packing", | |
| ) if tqdm is not None else nullcontext() | |
| with progress as progress: | |
| update = progress.update if progress is not None else lambda _: None | |
| pack = pack_tensors(originals, fakes, offload_dir, callback=update) | |
| tensor_packs.append(pack) | |
| # Free memory | |
| for fake_list in fakes.values(): | |
| for fake in fake_list: | |
| cuda_aliases[fake] = None | |
| return total_size | |
| def pack(): | |
| if len(cuda_aliases) == 0: | |
| return 0 | |
| shutil.rmtree(Config.zerogpu_offload_dir, ignore_errors=True) | |
| Path(Config.zerogpu_offload_dir).mkdir(parents=True) | |
| mmap_addrs = [tensor.data_ptr() for tensor in cuda_aliases.values() if tensor is not None] | |
| with unmap_capture(mmap_addrs) as unmapped_paths: | |
| total_size = _pack(Config.zerogpu_offload_dir) | |
| total_cleaned_size = 0 | |
| for path in unmapped_paths: | |
| total_cleaned_size += path.lstat().st_size | |
| path.unlink(missing_ok=True) | |
| if total_cleaned_size > 0: | |
| print( | |
| f"Cleaned {total_cleaned_size / 2**30:.2f}GB of tensor files " | |
| f"from {Config.zerogpu_mmap_autoprune_pattern} after packing" | |
| ) | |
| gc.collect() | |
| malloc_trim() | |
| return total_size | |
| def init(nvidia_uuid: str): | |
| os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid | |
| torch.Tensor([0]).cuda() | |
| def size(): | |
| return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs]) | |
| def _move(callback: Callable[[int], Any] | None = None): | |
| callback = callback if callback is not None else lambda _: None | |
| # CPU -> CUDA | |
| pinned_limit = _total_unpacked_size() * PINNED_MEMORY_RATIO_LIMIT | |
| moved: dict[AliasId, torch.Tensor] = {} | |
| for fake, original in cuda_aliases.items(): | |
| if original is not None: | |
| original = torch.Tensor(original) # unwrap subclass | |
| original_id = AliasId.from_tensor(original) | |
| if original_id not in moved: | |
| if original.numel() * original.element_size() < pinned_limit: | |
| original_cuda = original.pin_memory().cuda(non_blocking=True) | |
| else: | |
| original_cuda = original.cuda() | |
| moved[original_id] = original_cuda | |
| callback(fake.numel() * fake.element_size()) | |
| torch.cuda.synchronize() | |
| for fake, original in cuda_aliases.items(): | |
| if original is not None: | |
| fake.data = moved[AliasId.from_tensor(original)] | |
| # Disk -> CUDA | |
| for tensor_pack in tensor_packs: | |
| pack_to_cuda(tensor_pack, callback=callback) | |
| bitsandbytes().move() | |
| def move(callback: Callable[[int], Any] | None = None): | |
| callback = callback if callback is not None else lambda _: None | |
| with ThreadPoolExecutor(1) as e: | |
| e.submit(copy_context().run, _move, callback=callback).result() | |
| torch.cuda.synchronize() | |
| def is_in_bad_fork(): | |
| with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e: | |
| f = e.submit(torch.cuda._is_in_bad_fork) | |
| return f.result() | |
| def bitsandbytes(): | |
| # Lazy import | |
| from . import bitsandbytes | |
| return bitsandbytes | |