|
|
import copy |
|
|
import h5py |
|
|
import torch |
|
|
import numpy as np |
|
|
from time import time |
|
|
from typing import List, Tuple, Union, Any |
|
|
|
|
|
import src |
|
|
from src.utils import tensor_idx, is_sorted, indices_to_pointers, \ |
|
|
sizes_to_pointers, fast_repeat, save_tensor, load_tensor |
|
|
|
|
|
|
|
|
__all__ = ['CSRData', 'CSRBatch'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CSRData: |
|
|
"""Implements the CSRData format and associated mechanisms in Torch. |
|
|
|
|
|
When defining a subclass A of CSRData, it is recommended to create |
|
|
an associated CSRBatch subclass by doing the following: |
|
|
- ABatch inherits from (A, CSRBatch) |
|
|
- A.get_base_class() returns A |
|
|
- A.get_batch_class() returns ABatch |
|
|
""" |
|
|
|
|
|
__value_serialization_keys__ = None |
|
|
__pointer_serialization_key__ = 'pointers' |
|
|
__is_index_value_serialization_key__ = 'is_index_value' |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pointers: torch.Tensor, |
|
|
*args, |
|
|
dense: bool = False, |
|
|
is_index_value: List[bool] = None): |
|
|
"""Initialize the pointers and values. |
|
|
|
|
|
Values are passed as args and stored in a list. They are |
|
|
expected to all have the same size and support torch tensor |
|
|
indexing (i.e. they can be torch tensor or CSRData objects |
|
|
themselves). |
|
|
|
|
|
If `dense=True`, pointers are treated as a dense tensor of |
|
|
indices to be converted into pointer indices. |
|
|
|
|
|
Optionally, a list of booleans `is_index_value` can be passed. |
|
|
It must be the same size as *args and indicates, for each value, |
|
|
whether it holds elements that should be treated as indices when |
|
|
stacking CSRData objects into a CSRBatch. If so, the indices |
|
|
will be updated wrt the cumulative size of the batched values. |
|
|
""" |
|
|
if dense: |
|
|
self.pointers, order = indices_to_pointers(pointers) |
|
|
args = [a[order] for a in args] |
|
|
else: |
|
|
self.pointers = pointers |
|
|
self.values = [*args] if len(args) > 0 else None |
|
|
if is_index_value is None or is_index_value == []: |
|
|
self.is_index_value = torch.zeros(self.num_values, dtype=torch.bool) |
|
|
else: |
|
|
self.is_index_value = torch.BoolTensor(is_index_value) |
|
|
if src.is_debug_enabled(): |
|
|
self.debug() |
|
|
|
|
|
def debug(self): |
|
|
if self.pointer_key in self.value_keys: |
|
|
raise ValueError( |
|
|
f"Cannot serialize {self.__class__.__name__} object because" |
|
|
f"'{self.pointer_key}' is both in `self.pointer_key` and " |
|
|
f"`self.value_keys`.") |
|
|
|
|
|
if len(self.value_keys) != self.num_values: |
|
|
raise ValueError( |
|
|
f"Cannot serialize {self.__class__.__name__} object because" |
|
|
f"`self.value_keys` has length {len(self.value_keys)} but " |
|
|
f"`self.num_values` is {self.num_values}.") |
|
|
|
|
|
|
|
|
|
|
|
assert self.pointers[0] == 0, \ |
|
|
"The first pointer element must always be 0." |
|
|
assert torch.all(self.sizes >= 0), \ |
|
|
"pointer indices must be increasing." |
|
|
|
|
|
if self.values is not None: |
|
|
assert isinstance(self.values, list), \ |
|
|
"Values must be held in a list." |
|
|
assert all([len(v) == self.num_items for v in self.values]), \ |
|
|
"All value objects must have the same size." |
|
|
assert len(self.values[0]) == self.num_items, \ |
|
|
"pointers must cover the entire range of values." |
|
|
for v in self.values: |
|
|
if isinstance(v, CSRData): |
|
|
v.debug() |
|
|
|
|
|
if self.values is not None and self.is_index_value is not None: |
|
|
assert isinstance(self.is_index_value, torch.BoolTensor), \ |
|
|
"is_index_value must be a torch.BoolTensor." |
|
|
assert self.is_index_value.dtype == torch.bool, \ |
|
|
"is_index_value must be an tensor of booleans." |
|
|
assert self.is_index_value.ndim == 1, \ |
|
|
"is_index_value must be a 1D tensor." |
|
|
assert self.is_index_value.shape[0] == self.num_values, \ |
|
|
"is_index_value size must match the number of value tensors." |
|
|
|
|
|
def detach(self) -> 'CSRData': |
|
|
"""Detach all tensors in the CSRData.""" |
|
|
self.pointers = self.pointers.detach() |
|
|
for i in range(self.num_values): |
|
|
self.values[i] = self.values[i].detach() |
|
|
return self |
|
|
|
|
|
def to(self, device, **kwargs) -> 'CSRData': |
|
|
"""Move the CSRData to the specified device.""" |
|
|
self.pointers = self.pointers.to(device, **kwargs) |
|
|
for i in range(self.num_values): |
|
|
self.values[i] = self.values[i].to(device, **kwargs) |
|
|
return self |
|
|
|
|
|
def cpu(self, **kwargs) -> 'CSRData': |
|
|
"""Move the CSRData to the CPU.""" |
|
|
return self.to('cpu', **kwargs) |
|
|
|
|
|
def cuda(self, **kwargs) -> 'CSRData': |
|
|
"""Move the CSRData to the first available GPU.""" |
|
|
return self.to('cuda', **kwargs) |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
return self.pointers.device |
|
|
|
|
|
@property |
|
|
def num_groups(self): |
|
|
return self.pointers.shape[0] - 1 |
|
|
|
|
|
@property |
|
|
def num_values(self): |
|
|
return len(self.values) if self.values is not None else 0 |
|
|
|
|
|
@property |
|
|
def num_items(self): |
|
|
return self.pointers[-1] |
|
|
|
|
|
@property |
|
|
def sizes(self) -> torch.Tensor: |
|
|
"""Returns the size of each group (i.e. the pointer jumps). |
|
|
""" |
|
|
return self.pointers[1:] - self.pointers[:-1] |
|
|
|
|
|
@property |
|
|
def indices(self) -> torch.Tensor: |
|
|
"""Returns the dense indices corresponding to the pointers. |
|
|
""" |
|
|
return fast_repeat( |
|
|
torch.arange(self.num_groups, device=self.device), self.sizes) |
|
|
|
|
|
@classmethod |
|
|
def get_base_class(cls) -> type: |
|
|
"""Helps `self.from_list()` and `self.to_list()` identify which |
|
|
classes to use for batch collation and un-collation. |
|
|
""" |
|
|
return CSRData |
|
|
|
|
|
@classmethod |
|
|
def get_batch_class(cls) -> type: |
|
|
"""Helps `self.from_list()` and `self.to_list()` identify which |
|
|
classes to use for batch collation and un-collation. |
|
|
""" |
|
|
return CSRBatch |
|
|
|
|
|
def clone(self) -> 'CSRData': |
|
|
"""Shallow copy of self. This may cause issues for certain types |
|
|
of downstream operations, but it saves time and memory. In |
|
|
practice, it shouldn't be problematic in this project. |
|
|
""" |
|
|
out = copy.copy(self) |
|
|
out.pointers = copy.copy(self.pointers) |
|
|
out.values = copy.copy(self.values) |
|
|
return out |
|
|
|
|
|
def reindex_groups( |
|
|
self, |
|
|
group_indices: torch.Tensor, |
|
|
order: torch.Tensor = None, |
|
|
num_groups: int = None |
|
|
) -> 'CSRData': |
|
|
"""Returns a copy of self with modified pointers to account for |
|
|
new groups. Affects the num_groups and the order of groups. |
|
|
Injects 0-length pointers where need be. |
|
|
|
|
|
By default, pointers are implicitly linked to the group indices |
|
|
in range(0, self.num_groups). |
|
|
|
|
|
Here we provide new group_indices for the existing pointers, |
|
|
with group_indices[i] corresponding to the position of existing |
|
|
group i in the new tensor. The indices missing from |
|
|
group_indices account for empty groups to be injected. |
|
|
|
|
|
The num_groups specifies the number of groups in the new tensor. |
|
|
If not provided, it is inferred from the size of group_indices. |
|
|
""" |
|
|
if order is None: |
|
|
order = torch.argsort(group_indices) |
|
|
csr_new = self[order].insert_empty_groups( |
|
|
group_indices[order], num_groups=num_groups) |
|
|
return csr_new |
|
|
|
|
|
def insert_empty_groups( |
|
|
self, |
|
|
group_indices: torch.Tensor, |
|
|
num_groups: int = None |
|
|
) -> 'CSRData': |
|
|
"""Method called when in-place reindexing groups. |
|
|
|
|
|
The group_indices are assumed to be sorted and group_indices[i] |
|
|
corresponds to the position of existing group i in the new |
|
|
tensor. The indices missing from group_indices correspond to |
|
|
empty groups to be injected. |
|
|
|
|
|
The num_groups specifies the number of groups in the new tensor. |
|
|
If not provided, it is inferred from the size of group_indices. |
|
|
""" |
|
|
assert self.num_groups == group_indices.shape[0], \ |
|
|
"New group indices must correspond to the existing number " \ |
|
|
"of groups" |
|
|
assert is_sorted(group_indices), "New group indices must be sorted." |
|
|
|
|
|
if num_groups is not None: |
|
|
num_groups = max(group_indices.max() + 1, num_groups) |
|
|
else: |
|
|
num_groups = group_indices.max() + 1 |
|
|
|
|
|
starts = torch.cat([ |
|
|
torch.LongTensor([-1]).to(self.device), |
|
|
group_indices.to(self.device)]) |
|
|
ends = torch.cat([ |
|
|
group_indices.to(self.device), |
|
|
torch.LongTensor([num_groups]).to(self.device)]) |
|
|
repeats = ends - starts |
|
|
self.pointers = self.pointers.repeat_interleave(repeats) |
|
|
|
|
|
return self |
|
|
|
|
|
@staticmethod |
|
|
def index_select_pointers( |
|
|
pointers: torch.Tensor, |
|
|
indices: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Index selection of pointers. |
|
|
|
|
|
Returns a new pointer tensor with updated pointers, along with |
|
|
an index tensor to be used to update any values tensor |
|
|
associated with the input pointers. |
|
|
""" |
|
|
assert indices.max() <= pointers.shape[0] - 2 |
|
|
device = pointers.device |
|
|
|
|
|
|
|
|
pointers_new = torch.cat([ |
|
|
torch.zeros(1, dtype=pointers.dtype, device=device), |
|
|
torch.cumsum(pointers[indices + 1] - pointers[indices], 0)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sizes = pointers_new[1:] - pointers_new[:-1] |
|
|
val_idx = torch.arange(pointers_new[-1], device=device) |
|
|
val_idx -= torch.arange(pointers_new[-1] + 1, device=device)[ |
|
|
pointers_new[:-1]].repeat_interleave(sizes) |
|
|
val_idx += pointers[indices].repeat_interleave(sizes).to(device) |
|
|
|
|
|
return pointers_new, val_idx |
|
|
|
|
|
def __getitem__( |
|
|
self, |
|
|
idx: Union[int, List[int], torch.Tensor, np.ndarray] |
|
|
) -> 'CSRData': |
|
|
"""Indexing CSRData format. Supports Numpy and torch indexing |
|
|
mechanisms. |
|
|
|
|
|
Return a copy of self with updated pointers and values. |
|
|
""" |
|
|
idx = tensor_idx(idx).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
out = self.clone() |
|
|
|
|
|
|
|
|
|
|
|
if idx.shape[0] == 0: |
|
|
out.pointers = torch.LongTensor([0]) |
|
|
out.values = [v[[]] for v in self.values] |
|
|
|
|
|
else: |
|
|
|
|
|
pointers, val_idx = self.__class__.index_select_pointers( |
|
|
self.pointers, idx) |
|
|
out.pointers = pointers |
|
|
out.values = [v[val_idx] for v in self.values] |
|
|
|
|
|
if src.is_debug_enabled(): |
|
|
out.debug() |
|
|
|
|
|
return out |
|
|
|
|
|
def select( |
|
|
self, |
|
|
idx: Union[int, List[int], torch.Tensor, np.ndarray], |
|
|
*args, |
|
|
**kwargs |
|
|
) -> 'CSRData': |
|
|
"""Returns a new CSRData which indexes `self` using entries |
|
|
in `idx`. Supports torch and numpy fancy indexing. |
|
|
|
|
|
:parameter |
|
|
idx: int or 1D torch.LongTensor or numpy.NDArray |
|
|
Cluster indices to select from 'self'. Must NOT contain |
|
|
duplicates |
|
|
""" |
|
|
|
|
|
return self[idx] |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_groups |
|
|
|
|
|
def __repr__(self): |
|
|
info = [ |
|
|
f"{key}={int(getattr(self, key))}" |
|
|
for key in ['num_groups', 'num_items']] |
|
|
info.append(f"device={self.device}") |
|
|
return f"{self.__class__.__name__}({', '.join(info)})" |
|
|
|
|
|
def __eq__(self, other: Any) -> bool: |
|
|
if not isinstance(other, self.__class__): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: classes differ') |
|
|
return False |
|
|
if not torch.equal(self.pointers, other.pointers): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: pointers differ') |
|
|
return False |
|
|
if not torch.equal(self.is_index_value, other.is_index_value): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: is_index_value differ') |
|
|
return False |
|
|
if self.num_values != other.num_values: |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: num_values differ') |
|
|
return False |
|
|
for v1, v2 in zip(self.values, other.values): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not torch.equal(v1, v2): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: values differ') |
|
|
return False |
|
|
return True |
|
|
|
|
|
def __hash__(self) -> int: |
|
|
"""Hashing for an CSRData. |
|
|
""" |
|
|
return hash(( |
|
|
self.__class__.__name__, self.pointers, *(v for v in self.values))) |
|
|
|
|
|
@property |
|
|
def pointer_key(self) -> str: |
|
|
"""Key name for pointers. This will be used as labels for |
|
|
serialization. |
|
|
""" |
|
|
return self.__pointer_serialization_key__ |
|
|
|
|
|
@property |
|
|
def value_keys(self) -> List[str]: |
|
|
"""List of names for each value. These will be used as labels |
|
|
for serialization. |
|
|
""" |
|
|
if self.__value_serialization_keys__ is None: |
|
|
return [str(i) for i in range(self.num_values)] |
|
|
return self.__value_serialization_keys__ |
|
|
|
|
|
@property |
|
|
def is_index_value_key(self) -> str: |
|
|
"""Key name for is_index_value. This will be used as labels for |
|
|
serialization. |
|
|
""" |
|
|
return self.__is_index_value_serialization_key__ |
|
|
|
|
|
def save( |
|
|
self, |
|
|
f: Union[str, h5py.File, h5py.Group], |
|
|
fp_dtype: torch.dtype = torch.float): |
|
|
"""Save CSRData to HDF5 file. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param fp_dtype: torch dtype |
|
|
Data type to which floating point tensors will be cast |
|
|
before saving |
|
|
:return: |
|
|
""" |
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'w') as file: |
|
|
self.save(file, fp_dtype=fp_dtype) |
|
|
return |
|
|
|
|
|
save_tensor(self.pointers, f, self.pointer_key, fp_dtype=fp_dtype) |
|
|
|
|
|
if self.is_index_value_key is not None: |
|
|
save_tensor( |
|
|
self.is_index_value, f, self.is_index_value_key, |
|
|
fp_dtype=fp_dtype) |
|
|
|
|
|
if self.values is None: |
|
|
return |
|
|
for k, v in zip(self.value_keys, self.values): |
|
|
save_tensor(v, f, k, fp_dtype=fp_dtype) |
|
|
|
|
|
@classmethod |
|
|
def load( |
|
|
cls, |
|
|
f: Union[str, h5py.File, h5py.Group], |
|
|
idx: Union[int, List, np.ndarray, torch.Tensor] = None, |
|
|
verbose: bool = False |
|
|
) -> 'CSRData': |
|
|
"""Load CSRData from an HDF5 file. See `CSRData.save` |
|
|
for writing such file. Options allow reading only part of the |
|
|
clusters. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param idx: int, list, numpy.ndarray, torch.Tensor |
|
|
Used to select clusters when reading. Supports fancy |
|
|
indexing |
|
|
:param verbose: bool |
|
|
""" |
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'r') as file: |
|
|
out = cls.load(file, idx=idx, verbose=verbose) |
|
|
return out |
|
|
|
|
|
start = time() |
|
|
idx = tensor_idx(idx) |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load tensor_idx : {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
|
|
|
has_sizes = '__sizes__' in f.keys() |
|
|
is_not_batch = cls != cls.get_batch_class() |
|
|
has_no_indexing = idx is None or idx.shape[0] == 0 |
|
|
if has_sizes and is_not_batch and has_no_indexing: |
|
|
return cls.get_batch_class().load(f, idx=idx, verbose=verbose) |
|
|
|
|
|
|
|
|
pointer_key = cls.__pointer_serialization_key__ |
|
|
value_keys = cls.__value_serialization_keys__ |
|
|
value_keys = value_keys if value_keys is not None else [] |
|
|
is_index_value_key = cls.__is_index_value_serialization_key__ |
|
|
assert pointer_key in f.keys() |
|
|
assert all(k in f.keys() for k in value_keys) |
|
|
assert is_index_value_key is None or is_index_value_key in f.keys() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(value_keys) == 0: |
|
|
num_values = 0 |
|
|
while str(num_values) in f.keys(): |
|
|
num_values += 1 |
|
|
value_keys = [str(i) for i in range(num_values)] |
|
|
|
|
|
if idx is None or idx.shape[0] == 0: |
|
|
start = time() |
|
|
pointers = load_tensor(f[pointer_key]) |
|
|
values = [load_tensor(f[k]) for k in value_keys] |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load read all : {time() - start:0.5f}s') |
|
|
start = time() |
|
|
out = cls(pointers, *values) |
|
|
if is_index_value_key is not None: |
|
|
out.is_index_value = load_tensor(f[is_index_value_key]).bool() |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load init : {time() - start:0.5f}s') |
|
|
return out |
|
|
|
|
|
|
|
|
start = time() |
|
|
ptr_start = load_tensor(f[pointer_key], idx=idx) |
|
|
ptr_end = load_tensor(f[pointer_key], idx=idx + 1) |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load read ptr : {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
start = time() |
|
|
pointers = torch.cat([ |
|
|
torch.zeros(1, dtype=ptr_start.dtype), |
|
|
torch.cumsum(ptr_end - ptr_start, 0)]) |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load new pointers : {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = time() |
|
|
sizes = pointers[1:] - pointers[:-1] |
|
|
val_idx = torch.arange(pointers[-1]) |
|
|
val_idx -= torch.arange(pointers[-1] + 1)[ |
|
|
pointers[:-1]].repeat_interleave(sizes) |
|
|
val_idx += ptr_start.repeat_interleave(sizes) |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load val_idx : {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
start = time() |
|
|
values = [load_tensor(f[k], idx=val_idx) for k in value_keys] |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load read values : {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
start = time() |
|
|
out = cls(pointers, *values) |
|
|
if is_index_value_key is not None: |
|
|
out.is_index_value = load_tensor(f[is_index_value_key]).bool() |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load init : {time() - start:0.5f}s') |
|
|
return out |
|
|
|
|
|
|
|
|
class CSRBatch(CSRData): |
|
|
""" |
|
|
Wrapper class of CSRData to build a batch from a list of CSRData |
|
|
data and reconstruct it afterward. |
|
|
|
|
|
When defining a subclass A of CSRData, it is recommended to create |
|
|
an associated CSRBatch subclass by doing the following: |
|
|
- ABatch inherits from (A, CSRBatch) |
|
|
- A.get_base_class() returns A |
|
|
- A.get_batch_class() returns ABatch |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
pointers: torch.Tensor, |
|
|
*args, |
|
|
dense: bool = False, |
|
|
is_index_value: List[bool] = None): |
|
|
"""Basic constructor for a CSRBatch. Batches are rather |
|
|
intended to be built using the from_list() method. |
|
|
""" |
|
|
super(CSRBatch, self).__init__( |
|
|
pointers, *args, dense=dense, is_index_value=is_index_value) |
|
|
self.__sizes__ = None |
|
|
|
|
|
@property |
|
|
def batch_pointers(self) -> torch.Tensor: |
|
|
return sizes_to_pointers(self.__sizes__) if self.__sizes__ is not None \ |
|
|
else None |
|
|
|
|
|
@property |
|
|
def batch_items_sizes(self) -> torch.Tensor: |
|
|
return self.__sizes__ if self.__sizes__ is not None else None |
|
|
|
|
|
@property |
|
|
def num_batch_items(self): |
|
|
return len(self.__sizes__) if self.__sizes__ is not None else 0 |
|
|
|
|
|
def detach(self) -> 'CSRBatch': |
|
|
"""Detach all tensors in the CSRBatch.""" |
|
|
self = super().detach() |
|
|
self.__sizes__ = self.__sizes__.detach() if self.__sizes__ is not None \ |
|
|
else None |
|
|
return self |
|
|
|
|
|
def to(self, device, **kwargs) -> 'CSRBatch': |
|
|
"""Move the CSRBatch to the specified device.""" |
|
|
out = super().to(device, **kwargs) |
|
|
out.__sizes__ = self.__sizes__.to(device, **kwargs) \ |
|
|
if self.__sizes__ is not None else None |
|
|
return out |
|
|
|
|
|
@classmethod |
|
|
def from_list(cls, csr_list: List['CSRData']) -> 'CSRBatch': |
|
|
assert isinstance(csr_list, list) and len(csr_list) > 0 |
|
|
assert isinstance(csr_list[0], CSRData), \ |
|
|
"All provided items must be CSRData objects." |
|
|
csr_cls = type(csr_list[0]) |
|
|
assert all([isinstance(csr, csr_cls) for csr in csr_list]), \ |
|
|
"All provided items must have the same class." |
|
|
device = csr_list[0].device |
|
|
assert all([csr.device == device for csr in csr_list]), \ |
|
|
"All provided items must be on the same device." |
|
|
num_values = csr_list[0].num_values |
|
|
assert all([csr.num_values == num_values for csr in csr_list]), \ |
|
|
"All provided items must have the same number of values." |
|
|
is_index_value = csr_list[0].is_index_value |
|
|
if is_index_value is not None: |
|
|
assert all([ |
|
|
np.array_equal(csr.is_index_value, is_index_value) |
|
|
for csr in csr_list]), \ |
|
|
"All provided items must have the same is_index_value." |
|
|
else: |
|
|
assert all([csr.is_index_value is None for csr in csr_list]), \ |
|
|
"All provided items must have the same is_index_value." |
|
|
if src.is_debug_enabled(): |
|
|
for csr in csr_list: |
|
|
csr.debug() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offsets = torch.cumsum(torch.LongTensor( |
|
|
[0] + [csr.num_items for csr in csr_list[:-1]]), dim=0).to(device) |
|
|
|
|
|
|
|
|
pointers = torch.cat(( |
|
|
torch.LongTensor([0]).to(device), |
|
|
*[csr.pointers[1:] + offset |
|
|
for csr, offset in zip(csr_list, offsets)]), dim=0) |
|
|
|
|
|
|
|
|
values = [] |
|
|
for i in range(num_values): |
|
|
val_list = [csr.values[i] for csr in csr_list] |
|
|
if len(val_list) > 0 and isinstance(val_list[0], CSRData): |
|
|
val = val_list[0].from_list(val_list) |
|
|
elif is_index_value[i]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offsets = torch.LongTensor( |
|
|
[0] + [ |
|
|
v.max() + 1 if v.shape[0] > 0 else 0 |
|
|
for v in val_list[:-1]]) |
|
|
cum_offsets = torch.cumsum(offsets, dim=0).to(device) |
|
|
val = torch.cat([ |
|
|
v + o for v, o in zip(val_list, cum_offsets)], dim=0) |
|
|
else: |
|
|
val = torch.cat(val_list, dim=0) |
|
|
values.append(val) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch = csr_list[0].get_batch_class()( |
|
|
pointers, *values, dense=False, is_index_value=is_index_value) |
|
|
batch.__sizes__ = torch.LongTensor([csr.num_groups for csr in csr_list]) |
|
|
|
|
|
return batch |
|
|
|
|
|
def to_list(self) -> List['CSRData']: |
|
|
if self.__sizes__ is None: |
|
|
raise RuntimeError( |
|
|
'Cannot reconstruct CSRData data list from batch because the ' |
|
|
'CSRBatch was not created using `CSRBatch.from_list()`.') |
|
|
|
|
|
group_pointers = self.batch_pointers |
|
|
item_pointers = self.pointers[group_pointers] |
|
|
|
|
|
|
|
|
pointers = [ |
|
|
self.pointers[group_pointers[i]:group_pointers[i + 1] + 1] |
|
|
- item_pointers[i] |
|
|
for i in range(self.num_batch_items)] |
|
|
|
|
|
|
|
|
values = [] |
|
|
for i in range(self.num_values): |
|
|
batch_value = self.values[i] |
|
|
|
|
|
if isinstance(batch_value, CSRData): |
|
|
val = batch_value.to_list() |
|
|
|
|
|
elif self.is_index_value[i]: |
|
|
val = [ |
|
|
batch_value[item_pointers[j]:item_pointers[j + 1]] |
|
|
- (batch_value[:item_pointers[j]].max() + 1 if j > 0 else 0) |
|
|
for j in range(self.num_batch_items)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for j in range(self.num_batch_items): |
|
|
if val[j].min() < 0: |
|
|
val[j] -= val[j].min() |
|
|
|
|
|
else: |
|
|
val = [batch_value[item_pointers[j]:item_pointers[j + 1]] |
|
|
for j in range(self.num_batch_items)] |
|
|
|
|
|
values.append(val) |
|
|
values = [list(x) for x in zip(*values)] |
|
|
|
|
|
csr_list = [ |
|
|
self.get_base_class()( |
|
|
j, *v, dense=False, is_index_value=self.is_index_value) |
|
|
for j, v in zip(pointers, values)] |
|
|
|
|
|
return csr_list |
|
|
|
|
|
def __repr__(self): |
|
|
info = [f"{key}={getattr(self, key)}" |
|
|
for key in [ |
|
|
'num_batch_items', 'num_groups', 'num_items', 'device']] |
|
|
return f"{self.__class__.__name__}({', '.join(info)})" |
|
|
|
|
|
def select( |
|
|
self, |
|
|
idx: Union[int, List[int], torch.Tensor, np.ndarray], |
|
|
*args, |
|
|
**kwargs |
|
|
) -> 'CSRData': |
|
|
"""Indexing CSRBatch format. Supports Numpy and torch indexing |
|
|
mechanisms. |
|
|
|
|
|
Since indexing breaks batching, this will return a CSRData |
|
|
object with updated pointers and values. |
|
|
""" |
|
|
|
|
|
out_batch = super().select(idx, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = self.get_base_class()( |
|
|
torch.arange(1), |
|
|
*[torch.empty(0, dtype=v.dtype) for v in self.values]) |
|
|
out.pointers = out_batch.pointers |
|
|
out.values = out_batch.values |
|
|
out.is_index_value = out_batch.is_index_value |
|
|
|
|
|
return out |
|
|
|
|
|
def save( |
|
|
self, |
|
|
f: Union[str, h5py.File, h5py.Group], |
|
|
fp_dtype: torch.dtype = torch.float): |
|
|
"""Save CSRBatch to HDF5 file. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param fp_dtype: torch dtype |
|
|
Data type to which floating point tensors will be cast |
|
|
before saving |
|
|
:return: |
|
|
""" |
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'w') as file: |
|
|
self.save(file, fp_dtype=fp_dtype) |
|
|
return |
|
|
|
|
|
|
|
|
super().save(f, fp_dtype=fp_dtype) |
|
|
|
|
|
|
|
|
|
|
|
save_tensor(self.__sizes__, f, '__sizes__', fp_dtype=fp_dtype) |
|
|
|
|
|
@classmethod |
|
|
def load( |
|
|
cls, |
|
|
f: Union[str, h5py.File, h5py.Group], |
|
|
idx: Union[int, List, np.ndarray, torch.Tensor] = None, |
|
|
verbose: bool = False |
|
|
) -> Union['CSRBatch', 'CSRData']: |
|
|
"""Load CSRBatch from an HDF5 file. See `CSRData.save` |
|
|
for writing such file. Options allow reading only part of the |
|
|
clusters. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param idx: int, list, numpy.ndarray, torch.Tensor |
|
|
Used to select clusters when reading. Supports fancy |
|
|
indexing |
|
|
:param verbose: bool |
|
|
""" |
|
|
|
|
|
|
|
|
idx = tensor_idx(idx) |
|
|
if idx is not None and idx.shape[0] != 0: |
|
|
return cls.get_base_class().load(f, idx=idx, verbose=verbose) |
|
|
|
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'r') as file: |
|
|
out = cls.load(file, idx=idx, verbose=verbose) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
if '__sizes__' not in f.keys(): |
|
|
return cls.get_base_class().load(f, idx=idx, verbose=verbose) |
|
|
|
|
|
|
|
|
|
|
|
out = super().load(f, idx=idx, verbose=verbose) |
|
|
out.__sizes__ = load_tensor(f['__sizes__']) |
|
|
return out |
|
|
|