multimodalart's picture
multimodalart HF Staff
Upload 75 files
a34bca4 verified
"""
"""
# pyright: reportPrivateImportUsage=false
import gc
import multiprocessing
import os
import shutil
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from contextvars import copy_context
from pathlib import Path
from typing import Any
from typing import Callable
import torch
from torch.overrides import TorchFunctionMode
from torch.overrides import resolve_name
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._python_dispatch import transform_subclass
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakTensorKeyDictionary
from ...config import Config
from ..mmap import unmap_capture
from ..tqdm import tqdm
from ..utils import malloc_trim
from . import cudart
from .packing import ZeroGPUTensorPack
from .packing import pack_tensors
from .packing import pack_to_cuda
from .static import *
from .utils import empty_like_raw_alloc
from .types import AliasId
PINNED_MEMORY_RATIO_LIMIT = 0.1
OPS_INPUTS_CHECK_NO_RETURN = (
torch.Tensor.equal,
)
OPS_INPUT_CHECK_SELF_RETURN = (
torch.Tensor.set_, # probably never dispatched
torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue]
)
OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
_tensor_make_subclass = torch.Tensor._make_subclass
_asarray = torch.asarray
_device = torch.device
_cuda_init = torch._C._cuda_init
_cuda_exchange_device = torch.cuda._exchange_device
_cuda_available = torch.cuda.is_available
_cuda_device_count = torch.cuda.device_count
_cuda_current_device = torch.cuda.current_device
_cuda_synchronize = torch.cuda.synchronize
_cuda_get_device_capability = torch.cuda.get_device_capability
_cuda_get_device_properties = torch.cuda.get_device_properties
_cuda_get_device_name = torch.cuda.get_device_name
_cuda_memory_stats_as_nested_dict = torch.cuda.memory.memory_stats_as_nested_dict
_cuda_cudart = torch.cuda.cudart
# PyTorch 2.3
_cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType]
tensor_packs: list[ZeroGPUTensorPack] = []
class ZeroGPUTensor(torch.Tensor):
pass
def empty_fake(tensor: torch.Tensor):
fake = empty_like_raw_alloc(tensor, requires_grad=tensor.requires_grad)
if fake.__class__ != tensor.__class__:
fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType]
return fake
# Torch 2.5: https://github.com/pytorch/pytorch/issues/144152
def no_int_device(*args, **kwargs):
if len(args) and isinstance(index := args[0], int):
args = (f'cuda:{index}', *args[1:])
if isinstance(index := kwargs.get('device'), int):
kwargs['device'] = f'cuda:{index}'
return args, kwargs
class ZeroGPUFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
kwargs = {} if kwargs is None else kwargs
if func == torch._C._nn._parse_to: # pyright: ignore [reportAttributeAccessIssue]
args, kwargs = no_int_device(*args, **kwargs)
return func(*args, **kwargs)
# Redispatch: tensor.cuda() -> tensor.to(device='cuda')
if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
memory_format = kwargs.get('memory_format')
return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
'device': 'cuda' if func == torch.Tensor.cuda else 'cpu',
**({'memory_format': memory_format} if memory_format is not None else {}),
})
# Redispatch: tensor.to('cuda') -> tensor.to(device='cuda')
if func == torch.Tensor.to and len(args) > 1:
parse_to_args, parse_to_kwargs = no_int_device(*args[1:], **kwargs)
# We are using nn._parse_to utility to parse generic Tensor.to but nn does not accept copy kwarg
copy_kwarg = {'copy': parse_to_kwargs.pop('copy')} if 'copy' in parse_to_kwargs else {}
device, dtype, _, memory_format = torch._C._nn._parse_to(*parse_to_args, **parse_to_kwargs) # pyright: ignore [reportAttributeAccessIssue]
return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
'device': device,
'dtype': dtype,
'memory_format': memory_format,
} | copy_kwarg)
if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue]
self, target = args
if target in cuda_aliases:
if (target_original := cuda_aliases[target]) is None:
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target))
original = empty_fake(self)
original.data = target_original
cuda_aliases[self] = original
elif self in cuda_aliases:
del cuda_aliases[self]
self.data = target
return
if func == torch.Tensor.device.__get__:
tensor, = args
if tensor in cuda_aliases:
return torch.device('cuda', index=0)
elif func == torch.Tensor.__repr__:
tensor, = args
if tensor in cuda_aliases:
if (original := cuda_aliases[tensor]) is None:
original = tensor.to('meta')
original_class = original.__class__
original.__class__ = ZeroGPUTensor
try:
return func(original, **kwargs)
finally:
original.__class__ = original_class
elif func == torch.Tensor.untyped_storage:
tensor, = args
if tensor in cuda_aliases:
if (original := cuda_aliases[tensor]) is None:
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
res = func(original, **kwargs)
res._zerogpu = True
return res
cuda: bool | None = None
# Handle device kwarg
if (device := kwargs.get('device')) is not None:
device = torch.device(device)
if device.type == 'cuda':
kwargs['device'] = torch.device('cpu')
cuda = True
else:
cuda = False
# Swap fake inputs with original data
swapped = {}
inputs_are_cuda = set()
def swap(tensor: torch.Tensor):
nonlocal inputs_are_cuda
if tensor not in cuda_aliases:
inputs_are_cuda |= {False}
return tensor
if (original := cuda_aliases[tensor]) is None:
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
swapped[original] = tensor
inputs_are_cuda |= {True}
return original
args_ = tree_map_only(torch.Tensor, swap, args)
kwargs_ = tree_map_only(torch.Tensor, swap, kwargs)
if inputs_are_cuda == {True}:
if cuda is not False:
cuda = True
# Wrapper tensors special case (torchao quickix)
if len(args) == 1 and is_traceable_wrapper_subclass(wrapper_tensor := args[0]):
if func in {
torch.Tensor.detach,
torch.ops.aten.alias.default, # pyright: ignore [reportAttributeAccessIssue]
torch.ops.aten.clone.default, # pyright: ignore [reportAttributeAccessIssue]
}:
with self:
return transform_subclass(wrapper_tensor, lambda _, t: func(t))
res = func(*args_, **kwargs_)
# Re-generate swapped fakes in case of mutation
for original, fake in swapped.items():
fake.data = empty_fake(original)
# Special case for Tensor indexing where only 'self' matters
if func in {
torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue]
torch.Tensor.__getitem__, # PyTorch 2.4+
}:
self = args[0]
cuda = self in cuda_aliases
inputs_are_cuda = {cuda}
# Emulate device check
if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN:
self = None
if len(args_) >= 1 and isinstance(args_[0], torch.Tensor):
self = args_[0]
# Only raise if func does not return its first input (Tensor.copy_)
if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN:
if inputs_are_cuda == {True, False}:
raise RuntimeError(
"Expected all tensors to be on the same device, "
"but found at least two devices, cuda:0 (ZeroGPU) and cpu!"
)
# Register output
def register(tensor: torch.Tensor):
if tensor in swapped and cuda is not False:
return swapped[tensor]
if cuda is not True:
return tensor
fake = empty_fake(tensor)
cuda_aliases[fake] = tensor
return fake
return tree_map_only(torch.Tensor, register, res)
# When enabling DispatchMode, some aten ops are dispatched to FunctionMode
# We are using it for aten.alias.default and aten.set_.source_Tensor
class DefaultDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
return func(*args, **(kwargs or {}))
function_mode = ZeroGPUFunctionMode()
dispatch_mode = DefaultDispatchMode()
def _untyped_storage_new_register(*args, **kwargs):
cuda = False
if (device := kwargs.get('device')) is not None:
device = torch.device(device)
if device.type == 'cuda':
cuda = True
del kwargs['device']
storage = torch._C.StorageBase.__new__(*args, **kwargs)
if cuda:
storage._zerogpu = True
return storage
@property
def _untyped_storage_device(self):
if hasattr(self, '_zerogpu'):
return torch.device('cuda', index=0)
return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue]
# Force dispatch
def _tensor_make_subclass_function_mode(*args, **kwargs):
with torch._C.DisableTorchFunction():
return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
def _asarray_function_mode(*args, **kwargs):
with torch._C.DisableTorchFunction():
return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
class _DeviceStringOnlyMeta(type):
def __instancecheck__(cls, instance):
return isinstance(instance, _device)
class _DeviceStringOnly(metaclass=_DeviceStringOnlyMeta):
def __new__(cls, *args, **kwargs):
args, kwargs = no_int_device(*args, **kwargs)
return _device(*args, **kwargs)
def _cuda_init_raise():
raise RuntimeError(
"Low-level CUDA init (`torch._C._cuda_init`) reached. "
"This means ZeroGPU's PyTorch CUDA emulation mode "
"did not intercept a CUDA operation in your code.\n"
"Check this stacktrace to locate the trigger."
)
def _cuda_dummy_exchange_device(device):
assert device in {-1, 0}
return device
def patch():
function_mode.__enter__()
dispatch_mode.__enter__()
# TODO: only patch bellow methods on current Thread to be consistent with TorchModes
# (or hijack threading.Thread.__init__ to force Modes on all threads)
torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue]
torch.UntypedStorage.__new__ = _untyped_storage_new_register
torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue]
torch.asarray = _asarray_function_mode
torch.device = _DeviceStringOnly
torch._C._cuda_init = _cuda_init_raise
torch.cuda._exchange_device = _cuda_dummy_exchange_device
torch.cuda.is_available = lambda: True
torch.cuda.device_count = lambda: 1
torch.cuda.current_device = lambda: 0
torch.cuda.synchronize = lambda *args: None
torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
torch.cuda.memory.memory_stats_as_nested_dict = lambda *args, **kwargs: CUDA_MEMORY_STATS_AS_NESTED_DICT
torch.cuda.cudart = lambda: cudart
# PyTorch 2.3
if _cuda_maybe_exchange_device is not None: # pragma: no cover
setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device)
bitsandbytes().patch()
def unpatch():
try:
dispatch_mode.__exit__(None, None, None)
function_mode.__exit__(None, None, None)
except RuntimeError:
pass # patch() and unpatch() called from != threads
torch.Tensor._make_subclass = _tensor_make_subclass
torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue]
torch.asarray = _asarray
torch.device = _device
torch._C._cuda_init = _cuda_init
torch.cuda._exchange_device = _cuda_exchange_device
torch.cuda.is_available = _cuda_available
torch.cuda.device_count = _cuda_device_count
torch.cuda.current_device = _cuda_current_device
torch.cuda.synchronize = _cuda_synchronize
torch.cuda.get_device_capability = _cuda_get_device_capability
torch.cuda.get_device_properties = _cuda_get_device_properties
torch.cuda.get_device_name = _cuda_get_device_name
torch.cuda.memory.memory_stats_as_nested_dict = _cuda_memory_stats_as_nested_dict
torch.cuda.cudart = _cuda_cudart
# PyTorch 2.3
if _cuda_maybe_exchange_device is not None: # pragma: no cover
setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
bitsandbytes().unpatch()
def _total_unpacked_size():
tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None]
deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors}
return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()])
def _pack(offload_dir: str):
# Pack to disk
originals: set[torch.Tensor] = set()
originals_dedup: dict[AliasId, torch.Tensor] = {}
fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list)
for fake, original in cuda_aliases.items():
# TODO filter-out sparse Tensors
if original is not None:
original_id = AliasId.from_tensor(original)
if original_id not in originals_dedup:
originals_dedup[original_id] = original
originals |= {original}
fakes[originals_dedup[original_id]] += [fake]
total_size = _total_unpacked_size()
progress = tqdm(
total=total_size,
unit='B',
unit_scale=True,
desc="ZeroGPU tensors packing",
) if tqdm is not None else nullcontext()
with progress as progress:
update = progress.update if progress is not None else lambda _: None
pack = pack_tensors(originals, fakes, offload_dir, callback=update)
tensor_packs.append(pack)
# Free memory
for fake_list in fakes.values():
for fake in fake_list:
cuda_aliases[fake] = None
return total_size
def pack():
if len(cuda_aliases) == 0:
return 0
shutil.rmtree(Config.zerogpu_offload_dir, ignore_errors=True)
Path(Config.zerogpu_offload_dir).mkdir(parents=True)
mmap_addrs = [tensor.data_ptr() for tensor in cuda_aliases.values() if tensor is not None]
with unmap_capture(mmap_addrs) as unmapped_paths:
total_size = _pack(Config.zerogpu_offload_dir)
total_cleaned_size = 0
for path in unmapped_paths:
total_cleaned_size += path.lstat().st_size
path.unlink(missing_ok=True)
if total_cleaned_size > 0:
print(
f"Cleaned {total_cleaned_size / 2**30:.2f}GB of tensor files "
f"from {Config.zerogpu_mmap_autoprune_pattern} after packing"
)
gc.collect()
malloc_trim()
return total_size
def init(nvidia_uuid: str):
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
torch.Tensor([0]).cuda()
def size():
return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs])
def _move(callback: Callable[[int], Any] | None = None):
callback = callback if callback is not None else lambda _: None
# CPU -> CUDA
pinned_limit = _total_unpacked_size() * PINNED_MEMORY_RATIO_LIMIT
moved: dict[AliasId, torch.Tensor] = {}
for fake, original in cuda_aliases.items():
if original is not None:
original = torch.Tensor(original) # unwrap subclass
original_id = AliasId.from_tensor(original)
if original_id not in moved:
if original.numel() * original.element_size() < pinned_limit:
original_cuda = original.pin_memory().cuda(non_blocking=True)
else:
original_cuda = original.cuda()
moved[original_id] = original_cuda
callback(fake.numel() * fake.element_size())
torch.cuda.synchronize()
for fake, original in cuda_aliases.items():
if original is not None:
fake.data = moved[AliasId.from_tensor(original)]
# Disk -> CUDA
for tensor_pack in tensor_packs:
pack_to_cuda(tensor_pack, callback=callback)
bitsandbytes().move()
def move(callback: Callable[[int], Any] | None = None):
callback = callback if callback is not None else lambda _: None
with ThreadPoolExecutor(1) as e:
e.submit(copy_context().run, _move, callback=callback).result()
torch.cuda.synchronize()
def is_in_bad_fork():
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
f = e.submit(torch.cuda._is_in_bad_fork)
return f.result()
def bitsandbytes():
# Lazy import
from . import bitsandbytes
return bitsandbytes