WaveGen / EMS-superquadric_fitting_inference /process_viser_hierarchical.py
FangSen9000's picture
Upload EMS-superquadric_fitting_inference
7f585cf verified
#!/usr/bin/env python3
"""
Hierarchical multi-superquadric fitting with viser visualization
Based on the hierarchical_ems algorithm from multiquadric_test.py
"""
import numpy as np
import sys
import os
import time
import viser
from sklearn.cluster import DBSCAN
# Add the src directory to Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from EMS.EMS_recovery import EMS_recovery
def hierarchical_ems(
point,
OutlierRatio=0.5, # Reduced for better initial fit
MaxIterationEM=20,
ToleranceEM=1e-3,
RelativeToleranceEM=2e-1,
MaxOptiIterations=2,
Sigma=0.3,
MaxiSwitch=2,
AdaptiveUpperBound=True,
Rescale=False,
MaxLayer=3, # Reduced for faster processing
Eps=0.1, # Adjusted for normalized point clouds
MinPoints=50, # Minimum points to form a cluster
):
"""
Hierarchical EMS for extracting multiple superquadrics from a point cloud
"""
point_seg = {key: [] for key in list(range(0, MaxLayer+1))}
point_outlier = {key: [] for key in list(range(0, MaxLayer+1))}
point_seg[0] = [point]
list_quadrics = []
quadric_info = [] # Store additional info about each quadric
for h in range(MaxLayer):
if len(point_seg[h]) == 0:
break
for c in range(len(point_seg[h])):
current_points = point_seg[h][c]
if len(current_points) < MinPoints * 2:
continue
print(f" Layer {h}, Segment {c}: Processing {len(current_points)} points")
try:
# Fit superquadric
x_raw, p_raw = EMS_recovery(
current_points,
OutlierRatio,
MaxIterationEM,
ToleranceEM,
RelativeToleranceEM,
MaxOptiIterations,
Sigma,
MaxiSwitch,
AdaptiveUpperBound,
Rescale,
)
# Calculate fitting quality
inlier_mask = p_raw > 0.5
inlier_ratio = np.sum(inlier_mask) / len(p_raw)
if inlier_ratio > 0.3: # Accept if at least 30% inliers
list_quadrics.append(x_raw)
quadric_info.append({
'layer': h,
'segment': c,
'inlier_ratio': inlier_ratio,
'num_points': len(current_points),
'inlier_points': current_points[inlier_mask]
})
print(f" → Fitted superquadric with {inlier_ratio:.1%} inliers")
# Separate outliers for next layer
outlier_mask = p_raw < 0.1
outlier = current_points[outlier_mask]
# If many outliers and not last layer, try clustering
if len(outlier) > MinPoints * 2 and h < MaxLayer - 1:
clustering = DBSCAN(eps=Eps, min_samples=MinPoints).fit(outlier)
labels = list(set(clustering.labels_))
labels = [item for item in labels if item >= 0]
if len(labels) >= 1:
print(f" → Found {len(labels)} clusters in outliers")
for i in range(len(labels)):
cluster_points = outlier[clustering.labels_ == labels[i]]
if len(cluster_points) > MinPoints:
point_seg[h + 1].append(cluster_points)
except Exception as e:
print(f" → Error: {e}")
continue
return list_quadrics, quadric_info
def generate_superquadric_mesh(sq, num_samples=25):
"""Generate mesh vertices and faces for superquadric surface"""
eta = np.linspace(-np.pi/2, np.pi/2, num_samples)
omega = np.linspace(-np.pi, np.pi, num_samples)
vertices = []
faces = []
# Generate vertices
for i, e in enumerate(eta):
for j, w in enumerate(omega):
# Superquadric parametric equations
cos_eta = np.sign(np.cos(e)) * np.abs(np.cos(e))**sq.shape[0]
sin_eta = np.sign(np.sin(e)) * np.abs(np.sin(e))**sq.shape[0]
cos_omega = np.sign(np.cos(w)) * np.abs(np.cos(w))**sq.shape[1]
sin_omega = np.sign(np.sin(w)) * np.abs(np.sin(w))**sq.shape[1]
# Local coordinates
x_local = sq.scale[0] * cos_eta * cos_omega
y_local = sq.scale[1] * cos_eta * sin_omega
z_local = sq.scale[2] * sin_eta
# Apply rotation and translation
point_local = np.array([x_local, y_local, z_local])
point_global = sq.RotM @ point_local + sq.translation
vertices.append(point_global)
vertices = np.array(vertices)
# Generate faces (triangles)
for i in range(num_samples - 1):
for j in range(num_samples - 1):
# Current vertex indices
idx1 = i * num_samples + j
idx2 = i * num_samples + (j + 1) % num_samples
idx3 = (i + 1) * num_samples + j
idx4 = (i + 1) * num_samples + (j + 1) % num_samples
# Two triangles per quad
faces.append([idx1, idx2, idx3])
faces.append([idx2, idx4, idx3])
return vertices, np.array(faces)
def main():
# Import utilities for reading PLY files
from EMS.utilities import read_ply
all_samples = []
sample_idx = 0
print("Loading and processing samples with hierarchical multi-quadric fitting...")
# 1. Load repository example PLY files
example_data_dir = "/research/cbim/vast/sf895/code/EMS-superquadric_fitting/MATLAB/example_scripts/data"
# Single superquadric examples
single_ply_files = [
"single_superquadric/noisy_pointCloud_example_1.ply",
"single_superquadric/noisy_pointCloud_example_2.ply",
"single_superquadric/partial_pointCloud_example_1.ply",
]
# Multi superquadric examples
multi_ply_files = [
"multi_superquadrics/cat.ply",
"multi_superquadrics/dog.ply",
"multi_superquadrics/turtle.ply",
]
# Process single superquadric files
for ply_file in single_ply_files:
file_path = os.path.join(example_data_dir, ply_file)
if os.path.exists(file_path):
print(f"\nProcessing {ply_file}...")
try:
# Load PLY data
point_cloud = read_ply(file_path)
# Single quadric fitting
from EMS.EMS_recovery import EMS_recovery
sq, p = EMS_recovery(point_cloud, OutlierRatio=0.2, AdaptiveUpperBound=True)
all_samples.append({
'name': os.path.basename(ply_file),
'idx': sample_idx,
'points': point_cloud,
'quadrics': [sq],
'quadric_info': [{
'layer': 0,
'segment': 0,
'inlier_ratio': np.sum(p > 0.5) / len(p),
'num_points': len(point_cloud),
'inlier_points': point_cloud[p > 0.5]
}]
})
sample_idx += 1
print(f" Success! Shape: {sq.shape}, Scale: {sq.scale}")
except Exception as e:
print(f" Failed: {e}")
# Process multi superquadric files
for ply_file in multi_ply_files:
file_path = os.path.join(example_data_dir, ply_file)
if os.path.exists(file_path):
print(f"\nProcessing {ply_file} (multi-quadric)...")
try:
# Load PLY data
point_cloud = read_ply(file_path)
# Hierarchical multi-quadric fitting
# Adjust parameters for these specific examples
quadrics, quadric_info = hierarchical_ems(
point_cloud,
OutlierRatio=0.9, # Higher for multi-object scenes
Eps=1.7, # Larger for non-normalized data
MinPoints=60, # Standard minimum
Rescale=True # Enable rescaling for raw PLY data
)
all_samples.append({
'name': os.path.basename(ply_file),
'idx': sample_idx,
'points': point_cloud,
'quadrics': quadrics,
'quadric_info': quadric_info
})
sample_idx += 1
print(f"Summary: Found {len(quadrics)} superquadrics")
for j, (sq, info) in enumerate(zip(quadrics, quadric_info)):
print(f" SQ{j+1}: Shape={sq.shape}, Scale={sq.scale}, "
f"Inliers={info['inlier_ratio']:.1%}")
except Exception as e:
print(f" Failed: {e}")
# 2. Also load normalized point cloud samples if they exist
normalized_dir = "/research/cbim/vast/sf895/code/EMS-superquadric_fitting/20250811_231035_step10_stage0_waves1"
if os.path.exists(normalized_dir):
print("\n--- Processing normalized point cloud samples ---")
for i in range(2): # Just load first 2 samples
sample_name = f"sample_{i}_normalized_points.npz"
sample_path = os.path.join(normalized_dir, sample_name)
if os.path.exists(sample_path):
print(f"\nProcessing {sample_name}...")
try:
# Load data
data = np.load(sample_path)
point_cloud = data['points'][0] # First frame
# Hierarchical multi-quadric fitting
quadrics, quadric_info = hierarchical_ems(point_cloud)
all_samples.append({
'name': sample_name,
'idx': sample_idx,
'points': point_cloud,
'quadrics': quadrics,
'quadric_info': quadric_info
})
sample_idx += 1
print(f"Summary: Found {len(quadrics)} superquadrics")
except Exception as e:
print(f" Failed: {e}")
# Start viser server
server = viser.ViserServer(port=8080)
print(f"\n{'='*60}")
print(f"Viser server started at http://localhost:8080")
print("Open this URL in your browser to view the 3D visualization")
print("Press Ctrl+C to stop the server")
print('='*60)
# Colors for different superquadrics
quadric_colors = [
(255, 0, 0), # Red
(0, 255, 0), # Green
(0, 0, 255), # Blue
(255, 255, 0), # Yellow
(255, 0, 255), # Magenta
(0, 255, 255), # Cyan
]
# Create GUI
with server.gui.add_folder("Controls"):
# Sample selector
sample_names = [s['name'] for s in all_samples if s['points'] is not None]
current_sample = server.gui.add_dropdown(
"Select Sample",
options=sample_names,
initial_value=sample_names[0] if sample_names else None
)
# Visibility controls
show_points = server.gui.add_checkbox("Show Points", initial_value=True)
show_all_quadrics = server.gui.add_checkbox("Show All Quadrics", initial_value=True)
show_labels = server.gui.add_checkbox("Show Labels", initial_value=True)
# Individual quadric toggles will be added dynamically
quadric_toggles_folder = server.gui.add_folder("Individual Quadrics")
# Visual parameters
point_size = server.gui.add_slider(
"Point Size",
min=0.001,
max=0.02,
step=0.001,
initial_value=0.003
)
mesh_opacity = server.gui.add_slider(
"Mesh Opacity",
min=0.0,
max=1.0,
step=0.1,
initial_value=0.5
)
# Info display
info_display = server.gui.add_markdown("**Sample Info:**\n\nSelect a sample to view")
# Store current visualization handles
current_viz = {
'points': None,
'meshes': [],
'labels': [],
'quadric_toggles': []
}
def update_scene():
"""Update the 3D scene based on current selection"""
# Clear existing visualization
if current_viz['points'] is not None:
current_viz['points'].remove()
current_viz['points'] = None
for mesh in current_viz['meshes']:
mesh.remove()
current_viz['meshes'] = []
for label in current_viz['labels']:
label.remove()
current_viz['labels'] = []
# Clear quadric toggles
for toggle in current_viz['quadric_toggles']:
toggle.remove()
current_viz['quadric_toggles'] = []
# Find selected sample
selected = None
for sample in all_samples:
if sample['name'] == current_sample.value:
selected = sample
break
if selected is None or selected['points'] is None:
info_display.value = "**No valid sample selected**"
return
# Update info
info_text = f"**{selected['name']}**\n\n"
info_text += f"Total points: {len(selected['points'])}\n"
info_text += f"Superquadrics found: {len(selected['quadrics'])}\n\n"
if len(selected['quadrics']) > 0:
info_text += "**Superquadric Details:**\n"
for i, (sq, info) in enumerate(zip(selected['quadrics'], selected['quadric_info'])):
info_text += f"\n**SQ{i+1}** (Layer {info['layer']}):\n"
info_text += f"- Shape: ε₁={sq.shape[0]:.3f}, ε₂={sq.shape[1]:.3f}\n"
info_text += f"- Scale: ({sq.scale[0]:.2f}, {sq.scale[1]:.2f}, {sq.scale[2]:.2f})\n"
info_text += f"- Inliers: {info['inlier_ratio']:.1%} ({info['num_points']} points)\n"
info_display.value = info_text
# Add point cloud
if show_points.value:
current_viz['points'] = server.scene.add_point_cloud(
"/current/points",
points=selected['points'],
colors=np.array([(128, 128, 128)] * len(selected['points']), dtype=np.uint8),
point_size=point_size.value,
)
# Add individual quadric toggles
with quadric_toggles_folder:
for i in range(len(selected['quadrics'])):
toggle = server.gui.add_checkbox(
f"Quadric {i+1}",
initial_value=True
)
current_viz['quadric_toggles'].append(toggle)
# Add superquadrics
for i, (sq, info) in enumerate(zip(selected['quadrics'], selected['quadric_info'])):
color = quadric_colors[i % len(quadric_colors)]
# Check if this quadric should be shown
show_this = show_all_quadrics.value
if i < len(current_viz['quadric_toggles']):
show_this = show_this and current_viz['quadric_toggles'][i].value
if show_this:
try:
vertices, faces = generate_superquadric_mesh(sq, num_samples=20)
mesh = server.scene.add_mesh_simple(
f"/current/mesh_{i}",
vertices=vertices,
faces=faces,
color=color,
opacity=mesh_opacity.value,
)
current_viz['meshes'].append(mesh)
if show_labels.value:
label = server.scene.add_label(
f"/current/label_{i}",
text=f"SQ{i+1}: ε₁={sq.shape[0]:.2f}, ε₂={sq.shape[1]:.2f}",
position=sq.translation,
)
current_viz['labels'].append(label)
except Exception as e:
print(f"Error visualizing quadric {i}: {e}")
# Set up callbacks
@current_sample.on_update
def _(_):
update_scene()
@show_points.on_update
def _(_):
if current_viz['points'] is not None:
current_viz['points'].visible = show_points.value
@show_all_quadrics.on_update
def _(_):
for mesh in current_viz['meshes']:
mesh.visible = show_all_quadrics.value
@show_labels.on_update
def _(_):
for label in current_viz['labels']:
label.visible = show_labels.value
@point_size.on_update
def _(event):
if current_viz['points'] is not None:
current_viz['points'].point_size = event.target.value
@mesh_opacity.on_update
def _(event):
for mesh in current_viz['meshes']:
mesh.opacity = event.target.value
# Initial scene
update_scene()
# Keep server running
try:
while True:
time.sleep(0.1)
except KeyboardInterrupt:
print("\nShutting down server...")
server.stop()
if __name__ == "__main__":
main()