| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """3D stage model for Panoptic Recon 3D.""" |
|
|
| import torch |
| import MinkowskiEngine as Me |
| from torch.nn import functional as F |
| import torch.nn as nn |
|
|
| from .blocks import ProjectionBlock |
| from .utils.helper import retry_if_cuda_oom |
| |
| from .reconstruction import SparseProjection, FrustumDecoder |
| from .mp_occ.occupancy_aware_lifting import OccupancyAwareLifting |
| from .mp_occ.back_projection import BackProjection |
|
|
| from .utils.sparse_tensor import \ |
| to_dense, prepare_instance_masks_thicken |
| from .utils.coords_transform import \ |
| transform_feat3d_coordinates, fuse_sparse_tensors, generate_multiscale_feat3d |
|
|
| class Postprocessor(nn.Module): |
| """2D model postprocessor.""" |
|
|
| def __init__(self, cfg): |
| """Initialize the postprocessor.""" |
| super().__init__() |
| self.cfg = cfg |
| self.test_topk_per_image = cfg.model.test_topk_per_image |
| self.num_classes = cfg.model.sem_seg_head.num_classes |
| self.num_queries = cfg.model.mask_former.num_object_queries |
| self.object_mask_threshold = cfg.model.object_mask_threshold |
| self.overlap_threshold = cfg.model.overlap_threshold |
| self.depth_scale = cfg.dataset.depth_scale |
| self.num_thing_classes = cfg.dataset.num_thing_classes |
|
|
| def panoptic_inference(self, mask_cls, mask_pred, depth_pred): |
| """Post process for panoptic segmentation.""" |
| scores, labels = F.softmax(mask_cls, dim=-1).max(-1) |
| mask_pred = mask_pred.sigmoid() |
|
|
| keep = labels.ne(self.num_classes) & (scores > self.object_mask_threshold) |
| cur_scores = scores[keep] |
| cur_classes = labels[keep] |
| cur_masks = mask_pred[keep] |
| cur_mask_cls = mask_cls[keep] |
| cur_mask_cls = cur_mask_cls[:, :-1] |
|
|
| cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks |
|
|
| h, w = cur_masks.shape[-2:] |
| panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) |
| segments_info = [] |
|
|
| sem_prob_masks = torch.zeros(( |
| self.num_classes, h, w |
| ), dtype=torch.float32, device=cur_masks.device) |
|
|
| current_segment_id = 0 |
|
|
| if cur_masks.shape[0] == 0: |
| return panoptic_seg, depth_pred[0, :, :], segments_info, sem_prob_masks |
| else: |
| cur_mask_ids = cur_prob_masks.argmax(0) |
| stuff_memory_list = {} |
| stuff_mask_ids = [] |
| for k in range(cur_classes.shape[0]): |
| pred_class = cur_classes[k].item() |
| isthing = pred_class in list(range(1, self.num_thing_classes + 1)) |
| mask_area = (cur_mask_ids == k).sum().item() |
| original_area = (cur_masks[k] >= 0.5).sum().item() |
| mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) |
|
|
| if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: |
| if mask_area / original_area < self.overlap_threshold: |
| continue |
|
|
| |
| if not isthing: |
| stuff_mask_ids.append(k) |
| if int(pred_class) in stuff_memory_list.keys(): |
| panoptic_seg[mask] = stuff_memory_list[int(pred_class)] |
| sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask] |
| continue |
| else: |
| stuff_memory_list[int(pred_class)] = current_segment_id + 1 |
|
|
| current_segment_id += 1 |
| panoptic_seg[mask] = current_segment_id |
| sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask] |
|
|
| segments_info.append( |
| { |
| "id": current_segment_id, |
| "isthing": bool(isthing), |
| "category_id": int(pred_class), |
| } |
| ) |
|
|
| if stuff_mask_ids: |
| |
| stuff_mask_ids = torch.tensor(stuff_mask_ids, dtype=torch.long, device=cur_prob_masks.device) |
| cur_stuff_ids = stuff_mask_ids[cur_prob_masks[stuff_mask_ids].argmax(0)] |
| empty_pixel_mask = panoptic_seg == 0 |
| for k in stuff_mask_ids: |
| k = k.item() |
| pred_class = cur_classes[k].item() |
| mask = empty_pixel_mask & (cur_stuff_ids == k) |
| panoptic_seg[mask] = stuff_memory_list[int(pred_class)] |
| sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask] |
|
|
| |
| depth_pred = depth_pred[0, ...].clamp(min=0, max=self.depth_scale) |
| return panoptic_seg, depth_pred, segments_info, sem_prob_masks |
|
|
| @staticmethod |
| def sem_seg_postprocess(result, img_size): |
| """Return semantic segmentation predictions in the original resolution.""" |
| |
| result = result[:, :img_size[0], :img_size[1]].expand(1, -1, -1, -1) |
| |
| result = F.interpolate( |
| result, |
| size=(img_size[0], img_size[1]), |
| mode="bilinear", |
| align_corners=False |
| ) |
| return result[0] |
|
|
| def forward(self, outputs, orig_shape, orig_pad_shape): |
| """Forward pass.""" |
| mask_cls_results = outputs["pred_logits"] |
| mask_pred_results = outputs["pred_masks"] |
| depth_pred_results = outputs["pred_depths"] |
|
|
| del outputs |
|
|
| mask_pred_results = F.interpolate( |
| mask_pred_results, |
| size=(orig_pad_shape[-2], orig_pad_shape[-1]), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| depth_pred_results = F.interpolate( |
| depth_pred_results, |
| size=(orig_pad_shape[-2], orig_pad_shape[-1]), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| if self.cfg.model.mode == "panoptic": |
| processed_results = [] |
| for _, (mask_cls_result, mask_pred_result, depth_pred_result) in enumerate(zip( |
| mask_cls_results, mask_pred_results, depth_pred_results |
| )): |
| processed_results.append({}) |
|
|
| mask_pred_result = retry_if_cuda_oom(self.sem_seg_postprocess)( |
| mask_pred_result, orig_shape |
| ) |
| mask_cls_result = mask_cls_result.to(mask_pred_result) |
|
|
| depth_pred_result = retry_if_cuda_oom(self.sem_seg_postprocess)( |
| depth_pred_result, orig_shape |
| ) |
|
|
| panoptic_seg, depth_r, segments_info, sem_prob_mask = retry_if_cuda_oom( |
| self.panoptic_inference |
| )(mask_cls_result, mask_pred_result, depth_pred_result) |
|
|
| processed_results[-1]["panoptic_seg"] = (panoptic_seg, segments_info) |
| processed_results[-1]["depth"] = depth_r |
| processed_results[-1]["sem_seg"] = sem_prob_mask |
|
|
| return processed_results |
|
|
| else: |
| raise ValueError("Only panoptic mode is supported for 2D model.") |
|
|
|
|
| class Panoptic3DModel(nn.Module): |
| """3D model.""" |
|
|
| def __init__(self, cfg): |
| """Initialize 3D model.""" |
| super().__init__() |
| self.cfg = cfg |
|
|
| |
| |
| |
|
|
| |
| self.reprojection = SparseProjection(self.cfg) |
| self.completion = FrustumDecoder(self.cfg) |
| self.projector = ProjectionBlock( |
| self.cfg.model.projection.depth_feature_dim, |
| self.cfg.model.projection.depth_feature_dim |
| ) |
| self.ol = OccupancyAwareLifting(self.cfg) |
| self.post_processor = Postprocessor(self.cfg) |
| self.back_projection = BackProjection(self.cfg) |
|
|
| |
| self.downsample_factor = cfg.dataset.downsample_factor |
| self.frustum_dims = [cfg.model.frustum3d.frustum_dims] * 3 |
| self.iso_recon_value = cfg.model.frustum3d.iso_recon_value |
| self.truncation = cfg.model.frustum3d.truncation |
| self.num_classes = cfg.model.sem_seg_head.num_classes |
| self.object_mask_threshold = cfg.model.object_mask_threshold |
| self.overlap_threshold = cfg.model.overlap_threshold |
|
|
| def forward(self, x): |
| """Forward pass.""" |
| pass |
|
|
| def panoptic_3d_inference( |
| self, geometry, mask_cls, sparse_mask_tuple, min_coordinates, dense_dimensions |
| ): |
| """Panoptic 3D inference.""" |
| panoptic_seg = torch.zeros(geometry.shape, dtype=torch.int32, device=mask_cls.device) |
| semantic_seg = torch.zeros_like(panoptic_seg) |
| panoptic_semantic_mapping = {} |
|
|
| scores, labels = F.softmax(mask_cls, dim=-1).max(-1) |
| keep = labels.ne(self.num_classes) & \ |
| labels.ne(0) & \ |
| (scores > self.object_mask_threshold) |
|
|
| coords, sparse_masks, stride = sparse_mask_tuple |
| cur_scores = scores[keep] |
| cur_classes = labels[keep] |
| cur_masks = Me.MinkowskiSigmoid()( |
| Me.SparseTensor( |
| features=sparse_masks[:, keep], |
| coordinates=coords, |
| tensor_stride=stride |
| ) |
| ).dense(dense_dimensions, min_coordinates)[0].squeeze(0) |
| cur_mask_cls = mask_cls[keep] |
| cur_mask_cls = cur_mask_cls[:, :-1] |
|
|
| cur_prob_masks = cur_scores.view(-1, 1, 1, 1) * cur_masks |
|
|
| current_segment_id = 0 |
| if cur_masks.shape[0] > 0: |
| cur_mask_ids = cur_prob_masks.argmax(0) |
| stuff_memory_list = {} |
| query_to_segment_id = {} |
| for k in range(cur_classes.shape[0]): |
| pred_class = cur_classes[k].item() |
| isthing = pred_class in list(range(1, self.post_processor.num_thing_classes + 1)) |
| mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) |
|
|
| if mask.sum().item() > 0: |
| if not isthing: |
| if int(pred_class) in stuff_memory_list.keys(): |
| panoptic_seg[mask] = stuff_memory_list[int(pred_class)] |
| query_to_segment_id[k] = stuff_memory_list[int(pred_class)] |
| continue |
| else: |
| stuff_memory_list[int(pred_class)] = current_segment_id + 1 |
|
|
| current_segment_id += 1 |
| panoptic_seg[mask] = current_segment_id |
| query_to_segment_id[k] = current_segment_id |
| panoptic_semantic_mapping[current_segment_id] = int(pred_class) |
|
|
| surface_mask = geometry.abs() <= 1.5 |
|
|
| |
| unassigned_mask = surface_mask & (panoptic_seg == 0) |
| for k in range(cur_classes.shape[0]): |
| mask = (cur_mask_ids == k) & unassigned_mask |
| if mask.sum().item() > 0 and k in query_to_segment_id.keys(): |
| panoptic_seg[mask] = query_to_segment_id[k] |
|
|
| for segm_id, semantic_label in panoptic_semantic_mapping.items(): |
| instance_mask = panoptic_seg == segm_id |
| semantic_seg[instance_mask] = semantic_label |
|
|
| return panoptic_seg, panoptic_semantic_mapping, semantic_seg |
|
|
| def postprocess(self, outputs_3d, outputs_2d, processed_results, frustum_mask): |
| """Postprocess 3D results.""" |
| dense_dimensions = torch.Size([1, 1] + self.frustum_dims) |
| min_coordinates = torch.IntTensor([0, 0, 0]) |
|
|
| geometry_results = to_dense( |
| outputs_3d["pred_geometry"], |
| dense_dimensions, |
| min_coordinates, |
| default_value=self.truncation |
| )[0] |
| mask_3d_results = outputs_3d["pred_segms"][-1] |
| mask_cls_results = outputs_2d["pred_logits"] |
|
|
| processed_results_3d = { |
| "intrinsic": [], |
| "image_size": [], |
| "depth": [], |
| "panoptic_seg_2d": [], |
| "geometry": [], |
| "panoptic_seg": [], |
| "semantic_seg": [], |
| "panoptic_semantic_mapping": [], |
| "instance_info_pred": [] |
| } |
|
|
| for idx, (geometry_result, mask_cls_result) in enumerate(zip( |
| geometry_results, |
| mask_cls_results |
| )): |
| coords, mask_3d = mask_3d_results.coordinates_at(idx), mask_3d_results.features_at(idx) |
| coords, mask_3d = Me.utils.sparse_collate([coords], [mask_3d]) |
| geometry_result = geometry_result.squeeze(0) |
| panoptic_seg, panoptic_semantic_mapping, semantic_seg = self.panoptic_3d_inference( |
| geometry_result, |
| mask_cls_result, |
| (coords, mask_3d, mask_3d_results.tensor_stride), |
| min_coordinates, |
| dense_dimensions, |
| ) |
|
|
| processed_results_3d["intrinsic"].append(processed_results[idx]["intrinsic"]) |
| processed_results_3d["image_size"].append(processed_results[idx]["image_size"]) |
| processed_results_3d["depth"].append(processed_results[idx]["depth"]) |
| processed_results_3d["panoptic_seg_2d"].append(processed_results[idx]["panoptic_seg"]) |
| processed_results_3d["geometry"].append(geometry_result) |
| processed_results_3d["panoptic_seg"].append(panoptic_seg) |
| processed_results_3d["semantic_seg"].append(semantic_seg) |
| processed_results_3d["panoptic_semantic_mapping"].append(panoptic_semantic_mapping) |
| processed_results_3d["instance_info_pred"].append(prepare_instance_masks_thicken( |
| panoptic_seg, |
| panoptic_semantic_mapping, |
| geometry_result, |
| frustum_mask[idx], |
| iso_value=self.iso_recon_value, |
| downsample_factor=self.downsample_factor |
| )) |
|
|
| return processed_results_3d |
|
|