import re import torch from torch_geometric.nn.pool import voxel_grid from torch_geometric.utils import k_hop_subgraph, to_undirected from torch_cluster import grid_cluster from torch_scatter import scatter_mean from torch_geometric.nn.pool.consecutive import consecutive_cluster from src.utils import fast_randperm, sparse_sample, scatter_pca, sanitize_keys from src.transforms import Transform from src.data import Data, NAG, NAGBatch, CSRData, InstanceData, Cluster from src.utils.histogram import atomic_to_histogram __all__ = [ 'Shuffle', 'SaveNodeIndex', 'NAGSaveNodeIndex', 'GridSampling3D', 'SampleXYTiling', 'SampleRecursiveMainXYAxisTiling', 'SampleSubNodes', 'SampleKHopSubgraphs', 'SampleRadiusSubgraphs', 'SampleSegments', 'SampleEdges', 'RestrictSize', 'NAGRestrictSize'] class Shuffle(Transform): """Shuffle the order of points in a Data object.""" def _process(self, data): idx = fast_randperm(data.num_points, device=data.device) return data.select(idx, update_sub=False, update_super=False) class SaveNodeIndex(Transform): """Adds the index of the nodes to the Data object attributes. This allows tracking nodes from the output back to the input Data object. """ DEFAULT_KEY = 'node_id' def __init__(self, key=None): self.key = key if key is not None else self.DEFAULT_KEY def _process(self, data): idx = torch.arange(0, data.pos.shape[0], device=data.device) setattr(data, self.key, idx) return data class NAGSaveNodeIndex(SaveNodeIndex): """SaveNodeIndex, applied to each NAG level. """ _IN_TYPE = NAG _OUT_TYPE = NAG def _process(self, nag): transform = SaveNodeIndex(key=self.key) for i_level in range(nag.num_levels): nag._list[i_level] = transform(nag._list[i_level]) return nag class GridSampling3D(Transform): """ Clusters 3D points into voxels with size :attr:`size`. By default, some special keys undergo dedicated grouping mechanisms. The `_VOTING_KEYS=['y', 'super_index', 'is_val']` keys are grouped by their majority label. The `_INSTANCE_KEYS=['obj', 'obj_pred']` keys are grouped into an `InstanceData`, which stores all instance/panoptic overlap data values in CSR format. The `_CLUSTER_KEYS=['point_id']` keys are grouped into a `Cluster` object, which stores indices of child elements for parent clusters in CSR format. The `_LAST_KEYS=['batch', SaveNodeIndex.DEFAULT_KEY]` keys are by default grouped following `mode='last'`. Besides, for keys where a more subtle histogram mechanism is needed, (e.g. for 'y'), the 'hist_key' and 'hist_size' arguments can be used. Modified from: https://github.com/torch-points3d/torch-points3d Parameters ---------- size: float Size of a voxel (in each dimension). quantize_coords: bool If True, it will convert the points into their associated sparse coordinates within the grid and store the value into a new `coords` attribute. mode: string: The mode can be either `last` or `mean`. If mode is `mean`, all the points and their features within a cell will be averaged. If mode is `last`, one random points per cell will be selected with its associated features. hist_key: str or List(str) Data attributes for which we would like to aggregate values into an histogram. This is typically needed when we want to aggregate points labels without losing the distribution, as opposed to majority voting. hist_size: str or List(str) Must be of same size as `hist_key`, indicates the number of bins for each key-histogram. This is typically needed when we want to aggregate points labels without losing the distribution, as opposed to majority voting. inplace: bool Whether the input Data object should be modified in-place verbose: bool Verbosity """ _NO_REPR = ['verbose', 'inplace'] def __init__( self, size, quantize_coords=False, mode="mean", hist_key=None, hist_size=None, inplace=False, verbose=False): hist_key = [] if hist_key is None else hist_key hist_size = [] if hist_size is None else hist_size hist_key = [hist_key] if isinstance(hist_key, str) else hist_key hist_size = [hist_size] if isinstance(hist_size, int) else hist_size assert isinstance(hist_key, list) assert isinstance(hist_size, list) assert len(hist_key) == len(hist_size) self.grid_size = size self.quantize_coords = quantize_coords self.mode = mode self.bins = {k: v for k, v in zip(hist_key, hist_size)} self.inplace = inplace if verbose: print( f"If you need to keep track of the position of your points, " f"use SaveNodeIndex transform before using " f"{self.__class__.__name__}.") if self.mode == "last": print( "The tensors within data will be shuffled each time this " "transform is applied. Be careful that if an attribute " "doesn't have the size of num_nodes, it won't be shuffled") def _process(self, data_in): # In-place option will modify the input Data object directly data = data_in if self.inplace else data_in.clone() # If the aggregation mode is 'last', shuffle the points order. # Note that voxelization of point attributes will be stochastic if self.mode == 'last': data = Shuffle()(data) # Convert point coordinates to the voxel grid coordinates coords = torch.round((data.pos) / self.grid_size) # Match each point with a voxel identifier if 'batch' not in data: cluster = grid_cluster(coords, torch.ones(3, device=coords.device)) else: cluster = voxel_grid(coords, data.batch, 1) # Reindex the clusters to make sure the indices used are # consecutive. Basically, we do not want cluster indices to span # [0, i_max] without all in-between indices to be used, because # this will affect the speed and output size of torch_scatter # operations cluster, unique_pos_indices = consecutive_cluster(cluster) # Perform voxel aggregation data = _group_data( data, cluster, unique_pos_indices, mode=self.mode, bins=self.bins) # Optionally convert quantize the coordinates. This is useful # for sparse convolution models if self.quantize_coords: data.coords = coords[unique_pos_indices].int() # Save the grid size in the Data attributes data.grid_size = torch.tensor([self.grid_size]) return data def _group_data( data, cluster=None, unique_pos_indices=None, mode="mean", skip_keys=None, bins={}): """Group data based on indices in cluster. The option ``mode`` controls how data gets aggregated within each cluster. By default, some special keys undergo dedicated grouping mechanisms. The `_VOTING_KEYS=['y', 'super_index', 'is_val']` keys are grouped by their majority label. The `_INSTANCE_KEYS=['obj', 'obj_pred']` keys are grouped into an `InstanceData`, which stores all instance/panoptic overlap data values in CSR format. The `_CLUSTER_KEYS=['point_id']` keys are grouped into a `Cluster` object, which stores indices of child elements for parent clusters in CSR format. The `_LAST_KEYS=['batch', SaveNodeIndex.DEFAULT_KEY]` keys are by default grouped following `mode='last'`. Besides, for keys where a more subtle histogram mechanism is needed, (e.g. for 'y'), the 'bins' argument can be used. Warning: this function modifies the input Data object in-place. :param data : Data :param cluster : Tensor Tensor of the same size as the number of points in data. Each element is the cluster index of that point. :param unique_pos_indices : Tensor Tensor containing one index per cluster, this index will be used to select features and labels. :param mode : str Option to select how the features and labels for each voxel is computed. Can be ``last`` or ``mean``. ``last`` selects the last point falling in a voxel as the representative, ``mean`` takes the average. :param skip_keys: list Keys of attributes to skip in the grouping. :param bins: dict Dictionary holding ``{'key': n_bins}`` where ``key`` is a Data attribute for which we would like to aggregate values into an histogram and ``n_bins`` accounts for the corresponding number of bins. This is typically needed when we want to aggregate point labels without losing the distribution, as opposed to majority voting. """ skip_keys = sanitize_keys(skip_keys, default=[]) # Keys for which voxel aggregation will be based on majority voting _VOTING_KEYS = ['y', 'super_index', 'is_val'] # Keys for which voxel aggregation will use an InstanceData object, # which store all input information in CSR format _INSTANCE_KEYS = ['obj', 'obj_pred'] # Keys for which voxel aggregation will use a Cluster object, which # store all input information in CSR format _CLUSTER_KEYS = ['sub'] # Keys for which voxel aggregation will be based on majority voting _LAST_KEYS = ['batch', SaveNodeIndex.DEFAULT_KEY] # Keys to be treated as normal vectors, for which the unit-norm must # be preserved _NORMAL_KEYS = ['normal'] # Supported mode for aggregation _MODES = ['mean', 'last'] assert mode in _MODES if mode == "mean" and cluster is None: raise ValueError( "In mean mode the cluster argument needs to be specified") if mode == "last" and unique_pos_indices is None: raise ValueError( "In last mode the unique_pos_indices argument needs to be specified") # Save the number of nodes here because the subsequent in-place # modifications will affect it num_nodes = data.num_nodes # Aggregate Data attributes for same-cluster points for key, item in data: # `skip_keys` are not aggregated if key in skip_keys: continue # Edges cannot be aggregated if bool(re.search('edge', key)): raise NotImplementedError("Edges not supported. Wrong data type.") # For instance labels grouped into an InstanceData. Supports # input instance labels either as InstanceData or as a simple # index tensor if key in _INSTANCE_KEYS: if isinstance(item, InstanceData): data[key] = item.merge(cluster) else: count = torch.ones_like(item) y = data.y if getattr(data, 'y', None) is not None \ else torch.zeros_like(item) data[key] = InstanceData(cluster, item, count, y, dense=True) continue # For point indices to be grouped in Cluster. This allows # backtracking full-resolution point indices to the voxels if key in _CLUSTER_KEYS: if (isinstance(item, torch.Tensor) and item.dim() == 1 and not item.is_floating_point()): data[key] = Cluster(cluster, item, dense=True) else: raise NotImplementedError( f"Cannot merge '{key}' with data type: {type(item)} into " f"a Cluster object. Only supports 1D Tensor of integers.") continue # TODO: adapt to make use of CSRData batching ? if isinstance(item, CSRData): raise NotImplementedError( f"Cannot merge '{key}' with data type: {type(item)}") # Only torch.Tensor attributes of size Data.num_nodes are # considered for aggregation if not torch.is_tensor(item) or item.size(0) != num_nodes: continue # For 'last' mode, use unique_pos_indices to pick values # from a single point within each cluster. The same behavior # is expected for the _LAST_KEYS if mode == 'last' or key in _LAST_KEYS: data[key] = item[unique_pos_indices] continue # For 'mean' mode, the attributes will be aggregated # depending on their nature. # If the attribute is a boolean, temporarily convert to integer # to facilitate aggregation is_item_bool = item.dtype == torch.bool if is_item_bool: item = item.int() # For keys requiring a voting scheme or a histogram if key in _VOTING_KEYS or key in bins.keys(): voting = key not in bins.keys() n_bins = item.max() + 1 if voting else bins[key] hist = atomic_to_histogram(item, cluster, n_bins=n_bins) data[key] = hist.argmax(dim=-1) if voting else hist # Standard behavior, where attributes are simply # averaged across the clusters else: data[key] = scatter_mean(item, cluster, dim=0) # For normals, make sure to re-normalize the mean-normal if key in _NORMAL_KEYS: data[key] = data[key] / data[key].norm(dim=1).view(-1, 1) # Convert back to boolean if need be if is_item_bool: data[key] = data[key].bool() return data class SampleXYTiling(Transform): """Tile the input Data along the XY axes and select only a given tile. This is useful to reduce the size of very large clouds at preprocessing time. :param x: int x coordinate of the sample in the tiling grid :param y: int x coordinate of the sample in the tiling grid :param tiling: int or tuple(int, int) Number of tiles in the grid in each direction. If a tuple is passed, each direction can be tiled independently """ def __init__(self, x=0, y=0, tiling=2): tiling = (tiling, tiling) if isinstance(tiling, int) else tiling assert 0 <= x < tiling[0] assert 0 <= y < tiling[1] self.tiling = torch.as_tensor(tiling) self.x = x self.y = y def _process(self, data): # Compute the xy coordinates in the tiling grid, for each point xy = data.pos[:, :2].clone().view(-1, 2) xy -= xy.min(dim=0).values.view(1, 2) xy /= xy.max(dim=0).values.view(1, 2) xy = xy.clip(min=0, max=1) * self.tiling.view(1, 2) xy = xy.long() # Select only the points in the desired tile idx = torch.where((xy[:, 0] == self.x) & (xy[:, 1] == self.y))[0] return data.select(idx)[0] class SampleRecursiveMainXYAxisTiling(Transform): """Tile the input Data by recursively splitting the points along their principal XY direction and select only a given tile. This is useful to reduce the size of very large clouds at preprocessing time, when clouds are not XY-aligned or have non-trivial geometries. :param x: int x coordinate of the sample in the tiling structure. The tiles are "lexicographically" ordered, with the points lying below the median of each split considered before those above the median :param steps: int Number of splitting steps. By construction, the total number of tiles is 2**steps """ def __init__(self, x=0, steps=2): assert 0 <= x < 2 ** steps self.steps = steps self.x = x def _process(self, data): # Nothing to do if less than 1 step required if self.steps <= 0: return data # Recursively split the data for p in self.binary_tree_path: data = self.split_by_main_xy_direction(data, left=not p, right=p) return data @property def binary_tree_path(self): # Converting x to a binary number gives the solution ! path = bin(self.x)[2:] # Prepend with zeros to build path of length steps path = (self.steps - len(path)) * '0' + path # Convert string of 0 and 1 to list of booleans return [bool(int(i)) for i in path] @staticmethod def split_by_main_xy_direction(data, left=True, right=True): assert left or right, "At least one split must be returned" # Find the main XY direction and orient it along the x+ halfspace, # for repeatability v = SampleRecursiveMainXYAxisTiling.compute_main_xy_direction(data) if v[0] < 0: v *= -1 # Project points along this direction and split around the median proj = (data.pos[:, :2] * v.view(1, -1)).sum(dim=1) mask = proj < proj.median() if left and not right: return data.select(mask)[0] if right and not left: return data.select(~mask)[0] return data.select(mask)[0], data.select(~mask)[0] @staticmethod def compute_main_xy_direction(data): # Work on local copy data = Data(pos=data.pos.clone()) # Compute a voxel size to aggressively sample the data xy = data.pos[:, :2] xy -= xy.min(dim=0).values.view(1, -1) voxel = xy.max() / 100 # Set Z to 0, we only want to compute the principal components in XY data.pos[:, 2] = 0 # Voxelize data = GridSampling3D(size=voxel)(data) # Search first principal component idx = torch.zeros_like(data.pos[:, 0], dtype=torch.long) v = scatter_pca(data.pos, idx, on_cpu=True)[1][0][:2, -1] return v class SampleSubNodes(Transform): """Sample elements at `low`-level, based on which segment they belong to at `high`-level. The sampling operation is run without replacement and each segment is sampled at least `n_min` and at most `n_max` times, within the limits allowed by its actual size. Optionally, a `mask` can be passed to filter out some `low`-level points. :param high: int Partition level of the segments we want to sample. By default, `high=1` to sample the level-1 segments :param low: int Partition level we will sample from, guided by the `high` segments. By default, `high=0` to sample the level-0 points. `low=-1` is accepted when level-0 has a `sub` attribute (i.e. level-0 points are themselves segments of `-1` level absent from the NAG object). :param n_max: int Maximum number of `low`-level elements to sample in each `high`-level segment :param n_min: int Minimum number of `low`-level elements to sample in each `high`-level segment, within the limits of its size (i.e. no oversampling) :param mask: list, np.ndarray, torch.LongTensor, torch.BoolTensor Indicates a subset of `low`-level elements to consider. This allows ignoring """ _IN_TYPE = NAG _OUT_TYPE = NAG def __init__( self, high=1, low=0, n_max=32, n_min=16, mask=None): assert isinstance(high, int) assert isinstance(low, int) assert isinstance(n_max, int) assert isinstance(n_min, int) self.high = high self.low = low self.n_max = n_max self.n_min = n_min self.mask = mask def _process(self, nag): idx = nag.get_sampling( high=self.high, low=self.low, n_max=self.n_max, n_min=self.n_min, return_pointers=False) return nag.select(self.low, idx) class SampleSegments(Transform): """Remove randomly-picked nodes from each level 1+ of the NAG. This operation relies on `NAG.select()` to maintain index consistency across the NAG levels. Note: we do not directly prune level-0 points, see `SampleSubNodes` for that. For speed consideration, it is recommended to use `SampleSubNodes` first before `SampleSegments`, to minimize the number of level-0 points to manipulate. :param ratio: float or list(float) Portion of nodes to be dropped. A list may be passed to prune NAG 1+ levels with different probabilities :param by_size: bool If True, the segment size will affect the chances of being dropped out. The smaller the segment, the greater its chances to be dropped :param by_class: bool If True, the classes will affect the chances of being dropped out. The more frequent the segment class, the greater its chances to be dropped """ _IN_TYPE = NAG _OUT_TYPE = NAG def __init__(self, ratio=0.2, by_size=False, by_class=False): assert isinstance(ratio, list) and all(0 <= r < 1 for r in ratio) \ or (0 <= ratio < 1) self.ratio = ratio self.by_size = by_size self.by_class = by_class def _process(self, nag): if not isinstance(self.ratio, list): ratio = [self.ratio] * (nag.num_levels - 1) else: ratio = self.ratio # Drop some nodes from each NAG level. Note that we start # dropping from the highest to the lowest level, to accelerate # sampling device = nag.device for i_level in range(nag.num_levels - 1, 0, -1): # Negative max_ratios prevent dropout if ratio[i_level - 1] <= 0: continue # Prepare sampling num_nodes = nag[i_level].num_nodes num_keep = num_nodes - int(num_nodes * ratio[i_level - 1]) # Initialize all segments with the same weights weights = torch.ones(num_nodes, device=device) # Compute per-segment weights solely based on the segment # size. This is biased towards preserving large segments in # the sampling if self.by_size: node_size = nag.get_sub_size(i_level, low=0) size_weights = node_size ** 0.333 size_weights /= size_weights.sum() weights += size_weights # Compute per-class weights based on class frequencies in # the current NAG and give a weight to each segment # based on the rarest class it contains. This is biased # towards sampling rare classes if self.by_class and nag[i_level].y is not None: counts = nag[i_level].y.sum(dim=0).sqrt() scores = 1 / (counts + 1) scores /= scores.sum() mask = nag[i_level].y.gt(0) class_weights = (mask * scores.view(1, -1)).max(dim=1).values class_weights /= class_weights.sum() weights += class_weights.squeeze() # Normalize the weights again, in case size or class weights # were added weights /= weights.sum() # Generate sampling indices idx = torch.multinomial(weights, num_keep, replacement=False) # Select the nodes and update the NAG structure accordingly nag = nag.select(i_level, idx) return nag class BaseSampleSubgraphs(Transform): """Base class for sampling subgraphs from a NAG. It randomly picks `k` seed nodes from `i_level`, from which `k` subgraphs can be grown. Child classes must implement `_sample_subgraphs()` to describe how these subgraphs are built. Optionally, the see sampling can be driven by their class, or their size, using `by_class` and `by_size`, respectively. This operation relies on `NAG.select()` to maintain index consistency across the NAG levels. :param i_level: int Partition level we want to pick from. By default, `i_level=-1` will sample the highest level of the input NAG :param k: int Number of sub-graphs/seeds to pick :param by_size: bool If True, the segment size will affect the chances of being selected as a seed. The larger the segment, the greater its chances to be picked :param by_class: bool If True, the classes will affect the chances of being selected as a seed. The scarcer the segment class, the greater its chances to be selected :param use_batch: bool If True, the 'Data.batch' attribute will be used to guide seed sampling across batches. More specifically, if the input NAG is a NAGBatch made up of multiple NAGs, the subgraphs will be sampled in a way that guarantees each NAG is sampled from. Obviously enough, if `k < batch.max() + 1`, not all NAGs will be sampled from :param disjoint: bool If True, subgraphs sampled from the same NAG will be separated as distinct NAGs themselves. Instead, when `disjoint=False`, subgraphs sampled in the same NAG will be long the same NAG. Hence, if two subgraphs share a node, they will be connected """ _IN_TYPE = NAG _OUT_TYPE = NAG def __init__( self, i_level=1, k=1, by_size=False, by_class=False, use_batch=True, disjoint=True): self.i_level = i_level self.k = k self.by_size = by_size self.by_class = by_class self.use_batch = use_batch self.disjoint = disjoint def _process(self, nag): device = nag.device # Skip if i_level is None or k<=0. This may be useful to turn # this transform into an Identity, if need be if self.i_level is None or self.k <= 0: return nag # Initialization i_level = self.i_level if 0 <= self.i_level < nag.num_levels \ else nag.num_levels - 1 k = self.k if self.k < nag[i_level].num_nodes \ else 1 # Initialize all segments with the same weights weights = torch.ones(nag[i_level].num_nodes, device=device) # Compute per-segment weights solely based on the segment # size. This is biased towards preserving large segments in # the sampling if self.by_size: node_size = nag.get_sub_size(i_level, low=0) size_weights = node_size ** 0.333 size_weights /= size_weights.sum() weights += size_weights # Compute per-class weights based on class frequencies in # the current NAG and give a weight to each segment # based on the rarest class it contains. This is biased # towards sampling rare classes if self.by_class and nag[i_level].y is not None: counts = nag[i_level].y.sum(dim=0).sqrt() scores = 1 / (counts + 1) scores /= scores.sum() mask = nag[i_level].y.gt(0) class_weights = (mask * scores.view(1, -1)).max(dim=1).values class_weights /= class_weights.sum() weights += class_weights.squeeze() # Normalize the weights again, in case size or class weights # were added weights /= weights.sum() # Generate sampling indices. If the Data object has a 'batch' # attribute and 'self.use_batch', use it to guide the sampling # across the batches batch = getattr(nag[i_level], 'batch', None) if batch is not None and self.use_batch: idx_list = [] num_batch = batch.max() + 1 num_sampled = 0 k_batch = torch.div(k, num_batch, rounding_mode='floor') k_batch = k_batch.maximum(torch.ones_like(k_batch)) for i_batch in range(num_batch): # Try to sample all NAGs in the batch as evenly as # possible, within the constraints of k and # num_batch if i_batch >= num_batch - 1: k_batch = k - num_sampled # Compute the sampling indices for the NAG at hand mask = torch.where(i_batch == batch)[0] idx_ = torch.multinomial( weights[mask], k_batch, replacement=False) idx_list.append(mask[idx_]) # Update number of sampled subgraphs num_sampled += k_batch if num_sampled >= k: break # Aggregate sampling indices idx = torch.cat(idx_list) else: idx = torch.multinomial(weights, k, replacement=False) # Sample the NAG and allow subgraphs sharing the same nodes to # be connected if not self.disjoint: return self._sample_subgraphs(nag, i_level, idx) # All sampled subgraphs are disjoint return NAGBatch.from_nag_list([ self._sample_subgraphs(nag, i_level, i.view(1)) for i in idx]) def _sample_subgraphs(self, nag, i_level, idx): raise NotImplementedError class SampleKHopSubgraphs(BaseSampleSubgraphs): """Randomly pick segments from `i_level`, along with their `hops` neighbors. This can be thought as a spherical sampling in the graph of i_level. This operation relies on `NAG.select()` to maintain index consistency across the NAG levels. Note: we do not directly sample level-0 points, see `SampleSubNodes` for that. For speed consideration, it is recommended to use `SampleSubNodes` first before `SampleKHopSubgraphs`, to minimize the number of level-0 points to manipulate. :param hops: int Number of hops ruling the neighborhood size selected around the seed nodes :param i_level: int Partition level we want to pick from. By default, `i_level=-1` will sample the highest level of the input NAG :param k: int Number of sub-graphs/seeds to pick :param by_size: bool If True, the segment size will affect the chances of being selected as a seed. The larger the segment, the greater its chances to be picked :param by_class: bool If True, the classes will affect the chances of being selected as a seed. The scarcer the segment class, the greater its chances to be selected :param use_batch: bool If True, the 'Data.batch' attribute will be used to guide seed sampling across batches. More specifically, if the input NAG is a NAGBatch made up of multiple NAGs, the subgraphs will be sampled in a way that guarantees each NAG is sampled from. Obviously enough, if `k < batch.max() + 1`, not all NAGs will be sampled from :param disjoint: bool If True, subgraphs sampled from the same NAG will be separated as distinct NAGs themselves. Instead, when `disjoint=False`, subgraphs sampled in the same NAG will be long the same NAG. Hence, if two subgraphs share a node, they will be connected """ def __init__( self, hops=2, i_level=1, k=1, by_size=False, by_class=False, use_batch=True, disjoint=False): super().__init__( i_level=i_level, k=k, by_size=by_size, by_class=by_class, use_batch=use_batch, disjoint=disjoint) self.hops = hops def _sample_subgraphs(self, nag, i_level, idx): assert nag[i_level].has_edges, \ "Expected Data object to have edges for k-hop subgraph sampling" # Convert the graph to undirected graph. This is needed because # it is likely that the graph has been trimmed (see # `src.utils.to_trimmed`), in which case the trimmed edge # direction would affect the k-hop search edge_index = to_undirected(nag[i_level].edge_index) # Search the k-hop neighbors of the sampled nodes idx = k_hop_subgraph( idx, self.hops, edge_index, num_nodes=nag[i_level].num_nodes)[0] # Select the nodes and update the NAG structure accordingly return nag.select(i_level, idx) class SampleRadiusSubgraphs(BaseSampleSubgraphs): """Randomly pick segments from `i_level`, along with their spherical neighborhood of fixed radius. This operation relies on `NAG.select()` to maintain index consistency across the NAG levels. Note: we do not directly sample level-0 points, see `SampleSubNodes` for that. For speed consideration, it is recommended to use `SampleSubNodes` first before `SampleRadiusSubgraphs`, to minimize the number of level-0 points to manipulate. :param r: float Radius for spherical sampling :param i_level: int Partition level we want to pick from. By default, `i_level=-1` will sample the highest level of the input NAG :param k: int Number of sub-graphs/seeds to pick :param by_size: bool If True, the segment size will affect the chances of being selected as a seed. The larger the segment, the greater its chances to be picked :param by_class: bool If True, the classes will affect the chances of being selected as a seed. The scarcer the segment class, the greater its chances to be selected :param use_batch: bool If True, the 'Data.batch' attribute will be used to guide seed sampling across batches. More specifically, if the input NAG is a NAGBatch made up of multiple NAGs, the subgraphs will be sampled in a way that guarantees each NAG is sampled from. Obviously enough, if `k < batch.max() + 1`, not all NAGs will be sampled from :param disjoint: bool If True, subgraphs sampled from the same NAG will be separated as distinct NAGs themselves. Instead, when `disjoint=False`, subgraphs sampled in the same NAG will be long the same NAG. Hence, if two subgraphs share a node, they will be connected """ def __init__( self, r=2, i_level=1, k=1, by_size=False, by_class=False, use_batch=True, disjoint=False): super().__init__( i_level=i_level, k=k, by_size=by_size, by_class=by_class, use_batch=use_batch, disjoint=disjoint) self.r = r def _sample_subgraphs(self, nag, i_level, idx): # Skip if r<=0. This may be useful to turn this transform into # an Identity, if need be if self.r <= 0: return nag # Neighbors are searched using the node coordinates. This is not # the optimal search for cluster-cluster distances, but is the # fastest for our needs here. If need be, one could make this # search more accurate using something like: # `src.utils.neighbors.cluster_radius_nn_graph` # TODO: searching using knn_2 was sluggish, switching to brute # force for now. If bottleneck, need to investigate alternative # search approaches # # Search using radius knn utils # search_mask = torch.ones_like(nag[i_level].pos[:, 0], dtype=torch.bool) # search_mask[idx] = False # x_search = nag[i_level].pos # x_query = nag[i_level].pos[idx] # k = x_search.shape[0] # neighbors = knn_2(x_search, x_query, k, r_max=self.r)[0] # # # Convert neighborhoods to node indices for `NAG.select()` # neighbors = neighbors.flatten() # idx = neighbors[neighbors != -1].unique() # TODO: Assuming idx.shape[0] is small, we search spherical # samplings one by one, without any fancy KNN search tool, # because it seems faster that way, probably due to the large # number of neighbors idx_select_list = [] pos = nag[i_level].pos for i in idx: distance = (pos - pos[i].view(1, -1)).norm(dim=1) idx_select_list.append(torch.where(distance < self.r)[0]) idx_select = torch.cat(idx_select_list).unique() # Select the nodes and update the NAG structure accordingly return nag.select(i_level, idx_select) class SampleEdges(Transform): """Sample edges based on which source node they belong to. The sampling operation is run without replacement and each source segment has at least `n_min` and at most `n_max` edges, within the limits allowed by its actual number of edges. :param level: int or str Level at which to sample edges. Can be an int or a str. If the latter, 'all' will apply on all levels, 'i+' will apply on level-i and above, 'i-' will apply on level-i and below :param n_min: int or List(int) Minimum number of edges for each node, within the limits of its input number of edges :param n_max: int or List(int) Maximum number of edges for each node """ _IN_TYPE = NAG _OUT_TYPE = NAG def __init__(self, level='1+', n_min=16, n_max=32): assert isinstance(level, (int, str)) assert isinstance(n_min, (int, list)) assert isinstance(n_max, (int, list)) self.level = level self.n_min = n_min self.n_max = n_max def _process(self, nag): # If 'level' is an int, we only need to process a single level if isinstance(self.level, int): nag._list[self.level] = self._process_single_level( nag[self.level], self.n_min, self.n_max) return nag # If 'level' covers multiple levels, iteratively process levels level_n_min = [-1] * nag.num_levels level_n_max = [-1] * nag.num_levels if self.level == 'all': level_n_min = self.n_min if isinstance(self.n_min, list) \ else [self.n_min] * nag.num_levels level_n_max = self.n_max if isinstance(self.n_max, list) \ else [self.n_max] * nag.num_levels elif self.level[-1] == '+': i = int(self.level[:-1]) level_n_min[i:] = self.n_min if isinstance(self.n_min, list) \ else [self.n_min] * (nag.num_levels - i) level_n_max[i:] = self.n_max if isinstance(self.n_max, list) \ else [self.n_max] * (nag.num_levels - i) elif self.level[-1] == '-': i = int(self.level[:-1]) level_n_min[:i] = self.n_min if isinstance(self.n_min, list) \ else [self.n_min] * i level_n_max[:i] = self.n_max if isinstance(self.n_max, list) \ else [self.n_max] * i else: raise ValueError(f'Unsupported level={self.level}') for i_level, (n_min, n_max) in enumerate(zip(level_n_min, level_n_max)): nag._list[i_level] = self._process_single_level( nag[i_level], n_min, n_max) return nag @staticmethod def _process_single_level(data, n_min, n_max): # Skip process if n_min or n_max is negative or if in put Data # has not edges if n_min < 0 or n_max < 0 or not data.has_edges: return data # Compute a sampling for the edges, based on the source node # they belong to idx = sparse_sample( data.edge_index[0], n_max=n_max, n_min=n_min, return_pointers=False) # Select edges and their attributes, if relevant data.edge_index = data.edge_index[:, idx] if data.has_edge_attr: data.edge_attr = data.edge_attr[idx] for key in data.edge_keys: data[key] = data[key][idx] return data class RestrictSize(Transform): """Randomly sample nodes and edges to restrict their number within given limits. This is useful for stabilizing memory use of the model. :param num_nodes: int Maximum number of nodes. If the input has more, a subset of `num_nodes` nodes will be randomly sampled. No sampling if <=0 :param num_edges: int Maximum number of edges. If the input has more, a subset of `num_edges` edges will be randomly sampled. No sampling if <=0 """ def __init__(self, num_nodes=0, num_edges=0): self.num_nodes = num_nodes self.num_edges = num_edges def _process(self, data): if data.num_nodes > self.num_nodes and self.num_nodes > 0: weights = torch.ones(data.num_nodes, device=data.device) idx = torch.multinomial(weights, self.num_nodes, replacement=False) data = data.select(idx) if data.num_edges > self.num_edges and self.num_edges > 0: weights = torch.ones(data.num_edges, device=data.device) idx = torch.multinomial(weights, self.num_edges, replacement=False) data.edge_index = data.edge_index[:, idx] if data.has_edge_attr: data.edge_attr = data.edge_attr[idx] for key in data.edge_keys: data[key] = data[key][idx] return data class NAGRestrictSize(Transform): """Randomly sample nodes and edges to restrict their number within given limits. This is useful for stabilizing memory use of the model. :param num_nodes: int Maximum number of nodes. If the input has more, a subset of `num_nodes` nodes will be randomly sampled. No sampling if <=0 :param num_edges: int Maximum number of edges. If the input has more, a subset of `num_edges` edges will be randomly sampled. No sampling if <=0 """ _IN_TYPE = NAG _OUT_TYPE = NAG def __init__(self, level='1+', num_nodes=0, num_edges=0): assert isinstance(level, (int, str)) assert isinstance(num_nodes, (int, list)) assert isinstance(num_edges, (int, list)) self.level = level self.num_nodes = num_nodes self.num_edges = num_edges def _process(self, nag): # If 'level' is an int, we only need to process a single level if isinstance(self.level, int): return self._restrict_level( nag, self.level, self.num_nodes, self.num_edges) # If 'level' covers multiple levels, iteratively process levels level_num_nodes = [-1] * nag.num_levels level_num_edges = [-1] * nag.num_levels if self.level == 'all': level_num_nodes = self.num_nodes \ if isinstance(self.num_nodes, list) \ else [self.num_nodes] * nag.num_levels level_num_edges = self.num_edges \ if isinstance(self.num_edges, list) \ else [self.num_edges] * nag.num_levels elif self.level[-1] == '+': i = int(self.level[:-1]) level_num_nodes[i:] = self.num_nodes \ if isinstance(self.num_nodes, list) \ else [self.num_nodes] * (nag.num_levels - i) level_num_edges[i:] = self.num_edges \ if isinstance(self.num_edges, list) \ else [self.num_edges] * (nag.num_levels - i) elif self.level[-1] == '-': i = int(self.level[:-1]) level_num_nodes[:i] = self.num_nodes \ if isinstance(self.num_nodes, list) \ else [self.num_nodes] * i level_num_edges[:i] = self.num_edges \ if isinstance(self.num_edges, list) \ else [self.num_edges] * i else: raise ValueError(f'Unsupported level={self.level}') for i_level, (num_nodes, num_edges) in enumerate(zip( level_num_nodes, level_num_edges)): nag = self._restrict_level(nag, i_level, num_nodes, num_edges) return nag @staticmethod def _restrict_level(nag, i_level, num_nodes, num_edges): if nag[i_level].num_nodes > num_nodes and num_nodes > 0: weights = torch.ones(nag[i_level].num_nodes, device=nag.device) idx = torch.multinomial(weights, num_nodes, replacement=False) nag = nag.select(i_level, idx) if nag[i_level].num_edges > num_edges and num_edges > 0: weights = torch.ones(nag[i_level].num_edges, device=nag.device) idx = torch.multinomial(weights, num_edges, replacement=False) nag[i_level].edge_index = nag[i_level].edge_index[:, idx] if nag[i_level].has_edge_attr: nag[i_level].edge_attr = nag[i_level].edge_attr[idx] for key in nag[i_level].edge_keys: nag[i_level][key] = nag[i_level][key][idx] return nag