|
|
import h5py |
|
|
import torch |
|
|
import numpy as np |
|
|
from time import time |
|
|
from typing import List, Tuple, Union |
|
|
from torch_geometric.nn.pool.consecutive import consecutive_cluster |
|
|
|
|
|
from src.data.csr import CSRData, CSRBatch |
|
|
from src.utils import has_duplicates, tensor_idx, load_tensor |
|
|
|
|
|
|
|
|
__all__ = ['Cluster', 'ClusterBatch'] |
|
|
|
|
|
|
|
|
class Cluster(CSRData): |
|
|
"""Child class of CSRData to simplify some common operations |
|
|
dedicated to cluster-point indexing. |
|
|
""" |
|
|
|
|
|
__value_serialization_keys__ = ['points'] |
|
|
__is_index_value_serialization_key__ = None |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pointers: torch.Tensor, |
|
|
points: torch.Tensor, |
|
|
dense: bool = False, |
|
|
**kwargs): |
|
|
super().__init__( |
|
|
pointers, points, dense=dense, is_index_value=[True]) |
|
|
|
|
|
@classmethod |
|
|
def get_base_class(cls) -> type: |
|
|
"""Helps `self.from_list()` and `self.to_list()` identify which |
|
|
classes to use for batch collation and un-collation. |
|
|
""" |
|
|
return Cluster |
|
|
|
|
|
@classmethod |
|
|
def get_batch_class(cls) -> type: |
|
|
"""Helps `self.from_list()` and `self.to_list()` identify which |
|
|
classes to use for batch collation and un-collation. |
|
|
""" |
|
|
return ClusterBatch |
|
|
|
|
|
@property |
|
|
def points(self) -> torch.Tensor: |
|
|
return self.values[0] |
|
|
|
|
|
@points.setter |
|
|
def points(self, points: torch.Tensor): |
|
|
assert points.device == self.device, \ |
|
|
f"Points is on {points.device} while self is on {self.device}" |
|
|
self.values[0] = points |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def num_clusters(self): |
|
|
return self.num_groups |
|
|
|
|
|
@property |
|
|
def num_points(self): |
|
|
return self.num_items |
|
|
|
|
|
def to_super_index(self) -> torch.Tensor: |
|
|
"""Return a 1D tensor of indices converting the CSR-formatted |
|
|
clustering structure in 'self' into the 'super_index' format. |
|
|
""" |
|
|
|
|
|
|
|
|
device = self.device |
|
|
out = torch.empty((self.num_items,), dtype=torch.long, device=device) |
|
|
cluster_idx = torch.arange(self.num_groups, device=device) |
|
|
out[self.points] = cluster_idx.repeat_interleave(self.sizes) |
|
|
return out |
|
|
|
|
|
def select( |
|
|
self, |
|
|
idx: Union[int, List[int], torch.Tensor, np.ndarray], |
|
|
update_sub: bool = True |
|
|
) -> Tuple['Cluster', Tuple[torch.Tensor, torch.Tensor]]: |
|
|
"""Returns a new Cluster with updated clusters and points, which |
|
|
indexes `self` using entries in `idx`. Supports torch and numpy |
|
|
fancy indexing. `idx` must NOT contain duplicate entries, as |
|
|
this would cause ambiguities in super- and sub- indices. |
|
|
|
|
|
NB: if `self` belongs to a NAG, calling this function in |
|
|
isolation may break compatibility with point and cluster indices |
|
|
in the other hierarchy levels. If consistency matters, prefer |
|
|
using NAG indexing instead. |
|
|
|
|
|
:parameter |
|
|
idx: int or 1D torch.LongTensor or numpy.NDArray |
|
|
Cluster indices to select from 'self'. Must NOT contain |
|
|
duplicates |
|
|
update_sub: bool |
|
|
If True, the point (i.e. subpoint) indices will also be |
|
|
updated to maintain dense indices. The output will then |
|
|
contain '(idx_sub, sub_super)' which can help apply these |
|
|
changes to maintain consistency with lower hierarchy levels |
|
|
of a NAG. |
|
|
|
|
|
:return: cluster, (idx_sub, sub_super) |
|
|
clusters: Cluster |
|
|
indexed cluster |
|
|
idx_sub: torch.LongTensor |
|
|
to be used with 'Data.select()' on the sub-level |
|
|
sub_super: torch.LongTensor |
|
|
to replace 'Data.super_index' on the sub-level |
|
|
""" |
|
|
|
|
|
cluster = super().select(idx) |
|
|
|
|
|
if not update_sub: |
|
|
return cluster, (None, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_cluster_points, perm = consecutive_cluster(cluster.points) |
|
|
idx_sub = cluster.points[perm] |
|
|
cluster.points = new_cluster_points |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sub_super = cluster.to_super_index() |
|
|
|
|
|
return cluster, (idx_sub, sub_super) |
|
|
|
|
|
def debug(self): |
|
|
super().debug() |
|
|
assert not has_duplicates(self.points) |
|
|
|
|
|
def __repr__(self): |
|
|
info = [ |
|
|
f"{key}={getattr(self, key)}" |
|
|
for key in ['num_clusters', 'num_points', 'device']] |
|
|
return f"{self.__class__.__name__}({', '.join(info)})" |
|
|
|
|
|
@classmethod |
|
|
def load( |
|
|
cls, |
|
|
f: Union[str, h5py.File, h5py.Group], |
|
|
idx: Union[int, List, np.ndarray, torch.Tensor] = None, |
|
|
update_sub: bool = True, |
|
|
verbose: bool = False |
|
|
) -> 'Cluster': |
|
|
"""Load Cluster from an HDF5 file. See `Cluster.save` for |
|
|
writing such file. Options allow reading only part of the |
|
|
clusters. |
|
|
|
|
|
This reproduces the behavior of Cluster.select but without |
|
|
reading the full pointer data from disk. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param idx: int, list, numpy.ndarray, torch.Tensor |
|
|
Used to select clusters when reading. Supports fancy |
|
|
indexing |
|
|
:param update_sub: bool |
|
|
If True, the point (i.e. subpoint) indices will also be |
|
|
updated to maintain dense indices. The output will then |
|
|
contain '(idx_sub, sub_super)' which can help apply these |
|
|
changes to maintain consistency with lower hierarchy levels |
|
|
of a NAG. |
|
|
:param verbose: bool |
|
|
|
|
|
:return: cluster, (idx_sub, sub_super) |
|
|
""" |
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'r') as file: |
|
|
out = cls.load( |
|
|
file, idx=idx, update_sub=update_sub, verbose=verbose) |
|
|
return out |
|
|
|
|
|
|
|
|
out = super().load(f, idx=idx, verbose=verbose) |
|
|
cluster = out[0] if isinstance(out, tuple) else out |
|
|
|
|
|
if not update_sub: |
|
|
return cluster, (None, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = time() |
|
|
new_cluster_points, perm = consecutive_cluster(cluster.points) |
|
|
idx_sub = cluster.points[perm] |
|
|
cluster.points = new_cluster_points |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load update_sub : {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = time() |
|
|
sub_super = cluster.to_super_index() |
|
|
if verbose: |
|
|
print(f'{cls.__name__}.load super_index : {time() - start:0.5f}s') |
|
|
|
|
|
return cluster, (idx_sub, sub_super) |
|
|
|
|
|
|
|
|
class ClusterBatch(Cluster, CSRBatch): |
|
|
"""Wrapper for Cluster batching.""" |
|
|
|
|
|
@classmethod |
|
|
def load( |
|
|
cls, |
|
|
f: Union[str, h5py.File, h5py.Group], |
|
|
idx: Union[int, List, np.ndarray, torch.Tensor] = None, |
|
|
update_sub: bool = True, |
|
|
verbose: bool = False |
|
|
) -> Union['ClusterBatch', 'Cluster']: |
|
|
"""Load ClusterBatch from an HDF5 file. See `Cluster.save` for |
|
|
writing such file. Options allow reading only part of the |
|
|
clusters. |
|
|
|
|
|
This reproduces the behavior of Cluster.select but without |
|
|
reading the full pointer data from disk. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param idx: int, list, numpy.ndarray, torch.Tensor |
|
|
Used to select clusters when reading. Supports fancy |
|
|
indexing |
|
|
:param update_sub: bool |
|
|
If True, the point (i.e. subpoint) indices will also be |
|
|
updated to maintain dense indices. The output will then |
|
|
contain '(idx_sub, sub_super)' which can help apply these |
|
|
changes to maintain consistency with lower hierarchy levels |
|
|
of a NAG. |
|
|
:param verbose: bool |
|
|
|
|
|
:return: cluster, (idx_sub, sub_super) |
|
|
""" |
|
|
|
|
|
|
|
|
idx = tensor_idx(idx) |
|
|
if idx is not None and idx.shape[0] != 0: |
|
|
return cls.get_base_class().load( |
|
|
f, idx=idx, update_sub=update_sub, verbose=verbose) |
|
|
|
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'r') as file: |
|
|
out = cls.load( |
|
|
file, idx=idx, update_sub=update_sub, verbose=verbose) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
if '__sizes__' not in f.keys(): |
|
|
return cls.get_base_class().load( |
|
|
f, idx=idx, update_sub=update_sub, verbose=verbose) |
|
|
|
|
|
out = super().load(f, idx=idx, update_sub=update_sub, verbose=verbose) |
|
|
out[0].__sizes__ = load_tensor(f['__sizes__']) |
|
|
return out |
|
|
|