import sys import os.path as osp import torch import numpy as np from torch_scatter import scatter_sum, scatter_mean from src.transforms import Transform from src.data import Data, NAG, Cluster, InstanceData from src.utils.cpu import available_cpu_count from src.utils import xy_partition dependencies_folder = osp.dirname(osp.dirname(osp.abspath(__file__))) sys.path.append(dependencies_folder) sys.path.append(osp.join(dependencies_folder, "dependencies/grid_graph/python/bin")) sys.path.append(osp.join(dependencies_folder, "dependencies/parallel_cut_pursuit/python/wrappers")) from grid_graph import edge_list_to_forward_star from cp_d0_dist import cp_d0_dist __all__ = ['CutPursuitPartition', 'GridPartition'] class CutPursuitPartition(Transform): """Partition a graph contained in a `Data` object using cut-pursuit. The input `Data` object is assumed to hold the following attributes: - `pos` carrying node spatial coordinates - `x` carrying node features - `edge_index` carrying the adjacency graph edges in Pytorch Geometric format (typically generated with `AdjacencyGraph`) - `edge_attr` carrying the scalar edge weights in Pytorch Geometric format (typically generated with `AdjacencyGraph`) The quality of a partition may be assessed in terms of efficiency (how much it simplifies the input graph) and accuracy (how well it respects the semantic boundaries). We provide two tools for assessing these: `NAG.level_ratios` which computes the ratio of the number of elements between successive partition levels, and `Data.semantic_segmentation_oracle()` which computes the semantic segmentation metrics of a hypothetical oracle model capable of predicting the majority label for each superpoint. See our Superpoint Transformer tutorial `notebooks/superpoint_transformer_tutorial.ipynb` for more on this. :param regularization: float or List(float) Regularization strength used for each partition level. This is the primary parameter for adjusting cut-pursuit partitions. The larger the regularization, the coarser the partition, the fewer the superpoints, the bigger the superpoints, the lower their semantic purity (ie superpoints are more likely to bleed across semantic object boundaries). And vice versa. If a list is passed, the values are assumed to be increasing :param spatial_weight: float or List(float) Weight used to mitigate the impact of the point position in the partition. The smaller, the less spatial coordinates matter. This can be loosely interpreted as the inverse of a maximum superpoint radius. It typically affects the size of superpoints in geometrically/radiometrically homogeneous regions such as the ground, walls, or ceilings. Setting a large `spatial_weight` will have a "voronoi tessellation" effect on the superpoint partition, preventing too-large superpoints from being constructed in these otherwise-homogeneous regions. Inversely, setting a small `spatial_weight` will encourage cut-pursuit to create superpoints as large as possible, so long as the features of the points inside are homogeneous. In an extreme case: the entire floor would then be a single superpoint. If a list is passed, it must match the length of `regularization` :param cutoff: float or List(float) Minimum number of points in each superpoint. The output partition will not contain any superpoint smaller than `cutoff`. If a list is passed, it must match the length of `regularization` :param parallel: bool Whether cut-pursuit should run in parallel (ie on multiple CPU threads) :param iterations: int Maximum number of iterations for the cut-pursuit algorithm. The higher, the longer the processing. A value in $[10, 15]$ is usually sufficient :param k_adjacency: int When a node is isolated after a partition, we connect it to the nearest nodes. This rules the number of neighbors it should be connected to :param verbose: bool """ _IN_TYPE = Data _OUT_TYPE = NAG _MAX_NUM_EDGES = 4294967295 _NO_REPR = ['verbose', 'parallel'] def __init__( self, regularization=5e-2, spatial_weight=1, cutoff=10, parallel=True, iterations=10, k_adjacency=5, verbose=False): self.regularization = regularization self.spatial_weight = spatial_weight self.cutoff = cutoff self.parallel = parallel self.iterations = iterations self.k_adjacency = k_adjacency self.verbose = verbose def _process(self, data): # Sanity checks assert data.has_edges, \ "Cannot compute partition, no edges in Data" assert data.num_nodes < np.iinfo(np.uint32).max, \ "Too many nodes for `uint32` indices" assert data.num_edges < np.iinfo(np.uint32).max, \ "Too many edges for `uint32` indices" assert isinstance(self.regularization, (int, float, list)), \ "Expected a scalar or a List" assert isinstance(self.cutoff, (int, list)), \ "Expected an int or a List" assert isinstance(self.spatial_weight, (int, float, list)), \ "Expected a scalar or a List" # Trim the graph # TODO: calling this on the level-0 adjacency graph is a bit sluggish # but still saves partition time overall. May be worth finding a # quick way of removing self loops and redundant edges... data = data.to_trimmed() # Initialize the hierarchical partition parameters. In particular, # prepare the output as list of Data objects that will be stored in # a NAG structure num_threads = available_cpu_count() if self.parallel else 1 data.node_size = torch.ones( data.num_nodes, device=data.device, dtype=torch.long) # level-0 points all have the same importance data_list = [data] regularization = self.regularization if not isinstance(regularization, list): regularization = [regularization] cutoff = self.cutoff if isinstance(cutoff, int): cutoff = [cutoff] * len(regularization) spatial_weight = self.spatial_weight if isinstance(spatial_weight, (float, int)): spatial_weight = [spatial_weight] * len(regularization) assert len(regularization) == len(cutoff) == len(spatial_weight) n_dim = data.pos.shape[1] n_feat = data.x.shape[1] if data.x is not None else 0 # Iteratively run the partition on the previous partition level for level, (reg, cut, sw) in enumerate(zip( regularization, cutoff, spatial_weight)): if self.verbose: print( f'Launching partition level={level} reg={reg}, ' f'cutoff={cut}') # Recover the Data object on which we will run the partition d1 = data_list[level] # Exit if the graph contains only one node if d1.num_nodes < 2: break # User warning if the number of edges exceeds uint32 limits if d1.edge_index.shape[1] > self._MAX_NUM_EDGES and self.verbose: print( f"WARNING: number of edges {d1.edge_index.shape[1]} " f"exceeds the uint32 limit {self._MAX_NUM_EDGES}. Please" f"update the cut-pursuit source code to accept a larger " f"data type for `index_t`.") # Convert edges to forward-star (or CSR) representation source_csr, target, reindex = edge_list_to_forward_star( d1.num_nodes, d1.edge_index.T.contiguous().cpu().numpy()) source_csr = source_csr.astype('uint32') target = target.astype('uint32') edge_weights = d1.edge_attr.cpu().numpy()[reindex] * reg \ if d1.edge_attr is not None else reg # Recover attributes features from Data object pos_offset = d1.pos.mean(dim=0) if d1.x is not None: x = torch.cat((d1.pos - pos_offset, d1.x), dim=1) else: x = d1.pos - pos_offset x = np.asfortranarray(x.cpu().numpy().T) node_size = d1.node_size.float().cpu().numpy() coor_weights = np.ones(n_dim + n_feat, dtype=np.float32) coor_weights[:n_dim] *= sw # Partition computation super_index, x_c, cluster, edges, times = cp_d0_dist( n_dim + n_feat, x, source_csr, target, edge_weights=edge_weights, vert_weights=node_size, coor_weights=coor_weights, min_comp_weight=cut, cp_dif_tol=1e-2, cp_it_max=self.iterations, split_damp_ratio=0.7, verbose=self.verbose, max_num_threads=num_threads, balance_parallel_split=True, compute_Time=True, compute_List=True, compute_Graph=True) if self.verbose: delta_t = (times[1:] - times[:-1]).round(2) print(f'Level {level} iteration times: {delta_t}') print(f'partition {level} done') # Save the super_index for the i-level super_index = torch.from_numpy(super_index.astype('int64')) d1.super_index = super_index # Save cluster information in another Data object. Convert # cluster-to-point indices in a CSR format size = torch.LongTensor([c.shape[0] for c in cluster]) pointer = torch.cat([torch.LongTensor([0]), size.cumsum(dim=0)]) value = torch.cat([ torch.from_numpy(x.astype('int64')) for x in cluster]) pos = torch.from_numpy(x_c[:n_dim].T) + pos_offset.cpu() x = torch.from_numpy(x_c[n_dim:].T) s = torch.arange(edges[0].shape[0] - 1).repeat_interleave( torch.from_numpy((edges[0][1:] - edges[0][:-1]).astype("int64"))) t = torch.from_numpy(edges[1].astype("int64")) edge_index = torch.vstack((s, t)) edge_attr = torch.from_numpy(edges[2] / reg) node_size = torch.from_numpy(node_size) node_size_new = scatter_sum( node_size.cuda(), super_index.cuda(), dim=0).cpu().long() d2 = Data( pos=pos, x=x, edge_index=edge_index, edge_attr=edge_attr, sub=Cluster(pointer, value), node_size=node_size_new) # Merge the lower level's instance annotations, if any if d1.obj is not None and isinstance(d1.obj, InstanceData): d2.obj = d1.obj.merge(d1.super_index) # Trim the graph d2 = d2.to_trimmed() # If some nodes are isolated in the graph, connect them to # their nearest neighbors, so their absence of connectivity # does not "pollute" higher levels of partition if d2.num_nodes > 1: d2 = d2.connect_isolated(k=self.k_adjacency) # Aggregate some point attributes into the clusters. This # is not performed dynamically since not all attributes can # be aggregated (e.g. 'neighbor_index', 'neighbor_distance', # 'edge_index', 'edge_attr'...) if 'y' in d1.keys: assert d1.y.dim() == 2, \ "Expected Data.y to hold `(num_nodes, num_classes)` " \ "histograms, not single labels" d2.y = scatter_sum( d1.y.cuda(), d1.super_index.cuda(), dim=0).cpu() torch.cuda.empty_cache() if 'semantic_pred' in d1.keys: assert d1.semantic_pred.dim() == 2, \ "Expected Data.semantic_pred to hold `(num_nodes, num_classes)` " \ "histograms, not single labels" d2.semantic_pred = scatter_sum( d1.semantic_pred.cuda(), d1.super_index.cuda(), dim=0).cpu() torch.cuda.empty_cache() # TODO: aggregate other attributes ? # TODO: if scatter operations are bottleneck, use scatter_csr # Add the l+1-level Data object to data_list and update the # l-level after super_index has been changed data_list[level] = d1 data_list.append(d2) if self.verbose: print('\n' + '-' * 64 + '\n') # Create the NAG object nag = NAG(data_list) return nag class GridPartition(Transform): """XY-grid-based hierarchical partition of Data. The nodes are aggregated based on their coordinates in a grid of step `size`. :param size: int or List(int) """ _IN_TYPE = Data _OUT_TYPE = NAG def __init__(self, size=2): self.size = size def _process(self, data): # Sanity checks assert data.num_nodes < np.iinfo(np.uint32).max, \ "Too many nodes for `uint32` indices" assert data.num_edges < np.iinfo(np.uint32).max, \ "Too many edges for `uint32` indices" assert isinstance(self.size, (int, float, list)), \ "Expected a scalar or a List" # Initialize the partition data size = self.size if not isinstance(size, list): size = [size] data_list = [data] # XY-grid partitions for w in size: # Compute a "manual" partition based on the grid coordinates d = data_list[-1] super_index = xy_partition(d.pos, consecutive=True) # Compute the superpoint centroids and Cluster object pos = scatter_mean(d.pos, super_index, dim=0) cluster = Cluster( super_index, torch.arange(d.num_nodes), dense=True) # TODO: support more Data attributes and more advanced # grouping, probably by interfacing with # src.transforms.sampling._group_data() # Update the super_index of the previous level and create # the Data object for the new level data_list[-1].super_index = super_index data_list.append(Data(pos=pos, sub=cluster)) # Create the NAG object nag = NAG(data_list) return nag