File size: 23,692 Bytes
26225c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 | import torch
import logging
from torch_scatter import scatter_mean
from src.utils.scatter import scatter_mean_weighted
from src.utils.output_semantic import SemanticSegmentationOutput
log = logging.getLogger(__name__)
__all__ = ['PanopticSegmentationOutput', 'PartitionParameterSearchStorage']
class PanopticSegmentationOutput(SemanticSegmentationOutput):
"""A simple holder for panoptic segmentation model output, with a
few helper methods for manipulating the predictions and targets
(if any).
"""
def __init__(
self,
logits,
stuff_classes,
edge_affinity_logits,
# node_offset_pred,
node_size,
y_hist=None,
obj=None,
obj_edge_index=None,
obj_edge_affinity=None,
pos=None,
obj_pos=None,
obj_index_pred=None,
semantic_loss=None,
# node_offset_loss=None,
edge_affinity_loss=None):
# We set the child class attributes before calling the parent
# class constructor, because the parent constructor calls
# `self.debug()`, which needs all attributes to be initialized
device = edge_affinity_logits.device
self.stuff_classes = torch.tensor(stuff_classes, device=device).long() \
if stuff_classes is not None \
else torch.empty(0, device=device).long()
self.edge_affinity_logits = edge_affinity_logits
# self.node_offset_pred = node_offset_pred
self.node_size = node_size
self.obj = obj
self.obj_edge_index = obj_edge_index
self.obj_edge_affinity = obj_edge_affinity
self.pos = pos
self.obj_pos = obj_pos
self.obj_index_pred = obj_index_pred
self.semantic_loss = semantic_loss
# self.node_offset_loss = node_offset_loss
self.edge_affinity_loss = edge_affinity_loss
super().__init__(logits, y_hist=y_hist)
def debug(self):
# Parent class debugger
super().debug()
# Instance predictions
# assert self.node_offset_pred.dim() == 2
# assert self.node_offset_pred.shape[0] == self.num_nodes
assert self.edge_affinity_logits.dim() == 1
# Node properties
assert self.node_size.dim() == 1
assert self.node_size.shape[0] == self.num_nodes
if self.has_instance_pred:
if not self.has_multi_instance_pred:
assert self.obj_index_pred.dim() == 1
assert self.obj_index_pred.shape[0] == self.num_nodes
else:
assert isinstance(self.obj_index_pred, list)
item = self.obj_index_pred[0]
assert isinstance(item[0], dict)
assert isinstance(item[1], torch.Tensor)
assert item[1].dim() == 1
assert item[1].shape[0] == self.num_nodes
# Instance target
items = [
self.obj_edge_index, self.obj_edge_affinity, self.pos, self.obj_pos]
without_instance_target = all(x is None for x in items)
with_instance_target = all(x is not None for x in items)
assert without_instance_target or with_instance_target
if without_instance_target:
return
# Local import to avoid import loop errors
from src.data import InstanceData
assert isinstance(self.obj, InstanceData)
assert self.obj.num_clusters == self.num_nodes
assert self.obj_edge_index.dim() == 2
assert self.obj_edge_index.shape[0] == 2
assert self.obj_edge_index.shape[1] == self.num_edges
assert self.obj_edge_affinity.dim() == 1
assert self.obj_edge_affinity.shape[0] == self.num_edges
# assert self.pos.shape == self.node_offset_pred.shape
# assert self.obj_pos.shape == self.node_offset_pred.shape
@property
def has_target(self):
"""Check whether `self` contains target data for panoptic
segmentation.
"""
items = [
self.obj,
self.obj_edge_index,
self.obj_edge_affinity,
self.pos,
self.obj_pos]
return super().has_target and all(x is not None for x in items)
@property
def has_instance_pred(self):
"""Check whether `self` contains predicted data for panoptic
segmentation `obj_index_pred`.
"""
return self.obj_index_pred is not None
@property
def has_multi_instance_pred(self):
"""Check whether `self` contains predicted data for panoptic
segmentation `obj_index_pred` as a list of results for
performance comparison of partition settings.
"""
return self.has_instance_pred \
and not isinstance(self.obj_index_pred, torch.Tensor)
@property
def num_edges(self):
"""Number for edges in the instance graph.
"""
return self.edge_affinity_logits.shape[1]
# @property
# def node_offset(self):
# """Target node offset: `offset = obj_pos - pos`.
# """
# if not self.has_target:
# return
# return self.obj_pos - self.pos
@property
def edge_affinity_pred(self):
"""Simply applies a sigmoid on `edge_affinity_logits` to produce
the actual affinity predictions to be used for superpoint
graph clustering.
"""
return self.edge_affinity_logits.sigmoid()
@property
def void_edge_mask(self):
"""Returns a mask on the edges indicating those connecting two
void nodes.
"""
if not self.has_target:
return
mask = self.void_mask[self.obj_edge_index]
return mask[0] & mask[1]
# @property
# def sanitized_node_offsets(self):
# """Return the predicted and target node offsets, along with node
# size, sanitized for node offset loss and metrics computation.
#
# By convention, we want stuff nodes to have 0 offset. Two
# reasons for that:
# - defining a stuff target center is ambiguous
# - by predicting 0 offsets, the corresponding nodes are
# likely to be isolated by the superpoint clustering step.
# This is what we want, because the predictions will be
# merged as a post-processing step, to ensure there is a
# most one prediction per batch item for each stuff class
#
# Besides, we choose to exclude nodes/superpoints with more than
# 50% 'void' points from node offset loss and metrics computation.
#
# To this end, the present function does the following:
# - ASSUME predicted offsets are 0 when predicted semantic class
# is of type 'stuff'
# - set target offsets to 0 when target semantic class is of
# type 'stuff'
# - remove predicted and target offsets for 'void' nodes (see
# `self.void_mask`)
# """
# if not self.has_target:
# return None, None, None
#
# # We exclude the void nodes from loss computation
# idx = torch.where(~self.void_mask)[0]
#
# # Set target offsets to 0 when predicted semantic is stuff
# y_hist = self.semantic_target
# is_stuff = get_stuff_mask(y_hist, self.stuff_classes)
# node_offset = self.node_offset
# node_offset[is_stuff] = 0
#
# return self.node_offset_pred[idx], node_offset[idx], self.node_size[idx]
def sanitized_edge_affinities(self):
"""Return the predicted and target edge affinities, along with
masks indicating same-class and same-object edges. The output is
sanitized for edge affinity loss and metrics computation.
We return the edge affinity logits to the criterion and not
the actual sigmoid-normalized predictions used for graph
clustering. The reason for this is that we expect the edge
affinity loss to be computed using `BCEWithLogitsLoss`.
We choose to exclude edges connecting nodes/superpoints with
more than 50% 'void' points from edge affinity loss and metrics
computation. This is what the sanitization step consists in.
To this end, the present function does the following:
- remove predicted and target edges connecting two 'void'
nodes (see `self.void_edge_mask`)
"""
# Identify the sanitized edges
idx = torch.where(~self.void_edge_mask)[0]
# Compute the boolean masks indicating same-class and
# same-object edges. These can be useful for losses with more
# weights on hard edges
obj, count, y = self.obj.major(num_classes=self.num_classes)
is_same_class = y[self.obj_edge_index[0]] == y[self.obj_edge_index[1]]
is_same_obj = obj[self.obj_edge_index[0]] == obj[self.obj_edge_index[1]]
# Return sanitized predicted and target affinities, as well as
# edge masks
return self.edge_affinity_logits[idx], self.obj_edge_affinity[idx], \
is_same_class[idx], is_same_obj[idx]
def weighted_instance_semantic_pred(self):
"""Compute the predicted semantic label, score and logits for
each predicted instance. This involves computing, for each
predicted instance, the weighted average of the logits of the
superpoints it contains.
"""
if not self.has_instance_pred:
return None, None, None
# Compute the mean logits for each predicted object, weighted by
# the node sizes
node_logits = self.logits[0] if self.multi_stage else self.logits
obj_logits = scatter_mean_weighted(
node_logits, self.obj_index_pred, self.node_size)
# Compute the predicted semantic label and proba for each node
obj_semantic_score, obj_y = obj_logits.softmax(dim=1).max(dim=1)
return obj_y, obj_semantic_score, obj_logits
def panoptic_pred(self):
"""Panoptic predictions on the level-1 superpoints.
Return the predicted semantic score and label for each predicted
instance, along with the InstanceData object summarizing
predictions.
"""
if not self.has_instance_pred:
return None, None, None
# Merge the InstanceData based on the predicted instances and
# target instances
instance_data = self.obj.merge(self.obj_index_pred) if self.has_target \
else None
# Compute the semantic prediction for each predicted object,
# weighted by the node sizes
obj_y, obj_semantic_score, obj_logits = \
self.weighted_instance_semantic_pred()
# # Compute the mean node offset, weighted by node sizes, for each
# # object
# node_x = self.pos + self.node_offset_pred
# obj_x = scatter_mean_weighted(
# node_x, self.obj_index_pred, self.node_size)
#
# # Compute the mean squared distance to the mean predicted offset
# # for each object
# node_x_error = ((node_x - obj_x[self.obj_index_pred]) ** 2).sum(dim=1)
# obj_x_error = scatter_mean_weighted(
# node_x_error, self.obj_index_pred, self.node_size).squeeze()
#
# # Compute the node offset prediction score
# obj_x_score = 1 / (1 + obj_x_error)
# TODO: should we take object size into account in the scoring ?
# Compute, for each predicted object, the mean inter-object and
# intra-object predicted edge affinity
ie = self.obj_index_pred[self.obj_edge_index]
intra = ie[0] == ie[1]
idx = ie.flatten()
intra = intra.repeat(2)
a = self.edge_affinity_pred.repeat(2)
n = self.obj_index_pred.max() + 1
obj_mean_intra = scatter_mean(a[intra], idx[intra], dim_size=n)
obj_mean_inter = scatter_mean(a[~intra], idx[~intra], dim_size=n)
# Compute the inter-object and intra-object scores
obj_intra_score = obj_mean_intra
obj_inter_score = 1 / (1 + obj_mean_inter)
# Final prediction score is the product of individual scores
# TODO : cleanly remove offset
# obj_score = \
# obj_semantic_score * obj_x_score * obj_intra_score * obj_inter_score
# obj_score = obj_semantic_score * obj_intra_score * obj_inter_score
obj_score = obj_semantic_score
return obj_score, obj_y, instance_data
def superpoint_panoptic_pred(self):
"""Panoptic predictions on the level-1 nodes. Returns the
predicted semantic label and instance index for each superpoint,
along with the voxel-wise InstanceData summarizing predictions.
Note this differs from `self.panoptic_pred()` which returns
scores, semantic labels, and InstanceData objects with respect
to the predicted instances, and not to the superpoint
themselves.
Final panoptic segmentation predictions are computed with
respect to predicted instances, after level-1 superpoint-graph
clustering.
The predicted instance semantic labels are computed from the
average of logits of level-1 superpoints they include, weighted
by the superpoint sizes. These instance-aggregated semantic
predictions may (slightly) differ from the per-superpoint
semantic segmentation prediction obtained from
`self.semantic_pred()`.
"""
# Compute the semantic prediction for each predicted object,
# weighted by the node sizes
obj_y, _, _ = self.weighted_instance_semantic_pred()
# Distribute the per-instance predictions to level-1 superpoints
sp_y = obj_y[self.obj_index_pred]
# # Distribute the level-1 superpoint semantic predictions and
# # instance indices to the voxels
# vox_y = sp_y[super_index]
# vox_index = self.obj_index_pred[super_index]
# Local import to avoid import loop errors
from src.data import InstanceData
# Compute the superpoint-wise InstanceData carrying predictions
sp_obj_pred = InstanceData(
torch.arange(self.num_nodes, device=self.device),
self.obj_index_pred,
self.node_size,
sp_y,
dense=True)
return sp_y, self.obj_index_pred, sp_obj_pred
def voxel_panoptic_pred(self, super_index=None, sub=None):
"""Panoptic predictions on the level-0 voxels. Returns the
predicted semantic label and instance index for each voxel,
along with the voxel-wise InstanceData summarizing predictions.
Final panoptic segmentation predictions are computed with
respect to predicted instances, after level-1 superpoint-graph
clustering.
The predicted instance semantic labels are computed from the
average of logits of level-1 superpoints they include, weighted
by the superpoint sizes. These instance-aggregated semantic
predictions may (slightly) differ from the per-superpoint
semantic segmentation prediction obtained from
`self.voxel_semantic_pred()`.
This function then distributes semantic and instance index
predictions to each level-0 point (ie voxel in our framework).
:param super_index: LongTensor
Tensor holding, for each level-0 point (ie voxel), the index
of the level-1 superpoint it belongs to
:param sub: Cluster
Cluster object indicating, for each level-1 superpoint,
the indices of the level-0 points (ie voxels) it contains
"""
assert super_index is not None or sub is not None, \
"Must provide either `super_index` or `sub`"
# If super_index is not provided, build it from sub
if super_index is None:
super_index = sub.to_super_index()
# Compute the semantic prediction for each predicted object,
# weighted by the node sizes
obj_y, _, _ = self.weighted_instance_semantic_pred()
# Distribute the per-instance predictions to level-1 superpoints
sp_y = obj_y[self.obj_index_pred]
# Distribute the level-1 superpoint semantic predictions and
# instance indices to the voxels
vox_y = sp_y[super_index]
vox_index = self.obj_index_pred[super_index]
# Local import to avoid import loop errors
from src.data import InstanceData
# Compute the voxel-wise InstanceData carrying voxel predictions
# NB: we make an approximation here: each voxel is given a count
# of 1 point, neglecting the actual number of points in each
# voxel. This may slightly affect the metrics, compared to
# the true full-resolution predictions
num_voxels = super_index.shape[0]
vox_obj_pred = InstanceData(
torch.arange(num_voxels, device=self.device),
vox_index,
torch.ones(num_voxels, device=self.device, dtype=torch.long),
vox_y,
dense=True)
return vox_y, vox_index, vox_obj_pred
def full_res_panoptic_pred(
self,
super_index_level0_to_level1=None,
super_index_raw_to_level0=None,
sub_level1_to_level0=None,
sub_level0_to_raw=None):
"""Panoptic predictions on the full-resolution input point
cloud. Returns the predicted semantic label and instance index
for each point, along with the point-wise InstanceData
summarizing predictions.
Final panoptic segmentation predictions are computed with
respect to predicted instances, after level-1 superpoint-graph
clustering.
The predicted instance semantic labels are computed from the
average of logits of level-1 superpoints they include, weighted
by the superpoint sizes. These instance-aggregated semantic
predictions may (slightly) differ from the per-superpoint
semantic segmentation prediction obtained from
`self.full_res_semantic_pred()`.
This function then distributes these predictions to each raw
point (ie full-resolution point cloud before voxelization in our
framework).
:param super_index_level0_to_level1: LongTensor
Tensor holding, for each level-0 point (ie voxel), the index
of the level-1 superpoint it belongs to
:param super_index_raw_to_level0: LongTensor
Tensor holding, for each raw full-resolution point, the
index of the level-0 point (ie voxel) it belongs to
:param sub_level1_to_level0: Cluster
Cluster object indicating, for each level-1 superpoint,
the indices of the level-0 points (ie voxels) it contains
:param sub_level0_to_raw: Cluster
Cluster object indicating, for each level-0 point (ie
voxel), the indices of the raw full-resolution points it
contains
"""
assert super_index_level0_to_level1 is not None or sub_level1_to_level0 is not None, \
"Must provide either `super_index_level0_to_level1` or `sub_level1_to_level0`"
assert super_index_raw_to_level0 is not None or sub_level0_to_raw is not None, \
"Must provide either `super_index_raw_to_level0` or `sub_level0_to_raw`"
# If super_index are not provided, build them from sub
if super_index_level0_to_level1 is None:
super_index_level0_to_level1 = sub_level1_to_level0.to_super_index()
if super_index_raw_to_level0 is None:
super_index_raw_to_level0 = sub_level0_to_raw.to_super_index()
# Distribute the level-1 superpoint semantic predictions and
# instance indices to the voxels
vox_y, vox_index, vox_obj_pred = self.voxel_panoptic_pred(
super_index=super_index_level0_to_level1)
# Distribute the level-1 superpoint predictions to the
# full-resolution points
raw_y = vox_y[super_index_raw_to_level0]
raw_index = vox_index[super_index_raw_to_level0]
# Local import to avoid import loop errors
from src.data import InstanceData
# Compute the voxel-wise InstanceData carrying voxel predictions
# NB: we make an approximation here: each voxel is given a count
# of 1 point, neglecting the actual number of points in each
# voxel. This may slightly affect the metrics, compared to
# the true full-resolution predictions
num_points = super_index_raw_to_level0.shape[0]
raw_obj_pred = InstanceData(
torch.arange(num_points, device=self.device),
raw_index,
torch.ones(num_points, device=self.device, dtype=torch.long),
raw_y,
dense=True)
return raw_y, raw_index, raw_obj_pred
class PartitionParameterSearchStorage:
"""A class to hold the output results of multiple partitions, when
searching for the optimal partition parameter settings. Since
metrics are only computed at the end of an epoch, we cannot compute
the optimal parameter settings at each batch. On the other hand, we
cannot store the whole content of the `PanopticSegmentationOutput`
of each batch. This holder is used to store the strict necessary
from the `PanopticSegmentationOutput` of each batch, to be able to
call `PanopticSegmentationOutput.panoptic_pred()` at
the end of an epoch and pass its output to an instance or panoptic
segmentation metric object.
NB: make sure the input is detached and on CPU, you do not want to
blow up your GPU memory. Still, for very large datasets, this
approach will be RAM-hungry. If this causes CPU memory errors, you
will need to save your predicted data in temp files on disk.
"""
def __init__(
self,
logits,
stuff_classes,
node_size,
edge_affinity_logits,
obj,
obj_index_pred):
self.stuff_classes = stuff_classes
self.logits = logits
self.node_size = node_size
self.edge_affinity_logits = edge_affinity_logits
self.obj = obj
self.obj_index_pred = obj_index_pred
@property
def settings(self):
"""This assumes all items in `self.obj_index_pred` follow the
output format of `InstancePartitioner._grid_forward()`.
"""
return [v[0] for v in self.obj_index_pred]
@property
def num_settings(self):
"""This assumes all items in `self.obj_index_pred` follow the
output format of `InstancePartitioner._grid_forward()`.
"""
return len(self.settings)
def panoptic_pred(self, setting):
"""Return the predicted InstanceData, and the predicted instance
semantic label and score, for a given batch item and a given
partition setting.
"""
# Recover the index of the setting in the stored results
i_setting = self.settings.index(setting) \
if not isinstance(setting, int) else setting
# Recover the batch's partition results
output = PanopticSegmentationOutput(
self.logits,
self.stuff_classes,
self.edge_affinity_logits,
self.node_size,
obj=self.obj,
obj_index_pred=self.obj_index_pred[i_setting][1])
# Compute inputs for an instance or panoptic segmentation metric
return output.panoptic_pred()
|