|
|
import copy |
|
|
import h5py |
|
|
import torch |
|
|
import warnings |
|
|
import numpy as np |
|
|
from time import time |
|
|
from typing import List, Tuple, Optional, Union, Any |
|
|
from torch_geometric.data import Data as PyGData |
|
|
from torch_geometric.data import Batch as PyGBatch |
|
|
from torch_geometric.nn.pool.consecutive import consecutive_cluster |
|
|
|
|
|
import src |
|
|
from src.data.cluster import CSRData |
|
|
from src.data.cluster import Cluster |
|
|
from src.data.instance import InstanceData |
|
|
from src.metrics import SemanticMetricResults, PanopticMetricResults, \ |
|
|
InstanceMetricResults |
|
|
from src.utils import tensor_idx, is_dense, has_duplicates, \ |
|
|
isolated_nodes, knn_2, save_tensor, load_tensor, save_tensor_dict, \ |
|
|
load_tensor_dict, save_dense_to_csr, load_csr_to_dense, to_trimmed, \ |
|
|
to_float_rgb, to_byte_rgb |
|
|
|
|
|
|
|
|
__all__ = ['Data', 'Batch'] |
|
|
|
|
|
|
|
|
class Data(PyGData): |
|
|
"""Inherit from torch_geometric.Data with extensions tailored to our |
|
|
specific needs. |
|
|
""" |
|
|
|
|
|
_NOT_INDEXABLE = [ |
|
|
'_csr_', '_cluster_', '_instance_data_', 'edge_index', 'edge_attr', |
|
|
'_slice_dict', '_inc_dict', '_num_graphs'] |
|
|
|
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
if src.is_debug_enabled(): |
|
|
self.debug() |
|
|
|
|
|
@property |
|
|
def pos(self): |
|
|
return self['pos'] if 'pos' in self._store else None |
|
|
|
|
|
@property |
|
|
def rgb(self): |
|
|
return self['rgb'] if 'rgb' in self._store else None |
|
|
|
|
|
@property |
|
|
def obj(self) -> InstanceData: |
|
|
"""InstanceData object indicating the instance indices for each |
|
|
node/point/superpoint in the Data. |
|
|
""" |
|
|
return self['obj'] if 'obj' in self._store else None |
|
|
|
|
|
@property |
|
|
def semantic_pred(self): |
|
|
return self['semantic_pred'] if 'semantic_pred' in self._store else None |
|
|
|
|
|
@property |
|
|
def neighbor_index(self): |
|
|
return self['neighbor_index'] if 'neighbor_index' in self._store \ |
|
|
else None |
|
|
|
|
|
@property |
|
|
def sub(self) -> Cluster: |
|
|
"""Cluster object indicating subpoint indices for each point.""" |
|
|
return self['sub'] if 'sub' in self._store else None |
|
|
|
|
|
@property |
|
|
def super_index(self): |
|
|
"""Index of the superpoint each point belongs to.""" |
|
|
return self['super_index'] if 'super_index' in self._store else None |
|
|
|
|
|
@property |
|
|
def v_edge_attr(self): |
|
|
"""Vertical edge features.""" |
|
|
return self['v_edge_attr'] if 'v_edge_attr' in self._store else None |
|
|
|
|
|
def norm_index(self, mode: str = 'graph') -> torch.Tensor: |
|
|
"""Index to be used for LayerNorm. |
|
|
|
|
|
:param mode: str |
|
|
Normalization mode. 'graph' will normalize per graph (i.e. |
|
|
per cloud, i.e. per batch). 'node' will normalize per node |
|
|
(i.e. per point). 'segment' will normalize per segment |
|
|
(i.e. per cluster) |
|
|
""" |
|
|
if getattr(self, 'batch', None) is not None: |
|
|
batch = self.batch |
|
|
else: |
|
|
batch = torch.zeros( |
|
|
self.num_nodes, device=self.device, dtype=torch.long) |
|
|
if self.super_index is not None: |
|
|
super_index = self.super_index |
|
|
else: |
|
|
super_index = torch.zeros( |
|
|
self.num_nodes, device=self.device, dtype=torch.long) |
|
|
if mode == 'graph': |
|
|
return batch |
|
|
elif mode == 'node': |
|
|
return torch.arange(self.num_nodes, device=self.device) |
|
|
elif mode == 'segment': |
|
|
num_batches = batch.max() + 1 |
|
|
return super_index * num_batches + batch |
|
|
else: |
|
|
raise NotImplementedError(f"Unknown mode='{mode}'") |
|
|
|
|
|
@property |
|
|
def is_super(self): |
|
|
"""Whether the points are superpoints for a denser sub-graph.""" |
|
|
return self.sub is not None |
|
|
|
|
|
@property |
|
|
def is_sub(self): |
|
|
"""Whether the points belong to a coarser super-graph.""" |
|
|
return self.super_index is not None |
|
|
|
|
|
@property |
|
|
def has_neighbors(self): |
|
|
"""Whether the points have neighbors.""" |
|
|
return self.neighbor_index is not None and self.neighbor_index.shape[1] > 0 |
|
|
|
|
|
@property |
|
|
def has_edges(self): |
|
|
"""Whether the points have edges.""" |
|
|
return self.edge_index is not None and self.edge_index.shape[1] > 0 |
|
|
|
|
|
@property |
|
|
def has_edge_attr(self): |
|
|
"""Whether the edges have features in `edge_attr`.""" |
|
|
return self.edge_attr is not None and self.edge_attr.shape[0] > 0 |
|
|
|
|
|
@property |
|
|
def edge_keys(self) -> List[str]: |
|
|
"""All keys starting with `edge_`, apart from `edge_index` and |
|
|
`edge_attr`. |
|
|
""" |
|
|
return [ |
|
|
k for k in self.keys |
|
|
if k.startswith('edge_') and k not in ['edge_index', 'edge_attr']] |
|
|
|
|
|
def raise_if_edge_keys(self): |
|
|
"""This is a TEMPORARY, HACKY method to be called wherever |
|
|
edge_keys may cause an issue. |
|
|
""" |
|
|
if len(self.edge_keys) > 0: |
|
|
raise NotImplementedError( |
|
|
"Edge keys are not fully supported yet, please consider " |
|
|
"stacking all your `edge_` attributes in `edge_attr` for the " |
|
|
"time being. This error was triggered by the presence of the " |
|
|
f"following attributes: {self.edge_keys}") |
|
|
|
|
|
@property |
|
|
def v_edge_keys(self) -> List[str]: |
|
|
"""All keys starting with `v_edge_`.""" |
|
|
return [k for k in self.keys if k.startswith('v_edge_')] |
|
|
|
|
|
@property |
|
|
def num_edges(self): |
|
|
"""Overwrite the torch_geometric initial definition, which |
|
|
somehow returns incorrect results, like: |
|
|
data.num_edges != data.edge_index.shape[1] |
|
|
""" |
|
|
return self.edge_index.shape[1] if self.has_edges else 0 |
|
|
|
|
|
@property |
|
|
def num_points(self): |
|
|
return self.num_nodes |
|
|
|
|
|
@property |
|
|
def num_super(self): |
|
|
return self.super_index.max().item() + 1 if self.is_sub else 0 |
|
|
|
|
|
@property |
|
|
def num_sub(self): |
|
|
return self.sub.points.max().item() + 1 if self.is_super else 0 |
|
|
|
|
|
def detach(self) -> 'Data': |
|
|
"""Extend `torch_geometric.Data.detach` to handle Cluster and |
|
|
InstanceData attributes. |
|
|
""" |
|
|
self = super().detach() |
|
|
for k in self.keys: |
|
|
if isinstance(self[k], CSRData): |
|
|
self[k] = self[k].detach() |
|
|
return self |
|
|
|
|
|
def to(self, device, **kwargs) -> 'Data': |
|
|
"""Extend `torch_geometric.Data.to` to handle Cluster and |
|
|
InstanceData attributes. |
|
|
""" |
|
|
self = super().to(device, **kwargs) |
|
|
for k in self.keys: |
|
|
if isinstance(self[k], CSRData): |
|
|
self[k] = self[k].to(device, **kwargs) |
|
|
return self |
|
|
|
|
|
def cpu(self, **kwargs) -> 'Data': |
|
|
"""Move the NAG with all Data in it to CPU.""" |
|
|
return self.to('cpu', **kwargs) |
|
|
|
|
|
def cuda(self, **kwargs) -> 'Data': |
|
|
"""Move the NAG with all Data in it to CUDA.""" |
|
|
return self.to('cuda', **kwargs) |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
"""Device of the first-encountered tensor in 'self'.""" |
|
|
for key, item in self: |
|
|
if torch.is_tensor(item): |
|
|
return item.device |
|
|
return torch.tensor([]).device |
|
|
|
|
|
def debug(self): |
|
|
"""Sanity checks.""" |
|
|
self.validate() |
|
|
|
|
|
if self.is_super: |
|
|
assert isinstance(self.sub, Cluster), \ |
|
|
"Clusters in 'sub' must be expressed using a Cluster object" |
|
|
assert self.y is None or self.y.dim() == 2, \ |
|
|
"Clusters in 'sub' must hold label histograms" |
|
|
|
|
|
if self.obj is not None: |
|
|
assert isinstance(self.obj, InstanceData), \ |
|
|
"Instance labels in 'obj' must be expressed using an " \ |
|
|
"InstanceData object" |
|
|
|
|
|
if self.is_sub: |
|
|
if not is_dense(self.super_index): |
|
|
print( |
|
|
"WARNING: super_index indices are generally expected to be " |
|
|
"dense (i.e. all indices in [0, super_index.max()] are used)," |
|
|
" which is not the case here. This may be because you are " |
|
|
"creating a Data object after applying a selection of " |
|
|
"points without updating the cluster indices.") |
|
|
|
|
|
if self.has_edges: |
|
|
assert self.edge_index.max() < self.num_points |
|
|
assert 0 <= self.edge_index.min() |
|
|
|
|
|
def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: |
|
|
"""Extend the PyG.Data.__inc__ behavior on '*index*' and |
|
|
'face' attributes to our 'super_index'. This is needed for |
|
|
maintaining clusters when batching Data objects together. |
|
|
""" |
|
|
if 'super_index' in key: |
|
|
return self.num_super |
|
|
return super().__inc__(key, value, *args, **kwargs) |
|
|
|
|
|
def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: |
|
|
"""Extend the PyG.Data.__cat_dim__ behavior on '*index*' and |
|
|
'face' attributes to our 'neighbor_index'. This is needed for |
|
|
maintaining neighbors when batching Data objects together. |
|
|
""" |
|
|
return 0 if key == 'neighbor_index' \ |
|
|
else super().__cat_dim__(key, value, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def select( |
|
|
self, |
|
|
idx: Union[int, List[int], torch.Tensor, np.ndarray], |
|
|
update_sub: bool = True, |
|
|
update_super: bool = True |
|
|
) -> Tuple['Data', Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, 'Cluster']]: |
|
|
"""Returns a new Data with updated clusters, which indexes |
|
|
`self` using entries in `idx`. Supports torch and numpy fancy |
|
|
indexing. `idx` must not contain duplicate entries, as this |
|
|
would cause ambiguities in edges and super- and sub- indices. |
|
|
|
|
|
This operations breaks neighborhoods, so if 'self.has_neighbors' |
|
|
the output Data will not. |
|
|
|
|
|
NB: if `self` belongs to a NAG, calling this function in |
|
|
isolation may break compatibility with point and cluster indices |
|
|
in the other hierarchy levels. If consistency matters, prefer |
|
|
using NAG indexing instead. |
|
|
|
|
|
:parameter |
|
|
idx: int or 1D torch.LongTensor or numpy.NDArray |
|
|
Data indices to select from 'self'. Must NOT contain |
|
|
duplicates |
|
|
update_sub: bool |
|
|
If True, the point (i.e. subpoint) indices will also be |
|
|
updated to maintain dense indices. The output will then |
|
|
contain '(idx_sub, sub_super)' which can help apply these |
|
|
changes to maintain consistency with lower hierarchy levels |
|
|
of a NAG. |
|
|
update_super: bool |
|
|
If True, the cluster (i.e. superpoint) indices will also be |
|
|
updated to maintain dense indices. The output will then |
|
|
contain '(idx_super, super_sub)' which can help apply these |
|
|
changes to maintain consistency with higher hierarchy levels |
|
|
of a NAG. |
|
|
|
|
|
:return: data, (idx_sub, sub_super), (idx_super, super_sub) |
|
|
data: Data |
|
|
indexed data |
|
|
idx_sub: torch.LongTensor |
|
|
to be used with 'Data.select()' on the sub-level |
|
|
sub_super: torch.LongTensor |
|
|
to replace 'Data.super_index' on the sub-level |
|
|
idx_super: torch.LongTensor |
|
|
to be used with 'Data.select()' on the super-level |
|
|
super_sub: Cluster |
|
|
to replace 'Data.sub' on the super-level |
|
|
""" |
|
|
device = self.device |
|
|
|
|
|
|
|
|
idx = tensor_idx(idx).to(device) |
|
|
|
|
|
|
|
|
if src.is_debug_enabled(): |
|
|
assert not has_duplicates(idx), \ |
|
|
"Duplicate indices are not supported. This would cause " \ |
|
|
"ambiguities in edges and super- and sub- indices." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = Data() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.has_edges: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reindex = torch.full( |
|
|
(self.num_nodes,), -1, dtype=torch.int64, device=device) |
|
|
reindex = reindex.scatter_( |
|
|
0, idx, torch.arange(idx.shape[0], device=device)) |
|
|
edge_index = reindex[self.edge_index] |
|
|
|
|
|
|
|
|
idx_edge = torch.where((edge_index != -1).all(dim=0))[0] |
|
|
data.edge_index = edge_index[:, idx_edge] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_sub = (None, None) |
|
|
if self.is_super: |
|
|
data.sub, out_sub = self.sub.select(idx, update_sub=update_sub) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_super = (None, None) |
|
|
if self.is_sub: |
|
|
data.super_index = self.super_index[idx] |
|
|
|
|
|
if self.is_sub and update_super: |
|
|
|
|
|
|
|
|
|
|
|
new_super_index, perm = consecutive_cluster(data.super_index) |
|
|
idx_super = data.super_index[perm] |
|
|
data.super_index = new_super_index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
super_sub = Cluster( |
|
|
data.super_index, torch.arange(idx.shape[0], device=device), |
|
|
dense=True) |
|
|
|
|
|
out_super = (idx_super, super_sub) |
|
|
|
|
|
|
|
|
warn_keys = ['neighbor_index', 'neighbor_distance'] |
|
|
skip_keys = ['edge_index', 'sub', 'super_index'] + warn_keys |
|
|
for key, item in self: |
|
|
|
|
|
|
|
|
|
|
|
if key in warn_keys and src.is_debug_enabled(): |
|
|
print( |
|
|
f"WARNING: Data.select does not support '{key}', this " |
|
|
f"attribute will be absent from the output") |
|
|
if key in skip_keys: |
|
|
continue |
|
|
|
|
|
|
|
|
if isinstance(item, CSRData): |
|
|
data[key] = item[idx] |
|
|
continue |
|
|
|
|
|
is_tensor = torch.is_tensor(item) |
|
|
is_node_size = item.shape[0] == self.num_nodes |
|
|
is_edge_size = item.shape[0] == self.num_edges |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_tensor and is_node_size and key in self.v_edge_keys: |
|
|
data[key] = item[idx] |
|
|
|
|
|
elif self.has_edges and is_tensor and is_edge_size and \ |
|
|
key in ['edge_attr'] + self.edge_keys: |
|
|
data[key] = item[idx_edge] |
|
|
|
|
|
|
|
|
elif is_tensor and is_node_size: |
|
|
data[key] = item[idx] |
|
|
|
|
|
|
|
|
else: |
|
|
data[key] = copy.deepcopy(item) |
|
|
|
|
|
|
|
|
|
|
|
if data.num_nodes != idx.shape[0]: |
|
|
data.num_nodes = idx.shape[0] |
|
|
|
|
|
return data, out_sub, out_super |
|
|
|
|
|
def is_isolated(self): |
|
|
"""If self.has_edges, returns a boolean tensor of size |
|
|
self.num_nodes indicating which are absent from self.edge_index. |
|
|
Will raise an error if self.has_edges is False. |
|
|
""" |
|
|
edge_index = self.edge_index if self.has_edges \ |
|
|
else torch.zeros(2, 0, dtype=torch.long, device=self.device) |
|
|
return isolated_nodes(edge_index, num_nodes=self.num_nodes) |
|
|
|
|
|
def connect_isolated(self, k: int = 1) -> 'Data': |
|
|
"""Search for nodes with no edges in the graph and connect them |
|
|
to their k nearest neighbors. Update self.edge_index and |
|
|
self.edge_attr accordingly. |
|
|
|
|
|
Will raise an error if self has no edges or no pos. |
|
|
|
|
|
Returns self updated with the newly-created edges. |
|
|
""" |
|
|
assert self.pos is not None |
|
|
|
|
|
|
|
|
if not self.has_edges: |
|
|
self.edge_attr = None |
|
|
|
|
|
self.raise_if_edge_keys() |
|
|
|
|
|
|
|
|
is_isolated = self.is_isolated() |
|
|
is_out = torch.where(is_isolated)[0] |
|
|
if not is_isolated.any(): |
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
high = self.pos.max(dim=0).values |
|
|
low = self.pos.min(dim=0).values |
|
|
r_max = (high - low).norm() |
|
|
neighbors, distances = knn_2( |
|
|
self.pos, |
|
|
self.pos[is_out], |
|
|
k + 1, |
|
|
r_max=r_max, |
|
|
batch_search=self.batch, |
|
|
batch_query=self.batch[is_out] if self.batch is not None else None) |
|
|
distances = distances[:, 1:] |
|
|
neighbors = neighbors[:, 1:] |
|
|
|
|
|
|
|
|
source = is_out.repeat_interleave(k) |
|
|
target = neighbors.flatten() |
|
|
edge_index_new = torch.vstack((source, target)) |
|
|
edge_index_old = self.edge_index |
|
|
self.edge_index = torch.cat((edge_index_old, edge_index_new), dim=1) |
|
|
|
|
|
|
|
|
if self.edge_attr is None: |
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
w = self.edge_attr |
|
|
s = edge_index_old[0] |
|
|
t = edge_index_old[1] |
|
|
d = (self.pos[s] - self.pos[t]).norm(dim=1) |
|
|
d_1 = torch.vstack((d, torch.ones_like(d))).T |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
a, b = torch.linalg.lstsq(d_1, w).solution |
|
|
except: |
|
|
if src.is_debug_enabled(): |
|
|
print( |
|
|
'\nWarning: torch.linalg.lstsq failed, trying again ' |
|
|
'on CPU') |
|
|
a, b = torch.linalg.lstsq(d_1.cpu(), w.cpu()).solution |
|
|
a = a.to(self.device) |
|
|
b = b.to(self.device) |
|
|
|
|
|
|
|
|
edge_attr_new = distances.flatten() * a + b |
|
|
|
|
|
|
|
|
self.edge_attr = torch.cat((self.edge_attr, edge_attr_new)) |
|
|
|
|
|
return self |
|
|
|
|
|
def to_trimmed(self, reduce: str = 'mean') -> 'Data': |
|
|
"""Convert to 'trimmed' graph: same as coalescing with the |
|
|
additional constraint that (i, j) and (j, i) edges are duplicates. |
|
|
|
|
|
If edge attributes are passed, 'reduce' will indicate how to fuse |
|
|
duplicate edges' attributes. |
|
|
|
|
|
NB: returned edges are expressed with i<j by default. |
|
|
""" |
|
|
assert self.has_edges |
|
|
|
|
|
self.raise_if_edge_keys() |
|
|
|
|
|
if self.edge_attr is not None: |
|
|
edge_index, edge_attr = to_trimmed( |
|
|
self.edge_index, edge_attr=self.edge_attr, reduce=reduce) |
|
|
else: |
|
|
edge_index = to_trimmed(self.edge_index) |
|
|
edge_attr = None |
|
|
|
|
|
self.edge_index = edge_index |
|
|
self.edge_attr = edge_attr |
|
|
|
|
|
return self |
|
|
|
|
|
def __eq__(self, other: Any) -> bool: |
|
|
if not isinstance(other, self.__class__): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: classes differ') |
|
|
return False |
|
|
if sorted(self.keys) != sorted(other.keys): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: keys differ') |
|
|
return False |
|
|
for k, v in self.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
if not torch.equal(v, other[k]): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: {k} differ') |
|
|
return False |
|
|
continue |
|
|
if isinstance(v, np.ndarray): |
|
|
if not np.array_equal(v, other[k]): |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: {k} differ') |
|
|
return False |
|
|
continue |
|
|
if v != other[k]: |
|
|
if src.is_debug_enabled(): |
|
|
print(f'{self.__class__.__name__}.__eq__: {k} differ') |
|
|
return False |
|
|
return True |
|
|
|
|
|
def save( |
|
|
self, |
|
|
f: Union[str, h5py.File, h5py.Group], |
|
|
y_to_csr: bool = True, |
|
|
pos_dtype: torch.dtype = torch.float, |
|
|
fp_dtype: torch.dtype = torch.float): |
|
|
"""Save Data to HDF5 file. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param y_to_csr: bool |
|
|
Convert 'y' to CSR format before saving. Only applies if |
|
|
'y' is a 2D histogram |
|
|
:param pos_dtype: torch dtype |
|
|
Data type to which 'pos' should be cast before saving. The |
|
|
reason for this separate treatment of 'pos' is that global |
|
|
coordinates may be too large and casting to 'fp_dtype' may |
|
|
result in hurtful precision loss |
|
|
:param fp_dtype: torch dtype |
|
|
Data type to which floating point tensors should be cast |
|
|
before saving |
|
|
:return: |
|
|
""" |
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'w') as file: |
|
|
self.save( |
|
|
file, |
|
|
y_to_csr=y_to_csr, |
|
|
pos_dtype=pos_dtype, |
|
|
fp_dtype=fp_dtype) |
|
|
return |
|
|
|
|
|
assert isinstance(f, (h5py.File, h5py.Group)) |
|
|
|
|
|
for k, val in self.items(): |
|
|
if k == 'pos_offset': |
|
|
save_tensor(val, f, k, fp_dtype=torch.double) |
|
|
elif k == 'pos': |
|
|
save_tensor(val, f, k, fp_dtype=pos_dtype) |
|
|
elif k == 'y' and val.dim() > 1 and y_to_csr: |
|
|
sg = f.create_group(f"{f.name}/_csr_/{k}") |
|
|
save_dense_to_csr(val, sg, fp_dtype=fp_dtype) |
|
|
elif k in ['rgb', 'mean_rgb']: |
|
|
if val.is_floating_point(): |
|
|
save_tensor((val * 255).byte(), f, k, fp_dtype=fp_dtype) |
|
|
else: |
|
|
save_tensor(val.byte(), f, k, fp_dtype=fp_dtype) |
|
|
elif isinstance(val, Cluster): |
|
|
sg = f.create_group(f"{f.name}/_cluster_/{k}") |
|
|
val.save(sg, fp_dtype=fp_dtype) |
|
|
elif isinstance(val, InstanceData): |
|
|
sg = f.create_group(f"{f.name}/_instance_data_/{k}") |
|
|
val.save(sg, fp_dtype=fp_dtype) |
|
|
elif isinstance(val, CSRData): |
|
|
sg = f.create_group(f"{f.name}/_csr_/{k}") |
|
|
val.save(sg, fp_dtype=fp_dtype) |
|
|
elif isinstance(val, torch.Tensor): |
|
|
save_tensor(val, f, k, fp_dtype=fp_dtype) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Cannot save attribute {k} with unsupported type " |
|
|
f"{type(val)}") |
|
|
|
|
|
|
|
|
|
|
|
f['_not_indexable_'] = list(set(self.keys) - set(self.node_attrs())) |
|
|
|
|
|
@classmethod |
|
|
def load( |
|
|
cls, |
|
|
f: Union[h5py.File, h5py.Group], |
|
|
idx: Union[int, List, np.ndarray, torch.Tensor] = None, |
|
|
keys_idx: List[str] = None, |
|
|
keys: List[str] = None, |
|
|
update_sub: bool = True, |
|
|
verbose: bool = False, |
|
|
rgb_to_float: bool = False |
|
|
) -> 'Data': |
|
|
"""Read an HDF5 file and return its content as a Data object. |
|
|
|
|
|
NB: if relevant, a Batch object will be returned. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param idx: int, list, numpy.ndarray, torch.Tensor |
|
|
Used to select the elements in `keys_idx`. Supports fancy |
|
|
indexing |
|
|
:param keys_idx: List(str) |
|
|
Keys on which the indexing should be applied |
|
|
:param keys: List(str) |
|
|
Keys should be loaded from the file, ignoring the rest |
|
|
:param update_sub: bool |
|
|
If True, the point (i.e. subpoint) indices will also be |
|
|
updated to maintain dense indices. The output will then |
|
|
contain '(idx_sub, sub_super)' which can help apply these |
|
|
changes to maintain consistency with lower hierarchy levels |
|
|
of a NAG. |
|
|
:param verbose: bool |
|
|
:param rgb_to_float: bool |
|
|
If True and an integer 'rgb' or 'mean_rgb' attribute is |
|
|
loaded, it will be cast to float |
|
|
:return: |
|
|
""" |
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'r') as file: |
|
|
out = cls.load( |
|
|
file, idx=idx, keys_idx=keys_idx, keys=keys, |
|
|
update_sub=update_sub, verbose=verbose, |
|
|
rgb_to_float=rgb_to_float) |
|
|
return out |
|
|
|
|
|
|
|
|
_not_indexable = cls._NOT_INDEXABLE |
|
|
if '_not_indexable_' in f.keys(): |
|
|
_not_indexable += [s.decode("utf-8") for s in f['_not_indexable_']] |
|
|
|
|
|
idx = tensor_idx(idx) |
|
|
if idx.shape[0] == 0: |
|
|
keys_idx = [] |
|
|
elif keys_idx is None: |
|
|
keys_idx = list(set(f.keys()) - set(_not_indexable)) |
|
|
|
|
|
if keys is None: |
|
|
all_keys = list(f.keys()) |
|
|
for k in ['_csr_', '_cluster_', '_instance_data_']: |
|
|
if k in all_keys: |
|
|
all_keys.remove(k) |
|
|
all_keys += list(f[k].keys()) |
|
|
keys = all_keys |
|
|
|
|
|
d_dict = {} |
|
|
csr_keys = [] |
|
|
cluster_keys = [] |
|
|
instance_data_keys = [] |
|
|
|
|
|
|
|
|
|
|
|
has_slice_and_inc = '_slice_dict' in f.keys() and '_inc_dict' in f.keys() |
|
|
has_batch = 'batch' in f.keys() |
|
|
has_no_indexing = idx is None or idx.shape[0] == 0 |
|
|
if (has_slice_and_inc or has_batch) and has_no_indexing: |
|
|
cls = Batch |
|
|
else: |
|
|
cls = Data |
|
|
|
|
|
|
|
|
for k in f.keys(): |
|
|
start = time() |
|
|
if k == '_not_indexable_': |
|
|
continue |
|
|
if k == '_csr_': |
|
|
csr_keys = list(f[k].keys()) |
|
|
continue |
|
|
if k == '_cluster_': |
|
|
cluster_keys = list(f[k].keys()) |
|
|
continue |
|
|
if k == '_instance_data_': |
|
|
instance_data_keys = list(f[k].keys()) |
|
|
continue |
|
|
if k in ['_slice_dict', '_inc_dict']: |
|
|
if cls == Batch: |
|
|
d_dict[k] = load_tensor_dict(f[k]) |
|
|
continue |
|
|
if k == '_num_graphs': |
|
|
if cls == Batch: |
|
|
d_dict[k] = f['_num_graphs'][0] |
|
|
continue |
|
|
if k in keys_idx: |
|
|
d_dict[k] = load_tensor(f[k], idx=idx) |
|
|
elif k in keys: |
|
|
d_dict[k] = load_tensor(f[k]) |
|
|
if verbose and k in d_dict.keys(): |
|
|
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if '_slice_dict' in d_dict.keys(): |
|
|
for k in set(d_dict['_slice_dict']) - set(d_dict['_inc_dict']): |
|
|
d_dict['_inc_dict'][k] = None |
|
|
|
|
|
|
|
|
|
|
|
if idx.shape[0] != 0: |
|
|
keys_idx = list(set(keys_idx).union(set(csr_keys))) |
|
|
keys_idx = list(set(keys_idx).union(set(cluster_keys))) |
|
|
keys_idx = list(set(keys_idx).union(set(instance_data_keys))) |
|
|
|
|
|
|
|
|
for k in csr_keys: |
|
|
start = time() |
|
|
if k in keys_idx: |
|
|
d_dict[k] = load_csr_to_dense( |
|
|
f['_csr_'][k], idx=idx, verbose=verbose) |
|
|
elif k in keys: |
|
|
d_dict[k] = load_csr_to_dense(f['_csr_'][k], verbose=verbose) |
|
|
if verbose and k in d_dict.keys(): |
|
|
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
for k in cluster_keys: |
|
|
start = time() |
|
|
if k in keys_idx: |
|
|
d_dict[k] = Cluster.load( |
|
|
f['_cluster_'][k], idx=idx, update_sub=update_sub, |
|
|
verbose=verbose)[0] |
|
|
elif k in keys: |
|
|
d_dict[k] = Cluster.load( |
|
|
f['_cluster_'][k], update_sub=update_sub, |
|
|
verbose=verbose)[0] |
|
|
if verbose and k in d_dict.keys(): |
|
|
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
for k in instance_data_keys: |
|
|
start = time() |
|
|
if k in keys_idx: |
|
|
d_dict[k] = InstanceData.load( |
|
|
f['_instance_data_'][k], idx=idx, verbose=verbose) |
|
|
elif k in keys: |
|
|
d_dict[k] = InstanceData.load( |
|
|
f['_instance_data_'][k], verbose=verbose) |
|
|
if verbose and k in d_dict.keys(): |
|
|
print(f'{cls.__name__}.load {k:<22}: {time() - start:0.5f}s') |
|
|
|
|
|
|
|
|
|
|
|
for k in ['rgb', 'mean_rgb']: |
|
|
if k in d_dict.keys(): |
|
|
d_dict[k] = to_float_rgb(d_dict[k]) if rgb_to_float \ |
|
|
else to_byte_rgb(d_dict[k]) |
|
|
|
|
|
return cls(**d_dict) |
|
|
|
|
|
def estimate_instance_centroid( |
|
|
self, |
|
|
mode: str = 'iou' |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Estimate the centroid position of each target instance |
|
|
object, based on the position of the clusters. |
|
|
|
|
|
Based on the hypothesis that clusters are relatively |
|
|
instance-pure, we approximate the centroid of each object by |
|
|
taking the barycenter of the centroids of the clusters |
|
|
overlapping with each object, weighed down by their respective |
|
|
IoUs. |
|
|
|
|
|
NB: This is a proxy and one could design failure cases, when |
|
|
clusters are not pure enough. |
|
|
|
|
|
:param mode: str |
|
|
Method used to estimate the centroids. 'iou' will weigh down |
|
|
the centroids of the clusters overlapping each instance by |
|
|
their IoU. 'ratio-product' will use the product of the size |
|
|
ratios of the overlap wrt the cluster and wrt the instance. |
|
|
'overlap' will use the size of the overlap between the |
|
|
cluster and the instance. |
|
|
|
|
|
:return obj_pos, obj_idx |
|
|
obj_pos: Tensor |
|
|
Estimated position for each object |
|
|
obj_idx: Tensor |
|
|
Corresponding object indices |
|
|
""" |
|
|
if self.obj is None: |
|
|
return None, None |
|
|
|
|
|
return self.obj.estimate_centroid(self.pos, mode=mode) |
|
|
|
|
|
def semantic_segmentation_oracle( |
|
|
self, |
|
|
num_classes: int, |
|
|
*metric_args, |
|
|
**metric_kwargs |
|
|
) -> SemanticMetricResults: |
|
|
"""Compute the oracle performance for semantic segmentation, |
|
|
when all nodes predict the dominant label among their points. |
|
|
This corresponds to the highest achievable performance with the |
|
|
partition at hand. |
|
|
|
|
|
This expects one of the following attributes: |
|
|
- `Data.obj`: holding node overlaps with instance annotations |
|
|
- `Data.y`: holding node label histograms |
|
|
|
|
|
:param num_classes: int |
|
|
Number of valid classes. By convention, we assume |
|
|
`y ∈ [0, num_classes-1]` are VALID LABELS, while |
|
|
`y < 0` AND `y >= num_classes` ARE VOID LABELS |
|
|
:param metric_args: |
|
|
Args for the metrics computation |
|
|
:param metric_kwargs: |
|
|
Kwargs for the metrics computation |
|
|
|
|
|
:return: mIoU, pre-class IoU, OA, mAcc |
|
|
""" |
|
|
|
|
|
if self.obj is not None: |
|
|
return self.obj.semantic_segmentation_oracle( |
|
|
num_classes, *metric_args, **metric_kwargs) |
|
|
|
|
|
|
|
|
if getattr(self, 'y', None) is None: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = self.y[:, :num_classes].argmax(dim=1) |
|
|
target = self.y |
|
|
|
|
|
|
|
|
from src.metrics import ConfusionMatrix |
|
|
cm = ConfusionMatrix(num_classes, *metric_args, **metric_kwargs) |
|
|
cm(pred.cpu(), target.cpu()) |
|
|
metrics = cm.all_metrics() |
|
|
|
|
|
return metrics |
|
|
|
|
|
def instance_segmentation_oracle( |
|
|
self, |
|
|
*metric_args, |
|
|
**metric_kwargs |
|
|
) -> InstanceMetricResults: |
|
|
"""Compute the oracle performance for instance segmentation. |
|
|
This is a proxy for the highest achievable performance with the |
|
|
cluster partition at hand. |
|
|
|
|
|
More precisely, for the oracle prediction: |
|
|
- each cluster is assigned to the instance it shares the most |
|
|
points with |
|
|
- clusters assigned to the same instance are merged into a |
|
|
single prediction |
|
|
- each predicted instance has a score equal to its IoU with |
|
|
the assigned target instance |
|
|
|
|
|
This expects the following attributes: |
|
|
- `Data.obj`: holding node overlaps with instance annotations |
|
|
|
|
|
:param metric_args: |
|
|
Args for the metrics computation |
|
|
:param metric_kwargs: |
|
|
Kwargs for the metrics computation |
|
|
|
|
|
:return: InstanceMetricResults |
|
|
""" |
|
|
|
|
|
if self.obj is not None: |
|
|
return self.obj.instance_segmentation_oracle( |
|
|
*metric_args, **metric_kwargs) |
|
|
return |
|
|
|
|
|
def panoptic_segmentation_oracle( |
|
|
self, |
|
|
*metric_args, |
|
|
**metric_kwargs |
|
|
) -> PanopticMetricResults: |
|
|
"""Compute the oracle performance for panoptic segmentation. |
|
|
This is a proxy for the highest achievable performance with the |
|
|
cluster partition at hand. |
|
|
|
|
|
More precisely, for the oracle prediction: |
|
|
- each cluster is assigned to the instance it shares the most |
|
|
points with |
|
|
- clusters assigned to the same instance are merged into a |
|
|
single prediction |
|
|
|
|
|
This expects the following attributes: |
|
|
- `Data.obj`: holding node overlaps with instance annotations |
|
|
|
|
|
:param metric_args: |
|
|
Args for the metrics computation |
|
|
:param metric_kwargs: |
|
|
Kwargs for the metrics computation |
|
|
|
|
|
:return: PanopticMetricResults |
|
|
""" |
|
|
|
|
|
if self.obj is not None: |
|
|
return self.obj.panoptic_segmentation_oracle( |
|
|
*metric_args, **metric_kwargs) |
|
|
return |
|
|
|
|
|
def show(self, **kwargs): |
|
|
"""See `src.visualization.show`.""" |
|
|
|
|
|
from src.visualization import show |
|
|
return show(self, **kwargs) |
|
|
|
|
|
|
|
|
class Batch(PyGBatch, Data): |
|
|
"""Inherit from torch_geometric.Batch with extensions tailored to |
|
|
our specific needs. |
|
|
|
|
|
NB: contrary to PyGBatch's dynamic inheritance behavior, we force |
|
|
the explicit inheritance to our Data class, to ensure Batch objects |
|
|
share all attributes and methods of our Data class throughout the |
|
|
codebase. |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def from_data_list( |
|
|
cls, |
|
|
data_list: List[Data], |
|
|
follow_batch: Optional[List[str]] = None, |
|
|
exclude_keys: Optional[List[str]] = None |
|
|
) -> 'Batch': |
|
|
"""Overwrite torch_geometric from_data_list to be able to handle |
|
|
Cluster and InstanceData objects batching. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
|
|
|
for d in data_list: |
|
|
d.raise_if_edge_keys() |
|
|
|
|
|
|
|
|
|
|
|
has = [ |
|
|
i for i, d in enumerate(data_list) if d.edge_index is not None] |
|
|
has_not = [ |
|
|
i for i, d in enumerate(data_list) if d.edge_index is None] |
|
|
|
|
|
if len(has) > 0 and len(has_not) > 0: |
|
|
device = data_list[0].device |
|
|
edge_index = torch.empty((2, 0), device=device).long() |
|
|
|
|
|
if data_list[has[0]].edge_attr is not None: |
|
|
dim = data_list[has[0]].edge_attr.shape[1] |
|
|
edge_attr = torch.empty((0, dim), device=device).long() |
|
|
else: |
|
|
edge_attr = None |
|
|
|
|
|
for i in has_not: |
|
|
data_list[i].edge_index = edge_index |
|
|
data_list[i].edge_attr = edge_attr |
|
|
|
|
|
|
|
|
|
|
|
batch = super().from_data_list( |
|
|
data_list, follow_batch=follow_batch, exclude_keys=exclude_keys) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for k, v in data_list[0].to_dict().items(): |
|
|
if isinstance(v, CSRData): |
|
|
batch[k] = v.get_batch_class().from_list(batch[k]) |
|
|
|
|
|
return batch |
|
|
|
|
|
def to_data_list(self, strict: bool = False) -> List['Data']: |
|
|
"""Reconstruct the list of `Data` objects that were passed to |
|
|
`Batch.from_data_list()`. |
|
|
|
|
|
This extends the behavior of PyG's `to_data_list()` by also |
|
|
attempting to infer how to un-collate some attributes that may |
|
|
have been added to the `Batch` even after its initial |
|
|
construction with `from_data_list()`. |
|
|
""" |
|
|
return [ |
|
|
self.get_example(i, strict=strict) for i in range(self.num_graphs)] |
|
|
|
|
|
def get_example(self, idx: int, strict: bool = False) -> List['Data']: |
|
|
"""Overwrite torch_geometric get_example to be able to handle |
|
|
Cluster and InstanceData objects batching. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
self._infer_collation(strict=strict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bckp_dict = {} |
|
|
for k in self.keys: |
|
|
if isinstance(self[k], CSRData): |
|
|
bckp_dict[k] = self[k].clone() |
|
|
self[k] = self[k].to_list() |
|
|
|
|
|
data = super().get_example(idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = Data(**data.to_dict()) |
|
|
|
|
|
|
|
|
for k, v in bckp_dict.items(): |
|
|
self[k] = v |
|
|
|
|
|
return data |
|
|
|
|
|
def _infer_collation(self, strict: bool = False): |
|
|
"""Populate `self._slice_dict` and `self._inc_dict` with |
|
|
inferred collation for missing keys. |
|
|
|
|
|
Unlike PyG, we want to handle attributes that may have been |
|
|
added to the Batch object even if they were not present yet |
|
|
when `Batch.from_data_list()` was initially called. To this |
|
|
end, we actively search for attributes that are absent from |
|
|
`self._slice_dict` and `self._inc_dict` and, check whether |
|
|
they are attributes of type node, edge, or other. Then, based |
|
|
on the rules of `self.__cat_dim__` and `self.__inc__`, we can |
|
|
identify the desirable batching behavior. Finally, we search |
|
|
for other node and edge attributes in `self._slice_dict` and |
|
|
`self._inc_dict` to infer the number of subgraphs, and the |
|
|
number nodes and edges in each, in order to update |
|
|
`self._slice_dict` and `self._inc_dict` with appropriate |
|
|
values for the missing keys. |
|
|
""" |
|
|
if not hasattr(self, '_slice_dict'): |
|
|
raise RuntimeError( |
|
|
("Cannot reconstruct 'Data' object from 'Batch' because " |
|
|
"'Batch' was not created via 'Batch.from_data_list()'")) |
|
|
|
|
|
|
|
|
all_keys = set(self.keys) |
|
|
all_node_keys = set(self.node_attrs()) |
|
|
all_edge_keys = set(self.edge_attrs()) |
|
|
all_node_csr_keys = {k for k in all_keys if ( |
|
|
isinstance(self[k], CSRData) |
|
|
and self[k].num_groups == self.num_nodes)} |
|
|
all_other_keys = ( |
|
|
all_keys - all_node_keys - all_edge_keys - all_node_csr_keys) |
|
|
slice_keys = set(self._slice_dict.keys()) |
|
|
slice_node_keys = slice_keys.intersection(all_node_keys) |
|
|
slice_edge_keys = slice_keys.intersection(all_edge_keys) |
|
|
slice_node_csr_keys = slice_keys.intersection(all_node_csr_keys) |
|
|
slice_other_keys = slice_keys.intersection(all_other_keys) |
|
|
special_keys = set(['_num_graphs', 'ptr', 'batch']) |
|
|
missing_keys = all_keys - slice_keys - special_keys |
|
|
missing_node_keys = missing_keys.intersection(all_node_keys) |
|
|
missing_edge_keys = missing_keys.intersection(all_edge_keys) |
|
|
missing_node_csr_keys = missing_keys.intersection(all_node_csr_keys) |
|
|
missing_other_keys = missing_keys.intersection(all_other_keys) |
|
|
|
|
|
|
|
|
all_keys = list(all_keys) |
|
|
all_node_keys = list(all_node_keys) |
|
|
all_edge_keys = list(all_edge_keys) |
|
|
all_node_csr_keys = list(all_node_csr_keys) |
|
|
all_other_keys = list(all_other_keys) |
|
|
slice_keys = list(slice_keys) |
|
|
slice_node_keys = list(slice_node_keys) |
|
|
slice_edge_keys = list(slice_edge_keys) |
|
|
slice_node_csr_keys = list(slice_node_csr_keys) |
|
|
slice_other_keys = list(slice_other_keys) |
|
|
special_keys = list(special_keys) |
|
|
missing_keys = list(missing_keys) |
|
|
missing_node_keys = list(missing_node_keys) |
|
|
missing_edge_keys = list(missing_edge_keys) |
|
|
missing_node_csr_keys = list(missing_node_csr_keys) |
|
|
missing_other_keys = list(missing_other_keys) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(slice_node_keys) == 0 and len(missing_node_keys) > 0: |
|
|
if strict: |
|
|
raise ValueError( |
|
|
f"Cannot infer how to un-collate the " |
|
|
f"{self.__class__.__name__} object because none of the " |
|
|
f"node attributes {self.node_attrs()} could be found in " |
|
|
f"`self._slice_dict`. Make sure your `Data` objects have " |
|
|
f"at least one node attribute before collating them with " |
|
|
f"`Batch.from_data_list()`.") |
|
|
else: |
|
|
missing_node_keys = [] |
|
|
|
|
|
if len(slice_edge_keys) == 0 and len(missing_edge_keys) > 0: |
|
|
if strict: |
|
|
raise ValueError( |
|
|
f"Cannot infer how to un-collate the " |
|
|
f"{self.__class__.__name__} object because none of the " |
|
|
f"edge attributes {self.edge_attrs()} could be found in " |
|
|
f"`self._slice_dict`. Make sure your `Data` objects have " |
|
|
f"at least one node attribute before collating them with " |
|
|
f"`Batch.from_data_list()`.") |
|
|
else: |
|
|
missing_edge_keys = [] |
|
|
|
|
|
if len(slice_node_keys) == 0 and len(missing_node_csr_keys) > 0: |
|
|
if strict: |
|
|
raise ValueError( |
|
|
f"Cannot infer how to un-collate the " |
|
|
f"{self.__class__.__name__} object because none of the " |
|
|
f"node attributes {self.node_attrs()} could be found in " |
|
|
f"`self._slice_dict`, which prevents inferring the " |
|
|
f"collation for node attributes carrying CSRData objects: " |
|
|
f"{missing_node_csr_keys}. Make sure your `Data` objects " |
|
|
f"have at least one node attribute before collating them " |
|
|
f"with `Batch.from_data_list()`.") |
|
|
else: |
|
|
missing_node_keys = [] |
|
|
|
|
|
if len(missing_other_keys) > 0: |
|
|
if strict: |
|
|
raise ValueError( |
|
|
f"Cannot infer how to un-collate the " |
|
|
f"{self.__class__.__name__} object because some attributes " |
|
|
f"of type 'other' {missing_other_keys} could be not found " |
|
|
f"in `self._slice_dict`. Make sure all your 'other' " |
|
|
f"attributes (i.e. neither node nor edge attributes) are" |
|
|
f"in your `Data` objects before collating them with " |
|
|
f"`Batch.from_data_list()`.") |
|
|
else: |
|
|
missing_other_keys = [] |
|
|
|
|
|
|
|
|
num_graphs = self.num_graphs |
|
|
if len(slice_node_keys) > 0: |
|
|
node_ptr = self._slice_dict[slice_node_keys[0]] |
|
|
else: |
|
|
node_ptr = None |
|
|
if len(slice_edge_keys) > 0: |
|
|
edge_ptr = self._slice_dict[slice_edge_keys[0]] |
|
|
else: |
|
|
edge_ptr = None |
|
|
|
|
|
|
|
|
|
|
|
for k in missing_node_keys: |
|
|
if self.__inc__(k, self[k]) == 0: |
|
|
self._slice_dict[k] = node_ptr |
|
|
self._inc_dict[k] = torch.zeros( |
|
|
num_graphs, dtype=torch.long, device=self.device) |
|
|
continue |
|
|
elif strict: |
|
|
raise ValueError( |
|
|
f"Cannot infer how to un-collate the '{k}' attribute " |
|
|
f"because `self.__inc__('{k}') != 0`. Guessing how to " |
|
|
f"restore the increments for each batch item is ambiguous. " |
|
|
f"To collate and un-collate '{k}', make sure your `Data` " |
|
|
f"objects have the '{k}' attribute before collating them " |
|
|
f"with `Batch.from_data_list()`.") |
|
|
|
|
|
|
|
|
|
|
|
for k in missing_edge_keys: |
|
|
if self.__inc__(k, self[k]) == 0: |
|
|
self._slice_dict[k] = edge_ptr |
|
|
self._inc_dict[k] = torch.zeros( |
|
|
num_graphs, dtype=torch.long, device=self.device) |
|
|
continue |
|
|
elif strict: |
|
|
raise ValueError( |
|
|
f"Cannot infer how to un-collate the '{k}' attribute " |
|
|
f"because `self.__inc__('{k}') != 0`. Guessing how to " |
|
|
f"restore the increments for each batch item is ambiguous. " |
|
|
f"To collate and un-collate '{k}', make sure your `Data` " |
|
|
f"objects have the '{k}' attribute before collating them " |
|
|
f"with `Batch.from_data_list()`.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for k in missing_node_csr_keys: |
|
|
self._slice_dict[k] = torch.arange( |
|
|
self.num_graphs + 1, device=self.device) |
|
|
self._inc_dict[k] = None |
|
|
|
|
|
self[k] = self[k].get_batch_class()( |
|
|
self[k].pointers, |
|
|
*self[k].values, |
|
|
dense=False, |
|
|
is_index_value=self[k].is_index_value) |
|
|
|
|
|
ref_k = list(set(self._slice_dict).intersection(all_node_keys))[0] |
|
|
ptr = self._slice_dict[ref_k] |
|
|
self[k].__sizes__ = ptr[1:] - ptr[:-1] |
|
|
|
|
|
def save( |
|
|
self, |
|
|
f: Union[h5py.File, h5py.Group], |
|
|
y_to_csr: bool = True, |
|
|
pos_dtype: torch.dtype = torch.float, |
|
|
fp_dtype: torch.dtype = torch.float): |
|
|
"""Save Batch to HDF5 file. |
|
|
|
|
|
:param f: h5 file path of h5py.File or h5py.Group |
|
|
:param y_to_csr: bool |
|
|
Convert 'y' to CSR format before saving. Only applies if |
|
|
'y' is a 2D histogram |
|
|
:param pos_dtype: torch dtype |
|
|
Data type to which 'pos' should be cast before saving. The |
|
|
reason for this separate treatment of 'pos' is that global |
|
|
coordinates may be too large and casting to 'fp_dtype' may |
|
|
result in hurtful precision loss |
|
|
:param fp_dtype: torch dtype |
|
|
Data type to which floating point tensors should be cast |
|
|
before saving |
|
|
:return: |
|
|
""" |
|
|
if not isinstance(f, (h5py.File, h5py.Group)): |
|
|
with h5py.File(f, 'w') as file: |
|
|
self.save( |
|
|
file, |
|
|
y_to_csr=y_to_csr, |
|
|
pos_dtype=pos_dtype, |
|
|
fp_dtype=fp_dtype) |
|
|
return |
|
|
|
|
|
assert isinstance(f, (h5py.File, h5py.Group)) |
|
|
|
|
|
|
|
|
|
|
|
super().save( |
|
|
f, |
|
|
y_to_csr=y_to_csr, |
|
|
pos_dtype=pos_dtype, |
|
|
fp_dtype=fp_dtype) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self, '_slice_dict'): |
|
|
save_tensor_dict(self._slice_dict, f, '_slice_dict', fp_dtype=fp_dtype) |
|
|
if hasattr(self, '_inc_dict'): |
|
|
save_tensor_dict(self._inc_dict, f, '_inc_dict', fp_dtype=fp_dtype) |
|
|
if hasattr(self, '_num_graphs'): |
|
|
f.create_dataset('_num_graphs', data=np.array([self._num_graphs])) |
|
|
|
|
|
@classmethod |
|
|
def load(cls, *args, **kwargs) -> Union['Batch', 'Data']: |
|
|
"""Read an HDF5 file and return its content as a Batch object. |
|
|
|
|
|
See Data.load() |
|
|
""" |
|
|
return Data.load(*args, **kwargs) |
|
|
|