File size: 10,001 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import h5py
import torch
import numpy as np
from time import time
from typing import List, Tuple, Union
from torch_geometric.nn.pool.consecutive import consecutive_cluster
from src.data.csr import CSRData, CSRBatch
from src.utils import has_duplicates, tensor_idx, load_tensor
__all__ = ['Cluster', 'ClusterBatch']
class Cluster(CSRData):
"""Child class of CSRData to simplify some common operations
dedicated to cluster-point indexing.
"""
__value_serialization_keys__ = ['points']
__is_index_value_serialization_key__ = None
def __init__(
self,
pointers: torch.Tensor,
points: torch.Tensor,
dense: bool = False,
**kwargs):
super().__init__(
pointers, points, dense=dense, is_index_value=[True])
@classmethod
def get_base_class(cls) -> type:
"""Helps `self.from_list()` and `self.to_list()` identify which
classes to use for batch collation and un-collation.
"""
return Cluster
@classmethod
def get_batch_class(cls) -> type:
"""Helps `self.from_list()` and `self.to_list()` identify which
classes to use for batch collation and un-collation.
"""
return ClusterBatch
@property
def points(self) -> torch.Tensor:
return self.values[0]
@points.setter
def points(self, points: torch.Tensor):
assert points.device == self.device, \
f"Points is on {points.device} while self is on {self.device}"
self.values[0] = points
# if src.is_debug_enabled():
# self.debug()
@property
def num_clusters(self):
return self.num_groups
@property
def num_points(self):
return self.num_items
def to_super_index(self) -> torch.Tensor:
"""Return a 1D tensor of indices converting the CSR-formatted
clustering structure in 'self' into the 'super_index' format.
"""
# TODO: this assumes 'self.point' is a permutation, shall we
# check this (although it requires sorting) ?
device = self.device
out = torch.empty((self.num_items,), dtype=torch.long, device=device)
cluster_idx = torch.arange(self.num_groups, device=device)
out[self.points] = cluster_idx.repeat_interleave(self.sizes)
return out
def select(
self,
idx: Union[int, List[int], torch.Tensor, np.ndarray],
update_sub: bool = True
) -> Tuple['Cluster', Tuple[torch.Tensor, torch.Tensor]]:
"""Returns a new Cluster with updated clusters and points, which
indexes `self` using entries in `idx`. Supports torch and numpy
fancy indexing. `idx` must NOT contain duplicate entries, as
this would cause ambiguities in super- and sub- indices.
NB: if `self` belongs to a NAG, calling this function in
isolation may break compatibility with point and cluster indices
in the other hierarchy levels. If consistency matters, prefer
using NAG indexing instead.
:parameter
idx: int or 1D torch.LongTensor or numpy.NDArray
Cluster indices to select from 'self'. Must NOT contain
duplicates
update_sub: bool
If True, the point (i.e. subpoint) indices will also be
updated to maintain dense indices. The output will then
contain '(idx_sub, sub_super)' which can help apply these
changes to maintain consistency with lower hierarchy levels
of a NAG.
:return: cluster, (idx_sub, sub_super)
clusters: Cluster
indexed cluster
idx_sub: torch.LongTensor
to be used with 'Data.select()' on the sub-level
sub_super: torch.LongTensor
to replace 'Data.super_index' on the sub-level
"""
# Normal CSRData indexing, creates a new object in memory
cluster = super().select(idx)
if not update_sub:
return cluster, (None, None)
# Convert subpoint indices, in case some subpoints have
# disappeared. 'idx_sub' is intended to be used with
# Data.select() on the level below
# TODO: IMPORTANT consecutive_cluster is a bottleneck for NAG
# and Data indexing, can we do better ?
new_cluster_points, perm = consecutive_cluster(cluster.points)
idx_sub = cluster.points[perm]
cluster.points = new_cluster_points
# Selecting the subpoints with 'idx_sub' will not be
# enough to maintain consistency with the current points. We
# also need to update the sub-level's 'Data.super_index', which
# can be computed from 'cluster'
sub_super = cluster.to_super_index()
return cluster, (idx_sub, sub_super)
def debug(self):
super().debug()
assert not has_duplicates(self.points)
def __repr__(self):
info = [
f"{key}={getattr(self, key)}"
for key in ['num_clusters', 'num_points', 'device']]
return f"{self.__class__.__name__}({', '.join(info)})"
@classmethod
def load(
cls,
f: Union[str, h5py.File, h5py.Group],
idx: Union[int, List, np.ndarray, torch.Tensor] = None,
update_sub: bool = True,
verbose: bool = False
) -> 'Cluster':
"""Load Cluster from an HDF5 file. See `Cluster.save` for
writing such file. Options allow reading only part of the
clusters.
This reproduces the behavior of Cluster.select but without
reading the full pointer data from disk.
:param f: h5 file path of h5py.File or h5py.Group
:param idx: int, list, numpy.ndarray, torch.Tensor
Used to select clusters when reading. Supports fancy
indexing
:param update_sub: bool
If True, the point (i.e. subpoint) indices will also be
updated to maintain dense indices. The output will then
contain '(idx_sub, sub_super)' which can help apply these
changes to maintain consistency with lower hierarchy levels
of a NAG.
:param verbose: bool
:return: cluster, (idx_sub, sub_super)
"""
if not isinstance(f, (h5py.File, h5py.Group)):
with h5py.File(f, 'r') as file:
out = cls.load(
file, idx=idx, update_sub=update_sub, verbose=verbose)
return out
# CSRData load behavior
out = super().load(f, idx=idx, verbose=verbose)
cluster = out[0] if isinstance(out, tuple) else out
if not update_sub:
return cluster, (None, None)
# Convert subpoint indices, in case some subpoints have
# disappeared. 'idx_sub' is intended to be used with
# Data.select() on the level below
# TODO: IMPORTANT consecutive_cluster is a bottleneck for NAG
# and Data indexing, can we do better ?
start = time()
new_cluster_points, perm = consecutive_cluster(cluster.points)
idx_sub = cluster.points[perm]
cluster.points = new_cluster_points
if verbose:
print(f'{cls.__name__}.load update_sub : {time() - start:0.5f}s')
# Selecting the subpoints with 'idx_sub' will not be
# enough to maintain consistency with the current points. We
# also need to update the sublevel's 'Data.super_index', which
# can be computed from 'cluster'
start = time()
sub_super = cluster.to_super_index()
if verbose:
print(f'{cls.__name__}.load super_index : {time() - start:0.5f}s')
return cluster, (idx_sub, sub_super)
class ClusterBatch(Cluster, CSRBatch):
"""Wrapper for Cluster batching."""
@classmethod
def load(
cls,
f: Union[str, h5py.File, h5py.Group],
idx: Union[int, List, np.ndarray, torch.Tensor] = None,
update_sub: bool = True,
verbose: bool = False
) -> Union['ClusterBatch', 'Cluster']:
"""Load ClusterBatch from an HDF5 file. See `Cluster.save` for
writing such file. Options allow reading only part of the
clusters.
This reproduces the behavior of Cluster.select but without
reading the full pointer data from disk.
:param f: h5 file path of h5py.File or h5py.Group
:param idx: int, list, numpy.ndarray, torch.Tensor
Used to select clusters when reading. Supports fancy
indexing
:param update_sub: bool
If True, the point (i.e. subpoint) indices will also be
updated to maintain dense indices. The output will then
contain '(idx_sub, sub_super)' which can help apply these
changes to maintain consistency with lower hierarchy levels
of a NAG.
:param verbose: bool
:return: cluster, (idx_sub, sub_super)
"""
# Indexing breaks batching, so we return a base object if
# indexing is required
idx = tensor_idx(idx)
if idx is not None and idx.shape[0] != 0:
return cls.get_base_class().load(
f, idx=idx, update_sub=update_sub, verbose=verbose)
if not isinstance(f, (h5py.File, h5py.Group)):
with h5py.File(f, 'r') as file:
out = cls.load(
file, idx=idx, update_sub=update_sub, verbose=verbose)
return out
# Check if the file actually corresponds to a batch object
# rather than its corresponding base object
if '__sizes__' not in f.keys():
return cls.get_base_class().load(
f, idx=idx, update_sub=update_sub, verbose=verbose)
out = super().load(f, idx=idx, update_sub=update_sub, verbose=verbose)
out[0].__sizes__ = load_tensor(f['__sizes__'])
return out
|