File size: 14,732 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
import sys
import os.path as osp
import torch
import numpy as np
from torch_scatter import scatter_sum, scatter_mean
from src.transforms import Transform
from src.data import Data, NAG, Cluster, InstanceData
from src.utils.cpu import available_cpu_count
from src.utils import xy_partition
dependencies_folder = osp.dirname(osp.dirname(osp.abspath(__file__)))
sys.path.append(dependencies_folder)
sys.path.append(osp.join(dependencies_folder, "dependencies/grid_graph/python/bin"))
sys.path.append(osp.join(dependencies_folder, "dependencies/parallel_cut_pursuit/python/wrappers"))
from grid_graph import edge_list_to_forward_star
from cp_d0_dist import cp_d0_dist
__all__ = ['CutPursuitPartition', 'GridPartition']
class CutPursuitPartition(Transform):
"""Partition a graph contained in a `Data` object using cut-pursuit.
The input `Data` object is assumed to hold the following attributes:
- `pos` carrying node spatial coordinates
- `x` carrying node features
- `edge_index` carrying the adjacency graph edges in Pytorch
Geometric format (typically generated with `AdjacencyGraph`)
- `edge_attr` carrying the scalar edge weights in Pytorch
Geometric format (typically generated with `AdjacencyGraph`)
The quality of a partition may be assessed in terms of efficiency
(how much it simplifies the input graph) and accuracy (how well it
respects the semantic boundaries). We provide two tools for
assessing these: `NAG.level_ratios` which computes the ratio of the
number of elements between successive partition levels, and
`Data.semantic_segmentation_oracle()` which computes the semantic
segmentation metrics of a hypothetical oracle model capable of
predicting the majority label for each superpoint. See our
Superpoint Transformer tutorial
`notebooks/superpoint_transformer_tutorial.ipynb` for more on this.
:param regularization: float or List(float)
Regularization strength used for each partition level. This is
the primary parameter for adjusting cut-pursuit partitions. The
larger the regularization, the coarser the partition, the fewer
the superpoints, the bigger the superpoints, the lower their
semantic purity (ie superpoints are more likely to bleed across
semantic object boundaries). And vice versa. If a list is
passed, the values are assumed to be increasing
:param spatial_weight: float or List(float)
Weight used to mitigate the impact of the point position in the
partition. The smaller, the less spatial coordinates matter.
This can be loosely interpreted as the inverse of a maximum
superpoint radius. It typically affects the size of superpoints
in geometrically/radiometrically homogeneous regions such as the
ground, walls, or ceilings. Setting a large `spatial_weight`
will have a "voronoi tessellation" effect on the superpoint
partition, preventing too-large superpoints from being
constructed in these otherwise-homogeneous regions. Inversely,
setting a small `spatial_weight` will encourage cut-pursuit to
create superpoints as large as possible, so long as the features
of the points inside are homogeneous. In an extreme case: the
entire floor would then be a single superpoint. If a list is
passed, it must match the length of `regularization`
:param cutoff: float or List(float)
Minimum number of points in each superpoint. The output
partition will not contain any superpoint smaller than `cutoff`.
If a list is passed, it must match the length of
`regularization`
:param parallel: bool
Whether cut-pursuit should run in parallel (ie on multiple CPU
threads)
:param iterations: int
Maximum number of iterations for the cut-pursuit algorithm. The
higher, the longer the processing. A value in $[10, 15]$ is
usually sufficient
:param k_adjacency: int
When a node is isolated after a partition, we connect it to the
nearest nodes. This rules the number of neighbors it should be
connected to
:param verbose: bool
"""
_IN_TYPE = Data
_OUT_TYPE = NAG
_MAX_NUM_EDGES = 4294967295
_NO_REPR = ['verbose', 'parallel']
def __init__(
self, regularization=5e-2, spatial_weight=1, cutoff=10,
parallel=True, iterations=10, k_adjacency=5, verbose=False):
self.regularization = regularization
self.spatial_weight = spatial_weight
self.cutoff = cutoff
self.parallel = parallel
self.iterations = iterations
self.k_adjacency = k_adjacency
self.verbose = verbose
def _process(self, data):
# Sanity checks
assert data.has_edges, \
"Cannot compute partition, no edges in Data"
assert data.num_nodes < np.iinfo(np.uint32).max, \
"Too many nodes for `uint32` indices"
assert data.num_edges < np.iinfo(np.uint32).max, \
"Too many edges for `uint32` indices"
assert isinstance(self.regularization, (int, float, list)), \
"Expected a scalar or a List"
assert isinstance(self.cutoff, (int, list)), \
"Expected an int or a List"
assert isinstance(self.spatial_weight, (int, float, list)), \
"Expected a scalar or a List"
# Trim the graph
# TODO: calling this on the level-0 adjacency graph is a bit sluggish
# but still saves partition time overall. May be worth finding a
# quick way of removing self loops and redundant edges...
data = data.to_trimmed()
# Initialize the hierarchical partition parameters. In particular,
# prepare the output as list of Data objects that will be stored in
# a NAG structure
num_threads = available_cpu_count() if self.parallel else 1
data.node_size = torch.ones(
data.num_nodes, device=data.device, dtype=torch.long) # level-0 points all have the same importance
data_list = [data]
regularization = self.regularization
if not isinstance(regularization, list):
regularization = [regularization]
cutoff = self.cutoff
if isinstance(cutoff, int):
cutoff = [cutoff] * len(regularization)
spatial_weight = self.spatial_weight
if isinstance(spatial_weight, (float, int)):
spatial_weight = [spatial_weight] * len(regularization)
assert len(regularization) == len(cutoff) == len(spatial_weight)
n_dim = data.pos.shape[1]
n_feat = data.x.shape[1] if data.x is not None else 0
# Iteratively run the partition on the previous partition level
for level, (reg, cut, sw) in enumerate(zip(
regularization, cutoff, spatial_weight)):
if self.verbose:
print(
f'Launching partition level={level} reg={reg}, '
f'cutoff={cut}')
# Recover the Data object on which we will run the partition
d1 = data_list[level]
# Exit if the graph contains only one node
if d1.num_nodes < 2:
break
# User warning if the number of edges exceeds uint32 limits
if d1.edge_index.shape[1] > self._MAX_NUM_EDGES and self.verbose:
print(
f"WARNING: number of edges {d1.edge_index.shape[1]} "
f"exceeds the uint32 limit {self._MAX_NUM_EDGES}. Please"
f"update the cut-pursuit source code to accept a larger "
f"data type for `index_t`.")
# Convert edges to forward-star (or CSR) representation
source_csr, target, reindex = edge_list_to_forward_star(
d1.num_nodes, d1.edge_index.T.contiguous().cpu().numpy())
source_csr = source_csr.astype('uint32')
target = target.astype('uint32')
edge_weights = d1.edge_attr.cpu().numpy()[reindex] * reg \
if d1.edge_attr is not None else reg
# Recover attributes features from Data object
pos_offset = d1.pos.mean(dim=0)
if d1.x is not None:
x = torch.cat((d1.pos - pos_offset, d1.x), dim=1)
else:
x = d1.pos - pos_offset
x = np.asfortranarray(x.cpu().numpy().T)
node_size = d1.node_size.float().cpu().numpy()
coor_weights = np.ones(n_dim + n_feat, dtype=np.float32)
coor_weights[:n_dim] *= sw
# Partition computation
super_index, x_c, cluster, edges, times = cp_d0_dist(
n_dim + n_feat,
x,
source_csr,
target,
edge_weights=edge_weights,
vert_weights=node_size,
coor_weights=coor_weights,
min_comp_weight=cut,
cp_dif_tol=1e-2,
cp_it_max=self.iterations,
split_damp_ratio=0.7,
verbose=self.verbose,
max_num_threads=num_threads,
balance_parallel_split=True,
compute_Time=True,
compute_List=True,
compute_Graph=True)
if self.verbose:
delta_t = (times[1:] - times[:-1]).round(2)
print(f'Level {level} iteration times: {delta_t}')
print(f'partition {level} done')
# Save the super_index for the i-level
super_index = torch.from_numpy(super_index.astype('int64'))
d1.super_index = super_index
# Save cluster information in another Data object. Convert
# cluster-to-point indices in a CSR format
size = torch.LongTensor([c.shape[0] for c in cluster])
pointer = torch.cat([torch.LongTensor([0]), size.cumsum(dim=0)])
value = torch.cat([
torch.from_numpy(x.astype('int64')) for x in cluster])
pos = torch.from_numpy(x_c[:n_dim].T) + pos_offset.cpu()
x = torch.from_numpy(x_c[n_dim:].T)
s = torch.arange(edges[0].shape[0] - 1).repeat_interleave(
torch.from_numpy((edges[0][1:] - edges[0][:-1]).astype("int64")))
t = torch.from_numpy(edges[1].astype("int64"))
edge_index = torch.vstack((s, t))
edge_attr = torch.from_numpy(edges[2] / reg)
node_size = torch.from_numpy(node_size)
node_size_new = scatter_sum(
node_size.cuda(), super_index.cuda(), dim=0).cpu().long()
d2 = Data(
pos=pos, x=x, edge_index=edge_index, edge_attr=edge_attr,
sub=Cluster(pointer, value), node_size=node_size_new)
# Merge the lower level's instance annotations, if any
if d1.obj is not None and isinstance(d1.obj, InstanceData):
d2.obj = d1.obj.merge(d1.super_index)
# Trim the graph
d2 = d2.to_trimmed()
# If some nodes are isolated in the graph, connect them to
# their nearest neighbors, so their absence of connectivity
# does not "pollute" higher levels of partition
if d2.num_nodes > 1:
d2 = d2.connect_isolated(k=self.k_adjacency)
# Aggregate some point attributes into the clusters. This
# is not performed dynamically since not all attributes can
# be aggregated (e.g. 'neighbor_index', 'neighbor_distance',
# 'edge_index', 'edge_attr'...)
if 'y' in d1.keys:
assert d1.y.dim() == 2, \
"Expected Data.y to hold `(num_nodes, num_classes)` " \
"histograms, not single labels"
d2.y = scatter_sum(
d1.y.cuda(), d1.super_index.cuda(), dim=0).cpu()
torch.cuda.empty_cache()
if 'semantic_pred' in d1.keys:
assert d1.semantic_pred.dim() == 2, \
"Expected Data.semantic_pred to hold `(num_nodes, num_classes)` " \
"histograms, not single labels"
d2.semantic_pred = scatter_sum(
d1.semantic_pred.cuda(), d1.super_index.cuda(), dim=0).cpu()
torch.cuda.empty_cache()
# TODO: aggregate other attributes ?
# TODO: if scatter operations are bottleneck, use scatter_csr
# Add the l+1-level Data object to data_list and update the
# l-level after super_index has been changed
data_list[level] = d1
data_list.append(d2)
if self.verbose:
print('\n' + '-' * 64 + '\n')
# Create the NAG object
nag = NAG(data_list)
return nag
class GridPartition(Transform):
"""XY-grid-based hierarchical partition of Data. The nodes are
aggregated based on their coordinates in a grid of step `size`.
:param size: int or List(int)
"""
_IN_TYPE = Data
_OUT_TYPE = NAG
def __init__(self, size=2):
self.size = size
def _process(self, data):
# Sanity checks
assert data.num_nodes < np.iinfo(np.uint32).max, \
"Too many nodes for `uint32` indices"
assert data.num_edges < np.iinfo(np.uint32).max, \
"Too many edges for `uint32` indices"
assert isinstance(self.size, (int, float, list)), \
"Expected a scalar or a List"
# Initialize the partition data
size = self.size
if not isinstance(size, list):
size = [size]
data_list = [data]
# XY-grid partitions
for w in size:
# Compute a "manual" partition based on the grid coordinates
d = data_list[-1]
super_index = xy_partition(d.pos, consecutive=True)
# Compute the superpoint centroids and Cluster object
pos = scatter_mean(d.pos, super_index, dim=0)
cluster = Cluster(
super_index, torch.arange(d.num_nodes), dense=True)
# TODO: support more Data attributes and more advanced
# grouping, probably by interfacing with
# src.transforms.sampling._group_data()
# Update the super_index of the previous level and create
# the Data object for the new level
data_list[-1].super_index = super_index
data_list.append(Data(pos=pos, sub=cluster))
# Create the NAG object
nag = NAG(data_list)
return nag
|