nvpanoptix-3d / visualization.py
tinlam-nv's picture
Update model inference code and environment setup instructions
911b379 verified
raw
history blame
14.8 kB
#!/usr/bin/env python3
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualization utilities for Panoptic Recon 3D model outputs.
This module provides functions for:
- 2D segmentation visualization
- Depth map visualization
- 3D mesh extraction and PLY export
"""
from pathlib import Path
from typing import Optional, Tuple, Union
import numpy as np
# Optional imports for visualization
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
try:
from PIL import Image
HAS_PIL = True
except ImportError:
HAS_PIL = False
try:
from skimage import measure
HAS_SKIMAGE = True
except ImportError:
HAS_SKIMAGE = False
try:
from scipy.spatial import KDTree
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
def create_color_palette() -> np.ndarray:
"""Create Front3D color palette for semantic classes.
Returns:
Color palette as numpy array (N, 3) with uint8 RGB values.
"""
return np.array([
(0, 0, 0), # 0: background
(174, 199, 232), # 1: wall
(152, 223, 138), # 2: floor
(31, 119, 180), # 3: cabinet
(255, 187, 120), # 4: bed
(188, 189, 34), # 5: chair
(140, 86, 75), # 6: sofa
(255, 152, 150), # 7: table
(214, 39, 40), # 8: door
(197, 176, 213), # 9: window
(148, 103, 189), # 10: bookshelf
(196, 156, 148), # 11: picture
(23, 190, 207), # 12: counter
(178, 76, 76), # 13
(247, 182, 210), # 14: desk
(66, 188, 102), # 15
(219, 219, 141), # 16: curtain
(140, 57, 197), # 17
(202, 185, 52), # 18
(51, 176, 203), # 19
(200, 54, 131), # 20
(92, 193, 61), # 21
(78, 71, 183), # 22
(172, 114, 82), # 23
(255, 127, 14), # 24: refrigerator
(91, 163, 138), # 25
(153, 98, 156), # 26
(140, 153, 101), # 27
(158, 218, 229), # 28: shower curtain
(100, 125, 154), # 29
(178, 127, 135), # 30
(120, 185, 128), # 31
(146, 111, 194), # 32
(44, 160, 44), # 33: toilet
(112, 128, 144), # 34: sink
(96, 207, 209), # 35
(227, 119, 194), # 36: bathtub
(213, 92, 176), # 37
(94, 106, 211), # 38
(82, 84, 163), # 39: otherfurn
(100, 85, 144), # 40
(172, 172, 172), # 41
], dtype=np.uint8)
def colorize_segmentation(
segmentation: np.ndarray,
palette: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Colorize segmentation map.
Args:
segmentation: Segmentation map (H, W) with class indices.
palette: Color palette (N, 3). Uses default if None.
Returns:
Colorized image (H, W, 3) as uint8.
"""
if palette is None:
palette = create_color_palette()
# Clip indices to valid range
seg_clipped = np.clip(segmentation, 0, len(palette) - 1)
return palette[seg_clipped]
def visualize_2d_segmentation(
image: np.ndarray,
panoptic_2d: np.ndarray,
output_path: Optional[Union[str, Path]] = None,
alpha: float = 0.6,
figsize: Tuple[int, int] = (18, 6),
dpi: int = 150,
) -> Optional[np.ndarray]:
"""Visualize 2D panoptic segmentation overlaid on image.
Args:
image: Original RGB image (H, W, C).
panoptic_2d: Panoptic segmentation map (H, W).
output_path: Path to save visualization. If None, returns array.
alpha: Blend alpha for overlay.
figsize: Figure size.
dpi: DPI for saved figure.
Returns:
Overlay image as numpy array if output_path is None.
"""
if not HAS_MATPLOTLIB:
raise ImportError("matplotlib required for visualization")
if not HAS_PIL:
raise ImportError("PIL required for visualization")
# Get color palette
palette = create_color_palette()
colored_seg = colorize_segmentation(panoptic_2d, palette)
# Resize image to match segmentation if needed
if image.shape[:2] != panoptic_2d.shape:
image_pil = Image.fromarray(image)
image_pil = image_pil.resize((panoptic_2d.shape[1], panoptic_2d.shape[0]), Image.LANCZOS)
image = np.array(image_pil)
# Create overlay
overlay = (image.astype(np.float32) * (1 - alpha) + colored_seg.astype(np.float32) * alpha)
overlay = overlay.clip(0, 255).astype(np.uint8)
if output_path is None:
return overlay
# Create side-by-side visualization
fig, axes = plt.subplots(1, 3, figsize=figsize)
axes[0].imshow(image)
axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
axes[0].axis('off')
axes[1].imshow(colored_seg)
axes[1].set_title('Panoptic Segmentation', fontsize=14, fontweight='bold')
axes[1].axis('off')
axes[2].imshow(overlay)
axes[2].set_title('Overlay', fontsize=14, fontweight='bold')
axes[2].axis('off')
plt.tight_layout()
plt.savefig(output_path, dpi=dpi, bbox_inches='tight')
plt.close()
print(f"✓ Saved 2D segmentation visualization to: {output_path}")
return None
def visualize_depth_map(
depth_2d: np.ndarray,
output_path: Optional[Union[str, Path]] = None,
vmin: float = 0.0,
vmax: float = 6.0,
cmap: str = 'viridis',
figsize: Tuple[int, int] = (10, 8),
dpi: int = 150,
) -> Optional[np.ndarray]:
"""Visualize depth map.
Args:
depth_2d: Depth map (H, W).
output_path: Path to save visualization. If None, returns array.
vmin: Minimum depth for colormap.
vmax: Maximum depth for colormap.
cmap: Matplotlib colormap name.
figsize: Figure size.
dpi: DPI for saved figure.
Returns:
Colorized depth as numpy array if output_path is None.
"""
if not HAS_MATPLOTLIB:
raise ImportError("matplotlib required for visualization")
# Normalize depth
depth_norm = (depth_2d - vmin) / (vmax - vmin)
depth_norm = np.clip(depth_norm, 0, 1)
# Get colormap
cm = plt.get_cmap(cmap)
depth_colored = (cm(depth_norm)[:, :, :3] * 255).astype(np.uint8)
if output_path is None:
return depth_colored
fig, ax = plt.subplots(1, 1, figsize=figsize)
im = ax.imshow(depth_2d, cmap=cmap, vmin=vmin, vmax=vmax)
ax.set_title('Depth Map', fontsize=14, fontweight='bold')
ax.axis('off')
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Depth (m)', rotation=270, labelpad=20, fontsize=12)
plt.tight_layout()
plt.savefig(output_path, dpi=dpi, bbox_inches='tight')
plt.close()
print(f"✓ Saved depth map visualization to: {output_path}")
return None
def get_mesh(
distance_field: np.ndarray,
iso_value: float = 1.0,
spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0),
) -> Tuple[np.ndarray, np.ndarray]:
"""Extract mesh from distance field using marching cubes.
Args:
distance_field: 3D distance field (D, H, W).
iso_value: Iso-surface value.
spacing: Voxel spacing.
Returns:
vertices: Mesh vertices (N, 3).
faces: Mesh faces (M, 3).
"""
if not HAS_SKIMAGE:
raise ImportError("scikit-image required for mesh extraction")
vertices, faces, _, _ = measure.marching_cubes(
distance_field,
level=iso_value,
spacing=spacing
)
return vertices, faces
def write_ply(
vertices: np.ndarray,
output_file: Union[str, Path],
colors: Optional[np.ndarray] = None,
faces: Optional[np.ndarray] = None,
) -> None:
"""Write PLY file.
Args:
vertices: Vertex positions (N, 3).
output_file: Output PLY file path.
colors: Optional vertex colors (N, 3) as uint8.
faces: Optional face indices (M, 3).
"""
with open(output_file, "w") as f:
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(vertices)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
if colors is not None:
f.write("property uchar red\n")
f.write("property uchar green\n")
f.write("property uchar blue\n")
if faces is not None and len(faces) > 0:
f.write(f"element face {len(faces)}\n")
f.write("property list uchar uint vertex_indices\n")
f.write("end_header\n")
# Write vertices
if colors is not None:
for v, c in zip(vertices, colors):
f.write(f"{v[0]} {v[1]} {v[2]} {int(c[0])} {int(c[1])} {int(c[2])}\n")
else:
for v in vertices:
f.write(f"{v[0]} {v[1]} {v[2]}\n")
# Write faces
if faces is not None:
for face in faces:
f.write(f"3 {face[0]} {face[1]} {face[2]}\n")
def save_3d_mesh(
geometry_3d: np.ndarray,
semantic_3d: np.ndarray,
output_path: Union[str, Path],
iso_value: float = 1.0,
voxel_size: float = 0.03,
) -> bool:
"""Extract and save 3D mesh with semantic colors.
Args:
geometry_3d: 3D geometry/TSDF (D, H, W).
semantic_3d: 3D semantic segmentation (D, H, W).
output_path: Output PLY file path.
iso_value: Iso-surface value for mesh extraction.
voxel_size: Voxel size in meters.
Returns:
True if successful, False otherwise.
"""
if not HAS_SKIMAGE:
print("Warning: scikit-image not installed. Cannot save PLY mesh.")
return False
if not HAS_SCIPY:
print("Warning: scipy not installed. Cannot color mesh by semantics.")
try:
# Extract mesh
vertices, faces = get_mesh(
geometry_3d,
iso_value=iso_value,
spacing=(voxel_size, voxel_size, voxel_size)
)
colors = None
if HAS_SCIPY and np.any(semantic_3d):
# Get non-zero labeled voxels
nonzero_coords = np.stack(semantic_3d.nonzero(), axis=-1)
if len(nonzero_coords) > 0:
# Build KD tree for nearest neighbor lookup
labels_kd = KDTree(nonzero_coords)
palette = create_color_palette()
# Create color volume
semantic_clipped = np.clip(semantic_3d, 0, len(palette) - 1).astype(np.uint32)
color_volume = palette[semantic_clipped]
# Find nearest label for each vertex
# Scale vertices to voxel indices
vertex_indices = (vertices / voxel_size).astype(int)
neighbor_inds = labels_kd.query(vertex_indices)[1]
neighbors = labels_kd.data[neighbor_inds].astype(int)
# Clip to valid indices
neighbors = np.clip(neighbors, 0, np.array(color_volume.shape[:3]) - 1)
colors = color_volume[neighbors[:, 0], neighbors[:, 1], neighbors[:, 2]]
# Write PLY
write_ply(vertices, output_path, colors, faces)
print(f"✓ Saved 3D mesh to: {output_path}")
print(f" Vertices: {len(vertices)}, Faces: {len(faces)}")
return True
except Exception as e:
print(f"Warning: Failed to save 3D mesh: {e}")
return False
def save_outputs(
outputs,
output_dir: Union[str, Path],
original_image: Optional[np.ndarray] = None,
save_mesh: bool = True,
save_depth: bool = True,
save_segmentation: bool = True,
save_numpy: bool = True,
) -> dict:
"""Save all model outputs to directory.
Args:
outputs: PanopticRecon3DOutput from model.
output_dir: Output directory.
original_image: Optional original input image for visualization.
save_mesh: Whether to save 3D mesh PLY files.
save_depth: Whether to save depth visualization.
save_segmentation: Whether to save segmentation visualization.
save_numpy: Whether to save raw numpy arrays.
Returns:
Dictionary of saved file paths.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
saved_files = {}
# Convert outputs to numpy
outputs_np = outputs.to_numpy()
# Save numpy arrays
if save_numpy:
for name, arr in outputs_np.items():
npy_path = output_dir / f"{name}.npy"
np.save(npy_path, arr)
saved_files[f"{name}_npy"] = str(npy_path)
# Save 2D segmentation visualization
if save_segmentation and original_image is not None:
seg_path = output_dir / "panoptic_2d_visualization.png"
visualize_2d_segmentation(
original_image,
outputs_np["panoptic_seg_2d"],
seg_path
)
saved_files["segmentation_vis"] = str(seg_path)
# Save depth visualization
if save_depth:
depth_path = output_dir / "depth_visualization.png"
visualize_depth_map(
outputs_np["depth_2d"],
depth_path
)
saved_files["depth_vis"] = str(depth_path)
# Save 3D meshes
if save_mesh:
# Semantic mesh
semantic_mesh_path = output_dir / "mesh_semantic.ply"
if save_3d_mesh(
outputs_np["geometry_3d"],
outputs_np["semantic_seg_3d"],
semantic_mesh_path
):
saved_files["semantic_mesh"] = str(semantic_mesh_path)
# Panoptic mesh
panoptic_mesh_path = output_dir / "mesh_panoptic.ply"
if save_3d_mesh(
outputs_np["geometry_3d"],
outputs_np["panoptic_seg_3d"],
panoptic_mesh_path
):
saved_files["panoptic_mesh"] = str(panoptic_mesh_path)
return saved_files