| | from torch import nn |
| | from copy import copy |
| | from itertools import product |
| | from src.utils.instance import instance_cut_pursuit |
| |
|
| |
|
| | __all__ = ['InstancePartitioner'] |
| |
|
| |
|
| | class InstancePartitioner(nn.Module): |
| | """Partition a graph into instances using cut-pursuit. |
| | More specifically, this step will group nodes together based on: |
| | - node offset position |
| | - node predicted classification logits |
| | - node size |
| | - edge affinity |
| | |
| | NB: This operation relies on the parallel cut-pursuit algorithm: |
| | https://gitlab.com/1a7r0ch3/parallel-cut-pursuit |
| | Currently, this implementation is non-differentiable and runs on |
| | CPU. |
| | |
| | :param loss_type: str |
| | Rules the loss applied on the node features. Accepts one of |
| | 'l2' (L2 loss on node features and probabilities), |
| | 'l2_kl' (L2 loss on node features and Kullback-Leibler |
| | divergence on node probabilities) |
| | :param regularization: float |
| | Regularization parameter for the partition |
| | :param x_weight: float |
| | Weight used to mitigate the impact of the node position in the |
| | partition. The larger, the less spatial coordinates matter |
| | :param p_weight: float |
| | Weight used to mitigate the impact of the node probabilities in |
| | the partition. The larger, the greater the impact |
| | :param cutoff: float |
| | Minimum number of points in each cluster |
| | :param parallel: bool |
| | Whether cut-pursuit should run in parallel |
| | :param iterations: int |
| | Maximum number of iterations for each partition |
| | :param trim: bool |
| | Whether the input graph should be trimmed. See `to_trimmed()` |
| | documentation for more details on this operation |
| | :param discrepancy_epsilon: float |
| | Mitigates the maximum discrepancy. More precisely: |
| | `affinity=1 ⇒ discrepancy=1/discrepancy_epsilon` |
| | :param temperature: float |
| | Temperature used in the softmax when converting node logits to |
| | probabilities |
| | :param dampening: float |
| | Dampening applied to the node probabilities to mitigate the |
| | impact of near-zero probabilities in the Kullback-Leibler |
| | divergence |
| | :return: |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | loss_type='l2_kl', |
| | regularization=10, |
| | x_weight=1e-2, |
| | p_weight=1, |
| | cutoff=1, |
| | parallel=True, |
| | iterations=10, |
| | trim=False, |
| | discrepancy_epsilon=1e-4, |
| | temperature=1, |
| | dampening=0): |
| | super().__init__() |
| | self.loss_type = loss_type |
| | self.regularization = regularization |
| | self.x_weight = x_weight |
| | self.p_weight = p_weight |
| | self.cutoff = cutoff |
| | self.parallel = parallel |
| | self.iterations = iterations |
| | self.trim = trim |
| | self.discrepancy_epsilon = discrepancy_epsilon |
| | self.temperature = temperature |
| | self.dampening = dampening |
| |
|
| | def forward( |
| | self, |
| | batch, |
| | node_x, |
| | node_logits, |
| | stuff_classes, |
| | node_size, |
| | edge_index, |
| | edge_affinity_logits, |
| | grid=None): |
| | """The forward step will compute the partition on the instance |
| | graph, based on the node features, node logits, and edge |
| | affinities. The partition segments will then be further merged |
| | so that there is at most one instance of each stuff class per |
| | batch item (ie per scene). |
| | |
| | :param batch: Tensor of shape [num_nodes] |
| | Batch index of each node |
| | :param node_x: Tensor of shape [num_nodes, num_dim] |
| | Predicted node embeddings |
| | :param node_logits: Tensor of shape [num_nodes, num_classes] |
| | Predicted classification logits for each node |
| | :param stuff_classes: List or Tensor |
| | List of 'stuff' class labels. These are used for merging |
| | stuff segments together to ensure there is at most one |
| | predicted instance of each 'stuff' class per batch item |
| | :param node_size: Tensor of shape [num_nodes] |
| | Size of each node |
| | :param edge_index: Tensor of shape [2, num_edges] |
| | Edges of the graph, in torch-geometric's format |
| | :param edge_affinity_logits: Tensor of shape [num_edges] |
| | Predicted affinity logits (ie in R+, before sigmoid) of each |
| | edge |
| | :param grid: Dict |
| | A dictionary containing settings for grid-searching optimal |
| | partition parameters |
| | |
| | :return: obj_index: Tensor of shape [num_nodes] (or List(Dict, Tensor)) |
| | Indicates which predicted instance each node belongs to. If |
| | a grid is passed as input, a list containing partition |
| | settings and partition index tensors will be returned |
| | """ |
| | |
| | |
| | if grid is not None and len(grid) > 0: |
| | return self._grid_forward( |
| | batch, |
| | node_x, |
| | node_logits, |
| | stuff_classes, |
| | node_size, |
| | edge_index, |
| | edge_affinity_logits, |
| | grid) |
| |
|
| | |
| | |
| | return instance_cut_pursuit( |
| | batch, |
| | node_x, |
| | node_logits, |
| | stuff_classes, |
| | node_size, |
| | edge_index, |
| | edge_affinity_logits, |
| | loss_type=self.loss_type, |
| | regularization=self.regularization, |
| | x_weight=self.x_weight, |
| | p_weight=self.p_weight, |
| | cutoff=self.cutoff, |
| | parallel=self.parallel, |
| | iterations=self.iterations, |
| | trim=self.trim, |
| | discrepancy_epsilon=self.discrepancy_epsilon, |
| | temperature=self.temperature, |
| | dampening=self.dampening) |
| |
|
| | def _grid_forward( |
| | self, |
| | batch, |
| | node_x, |
| | node_logits, |
| | stuff_classes, |
| | node_size, |
| | edge_index, |
| | edge_affinity_logits, |
| | grid): |
| | """Run multiple forward calls for grid-searching optimal |
| | settings. |
| | """ |
| | |
| | |
| | keys = list(grid.keys()) |
| | for k in keys: |
| | if k not in self.__dict__: |
| | raise ValueError( |
| | f"'{k}' is not {self.__class__.__name__} attribute") |
| |
|
| | |
| | attr_bckp = copy(self.__dict__) |
| |
|
| | |
| | |
| | grid_outputs = [] |
| | for values in product(*grid.values()): |
| |
|
| | |
| | for k, v in zip(keys, values): |
| | setattr(self, k, v) |
| |
|
| | |
| | obj_index = self.forward( |
| | batch, |
| | node_x, |
| | node_logits, |
| | stuff_classes, |
| | node_size, |
| | edge_index, |
| | edge_affinity_logits, |
| | grid=None) |
| |
|
| | |
| | |
| | |
| | |
| | grid_outputs.append({k: v for k, v in zip(keys, values)}, obj_index) |
| |
|
| | |
| | for k, v in attr_bckp.items(): |
| | setattr(self, k, v) |
| |
|
| | return grid_outputs |
| |
|
| | def extra_repr(self) -> str: |
| | keys = [ |
| | 'regularization', |
| | 'x_weight', |
| | 'cutoff', |
| | 'parallel', |
| | 'iterations', |
| | 'trim', |
| | 'discrepancy_epsilon'] |
| | return ', '.join([f'{k}={getattr(self, k)}' for k in keys]) |
| |
|