English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
from torch import nn
from copy import copy
from itertools import product
from src.utils.instance import instance_cut_pursuit
__all__ = ['InstancePartitioner']
class InstancePartitioner(nn.Module):
"""Partition a graph into instances using cut-pursuit.
More specifically, this step will group nodes together based on:
- node offset position
- node predicted classification logits
- node size
- edge affinity
NB: This operation relies on the parallel cut-pursuit algorithm:
https://gitlab.com/1a7r0ch3/parallel-cut-pursuit
Currently, this implementation is non-differentiable and runs on
CPU.
:param loss_type: str
Rules the loss applied on the node features. Accepts one of
'l2' (L2 loss on node features and probabilities),
'l2_kl' (L2 loss on node features and Kullback-Leibler
divergence on node probabilities)
:param regularization: float
Regularization parameter for the partition
:param x_weight: float
Weight used to mitigate the impact of the node position in the
partition. The larger, the less spatial coordinates matter
:param p_weight: float
Weight used to mitigate the impact of the node probabilities in
the partition. The larger, the greater the impact
:param cutoff: float
Minimum number of points in each cluster
:param parallel: bool
Whether cut-pursuit should run in parallel
:param iterations: int
Maximum number of iterations for each partition
:param trim: bool
Whether the input graph should be trimmed. See `to_trimmed()`
documentation for more details on this operation
:param discrepancy_epsilon: float
Mitigates the maximum discrepancy. More precisely:
`affinity=1 ⇒ discrepancy=1/discrepancy_epsilon`
:param temperature: float
Temperature used in the softmax when converting node logits to
probabilities
:param dampening: float
Dampening applied to the node probabilities to mitigate the
impact of near-zero probabilities in the Kullback-Leibler
divergence
:return:
"""
def __init__(
self,
loss_type='l2_kl',
regularization=10,
x_weight=1e-2,
p_weight=1,
cutoff=1,
parallel=True,
iterations=10,
trim=False,
discrepancy_epsilon=1e-4,
temperature=1,
dampening=0):
super().__init__()
self.loss_type = loss_type
self.regularization = regularization
self.x_weight = x_weight
self.p_weight = p_weight
self.cutoff = cutoff
self.parallel = parallel
self.iterations = iterations
self.trim = trim
self.discrepancy_epsilon = discrepancy_epsilon
self.temperature = temperature
self.dampening = dampening
def forward(
self,
batch,
node_x,
node_logits,
stuff_classes,
node_size,
edge_index,
edge_affinity_logits,
grid=None):
"""The forward step will compute the partition on the instance
graph, based on the node features, node logits, and edge
affinities. The partition segments will then be further merged
so that there is at most one instance of each stuff class per
batch item (ie per scene).
:param batch: Tensor of shape [num_nodes]
Batch index of each node
:param node_x: Tensor of shape [num_nodes, num_dim]
Predicted node embeddings
:param node_logits: Tensor of shape [num_nodes, num_classes]
Predicted classification logits for each node
:param stuff_classes: List or Tensor
List of 'stuff' class labels. These are used for merging
stuff segments together to ensure there is at most one
predicted instance of each 'stuff' class per batch item
:param node_size: Tensor of shape [num_nodes]
Size of each node
:param edge_index: Tensor of shape [2, num_edges]
Edges of the graph, in torch-geometric's format
:param edge_affinity_logits: Tensor of shape [num_edges]
Predicted affinity logits (ie in R+, before sigmoid) of each
edge
:param grid: Dict
A dictionary containing settings for grid-searching optimal
partition parameters
:return: obj_index: Tensor of shape [num_nodes] (or List(Dict, Tensor))
Indicates which predicted instance each node belongs to. If
a grid is passed as input, a list containing partition
settings and partition index tensors will be returned
"""
# If grid is passed, multiple partition will be computed on the
# parameter grid
if grid is not None and len(grid) > 0:
return self._grid_forward(
batch,
node_x,
node_logits,
stuff_classes,
node_size,
edge_index,
edge_affinity_logits,
grid)
# If not grid searching optimal partition parameters, simply run
# the partition with the current parameters
return instance_cut_pursuit(
batch,
node_x,
node_logits,
stuff_classes,
node_size,
edge_index,
edge_affinity_logits,
loss_type=self.loss_type,
regularization=self.regularization,
x_weight=self.x_weight,
p_weight=self.p_weight,
cutoff=self.cutoff,
parallel=self.parallel,
iterations=self.iterations,
trim=self.trim,
discrepancy_epsilon=self.discrepancy_epsilon,
temperature=self.temperature,
dampening=self.dampening)
def _grid_forward(
self,
batch,
node_x,
node_logits,
stuff_classes,
node_size,
edge_index,
edge_affinity_logits,
grid):
"""Run multiple forward calls for grid-searching optimal
settings.
"""
# If a grid dictionary was passed, make sure all keys in the
# grid are supported attributes
keys = list(grid.keys())
for k in keys:
if k not in self.__dict__:
raise ValueError(
f"'{k}' is not {self.__class__.__name__} attribute")
# Backup the current attributes
attr_bckp = copy(self.__dict__)
# Compute the grid search on the Cartesian product of the sets
# of explored values
grid_outputs = []
for values in product(*grid.values()):
# Update self attributes with grid values
for k, v in zip(keys, values):
setattr(self, k, v)
# Compute the partition
obj_index = self.forward(
batch,
node_x,
node_logits,
stuff_classes,
node_size,
edge_index,
edge_affinity_logits,
grid=None)
# Store the partition index for the current settings. The
# results are stored in a tuple whose first element is a
# dictionary of settings for self, and the second is the
# output partition index
grid_outputs.append({k: v for k, v in zip(keys, values)}, obj_index)
# Restore the initial attributes
for k, v in attr_bckp.items():
setattr(self, k, v)
return grid_outputs
def extra_repr(self) -> str:
keys = [
'regularization',
'x_weight',
'cutoff',
'parallel',
'iterations',
'trim',
'discrepancy_epsilon']
return ', '.join([f'{k}={getattr(self, k)}' for k in keys])