|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
NVPanoptix-3D Model. |
|
|
""" |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from omegaconf import OmegaConf |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple, List, Dict, Any, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download |
|
|
from preprocessing import create_frustum_mask, DEFAULT_INTRINSIC |
|
|
|
|
|
from nvpanoptix_3d.model_3d import Panoptic3DModel |
|
|
from nvpanoptix_3d.utils.helper import get_kept_mapping, retry_if_cuda_oom |
|
|
from nvpanoptix_3d.utils.coords_transform import ( |
|
|
transform_feat3d_coordinates, fuse_sparse_tensors, generate_multiscale_feat3d |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
WEIGHTS_DIR = "weights" |
|
|
TORCHSCRIPT_2D_FILENAME = "model_2d_fp32.pt" |
|
|
CHECKPOINT_3D_FILENAME = "tao_vggt_front3d.pth" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PanopticRecon3DConfig: |
|
|
"""Configuration for Panoptic Recon 3D model. |
|
|
|
|
|
This config is JSON-serializable and will be saved to config.json |
|
|
when using save_pretrained or push_to_hub. |
|
|
""" |
|
|
|
|
|
num_classes: int = 13 |
|
|
num_thing_classes: int = 9 |
|
|
object_mask_threshold: float = 0.8 |
|
|
overlap_threshold: float = 0.5 |
|
|
test_topk_per_image: int = 100 |
|
|
|
|
|
|
|
|
backbone_type: str = "vggt" |
|
|
|
|
|
|
|
|
hidden_dim: int = 256 |
|
|
num_queries: int = 100 |
|
|
mask_dim: int = 256 |
|
|
depth_dim: int = 256 |
|
|
dec_layers: int = 10 |
|
|
|
|
|
|
|
|
frustum_dims: int = 256 |
|
|
truncation: float = 3.0 |
|
|
iso_recon_value: float = 2.0 |
|
|
voxel_size: float = 0.03 |
|
|
|
|
|
|
|
|
depth_feature_dim: int = 256 |
|
|
sign_channel: bool = True |
|
|
|
|
|
|
|
|
target_size: Tuple[int, int] = (320, 240) |
|
|
reduced_target_size: Tuple[int, int] = (160, 120) |
|
|
depth_size: Tuple[int, int] = (120, 160) |
|
|
depth_min: float = 0.4 |
|
|
depth_max: float = 6.0 |
|
|
depth_scale: float = 25.0 |
|
|
pixel_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) |
|
|
pixel_std: Tuple[float, float, float] = (0.229, 0.224, 0.225) |
|
|
ignore_label: int = 255 |
|
|
size_divisibility: int = 32 |
|
|
downsample_factor: int = 1 |
|
|
|
|
|
|
|
|
torchscript_2d_path: Optional[str] = None |
|
|
use_fp16_2d: bool = False |
|
|
|
|
|
|
|
|
is_matterport: bool = False |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert config to dictionary.""" |
|
|
return { |
|
|
"num_classes": self.num_classes, |
|
|
"num_thing_classes": self.num_thing_classes, |
|
|
"object_mask_threshold": self.object_mask_threshold, |
|
|
"overlap_threshold": self.overlap_threshold, |
|
|
"test_topk_per_image": self.test_topk_per_image, |
|
|
"backbone_type": self.backbone_type, |
|
|
"hidden_dim": self.hidden_dim, |
|
|
"num_queries": self.num_queries, |
|
|
"mask_dim": self.mask_dim, |
|
|
"depth_dim": self.depth_dim, |
|
|
"dec_layers": self.dec_layers, |
|
|
"frustum_dims": self.frustum_dims, |
|
|
"truncation": self.truncation, |
|
|
"iso_recon_value": self.iso_recon_value, |
|
|
"voxel_size": self.voxel_size, |
|
|
"depth_feature_dim": self.depth_feature_dim, |
|
|
"sign_channel": self.sign_channel, |
|
|
"target_size": list(self.target_size), |
|
|
"reduced_target_size": list(self.reduced_target_size), |
|
|
"depth_size": list(self.depth_size), |
|
|
"depth_min": self.depth_min, |
|
|
"depth_max": self.depth_max, |
|
|
"depth_scale": self.depth_scale, |
|
|
"pixel_mean": list(self.pixel_mean), |
|
|
"pixel_std": list(self.pixel_std), |
|
|
"ignore_label": self.ignore_label, |
|
|
"size_divisibility": self.size_divisibility, |
|
|
"downsample_factor": self.downsample_factor, |
|
|
"torchscript_2d_path": self.torchscript_2d_path, |
|
|
"use_fp16_2d": self.use_fp16_2d, |
|
|
"is_matterport": self.is_matterport, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, config_dict: Dict[str, Any]) -> "PanopticRecon3DConfig": |
|
|
"""Create config from dictionary.""" |
|
|
|
|
|
if "target_size" in config_dict: |
|
|
config_dict["target_size"] = tuple(config_dict["target_size"]) |
|
|
if "reduced_target_size" in config_dict: |
|
|
config_dict["reduced_target_size"] = tuple(config_dict["reduced_target_size"]) |
|
|
if "depth_size" in config_dict: |
|
|
config_dict["depth_size"] = tuple(config_dict["depth_size"]) |
|
|
if "pixel_mean" in config_dict: |
|
|
config_dict["pixel_mean"] = tuple(config_dict["pixel_mean"]) |
|
|
if "pixel_std" in config_dict: |
|
|
config_dict["pixel_std"] = tuple(config_dict["pixel_std"]) |
|
|
return cls(**config_dict) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PanopticRecon3DOutput: |
|
|
"""Output from Panoptic Recon 3D model.""" |
|
|
|
|
|
panoptic_seg_3d: torch.Tensor |
|
|
geometry_3d: torch.Tensor |
|
|
semantic_seg_3d: torch.Tensor |
|
|
|
|
|
|
|
|
panoptic_seg_2d: torch.Tensor |
|
|
depth_2d: torch.Tensor |
|
|
|
|
|
|
|
|
panoptic_semantic_mapping: Optional[Dict[int, int]] = None |
|
|
segments_info: Optional[List[Dict]] = None |
|
|
|
|
|
def to_numpy(self) -> Dict[str, np.ndarray]: |
|
|
"""Convert outputs to numpy arrays.""" |
|
|
result = { |
|
|
"panoptic_seg_3d": self.panoptic_seg_3d.cpu().numpy(), |
|
|
"geometry_3d": self.geometry_3d.cpu().numpy(), |
|
|
"semantic_seg_3d": self.semantic_seg_3d.cpu().numpy(), |
|
|
"panoptic_seg_2d": self.panoptic_seg_2d.cpu().numpy(), |
|
|
"depth_2d": self.depth_2d.cpu().numpy(), |
|
|
} |
|
|
return result |
|
|
|
|
|
|
|
|
class PanopticRecon3DModel( |
|
|
nn.Module, |
|
|
PyTorchModelHubMixin, |
|
|
|
|
|
repo_url="nvidia/nvpanoptix-3d", |
|
|
pipeline_tag="image-segmentation", |
|
|
license="apache-2.0", |
|
|
tags=["panoptic-segmentation", "3d-reconstruction", "depth-estimation", "nvidia"], |
|
|
): |
|
|
""" |
|
|
This model performs panoptic 3D scene reconstruction from a single RGB image. |
|
|
It combines: |
|
|
- 2D panoptic segmentation |
|
|
- Depth estimation |
|
|
- 3D volumetric reconstruction |
|
|
|
|
|
The model architecture uses: |
|
|
- VGGT backbone for feature extraction |
|
|
- MaskFormer head for panoptic segmentation |
|
|
- Occupancy-aware lifting for 2D-to-3D projection |
|
|
- Sparse 3D convolutions for volumetric completion |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_classes: int = 13, |
|
|
num_thing_classes: int = 9, |
|
|
object_mask_threshold: float = 0.8, |
|
|
overlap_threshold: float = 0.5, |
|
|
frustum_dims: int = 256, |
|
|
truncation: float = 3.0, |
|
|
iso_recon_value: float = 2.0, |
|
|
voxel_size: float = 0.03, |
|
|
depth_min: float = 0.4, |
|
|
depth_max: float = 6.0, |
|
|
target_size: Tuple[int, int] = (320, 240), |
|
|
reduced_target_size: Tuple[int, int] = (160, 120), |
|
|
size_divisibility: int = 32, |
|
|
downsample_factor: int = 1, |
|
|
is_matterport: bool = False, |
|
|
torchscript_2d_path: Optional[str] = None, |
|
|
use_fp16_2d: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
"""Initialize Panoptic Recon 3D model. |
|
|
|
|
|
Args: |
|
|
num_classes: Number of semantic classes. |
|
|
num_thing_classes: Number of "thing" (instance) classes. |
|
|
object_mask_threshold: Threshold for object mask confidence. |
|
|
overlap_threshold: Threshold for mask overlap. |
|
|
frustum_dims: Dimensions of 3D frustum volume. |
|
|
truncation: TSDF truncation distance. |
|
|
iso_recon_value: Iso-surface value for mesh extraction. |
|
|
voxel_size: Voxel size in meters. |
|
|
depth_min: Minimum depth value. |
|
|
depth_max: Maximum depth value. |
|
|
target_size: Target image size (width, height). |
|
|
reduced_target_size: Reduced target size for 3D projection. |
|
|
size_divisibility: Size divisibility for padding. |
|
|
downsample_factor: Downsample factor for 3D reconstruction. |
|
|
is_matterport: Whether using Matterport dataset mode. |
|
|
torchscript_2d_path: Path to TorchScript 2D model (optional). |
|
|
use_fp16_2d: Whether to use FP16 for 2D model. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.num_classes = num_classes |
|
|
self.num_thing_classes = num_thing_classes |
|
|
self.object_mask_threshold = object_mask_threshold |
|
|
self.overlap_threshold = overlap_threshold |
|
|
self.frustum_dims_val = frustum_dims |
|
|
self.truncation = truncation |
|
|
self.iso_recon_value = iso_recon_value |
|
|
self.voxel_size = voxel_size |
|
|
self.depth_min = depth_min |
|
|
self.depth_max = depth_max |
|
|
self.target_size = target_size |
|
|
self.reduced_target_size = reduced_target_size |
|
|
self.size_divisibility = size_divisibility |
|
|
self.downsample_factor = downsample_factor |
|
|
self.is_matterport = is_matterport |
|
|
self.torchscript_2d_path = torchscript_2d_path |
|
|
self.use_fp16_2d = use_fp16_2d |
|
|
|
|
|
|
|
|
self.frustum_dims = [frustum_dims] * 3 |
|
|
|
|
|
|
|
|
self.model_2d: Optional[torch.jit.ScriptModule] = None |
|
|
self.model_3d_components: Optional[Dict[str, nn.Module]] = None |
|
|
self._initialized = False |
|
|
|
|
|
|
|
|
self.post_processor = None |
|
|
|
|
|
@classmethod |
|
|
def _from_pretrained( |
|
|
cls, |
|
|
*, |
|
|
model_id: str, |
|
|
revision: Optional[str] = None, |
|
|
cache_dir: Optional[str] = None, |
|
|
force_download: bool = False, |
|
|
proxies: Optional[Dict] = None, |
|
|
resume_download: bool = False, |
|
|
local_files_only: bool = False, |
|
|
token: Optional[Union[str, bool]] = None, |
|
|
map_location: str = "cpu", |
|
|
strict: bool = False, |
|
|
**model_kwargs, |
|
|
) -> "PanopticRecon3DModel": |
|
|
"""Load model from HuggingFace Hub or local directory. |
|
|
|
|
|
This method handles loading both the TorchScript 2D model and the 3D checkpoint. |
|
|
|
|
|
Args: |
|
|
model_id: HuggingFace Hub repo ID or local directory path. |
|
|
revision: Git revision (branch, tag, or commit hash). |
|
|
cache_dir: Cache directory for downloaded files. |
|
|
force_download: Force re-download even if cached. |
|
|
proxies: Proxy configuration. |
|
|
resume_download: Resume interrupted downloads. |
|
|
local_files_only: Only use local files, don't download. |
|
|
token: HuggingFace API token. |
|
|
map_location: Device to load model onto. |
|
|
strict: Strict loading (not used for this model). |
|
|
**model_kwargs: Additional model arguments. |
|
|
|
|
|
Returns: |
|
|
Initialized PanopticRecon3DModel with weights loaded. |
|
|
""" |
|
|
|
|
|
device = model_kwargs.pop("device", None) |
|
|
if device is None: |
|
|
device = map_location if map_location != "cpu" else "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
model_path = Path(model_id) |
|
|
if model_path.exists() and model_path.is_dir(): |
|
|
|
|
|
config_path = model_path / "config.json" |
|
|
weights_dir = model_path / WEIGHTS_DIR |
|
|
torchscript_2d_path = weights_dir / TORCHSCRIPT_2D_FILENAME |
|
|
checkpoint_3d_path = weights_dir / CHECKPOINT_3D_FILENAME |
|
|
|
|
|
|
|
|
if config_path.exists(): |
|
|
with open(config_path, "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
for key, value in config.items(): |
|
|
if key not in model_kwargs: |
|
|
model_kwargs[key] = value |
|
|
else: |
|
|
|
|
|
|
|
|
try: |
|
|
config_file = hf_hub_download( |
|
|
repo_id=model_id, |
|
|
filename="config.json", |
|
|
revision=revision, |
|
|
cache_dir=cache_dir, |
|
|
force_download=force_download, |
|
|
proxies=proxies, |
|
|
resume_download=resume_download, |
|
|
local_files_only=local_files_only, |
|
|
token=token, |
|
|
) |
|
|
with open(config_file, "r") as f: |
|
|
config = json.load(f) |
|
|
for key, value in config.items(): |
|
|
if key not in model_kwargs: |
|
|
model_kwargs[key] = value |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
torchscript_2d_path = hf_hub_download( |
|
|
repo_id=model_id, |
|
|
filename=f"{WEIGHTS_DIR}/{TORCHSCRIPT_2D_FILENAME}", |
|
|
revision=revision, |
|
|
cache_dir=cache_dir, |
|
|
force_download=force_download, |
|
|
proxies=proxies, |
|
|
resume_download=resume_download, |
|
|
local_files_only=local_files_only, |
|
|
token=token, |
|
|
) |
|
|
checkpoint_3d_path = hf_hub_download( |
|
|
repo_id=model_id, |
|
|
filename=f"{WEIGHTS_DIR}/{CHECKPOINT_3D_FILENAME}", |
|
|
revision=revision, |
|
|
cache_dir=cache_dir, |
|
|
force_download=force_download, |
|
|
proxies=proxies, |
|
|
resume_download=resume_download, |
|
|
local_files_only=local_files_only, |
|
|
token=token, |
|
|
) |
|
|
|
|
|
|
|
|
model = cls(**model_kwargs) |
|
|
|
|
|
|
|
|
model.load_weights( |
|
|
torchscript_2d_path=str(torchscript_2d_path), |
|
|
checkpoint_3d_path=str(checkpoint_3d_path), |
|
|
device=device, |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None: |
|
|
"""Save model to directory. |
|
|
|
|
|
This saves the config.json and copies weight files to the directory. |
|
|
|
|
|
Args: |
|
|
save_directory: Directory to save model to. |
|
|
""" |
|
|
save_directory = Path(save_directory) |
|
|
save_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
config = { |
|
|
"num_classes": self.num_classes, |
|
|
"num_thing_classes": self.num_thing_classes, |
|
|
"object_mask_threshold": self.object_mask_threshold, |
|
|
"overlap_threshold": self.overlap_threshold, |
|
|
"frustum_dims": self.frustum_dims_val, |
|
|
"truncation": self.truncation, |
|
|
"iso_recon_value": self.iso_recon_value, |
|
|
"voxel_size": self.voxel_size, |
|
|
"depth_min": self.depth_min, |
|
|
"depth_max": self.depth_max, |
|
|
"target_size": list(self.target_size), |
|
|
"reduced_target_size": list(self.reduced_target_size), |
|
|
"size_divisibility": self.size_divisibility, |
|
|
"downsample_factor": self.downsample_factor, |
|
|
"is_matterport": self.is_matterport, |
|
|
"use_fp16_2d": self.use_fp16_2d, |
|
|
} |
|
|
|
|
|
config_path = save_directory / "config.json" |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
|
|
|
weights_dir = save_directory / WEIGHTS_DIR |
|
|
weights_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
if self._initialized and hasattr(self, '_torchscript_2d_path'): |
|
|
import shutil |
|
|
|
|
|
src_2d = Path(self._torchscript_2d_path) |
|
|
if src_2d.exists(): |
|
|
shutil.copy2(src_2d, weights_dir / TORCHSCRIPT_2D_FILENAME) |
|
|
|
|
|
|
|
|
src_3d = Path(self._checkpoint_3d_path) |
|
|
if src_3d.exists(): |
|
|
shutil.copy2(src_3d, weights_dir / CHECKPOINT_3D_FILENAME) |
|
|
|
|
|
def _build_omegaconf(self) -> Any: |
|
|
"""Build OmegaConf config for internal model components.""" |
|
|
|
|
|
return OmegaConf.create({ |
|
|
"model": { |
|
|
"export": True, |
|
|
"mode": "panoptic", |
|
|
"object_mask_threshold": self.object_mask_threshold, |
|
|
"overlap_threshold": self.overlap_threshold, |
|
|
"test_topk_per_image": 100, |
|
|
"backbone": {"type": "vggt", "pretrained_weights": None}, |
|
|
"sem_seg_head": { |
|
|
"common_stride": 4, |
|
|
"transformer_enc_layers": 6, |
|
|
"convs_dim": 256, |
|
|
"mask_dim": 256, |
|
|
"depth_dim": 256, |
|
|
"ignore_value": 255, |
|
|
"deformable_transformer_encoder_in_features": ["res3", "res4", "res5"], |
|
|
"num_classes": self.num_classes, |
|
|
"norm": "GN", |
|
|
"in_features": ["res2", "res3", "res4", "res5"] |
|
|
}, |
|
|
"mask_former": { |
|
|
"dropout": 0.0, |
|
|
"nheads": 8, |
|
|
"num_object_queries": 100, |
|
|
"hidden_dim": 256, |
|
|
"transformer_dim_feedforward": 1024, |
|
|
"dim_feedforward": 2048, |
|
|
"dec_layers": 10, |
|
|
"pre_norm": False, |
|
|
"class_weight": 2.0, |
|
|
"dice_weight": 5.0, |
|
|
"mask_weight": 5.0, |
|
|
"depth_weight": 5.0, |
|
|
"mp_occ_weight": 5.0, |
|
|
"train_num_points": 12544, |
|
|
"oversample_ratio": 3.0, |
|
|
"importance_sample_ratio": 0.75, |
|
|
"deep_supervision": True, |
|
|
"no_object_weight": 0.1, |
|
|
"size_divisibility": self.size_divisibility |
|
|
}, |
|
|
"frustum3d": { |
|
|
"truncation": self.truncation, |
|
|
"iso_recon_value": self.iso_recon_value, |
|
|
"panoptic_weight": 25.0, |
|
|
"completion_weights": [50.0, 25.0, 10.0], |
|
|
"surface_weight": 5.0, |
|
|
"unet_output_channels": 16, |
|
|
"unet_features": 16, |
|
|
"use_multi_scale": False, |
|
|
"grid_dimensions": self.frustum_dims_val, |
|
|
"frustum_dims": self.frustum_dims_val, |
|
|
"signed_channel": 3 |
|
|
}, |
|
|
"projection": { |
|
|
"voxel_size": self.voxel_size, |
|
|
"sign_channel": True, |
|
|
"depth_feature_dim": 256 |
|
|
} |
|
|
}, |
|
|
"dataset": { |
|
|
"contiguous_id": False, |
|
|
"label_map": "", |
|
|
"name": "", |
|
|
"downsample_factor": self.downsample_factor, |
|
|
"iso_value": 1.0, |
|
|
"pixel_mean": [0.485, 0.456, 0.406], |
|
|
"pixel_std": [0.229, 0.224, 0.225], |
|
|
"ignore_label": 255, |
|
|
"min_instance_pixels": 200, |
|
|
"img_format": "RGB", |
|
|
"target_size": list(self.target_size), |
|
|
"reduced_target_size": list(self.reduced_target_size), |
|
|
"depth_size": [120, 160], |
|
|
"depth_bound": False, |
|
|
"depth_min": self.depth_min, |
|
|
"depth_max": self.depth_max, |
|
|
"frustum_mask_path": "", |
|
|
"occ_truncation_lvl": [8.0, 6.0], |
|
|
"truncation_range": [0.0, 12.0], |
|
|
"enable_3d": False, |
|
|
"enable_mp_occ": True, |
|
|
"depth_scale": 25.0, |
|
|
"num_thing_classes": self.num_thing_classes, |
|
|
"augmentation": {"size_divisibility": self.size_divisibility} |
|
|
} |
|
|
}) |
|
|
|
|
|
def load_weights( |
|
|
self, |
|
|
torchscript_2d_path: Optional[str] = None, |
|
|
checkpoint_3d_path: Optional[str] = None, |
|
|
device: str = "cuda:0", |
|
|
): |
|
|
"""Load model weights. |
|
|
|
|
|
Args: |
|
|
torchscript_2d_path: Path to TorchScript 2D model file. |
|
|
checkpoint_3d_path: Path to 3D model checkpoint (.pth/.pt). |
|
|
device: Device to load models onto. |
|
|
""" |
|
|
|
|
|
torchscript_2d_path = torchscript_2d_path or self.torchscript_2d_path |
|
|
|
|
|
if torchscript_2d_path is None: |
|
|
raise ValueError("torchscript_2d_path is required") |
|
|
if checkpoint_3d_path is None: |
|
|
raise ValueError("checkpoint_3d_path is required") |
|
|
|
|
|
|
|
|
self._torchscript_2d_path = torchscript_2d_path |
|
|
self._checkpoint_3d_path = checkpoint_3d_path |
|
|
|
|
|
|
|
|
cfg = self._build_omegaconf() |
|
|
|
|
|
|
|
|
self.model_2d = torch.jit.load(torchscript_2d_path, map_location=device) |
|
|
self.model_2d.eval() |
|
|
|
|
|
|
|
|
full_model = Panoptic3DModel(cfg) |
|
|
|
|
|
checkpoint = torch.load(checkpoint_3d_path, map_location="cpu") |
|
|
state_dict = checkpoint.get("state_dict", checkpoint) |
|
|
|
|
|
|
|
|
filtered_state_dict = {} |
|
|
for key, value in state_dict.items(): |
|
|
new_key = key[6:] if key.startswith("model.") else key |
|
|
filtered_state_dict[new_key] = value |
|
|
|
|
|
full_model.load_state_dict(filtered_state_dict, strict=False) |
|
|
full_model.to(device) |
|
|
full_model.eval() |
|
|
|
|
|
|
|
|
self.model_3d_components = { |
|
|
"ol": full_model.ol, |
|
|
"reprojection": full_model.reprojection, |
|
|
"completion": full_model.completion, |
|
|
"projector": full_model.projector, |
|
|
"back_projection": full_model.back_projection, |
|
|
} |
|
|
|
|
|
|
|
|
self.post_processor = full_model.post_processor |
|
|
self.back_projection = full_model.back_projection |
|
|
self._get_kept_mapping = get_kept_mapping |
|
|
self._transform_feat3d_coordinates = transform_feat3d_coordinates |
|
|
self._fuse_sparse_tensors = fuse_sparse_tensors |
|
|
self._generate_multiscale_feat3d = generate_multiscale_feat3d |
|
|
self._retry_if_cuda_oom = retry_if_cuda_oom |
|
|
self._panoptic_3d_inference = full_model.panoptic_3d_inference |
|
|
self._postprocess = full_model.postprocess |
|
|
self._cfg = cfg |
|
|
|
|
|
|
|
|
for module in self.model_3d_components.values(): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = False |
|
|
module.eval() |
|
|
|
|
|
self._initialized = True |
|
|
self._device = device |
|
|
|
|
|
def _ensure_initialized(self): |
|
|
"""Ensure model is initialized.""" |
|
|
if not self._initialized: |
|
|
raise RuntimeError( |
|
|
"Model weights not loaded. Call load_weights() first, or use " |
|
|
"from_pretrained() to load a pre-trained model." |
|
|
) |
|
|
|
|
|
def _infer_2d( |
|
|
self, |
|
|
images: torch.Tensor, |
|
|
intrinsic: torch.Tensor, |
|
|
) -> Tuple[Dict[str, torch.Tensor], List[Dict], torch.Tensor]: |
|
|
"""Run 2D inference using TorchScript model. |
|
|
|
|
|
Args: |
|
|
images: Input images (B, C, H, W) as uint8 or float. |
|
|
intrinsic: Camera intrinsics (B, 4, 4). |
|
|
|
|
|
Returns: |
|
|
outputs_2d: Dictionary of 2D model outputs. |
|
|
processed_results: List of processed results per image. |
|
|
occupancy_pred: Occupancy predictions. |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
if self.use_fp16_2d: |
|
|
with torch.cuda.amp.autocast(): |
|
|
outputs_dict = self.model_2d(images) |
|
|
else: |
|
|
outputs_dict = self.model_2d(images) |
|
|
|
|
|
|
|
|
def to_fp32(x): |
|
|
return x.float() if isinstance(x, torch.Tensor) and x.dtype != torch.float32 else x |
|
|
|
|
|
|
|
|
mask_cls_results = to_fp32(outputs_dict["pred_logits"]) |
|
|
mask_pred_results = to_fp32(outputs_dict["pred_masks"]) |
|
|
depth_pred_results = to_fp32(outputs_dict["pred_depths"]) |
|
|
|
|
|
enc_features = [ |
|
|
to_fp32(outputs_dict["enc_features_0"]), |
|
|
to_fp32(outputs_dict["enc_features_1"]), |
|
|
to_fp32(outputs_dict["enc_features_2"]), |
|
|
to_fp32(outputs_dict["enc_features_3"]), |
|
|
] |
|
|
mask_features = to_fp32(outputs_dict["mask_features"]) |
|
|
depth_features = to_fp32(outputs_dict["depth_features"]) |
|
|
segm_decoder_out = to_fp32(outputs_dict["segm_decoder_out"]) |
|
|
pose_enc = to_fp32(outputs_dict["pose_enc"]) |
|
|
occupancy_pred = to_fp32(outputs_dict["occupancy_pred"]) |
|
|
|
|
|
orig_pad_h = int(outputs_dict["orig_pad_h"].item()) |
|
|
orig_pad_w = int(outputs_dict["orig_pad_w"].item()) |
|
|
orig_h = int(outputs_dict["orig_h"].item()) |
|
|
orig_w = int(outputs_dict["orig_w"].item()) |
|
|
|
|
|
|
|
|
padded_out_h, padded_out_w = orig_pad_h // 2, orig_pad_w // 2 |
|
|
mask_pred_results = F.interpolate( |
|
|
mask_pred_results, |
|
|
size=(padded_out_h, padded_out_w), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
depth_pred_results = F.interpolate( |
|
|
depth_pred_results, |
|
|
size=(padded_out_h, padded_out_w), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processed_results = [] |
|
|
final_mask_cls_result = None |
|
|
final_mask_pred_result = None |
|
|
|
|
|
for idx, (mask_cls_result, mask_pred_result, depth_pred_result, per_image_intrinsic) in enumerate(zip( |
|
|
mask_cls_results, mask_pred_results, depth_pred_results, intrinsic |
|
|
)): |
|
|
out_h, out_w = orig_h // 2, orig_w // 2 |
|
|
processed_results.append({}) |
|
|
|
|
|
|
|
|
mask_pred_result = mask_pred_result[:, :out_h, :out_w] |
|
|
depth_pred_result = depth_pred_result[:, :out_h, :out_w] |
|
|
|
|
|
|
|
|
panoptic_seg, depth_r, segments_info, sem_prob_masks = self._retry_if_cuda_oom( |
|
|
self.post_processor.panoptic_inference |
|
|
)( |
|
|
mask_cls_result, |
|
|
mask_pred_result, |
|
|
depth_pred_result |
|
|
) |
|
|
depth_r = depth_r[None] |
|
|
|
|
|
processed_results[-1]["panoptic_seg"] = (panoptic_seg, segments_info) |
|
|
processed_results[-1]["depth"] = depth_r[0] |
|
|
processed_results[-1]["image_size"] = (orig_w, orig_h) |
|
|
processed_results[-1]["padded_size"] = (orig_pad_w, orig_pad_h) |
|
|
processed_results[-1]["intrinsic"] = per_image_intrinsic |
|
|
processed_results[-1]["sem_seg"] = sem_prob_masks |
|
|
|
|
|
|
|
|
final_mask_cls_result = mask_cls_result |
|
|
final_mask_pred_result = mask_pred_result |
|
|
|
|
|
|
|
|
|
|
|
outputs_2d = { |
|
|
"pred_logits": final_mask_cls_result.unsqueeze(0), |
|
|
"pred_masks": final_mask_pred_result.unsqueeze(0), |
|
|
"enc_features": enc_features, |
|
|
"mask_features": mask_features, |
|
|
"depth_features": depth_features, |
|
|
"segm_decoder_out": segm_decoder_out, |
|
|
"pose_enc": pose_enc, |
|
|
} |
|
|
|
|
|
return outputs_2d, processed_results, occupancy_pred |
|
|
|
|
|
def _forward_3d( |
|
|
self, |
|
|
batched_inputs: Dict[str, torch.Tensor], |
|
|
outputs_2d: Dict[str, torch.Tensor], |
|
|
processed_results: List[Dict], |
|
|
kept: torch.Tensor, |
|
|
mapping: torch.Tensor, |
|
|
occupancy_pred: torch.Tensor, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run 3D reconstruction pipeline. |
|
|
|
|
|
Args: |
|
|
batched_inputs: Dictionary containing frustum_mask, intrinsic, etc. |
|
|
outputs_2d: 2D model outputs. |
|
|
processed_results: Processed 2D results. |
|
|
kept: Kept voxel indices. |
|
|
mapping: Voxel to pixel mapping. |
|
|
occupancy_pred: Occupancy predictions. |
|
|
|
|
|
Returns: |
|
|
Postprocessed 3D results. |
|
|
""" |
|
|
room_mask = batched_inputs.get("room_mask_buol") if self.is_matterport else None |
|
|
|
|
|
|
|
|
feat3d, mask3d = self.model_3d_components["ol"]( |
|
|
processed_results, kept, mapping, occupancy_pred, room_mask |
|
|
) |
|
|
del occupancy_pred, mask3d |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
multi_scale_features = list(reversed(outputs_2d["enc_features"])) |
|
|
depth_features = self.model_3d_components["projector"]( |
|
|
outputs_2d["depth_features"], |
|
|
outputs_2d["mask_features"].shape[-2:] |
|
|
) |
|
|
encoder_features = torch.cat([outputs_2d["mask_features"], depth_features], dim=1) |
|
|
|
|
|
sparse_multi_scale_features, sparse_encoder_features = self.model_3d_components["reprojection"]( |
|
|
multi_scale_features, encoder_features, processed_results |
|
|
) |
|
|
|
|
|
del multi_scale_features, encoder_features |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
segm_queries = outputs_2d["segm_decoder_out"] |
|
|
frustum_mask = batched_inputs["frustum_mask"] |
|
|
intrinsic = batched_inputs["intrinsic"] |
|
|
|
|
|
frustum_mask_64 = F.max_pool3d( |
|
|
frustum_mask[:, None].float(), |
|
|
kernel_size=2, |
|
|
stride=4 |
|
|
).bool() |
|
|
|
|
|
|
|
|
transformed_feat3d = self._transform_feat3d_coordinates(feat3d, intrinsic) |
|
|
del feat3d |
|
|
|
|
|
|
|
|
if not self.is_matterport: |
|
|
multi_scale_feat3d = self._generate_multiscale_feat3d(transformed_feat3d) |
|
|
fused_multi_scale_features = [ |
|
|
self._fuse_sparse_tensors(sparse_multi_scale_features[i], multi_scale_feat3d[i]) |
|
|
for i in range(len(multi_scale_feat3d)) |
|
|
] |
|
|
del sparse_multi_scale_features, multi_scale_feat3d |
|
|
else: |
|
|
fused_multi_scale_features = sparse_multi_scale_features |
|
|
|
|
|
try: |
|
|
fused_encoder_features = self._fuse_sparse_tensors( |
|
|
sparse_encoder_features, transformed_feat3d |
|
|
) |
|
|
except Exception: |
|
|
fused_encoder_features = sparse_encoder_features |
|
|
|
|
|
del sparse_encoder_features, transformed_feat3d |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
outputs_3d = self.model_3d_components["completion"]( |
|
|
fused_multi_scale_features, |
|
|
fused_encoder_features, |
|
|
segm_queries, |
|
|
frustum_mask_64 |
|
|
) |
|
|
|
|
|
outputs_3d["pred_logits"] = outputs_2d["pred_logits"] |
|
|
outputs_3d["pred_masks"] = outputs_2d["pred_masks"] |
|
|
|
|
|
return self._postprocess(outputs_3d, outputs_2d, processed_results, frustum_mask) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
images: torch.Tensor, |
|
|
frustum_mask: torch.Tensor, |
|
|
intrinsic: torch.Tensor, |
|
|
height: Optional[torch.Tensor] = None, |
|
|
width: Optional[torch.Tensor] = None, |
|
|
) -> PanopticRecon3DOutput: |
|
|
"""Run full panoptic 3D reconstruction pipeline. |
|
|
|
|
|
Args: |
|
|
images: Input images (B, C, H, W) as uint8 [0-255] or float [0-1]. |
|
|
frustum_mask: Boolean frustum mask (B, D, H, W). |
|
|
intrinsic: Camera intrinsic matrices (B, 4, 4). |
|
|
height: Optional image heights (B,). |
|
|
width: Optional image widths (B,). |
|
|
|
|
|
Returns: |
|
|
PanopticRecon3DOutput with 2D and 3D predictions. |
|
|
""" |
|
|
self._ensure_initialized() |
|
|
|
|
|
|
|
|
if height is None: |
|
|
height = torch.tensor([images.shape[2]], device=images.device) |
|
|
if width is None: |
|
|
width = torch.tensor([images.shape[3]], device=images.device) |
|
|
|
|
|
batched_inputs = { |
|
|
"image": images, |
|
|
"frustum_mask": frustum_mask.bool(), |
|
|
"intrinsic": intrinsic, |
|
|
"height": height, |
|
|
"width": width, |
|
|
} |
|
|
|
|
|
|
|
|
outputs_2d, processed_results, occupancy_pred = self._infer_2d(images, intrinsic) |
|
|
|
|
|
|
|
|
kept, mapping = self._get_kept_mapping( |
|
|
self, |
|
|
self._cfg, |
|
|
batched_inputs, |
|
|
device=images.device |
|
|
) |
|
|
|
|
|
|
|
|
outputs_3d = self._forward_3d( |
|
|
batched_inputs, outputs_2d, processed_results, kept, mapping, occupancy_pred |
|
|
) |
|
|
|
|
|
|
|
|
return PanopticRecon3DOutput( |
|
|
panoptic_seg_3d=outputs_3d["panoptic_seg"][0], |
|
|
geometry_3d=outputs_3d["geometry"][0], |
|
|
semantic_seg_3d=outputs_3d["semantic_seg"][0], |
|
|
panoptic_seg_2d=outputs_3d["panoptic_seg_2d"][0][0], |
|
|
depth_2d=outputs_3d["depth"][0], |
|
|
panoptic_semantic_mapping=outputs_3d["panoptic_semantic_mapping"][0], |
|
|
segments_info=outputs_3d["panoptic_seg_2d"][0][1] if len(outputs_3d["panoptic_seg_2d"][0]) > 1 else None, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict( |
|
|
self, |
|
|
image: Union[np.ndarray, torch.Tensor], |
|
|
frustum_mask: Optional[Union[np.ndarray, torch.Tensor]] = None, |
|
|
intrinsic: Optional[Union[np.ndarray, torch.Tensor]] = None, |
|
|
) -> PanopticRecon3DOutput: |
|
|
"""User-friendly prediction interface. |
|
|
|
|
|
Args: |
|
|
image: Input RGB image as numpy array. Accepted formats: |
|
|
- (H, W, C) HWC format uint8 [0-255] |
|
|
- (C, H, W) CHW format uint8 [0-255] |
|
|
- (1, C, H, W) batched CHW format (from load_image) |
|
|
frustum_mask: Optional frustum mask. If None, auto-generated using default intrinsic. |
|
|
intrinsic: Optional camera intrinsic (4x4). If None, uses DEFAULT_INTRINSIC. |
|
|
|
|
|
Returns: |
|
|
PanopticRecon3DOutput with predictions. |
|
|
""" |
|
|
self._ensure_initialized() |
|
|
|
|
|
|
|
|
if intrinsic is None: |
|
|
intrinsic = DEFAULT_INTRINSIC.copy() |
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
|
|
|
if image.ndim == 4: |
|
|
|
|
|
pass |
|
|
elif image.ndim == 3: |
|
|
if image.shape[2] == 3: |
|
|
|
|
|
image = np.ascontiguousarray(image.transpose(2, 0, 1)) |
|
|
|
|
|
image = image[np.newaxis, ...] |
|
|
|
|
|
|
|
|
if image.dtype != np.uint8: |
|
|
if image.max() <= 1.0: |
|
|
image = (image * 255).clip(0, 255).astype(np.uint8) |
|
|
else: |
|
|
image = image.clip(0, 255).astype(np.uint8) |
|
|
|
|
|
image = torch.from_numpy(image) |
|
|
else: |
|
|
|
|
|
if image.dim() == 3: |
|
|
image = image.unsqueeze(0) |
|
|
|
|
|
if image.dtype != torch.uint8: |
|
|
if image.max() <= 1.0: |
|
|
image = (image * 255).clamp(0, 255).to(torch.uint8) |
|
|
else: |
|
|
image = image.clamp(0, 255).to(torch.uint8) |
|
|
|
|
|
image = image.to(self._device) |
|
|
|
|
|
|
|
|
if frustum_mask is None: |
|
|
intrinsic_np = intrinsic if isinstance(intrinsic, np.ndarray) else intrinsic.cpu().numpy() |
|
|
frustum_mask = create_frustum_mask( |
|
|
intrinsics=intrinsic_np, |
|
|
volume_shape=(self.frustum_dims_val,) * 3, |
|
|
depth_range=(self.depth_min, self.depth_max), |
|
|
voxel_size=self.voxel_size, |
|
|
image_shape=(self.target_size[1], self.target_size[0]), |
|
|
) |
|
|
frustum_mask = torch.from_numpy(frustum_mask).unsqueeze(0) |
|
|
elif isinstance(frustum_mask, np.ndarray): |
|
|
frustum_mask = torch.from_numpy(frustum_mask) |
|
|
if frustum_mask.dim() == 3: |
|
|
frustum_mask = frustum_mask.unsqueeze(0) |
|
|
|
|
|
frustum_mask = frustum_mask.to(self._device) |
|
|
|
|
|
|
|
|
if isinstance(intrinsic, np.ndarray): |
|
|
intrinsic = torch.from_numpy(intrinsic) |
|
|
if intrinsic.dim() == 2: |
|
|
intrinsic = intrinsic.unsqueeze(0) |
|
|
|
|
|
intrinsic = intrinsic.float().to(self._device) |
|
|
|
|
|
return self.forward(image, frustum_mask, intrinsic) |
|
|
|
|
|
|