|
|
|
|
|
|
|
|
import collections
|
|
|
import os
|
|
|
|
|
|
import torch
|
|
|
from torch.utils.data import get_worker_info
|
|
|
from torch.utils.data._utils.collate import (
|
|
|
default_collate_err_msg_format,
|
|
|
np_str_obj_array_pattern,
|
|
|
)
|
|
|
from lightning_fabric.utilities.seed import pl_worker_init_function
|
|
|
|
|
|
def collate(batch):
|
|
|
"""Difference with PyTorch default_collate: it can stack other tensor-like objects.
|
|
|
Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
|
|
https://github.com/cvg/pixloc
|
|
|
Released under the Apache License 2.0
|
|
|
"""
|
|
|
if not isinstance(batch, list):
|
|
|
return batch
|
|
|
|
|
|
|
|
|
batch = [elem for elem in batch if elem is not None]
|
|
|
elem = batch[0]
|
|
|
elem_type = type(elem)
|
|
|
if isinstance(elem, torch.Tensor):
|
|
|
out = None
|
|
|
if torch.utils.data.get_worker_info() is not None:
|
|
|
|
|
|
|
|
|
numel = sum(x.numel() for x in batch)
|
|
|
storage = elem.storage()._new_shared(numel, device=elem.device)
|
|
|
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
|
|
return torch.stack(batch, 0, out=out)
|
|
|
elif (
|
|
|
elem_type.__module__ == "numpy"
|
|
|
and elem_type.__name__ != "str_"
|
|
|
and elem_type.__name__ != "string_"
|
|
|
):
|
|
|
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
|
|
|
|
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
|
|
|
|
return collate([torch.as_tensor(b) for b in batch])
|
|
|
elif elem.shape == ():
|
|
|
return torch.as_tensor(batch)
|
|
|
elif isinstance(elem, float):
|
|
|
return torch.tensor(batch, dtype=torch.float64)
|
|
|
elif isinstance(elem, int):
|
|
|
return torch.tensor(batch)
|
|
|
elif isinstance(elem, (str, bytes)):
|
|
|
return batch
|
|
|
elif isinstance(elem, collections.abc.Mapping):
|
|
|
return {key: collate([d[key] for d in batch]) for key in elem}
|
|
|
elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
|
|
|
return elem_type(*(collate(samples) for samples in zip(*batch)))
|
|
|
elif isinstance(elem, collections.abc.Sequence):
|
|
|
|
|
|
it = iter(batch)
|
|
|
elem_size = len(next(it))
|
|
|
if not all(len(elem) == elem_size for elem in it):
|
|
|
raise RuntimeError("each element in list of batch should be of equal size")
|
|
|
transposed = zip(*batch)
|
|
|
return [collate(samples) for samples in transposed]
|
|
|
else:
|
|
|
|
|
|
try:
|
|
|
return torch.stack(batch, 0)
|
|
|
except TypeError as e:
|
|
|
if "expected Tensor as element" in str(e):
|
|
|
return batch
|
|
|
else:
|
|
|
raise e
|
|
|
|
|
|
|
|
|
def set_num_threads(nt):
|
|
|
"""Force numpy and other libraries to use a limited number of threads."""
|
|
|
try:
|
|
|
import mkl
|
|
|
except ImportError:
|
|
|
pass
|
|
|
else:
|
|
|
mkl.set_num_threads(nt)
|
|
|
torch.set_num_threads(1)
|
|
|
os.environ["IPC_ENABLE"] = "1"
|
|
|
for o in [
|
|
|
"OPENBLAS_NUM_THREADS",
|
|
|
"NUMEXPR_NUM_THREADS",
|
|
|
"OMP_NUM_THREADS",
|
|
|
"MKL_NUM_THREADS",
|
|
|
]:
|
|
|
os.environ[o] = str(nt)
|
|
|
|
|
|
|
|
|
def worker_init_fn(i):
|
|
|
info = get_worker_info()
|
|
|
pl_worker_init_function(info.id)
|
|
|
num_threads = info.dataset.cfg.get("num_threads")
|
|
|
if num_threads is not None:
|
|
|
set_num_threads(num_threads) |