| |
| """ |
| Hierarchical multi-superquadric fitting with viser visualization |
| Based on the hierarchical_ems algorithm from multiquadric_test.py |
| """ |
|
|
| import numpy as np |
| import sys |
| import os |
| import time |
| import viser |
| from sklearn.cluster import DBSCAN |
|
|
| |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) |
|
|
| from EMS.EMS_recovery import EMS_recovery |
|
|
|
|
| def hierarchical_ems( |
| point, |
| OutlierRatio=0.5, |
| MaxIterationEM=20, |
| ToleranceEM=1e-3, |
| RelativeToleranceEM=2e-1, |
| MaxOptiIterations=2, |
| Sigma=0.3, |
| MaxiSwitch=2, |
| AdaptiveUpperBound=True, |
| Rescale=False, |
| MaxLayer=3, |
| Eps=0.1, |
| MinPoints=50, |
| ): |
| """ |
| Hierarchical EMS for extracting multiple superquadrics from a point cloud |
| """ |
| point_seg = {key: [] for key in list(range(0, MaxLayer+1))} |
| point_outlier = {key: [] for key in list(range(0, MaxLayer+1))} |
| point_seg[0] = [point] |
| list_quadrics = [] |
| quadric_info = [] |
| |
| for h in range(MaxLayer): |
| if len(point_seg[h]) == 0: |
| break |
| |
| for c in range(len(point_seg[h])): |
| current_points = point_seg[h][c] |
| if len(current_points) < MinPoints * 2: |
| continue |
| |
| print(f" Layer {h}, Segment {c}: Processing {len(current_points)} points") |
| |
| try: |
| |
| x_raw, p_raw = EMS_recovery( |
| current_points, |
| OutlierRatio, |
| MaxIterationEM, |
| ToleranceEM, |
| RelativeToleranceEM, |
| MaxOptiIterations, |
| Sigma, |
| MaxiSwitch, |
| AdaptiveUpperBound, |
| Rescale, |
| ) |
| |
| |
| inlier_mask = p_raw > 0.5 |
| inlier_ratio = np.sum(inlier_mask) / len(p_raw) |
| |
| if inlier_ratio > 0.3: |
| list_quadrics.append(x_raw) |
| quadric_info.append({ |
| 'layer': h, |
| 'segment': c, |
| 'inlier_ratio': inlier_ratio, |
| 'num_points': len(current_points), |
| 'inlier_points': current_points[inlier_mask] |
| }) |
| print(f" → Fitted superquadric with {inlier_ratio:.1%} inliers") |
| |
| |
| outlier_mask = p_raw < 0.1 |
| outlier = current_points[outlier_mask] |
| |
| |
| if len(outlier) > MinPoints * 2 and h < MaxLayer - 1: |
| clustering = DBSCAN(eps=Eps, min_samples=MinPoints).fit(outlier) |
| labels = list(set(clustering.labels_)) |
| labels = [item for item in labels if item >= 0] |
| |
| if len(labels) >= 1: |
| print(f" → Found {len(labels)} clusters in outliers") |
| for i in range(len(labels)): |
| cluster_points = outlier[clustering.labels_ == labels[i]] |
| if len(cluster_points) > MinPoints: |
| point_seg[h + 1].append(cluster_points) |
| |
| except Exception as e: |
| print(f" → Error: {e}") |
| continue |
| |
| return list_quadrics, quadric_info |
|
|
|
|
| def generate_superquadric_mesh(sq, num_samples=25): |
| """Generate mesh vertices and faces for superquadric surface""" |
| eta = np.linspace(-np.pi/2, np.pi/2, num_samples) |
| omega = np.linspace(-np.pi, np.pi, num_samples) |
| |
| vertices = [] |
| faces = [] |
| |
| |
| for i, e in enumerate(eta): |
| for j, w in enumerate(omega): |
| |
| cos_eta = np.sign(np.cos(e)) * np.abs(np.cos(e))**sq.shape[0] |
| sin_eta = np.sign(np.sin(e)) * np.abs(np.sin(e))**sq.shape[0] |
| cos_omega = np.sign(np.cos(w)) * np.abs(np.cos(w))**sq.shape[1] |
| sin_omega = np.sign(np.sin(w)) * np.abs(np.sin(w))**sq.shape[1] |
| |
| |
| x_local = sq.scale[0] * cos_eta * cos_omega |
| y_local = sq.scale[1] * cos_eta * sin_omega |
| z_local = sq.scale[2] * sin_eta |
| |
| |
| point_local = np.array([x_local, y_local, z_local]) |
| point_global = sq.RotM @ point_local + sq.translation |
| |
| vertices.append(point_global) |
| |
| vertices = np.array(vertices) |
| |
| |
| for i in range(num_samples - 1): |
| for j in range(num_samples - 1): |
| |
| idx1 = i * num_samples + j |
| idx2 = i * num_samples + (j + 1) % num_samples |
| idx3 = (i + 1) * num_samples + j |
| idx4 = (i + 1) * num_samples + (j + 1) % num_samples |
| |
| |
| faces.append([idx1, idx2, idx3]) |
| faces.append([idx2, idx4, idx3]) |
| |
| return vertices, np.array(faces) |
|
|
|
|
| def main(): |
| |
| from EMS.utilities import read_ply |
| |
| all_samples = [] |
| sample_idx = 0 |
| |
| print("Loading and processing samples with hierarchical multi-quadric fitting...") |
| |
| |
| example_data_dir = "/research/cbim/vast/sf895/code/EMS-superquadric_fitting/MATLAB/example_scripts/data" |
| |
| |
| single_ply_files = [ |
| "single_superquadric/noisy_pointCloud_example_1.ply", |
| "single_superquadric/noisy_pointCloud_example_2.ply", |
| "single_superquadric/partial_pointCloud_example_1.ply", |
| ] |
| |
| |
| multi_ply_files = [ |
| "multi_superquadrics/cat.ply", |
| "multi_superquadrics/dog.ply", |
| "multi_superquadrics/turtle.ply", |
| ] |
| |
| |
| for ply_file in single_ply_files: |
| file_path = os.path.join(example_data_dir, ply_file) |
| if os.path.exists(file_path): |
| print(f"\nProcessing {ply_file}...") |
| try: |
| |
| point_cloud = read_ply(file_path) |
| |
| |
| from EMS.EMS_recovery import EMS_recovery |
| sq, p = EMS_recovery(point_cloud, OutlierRatio=0.2, AdaptiveUpperBound=True) |
| |
| all_samples.append({ |
| 'name': os.path.basename(ply_file), |
| 'idx': sample_idx, |
| 'points': point_cloud, |
| 'quadrics': [sq], |
| 'quadric_info': [{ |
| 'layer': 0, |
| 'segment': 0, |
| 'inlier_ratio': np.sum(p > 0.5) / len(p), |
| 'num_points': len(point_cloud), |
| 'inlier_points': point_cloud[p > 0.5] |
| }] |
| }) |
| sample_idx += 1 |
| |
| print(f" Success! Shape: {sq.shape}, Scale: {sq.scale}") |
| |
| except Exception as e: |
| print(f" Failed: {e}") |
| |
| |
| for ply_file in multi_ply_files: |
| file_path = os.path.join(example_data_dir, ply_file) |
| if os.path.exists(file_path): |
| print(f"\nProcessing {ply_file} (multi-quadric)...") |
| try: |
| |
| point_cloud = read_ply(file_path) |
| |
| |
| |
| quadrics, quadric_info = hierarchical_ems( |
| point_cloud, |
| OutlierRatio=0.9, |
| Eps=1.7, |
| MinPoints=60, |
| Rescale=True |
| ) |
| |
| all_samples.append({ |
| 'name': os.path.basename(ply_file), |
| 'idx': sample_idx, |
| 'points': point_cloud, |
| 'quadrics': quadrics, |
| 'quadric_info': quadric_info |
| }) |
| sample_idx += 1 |
| |
| print(f"Summary: Found {len(quadrics)} superquadrics") |
| for j, (sq, info) in enumerate(zip(quadrics, quadric_info)): |
| print(f" SQ{j+1}: Shape={sq.shape}, Scale={sq.scale}, " |
| f"Inliers={info['inlier_ratio']:.1%}") |
| |
| except Exception as e: |
| print(f" Failed: {e}") |
| |
| |
| normalized_dir = "/research/cbim/vast/sf895/code/EMS-superquadric_fitting/20250811_231035_step10_stage0_waves1" |
| if os.path.exists(normalized_dir): |
| print("\n--- Processing normalized point cloud samples ---") |
| for i in range(2): |
| sample_name = f"sample_{i}_normalized_points.npz" |
| sample_path = os.path.join(normalized_dir, sample_name) |
| |
| if os.path.exists(sample_path): |
| print(f"\nProcessing {sample_name}...") |
| try: |
| |
| data = np.load(sample_path) |
| point_cloud = data['points'][0] |
| |
| |
| quadrics, quadric_info = hierarchical_ems(point_cloud) |
| |
| all_samples.append({ |
| 'name': sample_name, |
| 'idx': sample_idx, |
| 'points': point_cloud, |
| 'quadrics': quadrics, |
| 'quadric_info': quadric_info |
| }) |
| sample_idx += 1 |
| |
| print(f"Summary: Found {len(quadrics)} superquadrics") |
| |
| except Exception as e: |
| print(f" Failed: {e}") |
| |
| |
| server = viser.ViserServer(port=8080) |
| print(f"\n{'='*60}") |
| print(f"Viser server started at http://localhost:8080") |
| print("Open this URL in your browser to view the 3D visualization") |
| print("Press Ctrl+C to stop the server") |
| print('='*60) |
| |
| |
| quadric_colors = [ |
| (255, 0, 0), |
| (0, 255, 0), |
| (0, 0, 255), |
| (255, 255, 0), |
| (255, 0, 255), |
| (0, 255, 255), |
| ] |
| |
| |
| with server.gui.add_folder("Controls"): |
| |
| sample_names = [s['name'] for s in all_samples if s['points'] is not None] |
| current_sample = server.gui.add_dropdown( |
| "Select Sample", |
| options=sample_names, |
| initial_value=sample_names[0] if sample_names else None |
| ) |
| |
| |
| show_points = server.gui.add_checkbox("Show Points", initial_value=True) |
| show_all_quadrics = server.gui.add_checkbox("Show All Quadrics", initial_value=True) |
| show_labels = server.gui.add_checkbox("Show Labels", initial_value=True) |
| |
| |
| quadric_toggles_folder = server.gui.add_folder("Individual Quadrics") |
| |
| |
| point_size = server.gui.add_slider( |
| "Point Size", |
| min=0.001, |
| max=0.02, |
| step=0.001, |
| initial_value=0.003 |
| ) |
| |
| mesh_opacity = server.gui.add_slider( |
| "Mesh Opacity", |
| min=0.0, |
| max=1.0, |
| step=0.1, |
| initial_value=0.5 |
| ) |
| |
| |
| info_display = server.gui.add_markdown("**Sample Info:**\n\nSelect a sample to view") |
| |
| |
| current_viz = { |
| 'points': None, |
| 'meshes': [], |
| 'labels': [], |
| 'quadric_toggles': [] |
| } |
| |
| def update_scene(): |
| """Update the 3D scene based on current selection""" |
| |
| if current_viz['points'] is not None: |
| current_viz['points'].remove() |
| current_viz['points'] = None |
| |
| for mesh in current_viz['meshes']: |
| mesh.remove() |
| current_viz['meshes'] = [] |
| |
| for label in current_viz['labels']: |
| label.remove() |
| current_viz['labels'] = [] |
| |
| |
| for toggle in current_viz['quadric_toggles']: |
| toggle.remove() |
| current_viz['quadric_toggles'] = [] |
| |
| |
| selected = None |
| for sample in all_samples: |
| if sample['name'] == current_sample.value: |
| selected = sample |
| break |
| |
| if selected is None or selected['points'] is None: |
| info_display.value = "**No valid sample selected**" |
| return |
| |
| |
| info_text = f"**{selected['name']}**\n\n" |
| info_text += f"Total points: {len(selected['points'])}\n" |
| info_text += f"Superquadrics found: {len(selected['quadrics'])}\n\n" |
| |
| if len(selected['quadrics']) > 0: |
| info_text += "**Superquadric Details:**\n" |
| for i, (sq, info) in enumerate(zip(selected['quadrics'], selected['quadric_info'])): |
| info_text += f"\n**SQ{i+1}** (Layer {info['layer']}):\n" |
| info_text += f"- Shape: ε₁={sq.shape[0]:.3f}, ε₂={sq.shape[1]:.3f}\n" |
| info_text += f"- Scale: ({sq.scale[0]:.2f}, {sq.scale[1]:.2f}, {sq.scale[2]:.2f})\n" |
| info_text += f"- Inliers: {info['inlier_ratio']:.1%} ({info['num_points']} points)\n" |
| |
| info_display.value = info_text |
| |
| |
| if show_points.value: |
| current_viz['points'] = server.scene.add_point_cloud( |
| "/current/points", |
| points=selected['points'], |
| colors=np.array([(128, 128, 128)] * len(selected['points']), dtype=np.uint8), |
| point_size=point_size.value, |
| ) |
| |
| |
| with quadric_toggles_folder: |
| for i in range(len(selected['quadrics'])): |
| toggle = server.gui.add_checkbox( |
| f"Quadric {i+1}", |
| initial_value=True |
| ) |
| current_viz['quadric_toggles'].append(toggle) |
| |
| |
| for i, (sq, info) in enumerate(zip(selected['quadrics'], selected['quadric_info'])): |
| color = quadric_colors[i % len(quadric_colors)] |
| |
| |
| show_this = show_all_quadrics.value |
| if i < len(current_viz['quadric_toggles']): |
| show_this = show_this and current_viz['quadric_toggles'][i].value |
| |
| if show_this: |
| try: |
| vertices, faces = generate_superquadric_mesh(sq, num_samples=20) |
| |
| mesh = server.scene.add_mesh_simple( |
| f"/current/mesh_{i}", |
| vertices=vertices, |
| faces=faces, |
| color=color, |
| opacity=mesh_opacity.value, |
| ) |
| current_viz['meshes'].append(mesh) |
| |
| if show_labels.value: |
| label = server.scene.add_label( |
| f"/current/label_{i}", |
| text=f"SQ{i+1}: ε₁={sq.shape[0]:.2f}, ε₂={sq.shape[1]:.2f}", |
| position=sq.translation, |
| ) |
| current_viz['labels'].append(label) |
| |
| except Exception as e: |
| print(f"Error visualizing quadric {i}: {e}") |
| |
| |
| @current_sample.on_update |
| def _(_): |
| update_scene() |
| |
| @show_points.on_update |
| def _(_): |
| if current_viz['points'] is not None: |
| current_viz['points'].visible = show_points.value |
| |
| @show_all_quadrics.on_update |
| def _(_): |
| for mesh in current_viz['meshes']: |
| mesh.visible = show_all_quadrics.value |
| |
| @show_labels.on_update |
| def _(_): |
| for label in current_viz['labels']: |
| label.visible = show_labels.value |
| |
| @point_size.on_update |
| def _(event): |
| if current_viz['points'] is not None: |
| current_viz['points'].point_size = event.target.value |
| |
| @mesh_opacity.on_update |
| def _(event): |
| for mesh in current_viz['meshes']: |
| mesh.opacity = event.target.value |
| |
| |
| update_scene() |
| |
| |
| try: |
| while True: |
| time.sleep(0.1) |
| except KeyboardInterrupt: |
| print("\nShutting down server...") |
| server.stop() |
|
|
|
|
| if __name__ == "__main__": |
| main() |