from torch import nn from copy import copy from itertools import product from src.utils.instance import instance_cut_pursuit __all__ = ['InstancePartitioner'] class InstancePartitioner(nn.Module): """Partition a graph into instances using cut-pursuit. More specifically, this step will group nodes together based on: - node offset position - node predicted classification logits - node size - edge affinity NB: This operation relies on the parallel cut-pursuit algorithm: https://gitlab.com/1a7r0ch3/parallel-cut-pursuit Currently, this implementation is non-differentiable and runs on CPU. :param loss_type: str Rules the loss applied on the node features. Accepts one of 'l2' (L2 loss on node features and probabilities), 'l2_kl' (L2 loss on node features and Kullback-Leibler divergence on node probabilities) :param regularization: float Regularization parameter for the partition :param x_weight: float Weight used to mitigate the impact of the node position in the partition. The larger, the less spatial coordinates matter :param p_weight: float Weight used to mitigate the impact of the node probabilities in the partition. The larger, the greater the impact :param cutoff: float Minimum number of points in each cluster :param parallel: bool Whether cut-pursuit should run in parallel :param iterations: int Maximum number of iterations for each partition :param trim: bool Whether the input graph should be trimmed. See `to_trimmed()` documentation for more details on this operation :param discrepancy_epsilon: float Mitigates the maximum discrepancy. More precisely: `affinity=1 ⇒ discrepancy=1/discrepancy_epsilon` :param temperature: float Temperature used in the softmax when converting node logits to probabilities :param dampening: float Dampening applied to the node probabilities to mitigate the impact of near-zero probabilities in the Kullback-Leibler divergence :return: """ def __init__( self, loss_type='l2_kl', regularization=10, x_weight=1e-2, p_weight=1, cutoff=1, parallel=True, iterations=10, trim=False, discrepancy_epsilon=1e-4, temperature=1, dampening=0): super().__init__() self.loss_type = loss_type self.regularization = regularization self.x_weight = x_weight self.p_weight = p_weight self.cutoff = cutoff self.parallel = parallel self.iterations = iterations self.trim = trim self.discrepancy_epsilon = discrepancy_epsilon self.temperature = temperature self.dampening = dampening def forward( self, batch, node_x, node_logits, stuff_classes, node_size, edge_index, edge_affinity_logits, grid=None): """The forward step will compute the partition on the instance graph, based on the node features, node logits, and edge affinities. The partition segments will then be further merged so that there is at most one instance of each stuff class per batch item (ie per scene). :param batch: Tensor of shape [num_nodes] Batch index of each node :param node_x: Tensor of shape [num_nodes, num_dim] Predicted node embeddings :param node_logits: Tensor of shape [num_nodes, num_classes] Predicted classification logits for each node :param stuff_classes: List or Tensor List of 'stuff' class labels. These are used for merging stuff segments together to ensure there is at most one predicted instance of each 'stuff' class per batch item :param node_size: Tensor of shape [num_nodes] Size of each node :param edge_index: Tensor of shape [2, num_edges] Edges of the graph, in torch-geometric's format :param edge_affinity_logits: Tensor of shape [num_edges] Predicted affinity logits (ie in R+, before sigmoid) of each edge :param grid: Dict A dictionary containing settings for grid-searching optimal partition parameters :return: obj_index: Tensor of shape [num_nodes] (or List(Dict, Tensor)) Indicates which predicted instance each node belongs to. If a grid is passed as input, a list containing partition settings and partition index tensors will be returned """ # If grid is passed, multiple partition will be computed on the # parameter grid if grid is not None and len(grid) > 0: return self._grid_forward( batch, node_x, node_logits, stuff_classes, node_size, edge_index, edge_affinity_logits, grid) # If not grid searching optimal partition parameters, simply run # the partition with the current parameters return instance_cut_pursuit( batch, node_x, node_logits, stuff_classes, node_size, edge_index, edge_affinity_logits, loss_type=self.loss_type, regularization=self.regularization, x_weight=self.x_weight, p_weight=self.p_weight, cutoff=self.cutoff, parallel=self.parallel, iterations=self.iterations, trim=self.trim, discrepancy_epsilon=self.discrepancy_epsilon, temperature=self.temperature, dampening=self.dampening) def _grid_forward( self, batch, node_x, node_logits, stuff_classes, node_size, edge_index, edge_affinity_logits, grid): """Run multiple forward calls for grid-searching optimal settings. """ # If a grid dictionary was passed, make sure all keys in the # grid are supported attributes keys = list(grid.keys()) for k in keys: if k not in self.__dict__: raise ValueError( f"'{k}' is not {self.__class__.__name__} attribute") # Backup the current attributes attr_bckp = copy(self.__dict__) # Compute the grid search on the Cartesian product of the sets # of explored values grid_outputs = [] for values in product(*grid.values()): # Update self attributes with grid values for k, v in zip(keys, values): setattr(self, k, v) # Compute the partition obj_index = self.forward( batch, node_x, node_logits, stuff_classes, node_size, edge_index, edge_affinity_logits, grid=None) # Store the partition index for the current settings. The # results are stored in a tuple whose first element is a # dictionary of settings for self, and the second is the # output partition index grid_outputs.append({k: v for k, v in zip(keys, values)}, obj_index) # Restore the initial attributes for k, v in attr_bckp.items(): setattr(self, k, v) return grid_outputs def extra_repr(self) -> str: keys = [ 'regularization', 'x_weight', 'cutoff', 'parallel', 'iterations', 'trim', 'discrepancy_epsilon'] return ', '.join([f'{k}={getattr(self, k)}' for k in keys])