| | import re |
| | import torch |
| | from torch_geometric.nn.pool import voxel_grid |
| | from torch_geometric.utils import k_hop_subgraph, to_undirected |
| | from torch_cluster import grid_cluster |
| | from torch_scatter import scatter_mean |
| | from torch_geometric.nn.pool.consecutive import consecutive_cluster |
| | from src.utils import fast_randperm, sparse_sample, scatter_pca, sanitize_keys |
| | from src.transforms import Transform |
| | from src.data import Data, NAG, NAGBatch, CSRData, InstanceData, Cluster |
| | from src.utils.histogram import atomic_to_histogram |
| |
|
| |
|
| | __all__ = [ |
| | 'Shuffle', 'SaveNodeIndex', 'NAGSaveNodeIndex', 'GridSampling3D', |
| | 'SampleXYTiling', 'SampleRecursiveMainXYAxisTiling', 'SampleSubNodes', |
| | 'SampleKHopSubgraphs', 'SampleRadiusSubgraphs', 'SampleSegments', |
| | 'SampleEdges', 'RestrictSize', 'NAGRestrictSize'] |
| |
|
| |
|
| | class Shuffle(Transform): |
| | """Shuffle the order of points in a Data object.""" |
| |
|
| | def _process(self, data): |
| | idx = fast_randperm(data.num_points, device=data.device) |
| | return data.select(idx, update_sub=False, update_super=False) |
| |
|
| |
|
| | class SaveNodeIndex(Transform): |
| | """Adds the index of the nodes to the Data object attributes. This |
| | allows tracking nodes from the output back to the input Data object. |
| | """ |
| |
|
| | DEFAULT_KEY = 'node_id' |
| |
|
| | def __init__(self, key=None): |
| | self.key = key if key is not None else self.DEFAULT_KEY |
| |
|
| | def _process(self, data): |
| | idx = torch.arange(0, data.pos.shape[0], device=data.device) |
| | setattr(data, self.key, idx) |
| | return data |
| |
|
| |
|
| | class NAGSaveNodeIndex(SaveNodeIndex): |
| | """SaveNodeIndex, applied to each NAG level. |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def _process(self, nag): |
| | transform = SaveNodeIndex(key=self.key) |
| | for i_level in range(nag.num_levels): |
| | nag._list[i_level] = transform(nag._list[i_level]) |
| | return nag |
| |
|
| |
|
| | class GridSampling3D(Transform): |
| | """ Clusters 3D points into voxels with size :attr:`size`. |
| | |
| | By default, some special keys undergo dedicated grouping mechanisms. |
| | The `_VOTING_KEYS=['y', 'super_index', 'is_val']` keys are grouped |
| | by their majority label. The `_INSTANCE_KEYS=['obj', 'obj_pred']` |
| | keys are grouped into an `InstanceData`, which stores all |
| | instance/panoptic overlap data values in CSR format. The |
| | `_CLUSTER_KEYS=['point_id']` keys are grouped into a `Cluster` |
| | object, which stores indices of child elements for parent clusters |
| | in CSR format. The `_LAST_KEYS=['batch', SaveNodeIndex.DEFAULT_KEY]` |
| | keys are by default grouped following `mode='last'`. |
| | |
| | Besides, for keys where a more subtle histogram mechanism is needed, |
| | (e.g. for 'y'), the 'hist_key' and 'hist_size' arguments can be |
| | used. |
| | |
| | Modified from: https://github.com/torch-points3d/torch-points3d |
| | |
| | Parameters |
| | ---------- |
| | size: float |
| | Size of a voxel (in each dimension). |
| | quantize_coords: bool |
| | If True, it will convert the points into their associated sparse |
| | coordinates within the grid and store the value into a new |
| | `coords` attribute. |
| | mode: string: |
| | The mode can be either `last` or `mean`. |
| | If mode is `mean`, all the points and their features within a |
| | cell will be averaged. If mode is `last`, one random points per |
| | cell will be selected with its associated features. |
| | hist_key: str or List(str) |
| | Data attributes for which we would like to aggregate values into |
| | an histogram. This is typically needed when we want to aggregate |
| | points labels without losing the distribution, as opposed to |
| | majority voting. |
| | hist_size: str or List(str) |
| | Must be of same size as `hist_key`, indicates the number of |
| | bins for each key-histogram. This is typically needed when we |
| | want to aggregate points labels without losing the distribution, |
| | as opposed to majority voting. |
| | inplace: bool |
| | Whether the input Data object should be modified in-place |
| | verbose: bool |
| | Verbosity |
| | """ |
| |
|
| | _NO_REPR = ['verbose', 'inplace'] |
| |
|
| | def __init__( |
| | self, size, quantize_coords=False, mode="mean", hist_key=None, |
| | hist_size=None, inplace=False, verbose=False): |
| |
|
| | hist_key = [] if hist_key is None else hist_key |
| | hist_size = [] if hist_size is None else hist_size |
| | hist_key = [hist_key] if isinstance(hist_key, str) else hist_key |
| | hist_size = [hist_size] if isinstance(hist_size, int) else hist_size |
| |
|
| | assert isinstance(hist_key, list) |
| | assert isinstance(hist_size, list) |
| | assert len(hist_key) == len(hist_size) |
| |
|
| | self.grid_size = size |
| | self.quantize_coords = quantize_coords |
| | self.mode = mode |
| | self.bins = {k: v for k, v in zip(hist_key, hist_size)} |
| | self.inplace = inplace |
| |
|
| | if verbose: |
| | print( |
| | f"If you need to keep track of the position of your points, " |
| | f"use SaveNodeIndex transform before using " |
| | f"{self.__class__.__name__}.") |
| |
|
| | if self.mode == "last": |
| | print( |
| | "The tensors within data will be shuffled each time this " |
| | "transform is applied. Be careful that if an attribute " |
| | "doesn't have the size of num_nodes, it won't be shuffled") |
| |
|
| | def _process(self, data_in): |
| | |
| | data = data_in if self.inplace else data_in.clone() |
| |
|
| | |
| | |
| | if self.mode == 'last': |
| | data = Shuffle()(data) |
| |
|
| | |
| | coords = torch.round((data.pos) / self.grid_size) |
| |
|
| | |
| | if 'batch' not in data: |
| | cluster = grid_cluster(coords, torch.ones(3, device=coords.device)) |
| | else: |
| | cluster = voxel_grid(coords, data.batch, 1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | cluster, unique_pos_indices = consecutive_cluster(cluster) |
| |
|
| | |
| | data = _group_data( |
| | data, cluster, unique_pos_indices, mode=self.mode, bins=self.bins) |
| |
|
| | |
| | |
| | if self.quantize_coords: |
| | data.coords = coords[unique_pos_indices].int() |
| |
|
| | |
| | data.grid_size = torch.tensor([self.grid_size]) |
| |
|
| | return data |
| |
|
| |
|
| | def _group_data( |
| | data, cluster=None, unique_pos_indices=None, mode="mean", |
| | skip_keys=None, bins={}): |
| | """Group data based on indices in cluster. The option ``mode`` |
| | controls how data gets aggregated within each cluster. |
| | |
| | By default, some special keys undergo dedicated grouping mechanisms. |
| | The `_VOTING_KEYS=['y', 'super_index', 'is_val']` keys are grouped |
| | by their majority label. The `_INSTANCE_KEYS=['obj', 'obj_pred']` |
| | keys are grouped into an `InstanceData`, which stores all |
| | instance/panoptic overlap data values in CSR format. The |
| | `_CLUSTER_KEYS=['point_id']` keys are grouped into a `Cluster` |
| | object, which stores indices of child elements for parent clusters |
| | in CSR format. The `_LAST_KEYS=['batch', SaveNodeIndex.DEFAULT_KEY]` |
| | keys are by default grouped following `mode='last'`. |
| | |
| | Besides, for keys where a more subtle histogram mechanism is needed, |
| | (e.g. for 'y'), the 'bins' argument can be used. |
| | |
| | Warning: this function modifies the input Data object in-place. |
| | |
| | :param data : Data |
| | :param cluster : Tensor |
| | Tensor of the same size as the number of points in data. Each |
| | element is the cluster index of that point. |
| | :param unique_pos_indices : Tensor |
| | Tensor containing one index per cluster, this index will be used |
| | to select features and labels. |
| | :param mode : str |
| | Option to select how the features and labels for each voxel is |
| | computed. Can be ``last`` or ``mean``. ``last`` selects the last |
| | point falling in a voxel as the representative, ``mean`` takes |
| | the average. |
| | :param skip_keys: list |
| | Keys of attributes to skip in the grouping. |
| | :param bins: dict |
| | Dictionary holding ``{'key': n_bins}`` where ``key`` is a Data |
| | attribute for which we would like to aggregate values into an |
| | histogram and ``n_bins`` accounts for the corresponding number |
| | of bins. This is typically needed when we want to aggregate |
| | point labels without losing the distribution, as opposed to |
| | majority voting. |
| | """ |
| | skip_keys = sanitize_keys(skip_keys, default=[]) |
| |
|
| | |
| | _VOTING_KEYS = ['y', 'super_index', 'is_val'] |
| |
|
| | |
| | |
| | _INSTANCE_KEYS = ['obj', 'obj_pred'] |
| |
|
| | |
| | |
| | _CLUSTER_KEYS = ['sub'] |
| |
|
| | |
| | _LAST_KEYS = ['batch', SaveNodeIndex.DEFAULT_KEY] |
| |
|
| | |
| | |
| | _NORMAL_KEYS = ['normal'] |
| |
|
| | |
| | _MODES = ['mean', 'last'] |
| | assert mode in _MODES |
| | if mode == "mean" and cluster is None: |
| | raise ValueError( |
| | "In mean mode the cluster argument needs to be specified") |
| | if mode == "last" and unique_pos_indices is None: |
| | raise ValueError( |
| | "In last mode the unique_pos_indices argument needs to be specified") |
| |
|
| | |
| | |
| | num_nodes = data.num_nodes |
| |
|
| | |
| | for key, item in data: |
| |
|
| | |
| | if key in skip_keys: |
| | continue |
| |
|
| | |
| | if bool(re.search('edge', key)): |
| | raise NotImplementedError("Edges not supported. Wrong data type.") |
| |
|
| | |
| | |
| | |
| | if key in _INSTANCE_KEYS: |
| | if isinstance(item, InstanceData): |
| | data[key] = item.merge(cluster) |
| | else: |
| | count = torch.ones_like(item) |
| | y = data.y if getattr(data, 'y', None) is not None \ |
| | else torch.zeros_like(item) |
| | data[key] = InstanceData(cluster, item, count, y, dense=True) |
| | continue |
| | |
| | |
| | |
| | if key in _CLUSTER_KEYS: |
| | if (isinstance(item, torch.Tensor) and item.dim() == 1 |
| | and not item.is_floating_point()): |
| | data[key] = Cluster(cluster, item, dense=True) |
| | else: |
| | raise NotImplementedError( |
| | f"Cannot merge '{key}' with data type: {type(item)} into " |
| | f"a Cluster object. Only supports 1D Tensor of integers.") |
| | continue |
| |
|
| | |
| | if isinstance(item, CSRData): |
| | raise NotImplementedError( |
| | f"Cannot merge '{key}' with data type: {type(item)}") |
| |
|
| | |
| | |
| | if not torch.is_tensor(item) or item.size(0) != num_nodes: |
| | continue |
| |
|
| | |
| | |
| | |
| | if mode == 'last' or key in _LAST_KEYS: |
| | data[key] = item[unique_pos_indices] |
| | continue |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | is_item_bool = item.dtype == torch.bool |
| | if is_item_bool: |
| | item = item.int() |
| |
|
| | |
| | if key in _VOTING_KEYS or key in bins.keys(): |
| | voting = key not in bins.keys() |
| | n_bins = item.max() + 1 if voting else bins[key] |
| | hist = atomic_to_histogram(item, cluster, n_bins=n_bins) |
| | data[key] = hist.argmax(dim=-1) if voting else hist |
| |
|
| | |
| | |
| | else: |
| | data[key] = scatter_mean(item, cluster, dim=0) |
| |
|
| | |
| | if key in _NORMAL_KEYS: |
| | data[key] = data[key] / data[key].norm(dim=1).view(-1, 1) |
| |
|
| | |
| | if is_item_bool: |
| | data[key] = data[key].bool() |
| |
|
| | return data |
| |
|
| |
|
| | class SampleXYTiling(Transform): |
| | """Tile the input Data along the XY axes and select only a given |
| | tile. This is useful to reduce the size of very large clouds at |
| | preprocessing time. |
| | |
| | :param x: int |
| | x coordinate of the sample in the tiling grid |
| | :param y: int |
| | x coordinate of the sample in the tiling grid |
| | :param tiling: int or tuple(int, int) |
| | Number of tiles in the grid in each direction. If a tuple is |
| | passed, each direction can be tiled independently |
| | """ |
| |
|
| | def __init__(self, x=0, y=0, tiling=2): |
| | tiling = (tiling, tiling) if isinstance(tiling, int) else tiling |
| | assert 0 <= x < tiling[0] |
| | assert 0 <= y < tiling[1] |
| | self.tiling = torch.as_tensor(tiling) |
| | self.x = x |
| | self.y = y |
| |
|
| | def _process(self, data): |
| | |
| | xy = data.pos[:, :2].clone().view(-1, 2) |
| | xy -= xy.min(dim=0).values.view(1, 2) |
| | xy /= xy.max(dim=0).values.view(1, 2) |
| | xy = xy.clip(min=0, max=1) * self.tiling.view(1, 2) |
| | xy = xy.long() |
| |
|
| | |
| | idx = torch.where((xy[:, 0] == self.x) & (xy[:, 1] == self.y))[0] |
| |
|
| | return data.select(idx)[0] |
| |
|
| |
|
| | class SampleRecursiveMainXYAxisTiling(Transform): |
| | """Tile the input Data by recursively splitting the points along |
| | their principal XY direction and select only a given tile. This is |
| | useful to reduce the size of very large clouds at preprocessing |
| | time, when clouds are not XY-aligned or have non-trivial geometries. |
| | |
| | :param x: int |
| | x coordinate of the sample in the tiling structure. The tiles |
| | are "lexicographically" ordered, with the points lying below the |
| | median of each split considered before those above the median |
| | :param steps: int |
| | Number of splitting steps. By construction, the total number of |
| | tiles is 2**steps |
| | """ |
| |
|
| | def __init__(self, x=0, steps=2): |
| | assert 0 <= x < 2 ** steps |
| | self.steps = steps |
| | self.x = x |
| |
|
| | def _process(self, data): |
| | |
| | if self.steps <= 0: |
| | return data |
| |
|
| | |
| | for p in self.binary_tree_path: |
| | data = self.split_by_main_xy_direction(data, left=not p, right=p) |
| |
|
| | return data |
| |
|
| | @property |
| | def binary_tree_path(self): |
| | |
| | path = bin(self.x)[2:] |
| |
|
| | |
| | path = (self.steps - len(path)) * '0' + path |
| |
|
| | |
| | return [bool(int(i)) for i in path] |
| |
|
| | @staticmethod |
| | def split_by_main_xy_direction(data, left=True, right=True): |
| | assert left or right, "At least one split must be returned" |
| |
|
| | |
| | |
| | v = SampleRecursiveMainXYAxisTiling.compute_main_xy_direction(data) |
| | if v[0] < 0: |
| | v *= -1 |
| |
|
| | |
| | proj = (data.pos[:, :2] * v.view(1, -1)).sum(dim=1) |
| | mask = proj < proj.median() |
| |
|
| | if left and not right: |
| | return data.select(mask)[0] |
| | if right and not left: |
| | return data.select(~mask)[0] |
| | return data.select(mask)[0], data.select(~mask)[0] |
| |
|
| | @staticmethod |
| | def compute_main_xy_direction(data): |
| | |
| | data = Data(pos=data.pos.clone()) |
| |
|
| | |
| | xy = data.pos[:, :2] |
| | xy -= xy.min(dim=0).values.view(1, -1) |
| | voxel = xy.max() / 100 |
| |
|
| | |
| | data.pos[:, 2] = 0 |
| |
|
| | |
| | data = GridSampling3D(size=voxel)(data) |
| |
|
| | |
| | idx = torch.zeros_like(data.pos[:, 0], dtype=torch.long) |
| | v = scatter_pca(data.pos, idx, on_cpu=True)[1][0][:2, -1] |
| |
|
| | return v |
| |
|
| |
|
| | class SampleSubNodes(Transform): |
| | """Sample elements at `low`-level, based on which segment they |
| | belong to at `high`-level. |
| | |
| | The sampling operation is run without replacement and each segment |
| | is sampled at least `n_min` and at most `n_max` times, within the |
| | limits allowed by its actual size. |
| | |
| | Optionally, a `mask` can be passed to filter out some `low`-level |
| | points. |
| | |
| | :param high: int |
| | Partition level of the segments we want to sample. By default, |
| | `high=1` to sample the level-1 segments |
| | :param low: int |
| | Partition level we will sample from, guided by the `high` |
| | segments. By default, `high=0` to sample the level-0 points. |
| | `low=-1` is accepted when level-0 has a `sub` attribute (i.e. |
| | level-0 points are themselves segments of `-1` level absent |
| | from the NAG object). |
| | :param n_max: int |
| | Maximum number of `low`-level elements to sample in each |
| | `high`-level segment |
| | :param n_min: int |
| | Minimum number of `low`-level elements to sample in each |
| | `high`-level segment, within the limits of its size (i.e. no |
| | oversampling) |
| | :param mask: list, np.ndarray, torch.LongTensor, torch.BoolTensor |
| | Indicates a subset of `low`-level elements to consider. This |
| | allows ignoring |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__( |
| | self, high=1, low=0, n_max=32, n_min=16, mask=None): |
| | assert isinstance(high, int) |
| | assert isinstance(low, int) |
| | assert isinstance(n_max, int) |
| | assert isinstance(n_min, int) |
| | self.high = high |
| | self.low = low |
| | self.n_max = n_max |
| | self.n_min = n_min |
| | self.mask = mask |
| |
|
| | def _process(self, nag): |
| | idx = nag.get_sampling( |
| | high=self.high, low=self.low, n_max=self.n_max, n_min=self.n_min, |
| | return_pointers=False) |
| | return nag.select(self.low, idx) |
| |
|
| |
|
| | class SampleSegments(Transform): |
| | """Remove randomly-picked nodes from each level 1+ of the NAG. This |
| | operation relies on `NAG.select()` to maintain index consistency |
| | across the NAG levels. |
| | |
| | Note: we do not directly prune level-0 points, see `SampleSubNodes` |
| | for that. For speed consideration, it is recommended to use |
| | `SampleSubNodes` first before `SampleSegments`, to minimize the |
| | number of level-0 points to manipulate. |
| | |
| | :param ratio: float or list(float) |
| | Portion of nodes to be dropped. A list may be passed to prune |
| | NAG 1+ levels with different probabilities |
| | :param by_size: bool |
| | If True, the segment size will affect the chances of being |
| | dropped out. The smaller the segment, the greater its chances |
| | to be dropped |
| | :param by_class: bool |
| | If True, the classes will affect the chances of being |
| | dropped out. The more frequent the segment class, the greater |
| | its chances to be dropped |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, ratio=0.2, by_size=False, by_class=False): |
| | assert isinstance(ratio, list) and all(0 <= r < 1 for r in ratio) \ |
| | or (0 <= ratio < 1) |
| | self.ratio = ratio |
| | self.by_size = by_size |
| | self.by_class = by_class |
| |
|
| | def _process(self, nag): |
| | if not isinstance(self.ratio, list): |
| | ratio = [self.ratio] * (nag.num_levels - 1) |
| | else: |
| | ratio = self.ratio |
| |
|
| | |
| | |
| | |
| | device = nag.device |
| | for i_level in range(nag.num_levels - 1, 0, -1): |
| |
|
| | |
| | if ratio[i_level - 1] <= 0: |
| | continue |
| |
|
| | |
| | num_nodes = nag[i_level].num_nodes |
| | num_keep = num_nodes - int(num_nodes * ratio[i_level - 1]) |
| |
|
| | |
| | weights = torch.ones(num_nodes, device=device) |
| |
|
| | |
| | |
| | |
| | if self.by_size: |
| | node_size = nag.get_sub_size(i_level, low=0) |
| | size_weights = node_size ** 0.333 |
| | size_weights /= size_weights.sum() |
| | weights += size_weights |
| |
|
| | |
| | |
| | |
| | |
| | if self.by_class and nag[i_level].y is not None: |
| | counts = nag[i_level].y.sum(dim=0).sqrt() |
| | scores = 1 / (counts + 1) |
| | scores /= scores.sum() |
| | mask = nag[i_level].y.gt(0) |
| | class_weights = (mask * scores.view(1, -1)).max(dim=1).values |
| | class_weights /= class_weights.sum() |
| | weights += class_weights.squeeze() |
| |
|
| | |
| | |
| | weights /= weights.sum() |
| |
|
| | |
| | idx = torch.multinomial(weights, num_keep, replacement=False) |
| |
|
| | |
| | nag = nag.select(i_level, idx) |
| |
|
| | return nag |
| |
|
| |
|
| | class BaseSampleSubgraphs(Transform): |
| | """Base class for sampling subgraphs from a NAG. It randomly picks |
| | `k` seed nodes from `i_level`, from which `k` subgraphs can be |
| | grown. Child classes must implement `_sample_subgraphs()` to |
| | describe how these subgraphs are built. Optionally, the see sampling |
| | can be driven by their class, or their size, using `by_class` and |
| | `by_size`, respectively. |
| | |
| | This operation relies on `NAG.select()` to maintain index |
| | consistency across the NAG levels. |
| | |
| | :param i_level: int |
| | Partition level we want to pick from. By default, `i_level=-1` |
| | will sample the highest level of the input NAG |
| | :param k: int |
| | Number of sub-graphs/seeds to pick |
| | :param by_size: bool |
| | If True, the segment size will affect the chances of being |
| | selected as a seed. The larger the segment, the greater its |
| | chances to be picked |
| | :param by_class: bool |
| | If True, the classes will affect the chances of being |
| | selected as a seed. The scarcer the segment class, the greater |
| | its chances to be selected |
| | :param use_batch: bool |
| | If True, the 'Data.batch' attribute will be used to guide seed |
| | sampling across batches. More specifically, if the input NAG is |
| | a NAGBatch made up of multiple NAGs, the subgraphs will be |
| | sampled in a way that guarantees each NAG is sampled from. |
| | Obviously enough, if `k < batch.max() + 1`, not all NAGs will be |
| | sampled from |
| | :param disjoint: bool |
| | If True, subgraphs sampled from the same NAG will be separated |
| | as distinct NAGs themselves. Instead, when `disjoint=False`, |
| | subgraphs sampled in the same NAG will be long the same NAG. |
| | Hence, if two subgraphs share a node, they will be connected |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__( |
| | self, i_level=1, k=1, by_size=False, by_class=False, |
| | use_batch=True, disjoint=True): |
| | self.i_level = i_level |
| | self.k = k |
| | self.by_size = by_size |
| | self.by_class = by_class |
| | self.use_batch = use_batch |
| | self.disjoint = disjoint |
| |
|
| | def _process(self, nag): |
| | device = nag.device |
| |
|
| | |
| | |
| | if self.i_level is None or self.k <= 0: |
| | return nag |
| |
|
| | |
| | i_level = self.i_level if 0 <= self.i_level < nag.num_levels \ |
| | else nag.num_levels - 1 |
| | k = self.k if self.k < nag[i_level].num_nodes \ |
| | else 1 |
| |
|
| | |
| | weights = torch.ones(nag[i_level].num_nodes, device=device) |
| |
|
| | |
| | |
| | |
| | if self.by_size: |
| | node_size = nag.get_sub_size(i_level, low=0) |
| | size_weights = node_size ** 0.333 |
| | size_weights /= size_weights.sum() |
| | weights += size_weights |
| |
|
| | |
| | |
| | |
| | |
| | if self.by_class and nag[i_level].y is not None: |
| | counts = nag[i_level].y.sum(dim=0).sqrt() |
| | scores = 1 / (counts + 1) |
| | scores /= scores.sum() |
| | mask = nag[i_level].y.gt(0) |
| | class_weights = (mask * scores.view(1, -1)).max(dim=1).values |
| | class_weights /= class_weights.sum() |
| | weights += class_weights.squeeze() |
| |
|
| | |
| | |
| | weights /= weights.sum() |
| |
|
| | |
| | |
| | |
| | batch = getattr(nag[i_level], 'batch', None) |
| | if batch is not None and self.use_batch: |
| | idx_list = [] |
| | num_batch = batch.max() + 1 |
| | num_sampled = 0 |
| | k_batch = torch.div(k, num_batch, rounding_mode='floor') |
| | k_batch = k_batch.maximum(torch.ones_like(k_batch)) |
| | for i_batch in range(num_batch): |
| |
|
| | |
| | |
| | |
| | if i_batch >= num_batch - 1: |
| | k_batch = k - num_sampled |
| |
|
| | |
| | mask = torch.where(i_batch == batch)[0] |
| | idx_ = torch.multinomial( |
| | weights[mask], k_batch, replacement=False) |
| | idx_list.append(mask[idx_]) |
| |
|
| | |
| | num_sampled += k_batch |
| | if num_sampled >= k: |
| | break |
| |
|
| | |
| | idx = torch.cat(idx_list) |
| | else: |
| | idx = torch.multinomial(weights, k, replacement=False) |
| |
|
| | |
| | |
| | if not self.disjoint: |
| | return self._sample_subgraphs(nag, i_level, idx) |
| |
|
| | |
| | return NAGBatch.from_nag_list([ |
| | self._sample_subgraphs(nag, i_level, i.view(1)) for i in idx]) |
| |
|
| | def _sample_subgraphs(self, nag, i_level, idx): |
| | raise NotImplementedError |
| |
|
| |
|
| | class SampleKHopSubgraphs(BaseSampleSubgraphs): |
| | """Randomly pick segments from `i_level`, along with their `hops` |
| | neighbors. This can be thought as a spherical sampling in the graph |
| | of i_level. |
| | |
| | This operation relies on `NAG.select()` to maintain index |
| | consistency across the NAG levels. |
| | |
| | Note: we do not directly sample level-0 points, see `SampleSubNodes` |
| | for that. For speed consideration, it is recommended to use |
| | `SampleSubNodes` first before `SampleKHopSubgraphs`, to minimize the |
| | number of level-0 points to manipulate. |
| | |
| | :param hops: int |
| | Number of hops ruling the neighborhood size selected around the |
| | seed nodes |
| | :param i_level: int |
| | Partition level we want to pick from. By default, `i_level=-1` |
| | will sample the highest level of the input NAG |
| | :param k: int |
| | Number of sub-graphs/seeds to pick |
| | :param by_size: bool |
| | If True, the segment size will affect the chances of being |
| | selected as a seed. The larger the segment, the greater its |
| | chances to be picked |
| | :param by_class: bool |
| | If True, the classes will affect the chances of being |
| | selected as a seed. The scarcer the segment class, the greater |
| | its chances to be selected |
| | :param use_batch: bool |
| | If True, the 'Data.batch' attribute will be used to guide seed |
| | sampling across batches. More specifically, if the input NAG is |
| | a NAGBatch made up of multiple NAGs, the subgraphs will be |
| | sampled in a way that guarantees each NAG is sampled from. |
| | Obviously enough, if `k < batch.max() + 1`, not all NAGs will be |
| | sampled from |
| | :param disjoint: bool |
| | If True, subgraphs sampled from the same NAG will be separated |
| | as distinct NAGs themselves. Instead, when `disjoint=False`, |
| | subgraphs sampled in the same NAG will be long the same NAG. |
| | Hence, if two subgraphs share a node, they will be connected |
| | """ |
| | def __init__( |
| | self, hops=2, i_level=1, k=1, by_size=False, by_class=False, |
| | use_batch=True, disjoint=False): |
| | super().__init__( |
| | i_level=i_level, k=k, by_size=by_size, by_class=by_class, |
| | use_batch=use_batch, disjoint=disjoint) |
| | self.hops = hops |
| |
|
| | def _sample_subgraphs(self, nag, i_level, idx): |
| | assert nag[i_level].has_edges, \ |
| | "Expected Data object to have edges for k-hop subgraph sampling" |
| |
|
| | |
| | |
| | |
| | |
| | edge_index = to_undirected(nag[i_level].edge_index) |
| |
|
| | |
| | idx = k_hop_subgraph( |
| | idx, self.hops, edge_index, num_nodes=nag[i_level].num_nodes)[0] |
| |
|
| | |
| | return nag.select(i_level, idx) |
| |
|
| |
|
| | class SampleRadiusSubgraphs(BaseSampleSubgraphs): |
| | """Randomly pick segments from `i_level`, along with their |
| | spherical neighborhood of fixed radius. |
| | |
| | This operation relies on `NAG.select()` to maintain index |
| | consistency across the NAG levels. |
| | |
| | Note: we do not directly sample level-0 points, see `SampleSubNodes` |
| | for that. For speed consideration, it is recommended to use |
| | `SampleSubNodes` first before `SampleRadiusSubgraphs`, to minimize |
| | the number of level-0 points to manipulate. |
| | |
| | :param r: float |
| | Radius for spherical sampling |
| | :param i_level: int |
| | Partition level we want to pick from. By default, `i_level=-1` |
| | will sample the highest level of the input NAG |
| | :param k: int |
| | Number of sub-graphs/seeds to pick |
| | :param by_size: bool |
| | If True, the segment size will affect the chances of being |
| | selected as a seed. The larger the segment, the greater its |
| | chances to be picked |
| | :param by_class: bool |
| | If True, the classes will affect the chances of being |
| | selected as a seed. The scarcer the segment class, the greater |
| | its chances to be selected |
| | :param use_batch: bool |
| | If True, the 'Data.batch' attribute will be used to guide seed |
| | sampling across batches. More specifically, if the input NAG is |
| | a NAGBatch made up of multiple NAGs, the subgraphs will be |
| | sampled in a way that guarantees each NAG is sampled from. |
| | Obviously enough, if `k < batch.max() + 1`, not all NAGs will be |
| | sampled from |
| | :param disjoint: bool |
| | If True, subgraphs sampled from the same NAG will be separated |
| | as distinct NAGs themselves. Instead, when `disjoint=False`, |
| | subgraphs sampled in the same NAG will be long the same NAG. |
| | Hence, if two subgraphs share a node, they will be connected |
| | """ |
| | def __init__( |
| | self, r=2, i_level=1, k=1, by_size=False, by_class=False, |
| | use_batch=True, disjoint=False): |
| | super().__init__( |
| | i_level=i_level, k=k, by_size=by_size, by_class=by_class, |
| | use_batch=use_batch, disjoint=disjoint) |
| | self.r = r |
| |
|
| | def _sample_subgraphs(self, nag, i_level, idx): |
| | |
| | |
| | if self.r <= 0: |
| | return nag |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | idx_select_list = [] |
| | pos = nag[i_level].pos |
| | for i in idx: |
| | distance = (pos - pos[i].view(1, -1)).norm(dim=1) |
| | idx_select_list.append(torch.where(distance < self.r)[0]) |
| | idx_select = torch.cat(idx_select_list).unique() |
| |
|
| | |
| | return nag.select(i_level, idx_select) |
| |
|
| |
|
| | class SampleEdges(Transform): |
| | """Sample edges based on which source node they belong to. |
| | |
| | The sampling operation is run without replacement and each source |
| | segment has at least `n_min` and at most `n_max` edges, within the |
| | limits allowed by its actual number of edges. |
| | |
| | :param level: int or str |
| | Level at which to sample edges. Can be an int or a str. If the |
| | latter, 'all' will apply on all levels, 'i+' will apply on |
| | level-i and above, 'i-' will apply on level-i and below |
| | :param n_min: int or List(int) |
| | Minimum number of edges for each node, within the limits of its |
| | input number of edges |
| | :param n_max: int or List(int) |
| | Maximum number of edges for each node |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, level='1+', n_min=16, n_max=32): |
| | assert isinstance(level, (int, str)) |
| | assert isinstance(n_min, (int, list)) |
| | assert isinstance(n_max, (int, list)) |
| | self.level = level |
| | self.n_min = n_min |
| | self.n_max = n_max |
| |
|
| | def _process(self, nag): |
| |
|
| | |
| | if isinstance(self.level, int): |
| | nag._list[self.level] = self._process_single_level( |
| | nag[self.level], self.n_min, self.n_max) |
| | return nag |
| |
|
| | |
| | level_n_min = [-1] * nag.num_levels |
| | level_n_max = [-1] * nag.num_levels |
| |
|
| | if self.level == 'all': |
| | level_n_min = self.n_min if isinstance(self.n_min, list) \ |
| | else [self.n_min] * nag.num_levels |
| | level_n_max = self.n_max if isinstance(self.n_max, list) \ |
| | else [self.n_max] * nag.num_levels |
| | elif self.level[-1] == '+': |
| | i = int(self.level[:-1]) |
| | level_n_min[i:] = self.n_min if isinstance(self.n_min, list) \ |
| | else [self.n_min] * (nag.num_levels - i) |
| | level_n_max[i:] = self.n_max if isinstance(self.n_max, list) \ |
| | else [self.n_max] * (nag.num_levels - i) |
| | elif self.level[-1] == '-': |
| | i = int(self.level[:-1]) |
| | level_n_min[:i] = self.n_min if isinstance(self.n_min, list) \ |
| | else [self.n_min] * i |
| | level_n_max[:i] = self.n_max if isinstance(self.n_max, list) \ |
| | else [self.n_max] * i |
| | else: |
| | raise ValueError(f'Unsupported level={self.level}') |
| |
|
| | for i_level, (n_min, n_max) in enumerate(zip(level_n_min, level_n_max)): |
| | nag._list[i_level] = self._process_single_level( |
| | nag[i_level], n_min, n_max) |
| |
|
| | return nag |
| |
|
| | @staticmethod |
| | def _process_single_level(data, n_min, n_max): |
| | |
| | |
| | if n_min < 0 or n_max < 0 or not data.has_edges: |
| | return data |
| |
|
| | |
| | |
| | idx = sparse_sample( |
| | data.edge_index[0], n_max=n_max, n_min=n_min, return_pointers=False) |
| |
|
| | |
| | data.edge_index = data.edge_index[:, idx] |
| | if data.has_edge_attr: |
| | data.edge_attr = data.edge_attr[idx] |
| | for key in data.edge_keys: |
| | data[key] = data[key][idx] |
| |
|
| | return data |
| |
|
| |
|
| | class RestrictSize(Transform): |
| | """Randomly sample nodes and edges to restrict their number within |
| | given limits. This is useful for stabilizing memory use of the |
| | model. |
| | |
| | :param num_nodes: int |
| | Maximum number of nodes. If the input has more, a subset of |
| | `num_nodes` nodes will be randomly sampled. No sampling if <=0 |
| | :param num_edges: int |
| | Maximum number of edges. If the input has more, a subset of |
| | `num_edges` edges will be randomly sampled. No sampling if <=0 |
| | """ |
| |
|
| | def __init__(self, num_nodes=0, num_edges=0): |
| | self.num_nodes = num_nodes |
| | self.num_edges = num_edges |
| |
|
| | def _process(self, data): |
| | if data.num_nodes > self.num_nodes and self.num_nodes > 0: |
| | weights = torch.ones(data.num_nodes, device=data.device) |
| | idx = torch.multinomial(weights, self.num_nodes, replacement=False) |
| | data = data.select(idx) |
| |
|
| | if data.num_edges > self.num_edges and self.num_edges > 0: |
| | weights = torch.ones(data.num_edges, device=data.device) |
| | idx = torch.multinomial(weights, self.num_edges, replacement=False) |
| |
|
| | data.edge_index = data.edge_index[:, idx] |
| | if data.has_edge_attr: |
| | data.edge_attr = data.edge_attr[idx] |
| | for key in data.edge_keys: |
| | data[key] = data[key][idx] |
| |
|
| | return data |
| |
|
| |
|
| | class NAGRestrictSize(Transform): |
| | """Randomly sample nodes and edges to restrict their number within |
| | given limits. This is useful for stabilizing memory use of the |
| | model. |
| | |
| | :param num_nodes: int |
| | Maximum number of nodes. If the input has more, a subset of |
| | `num_nodes` nodes will be randomly sampled. No sampling if <=0 |
| | :param num_edges: int |
| | Maximum number of edges. If the input has more, a subset of |
| | `num_edges` edges will be randomly sampled. No sampling if <=0 |
| | """ |
| |
|
| | _IN_TYPE = NAG |
| | _OUT_TYPE = NAG |
| |
|
| | def __init__(self, level='1+', num_nodes=0, num_edges=0): |
| | assert isinstance(level, (int, str)) |
| | assert isinstance(num_nodes, (int, list)) |
| | assert isinstance(num_edges, (int, list)) |
| | self.level = level |
| | self.num_nodes = num_nodes |
| | self.num_edges = num_edges |
| |
|
| | def _process(self, nag): |
| |
|
| | |
| | if isinstance(self.level, int): |
| | return self._restrict_level( |
| | nag, self.level, self.num_nodes, self.num_edges) |
| |
|
| | |
| | level_num_nodes = [-1] * nag.num_levels |
| | level_num_edges = [-1] * nag.num_levels |
| |
|
| | if self.level == 'all': |
| | level_num_nodes = self.num_nodes \ |
| | if isinstance(self.num_nodes, list) \ |
| | else [self.num_nodes] * nag.num_levels |
| | level_num_edges = self.num_edges \ |
| | if isinstance(self.num_edges, list) \ |
| | else [self.num_edges] * nag.num_levels |
| | elif self.level[-1] == '+': |
| | i = int(self.level[:-1]) |
| | level_num_nodes[i:] = self.num_nodes \ |
| | if isinstance(self.num_nodes, list) \ |
| | else [self.num_nodes] * (nag.num_levels - i) |
| | level_num_edges[i:] = self.num_edges \ |
| | if isinstance(self.num_edges, list) \ |
| | else [self.num_edges] * (nag.num_levels - i) |
| | elif self.level[-1] == '-': |
| | i = int(self.level[:-1]) |
| | level_num_nodes[:i] = self.num_nodes \ |
| | if isinstance(self.num_nodes, list) \ |
| | else [self.num_nodes] * i |
| | level_num_edges[:i] = self.num_edges \ |
| | if isinstance(self.num_edges, list) \ |
| | else [self.num_edges] * i |
| | else: |
| | raise ValueError(f'Unsupported level={self.level}') |
| |
|
| | for i_level, (num_nodes, num_edges) in enumerate(zip( |
| | level_num_nodes, level_num_edges)): |
| | nag = self._restrict_level(nag, i_level, num_nodes, num_edges) |
| |
|
| | return nag |
| |
|
| | @staticmethod |
| | def _restrict_level(nag, i_level, num_nodes, num_edges): |
| |
|
| | if nag[i_level].num_nodes > num_nodes and num_nodes > 0: |
| | weights = torch.ones(nag[i_level].num_nodes, device=nag.device) |
| | idx = torch.multinomial(weights, num_nodes, replacement=False) |
| | nag = nag.select(i_level, idx) |
| |
|
| | if nag[i_level].num_edges > num_edges and num_edges > 0: |
| | weights = torch.ones(nag[i_level].num_edges, device=nag.device) |
| | idx = torch.multinomial(weights, num_edges, replacement=False) |
| |
|
| | nag[i_level].edge_index = nag[i_level].edge_index[:, idx] |
| | if nag[i_level].has_edge_attr: |
| | nag[i_level].edge_attr = nag[i_level].edge_attr[idx] |
| | for key in nag[i_level].edge_keys: |
| | nag[i_level][key] = nag[i_level][key][idx] |
| |
|
| | return nag |