#!/usr/bin/env python3 # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ NVPanoptix-3D Model. """ import json from pathlib import Path from omegaconf import OmegaConf from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Any, Union import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from huggingface_hub import PyTorchModelHubMixin, hf_hub_download from preprocessing import create_frustum_mask, DEFAULT_INTRINSIC from nvpanoptix_3d.model_3d import Panoptic3DModel from nvpanoptix_3d.utils.helper import get_kept_mapping, retry_if_cuda_oom from nvpanoptix_3d.utils.coords_transform import ( transform_feat3d_coordinates, fuse_sparse_tensors, generate_multiscale_feat3d ) # Weight file names (stored in weights/ subdirectory) WEIGHTS_DIR = "weights" TORCHSCRIPT_2D_FILENAME = "model_2d_fp32.pt" CHECKPOINT_3D_FILENAME = "tao_vggt_front3d.pth" @dataclass class PanopticRecon3DConfig: """Configuration for Panoptic Recon 3D model. This config is JSON-serializable and will be saved to config.json when using save_pretrained or push_to_hub. """ # Model architecture num_classes: int = 13 num_thing_classes: int = 9 object_mask_threshold: float = 0.8 overlap_threshold: float = 0.5 test_topk_per_image: int = 100 # Backbone backbone_type: str = "vggt" # Mask Former hidden_dim: int = 256 num_queries: int = 100 mask_dim: int = 256 depth_dim: int = 256 dec_layers: int = 10 # 3D Frustum frustum_dims: int = 256 truncation: float = 3.0 iso_recon_value: float = 2.0 voxel_size: float = 0.03 # Projection depth_feature_dim: int = 256 sign_channel: bool = True # Dataset/preprocessing target_size: Tuple[int, int] = (320, 240) reduced_target_size: Tuple[int, int] = (160, 120) depth_size: Tuple[int, int] = (120, 160) depth_min: float = 0.4 depth_max: float = 6.0 depth_scale: float = 25.0 pixel_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) pixel_std: Tuple[float, float, float] = (0.229, 0.224, 0.225) ignore_label: int = 255 size_divisibility: int = 32 downsample_factor: int = 1 # Model paths torchscript_2d_path: Optional[str] = None use_fp16_2d: bool = False # Dataset mode is_matterport: bool = False def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary.""" return { "num_classes": self.num_classes, "num_thing_classes": self.num_thing_classes, "object_mask_threshold": self.object_mask_threshold, "overlap_threshold": self.overlap_threshold, "test_topk_per_image": self.test_topk_per_image, "backbone_type": self.backbone_type, "hidden_dim": self.hidden_dim, "num_queries": self.num_queries, "mask_dim": self.mask_dim, "depth_dim": self.depth_dim, "dec_layers": self.dec_layers, "frustum_dims": self.frustum_dims, "truncation": self.truncation, "iso_recon_value": self.iso_recon_value, "voxel_size": self.voxel_size, "depth_feature_dim": self.depth_feature_dim, "sign_channel": self.sign_channel, "target_size": list(self.target_size), "reduced_target_size": list(self.reduced_target_size), "depth_size": list(self.depth_size), "depth_min": self.depth_min, "depth_max": self.depth_max, "depth_scale": self.depth_scale, "pixel_mean": list(self.pixel_mean), "pixel_std": list(self.pixel_std), "ignore_label": self.ignore_label, "size_divisibility": self.size_divisibility, "downsample_factor": self.downsample_factor, "torchscript_2d_path": self.torchscript_2d_path, "use_fp16_2d": self.use_fp16_2d, "is_matterport": self.is_matterport, } @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> "PanopticRecon3DConfig": """Create config from dictionary.""" # Convert lists back to tuples if "target_size" in config_dict: config_dict["target_size"] = tuple(config_dict["target_size"]) if "reduced_target_size" in config_dict: config_dict["reduced_target_size"] = tuple(config_dict["reduced_target_size"]) if "depth_size" in config_dict: config_dict["depth_size"] = tuple(config_dict["depth_size"]) if "pixel_mean" in config_dict: config_dict["pixel_mean"] = tuple(config_dict["pixel_mean"]) if "pixel_std" in config_dict: config_dict["pixel_std"] = tuple(config_dict["pixel_std"]) return cls(**config_dict) @dataclass class PanopticRecon3DOutput: """Output from Panoptic Recon 3D model.""" # 3D outputs panoptic_seg_3d: torch.Tensor # (D, H, W) int32 - panoptic segmentation geometry_3d: torch.Tensor # (D, H, W) float32 - TSDF/geometry semantic_seg_3d: torch.Tensor # (D, H, W) int32 - semantic segmentation # 2D outputs panoptic_seg_2d: torch.Tensor # (H, W) int32 - 2D panoptic segmentation depth_2d: torch.Tensor # (H, W) float32 - depth map # Optional metadata panoptic_semantic_mapping: Optional[Dict[int, int]] = None segments_info: Optional[List[Dict]] = None def to_numpy(self) -> Dict[str, np.ndarray]: """Convert outputs to numpy arrays.""" result = { "panoptic_seg_3d": self.panoptic_seg_3d.cpu().numpy(), "geometry_3d": self.geometry_3d.cpu().numpy(), "semantic_seg_3d": self.semantic_seg_3d.cpu().numpy(), "panoptic_seg_2d": self.panoptic_seg_2d.cpu().numpy(), "depth_2d": self.depth_2d.cpu().numpy(), } return result class PanopticRecon3DModel( nn.Module, PyTorchModelHubMixin, # HuggingFace Hub metadata repo_url="nvidia/nvpanoptix-3d", pipeline_tag="image-segmentation", license="apache-2.0", tags=["panoptic-segmentation", "3d-reconstruction", "depth-estimation", "nvidia"], ): """ This model performs panoptic 3D scene reconstruction from a single RGB image. It combines: - 2D panoptic segmentation - Depth estimation - 3D volumetric reconstruction The model architecture uses: - VGGT backbone for feature extraction - MaskFormer head for panoptic segmentation - Occupancy-aware lifting for 2D-to-3D projection - Sparse 3D convolutions for volumetric completion """ def __init__( self, num_classes: int = 13, num_thing_classes: int = 9, object_mask_threshold: float = 0.8, overlap_threshold: float = 0.5, frustum_dims: int = 256, truncation: float = 3.0, iso_recon_value: float = 2.0, voxel_size: float = 0.03, depth_min: float = 0.4, depth_max: float = 6.0, target_size: Tuple[int, int] = (320, 240), reduced_target_size: Tuple[int, int] = (160, 120), size_divisibility: int = 32, downsample_factor: int = 1, is_matterport: bool = False, torchscript_2d_path: Optional[str] = None, use_fp16_2d: bool = False, **kwargs, ): """Initialize Panoptic Recon 3D model. Args: num_classes: Number of semantic classes. num_thing_classes: Number of "thing" (instance) classes. object_mask_threshold: Threshold for object mask confidence. overlap_threshold: Threshold for mask overlap. frustum_dims: Dimensions of 3D frustum volume. truncation: TSDF truncation distance. iso_recon_value: Iso-surface value for mesh extraction. voxel_size: Voxel size in meters. depth_min: Minimum depth value. depth_max: Maximum depth value. target_size: Target image size (width, height). reduced_target_size: Reduced target size for 3D projection. size_divisibility: Size divisibility for padding. downsample_factor: Downsample factor for 3D reconstruction. is_matterport: Whether using Matterport dataset mode. torchscript_2d_path: Path to TorchScript 2D model (optional). use_fp16_2d: Whether to use FP16 for 2D model. """ super().__init__() # Store config as attributes (for PyTorchModelHubMixin serialization) self.num_classes = num_classes self.num_thing_classes = num_thing_classes self.object_mask_threshold = object_mask_threshold self.overlap_threshold = overlap_threshold self.frustum_dims_val = frustum_dims self.truncation = truncation self.iso_recon_value = iso_recon_value self.voxel_size = voxel_size self.depth_min = depth_min self.depth_max = depth_max self.target_size = target_size self.reduced_target_size = reduced_target_size self.size_divisibility = size_divisibility self.downsample_factor = downsample_factor self.is_matterport = is_matterport self.torchscript_2d_path = torchscript_2d_path self.use_fp16_2d = use_fp16_2d # Derived values self.frustum_dims = [frustum_dims] * 3 # Models will be loaded on first use or via load_weights self.model_2d: Optional[torch.jit.ScriptModule] = None self.model_3d_components: Optional[Dict[str, nn.Module]] = None self._initialized = False # Placeholder for post processor self.post_processor = None @classmethod def _from_pretrained( cls, *, model_id: str, revision: Optional[str] = None, cache_dir: Optional[str] = None, force_download: bool = False, proxies: Optional[Dict] = None, resume_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, map_location: str = "cpu", strict: bool = False, **model_kwargs, ) -> "PanopticRecon3DModel": """Load model from HuggingFace Hub or local directory. This method handles loading both the TorchScript 2D model and the 3D checkpoint. Args: model_id: HuggingFace Hub repo ID or local directory path. revision: Git revision (branch, tag, or commit hash). cache_dir: Cache directory for downloaded files. force_download: Force re-download even if cached. proxies: Proxy configuration. resume_download: Resume interrupted downloads. local_files_only: Only use local files, don't download. token: HuggingFace API token. map_location: Device to load model onto. strict: Strict loading (not used for this model). **model_kwargs: Additional model arguments. Returns: Initialized PanopticRecon3DModel with weights loaded. """ # Determine device device = model_kwargs.pop("device", None) if device is None: device = map_location if map_location != "cpu" else "cuda:0" if torch.cuda.is_available() else "cpu" # Check if local directory model_path = Path(model_id) if model_path.exists() and model_path.is_dir(): # Local directory config_path = model_path / "config.json" weights_dir = model_path / WEIGHTS_DIR torchscript_2d_path = weights_dir / TORCHSCRIPT_2D_FILENAME checkpoint_3d_path = weights_dir / CHECKPOINT_3D_FILENAME # Load config if exists if config_path.exists(): with open(config_path, "r") as f: config = json.load(f) # Merge with model_kwargs (model_kwargs take precedence) for key, value in config.items(): if key not in model_kwargs: model_kwargs[key] = value else: # HuggingFace Hub - download files # Download config.json try: config_file = hf_hub_download( repo_id=model_id, filename="config.json", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, ) with open(config_file, "r") as f: config = json.load(f) for key, value in config.items(): if key not in model_kwargs: model_kwargs[key] = value except Exception: pass # Config is optional # Download weight files torchscript_2d_path = hf_hub_download( repo_id=model_id, filename=f"{WEIGHTS_DIR}/{TORCHSCRIPT_2D_FILENAME}", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, ) checkpoint_3d_path = hf_hub_download( repo_id=model_id, filename=f"{WEIGHTS_DIR}/{CHECKPOINT_3D_FILENAME}", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, ) # Create model instance model = cls(**model_kwargs) # Load weights model.load_weights( torchscript_2d_path=str(torchscript_2d_path), checkpoint_3d_path=str(checkpoint_3d_path), device=device, ) return model def _save_pretrained(self, save_directory: Path) -> None: """Save model to directory. This saves the config.json and copies weight files to the directory. Args: save_directory: Directory to save model to. """ save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) # Save config config = { "num_classes": self.num_classes, "num_thing_classes": self.num_thing_classes, "object_mask_threshold": self.object_mask_threshold, "overlap_threshold": self.overlap_threshold, "frustum_dims": self.frustum_dims_val, "truncation": self.truncation, "iso_recon_value": self.iso_recon_value, "voxel_size": self.voxel_size, "depth_min": self.depth_min, "depth_max": self.depth_max, "target_size": list(self.target_size), "reduced_target_size": list(self.reduced_target_size), "size_divisibility": self.size_divisibility, "downsample_factor": self.downsample_factor, "is_matterport": self.is_matterport, "use_fp16_2d": self.use_fp16_2d, } config_path = save_directory / "config.json" with open(config_path, "w") as f: json.dump(config, f, indent=2) # Create weights directory and copy/save weights weights_dir = save_directory / WEIGHTS_DIR weights_dir.mkdir(exist_ok=True) # Note: Weight files should be copied manually or the model # should be saved from a loaded state if self._initialized and hasattr(self, '_torchscript_2d_path'): import shutil # Copy TorchScript 2D model src_2d = Path(self._torchscript_2d_path) if src_2d.exists(): shutil.copy2(src_2d, weights_dir / TORCHSCRIPT_2D_FILENAME) # Copy 3D checkpoint src_3d = Path(self._checkpoint_3d_path) if src_3d.exists(): shutil.copy2(src_3d, weights_dir / CHECKPOINT_3D_FILENAME) def _build_omegaconf(self) -> Any: """Build OmegaConf config for internal model components.""" return OmegaConf.create({ "model": { "export": True, "mode": "panoptic", "object_mask_threshold": self.object_mask_threshold, "overlap_threshold": self.overlap_threshold, "test_topk_per_image": 100, "backbone": {"type": "vggt", "pretrained_weights": None}, "sem_seg_head": { "common_stride": 4, "transformer_enc_layers": 6, "convs_dim": 256, "mask_dim": 256, "depth_dim": 256, "ignore_value": 255, "deformable_transformer_encoder_in_features": ["res3", "res4", "res5"], "num_classes": self.num_classes, "norm": "GN", "in_features": ["res2", "res3", "res4", "res5"] }, "mask_former": { "dropout": 0.0, "nheads": 8, "num_object_queries": 100, "hidden_dim": 256, "transformer_dim_feedforward": 1024, "dim_feedforward": 2048, "dec_layers": 10, "pre_norm": False, "class_weight": 2.0, "dice_weight": 5.0, "mask_weight": 5.0, "depth_weight": 5.0, "mp_occ_weight": 5.0, "train_num_points": 12544, "oversample_ratio": 3.0, "importance_sample_ratio": 0.75, "deep_supervision": True, "no_object_weight": 0.1, "size_divisibility": self.size_divisibility }, "frustum3d": { "truncation": self.truncation, "iso_recon_value": self.iso_recon_value, "panoptic_weight": 25.0, "completion_weights": [50.0, 25.0, 10.0], "surface_weight": 5.0, "unet_output_channels": 16, "unet_features": 16, "use_multi_scale": False, "grid_dimensions": self.frustum_dims_val, "frustum_dims": self.frustum_dims_val, "signed_channel": 3 }, "projection": { "voxel_size": self.voxel_size, "sign_channel": True, "depth_feature_dim": 256 } }, "dataset": { "contiguous_id": False, "label_map": "", "name": "", # Empty string to match Triton behavior (triggers adjust_intrinsic) "downsample_factor": self.downsample_factor, "iso_value": 1.0, "pixel_mean": [0.485, 0.456, 0.406], "pixel_std": [0.229, 0.224, 0.225], "ignore_label": 255, "min_instance_pixels": 200, "img_format": "RGB", "target_size": list(self.target_size), "reduced_target_size": list(self.reduced_target_size), "depth_size": [120, 160], "depth_bound": False, "depth_min": self.depth_min, "depth_max": self.depth_max, "frustum_mask_path": "", "occ_truncation_lvl": [8.0, 6.0], "truncation_range": [0.0, 12.0], "enable_3d": False, "enable_mp_occ": True, "depth_scale": 25.0, "num_thing_classes": self.num_thing_classes, "augmentation": {"size_divisibility": self.size_divisibility} } }) def load_weights( self, torchscript_2d_path: Optional[str] = None, checkpoint_3d_path: Optional[str] = None, device: str = "cuda:0", ): """Load model weights. Args: torchscript_2d_path: Path to TorchScript 2D model file. checkpoint_3d_path: Path to 3D model checkpoint (.pth/.pt). device: Device to load models onto. """ # Use stored path if not provided torchscript_2d_path = torchscript_2d_path or self.torchscript_2d_path if torchscript_2d_path is None: raise ValueError("torchscript_2d_path is required") if checkpoint_3d_path is None: raise ValueError("checkpoint_3d_path is required") # Store paths for save_pretrained self._torchscript_2d_path = torchscript_2d_path self._checkpoint_3d_path = checkpoint_3d_path # Build config cfg = self._build_omegaconf() # Load 2D TorchScript model self.model_2d = torch.jit.load(torchscript_2d_path, map_location=device) self.model_2d.eval() # Load 3D model from checkpoint full_model = Panoptic3DModel(cfg) checkpoint = torch.load(checkpoint_3d_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # Remove 'model.' prefix if present filtered_state_dict = {} for key, value in state_dict.items(): new_key = key[6:] if key.startswith("model.") else key filtered_state_dict[new_key] = value full_model.load_state_dict(filtered_state_dict, strict=False) full_model.to(device) full_model.eval() # Extract 3D components self.model_3d_components = { "ol": full_model.ol, "reprojection": full_model.reprojection, "completion": full_model.completion, "projector": full_model.projector, "back_projection": full_model.back_projection, } # Store post processor and helper functions self.post_processor = full_model.post_processor self.back_projection = full_model.back_projection # Required for get_kept_mapping self._get_kept_mapping = get_kept_mapping self._transform_feat3d_coordinates = transform_feat3d_coordinates self._fuse_sparse_tensors = fuse_sparse_tensors self._generate_multiscale_feat3d = generate_multiscale_feat3d self._retry_if_cuda_oom = retry_if_cuda_oom self._panoptic_3d_inference = full_model.panoptic_3d_inference self._postprocess = full_model.postprocess self._cfg = cfg # Disable gradients for all components for module in self.model_3d_components.values(): for param in module.parameters(): param.requires_grad = False module.eval() self._initialized = True self._device = device def _ensure_initialized(self): """Ensure model is initialized.""" if not self._initialized: raise RuntimeError( "Model weights not loaded. Call load_weights() first, or use " "from_pretrained() to load a pre-trained model." ) def _infer_2d( self, images: torch.Tensor, intrinsic: torch.Tensor, ) -> Tuple[Dict[str, torch.Tensor], List[Dict], torch.Tensor]: """Run 2D inference using TorchScript model. Args: images: Input images (B, C, H, W) as uint8 or float. intrinsic: Camera intrinsics (B, 4, 4). Returns: outputs_2d: Dictionary of 2D model outputs. processed_results: List of processed results per image. occupancy_pred: Occupancy predictions. """ # Run 2D model with torch.no_grad(): if self.use_fp16_2d: with torch.cuda.amp.autocast(): outputs_dict = self.model_2d(images) else: outputs_dict = self.model_2d(images) # Normalize to FP32 def to_fp32(x): return x.float() if isinstance(x, torch.Tensor) and x.dtype != torch.float32 else x # Extract outputs mask_cls_results = to_fp32(outputs_dict["pred_logits"]) mask_pred_results = to_fp32(outputs_dict["pred_masks"]) depth_pred_results = to_fp32(outputs_dict["pred_depths"]) enc_features = [ to_fp32(outputs_dict["enc_features_0"]), to_fp32(outputs_dict["enc_features_1"]), to_fp32(outputs_dict["enc_features_2"]), to_fp32(outputs_dict["enc_features_3"]), ] mask_features = to_fp32(outputs_dict["mask_features"]) depth_features = to_fp32(outputs_dict["depth_features"]) segm_decoder_out = to_fp32(outputs_dict["segm_decoder_out"]) pose_enc = to_fp32(outputs_dict["pose_enc"]) occupancy_pred = to_fp32(outputs_dict["occupancy_pred"]) orig_pad_h = int(outputs_dict["orig_pad_h"].item()) orig_pad_w = int(outputs_dict["orig_pad_w"].item()) orig_h = int(outputs_dict["orig_h"].item()) orig_w = int(outputs_dict["orig_w"].item()) # Interpolate masks and depths padded_out_h, padded_out_w = orig_pad_h // 2, orig_pad_w // 2 mask_pred_results = F.interpolate( mask_pred_results, size=(padded_out_h, padded_out_w), mode="bilinear", align_corners=False, ) depth_pred_results = F.interpolate( depth_pred_results, size=(padded_out_h, padded_out_w), mode="bilinear", align_corners=False, ) # Postprocess each image # NOTE: We need to track the CROPPED mask_pred_result for outputs_2d # (matching the Triton model behavior) processed_results = [] final_mask_cls_result = None final_mask_pred_result = None for idx, (mask_cls_result, mask_pred_result, depth_pred_result, per_image_intrinsic) in enumerate(zip( mask_cls_results, mask_pred_results, depth_pred_results, intrinsic )): out_h, out_w = orig_h // 2, orig_w // 2 processed_results.append({}) # Remove padding - OVERWRITE the variable like Triton does mask_pred_result = mask_pred_result[:, :out_h, :out_w] depth_pred_result = depth_pred_result[:, :out_h, :out_w] # Panoptic inference panoptic_seg, depth_r, segments_info, sem_prob_masks = self._retry_if_cuda_oom( self.post_processor.panoptic_inference )( mask_cls_result, mask_pred_result, depth_pred_result ) depth_r = depth_r[None] processed_results[-1]["panoptic_seg"] = (panoptic_seg, segments_info) processed_results[-1]["depth"] = depth_r[0] processed_results[-1]["image_size"] = (orig_w, orig_h) processed_results[-1]["padded_size"] = (orig_pad_w, orig_pad_h) processed_results[-1]["intrinsic"] = per_image_intrinsic processed_results[-1]["sem_seg"] = sem_prob_masks # Store last iteration's results for outputs_2d (matching Triton behavior) final_mask_cls_result = mask_cls_result final_mask_pred_result = mask_pred_result # Reconstruct outputs_2d - use CROPPED mask_pred_result from last iteration # This matches the Triton model's behavior exactly outputs_2d = { "pred_logits": final_mask_cls_result.unsqueeze(0), "pred_masks": final_mask_pred_result.unsqueeze(0), "enc_features": enc_features, "mask_features": mask_features, "depth_features": depth_features, "segm_decoder_out": segm_decoder_out, "pose_enc": pose_enc, } return outputs_2d, processed_results, occupancy_pred def _forward_3d( self, batched_inputs: Dict[str, torch.Tensor], outputs_2d: Dict[str, torch.Tensor], processed_results: List[Dict], kept: torch.Tensor, mapping: torch.Tensor, occupancy_pred: torch.Tensor, ) -> Dict[str, Any]: """Run 3D reconstruction pipeline. Args: batched_inputs: Dictionary containing frustum_mask, intrinsic, etc. outputs_2d: 2D model outputs. processed_results: Processed 2D results. kept: Kept voxel indices. mapping: Voxel to pixel mapping. occupancy_pred: Occupancy predictions. Returns: Postprocessed 3D results. """ room_mask = batched_inputs.get("room_mask_buol") if self.is_matterport else None # Occupancy-aware lifting feat3d, mask3d = self.model_3d_components["ol"]( processed_results, kept, mapping, occupancy_pred, room_mask ) del occupancy_pred, mask3d torch.cuda.empty_cache() # Project features multi_scale_features = list(reversed(outputs_2d["enc_features"])) depth_features = self.model_3d_components["projector"]( outputs_2d["depth_features"], outputs_2d["mask_features"].shape[-2:] ) encoder_features = torch.cat([outputs_2d["mask_features"], depth_features], dim=1) sparse_multi_scale_features, sparse_encoder_features = self.model_3d_components["reprojection"]( multi_scale_features, encoder_features, processed_results ) del multi_scale_features, encoder_features torch.cuda.empty_cache() # Prepare 3D inputs segm_queries = outputs_2d["segm_decoder_out"] frustum_mask = batched_inputs["frustum_mask"] intrinsic = batched_inputs["intrinsic"] frustum_mask_64 = F.max_pool3d( frustum_mask[:, None].float(), kernel_size=2, stride=4 ).bool() # Transform 3D coordinates transformed_feat3d = self._transform_feat3d_coordinates(feat3d, intrinsic) del feat3d # Fuse features if not self.is_matterport: multi_scale_feat3d = self._generate_multiscale_feat3d(transformed_feat3d) fused_multi_scale_features = [ self._fuse_sparse_tensors(sparse_multi_scale_features[i], multi_scale_feat3d[i]) for i in range(len(multi_scale_feat3d)) ] del sparse_multi_scale_features, multi_scale_feat3d else: fused_multi_scale_features = sparse_multi_scale_features try: fused_encoder_features = self._fuse_sparse_tensors( sparse_encoder_features, transformed_feat3d ) except Exception: fused_encoder_features = sparse_encoder_features del sparse_encoder_features, transformed_feat3d torch.cuda.empty_cache() # Run 3D completion outputs_3d = self.model_3d_components["completion"]( fused_multi_scale_features, fused_encoder_features, segm_queries, frustum_mask_64 ) outputs_3d["pred_logits"] = outputs_2d["pred_logits"] outputs_3d["pred_masks"] = outputs_2d["pred_masks"] return self._postprocess(outputs_3d, outputs_2d, processed_results, frustum_mask) def forward( self, images: torch.Tensor, frustum_mask: torch.Tensor, intrinsic: torch.Tensor, height: Optional[torch.Tensor] = None, width: Optional[torch.Tensor] = None, ) -> PanopticRecon3DOutput: """Run full panoptic 3D reconstruction pipeline. Args: images: Input images (B, C, H, W) as uint8 [0-255] or float [0-1]. frustum_mask: Boolean frustum mask (B, D, H, W). intrinsic: Camera intrinsic matrices (B, 4, 4). height: Optional image heights (B,). width: Optional image widths (B,). Returns: PanopticRecon3DOutput with 2D and 3D predictions. """ self._ensure_initialized() # Prepare inputs if height is None: height = torch.tensor([images.shape[2]], device=images.device) if width is None: width = torch.tensor([images.shape[3]], device=images.device) batched_inputs = { "image": images, "frustum_mask": frustum_mask.bool(), "intrinsic": intrinsic, "height": height, "width": width, } # Run 2D inference outputs_2d, processed_results, occupancy_pred = self._infer_2d(images, intrinsic) # Compute kept and mapping (self has back_projection attribute) kept, mapping = self._get_kept_mapping( self, self._cfg, batched_inputs, device=images.device ) # Run 3D inference outputs_3d = self._forward_3d( batched_inputs, outputs_2d, processed_results, kept, mapping, occupancy_pred ) # Create output object return PanopticRecon3DOutput( panoptic_seg_3d=outputs_3d["panoptic_seg"][0], geometry_3d=outputs_3d["geometry"][0], semantic_seg_3d=outputs_3d["semantic_seg"][0], panoptic_seg_2d=outputs_3d["panoptic_seg_2d"][0][0], depth_2d=outputs_3d["depth"][0], panoptic_semantic_mapping=outputs_3d["panoptic_semantic_mapping"][0], segments_info=outputs_3d["panoptic_seg_2d"][0][1] if len(outputs_3d["panoptic_seg_2d"][0]) > 1 else None, ) @torch.no_grad() def predict( self, image: Union[np.ndarray, torch.Tensor], frustum_mask: Optional[Union[np.ndarray, torch.Tensor]] = None, intrinsic: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> PanopticRecon3DOutput: """User-friendly prediction interface. Args: image: Input RGB image as numpy array. Accepted formats: - (H, W, C) HWC format uint8 [0-255] - (C, H, W) CHW format uint8 [0-255] - (1, C, H, W) batched CHW format (from load_image) frustum_mask: Optional frustum mask. If None, auto-generated using default intrinsic. intrinsic: Optional camera intrinsic (4x4). If None, uses DEFAULT_INTRINSIC. Returns: PanopticRecon3DOutput with predictions. """ self._ensure_initialized() # Use default intrinsic if not provided if intrinsic is None: intrinsic = DEFAULT_INTRINSIC.copy() # Process image - match test_triton_server.py preprocessing exactly if isinstance(image, np.ndarray): # Handle different input formats if image.ndim == 4: # Already batched (1, C, H, W) - from load_image pass elif image.ndim == 3: if image.shape[2] == 3: # HWC format -> CHW format image = np.ascontiguousarray(image.transpose(2, 0, 1)) # Now it's CHW, add batch dimension image = image[np.newaxis, ...] # Ensure uint8 if image.dtype != np.uint8: if image.max() <= 1.0: image = (image * 255).clip(0, 255).astype(np.uint8) else: image = image.clip(0, 255).astype(np.uint8) image = torch.from_numpy(image) else: # Tensor input if image.dim() == 3: image = image.unsqueeze(0) # Ensure uint8 if image.dtype != torch.uint8: if image.max() <= 1.0: image = (image * 255).clamp(0, 255).to(torch.uint8) else: image = image.clamp(0, 255).to(torch.uint8) image = image.to(self._device) # Generate frustum mask if not provided if frustum_mask is None: intrinsic_np = intrinsic if isinstance(intrinsic, np.ndarray) else intrinsic.cpu().numpy() frustum_mask = create_frustum_mask( intrinsics=intrinsic_np, volume_shape=(self.frustum_dims_val,) * 3, depth_range=(self.depth_min, self.depth_max), voxel_size=self.voxel_size, image_shape=(self.target_size[1], self.target_size[0]), ) frustum_mask = torch.from_numpy(frustum_mask).unsqueeze(0) elif isinstance(frustum_mask, np.ndarray): frustum_mask = torch.from_numpy(frustum_mask) if frustum_mask.dim() == 3: frustum_mask = frustum_mask.unsqueeze(0) frustum_mask = frustum_mask.to(self._device) # Convert intrinsic to tensor if isinstance(intrinsic, np.ndarray): intrinsic = torch.from_numpy(intrinsic) if intrinsic.dim() == 2: intrinsic = intrinsic.unsqueeze(0) intrinsic = intrinsic.float().to(self._device) return self.forward(image, frustum_mask, intrinsic)