|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""3D stage model for Panoptic Recon 3D.""" |
|
|
|
|
|
import torch |
|
|
import MinkowskiEngine as Me |
|
|
from torch.nn import functional as F |
|
|
import torch.nn as nn |
|
|
|
|
|
from .blocks import ProjectionBlock |
|
|
from .utils.helper import retry_if_cuda_oom |
|
|
|
|
|
from .reconstruction import SparseProjection, FrustumDecoder |
|
|
from .mp_occ.occupancy_aware_lifting import OccupancyAwareLifting |
|
|
from .mp_occ.back_projection import BackProjection |
|
|
|
|
|
from .utils.sparse_tensor import \ |
|
|
to_dense, prepare_instance_masks_thicken |
|
|
from .utils.coords_transform import \ |
|
|
transform_feat3d_coordinates, fuse_sparse_tensors, generate_multiscale_feat3d |
|
|
|
|
|
class Postprocessor(nn.Module): |
|
|
"""2D model postprocessor.""" |
|
|
|
|
|
def __init__(self, cfg): |
|
|
"""Initialize the postprocessor.""" |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.test_topk_per_image = cfg.model.test_topk_per_image |
|
|
self.num_classes = cfg.model.sem_seg_head.num_classes |
|
|
self.num_queries = cfg.model.mask_former.num_object_queries |
|
|
self.object_mask_threshold = cfg.model.object_mask_threshold |
|
|
self.overlap_threshold = cfg.model.overlap_threshold |
|
|
self.depth_scale = cfg.dataset.depth_scale |
|
|
self.num_thing_classes = cfg.dataset.num_thing_classes |
|
|
|
|
|
def panoptic_inference(self, mask_cls, mask_pred, depth_pred): |
|
|
"""Post process for panoptic segmentation.""" |
|
|
scores, labels = F.softmax(mask_cls, dim=-1).max(-1) |
|
|
mask_pred = mask_pred.sigmoid() |
|
|
|
|
|
keep = labels.ne(self.num_classes) & (scores > self.object_mask_threshold) |
|
|
cur_scores = scores[keep] |
|
|
cur_classes = labels[keep] |
|
|
cur_masks = mask_pred[keep] |
|
|
cur_mask_cls = mask_cls[keep] |
|
|
cur_mask_cls = cur_mask_cls[:, :-1] |
|
|
|
|
|
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks |
|
|
|
|
|
h, w = cur_masks.shape[-2:] |
|
|
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) |
|
|
segments_info = [] |
|
|
|
|
|
sem_prob_masks = torch.zeros(( |
|
|
self.num_classes, h, w |
|
|
), dtype=torch.float32, device=cur_masks.device) |
|
|
|
|
|
current_segment_id = 0 |
|
|
|
|
|
if cur_masks.shape[0] == 0: |
|
|
return panoptic_seg, depth_pred[0, :, :], segments_info, sem_prob_masks |
|
|
else: |
|
|
cur_mask_ids = cur_prob_masks.argmax(0) |
|
|
stuff_memory_list = {} |
|
|
stuff_mask_ids = [] |
|
|
for k in range(cur_classes.shape[0]): |
|
|
pred_class = cur_classes[k].item() |
|
|
isthing = pred_class in list(range(1, self.num_thing_classes + 1)) |
|
|
mask_area = (cur_mask_ids == k).sum().item() |
|
|
original_area = (cur_masks[k] >= 0.5).sum().item() |
|
|
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) |
|
|
|
|
|
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: |
|
|
if mask_area / original_area < self.overlap_threshold: |
|
|
continue |
|
|
|
|
|
|
|
|
if not isthing: |
|
|
stuff_mask_ids.append(k) |
|
|
if int(pred_class) in stuff_memory_list.keys(): |
|
|
panoptic_seg[mask] = stuff_memory_list[int(pred_class)] |
|
|
sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask] |
|
|
continue |
|
|
else: |
|
|
stuff_memory_list[int(pred_class)] = current_segment_id + 1 |
|
|
|
|
|
current_segment_id += 1 |
|
|
panoptic_seg[mask] = current_segment_id |
|
|
sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask] |
|
|
|
|
|
segments_info.append( |
|
|
{ |
|
|
"id": current_segment_id, |
|
|
"isthing": bool(isthing), |
|
|
"category_id": int(pred_class), |
|
|
} |
|
|
) |
|
|
|
|
|
if stuff_mask_ids: |
|
|
|
|
|
stuff_mask_ids = torch.tensor(stuff_mask_ids, dtype=torch.long, device=cur_prob_masks.device) |
|
|
cur_stuff_ids = stuff_mask_ids[cur_prob_masks[stuff_mask_ids].argmax(0)] |
|
|
empty_pixel_mask = panoptic_seg == 0 |
|
|
for k in stuff_mask_ids: |
|
|
k = k.item() |
|
|
pred_class = cur_classes[k].item() |
|
|
mask = empty_pixel_mask & (cur_stuff_ids == k) |
|
|
panoptic_seg[mask] = stuff_memory_list[int(pred_class)] |
|
|
sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask] |
|
|
|
|
|
|
|
|
depth_pred = depth_pred[0, ...].clamp(min=0, max=self.depth_scale) |
|
|
return panoptic_seg, depth_pred, segments_info, sem_prob_masks |
|
|
|
|
|
@staticmethod |
|
|
def sem_seg_postprocess(result, img_size): |
|
|
"""Return semantic segmentation predictions in the original resolution.""" |
|
|
|
|
|
result = result[:, :img_size[0], :img_size[1]].expand(1, -1, -1, -1) |
|
|
|
|
|
result = F.interpolate( |
|
|
result, |
|
|
size=(img_size[0], img_size[1]), |
|
|
mode="bilinear", |
|
|
align_corners=False |
|
|
) |
|
|
return result[0] |
|
|
|
|
|
def forward(self, outputs, orig_shape, orig_pad_shape): |
|
|
"""Forward pass.""" |
|
|
mask_cls_results = outputs["pred_logits"] |
|
|
mask_pred_results = outputs["pred_masks"] |
|
|
depth_pred_results = outputs["pred_depths"] |
|
|
|
|
|
del outputs |
|
|
|
|
|
mask_pred_results = F.interpolate( |
|
|
mask_pred_results, |
|
|
size=(orig_pad_shape[-2], orig_pad_shape[-1]), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
depth_pred_results = F.interpolate( |
|
|
depth_pred_results, |
|
|
size=(orig_pad_shape[-2], orig_pad_shape[-1]), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
if self.cfg.model.mode == "panoptic": |
|
|
processed_results = [] |
|
|
for _, (mask_cls_result, mask_pred_result, depth_pred_result) in enumerate(zip( |
|
|
mask_cls_results, mask_pred_results, depth_pred_results |
|
|
)): |
|
|
processed_results.append({}) |
|
|
|
|
|
mask_pred_result = retry_if_cuda_oom(self.sem_seg_postprocess)( |
|
|
mask_pred_result, orig_shape |
|
|
) |
|
|
mask_cls_result = mask_cls_result.to(mask_pred_result) |
|
|
|
|
|
depth_pred_result = retry_if_cuda_oom(self.sem_seg_postprocess)( |
|
|
depth_pred_result, orig_shape |
|
|
) |
|
|
|
|
|
panoptic_seg, depth_r, segments_info, sem_prob_mask = retry_if_cuda_oom( |
|
|
self.panoptic_inference |
|
|
)(mask_cls_result, mask_pred_result, depth_pred_result) |
|
|
|
|
|
processed_results[-1]["panoptic_seg"] = (panoptic_seg, segments_info) |
|
|
processed_results[-1]["depth"] = depth_r |
|
|
processed_results[-1]["sem_seg"] = sem_prob_mask |
|
|
|
|
|
return processed_results |
|
|
|
|
|
else: |
|
|
raise ValueError("Only panoptic mode is supported for 2D model.") |
|
|
|
|
|
|
|
|
class Panoptic3DModel(nn.Module): |
|
|
"""3D model.""" |
|
|
|
|
|
def __init__(self, cfg): |
|
|
"""Initialize 3D model.""" |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.reprojection = SparseProjection(self.cfg) |
|
|
self.completion = FrustumDecoder(self.cfg) |
|
|
self.projector = ProjectionBlock( |
|
|
self.cfg.model.projection.depth_feature_dim, |
|
|
self.cfg.model.projection.depth_feature_dim |
|
|
) |
|
|
self.ol = OccupancyAwareLifting(self.cfg) |
|
|
self.post_processor = Postprocessor(self.cfg) |
|
|
self.back_projection = BackProjection(self.cfg) |
|
|
|
|
|
|
|
|
self.downsample_factor = cfg.dataset.downsample_factor |
|
|
self.frustum_dims = [cfg.model.frustum3d.frustum_dims] * 3 |
|
|
self.iso_recon_value = cfg.model.frustum3d.iso_recon_value |
|
|
self.truncation = cfg.model.frustum3d.truncation |
|
|
self.num_classes = cfg.model.sem_seg_head.num_classes |
|
|
self.object_mask_threshold = cfg.model.object_mask_threshold |
|
|
self.overlap_threshold = cfg.model.overlap_threshold |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass.""" |
|
|
pass |
|
|
|
|
|
def panoptic_3d_inference( |
|
|
self, geometry, mask_cls, sparse_mask_tuple, min_coordinates, dense_dimensions |
|
|
): |
|
|
"""Panoptic 3D inference.""" |
|
|
panoptic_seg = torch.zeros(geometry.shape, dtype=torch.int32, device=mask_cls.device) |
|
|
semantic_seg = torch.zeros_like(panoptic_seg) |
|
|
panoptic_semantic_mapping = {} |
|
|
|
|
|
scores, labels = F.softmax(mask_cls, dim=-1).max(-1) |
|
|
keep = labels.ne(self.num_classes) & \ |
|
|
labels.ne(0) & \ |
|
|
(scores > self.object_mask_threshold) |
|
|
|
|
|
coords, sparse_masks, stride = sparse_mask_tuple |
|
|
cur_scores = scores[keep] |
|
|
cur_classes = labels[keep] |
|
|
cur_masks = Me.MinkowskiSigmoid()( |
|
|
Me.SparseTensor( |
|
|
features=sparse_masks[:, keep], |
|
|
coordinates=coords, |
|
|
tensor_stride=stride |
|
|
) |
|
|
).dense(dense_dimensions, min_coordinates)[0].squeeze(0) |
|
|
cur_mask_cls = mask_cls[keep] |
|
|
cur_mask_cls = cur_mask_cls[:, :-1] |
|
|
|
|
|
cur_prob_masks = cur_scores.view(-1, 1, 1, 1) * cur_masks |
|
|
|
|
|
current_segment_id = 0 |
|
|
if cur_masks.shape[0] > 0: |
|
|
cur_mask_ids = cur_prob_masks.argmax(0) |
|
|
stuff_memory_list = {} |
|
|
query_to_segment_id = {} |
|
|
for k in range(cur_classes.shape[0]): |
|
|
pred_class = cur_classes[k].item() |
|
|
isthing = pred_class in list(range(1, self.post_processor.num_thing_classes + 1)) |
|
|
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) |
|
|
|
|
|
if mask.sum().item() > 0: |
|
|
if not isthing: |
|
|
if int(pred_class) in stuff_memory_list.keys(): |
|
|
panoptic_seg[mask] = stuff_memory_list[int(pred_class)] |
|
|
query_to_segment_id[k] = stuff_memory_list[int(pred_class)] |
|
|
continue |
|
|
else: |
|
|
stuff_memory_list[int(pred_class)] = current_segment_id + 1 |
|
|
|
|
|
current_segment_id += 1 |
|
|
panoptic_seg[mask] = current_segment_id |
|
|
query_to_segment_id[k] = current_segment_id |
|
|
panoptic_semantic_mapping[current_segment_id] = int(pred_class) |
|
|
|
|
|
surface_mask = geometry.abs() <= 1.5 |
|
|
|
|
|
|
|
|
unassigned_mask = surface_mask & (panoptic_seg == 0) |
|
|
for k in range(cur_classes.shape[0]): |
|
|
mask = (cur_mask_ids == k) & unassigned_mask |
|
|
if mask.sum().item() > 0 and k in query_to_segment_id.keys(): |
|
|
panoptic_seg[mask] = query_to_segment_id[k] |
|
|
|
|
|
for segm_id, semantic_label in panoptic_semantic_mapping.items(): |
|
|
instance_mask = panoptic_seg == segm_id |
|
|
semantic_seg[instance_mask] = semantic_label |
|
|
|
|
|
return panoptic_seg, panoptic_semantic_mapping, semantic_seg |
|
|
|
|
|
def postprocess(self, outputs_3d, outputs_2d, processed_results, frustum_mask): |
|
|
"""Postprocess 3D results.""" |
|
|
dense_dimensions = torch.Size([1, 1] + self.frustum_dims) |
|
|
min_coordinates = torch.IntTensor([0, 0, 0]) |
|
|
|
|
|
geometry_results = to_dense( |
|
|
outputs_3d["pred_geometry"], |
|
|
dense_dimensions, |
|
|
min_coordinates, |
|
|
default_value=self.truncation |
|
|
)[0] |
|
|
mask_3d_results = outputs_3d["pred_segms"][-1] |
|
|
mask_cls_results = outputs_2d["pred_logits"] |
|
|
|
|
|
processed_results_3d = { |
|
|
"intrinsic": [], |
|
|
"image_size": [], |
|
|
"depth": [], |
|
|
"panoptic_seg_2d": [], |
|
|
"geometry": [], |
|
|
"panoptic_seg": [], |
|
|
"semantic_seg": [], |
|
|
"panoptic_semantic_mapping": [], |
|
|
"instance_info_pred": [] |
|
|
} |
|
|
|
|
|
for idx, (geometry_result, mask_cls_result) in enumerate(zip( |
|
|
geometry_results, |
|
|
mask_cls_results |
|
|
)): |
|
|
coords, mask_3d = mask_3d_results.coordinates_at(idx), mask_3d_results.features_at(idx) |
|
|
coords, mask_3d = Me.utils.sparse_collate([coords], [mask_3d]) |
|
|
geometry_result = geometry_result.squeeze(0) |
|
|
panoptic_seg, panoptic_semantic_mapping, semantic_seg = self.panoptic_3d_inference( |
|
|
geometry_result, |
|
|
mask_cls_result, |
|
|
(coords, mask_3d, mask_3d_results.tensor_stride), |
|
|
min_coordinates, |
|
|
dense_dimensions, |
|
|
) |
|
|
|
|
|
processed_results_3d["intrinsic"].append(processed_results[idx]["intrinsic"]) |
|
|
processed_results_3d["image_size"].append(processed_results[idx]["image_size"]) |
|
|
processed_results_3d["depth"].append(processed_results[idx]["depth"]) |
|
|
processed_results_3d["panoptic_seg_2d"].append(processed_results[idx]["panoptic_seg"]) |
|
|
processed_results_3d["geometry"].append(geometry_result) |
|
|
processed_results_3d["panoptic_seg"].append(panoptic_seg) |
|
|
processed_results_3d["semantic_seg"].append(semantic_seg) |
|
|
processed_results_3d["panoptic_semantic_mapping"].append(panoptic_semantic_mapping) |
|
|
processed_results_3d["instance_info_pred"].append(prepare_instance_masks_thicken( |
|
|
panoptic_seg, |
|
|
panoptic_semantic_mapping, |
|
|
geometry_result, |
|
|
frustum_mask[idx], |
|
|
iso_value=self.iso_recon_value, |
|
|
downsample_factor=self.downsample_factor |
|
|
)) |
|
|
|
|
|
return processed_results_3d |
|
|
|