vpraveen-nv's picture
Update model inference code and environment setup instructions (#4)
f4a0919 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""3D stage model for Panoptic Recon 3D."""
import torch
import MinkowskiEngine as Me
from torch.nn import functional as F
import torch.nn as nn
from .blocks import ProjectionBlock
from .utils.helper import retry_if_cuda_oom
# from .model_2d import MaskFormerModel
from .reconstruction import SparseProjection, FrustumDecoder
from .mp_occ.occupancy_aware_lifting import OccupancyAwareLifting
from .mp_occ.back_projection import BackProjection
from .utils.sparse_tensor import \
to_dense, prepare_instance_masks_thicken
from .utils.coords_transform import \
transform_feat3d_coordinates, fuse_sparse_tensors, generate_multiscale_feat3d
class Postprocessor(nn.Module):
"""2D model postprocessor."""
def __init__(self, cfg):
"""Initialize the postprocessor."""
super().__init__()
self.cfg = cfg
self.test_topk_per_image = cfg.model.test_topk_per_image
self.num_classes = cfg.model.sem_seg_head.num_classes
self.num_queries = cfg.model.mask_former.num_object_queries
self.object_mask_threshold = cfg.model.object_mask_threshold
self.overlap_threshold = cfg.model.overlap_threshold
self.depth_scale = cfg.dataset.depth_scale
self.num_thing_classes = cfg.dataset.num_thing_classes
def panoptic_inference(self, mask_cls, mask_pred, depth_pred):
"""Post process for panoptic segmentation."""
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
mask_pred = mask_pred.sigmoid()
keep = labels.ne(self.num_classes) & (scores > self.object_mask_threshold)
cur_scores = scores[keep]
cur_classes = labels[keep]
cur_masks = mask_pred[keep]
cur_mask_cls = mask_cls[keep]
cur_mask_cls = cur_mask_cls[:, :-1]
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
h, w = cur_masks.shape[-2:]
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
segments_info = []
sem_prob_masks = torch.zeros((
self.num_classes, h, w
), dtype=torch.float32, device=cur_masks.device)
current_segment_id = 0
if cur_masks.shape[0] == 0:
return panoptic_seg, depth_pred[0, :, :], segments_info, sem_prob_masks
else:
cur_mask_ids = cur_prob_masks.argmax(0)
stuff_memory_list = {}
stuff_mask_ids = []
for k in range(cur_classes.shape[0]):
pred_class = cur_classes[k].item()
isthing = pred_class in list(range(1, self.num_thing_classes + 1))
mask_area = (cur_mask_ids == k).sum().item()
original_area = (cur_masks[k] >= 0.5).sum().item()
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
if mask_area / original_area < self.overlap_threshold:
continue
# merge stuff regions
if not isthing:
stuff_mask_ids.append(k)
if int(pred_class) in stuff_memory_list.keys():
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask]
continue
else:
stuff_memory_list[int(pred_class)] = current_segment_id + 1
current_segment_id += 1
panoptic_seg[mask] = current_segment_id
sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask]
segments_info.append(
{
"id": current_segment_id,
"isthing": bool(isthing),
"category_id": int(pred_class),
}
)
if stuff_mask_ids:
# recover void pixels
stuff_mask_ids = torch.tensor(stuff_mask_ids, dtype=torch.long, device=cur_prob_masks.device)
cur_stuff_ids = stuff_mask_ids[cur_prob_masks[stuff_mask_ids].argmax(0)]
empty_pixel_mask = panoptic_seg == 0
for k in stuff_mask_ids:
k = k.item()
pred_class = cur_classes[k].item()
mask = empty_pixel_mask & (cur_stuff_ids == k)
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
sem_prob_masks[int(pred_class)][mask] = cur_prob_masks[k][mask]
# clamp depth_pred
depth_pred = depth_pred[0, ...].clamp(min=0, max=self.depth_scale)
return panoptic_seg, depth_pred, segments_info, sem_prob_masks
@staticmethod
def sem_seg_postprocess(result, img_size):
"""Return semantic segmentation predictions in the original resolution."""
# Crop each image in the batch to the original img_size
result = result[:, :img_size[0], :img_size[1]].expand(1, -1, -1, -1)
# Interpolate to the desired output size
result = F.interpolate(
result,
size=(img_size[0], img_size[1]),
mode="bilinear",
align_corners=False
)
return result[0]
def forward(self, outputs, orig_shape, orig_pad_shape):
"""Forward pass."""
mask_cls_results = outputs["pred_logits"]
mask_pred_results = outputs["pred_masks"]
depth_pred_results = outputs["pred_depths"]
del outputs
mask_pred_results = F.interpolate(
mask_pred_results,
size=(orig_pad_shape[-2], orig_pad_shape[-1]),
mode="bilinear",
align_corners=False,
)
depth_pred_results = F.interpolate(
depth_pred_results,
size=(orig_pad_shape[-2], orig_pad_shape[-1]),
mode="bilinear",
align_corners=False,
)
if self.cfg.model.mode == "panoptic":
processed_results = []
for _, (mask_cls_result, mask_pred_result, depth_pred_result) in enumerate(zip(
mask_cls_results, mask_pred_results, depth_pred_results
)):
processed_results.append({})
mask_pred_result = retry_if_cuda_oom(self.sem_seg_postprocess)(
mask_pred_result, orig_shape
)
mask_cls_result = mask_cls_result.to(mask_pred_result)
depth_pred_result = retry_if_cuda_oom(self.sem_seg_postprocess)(
depth_pred_result, orig_shape
)
panoptic_seg, depth_r, segments_info, sem_prob_mask = retry_if_cuda_oom(
self.panoptic_inference
)(mask_cls_result, mask_pred_result, depth_pred_result)
processed_results[-1]["panoptic_seg"] = (panoptic_seg, segments_info)
processed_results[-1]["depth"] = depth_r
processed_results[-1]["sem_seg"] = sem_prob_mask
return processed_results
else:
raise ValueError("Only panoptic mode is supported for 2D model.")
class Panoptic3DModel(nn.Module):
"""3D model."""
def __init__(self, cfg):
"""Initialize 3D model."""
super().__init__()
self.cfg = cfg
# disable gradients for the 2D model
# for _, param in self.named_parameters():
# param.requires_grad_(False)
# 3D modules
self.reprojection = SparseProjection(self.cfg)
self.completion = FrustumDecoder(self.cfg)
self.projector = ProjectionBlock(
self.cfg.model.projection.depth_feature_dim,
self.cfg.model.projection.depth_feature_dim
)
self.ol = OccupancyAwareLifting(self.cfg)
self.post_processor = Postprocessor(self.cfg)
self.back_projection = BackProjection(self.cfg)
# 3D model parameters
self.downsample_factor = cfg.dataset.downsample_factor
self.frustum_dims = [cfg.model.frustum3d.frustum_dims] * 3
self.iso_recon_value = cfg.model.frustum3d.iso_recon_value
self.truncation = cfg.model.frustum3d.truncation
self.num_classes = cfg.model.sem_seg_head.num_classes
self.object_mask_threshold = cfg.model.object_mask_threshold
self.overlap_threshold = cfg.model.overlap_threshold
def forward(self, x):
"""Forward pass."""
pass
def panoptic_3d_inference(
self, geometry, mask_cls, sparse_mask_tuple, min_coordinates, dense_dimensions
):
"""Panoptic 3D inference."""
panoptic_seg = torch.zeros(geometry.shape, dtype=torch.int32, device=mask_cls.device)
semantic_seg = torch.zeros_like(panoptic_seg)
panoptic_semantic_mapping = {}
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
keep = labels.ne(self.num_classes) & \
labels.ne(0) & \
(scores > self.object_mask_threshold)
coords, sparse_masks, stride = sparse_mask_tuple
cur_scores = scores[keep]
cur_classes = labels[keep]
cur_masks = Me.MinkowskiSigmoid()(
Me.SparseTensor(
features=sparse_masks[:, keep],
coordinates=coords,
tensor_stride=stride
)
).dense(dense_dimensions, min_coordinates)[0].squeeze(0)
cur_mask_cls = mask_cls[keep]
cur_mask_cls = cur_mask_cls[:, :-1]
cur_prob_masks = cur_scores.view(-1, 1, 1, 1) * cur_masks
current_segment_id = 0
if cur_masks.shape[0] > 0:
cur_mask_ids = cur_prob_masks.argmax(0)
stuff_memory_list = {}
query_to_segment_id = {}
for k in range(cur_classes.shape[0]):
pred_class = cur_classes[k].item()
isthing = pred_class in list(range(1, self.post_processor.num_thing_classes + 1))
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
if mask.sum().item() > 0:
if not isthing:
if int(pred_class) in stuff_memory_list.keys():
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
query_to_segment_id[k] = stuff_memory_list[int(pred_class)]
continue
else:
stuff_memory_list[int(pred_class)] = current_segment_id + 1
current_segment_id += 1
panoptic_seg[mask] = current_segment_id
query_to_segment_id[k] = current_segment_id
panoptic_semantic_mapping[current_segment_id] = int(pred_class)
surface_mask = geometry.abs() <= 1.5
# fill unassigned surface voxels
unassigned_mask = surface_mask & (panoptic_seg == 0)
for k in range(cur_classes.shape[0]):
mask = (cur_mask_ids == k) & unassigned_mask
if mask.sum().item() > 0 and k in query_to_segment_id.keys():
panoptic_seg[mask] = query_to_segment_id[k]
for segm_id, semantic_label in panoptic_semantic_mapping.items():
instance_mask = panoptic_seg == segm_id
semantic_seg[instance_mask] = semantic_label
return panoptic_seg, panoptic_semantic_mapping, semantic_seg
def postprocess(self, outputs_3d, outputs_2d, processed_results, frustum_mask):
"""Postprocess 3D results."""
dense_dimensions = torch.Size([1, 1] + self.frustum_dims)
min_coordinates = torch.IntTensor([0, 0, 0])
geometry_results = to_dense(
outputs_3d["pred_geometry"],
dense_dimensions,
min_coordinates,
default_value=self.truncation
)[0]
mask_3d_results = outputs_3d["pred_segms"][-1]
mask_cls_results = outputs_2d["pred_logits"]
processed_results_3d = {
"intrinsic": [],
"image_size": [],
"depth": [],
"panoptic_seg_2d": [],
"geometry": [],
"panoptic_seg": [],
"semantic_seg": [],
"panoptic_semantic_mapping": [],
"instance_info_pred": []
}
for idx, (geometry_result, mask_cls_result) in enumerate(zip(
geometry_results,
mask_cls_results
)):
coords, mask_3d = mask_3d_results.coordinates_at(idx), mask_3d_results.features_at(idx)
coords, mask_3d = Me.utils.sparse_collate([coords], [mask_3d])
geometry_result = geometry_result.squeeze(0)
panoptic_seg, panoptic_semantic_mapping, semantic_seg = self.panoptic_3d_inference(
geometry_result,
mask_cls_result,
(coords, mask_3d, mask_3d_results.tensor_stride),
min_coordinates,
dense_dimensions,
)
processed_results_3d["intrinsic"].append(processed_results[idx]["intrinsic"])
processed_results_3d["image_size"].append(processed_results[idx]["image_size"])
processed_results_3d["depth"].append(processed_results[idx]["depth"])
processed_results_3d["panoptic_seg_2d"].append(processed_results[idx]["panoptic_seg"])
processed_results_3d["geometry"].append(geometry_result)
processed_results_3d["panoptic_seg"].append(panoptic_seg)
processed_results_3d["semantic_seg"].append(semantic_seg)
processed_results_3d["panoptic_semantic_mapping"].append(panoptic_semantic_mapping)
processed_results_3d["instance_info_pred"].append(prepare_instance_masks_thicken(
panoptic_seg,
panoptic_semantic_mapping,
geometry_result,
frustum_mask[idx],
iso_value=self.iso_recon_value,
downsample_factor=self.downsample_factor
))
return processed_results_3d