import io import os from typing import Tuple import gradio as gr import numpy as np import torch import cv2 import open3d as o3d import trimesh from moge.model.v2 import MoGeModel DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @torch.no_grad() def load_model() -> MoGeModel: print(f"Loading MoGe model on device: {DEVICE}") model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal") model = model.to(DEVICE) model.eval() return model MODEL = load_model() @torch.no_grad() def run_moge_on_image(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ image: HxWx3 RGB uint8 numpy array. Returns: points: (N, 3) float32 XYZ colors: (N, 3) uint8 RGB """ img = image.astype(np.float32) / 255.0 tensor = ( torch.from_numpy(img) .permute(2, 0, 1) .unsqueeze(0) .to(DEVICE) # (1,3,H,W) ) out = MODEL.infer(tensor) print("MoGe output keys:", list(out.keys())) # You already have this part working; # keep your existing logic if it's different. # Here’s a generic version that assumes out["pcd"] (B,N,6) or out["points"]/out["colors"]. points = None colors = None if "pcd" in out: pcd = out["pcd"] if pcd.ndim == 3 and pcd.shape[-1] >= 3: if pcd.shape[0] == 1: pcd = pcd[0] pcd_np = pcd.detach().cpu().float().numpy() points = pcd_np[:, :3] if pcd_np.shape[1] >= 6: cols = pcd_np[:, 3:6] if cols.max() <= 1.0: cols = (cols * 255.0).clip(0, 255) colors = cols.astype(np.uint8) if points is None: if "points" in out: pts = out["points"] elif "point_cloud" in out: pts = out["point_cloud"] else: pts = None if pts is not None: if pts.ndim == 3 and pts.shape[0] == 1: pts = pts[0] pts_np = pts.detach().cpu().float().numpy() points = pts_np col_tensor = None for k in ["colors", "rgb", "point_colors"]: if k in out: col_tensor = out[k] break if col_tensor is not None: if col_tensor.ndim == 3 and col_tensor.shape[0] == 1: col_tensor = col_tensor[0] col_np = col_tensor.detach().cpu().float().numpy() if col_np.max() <= 1.0: col_np = (col_np * 255.0).clip(0, 255) colors = col_np.astype(np.uint8) if points is None: raise RuntimeError(f"Could not find point cloud in MoGe output") points = points.reshape(-1, 3) if colors is None: colors = np.full_like(points, 255, dtype=np.uint8) else: colors = colors.reshape(-1, 3) n = points.shape[0] print("MoGe point count:", n) if n < 100: raise RuntimeError(f"Too few points (N={n}), refusing to export") return points, colors def pointcloud_to_ply_bytes(points: np.ndarray, colors: np.ndarray) -> bytes: n = points.shape[0] print("Writing PLY with", n, "points") header = f"""ply format ascii 1.0 element vertex {n} property float x property float y property float z property uchar red property uchar green property uchar blue end_header """ lines = [] for i in range(n): x, y, z = points[i] r, g, b = colors[i] lines.append(f"{x:.6f} {y:.6f} {z:.6f} {int(r)} {int(g)} {int(b)}") body = "\n".join(lines) + "\n" return (header + body).encode("utf-8") def pointcloud_to_mesh_glb_bytes(points: np.ndarray, colors: np.ndarray) -> bytes: """ Build a surface mesh from the point cloud using Poisson reconstruction, transfer colors from points to mesh vertices via nearest neighbor, and export as GLB with vertex colors. """ print("Building mesh from point cloud for GLB export") # Optional: downsample for speed max_points = 50000 if points.shape[0] > max_points: idx = np.random.choice(points.shape[0], max_points, replace=False) pts_ds = points[idx] cols_ds = colors[idx] else: pts_ds = points cols_ds = colors # Open3D point cloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(pts_ds.astype(np.float64)) pcd.colors = o3d.utility.Vector3dVector((cols_ds / 255.0).astype(np.float64)) # --- NEW: estimate normals --- print("Estimating normals...") pcd.estimate_normals( search_param=o3d.geometry.KDTreeSearchParamKNN(knn=30) ) # Or radius-based: # pcd.estimate_normals( # search_param=o3d.geometry.KDTreeSearchParamHybrid( # radius=0.05, max_nn=30 # ) # ) # Optional: orient normals consistently (helps Poisson) pcd.orient_normals_consistent_tangent_plane(orientation_k=30) # Poisson reconstruction print("Running Poisson reconstruction...") mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( pcd, depth=8 ) # Remove low-density vertices (optional cleanup) densities = np.asarray(densities) density_thresh = np.quantile(densities, 0.05) vertices_to_keep = densities > density_thresh mesh = mesh.select_by_index(np.where(vertices_to_keep)[0]) mesh.remove_duplicated_vertices() mesh.remove_degenerate_triangles() mesh.remove_duplicated_triangles() mesh.remove_non_manifold_edges() verts = np.asarray(mesh.vertices) faces = np.asarray(mesh.triangles) print("Mesh verts:", verts.shape, "faces:", faces.shape) if verts.shape[0] == 0 or faces.shape[0] == 0: raise RuntimeError("Mesh reconstruction failed; got empty mesh") # Transfer colors from original (downsampled) cloud to mesh vertices print("Transferring vertex colors...") pcd_tree = o3d.geometry.KDTreeFlann(pcd) vert_colors = [] pcd_colors_np = np.asarray(pcd.colors) for v in verts: _, idx, _ = pcd_tree.search_knn_vector_3d(v, 1) vert_colors.append(pcd_colors_np[idx[0]]) vert_colors = np.stack(vert_colors, axis=0) # (V,3) in [0,1] # Convert to trimesh for GLB export tm = trimesh.Trimesh( vertices=verts, faces=faces, vertex_colors=(vert_colors * 255.0).astype(np.uint8), process=False, ) glb_bytes = tm.export(file_type="glb") if isinstance(glb_bytes, str): glb_bytes = glb_bytes.encode("utf-8") return glb_bytes def infer_and_export_files(image: np.ndarray): if image is None: raise gr.Error("Please upload an image.") points, colors = run_moge_on_image(image) # PLY ply_bytes = pointcloud_to_ply_bytes(points, colors) ply_path = "output.ply" with open(ply_path, "wb") as f: f.write(ply_bytes) # GLB glb_bytes = pointcloud_to_mesh_glb_bytes(points, colors) glb_path = "output.glb" with open(glb_path, "wb") as f: f.write(glb_bytes) return ply_path, glb_path title = "MoGe 3D Reconstruction → PLY + GLB" description = ( "Upload an image. MoGe reconstructs a 3D point cloud, which is exported as PLY " "and meshed into a colored GLB suitable for Three.js." ) demo = gr.Interface( fn=infer_and_export_files, inputs=gr.Image(type="numpy", label="Input image"), outputs=[ gr.File(label="Download PLY (point cloud)"), gr.File(label="Download GLB (colored mesh)"), ], title=title, description=description, ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)