| |
| """ |
| Point Cloud Registration Demo using Open3D and 3DMatch RedKitchen |
| Interactive Gradio app for pairwise point cloud registration |
| """ |
|
|
| import json |
| import os |
| import copy |
| import numpy as np |
| import gradio as gr |
| from pathlib import Path |
| import tempfile |
| from typing import Tuple, Dict, List |
| import plotly.graph_objects as go |
|
|
| try: |
| import open3d as o3d |
| except ImportError: |
| raise ImportError("Please install open3d: pip install open3d") |
|
|
|
|
| def get_file_path(file_obj): |
| """Convert a Gradio file output or string into a filesystem path.""" |
| if file_obj is None: |
| return None |
|
|
| if isinstance(file_obj, str): |
| return file_obj |
|
|
| if hasattr(file_obj, "name"): |
| return file_obj.name |
|
|
| if isinstance(file_obj, dict): |
| return file_obj.get("path") or file_obj.get("name") |
|
|
| return str(file_obj) |
|
|
|
|
| def point_cloud_to_arrays(pcd, max_points=30000): |
| """Convert Open3D point cloud to numpy array, optionally downsampled.""" |
| points = np.asarray(pcd.points) |
| |
| if len(points) == 0: |
| return points |
| |
| if len(points) > max_points: |
| idx = np.random.choice(len(points), max_points, replace=False) |
| points = points[idx] |
| |
| return points |
|
|
|
|
| class RegistrationDemo: |
| def __init__(self, examples_dir: str = "examples"): |
| self.examples_dir = Path(examples_dir) |
| self.metadata = self._load_metadata() |
| self.current_source = None |
| self.current_target = None |
| self.current_source_down = None |
| self.current_target_down = None |
| |
| def _load_metadata(self) -> Dict: |
| """Load pair metadata""" |
| metadata_file = self.examples_dir / "pair_metadata.json" |
| if metadata_file.exists(): |
| with open(metadata_file) as f: |
| return json.load(f) |
| return {} |
| |
| def get_pair_choices(self) -> List[str]: |
| """Get list of available demo pairs""" |
| if not self.metadata: |
| return [] |
| return sorted(self.metadata.keys()) |
| |
| def load_demo_pair(self, pair_key: str) -> Tuple[str, str]: |
| """Load source and target from demo pair""" |
| if pair_key not in self.metadata: |
| return None, None |
| |
| pair_info = self.metadata[pair_key] |
| source_file = self.examples_dir / pair_info["source_file"] |
| target_file = self.examples_dir / pair_info["target_file"] |
| |
| if source_file.exists() and target_file.exists(): |
| return str(source_file), str(target_file) |
| |
| return None, None |
| |
| def load_point_clouds(self, source_path: str, target_path: str) -> Tuple[bool, str]: |
| """Load point clouds from files""" |
| try: |
| self.current_source = o3d.io.read_point_cloud(source_path) |
| self.current_target = o3d.io.read_point_cloud(target_path) |
| |
| n_source = len(self.current_source.points) |
| n_target = len(self.current_target.points) |
| |
| return True, f"✓ Loaded: {n_source} source points, {n_target} target points" |
| except Exception as e: |
| return False, f"✗ Error loading point clouds: {e}" |
| |
| def preprocess_point_clouds(self, voxel_size: float) -> Tuple[bool, str]: |
| """Preprocess loaded point clouds""" |
| if self.current_source is None or self.current_target is None: |
| return False, "✗ No point clouds loaded" |
| |
| try: |
| |
| def remove_non_finite(pcd): |
| pcd_clean = o3d.geometry.PointCloud() |
| mask = ~np.any(~np.isfinite(np.asarray(pcd.points)), axis=1) |
| pcd_clean.points = o3d.utility.Vector3dVector(np.asarray(pcd.points)[mask]) |
| if pcd.has_normals(): |
| pcd_clean.normals = o3d.utility.Vector3dVector(np.asarray(pcd.normals)[mask]) |
| pcd_clean.remove_duplicated_points() |
| return pcd_clean |
| |
| source_clean = remove_non_finite(self.current_source) |
| target_clean = remove_non_finite(self.current_target) |
| |
| |
| self.current_source_down = source_clean.voxel_down_sample(voxel_size) |
| self.current_target_down = target_clean.voxel_down_sample(voxel_size) |
| |
| n_source = len(self.current_source_down.points) |
| n_target = len(self.current_target_down.points) |
| |
| if n_source < 10 or n_target < 10: |
| return False, f"✗ Too few points after downsampling: {n_source}, {n_target}" |
| |
| return True, f"✓ Preprocessed: {n_source} source, {n_target} target points" |
| |
| except Exception as e: |
| return False, f"✗ Preprocessing error: {e}" |
| |
| def estimate_normals_and_features( |
| self, |
| voxel_size: float, |
| normal_radius_mult: float, |
| fpfh_radius_mult: float, |
| ) -> Tuple[bool, str, object, object]: |
| """Estimate normals and compute FPFH features""" |
| if self.current_source_down is None or self.current_target_down is None: |
| return False, "✗ No preprocessed point clouds", None, None |
| |
| try: |
| normal_radius = voxel_size * normal_radius_mult |
| fpfh_radius = voxel_size * fpfh_radius_mult |
| |
| |
| self.current_source_down.estimate_normals( |
| o3d.geometry.KDTreeSearchParamRadius(radius=normal_radius) |
| ) |
| self.current_target_down.estimate_normals( |
| o3d.geometry.KDTreeSearchParamRadius(radius=normal_radius) |
| ) |
| |
| |
| source_fpfh = o3d.pipelines.registration.compute_fpfh_feature( |
| self.current_source_down, |
| o3d.geometry.KDTreeSearchParamRadius(radius=fpfh_radius) |
| ) |
| |
| target_fpfh = o3d.pipelines.registration.compute_fpfh_feature( |
| self.current_target_down, |
| o3d.geometry.KDTreeSearchParamRadius(radius=fpfh_radius) |
| ) |
| |
| return True, "✓ Features computed", source_fpfh, target_fpfh |
| |
| except Exception as e: |
| return False, f"✗ Feature computation error: {e}", None, None |
| |
| def ransac_registration( |
| self, |
| voxel_size: float, |
| normal_radius_mult: float = 2.0, |
| fpfh_radius_mult: float = 5.0, |
| distance_mult: float = 1.5, |
| max_iterations: int = 50000, |
| ) -> Tuple[bool, str, np.ndarray]: |
| """Global registration using RANSAC""" |
| if self.current_source_down is None or self.current_target_down is None: |
| return False, "✗ No preprocessed point clouds", None |
| |
| try: |
| success, msg, source_fpfh, target_fpfh = self.estimate_normals_and_features( |
| voxel_size, |
| normal_radius_mult, |
| fpfh_radius_mult, |
| ) |
| if not success: |
| return False, msg, None |
| |
| distance_threshold = voxel_size * distance_mult |
| |
| result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( |
| self.current_source_down, self.current_target_down, |
| source_fpfh, target_fpfh, |
| mutual_filter=False, |
| max_correspondence_distance=distance_threshold, |
| estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False), |
| ransac_n=3, |
| checkers=[ |
| o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), |
| o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold) |
| ], |
| criteria=o3d.pipelines.registration.RANSACConvergenceCriteria( |
| int(max_iterations), |
| 0.999, |
| ) |
| ) |
| |
| msg = f"✓ RANSAC: fitness={result.fitness:.4f}, RMSE={result.inlier_rmse:.6f}" |
| return True, msg, result.transformation |
| |
| except Exception as e: |
| return False, f"✗ RANSAC error: {e}", None |
| |
| def icp_registration(self, init_transform: np.ndarray, voxel_size: float, |
| distance_mult: float = 0.4, max_iterations: int = 50) -> Tuple[bool, str, np.ndarray]: |
| """Local registration using ICP""" |
| if self.current_source_down is None or self.current_target_down is None: |
| return False, "✗ No preprocessed point clouds", None |
| |
| try: |
| normal_radius = voxel_size * 2.0 |
|
|
| if not self.current_source_down.has_normals(): |
| self.current_source_down.estimate_normals( |
| o3d.geometry.KDTreeSearchParamRadius(radius=normal_radius) |
| ) |
|
|
| if not self.current_target_down.has_normals(): |
| self.current_target_down.estimate_normals( |
| o3d.geometry.KDTreeSearchParamRadius(radius=normal_radius) |
| ) |
| |
| distance_threshold = voxel_size * distance_mult |
| |
| result = o3d.pipelines.registration.registration_icp( |
| self.current_source_down, self.current_target_down, |
| max_correspondence_distance=distance_threshold, |
| init=init_transform, |
| criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iterations), |
| estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPlane() |
| ) |
| |
| msg = f"✓ ICP: fitness={result.fitness:.4f}, RMSE={result.inlier_rmse:.6f}" |
| return True, msg, result.transformation |
| |
| except Exception as e: |
| return False, f"✗ ICP error: {e}", None |
| |
| def create_before_visualization(self) -> Tuple[go.Figure, str]: |
| """Create before registration Plotly visualization.""" |
| if self.current_source is None or self.current_target is None: |
| return None, "✗ No point clouds loaded" |
| |
| try: |
| src = point_cloud_to_arrays(self.current_source) |
| tgt = point_cloud_to_arrays(self.current_target) |
| |
| fig = go.Figure() |
| |
| fig.add_trace(go.Scatter3d( |
| x=src[:, 0], y=src[:, 1], z=src[:, 2], |
| mode="markers", |
| marker=dict(size=1.5, color="orange"), |
| name="Source" |
| )) |
| |
| fig.add_trace(go.Scatter3d( |
| x=tgt[:, 0], y=tgt[:, 1], z=tgt[:, 2], |
| mode="markers", |
| marker=dict(size=1.5, color="cyan"), |
| name="Target" |
| )) |
| |
| fig.update_layout( |
| scene=dict(aspectmode="data"), |
| margin=dict(l=0, r=0, t=30, b=0), |
| height=520, |
| legend=dict(x=0, y=1) |
| ) |
| |
| return fig, "✓ Before visualization ready" |
| |
| except Exception as e: |
| return None, f"✗ Visualization error: {e}" |
| |
| def create_after_visualization(self, transform: np.ndarray) -> Tuple[go.Figure, str]: |
| """Create after registration Plotly visualization.""" |
| if self.current_source is None or self.current_target is None: |
| return None, "✗ No point clouds loaded" |
| |
| if transform is None: |
| return None, "✗ No transformation available" |
| |
| try: |
| source_transformed = copy.deepcopy(self.current_source) |
| source_transformed.transform(transform) |
| |
| src = point_cloud_to_arrays(source_transformed) |
| tgt = point_cloud_to_arrays(self.current_target) |
| |
| fig = go.Figure() |
| |
| fig.add_trace(go.Scatter3d( |
| x=src[:, 0], y=src[:, 1], z=src[:, 2], |
| mode="markers", |
| marker=dict(size=1.5, color="lime"), |
| name="Aligned Source" |
| )) |
| |
| fig.add_trace(go.Scatter3d( |
| x=tgt[:, 0], y=tgt[:, 1], z=tgt[:, 2], |
| mode="markers", |
| marker=dict(size=1.5, color="cyan"), |
| name="Target" |
| )) |
| |
| fig.update_layout( |
| scene=dict(aspectmode="data"), |
| margin=dict(l=0, r=0, t=30, b=0), |
| height=520, |
| legend=dict(x=0, y=1) |
| ) |
| |
| return fig, "✓ After visualization ready" |
| |
| except Exception as e: |
| return None, f"✗ Visualization error: {e}" |
| |
| def get_metrics_and_transform(self, fitness: float, rmse: float, |
| transform: np.ndarray) -> Tuple[str, str]: |
| """Format metrics and transformation matrix for display""" |
| metrics_text = f""" |
| Fitness: {fitness:.6f} |
| RMSE: {rmse:.6f} |
| """.strip() |
| |
| if transform is not None: |
| transform_text = "Transformation Matrix (4x4):\n" |
| for row in transform: |
| transform_text += " ".join([f"{x:12.6f}" for x in row]) + "\n" |
| else: |
| transform_text = "No transformation available" |
| |
| return metrics_text, transform_text |
| |
| def save_aligned_source(self, transform: np.ndarray) -> Tuple[str, str]: |
| """Save transformed source point cloud only.""" |
| if self.current_source is None: |
| return None, "✗ No source point cloud loaded" |
|
|
| if transform is None: |
| return None, "✗ No transformation available" |
|
|
| try: |
| aligned_source = copy.deepcopy(self.current_source) |
| aligned_source.transform(transform) |
|
|
| with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as f: |
| o3d.io.write_point_cloud(f.name, aligned_source) |
| return f.name, "✓ Aligned source saved" |
|
|
| except Exception as e: |
| return None, f"✗ Save aligned source error: {e}" |
| |
| def register(self, pair_selection: str, mode: str, voxel_size: float, |
| normal_radius_mult: float, fpfh_radius_mult: float, |
| ransac_dist_mult: float, ransac_iter: int, |
| icp_dist_mult: float, icp_iter: int, |
| source_upload, target_upload) -> Tuple: |
| """Main registration pipeline""" |
| |
| |
| if pair_selection != "Upload" and pair_selection: |
| source_path, target_path = self.load_demo_pair(pair_selection) |
| if source_path is None: |
| return None, None, "✗ Failed to load demo pair", "", "", "" |
| else: |
| if source_upload is None or target_upload is None: |
| return None, None, "✗ Please upload both source and target", "", "", "" |
| source_path = get_file_path(source_upload) |
| target_path = get_file_path(target_upload) |
| |
| |
| success, msg = self.load_point_clouds(source_path, target_path) |
| if not success: |
| return None, None, msg, "", "", "" |
| |
| |
| success, msg = self.preprocess_point_clouds(voxel_size) |
| if not success: |
| return None, None, msg, "", "", "" |
| |
| |
| final_transform = np.eye(4) |
| ransac_msg = "" |
| icp_msg = "" |
| |
| try: |
| if mode in ["RANSAC + ICP", "RANSAC only"]: |
| success, ransac_msg, transform = self.ransac_registration( |
| voxel_size, |
| normal_radius_mult, |
| fpfh_radius_mult, |
| ransac_dist_mult, |
| int(ransac_iter), |
| ) |
| if not success: |
| return None, None, ransac_msg, "", "", "" |
| final_transform = transform |
| |
| if mode in ["RANSAC + ICP", "ICP only"]: |
| if mode == "ICP only": |
| init = np.eye(4) |
| else: |
| init = final_transform |
| |
| success, icp_msg, transform = self.icp_registration( |
| init, voxel_size, icp_dist_mult, icp_iter |
| ) |
| if not success: |
| return None, None, icp_msg, "", "", "" |
| final_transform = transform |
| |
| |
| before_file, before_msg = self.create_before_visualization() |
| after_file, after_msg = self.create_after_visualization(final_transform) |
| |
| if before_file is None or after_file is None: |
| return None, None, before_msg or after_msg, "", "", "" |
| |
| |
| source_aligned = copy.deepcopy(self.current_source_down) |
| source_aligned.transform(final_transform) |
| |
| distances = source_aligned.compute_point_cloud_distance(self.current_target_down) |
| distances = np.asarray(distances) |
| |
| fitness = np.sum(distances < voxel_size * 1.5) / len(distances) |
| rmse = np.sqrt(np.mean(distances ** 2)) |
| |
| metrics_text, transform_text = self.get_metrics_and_transform(fitness, rmse, final_transform) |
| |
| status_msg = f"✓ Registration complete!\n{ransac_msg}\n{icp_msg}".strip() |
| |
| aligned_file, aligned_msg = self.save_aligned_source(final_transform) |
|
|
| if aligned_file is None: |
| return before_file, after_file, status_msg + "\n" + aligned_msg, metrics_text, transform_text, None |
|
|
| return before_file, after_file, status_msg, metrics_text, transform_text, aligned_file |
| |
| except Exception as e: |
| return None, None, f"✗ Registration error: {e}", "", "", "" |
|
|
|
|
| def create_gradio_interface(): |
| """Create Gradio interface""" |
| demo = RegistrationDemo(examples_dir="examples") |
| |
| pair_choices = demo.get_pair_choices() |
| if not pair_choices: |
| pair_choices = ["(No demo pairs loaded)"] |
| |
| with gr.Blocks(title="Point Cloud Registration Demo") as app: |
| gr.Markdown(""" |
| # Point Cloud Registration Demo |
| |
| An interactive demo for pairwise 3D point cloud registration using **Open3D**. |
| |
| This demo includes pre-selected point cloud pairs from the **3DMatch Geometric Registration Benchmark**, specifically the **RedKitchen** scene. |
| These point clouds are 3D fragments reconstructed from depth frames using TSDF fusion. |
| |
| Below, you can choose a provided demo pair or upload your own source and target point clouds for registration. |
| |
| The registration pipeline uses: |
| - **RANSAC + FPFH** for global registration to estimate the initial transformation. |
| - **ICP** for local refinement to improve the alignment. |
| - Supported modes include **RANSAC + ICP**, **RANSAC only**, and **ICP only**. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Input") |
| |
| pair_selection = gr.Dropdown( |
| choices=["Upload"] + pair_choices, |
| value="Upload" if not pair_choices else pair_choices[0], |
| label="Dataset Pair" |
| ) |
| |
| source_upload = gr.File( |
| label="Source Point Cloud (PLY)", |
| file_types=[".ply"], |
| type="filepath", |
| ) |
|
|
| target_upload = gr.File( |
| label="Target Point Cloud (PLY)", |
| file_types=[".ply"], |
| type="filepath", |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Registration Settings") |
| |
| mode = gr.Radio( |
| choices=["RANSAC + ICP", "RANSAC only", "ICP only"], |
| value="RANSAC + ICP", |
| label="Registration Mode" |
| ) |
| |
| voxel_size = gr.Slider(0.01, 0.2, value=0.05, step=0.01, label="Voxel Size") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### RANSAC Parameters") |
| ransac_dist_mult = gr.Slider(0.5, 3.0, value=1.5, step=0.1, label="Distance Multiplier") |
| ransac_iter = gr.Slider(1000, 100000, value=50000, step=5000, label="Max Iterations") |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### ICP Parameters") |
| icp_dist_mult = gr.Slider(0.1, 1.0, value=0.4, step=0.1, label="Distance Multiplier") |
| icp_iter = gr.Slider(10, 100, value=50, step=5, label="Max Iterations") |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Feature Parameters") |
| normal_radius_mult = gr.Slider(1.0, 5.0, value=2.0, step=0.5, label="Normal Radius Mult") |
| fpfh_radius_mult = gr.Slider(2.0, 10.0, value=5.0, step=0.5, label="FPFH Radius Mult") |
| |
| run_button = gr.Button("Run Registration", variant="primary", size="lg") |
| |
| with gr.Row(): |
| status = gr.Textbox(label="Status", lines=2, interactive=False) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Before Registration") |
| before_viewer = gr.Plot(label="Before Registration") |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### After Registration") |
| after_viewer = gr.Plot(label="After Registration") |
|
|
| download_aligned = gr.File( |
| label="Download Aligned Source", |
| type="filepath", |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| metrics = gr.Textbox(label="Metrics", lines=3, interactive=False) |
| |
| with gr.Column(scale=1): |
| transform_matrix = gr.Textbox(label="Transformation Matrix", lines=5, interactive=False) |
| |
| |
| run_button.click( |
| fn=demo.register, |
| inputs=[ |
| pair_selection, mode, voxel_size, |
| normal_radius_mult, fpfh_radius_mult, |
| ransac_dist_mult, ransac_iter, |
| icp_dist_mult, icp_iter, |
| source_upload, target_upload |
| ], |
| outputs=[before_viewer, after_viewer, status, metrics, transform_matrix, download_aligned] |
| ) |
| |
| return app |
|
|
|
|
| if __name__ == "__main__": |
| app = create_gradio_interface() |
| app.launch( |
| share=False, |
| server_name="0.0.0.0", |
| server_port=7860, |
| ssr_mode=False, |
| ) |
|
|