|
|
import sys |
|
|
import os.path as osp |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from torch_scatter import scatter_sum, scatter_mean |
|
|
from src.transforms import Transform |
|
|
from src.data import Data, NAG, Cluster, InstanceData |
|
|
from src.utils.cpu import available_cpu_count |
|
|
from src.utils import xy_partition |
|
|
|
|
|
dependencies_folder = osp.dirname(osp.dirname(osp.abspath(__file__))) |
|
|
sys.path.append(dependencies_folder) |
|
|
sys.path.append(osp.join(dependencies_folder, "dependencies/grid_graph/python/bin")) |
|
|
sys.path.append(osp.join(dependencies_folder, "dependencies/parallel_cut_pursuit/python/wrappers")) |
|
|
|
|
|
from grid_graph import edge_list_to_forward_star |
|
|
from cp_d0_dist import cp_d0_dist |
|
|
|
|
|
__all__ = ['CutPursuitPartition', 'GridPartition'] |
|
|
|
|
|
|
|
|
class CutPursuitPartition(Transform): |
|
|
"""Partition a graph contained in a `Data` object using cut-pursuit. |
|
|
|
|
|
The input `Data` object is assumed to hold the following attributes: |
|
|
- `pos` carrying node spatial coordinates |
|
|
- `x` carrying node features |
|
|
- `edge_index` carrying the adjacency graph edges in Pytorch |
|
|
Geometric format (typically generated with `AdjacencyGraph`) |
|
|
- `edge_attr` carrying the scalar edge weights in Pytorch |
|
|
Geometric format (typically generated with `AdjacencyGraph`) |
|
|
|
|
|
The quality of a partition may be assessed in terms of efficiency |
|
|
(how much it simplifies the input graph) and accuracy (how well it |
|
|
respects the semantic boundaries). We provide two tools for |
|
|
assessing these: `NAG.level_ratios` which computes the ratio of the |
|
|
number of elements between successive partition levels, and |
|
|
`Data.semantic_segmentation_oracle()` which computes the semantic |
|
|
segmentation metrics of a hypothetical oracle model capable of |
|
|
predicting the majority label for each superpoint. See our |
|
|
Superpoint Transformer tutorial |
|
|
`notebooks/superpoint_transformer_tutorial.ipynb` for more on this. |
|
|
|
|
|
:param regularization: float or List(float) |
|
|
Regularization strength used for each partition level. This is |
|
|
the primary parameter for adjusting cut-pursuit partitions. The |
|
|
larger the regularization, the coarser the partition, the fewer |
|
|
the superpoints, the bigger the superpoints, the lower their |
|
|
semantic purity (ie superpoints are more likely to bleed across |
|
|
semantic object boundaries). And vice versa. If a list is |
|
|
passed, the values are assumed to be increasing |
|
|
:param spatial_weight: float or List(float) |
|
|
Weight used to mitigate the impact of the point position in the |
|
|
partition. The smaller, the less spatial coordinates matter. |
|
|
This can be loosely interpreted as the inverse of a maximum |
|
|
superpoint radius. It typically affects the size of superpoints |
|
|
in geometrically/radiometrically homogeneous regions such as the |
|
|
ground, walls, or ceilings. Setting a large `spatial_weight` |
|
|
will have a "voronoi tessellation" effect on the superpoint |
|
|
partition, preventing too-large superpoints from being |
|
|
constructed in these otherwise-homogeneous regions. Inversely, |
|
|
setting a small `spatial_weight` will encourage cut-pursuit to |
|
|
create superpoints as large as possible, so long as the features |
|
|
of the points inside are homogeneous. In an extreme case: the |
|
|
entire floor would then be a single superpoint. If a list is |
|
|
passed, it must match the length of `regularization` |
|
|
:param cutoff: float or List(float) |
|
|
Minimum number of points in each superpoint. The output |
|
|
partition will not contain any superpoint smaller than `cutoff`. |
|
|
If a list is passed, it must match the length of |
|
|
`regularization` |
|
|
:param parallel: bool |
|
|
Whether cut-pursuit should run in parallel (ie on multiple CPU |
|
|
threads) |
|
|
:param iterations: int |
|
|
Maximum number of iterations for the cut-pursuit algorithm. The |
|
|
higher, the longer the processing. A value in $[10, 15]$ is |
|
|
usually sufficient |
|
|
:param k_adjacency: int |
|
|
When a node is isolated after a partition, we connect it to the |
|
|
nearest nodes. This rules the number of neighbors it should be |
|
|
connected to |
|
|
:param verbose: bool |
|
|
""" |
|
|
|
|
|
_IN_TYPE = Data |
|
|
_OUT_TYPE = NAG |
|
|
_MAX_NUM_EDGES = 4294967295 |
|
|
_NO_REPR = ['verbose', 'parallel'] |
|
|
|
|
|
def __init__( |
|
|
self, regularization=5e-2, spatial_weight=1, cutoff=10, |
|
|
parallel=True, iterations=10, k_adjacency=5, verbose=False): |
|
|
self.regularization = regularization |
|
|
self.spatial_weight = spatial_weight |
|
|
self.cutoff = cutoff |
|
|
self.parallel = parallel |
|
|
self.iterations = iterations |
|
|
self.k_adjacency = k_adjacency |
|
|
self.verbose = verbose |
|
|
|
|
|
def _process(self, data): |
|
|
|
|
|
assert data.has_edges, \ |
|
|
"Cannot compute partition, no edges in Data" |
|
|
assert data.num_nodes < np.iinfo(np.uint32).max, \ |
|
|
"Too many nodes for `uint32` indices" |
|
|
assert data.num_edges < np.iinfo(np.uint32).max, \ |
|
|
"Too many edges for `uint32` indices" |
|
|
assert isinstance(self.regularization, (int, float, list)), \ |
|
|
"Expected a scalar or a List" |
|
|
assert isinstance(self.cutoff, (int, list)), \ |
|
|
"Expected an int or a List" |
|
|
assert isinstance(self.spatial_weight, (int, float, list)), \ |
|
|
"Expected a scalar or a List" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = data.to_trimmed() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_threads = available_cpu_count() if self.parallel else 1 |
|
|
data.node_size = torch.ones( |
|
|
data.num_nodes, device=data.device, dtype=torch.long) |
|
|
data_list = [data] |
|
|
regularization = self.regularization |
|
|
if not isinstance(regularization, list): |
|
|
regularization = [regularization] |
|
|
cutoff = self.cutoff |
|
|
if isinstance(cutoff, int): |
|
|
cutoff = [cutoff] * len(regularization) |
|
|
spatial_weight = self.spatial_weight |
|
|
if isinstance(spatial_weight, (float, int)): |
|
|
spatial_weight = [spatial_weight] * len(regularization) |
|
|
assert len(regularization) == len(cutoff) == len(spatial_weight) |
|
|
n_dim = data.pos.shape[1] |
|
|
n_feat = data.x.shape[1] if data.x is not None else 0 |
|
|
|
|
|
|
|
|
for level, (reg, cut, sw) in enumerate(zip( |
|
|
regularization, cutoff, spatial_weight)): |
|
|
|
|
|
if self.verbose: |
|
|
print( |
|
|
f'Launching partition level={level} reg={reg}, ' |
|
|
f'cutoff={cut}') |
|
|
|
|
|
|
|
|
d1 = data_list[level] |
|
|
|
|
|
|
|
|
if d1.num_nodes < 2: |
|
|
break |
|
|
|
|
|
|
|
|
if d1.edge_index.shape[1] > self._MAX_NUM_EDGES and self.verbose: |
|
|
print( |
|
|
f"WARNING: number of edges {d1.edge_index.shape[1]} " |
|
|
f"exceeds the uint32 limit {self._MAX_NUM_EDGES}. Please" |
|
|
f"update the cut-pursuit source code to accept a larger " |
|
|
f"data type for `index_t`.") |
|
|
|
|
|
|
|
|
source_csr, target, reindex = edge_list_to_forward_star( |
|
|
d1.num_nodes, d1.edge_index.T.contiguous().cpu().numpy()) |
|
|
source_csr = source_csr.astype('uint32') |
|
|
target = target.astype('uint32') |
|
|
edge_weights = d1.edge_attr.cpu().numpy()[reindex] * reg \ |
|
|
if d1.edge_attr is not None else reg |
|
|
|
|
|
|
|
|
pos_offset = d1.pos.mean(dim=0) |
|
|
if d1.x is not None: |
|
|
x = torch.cat((d1.pos - pos_offset, d1.x), dim=1) |
|
|
else: |
|
|
x = d1.pos - pos_offset |
|
|
x = np.asfortranarray(x.cpu().numpy().T) |
|
|
node_size = d1.node_size.float().cpu().numpy() |
|
|
coor_weights = np.ones(n_dim + n_feat, dtype=np.float32) |
|
|
coor_weights[:n_dim] *= sw |
|
|
|
|
|
|
|
|
super_index, x_c, cluster, edges, times = cp_d0_dist( |
|
|
n_dim + n_feat, |
|
|
x, |
|
|
source_csr, |
|
|
target, |
|
|
edge_weights=edge_weights, |
|
|
vert_weights=node_size, |
|
|
coor_weights=coor_weights, |
|
|
min_comp_weight=cut, |
|
|
cp_dif_tol=1e-2, |
|
|
cp_it_max=self.iterations, |
|
|
split_damp_ratio=0.7, |
|
|
verbose=self.verbose, |
|
|
max_num_threads=num_threads, |
|
|
balance_parallel_split=True, |
|
|
compute_Time=True, |
|
|
compute_List=True, |
|
|
compute_Graph=True) |
|
|
|
|
|
if self.verbose: |
|
|
delta_t = (times[1:] - times[:-1]).round(2) |
|
|
print(f'Level {level} iteration times: {delta_t}') |
|
|
print(f'partition {level} done') |
|
|
|
|
|
|
|
|
super_index = torch.from_numpy(super_index.astype('int64')) |
|
|
d1.super_index = super_index |
|
|
|
|
|
|
|
|
|
|
|
size = torch.LongTensor([c.shape[0] for c in cluster]) |
|
|
pointer = torch.cat([torch.LongTensor([0]), size.cumsum(dim=0)]) |
|
|
value = torch.cat([ |
|
|
torch.from_numpy(x.astype('int64')) for x in cluster]) |
|
|
pos = torch.from_numpy(x_c[:n_dim].T) + pos_offset.cpu() |
|
|
x = torch.from_numpy(x_c[n_dim:].T) |
|
|
s = torch.arange(edges[0].shape[0] - 1).repeat_interleave( |
|
|
torch.from_numpy((edges[0][1:] - edges[0][:-1]).astype("int64"))) |
|
|
t = torch.from_numpy(edges[1].astype("int64")) |
|
|
edge_index = torch.vstack((s, t)) |
|
|
edge_attr = torch.from_numpy(edges[2] / reg) |
|
|
node_size = torch.from_numpy(node_size) |
|
|
node_size_new = scatter_sum( |
|
|
node_size.cuda(), super_index.cuda(), dim=0).cpu().long() |
|
|
d2 = Data( |
|
|
pos=pos, x=x, edge_index=edge_index, edge_attr=edge_attr, |
|
|
sub=Cluster(pointer, value), node_size=node_size_new) |
|
|
|
|
|
|
|
|
if d1.obj is not None and isinstance(d1.obj, InstanceData): |
|
|
d2.obj = d1.obj.merge(d1.super_index) |
|
|
|
|
|
|
|
|
d2 = d2.to_trimmed() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if d2.num_nodes > 1: |
|
|
d2 = d2.connect_isolated(k=self.k_adjacency) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'y' in d1.keys: |
|
|
assert d1.y.dim() == 2, \ |
|
|
"Expected Data.y to hold `(num_nodes, num_classes)` " \ |
|
|
"histograms, not single labels" |
|
|
d2.y = scatter_sum( |
|
|
d1.y.cuda(), d1.super_index.cuda(), dim=0).cpu() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if 'semantic_pred' in d1.keys: |
|
|
assert d1.semantic_pred.dim() == 2, \ |
|
|
"Expected Data.semantic_pred to hold `(num_nodes, num_classes)` " \ |
|
|
"histograms, not single labels" |
|
|
d2.semantic_pred = scatter_sum( |
|
|
d1.semantic_pred.cuda(), d1.super_index.cuda(), dim=0).cpu() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_list[level] = d1 |
|
|
data_list.append(d2) |
|
|
|
|
|
if self.verbose: |
|
|
print('\n' + '-' * 64 + '\n') |
|
|
|
|
|
|
|
|
nag = NAG(data_list) |
|
|
|
|
|
return nag |
|
|
|
|
|
|
|
|
class GridPartition(Transform): |
|
|
"""XY-grid-based hierarchical partition of Data. The nodes are |
|
|
aggregated based on their coordinates in a grid of step `size`. |
|
|
|
|
|
:param size: int or List(int) |
|
|
""" |
|
|
|
|
|
_IN_TYPE = Data |
|
|
_OUT_TYPE = NAG |
|
|
|
|
|
def __init__(self, size=2): |
|
|
self.size = size |
|
|
|
|
|
def _process(self, data): |
|
|
|
|
|
assert data.num_nodes < np.iinfo(np.uint32).max, \ |
|
|
"Too many nodes for `uint32` indices" |
|
|
assert data.num_edges < np.iinfo(np.uint32).max, \ |
|
|
"Too many edges for `uint32` indices" |
|
|
assert isinstance(self.size, (int, float, list)), \ |
|
|
"Expected a scalar or a List" |
|
|
|
|
|
|
|
|
size = self.size |
|
|
if not isinstance(size, list): |
|
|
size = [size] |
|
|
data_list = [data] |
|
|
|
|
|
|
|
|
for w in size: |
|
|
|
|
|
d = data_list[-1] |
|
|
super_index = xy_partition(d.pos, consecutive=True) |
|
|
|
|
|
|
|
|
pos = scatter_mean(d.pos, super_index, dim=0) |
|
|
cluster = Cluster( |
|
|
super_index, torch.arange(d.num_nodes), dense=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_list[-1].super_index = super_index |
|
|
data_list.append(Data(pos=pos, sub=cluster)) |
|
|
|
|
|
|
|
|
nag = NAG(data_list) |
|
|
|
|
|
return nag |
|
|
|