|
|
|
|
|
import multiprocessing
|
|
|
import os
|
|
|
import threading
|
|
|
from multiprocessing import reduction
|
|
|
from multiprocessing.util import register_after_fork
|
|
|
from typing import Union
|
|
|
|
|
|
import torch
|
|
|
from torch._namedtensor_internals import check_serializing_named_tensor
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import multiprocessing.resource_sharer
|
|
|
except ImportError:
|
|
|
pass
|
|
|
|
|
|
|
|
|
class StorageWeakRef:
|
|
|
r"""A weak reference to a Storage.
|
|
|
|
|
|
The cdata member is a Python number containing the integer representation of
|
|
|
the Storage pointer.
|
|
|
"""
|
|
|
|
|
|
__slots__ = ["cdata", "_free_weak_ref"]
|
|
|
|
|
|
def __init__(self, storage):
|
|
|
self.cdata = storage._weak_ref()
|
|
|
|
|
|
|
|
|
self._free_weak_ref = torch.Storage._free_weak_ref
|
|
|
|
|
|
@classmethod
|
|
|
def from_weakref(cls, cdata):
|
|
|
instance = cls.__new__(cls)
|
|
|
instance.cdata = cdata
|
|
|
instance._free_weak_ref = torch.Storage._free_weak_ref
|
|
|
return instance
|
|
|
|
|
|
def expired(self):
|
|
|
return torch.Storage._expired(self.cdata)
|
|
|
|
|
|
def __del__(self):
|
|
|
self._free_weak_ref(self.cdata)
|
|
|
|
|
|
def __hash__(self):
|
|
|
return self.cdata
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
if id(self) == id(other):
|
|
|
return True
|
|
|
return self.cdata == other.cdata
|
|
|
|
|
|
|
|
|
class SharedCache(dict):
|
|
|
"""Dictionary from multiprocessing handles to StorageWeakRef."""
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
|
|
|
|
self.limit = 128
|
|
|
|
|
|
|
|
|
|
|
|
self._after_fork()
|
|
|
register_after_fork(self, SharedCache._after_fork)
|
|
|
|
|
|
def _after_fork(self):
|
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
def get(self, key):
|
|
|
with self.lock:
|
|
|
return dict.get(self, key)
|
|
|
|
|
|
def __setitem__(self, key, storage_ref):
|
|
|
with self.lock:
|
|
|
dict.__setitem__(self, key, storage_ref)
|
|
|
if len(self) > self.limit:
|
|
|
self.free_dead_references()
|
|
|
|
|
|
def free_dead_references(self):
|
|
|
live = 0
|
|
|
for key, storage_ref in list(self.items()):
|
|
|
if storage_ref.expired():
|
|
|
del self[key]
|
|
|
else:
|
|
|
live += 1
|
|
|
self.limit = max(128, live * 2)
|
|
|
|
|
|
|
|
|
|
|
|
shared_cache = SharedCache()
|
|
|
|
|
|
|
|
|
def rebuild_event(device, handle):
|
|
|
return torch.cuda.Event.from_ipc_handle(device, handle)
|
|
|
|
|
|
|
|
|
def reduce_event(event):
|
|
|
handle = event.ipc_handle()
|
|
|
return (rebuild_event, (event.device, handle))
|
|
|
|
|
|
|
|
|
def rebuild_tensor(cls, storage, metadata):
|
|
|
storage_offset, size, stride, requires_grad = metadata
|
|
|
t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
|
|
|
if cls == torch.nn.parameter.Parameter:
|
|
|
|
|
|
|
|
|
|
|
|
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
|
|
else:
|
|
|
t.requires_grad = requires_grad
|
|
|
return t
|
|
|
|
|
|
|
|
|
def rebuild_meta_tensor(
|
|
|
tensor_cls,
|
|
|
tensor_size,
|
|
|
tensor_stride,
|
|
|
tensor_offset,
|
|
|
dtype,
|
|
|
storage_size_bytes,
|
|
|
requires_grad,
|
|
|
):
|
|
|
untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
|
|
|
|
|
|
typed_storage = torch.TypedStorage(
|
|
|
wrap_storage=untyped_storage, dtype=dtype, _internal=True
|
|
|
)
|
|
|
|
|
|
t = torch._utils._rebuild_tensor(
|
|
|
typed_storage,
|
|
|
tensor_offset,
|
|
|
tensor_size,
|
|
|
tensor_stride,
|
|
|
)
|
|
|
|
|
|
if tensor_cls == torch.nn.parameter.Parameter:
|
|
|
|
|
|
|
|
|
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
|
|
else:
|
|
|
t.requires_grad = requires_grad
|
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
def rebuild_cuda_tensor(
|
|
|
tensor_cls,
|
|
|
tensor_size,
|
|
|
tensor_stride,
|
|
|
tensor_offset,
|
|
|
storage_cls,
|
|
|
dtype,
|
|
|
storage_device,
|
|
|
storage_handle,
|
|
|
storage_size_bytes,
|
|
|
storage_offset_bytes,
|
|
|
requires_grad,
|
|
|
ref_counter_handle,
|
|
|
ref_counter_offset,
|
|
|
event_handle,
|
|
|
event_sync_required,
|
|
|
):
|
|
|
|
|
|
if storage_handle is None or storage_size_bytes == 0:
|
|
|
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
|
|
|
else:
|
|
|
storage = storage_from_cache(
|
|
|
storage_cls, (storage_handle, storage_offset_bytes)
|
|
|
)
|
|
|
if storage is None:
|
|
|
torch.cuda._lazy_init()
|
|
|
storage = storage_cls._new_shared_cuda(
|
|
|
storage_device,
|
|
|
storage_handle,
|
|
|
storage_size_bytes,
|
|
|
storage_offset_bytes,
|
|
|
ref_counter_handle,
|
|
|
ref_counter_offset,
|
|
|
event_handle,
|
|
|
event_sync_required,
|
|
|
)
|
|
|
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
|
|
|
storage
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
storage_cls._release_ipc_counter(
|
|
|
ref_counter_handle, ref_counter_offset, device=storage_device
|
|
|
)
|
|
|
|
|
|
_storage = (
|
|
|
storage
|
|
|
if isinstance(storage, torch.UntypedStorage)
|
|
|
else storage._untyped_storage
|
|
|
)
|
|
|
|
|
|
t = torch._utils._rebuild_tensor(
|
|
|
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
|
|
|
tensor_offset,
|
|
|
tensor_size,
|
|
|
tensor_stride,
|
|
|
)
|
|
|
|
|
|
if tensor_cls == torch.nn.parameter.Parameter:
|
|
|
|
|
|
|
|
|
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
|
|
else:
|
|
|
t.requires_grad = requires_grad
|
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
def reduce_tensor(tensor):
|
|
|
if tensor.requires_grad and not tensor.is_leaf:
|
|
|
raise RuntimeError(
|
|
|
"Cowardly refusing to serialize non-leaf tensor which requires_grad, "
|
|
|
"since autograd does not support crossing process boundaries. "
|
|
|
"If you just want to transfer the data, call detach() on the tensor "
|
|
|
"before serializing (e.g., putting it on the queue)."
|
|
|
)
|
|
|
|
|
|
check_serializing_named_tensor(tensor)
|
|
|
torch.utils.hooks.warn_if_has_hooks(tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.nested._internal.nested_tensor import NestedTensor
|
|
|
|
|
|
if tensor.is_nested and not isinstance(tensor, NestedTensor):
|
|
|
return reduce_nested_tensor(tensor)
|
|
|
|
|
|
if tensor.layout in {
|
|
|
torch.sparse_coo,
|
|
|
torch.sparse_csr,
|
|
|
torch.sparse_bsr,
|
|
|
torch.sparse_csc,
|
|
|
torch.sparse_bsc,
|
|
|
}:
|
|
|
return reduce_sparse_tensor(tensor)
|
|
|
|
|
|
storage = tensor._typed_storage()
|
|
|
|
|
|
if storage._untyped_storage.device.type == "cuda":
|
|
|
(
|
|
|
device,
|
|
|
handle,
|
|
|
storage_size_bytes,
|
|
|
storage_offset_bytes,
|
|
|
ref_counter_handle,
|
|
|
ref_counter_offset,
|
|
|
event_handle,
|
|
|
event_sync_required,
|
|
|
) = storage._share_cuda_()
|
|
|
tensor_offset = tensor.storage_offset()
|
|
|
shared_cache[handle] = StorageWeakRef(storage)
|
|
|
|
|
|
|
|
|
return (
|
|
|
rebuild_cuda_tensor,
|
|
|
(
|
|
|
type(tensor),
|
|
|
tensor.size(),
|
|
|
tensor.stride(),
|
|
|
tensor_offset,
|
|
|
type(storage),
|
|
|
tensor.dtype,
|
|
|
device,
|
|
|
handle,
|
|
|
storage_size_bytes,
|
|
|
storage_offset_bytes,
|
|
|
tensor.requires_grad,
|
|
|
ref_counter_handle,
|
|
|
ref_counter_offset,
|
|
|
event_handle,
|
|
|
event_sync_required,
|
|
|
),
|
|
|
)
|
|
|
elif storage._untyped_storage.device.type == "meta":
|
|
|
return (
|
|
|
rebuild_meta_tensor,
|
|
|
(
|
|
|
type(tensor),
|
|
|
tensor.size(),
|
|
|
tensor.stride(),
|
|
|
tensor.storage_offset(),
|
|
|
tensor.dtype,
|
|
|
tensor.untyped_storage().size(),
|
|
|
tensor.requires_grad,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
|
|
|
metadata = (
|
|
|
tensor.storage_offset(),
|
|
|
tensor.size(),
|
|
|
tensor.stride(),
|
|
|
tensor.requires_grad,
|
|
|
)
|
|
|
return (rebuild_tensor, (type(tensor), storage, metadata))
|
|
|
|
|
|
|
|
|
def rebuild_nested_tensor(
|
|
|
rebuild_buffer_func,
|
|
|
rebuild_buffer_args,
|
|
|
rebuild_sizes_func,
|
|
|
rebuild_sizes_args,
|
|
|
rebuild_strides_func,
|
|
|
rebuild_strides_args,
|
|
|
rebuild_offsets_func,
|
|
|
rebuild_offsets_args,
|
|
|
):
|
|
|
buffer = rebuild_buffer_func(*rebuild_buffer_args)
|
|
|
sizes = rebuild_sizes_func(*rebuild_sizes_args)
|
|
|
strides = rebuild_strides_func(*rebuild_strides_args)
|
|
|
offsets = rebuild_offsets_func(*rebuild_offsets_args)
|
|
|
return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
|
|
|
|
|
|
|
|
|
def reduce_nested_tensor(nt):
|
|
|
rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
|
|
|
rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
|
|
|
rebuild_strides_func, rebuild_strides_args = reduce_tensor(
|
|
|
nt._nested_tensor_strides()
|
|
|
)
|
|
|
rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
|
|
|
nt._nested_tensor_storage_offsets()
|
|
|
)
|
|
|
|
|
|
return (
|
|
|
rebuild_nested_tensor,
|
|
|
(
|
|
|
rebuild_buffer_func,
|
|
|
rebuild_buffer_args,
|
|
|
rebuild_sizes_func,
|
|
|
rebuild_sizes_args,
|
|
|
rebuild_strides_func,
|
|
|
rebuild_strides_args,
|
|
|
rebuild_offsets_func,
|
|
|
rebuild_offsets_args,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
|
|
|
def rebuild_sparse_coo_tensor(
|
|
|
rebuild_indices_func,
|
|
|
rebuild_indices_args,
|
|
|
rebuild_values_func,
|
|
|
rebuild_values_args,
|
|
|
shape,
|
|
|
is_coalesced,
|
|
|
):
|
|
|
indices = rebuild_indices_func(*rebuild_indices_args)
|
|
|
values = rebuild_values_func(*rebuild_values_args)
|
|
|
return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
|
|
|
|
|
|
|
|
|
def rebuild_sparse_compressed_tensor(
|
|
|
rebuild_compressed_indices_func,
|
|
|
rebuild_compressed_indices_args,
|
|
|
rebuild_plain_indices_func,
|
|
|
rebuild_plain_indices_args,
|
|
|
rebuild_values_func,
|
|
|
rebuild_values_args,
|
|
|
shape,
|
|
|
layout,
|
|
|
):
|
|
|
compressed_indices = rebuild_compressed_indices_func(
|
|
|
*rebuild_compressed_indices_args
|
|
|
)
|
|
|
plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
|
|
|
values = rebuild_values_func(*rebuild_values_args)
|
|
|
return torch.sparse_compressed_tensor(
|
|
|
compressed_indices, plain_indices, values, shape, layout=layout
|
|
|
)
|
|
|
|
|
|
|
|
|
def reduce_sparse_tensor(sparse):
|
|
|
if sparse.layout is torch.sparse_coo:
|
|
|
rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
|
|
|
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
|
|
|
return (
|
|
|
rebuild_sparse_coo_tensor,
|
|
|
(
|
|
|
rebuild_indices_func,
|
|
|
rebuild_indices_args,
|
|
|
rebuild_values_func,
|
|
|
rebuild_values_args,
|
|
|
sparse.shape,
|
|
|
sparse.is_coalesced(),
|
|
|
),
|
|
|
)
|
|
|
else:
|
|
|
if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
|
|
compressed_indices = sparse.crow_indices()
|
|
|
plain_indices = sparse.col_indices()
|
|
|
elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
|
|
compressed_indices = sparse.ccol_indices()
|
|
|
plain_indices = sparse.row_indices()
|
|
|
else:
|
|
|
raise NotImplementedError(sparse.layout)
|
|
|
(
|
|
|
rebuild_compressed_indices_func,
|
|
|
rebuild_compressed_indices_args,
|
|
|
) = reduce_tensor(compressed_indices)
|
|
|
rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
|
|
|
plain_indices
|
|
|
)
|
|
|
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
|
|
|
return (
|
|
|
rebuild_sparse_compressed_tensor,
|
|
|
(
|
|
|
rebuild_compressed_indices_func,
|
|
|
rebuild_compressed_indices_args,
|
|
|
rebuild_plain_indices_func,
|
|
|
rebuild_plain_indices_args,
|
|
|
rebuild_values_func,
|
|
|
rebuild_values_args,
|
|
|
sparse.shape,
|
|
|
sparse.layout,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
|
|
|
def fd_id(fd):
|
|
|
|
|
|
|
|
|
|
|
|
stat = os.fstat(fd)
|
|
|
return (stat.st_ino, stat.st_dev)
|
|
|
|
|
|
|
|
|
def storage_from_cache(cls, key):
|
|
|
storage_ref = shared_cache.get(key)
|
|
|
if storage_ref is None:
|
|
|
return None
|
|
|
return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
|
|
|
|
|
|
|
|
|
def rebuild_storage_fd(cls, df, size):
|
|
|
fd = df.detach()
|
|
|
try:
|
|
|
storage = storage_from_cache(cls, fd_id(fd))
|
|
|
if storage is not None:
|
|
|
return storage
|
|
|
storage = cls._new_shared_fd_cpu(fd, size)
|
|
|
shared_cache[fd_id(fd)] = StorageWeakRef(storage)
|
|
|
return storage
|
|
|
finally:
|
|
|
os.close(fd)
|
|
|
|
|
|
|
|
|
def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
|
|
|
storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
|
|
|
cls, handle
|
|
|
)
|
|
|
if storage is not None:
|
|
|
return storage._shared_decref()
|
|
|
if dtype is None:
|
|
|
storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
|
|
|
else:
|
|
|
byte_size = size * torch._utils._element_size(dtype)
|
|
|
untyped_storage: torch.UntypedStorage = (
|
|
|
torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
|
|
|
)
|
|
|
storage = torch.TypedStorage(
|
|
|
wrap_storage=untyped_storage, dtype=dtype, _internal=True
|
|
|
)
|
|
|
shared_cache[handle] = StorageWeakRef(storage)
|
|
|
return storage._shared_decref()
|
|
|
|
|
|
|
|
|
def rebuild_storage_empty(cls):
|
|
|
return cls()
|
|
|
|
|
|
|
|
|
def rebuild_typed_storage(storage, dtype):
|
|
|
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_typed_storage(storage):
|
|
|
return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
|
|
|
|
|
|
|
|
|
def rebuild_typed_storage_child(storage, storage_type):
|
|
|
return storage_type(wrap_storage=storage, _internal=True)
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_typed_storage_child(storage):
|
|
|
return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
|
|
|
|
|
|
|
|
|
def reduce_storage(storage):
|
|
|
from . import get_sharing_strategy
|
|
|
|
|
|
if storage.is_cuda:
|
|
|
raise RuntimeError(
|
|
|
"Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
|
|
|
)
|
|
|
elif storage.device.type == "meta":
|
|
|
raise RuntimeError(
|
|
|
"Cannot pickle meta storage; try pickling a meta tensor instead"
|
|
|
)
|
|
|
elif get_sharing_strategy() == "file_system":
|
|
|
metadata = storage._share_filename_cpu_()
|
|
|
cache_key = metadata[1]
|
|
|
rebuild = rebuild_storage_filename
|
|
|
if isinstance(storage, torch.TypedStorage):
|
|
|
metadata += (storage.dtype,)
|
|
|
storage._shared_incref()
|
|
|
elif storage.size() == 0:
|
|
|
|
|
|
|
|
|
return (rebuild_storage_empty, (type(storage),))
|
|
|
else:
|
|
|
fd, size = storage._share_fd_cpu_()
|
|
|
df = multiprocessing.reduction.DupFd(fd)
|
|
|
cache_key = fd_id(fd)
|
|
|
metadata = (df, size)
|
|
|
rebuild = rebuild_storage_fd
|
|
|
|
|
|
shared_cache[cache_key] = StorageWeakRef(storage)
|
|
|
return (rebuild, (type(storage),) + metadata)
|
|
|
|
|
|
|
|
|
def init_reductions():
|
|
|
reduction.register(torch.cuda.Event, reduce_event)
|
|
|
|
|
|
for t in torch._storage_classes:
|
|
|
if t.__name__ == "UntypedStorage":
|
|
|
reduction.register(t, reduce_storage)
|
|
|
else:
|
|
|
reduction.register(t, reduce_typed_storage_child)
|
|
|
|
|
|
reduction.register(torch.storage.TypedStorage, reduce_typed_storage)
|
|
|
|
|
|
for t in torch._tensor_classes:
|
|
|
reduction.register(t, reduce_tensor)
|
|
|
|
|
|
|
|
|
reduction.register(torch.Tensor, reduce_tensor)
|
|
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
reduction.register(Parameter, reduce_tensor)
|
|
|
|