File size: 9,585 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import torch
from src.data import NAG, InstanceData
from src.transforms import Transform
from src.utils import cluster_radius_nn_graph, knn_1_graph, to_trimmed
from torch_geometric.nn.pool.consecutive import consecutive_cluster
__all__ = ['NAGPropagatePointInstances', 'OnTheFlyInstanceGraph']
class NAGPropagatePointInstances(Transform):
"""Compute the instances contained in each superpoint of each level,
provided that the first level has an 'obj' attribute holding an
`InstanceData`.
:param strict: bool
If True, will raise an exception if the level Data does not have
instance
"""
_IN_TYPE = NAG
_OUT_TYPE = NAG
_NO_REPR = ['strict']
def __init__(self, strict=False):
self.strict = strict
def _process(self, nag):
# Read the instances from the first data
obj_0 = nag[0].obj
if obj_0 is None or not isinstance(obj_0, InstanceData):
if not self.strict:
return nag
raise ValueError(f"Could not find any InstanceData in `nag[0].obj`")
for i_level in range(1, nag.num_levels):
super_index = nag.get_super_index(i_level)
nag[i_level].obj = obj_0.merge(super_index)
return nag
class OnTheFlyInstanceGraph(Transform):
"""Compute the non-oriented graph used for instance and panoptic
segmentation.
We choose the following assignment rule:
- each superpoint is assigned to the instance it overlaps the most
Importantly, one could think of other assignment rules, such as the
maximum IoU, for instance. But the latter would not work for
over-segmented scenes with small superpoints and very large 'stuff'
instances. We favor the overlap-size rule due to this scenario and
for its simplicity.
It is recommended to call this transform only AFTER all geometric
transformations and sampling have been applied to the batch. This
is typically important when using samplings of subgraphs, where the
position of the entire clusters are maintained, even after being
cropped. Such behavior may damage the instance centroid estimation
step.
:param level: int
Partition level at which to compute the instance graph. Setting
`level=-1` or `level=None` will skip the present Transform (can
be useful for integrating this Transform in a pipeline and
optionally skip it)
:param num_classes: int
Number of classes in the dataset. Specifying `num_classes`
allows identifying 'void' labels. By convention, we assume
`y ∈ [0, self.num_classes-1]` ARE ALL VALID LABELS (i.e. not
'ignored', 'void', 'unknown', etc), while `y < 0` AND
`y >= self.num_classes` ARE VOID LABELS. Void data is dealt
with following https://arxiv.org/abs/1801.00868 and
https://arxiv.org/abs/1905.01220
:param adjacency_mode: str
Method used to compute search for adjacent nodes. If 'available'
the already-existing graph in the input's 'edge_index' will be
used. If 'radius', the `radius` parameter will be used to search
for all neighboring clusters with points within `radius` of each
other. If 'radius-centroid', the `radius` parameter will be used
to search for all neighboring clusters solely based on their
centroid position. This is likely faster but less accurate than
'radius'
:param k_max: int
Maximum number of neighbors per cluster if `adjacency_mode`
calls for it
:param radius: float
Radius used for neighbor search if `adjacency_mode` calls for it
:param use_batch: bool
If True, the 'NAG[level].batch' attribute will be used to
guide neighbor search if `adjacency_mode` calls for it. More
specifically, if the input NAG is a NAGBatch made up of multiple
NAGs, the neighbor search will ensure that clusters from
different batch items cannot be neighbors. It is recommended to
keep the default `use_batch=True`
:param centroid_mode: str
Method used to estimate the centroids. 'iou' will weigh down
the centroids of the clusters overlapping each instance by
their IoU. 'ratio-product' will use the product of the size
ratios of the overlap wrt the cluster and wrt the instance
:param centroid_level: int
Partition level to use to estimate the centroids. The purer the
partition, the better the estimation. But the larger the
partition, the slower the computation
:param smooth_affinity: bool
If True, the affinity score computed for each edge will
follow the 'smooth' formulation:
`(overlap_i_obj_j / size_i + overlap_j_obj_i / size_j) / 2`
for the edge `(i, j)`, where `obj_i` designates the target
instance of `i`. If False, the affinity will be computed
with the simpler formulation: `obj_i == obj_j`
"""
_IN_TYPE = NAG
_OUT_TYPE = NAG
_ADJACENCY_MODES = ['available', 'radius', 'radius-centroid']
_CENTROID_MODES = ['iou', 'ratio-product']
def __init__(
self,
level=1,
num_classes=None,
adjacency_mode='radius',
k_max=30,
radius=1,
use_batch=True,
centroid_mode='iou',
centroid_level=1,
smooth_affinity=True):
assert adjacency_mode.lower() in self._ADJACENCY_MODES, \
f"Expected 'mode' to be one of {self._ADJACENCY_MODES}"
assert centroid_mode.lower() in self._CENTROID_MODES, \
f"Expected 'mode' to be one of {self._CENTROID_MODES}"
self.level = level
self.num_classes = num_classes
self.adjacency_mode = adjacency_mode.lower()
self.k_max = k_max
self.radius = radius
self.use_batch = use_batch
self.centroid_mode = centroid_mode.lower()
self.centroid_level = centroid_level
self.smooth_affinity = smooth_affinity
def _process(self, nag):
# Skip the transform. This mechanism can be useful for skipping
# this Transform in a pipeline
if self.level is None or self.level < 0:
return nag
data = nag[self.level]
# Build the edges on which the graph optimization will be run
# for instance or panoptic segmentation.
# Use the already-existing graph in `edge_index`
if self.adjacency_mode == 'available':
obj_edge_index = data.edge_index
# Compute the neighbors based on the distances between the
# points they hold
elif self.adjacency_mode == 'radius':
# TODO: accelerate with subsampling ?
super_index = nag.get_super_index(self.level, low=0)
obj_edge_index, _ = cluster_radius_nn_graph(
nag[0].pos,
super_index,
k_max=self.k_max,
gap=self.radius,
batch=nag[self.level].batch if self.use_batch else None)
# Compute the neighbors solely based on the clusters' centroids
elif self.adjacency_mode == 'radius-centroid':
obj_edge_index, _ = knn_1_graph(
nag[self.level].pos,
self.k_max,
r_max=self.radius,
batch=nag[self.level].batch if self.use_batch else None)
else:
raise NotImplementedError
# If, for some reason, the graph is None, we convert it to an
# empty torch_geometric-friendly `edge_index`-like format
if obj_edge_index is None:
obj_edge_index = torch.empty(
2, 0, dtype=torch.long, device=data.device)
# If the Data does not contain any InstanceData with instance
# annotations, set the target object positions and edge
# affinities to None and save the trimmed instance graph
if data.obj is None:
data.obj_edge_index = to_trimmed(obj_edge_index)
data.obj_edge_affinity = None
data.obj_pos = None
nag._list[self.level] = data
return nag
# Compute the trimmed graph and the edge affinity scores
data.obj_edge_index, data.obj_edge_affinity = data.obj.instance_graph(
obj_edge_index,
num_classes=self.num_classes,
smooth_affinity=self.smooth_affinity)
# Compute the superpoint target instance centroid position
# NB: this is a proxy method assuming nag[0] is pure-enough
i_level = min(self.centroid_level, nag.num_levels - 1)
obj_pos, obj_idx = nag[i_level].estimate_instance_centroid(
mode=self.centroid_mode)
# Find the target instance for each superpoint: the instance it
# has the biggest overlap with
sp_obj_idx = data.obj.major(num_classes=self.num_classes)[0]
# Recover, for each superpoint, the instance position. Since the
# `estimate_instance_centroid()` output is sorted by increasing
# obj indices, `consecutive_cluster()` allows us to convert
# `obj_idx` into proper indices to gather object positions from
# `obj_pos`
joint_obj_idx = torch.cat((sp_obj_idx, obj_idx))
joint_obj_idx_consec = consecutive_cluster(joint_obj_idx)[0]
sp_obj_idx_consec = joint_obj_idx_consec[:sp_obj_idx.numel()]
data.obj_pos = obj_pos[sp_obj_idx_consec]
# Save in the data in the NAG structure
nag._list[self.level] = data
return nag
|