Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| SAM 3D Objects MCP Server | |
| Image β 3D Object (GLB) | |
| Automatic object detection with SAM2 + 3D reconstruction with SAM 3D Objects. | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from huggingface_hub import snapshot_download, login | |
| from PIL import Image | |
| # Login with HF_TOKEN if available | |
| if os.environ.get("HF_TOKEN"): | |
| login(token=os.environ.get("HF_TOKEN")) | |
| # Set CUDA_HOME for sam-3d-objects (expects conda but we're not using it) | |
| if "CUDA_HOME" not in os.environ: | |
| os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
| if "CONDA_PREFIX" not in os.environ: | |
| os.environ["CONDA_PREFIX"] = "/usr/local" | |
| # Clone sam-3d-objects repo if not exists | |
| SAM3D_PATH = Path("/home/user/app/sam-3d-objects") | |
| if not SAM3D_PATH.exists(): | |
| print("Cloning sam-3d-objects repository...") | |
| subprocess.run([ | |
| "git", "clone", | |
| "https://github.com/facebookresearch/sam-3d-objects.git", | |
| str(SAM3D_PATH) | |
| ], check=True) | |
| # Add both repo root and notebook folder to path | |
| sys.path.insert(0, str(SAM3D_PATH)) | |
| sys.path.insert(0, str(SAM3D_PATH / "notebook")) | |
| # Global models | |
| SAM3D_MODEL = None | |
| SAM2_GENERATOR = None | |
| def load_sam2(): | |
| """Load SAM2 automatic mask generator""" | |
| global SAM2_GENERATOR | |
| if SAM2_GENERATOR is not None: | |
| return SAM2_GENERATOR | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| print("Loading SAM2 model...") | |
| SAM2_GENERATOR = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-large") | |
| print("β SAM2 loaded") | |
| return SAM2_GENERATOR | |
| def load_sam3d(): | |
| """Load SAM 3D Objects model""" | |
| global SAM3D_MODEL | |
| if SAM3D_MODEL is not None: | |
| return SAM3D_MODEL | |
| import torch | |
| print("Loading SAM 3D Objects model...") | |
| # Download checkpoints | |
| checkpoint_dir = snapshot_download( | |
| repo_id="facebook/sam-3d-objects", | |
| token=os.environ.get("HF_TOKEN") | |
| ) | |
| # Import from notebook/inference.py | |
| from inference import Inference | |
| # Config path in the repo | |
| config_path = str(SAM3D_PATH / "sam3d_objects" / "configs" / "default.yaml") | |
| SAM3D_MODEL = Inference(config_path, compile=False) | |
| # Point to downloaded checkpoints | |
| SAM3D_MODEL.checkpoint_dir = checkpoint_dir | |
| print("β SAM 3D Objects loaded") | |
| return SAM3D_MODEL | |
| def reconstruct_objects(image: np.ndarray): | |
| """ | |
| Automatically detect and reconstruct 3D objects from image. | |
| Args: | |
| image: Input RGB image | |
| Returns: | |
| tuple: (glb_path, preview_image, status) | |
| """ | |
| if image is None: | |
| return None, None, "β No image provided" | |
| try: | |
| import torch | |
| import trimesh | |
| from PIL import Image as PILImage | |
| # Load models | |
| generator = load_sam2() | |
| inference = load_sam3d() | |
| # Convert to PIL if needed | |
| if isinstance(image, np.ndarray): | |
| pil_image = PILImage.fromarray(image) | |
| else: | |
| pil_image = image | |
| image = np.array(pil_image) | |
| # Auto-detect all objects with SAM2 | |
| print("Detecting objects...") | |
| masks = generator.generate(image) | |
| if not masks or len(masks) == 0: | |
| return None, image, "β οΈ No objects detected" | |
| # Sort by area, take largest object | |
| masks = sorted(masks, key=lambda x: x['area'], reverse=True) | |
| best_mask = masks[0]['segmentation'] | |
| # Create preview with mask overlay | |
| preview = image.copy() | |
| preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8) | |
| # Convert mask to PIL | |
| mask_pil = PILImage.fromarray((best_mask * 255).astype(np.uint8)) | |
| # Run 3D reconstruction | |
| print("Reconstructing 3D...") | |
| result = inference(image=pil_image, mask=mask_pil) | |
| if result is None: | |
| return None, preview, "β οΈ 3D reconstruction failed" | |
| # Export as GLB | |
| output_dir = tempfile.mkdtemp() | |
| glb_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.glb" | |
| # Extract point cloud from result and convert to mesh | |
| if hasattr(result, 'save_ply'): | |
| # Save temp PLY then convert | |
| ply_path = f"{output_dir}/temp.ply" | |
| result.save_ply(ply_path) | |
| # Load and convert to mesh using Open3D | |
| import open3d as o3d | |
| pcd = o3d.io.read_point_cloud(ply_path) | |
| # Estimate normals and create mesh via Poisson reconstruction | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb_path, mesh) | |
| elif 'gaussians' in result: | |
| ply_path = f"{output_dir}/temp.ply" | |
| result['gaussians'].save_ply(ply_path) | |
| import open3d as o3d | |
| pcd = o3d.io.read_point_cloud(ply_path) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb_path, mesh) | |
| else: | |
| # Try to extract vertices | |
| vertices = result.get('xyz', result.get('points', None)) | |
| if vertices is not None: | |
| if torch.is_tensor(vertices): | |
| vertices = vertices.cpu().numpy() | |
| # Create mesh from points | |
| import open3d as o3d | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(vertices) | |
| pcd.estimate_normals() | |
| mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8) | |
| o3d.io.write_triangle_mesh(glb_path, mesh) | |
| else: | |
| return None, preview, "β οΈ Could not extract 3D data" | |
| return glb_path, preview, f"β Detected {len(masks)} objects, reconstructed largest" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"β Error: {e}" | |
| # Gradio Interface | |
| with gr.Blocks(title="SAM 3D Objects MCP") as demo: | |
| gr.Markdown(""" | |
| # π¦ SAM 3D Objects MCP Server | |
| **Image β 3D Object (GLB)** | |
| Automatically detects objects and reconstructs the largest one in 3D. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="numpy") | |
| btn = gr.Button("π Detect & Reconstruct", variant="primary", size="lg") | |
| with gr.Column(): | |
| preview = gr.Image(label="Detected Object", type="numpy", interactive=False) | |
| status = gr.Textbox(label="Status") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_model = gr.Model3D(label="3D Preview") | |
| with gr.Column(): | |
| output_file = gr.File(label="Download GLB") | |
| btn.click( | |
| reconstruct_objects, | |
| inputs=[input_image], | |
| outputs=[output_model, preview, status] | |
| ) | |
| output_model.change(lambda x: x, inputs=[output_model], outputs=[output_file]) | |
| gr.Markdown(""" | |
| --- | |
| ### MCP Server | |
| ```json | |
| { | |
| "mcpServers": { | |
| "sam3d-objects": { | |
| "url": "https://dev-bjoern-sam3d-objects-mcp.hf.space/gradio_api/mcp/sse" | |
| } | |
| } | |
| } | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) | |