|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Visualization utilities for Panoptic Recon 3D model outputs. |
|
|
|
|
|
This module provides functions for: |
|
|
- 2D segmentation visualization |
|
|
- Depth map visualization |
|
|
- 3D mesh extraction and PLY export |
|
|
""" |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as mpatches |
|
|
HAS_MATPLOTLIB = True |
|
|
except ImportError: |
|
|
HAS_MATPLOTLIB = False |
|
|
|
|
|
try: |
|
|
from PIL import Image |
|
|
HAS_PIL = True |
|
|
except ImportError: |
|
|
HAS_PIL = False |
|
|
|
|
|
try: |
|
|
from skimage import measure |
|
|
HAS_SKIMAGE = True |
|
|
except ImportError: |
|
|
HAS_SKIMAGE = False |
|
|
|
|
|
try: |
|
|
from scipy.spatial import KDTree |
|
|
HAS_SCIPY = True |
|
|
except ImportError: |
|
|
HAS_SCIPY = False |
|
|
|
|
|
|
|
|
def create_color_palette() -> np.ndarray: |
|
|
"""Create Front3D color palette for semantic classes. |
|
|
|
|
|
Returns: |
|
|
Color palette as numpy array (N, 3) with uint8 RGB values. |
|
|
""" |
|
|
return np.array([ |
|
|
(0, 0, 0), |
|
|
(174, 199, 232), |
|
|
(152, 223, 138), |
|
|
(31, 119, 180), |
|
|
(255, 187, 120), |
|
|
(188, 189, 34), |
|
|
(140, 86, 75), |
|
|
(255, 152, 150), |
|
|
(214, 39, 40), |
|
|
(197, 176, 213), |
|
|
(148, 103, 189), |
|
|
(196, 156, 148), |
|
|
(23, 190, 207), |
|
|
(178, 76, 76), |
|
|
(247, 182, 210), |
|
|
(66, 188, 102), |
|
|
(219, 219, 141), |
|
|
(140, 57, 197), |
|
|
(202, 185, 52), |
|
|
(51, 176, 203), |
|
|
(200, 54, 131), |
|
|
(92, 193, 61), |
|
|
(78, 71, 183), |
|
|
(172, 114, 82), |
|
|
(255, 127, 14), |
|
|
(91, 163, 138), |
|
|
(153, 98, 156), |
|
|
(140, 153, 101), |
|
|
(158, 218, 229), |
|
|
(100, 125, 154), |
|
|
(178, 127, 135), |
|
|
(120, 185, 128), |
|
|
(146, 111, 194), |
|
|
(44, 160, 44), |
|
|
(112, 128, 144), |
|
|
(96, 207, 209), |
|
|
(227, 119, 194), |
|
|
(213, 92, 176), |
|
|
(94, 106, 211), |
|
|
(82, 84, 163), |
|
|
(100, 85, 144), |
|
|
(172, 172, 172), |
|
|
], dtype=np.uint8) |
|
|
|
|
|
|
|
|
def colorize_segmentation( |
|
|
segmentation: np.ndarray, |
|
|
palette: Optional[np.ndarray] = None, |
|
|
) -> np.ndarray: |
|
|
"""Colorize segmentation map. |
|
|
|
|
|
Args: |
|
|
segmentation: Segmentation map (H, W) with class indices. |
|
|
palette: Color palette (N, 3). Uses default if None. |
|
|
|
|
|
Returns: |
|
|
Colorized image (H, W, 3) as uint8. |
|
|
""" |
|
|
if palette is None: |
|
|
palette = create_color_palette() |
|
|
|
|
|
|
|
|
seg_clipped = np.clip(segmentation, 0, len(palette) - 1) |
|
|
return palette[seg_clipped] |
|
|
|
|
|
|
|
|
def visualize_2d_segmentation( |
|
|
image: np.ndarray, |
|
|
panoptic_2d: np.ndarray, |
|
|
output_path: Optional[Union[str, Path]] = None, |
|
|
alpha: float = 0.6, |
|
|
figsize: Tuple[int, int] = (18, 6), |
|
|
dpi: int = 150, |
|
|
) -> Optional[np.ndarray]: |
|
|
"""Visualize 2D panoptic segmentation overlaid on image. |
|
|
|
|
|
Args: |
|
|
image: Original RGB image (H, W, C). |
|
|
panoptic_2d: Panoptic segmentation map (H, W). |
|
|
output_path: Path to save visualization. If None, returns array. |
|
|
alpha: Blend alpha for overlay. |
|
|
figsize: Figure size. |
|
|
dpi: DPI for saved figure. |
|
|
|
|
|
Returns: |
|
|
Overlay image as numpy array if output_path is None. |
|
|
""" |
|
|
if not HAS_MATPLOTLIB: |
|
|
raise ImportError("matplotlib required for visualization") |
|
|
if not HAS_PIL: |
|
|
raise ImportError("PIL required for visualization") |
|
|
|
|
|
|
|
|
palette = create_color_palette() |
|
|
colored_seg = colorize_segmentation(panoptic_2d, palette) |
|
|
|
|
|
|
|
|
if image.shape[:2] != panoptic_2d.shape: |
|
|
image_pil = Image.fromarray(image) |
|
|
image_pil = image_pil.resize((panoptic_2d.shape[1], panoptic_2d.shape[0]), Image.LANCZOS) |
|
|
image = np.array(image_pil) |
|
|
|
|
|
|
|
|
overlay = (image.astype(np.float32) * (1 - alpha) + colored_seg.astype(np.float32) * alpha) |
|
|
overlay = overlay.clip(0, 255).astype(np.uint8) |
|
|
|
|
|
if output_path is None: |
|
|
return overlay |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=figsize) |
|
|
|
|
|
axes[0].imshow(image) |
|
|
axes[0].set_title('Original Image', fontsize=14, fontweight='bold') |
|
|
axes[0].axis('off') |
|
|
|
|
|
axes[1].imshow(colored_seg) |
|
|
axes[1].set_title('Panoptic Segmentation', fontsize=14, fontweight='bold') |
|
|
axes[1].axis('off') |
|
|
|
|
|
axes[2].imshow(overlay) |
|
|
axes[2].set_title('Overlay', fontsize=14, fontweight='bold') |
|
|
axes[2].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=dpi, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"✓ Saved 2D segmentation visualization to: {output_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
def visualize_depth_map( |
|
|
depth_2d: np.ndarray, |
|
|
output_path: Optional[Union[str, Path]] = None, |
|
|
vmin: float = 0.0, |
|
|
vmax: float = 6.0, |
|
|
cmap: str = 'viridis', |
|
|
figsize: Tuple[int, int] = (10, 8), |
|
|
dpi: int = 150, |
|
|
) -> Optional[np.ndarray]: |
|
|
"""Visualize depth map. |
|
|
|
|
|
Args: |
|
|
depth_2d: Depth map (H, W). |
|
|
output_path: Path to save visualization. If None, returns array. |
|
|
vmin: Minimum depth for colormap. |
|
|
vmax: Maximum depth for colormap. |
|
|
cmap: Matplotlib colormap name. |
|
|
figsize: Figure size. |
|
|
dpi: DPI for saved figure. |
|
|
|
|
|
Returns: |
|
|
Colorized depth as numpy array if output_path is None. |
|
|
""" |
|
|
if not HAS_MATPLOTLIB: |
|
|
raise ImportError("matplotlib required for visualization") |
|
|
|
|
|
|
|
|
depth_norm = (depth_2d - vmin) / (vmax - vmin) |
|
|
depth_norm = np.clip(depth_norm, 0, 1) |
|
|
|
|
|
|
|
|
cm = plt.get_cmap(cmap) |
|
|
depth_colored = (cm(depth_norm)[:, :, :3] * 255).astype(np.uint8) |
|
|
|
|
|
if output_path is None: |
|
|
return depth_colored |
|
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize) |
|
|
|
|
|
im = ax.imshow(depth_2d, cmap=cmap, vmin=vmin, vmax=vmax) |
|
|
ax.set_title('Depth Map', fontsize=14, fontweight='bold') |
|
|
ax.axis('off') |
|
|
|
|
|
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
|
|
cbar.set_label('Depth (m)', rotation=270, labelpad=20, fontsize=12) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=dpi, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"✓ Saved depth map visualization to: {output_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
def get_mesh( |
|
|
distance_field: np.ndarray, |
|
|
iso_value: float = 1.0, |
|
|
spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0), |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
"""Extract mesh from distance field using marching cubes. |
|
|
|
|
|
Args: |
|
|
distance_field: 3D distance field (D, H, W). |
|
|
iso_value: Iso-surface value. |
|
|
spacing: Voxel spacing. |
|
|
|
|
|
Returns: |
|
|
vertices: Mesh vertices (N, 3). |
|
|
faces: Mesh faces (M, 3). |
|
|
""" |
|
|
if not HAS_SKIMAGE: |
|
|
raise ImportError("scikit-image required for mesh extraction") |
|
|
|
|
|
vertices, faces, _, _ = measure.marching_cubes( |
|
|
distance_field, |
|
|
level=iso_value, |
|
|
spacing=spacing |
|
|
) |
|
|
return vertices, faces |
|
|
|
|
|
|
|
|
def write_ply( |
|
|
vertices: np.ndarray, |
|
|
output_file: Union[str, Path], |
|
|
colors: Optional[np.ndarray] = None, |
|
|
faces: Optional[np.ndarray] = None, |
|
|
) -> None: |
|
|
"""Write PLY file. |
|
|
|
|
|
Args: |
|
|
vertices: Vertex positions (N, 3). |
|
|
output_file: Output PLY file path. |
|
|
colors: Optional vertex colors (N, 3) as uint8. |
|
|
faces: Optional face indices (M, 3). |
|
|
""" |
|
|
with open(output_file, "w") as f: |
|
|
f.write("ply\n") |
|
|
f.write("format ascii 1.0\n") |
|
|
f.write(f"element vertex {len(vertices)}\n") |
|
|
f.write("property float x\n") |
|
|
f.write("property float y\n") |
|
|
f.write("property float z\n") |
|
|
|
|
|
if colors is not None: |
|
|
f.write("property uchar red\n") |
|
|
f.write("property uchar green\n") |
|
|
f.write("property uchar blue\n") |
|
|
|
|
|
if faces is not None and len(faces) > 0: |
|
|
f.write(f"element face {len(faces)}\n") |
|
|
f.write("property list uchar uint vertex_indices\n") |
|
|
|
|
|
f.write("end_header\n") |
|
|
|
|
|
|
|
|
if colors is not None: |
|
|
for v, c in zip(vertices, colors): |
|
|
f.write(f"{v[0]} {v[1]} {v[2]} {int(c[0])} {int(c[1])} {int(c[2])}\n") |
|
|
else: |
|
|
for v in vertices: |
|
|
f.write(f"{v[0]} {v[1]} {v[2]}\n") |
|
|
|
|
|
|
|
|
if faces is not None: |
|
|
for face in faces: |
|
|
f.write(f"3 {face[0]} {face[1]} {face[2]}\n") |
|
|
|
|
|
|
|
|
def save_3d_mesh( |
|
|
geometry_3d: np.ndarray, |
|
|
semantic_3d: np.ndarray, |
|
|
output_path: Union[str, Path], |
|
|
iso_value: float = 1.0, |
|
|
voxel_size: float = 0.03, |
|
|
) -> bool: |
|
|
"""Extract and save 3D mesh with semantic colors. |
|
|
|
|
|
Args: |
|
|
geometry_3d: 3D geometry/TSDF (D, H, W). |
|
|
semantic_3d: 3D semantic segmentation (D, H, W). |
|
|
output_path: Output PLY file path. |
|
|
iso_value: Iso-surface value for mesh extraction. |
|
|
voxel_size: Voxel size in meters. |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise. |
|
|
""" |
|
|
if not HAS_SKIMAGE: |
|
|
print("Warning: scikit-image not installed. Cannot save PLY mesh.") |
|
|
return False |
|
|
if not HAS_SCIPY: |
|
|
print("Warning: scipy not installed. Cannot color mesh by semantics.") |
|
|
|
|
|
try: |
|
|
|
|
|
vertices, faces = get_mesh( |
|
|
geometry_3d, |
|
|
iso_value=iso_value, |
|
|
spacing=(voxel_size, voxel_size, voxel_size) |
|
|
) |
|
|
|
|
|
colors = None |
|
|
if HAS_SCIPY and np.any(semantic_3d): |
|
|
|
|
|
nonzero_coords = np.stack(semantic_3d.nonzero(), axis=-1) |
|
|
|
|
|
if len(nonzero_coords) > 0: |
|
|
|
|
|
labels_kd = KDTree(nonzero_coords) |
|
|
palette = create_color_palette() |
|
|
|
|
|
|
|
|
semantic_clipped = np.clip(semantic_3d, 0, len(palette) - 1).astype(np.uint32) |
|
|
color_volume = palette[semantic_clipped] |
|
|
|
|
|
|
|
|
|
|
|
vertex_indices = (vertices / voxel_size).astype(int) |
|
|
neighbor_inds = labels_kd.query(vertex_indices)[1] |
|
|
neighbors = labels_kd.data[neighbor_inds].astype(int) |
|
|
|
|
|
|
|
|
neighbors = np.clip(neighbors, 0, np.array(color_volume.shape[:3]) - 1) |
|
|
colors = color_volume[neighbors[:, 0], neighbors[:, 1], neighbors[:, 2]] |
|
|
|
|
|
|
|
|
write_ply(vertices, output_path, colors, faces) |
|
|
print(f"✓ Saved 3D mesh to: {output_path}") |
|
|
print(f" Vertices: {len(vertices)}, Faces: {len(faces)}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to save 3D mesh: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def save_outputs( |
|
|
outputs, |
|
|
output_dir: Union[str, Path], |
|
|
original_image: Optional[np.ndarray] = None, |
|
|
save_mesh: bool = True, |
|
|
save_depth: bool = True, |
|
|
save_segmentation: bool = True, |
|
|
save_numpy: bool = True, |
|
|
) -> dict: |
|
|
"""Save all model outputs to directory. |
|
|
|
|
|
Args: |
|
|
outputs: PanopticRecon3DOutput from model. |
|
|
output_dir: Output directory. |
|
|
original_image: Optional original input image for visualization. |
|
|
save_mesh: Whether to save 3D mesh PLY files. |
|
|
save_depth: Whether to save depth visualization. |
|
|
save_segmentation: Whether to save segmentation visualization. |
|
|
save_numpy: Whether to save raw numpy arrays. |
|
|
|
|
|
Returns: |
|
|
Dictionary of saved file paths. |
|
|
""" |
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
saved_files = {} |
|
|
|
|
|
|
|
|
outputs_np = outputs.to_numpy() |
|
|
|
|
|
|
|
|
if save_numpy: |
|
|
for name, arr in outputs_np.items(): |
|
|
npy_path = output_dir / f"{name}.npy" |
|
|
np.save(npy_path, arr) |
|
|
saved_files[f"{name}_npy"] = str(npy_path) |
|
|
|
|
|
|
|
|
if save_segmentation and original_image is not None: |
|
|
seg_path = output_dir / "panoptic_2d_visualization.png" |
|
|
visualize_2d_segmentation( |
|
|
original_image, |
|
|
outputs_np["panoptic_seg_2d"], |
|
|
seg_path |
|
|
) |
|
|
saved_files["segmentation_vis"] = str(seg_path) |
|
|
|
|
|
|
|
|
if save_depth: |
|
|
depth_path = output_dir / "depth_visualization.png" |
|
|
visualize_depth_map( |
|
|
outputs_np["depth_2d"], |
|
|
depth_path |
|
|
) |
|
|
saved_files["depth_vis"] = str(depth_path) |
|
|
|
|
|
|
|
|
if save_mesh: |
|
|
|
|
|
semantic_mesh_path = output_dir / "mesh_semantic.ply" |
|
|
if save_3d_mesh( |
|
|
outputs_np["geometry_3d"], |
|
|
outputs_np["semantic_seg_3d"], |
|
|
semantic_mesh_path |
|
|
): |
|
|
saved_files["semantic_mesh"] = str(semantic_mesh_path) |
|
|
|
|
|
|
|
|
panoptic_mesh_path = output_dir / "mesh_panoptic.ply" |
|
|
if save_3d_mesh( |
|
|
outputs_np["geometry_3d"], |
|
|
outputs_np["panoptic_seg_3d"], |
|
|
panoptic_mesh_path |
|
|
): |
|
|
saved_files["panoptic_mesh"] = str(panoptic_mesh_path) |
|
|
|
|
|
return saved_files |
|
|
|
|
|
|