| """ |
| Gradio + Plotly point cloud viewer for .xyz, .ply and .obj files with PI3DETR model integration. |
| |
| Features: |
| - Upload .xyz (ASCII): one point per line: "x y z" (extra columns are ignored). |
| - Upload .ply: Standard PLY format point clouds. |
| - Upload .obj: OBJ format with vertices and faces (triangles). |
| - Interactive 3D view: orbit, pan, zoom with mouse. |
| - Optional: downsample for speed, normalize to unit cube, toggle axes, set point size. |
| - Dual view: Input point cloud and model predictions side-by-side. |
| - PI3DETR model integration for curve detection. |
| - Immediate point cloud rendering on upload. |
| """ |
|
|
| import io |
| import os |
| from typing import List, Dict, Optional |
|
|
| import gradio as gr |
| import numpy as np |
| import plotly.graph_objects as go |
| from plyfile import PlyData |
| import pandas |
| import torch |
| from torch_geometric.data import Data |
| import fpsample |
| import trimesh |
|
|
| |
| from pi3detr import ( |
| build_model, |
| build_model_config, |
| load_args, |
| load_weights, |
| ) |
| from pi3detr.dataset import normalize_and_scale |
|
|
| |
| PI3DETR_MODEL = None |
| MODEL_STATUS = {"loaded": False, "message": "Model not loaded"} |
|
|
| HOVER_FONT_SIZE = 16 |
| FIG_TEMPLATE = "plotly_white" |
| PLOT_HEIGHT = 800 |
|
|
| |
| DEMO_POINTCLOUDS = { |
| "Demo 1": "demo_inputs/demo1.xyz", |
| "Demo 2": "demo_inputs/demo2.xyz", |
| "Demo 3": "demo_inputs/demo3.xyz", |
| "Demo 4": "demo_inputs/demo4.xyz", |
| "Demo 5": "demo_inputs/demo5.xyz", |
| } |
|
|
|
|
| def initialize_model(checkpoint_path="model.ckpt", config_path="configs/pi3detr.yaml"): |
| """Initialize the model at startup and store it in the global cache.""" |
| global PI3DETR_MODEL, MODEL_STATUS |
| try: |
| args = load_args(config_path) if config_path else {} |
| model_config = build_model_config(args) |
| model = build_model(model_config) |
| load_weights(model, checkpoint_path) |
| model.eval() |
|
|
| PI3DETR_MODEL = model |
| MODEL_STATUS = {"loaded": True, "message": "Model loaded successfully"} |
| print("PI3DETR model initialized successfully") |
| return True |
| except Exception as e: |
| MODEL_STATUS = {"loaded": False, "message": f"Error loading model: {str(e)}"} |
| print(f"Error initializing PI3DETR model: {e}") |
| return False |
|
|
|
|
| def read_xyz(file_obj: io.BytesIO) -> np.ndarray: |
| """ |
| Parse a .xyz text file from bytes and return Nx3 float32 array. |
| Lines with fewer than 3 numeric values are skipped. |
| Only the first three numeric columns are used. |
| """ |
| if file_obj is None: |
| return np.zeros((0, 3), dtype=np.float32) |
|
|
| |
| raw = file_obj.read() |
| try: |
| text = raw.decode("utf-8", errors="ignore") |
| except Exception: |
| text = raw.decode("latin-1", errors="ignore") |
|
|
| pts = [] |
| for line in text.splitlines(): |
| line = line.strip() |
| if not line or line.startswith("#"): |
| continue |
| parts = line.replace(",", " ").split() |
| nums = [] |
| for p in parts: |
| try: |
| nums.append(float(p)) |
| except ValueError: |
| |
| pass |
| if len(nums) == 3: |
| break |
| if len(nums) >= 3: |
| pts.append(nums[:3]) |
|
|
| if not pts: |
| return np.zeros((0, 3), dtype=np.float32) |
|
|
| return np.asarray(pts, dtype=np.float32) |
|
|
|
|
| def read_ply(file_obj: io.BytesIO) -> np.ndarray: |
| """ |
| Parse a .ply file from bytes and return Nx3 float32 array of points. |
| """ |
| if file_obj is None: |
| return np.zeros((0, 3), dtype=np.float32) |
|
|
| try: |
| ply_data = PlyData.read(file_obj) |
| vertex = ply_data["vertex"] |
|
|
| x = np.asarray(vertex["x"]) |
| y = np.asarray(vertex["y"]) |
| z = np.asarray(vertex["z"]) |
|
|
| points = np.column_stack([x, y, z]).astype(np.float32) |
| return points |
| except Exception as e: |
| print(f"Error reading PLY file: {e}") |
| return np.zeros((0, 3), dtype=np.float32) |
|
|
|
|
| def read_obj_and_sample(file_obj: io.BytesIO, display_max_points: int): |
| """Parse OBJ via trimesh and sample up to display_max_points uniformly over the surface.""" |
| raw = file_obj.read() |
| |
| try: |
| mesh = trimesh.load(io.BytesIO(raw), file_type="obj", force="mesh") |
| except Exception as e: |
| print(f"trimesh load error: {e}") |
| return ( |
| np.zeros((0, 3), dtype=np.float32), |
| np.zeros((0, 3), dtype=np.float32), |
| "OBJ load failure", |
| ) |
| |
| if isinstance(mesh, trimesh.Scene): |
| mesh = trimesh.util.concatenate(tuple(g for g in mesh.geometry.values())) |
| if mesh.is_empty or mesh.vertices.shape[0] == 0: |
| return ( |
| np.zeros((0, 3), dtype=np.float32), |
| np.zeros((0, 3), dtype=np.float32), |
| "OBJ: empty mesh", |
| ) |
| sample_n = min(display_max_points, max(1, display_max_points)) |
| try: |
| sampled = mesh.sample(sample_n) |
| except Exception as e: |
| print(f"Sampling error: {e}") |
| sampled = mesh.vertices |
| if sampled.shape[0] > sample_n: |
| sampled = sampled[:sample_n] |
| sampled = np.asarray(sampled, dtype=np.float32) |
| info = f"OBJ: {mesh.vertices.shape[0]} verts, {len(mesh.faces) if mesh.faces is not None else 0} tris | Surface sampled: {sampled.shape[0]} pts" |
| model_pts = sampled.copy() |
| return model_pts, sampled, info |
|
|
|
|
| def downsample(pts: np.ndarray, max_points: int) -> np.ndarray: |
| if pts.shape[0] <= max_points: |
| return pts |
| rng = np.random.default_rng(42) |
| idx = rng.choice(pts.shape[0], size=max_points, replace=False) |
| return pts[idx] |
|
|
|
|
| def make_figure( |
| pts: np.ndarray, |
| point_size: int = 2, |
| show_axes: bool = True, |
| title: str = "", |
| polylines: Optional[List[Dict]] = None, |
| ) -> go.Figure: |
| """ |
| Build a Plotly 3D scatter figure with equal aspect ratio. |
| Optionally includes polylines from model predictions. |
| """ |
| if pts.size == 0 and (polylines is None or len(polylines) == 0): |
| fig = go.Figure() |
| fig.update_layout( |
| title="No data to display", |
| template=FIG_TEMPLATE, |
| scene=dict( |
| xaxis_visible=False, |
| yaxis_visible=False, |
| zaxis_visible=False, |
| ), |
| margin=dict(l=0, r=0, t=40, b=0), |
| ) |
| return fig |
|
|
| fig = go.Figure() |
|
|
| |
| if pts.size > 0: |
| x, y, z = pts[:, 0], pts[:, 1], pts[:, 2] |
| fig.add_trace( |
| go.Scatter3d( |
| x=x, |
| y=y, |
| z=z, |
| mode="markers", |
| marker=dict( |
| size=max(1, int(point_size)), color="darkgray", opacity=0.2 |
| ), |
| hoverinfo="skip", |
| name="Curves", |
| showlegend=False, |
| ) |
| ) |
|
|
| |
| curve_colors = { |
| "Line": "blue", |
| "Circle": "green", |
| "Arc": "red", |
| "BSpline": "purple", |
| } |
|
|
| |
| if polylines: |
| for curve in polylines: |
| points = np.array(curve["points"]) |
| if len(points) < 2: |
| continue |
|
|
| curve_type = curve["type"] |
| curve_id = curve["id"] |
| score = curve["score"] |
|
|
| |
| color = curve.get("display_color") or curve_colors.get(curve_type, "orange") |
|
|
| |
| fig.add_trace( |
| go.Scatter3d( |
| x=points[:, 0], |
| y=points[:, 1], |
| z=points[:, 2], |
| mode="lines", |
| line=dict(color=color, width=8), |
| name=f"{curve_type} #{curve_id} ({score:.2f})", |
| visible=curve.get("visible_state", True), |
| hoverinfo="text", |
| text=f"{curve_type} #{curve_id} ({score:.4f})", |
| showlegend=False, |
| ) |
| ) |
|
|
| |
| if pts.size > 0: |
| mins = pts.min(axis=0) |
| maxs = pts.max(axis=0) |
| elif polylines and len(polylines) > 0: |
| |
| all_points = np.vstack([np.array(curve["points"]) for curve in polylines]) |
| mins = all_points.min(axis=0) |
| maxs = all_points.max(axis=0) |
| else: |
| mins = np.array([-1, -1, -1]) |
| maxs = np.array([1, 1, 1]) |
|
|
| centers = (mins + maxs) / 2.0 |
| span = (maxs - mins).max() |
| if span <= 0: |
| span = 1.0 |
| half = span / 2.0 |
| xrange = [centers[0] - half, centers[0] + half] |
| yrange = [centers[1] - half, centers[1] + half] |
| zrange = [centers[2] - half, centers[2] + half] |
|
|
| scene_axes = dict( |
| xaxis=dict(range=xrange, visible=show_axes, title="x" if show_axes else ""), |
| yaxis=dict(range=yrange, visible=show_axes, title="y" if show_axes else ""), |
| zaxis=dict(range=zrange, visible=show_axes, title="z" if show_axes else ""), |
| aspectmode="cube", |
| ) |
|
|
| fig.update_layout( |
| title=title, |
| template=FIG_TEMPLATE, |
| showlegend=False, |
| scene=scene_axes, |
| margin=dict(l=0, r=0, t=40, b=0), |
| hoverlabel=dict(font=dict(size=HOVER_FONT_SIZE)), |
| height=PLOT_HEIGHT, |
| ) |
| return fig |
|
|
|
|
| def process_model_predictions(data: Data) -> list: |
| """ |
| Process model outputs into a format suitable for visualization. |
| """ |
| class_names = ["None", "BSpline", "Line", "Circle", "Arc"] |
| polylines = data.polylines.cpu().numpy() |
| curves = [] |
|
|
| |
| for i, polyline in enumerate(polylines): |
| cls = data.polyline_class[i].item() |
| score = data.polyline_score[i].item() |
| cls_name = class_names[cls] |
|
|
| |
| if cls == 0: |
| continue |
|
|
| |
| curve_data = { |
| "type": cls_name, |
| "id": i + 1, |
| "index": i, |
| "score": score, |
| "points": polyline, |
| } |
| curves.append(curve_data) |
|
|
| return curves |
|
|
|
|
| def process_data_for_model( |
| points: np.ndarray, |
| sample: int = 32768, |
| sample_mode: str = "fps", |
| ) -> Data: |
| """ |
| Process and subsample point cloud data using the same approach as predict_pi3detr.py. |
| |
| Args: |
| points: Input point cloud as numpy array |
| sample: Number of points to sample |
| sample_mode: Sampling method ("fps", "random", "uniform", "all") |
| |
| Returns: |
| Data object ready for model inference |
| """ |
| |
| pos = torch.tensor(points, dtype=torch.float32) |
|
|
| |
| if sample_mode == "random": |
| if pos.size(0) > sample: |
| indices = torch.randperm(pos.size(0))[:sample] |
| pos = pos[indices] |
|
|
| elif sample_mode == "fps": |
| if pos.size(0) > sample: |
| indices = fpsample.bucket_fps_kdline_sampling(pos, sample, h=6) |
| pos = pos[indices] |
|
|
| elif sample_mode == "uniform": |
| if pos.size(0) > sample: |
| step = max(1, pos.size(0) // sample) |
| pos = pos[::step][:sample] |
|
|
| elif sample_mode == "all": |
| pass |
|
|
| |
| data = Data(pos=pos) |
|
|
| |
| data.batch = torch.zeros(data.pos.size(0), dtype=torch.long) |
| data.batch_size = 1 |
|
|
| |
| data = normalize_and_scale(data) |
|
|
| |
| if hasattr(data, "scale") and data.scale.dim() == 0: |
| data.scale = data.scale.unsqueeze(0) |
| if hasattr(data, "center") and data.center.dim() == 1: |
| data.center = data.center.unsqueeze(0) |
|
|
| return data |
|
|
|
|
| @torch.no_grad() |
| def run_model_inference( |
| model, |
| points: np.ndarray, |
| max_points: int = 32768, |
| sample_mode: str = "fps", |
| num_queries: int = 256, |
| ) -> list: |
| """Run model inference on the given point cloud.""" |
| global PI3DETR_MODEL |
| if model is None: |
| model = PI3DETR_MODEL |
| if model is None: |
| return [] |
| try: |
| data = process_data_for_model( |
| points, sample=max_points, sample_mode=sample_mode |
| ) |
| device = next(model.parameters()).device |
| data = data.to(device) |
|
|
| if model.num_preds != num_queries: |
| model.set_num_preds(num_queries) |
|
|
| output = model.predict_step( |
| data, |
| reverse_norm=True, |
| thresholds=None, |
| ) |
| result = output[0] |
| curves = process_model_predictions(result) |
| return curves |
| except Exception as e: |
| print(f"Error in model inference: {e}") |
| return [] |
|
|
|
|
| def load_and_process_pointcloud( |
| file: gr.File, |
| max_points: int, |
| point_size: int, |
| show_axes: bool, |
| ): |
| """ |
| Load and process a point cloud from .xyz or .ply file |
| """ |
| if file is None: |
| empty_fig = make_figure(np.zeros((0, 3))) |
| return empty_fig, None, None, os.path.basename(file.name) if file else "" |
|
|
| |
| file_ext = os.path.splitext(file.name)[1].lower() |
|
|
| |
| with open(file.name, "rb") as f: |
| if file_ext == ".xyz": |
| pts = read_xyz(f) |
| mode = "XYZ" |
| elif file_ext == ".ply": |
| pts = read_ply(f) |
| mode = "PLY" |
| elif file_ext == ".obj": |
| model_pts, display_pts, _ = read_obj_and_sample(f, max_points) |
| fig = make_figure( |
| display_pts, |
| point_size=point_size, |
| show_axes=show_axes, |
| title=f"{os.path.basename(file.name)}", |
| ) |
| return fig, model_pts, display_pts, os.path.basename(file.name) |
| else: |
| empty_fig = make_figure(np.zeros((0, 3))) |
| return ( |
| empty_fig, |
| None, |
| None, |
| "Unsupported file type. Please use .xyz, .ply or .obj.", |
| "", |
| ) |
|
|
| original_n = pts.shape[0] |
|
|
| |
| model_pts = pts.copy() |
|
|
| pts = downsample(pts, max_points=max_points) |
| displayed_n = pts.shape[0] |
|
|
| fig = make_figure( |
| pts, |
| point_size=point_size, |
| show_axes=show_axes, |
| title=f"{os.path.basename(file.name)}", |
| ) |
|
|
| info = f"Loaded ({mode}): {original_n} points" |
|
|
| |
| return fig, model_pts, pts, os.path.basename(file.name) |
|
|
|
|
| def run_model_prediction( |
| model_pts: np.ndarray, |
| point_size: int, |
| show_axes: bool, |
| model_max_points: int, |
| sample_mode: str, |
| th_bspline: float, |
| th_line: float, |
| th_circle: float, |
| th_arc: float, |
| num_queries: int = 256, |
| ): |
| |
| |
| return run_model_prediction_unified( |
| model_pts, |
| None, |
| point_size, |
| show_axes, |
| model_max_points, |
| sample_mode, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| "", |
| num_queries, |
| ) |
|
|
|
|
| def run_model_prediction_unified( |
| model_pts: np.ndarray, |
| display_pts: Optional[np.ndarray], |
| point_size: int, |
| show_axes: bool, |
| model_max_points: int, |
| sample_mode: str, |
| th_bspline: float, |
| th_line: float, |
| th_circle: float, |
| th_arc: float, |
| file_name: str = "", |
| num_queries: int = 256, |
| ): |
| """ |
| Run model inference and apply initial threshold-based coloring. |
| """ |
| global PI3DETR_MODEL, MODEL_STATUS |
| if model_pts is None: |
| empty_fig = make_figure(np.zeros((0, 3))) |
| return empty_fig, [] |
|
|
| |
| curves = [] |
| try: |
| if PI3DETR_MODEL is None and not MODEL_STATUS["loaded"]: |
| |
| initialize_model() |
|
|
| if PI3DETR_MODEL is not None: |
| |
| curves = run_model_inference( |
| PI3DETR_MODEL, |
| model_pts, |
| max_points=model_max_points, |
| sample_mode=sample_mode, |
| num_queries=num_queries, |
| ) |
| except Exception: |
| pass |
|
|
| |
| thresholds = { |
| "BSpline": th_bspline, |
| "Line": th_line, |
| "Circle": th_circle, |
| "Arc": th_arc, |
| } |
| colored_curves = [] |
| for c in curves: |
| c_disp = dict(c) |
| if c["score"] < thresholds.get(c["type"], 0.7): |
| c_disp["visible_state"] = "legendonly" |
| colored_curves.append(c_disp) |
|
|
| |
| if display_pts is None: |
| display_pts = downsample(model_pts, max_points=100000) |
| title = f"{file_name} (curves)" if curves else f"{file_name} (no curves)" |
| fig = make_figure( |
| display_pts, |
| point_size=point_size, |
| show_axes=show_axes, |
| title=title, |
| polylines=colored_curves, |
| ) |
| return fig, curves |
|
|
|
|
| def apply_pointcloud_display_settings( |
| model_pts: np.ndarray, |
| curves: List[Dict], |
| max_points: int, |
| point_size: int, |
| show_axes: bool, |
| th_bspline: float, |
| th_line: float, |
| th_circle: float, |
| th_arc: float, |
| file_name: str, |
| ): |
| """ |
| Apply point cloud display settings without re-running inference. |
| Keeps existing detections and re-applies thresholds. |
| """ |
| if model_pts is None: |
| empty_fig = make_figure(np.zeros((0, 3))) |
| return empty_fig, None |
| display_pts = downsample(model_pts, max_points=max_points) |
| if not curves: |
| fig = make_figure( |
| display_pts, |
| point_size=point_size, |
| show_axes=show_axes, |
| title=file_name or "Point Cloud", |
| ) |
| return fig, display_pts |
| thresholds = { |
| "BSpline": th_bspline, |
| "Line": th_line, |
| "Circle": th_circle, |
| "Arc": th_arc, |
| } |
| colored_curves = [] |
| for c in curves: |
| c_disp = dict(c) |
| if c["score"] < thresholds.get(c["type"], 0.7): |
| c_disp["visible_state"] = "legendonly" |
| colored_curves.append(c_disp) |
| fig = make_figure( |
| display_pts, |
| point_size=point_size, |
| show_axes=show_axes, |
| title=(file_name or "Point Cloud") + " (curves)", |
| polylines=colored_curves, |
| ) |
| return fig, display_pts |
|
|
|
|
| def clear_curves( |
| curves: List[Dict], |
| display_pts: Optional[np.ndarray], |
| model_pts: Optional[np.ndarray], |
| point_size: int, |
| show_axes: bool, |
| file_name: str, |
| ): |
| """ |
| Recolor already inferred curves based on updated thresholds (no re-inference). |
| """ |
| if curves is None or model_pts is None or len(curves) == 0: |
| empty_fig = make_figure( |
| display_pts if display_pts is not None else np.zeros((0, 3)) |
| ) |
| return empty_fig, None |
|
|
| fig = make_figure( |
| display_pts if display_pts is not None else np.zeros((0, 3)), |
| point_size=point_size, |
| show_axes=show_axes, |
| title=file_name or "Point Cloud", |
| polylines=None, |
| ) |
| return fig, None |
|
|
|
|
| def load_demo_pointcloud( |
| label: str, |
| max_points: int, |
| point_size: int, |
| show_axes: bool, |
| ): |
| """ |
| Load one of the predefined demo point clouds. |
| Clears existing detected curves (curves_state -> None). |
| Also returns a value for the file upload component so the filename shows up. |
| """ |
| path = DEMO_POINTCLOUDS.get(label, "") |
| if not path or not os.path.isfile(path): |
| empty_fig = make_figure(np.zeros((0, 3))) |
| return empty_fig, None, None, "", None, None |
| ext = os.path.splitext(path)[1].lower() |
| try: |
| with open(path, "rb") as f: |
| if ext == ".xyz": |
| pts = read_xyz(f) |
| elif ext == ".ply": |
| pts = read_ply(f) |
| elif ext == ".obj": |
| model_pts, display_pts, _ = read_obj_and_sample( |
| f, min(20000, max_points) |
| ) |
| fig = make_figure( |
| display_pts, |
| point_size=1, |
| show_axes=show_axes, |
| title=f"{os.path.basename(path)} (demo)", |
| ) |
| return fig, model_pts, display_pts, os.path.basename(path), None, path |
| else: |
| empty_fig = make_figure(np.zeros((0, 3))) |
| return empty_fig, None, None, "", None, None |
| except Exception: |
| empty_fig = make_figure(np.zeros((0, 3))) |
| return empty_fig, None, None, "", None, None |
| model_pts = pts.copy() |
| pts = downsample(pts, max_points=max_points) |
| fig = make_figure( |
| pts, |
| point_size=1, |
| show_axes=show_axes, |
| title=f"{os.path.basename(path)} (demo)", |
| ) |
| return fig, model_pts, pts, os.path.basename(path), None, path |
|
|
|
|
| |
| def load_demo1(max_points, point_size, show_axes): |
| return load_demo_pointcloud("Demo 1", max_points, point_size, show_axes) |
|
|
|
|
| def load_demo2(max_points, point_size, show_axes): |
| return load_demo_pointcloud("Demo 2", max_points, point_size, show_axes) |
|
|
|
|
| def load_demo3(max_points, point_size, show_axes): |
| return load_demo_pointcloud("Demo 3", max_points, point_size, show_axes) |
|
|
|
|
| def load_demo4(max_points, point_size, show_axes): |
| return load_demo_pointcloud("Demo 4", max_points, point_size, show_axes) |
|
|
|
|
| def load_demo5(max_points, point_size, show_axes): |
| return load_demo_pointcloud("Demo 5", max_points, point_size, show_axes) |
|
|
|
|
| def build_demo_preview(label: str, max_pts: int = 20000) -> go.Figure: |
| """Create a small preview figure for a demo point cloud (no curves).""" |
| path = DEMO_POINTCLOUDS.get(label, "") |
| if not path or not os.path.isfile(path): |
| return make_figure(np.zeros((0, 3)), title=f"{label}: (missing)") |
| try: |
| ext = os.path.splitext(path)[1].lower() |
| with open(path, "rb") as f: |
| if ext == ".xyz": |
| pts = read_xyz(f) |
| elif ext == ".ply": |
| pts = read_ply(f) |
| elif ext == ".obj": |
| _, pts, _ = read_obj_and_sample(f, max_pts) |
| else: |
| return make_figure(np.zeros((0, 3)), title=f"{label}: (unsupported)") |
| pts = downsample(pts, max_pts) |
| return make_figure(pts, point_size=1, show_axes=False, title=f"{label} preview") |
| except Exception as e: |
| return make_figure(np.zeros((0, 3)), title=f"{label}: error") |
|
|
|
|
| def run_model_with_display( |
| model_pts: np.ndarray, |
| max_points: int, |
| point_size: int, |
| show_axes: bool, |
| model_max_points: int, |
| sample_mode: str, |
| th_bspline: float, |
| th_line: float, |
| th_circle: float, |
| th_arc: float, |
| file_name: str = "", |
| num_queries: int = 256, |
| ): |
| """ |
| Run inference (if model_pts present) then immediately apply current display |
| (max_points/point_size/show_axes) and thresholds. Returns: |
| figure, info_text, curves(list), display_pts |
| """ |
| if model_pts is None: |
| empty = make_figure(np.zeros((0, 3))) |
| return empty, None, None |
|
|
| |
| fig_infer, curves = run_model_prediction_unified( |
| model_pts, |
| None, |
| point_size, |
| show_axes, |
| model_max_points, |
| sample_mode, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name, |
| num_queries, |
| ) |
|
|
| |
| fig_final, display_pts = apply_pointcloud_display_settings( |
| model_pts, |
| curves, |
| max_points, |
| point_size, |
| show_axes, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name, |
| ) |
| return fig_final, curves, display_pts |
|
|
|
|
| with gr.Blocks(title="PI3DETR") as demo: |
| gr.Markdown( |
| """ |
| # 🥧 PI3DETR: Detection of Sharp 3D CAD Edges [CPU-PREVIEW] |
| |
| A novel end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**. |
| |
| <div style="margin-top: 10px;"> |
| <a href="https://arxiv.org/pdf/2509.03262" target="_blank" style=" |
| display: inline-block; |
| background-color: #4CAF50; |
| color: white; |
| padding: 8px 16px; |
| text-decoration: none; |
| border-radius: 5px; |
| margin-right: 8px; |
| font-weight: bold; |
| ">📄 Paper</a> |
| <a href="https://fafraob.github.io/pi3detr/" target="_blank" style=" |
| display: inline-block; |
| background-color: #2196F3; |
| color: white; |
| padding: 8px 16px; |
| text-decoration: none; |
| border-radius: 5px; |
| margin-right: 8px; |
| font-weight: bold; |
| ">🌐 Website</a> |
| <a href="https://github.com/fafraob/pi3detr" target="_blank" style=" |
| display: inline-block; |
| background-color: #333; |
| color: white; |
| padding: 8px 16px; |
| text-decoration: none; |
| border-radius: 5px; |
| font-weight: bold; |
| ">🐙 GitHub</a> |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown( |
| "### 🧩 Supported Inputs\n" |
| "- **Point Clouds:** `.xyz`, `.ply`; **Meshes:** `.obj`\n" |
| "- `Mesh` is surface-sampled using **Max Points (display)** slider." |
| ) |
| with gr.Column(): |
| gr.Markdown( |
| "### ⚙️ Point Cloud Settings\n" |
| "- Adjust **Max Points**, **point size**, and **axes visibility**.\n" |
| "- Controls visualization of point cloud." |
| ) |
| with gr.Column(): |
| gr.Markdown( |
| "### 🎯 Confidence Thresholds\n" |
| "- Hover to inspect scores.\n" |
| "- Filter curves by **class confidence** interactively" |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown( |
| "### 🧠 Model Settings\n" |
| "- **Sampling Mode:** Choose downsampling strategy.\n" |
| "- **Model Input Size:** Number of model input points.\n" |
| "- **Queries:** Transformer decoder queries (max. output curves)." |
| ) |
| with gr.Column(): |
| gr.Markdown( |
| "### ⚡ Performance Notes\n" |
| "- Trained on **human-made objects**.\n" |
| "- Optimized for **GPU**; this demo runs on **CPU**.\n" |
| "- For **full qualitative performance**: \n" |
| "[GitHub → PI3DETR](https://github.com/fafraob/pi3detr)" |
| ) |
| with gr.Column(): |
| gr.Markdown( |
| "### ▶️ Run Inference\n" |
| "- Click on demo point clouds (from test set) below.\n" |
| "- Press **Run PI3DETR** to execute inference and visualize results." |
| ) |
|
|
| model_pts_state = gr.State(None) |
| display_pts_state = gr.State(None) |
| curves_state = gr.State(None) |
| file_name_state = gr.State("demo_inputs/demo2.xyz") |
| with gr.Row(): |
| file_in = gr.File( |
| label="Upload Point Cloud (auto-renders)", |
| file_types=[".xyz", ".ply", ".obj"], |
| type="filepath", |
| ) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Point Cloud Settings") |
| max_points = gr.Slider( |
| 0, |
| 500_000, |
| value=200_000, |
| step=1_000, |
| label="Max points (display)", |
| ) |
| point_size = gr.Slider(1, 8, value=1, step=1, label="Point size") |
| show_axes = gr.Checkbox(value=False, label="Show axes") |
|
|
| gr.Markdown("### Model Settings") |
| sample_mode = gr.Radio( |
| ["fps", "random", "all"], |
| value="fps", |
| label="Main Sampling Method", |
| ) |
| model_max_points = gr.Slider( |
| 1_000, |
| 100_000, |
| value=32768, |
| step=500, |
| label="Downsample to Model Input Size", |
| ) |
| num_queries = gr.Slider( |
| 32, |
| 512, |
| value=128, |
| step=1, |
| label="Number of Queries", |
| ) |
|
|
| |
| gr.Markdown("#### Confidence Thresholds (per class)") |
| th_bspline = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="BSpline ≥") |
| th_line = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Line ≥") |
| th_circle = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Circle ≥") |
| th_arc = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Arc ≥") |
|
|
| with gr.Column(scale=1): |
| main_plot = gr.Plot( |
| label="Point Cloud & Curves" |
| ) |
|
|
| run_model_button = gr.Button("Run PI3DETR", variant="primary") |
| clear_curves_button = gr.Button("Clear Curves", variant="secondary") |
|
|
| |
| file_in.change( |
| load_and_process_pointcloud, |
| inputs=[file_in, max_points, point_size, show_axes], |
| outputs=[ |
| main_plot, |
| model_pts_state, |
| display_pts_state, |
| file_name_state, |
| ], |
| ) |
|
|
| run_model_button.click( |
| run_model_with_display, |
| inputs=[ |
| model_pts_state, |
| max_points, |
| point_size, |
| show_axes, |
| model_max_points, |
| sample_mode, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name_state, |
| num_queries, |
| ], |
| outputs=[main_plot, curves_state, display_pts_state], |
| ) |
|
|
| |
| def _apply_display_wrapper( |
| model_pts, |
| curves, |
| max_points, |
| point_size, |
| show_axes, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name, |
| display_pts_state_value, |
| ): |
| fig, display_pts = apply_pointcloud_display_settings( |
| model_pts, |
| curves, |
| max_points, |
| point_size, |
| show_axes, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name, |
| ) |
| return fig, display_pts |
|
|
| |
| for slider in [max_points, point_size]: |
| slider.release( |
| _apply_display_wrapper, |
| inputs=[ |
| model_pts_state, |
| curves_state, |
| max_points, |
| point_size, |
| show_axes, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name_state, |
| display_pts_state, |
| ], |
| outputs=[main_plot, display_pts_state], |
| ) |
|
|
| show_axes.change( |
| _apply_display_wrapper, |
| inputs=[ |
| model_pts_state, |
| curves_state, |
| max_points, |
| point_size, |
| show_axes, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name_state, |
| display_pts_state, |
| ], |
| outputs=[main_plot, display_pts_state], |
| ) |
|
|
| |
| for th in [th_bspline, th_line, th_circle, th_arc]: |
| th.release( |
| _apply_display_wrapper, |
| inputs=[ |
| model_pts_state, |
| curves_state, |
| max_points, |
| point_size, |
| show_axes, |
| th_bspline, |
| th_line, |
| th_circle, |
| th_arc, |
| file_name_state, |
| display_pts_state, |
| ], |
| outputs=[main_plot, display_pts_state], |
| ) |
|
|
| clear_curves_button.click( |
| clear_curves, |
| inputs=[ |
| curves_state, |
| display_pts_state, |
| model_pts_state, |
| point_size, |
| show_axes, |
| file_name_state, |
| ], |
| outputs=[main_plot, curves_state], |
| ) |
|
|
| |
| with gr.Row(): |
| gr.Markdown("### Demo Point Clouds (click an image to load)") |
| with gr.Row(): |
| |
| demo_image_components = {} |
| for label in ["Demo 1", "Demo 2", "Demo 3", "Demo 4", "Demo 5"]: |
| png_path = f"demo_inputs/{label.lower().replace(' ', '')}.png" |
| demo_image_components[label] = gr.Image( |
| value=png_path if os.path.isfile(png_path) else None, |
| label=label, |
| interactive=False, |
| ) |
|
|
| |
| _demo_loaders = { |
| "Demo 1": load_demo1, |
| "Demo 2": load_demo2, |
| "Demo 3": load_demo3, |
| "Demo 4": load_demo4, |
| "Demo 5": load_demo5, |
| } |
| for label, comp in demo_image_components.items(): |
| comp.select( |
| _demo_loaders[label], |
| inputs=[max_points, point_size, show_axes], |
| outputs=[ |
| main_plot, |
| model_pts_state, |
| display_pts_state, |
| file_name_state, |
| curves_state, |
| file_in, |
| ], |
| ) |
|
|
| |
| demo.load( |
| load_demo2, |
| inputs=[max_points, point_size, show_axes], |
| outputs=[ |
| main_plot, |
| model_pts_state, |
| display_pts_state, |
| file_name_state, |
| curves_state, |
| file_in, |
| ], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| |
| initialize_model() |
| demo.launch() |
|
|