Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| SAM 3D Objects MCP Server | |
| Image + Text/Click β 3D Object (GLB) | |
| Uses SAM3 for segmentation and SAM 3D Objects for 3D reconstruction. | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from huggingface_hub import snapshot_download, login | |
| from PIL import Image | |
| # Login with HF_TOKEN if available | |
| if os.environ.get("HF_TOKEN"): | |
| login(token=os.environ.get("HF_TOKEN")) | |
| # Clone sam-3d-objects repo if not exists | |
| SAM3D_PATH = Path("/home/user/app/sam-3d-objects") | |
| if not SAM3D_PATH.exists(): | |
| print("Cloning sam-3d-objects repository...") | |
| subprocess.run([ | |
| "git", "clone", | |
| "https://github.com/facebookresearch/sam-3d-objects.git", | |
| str(SAM3D_PATH) | |
| ], check=True) | |
| sys.path.insert(0, str(SAM3D_PATH)) | |
| sys.path.insert(0, str(SAM3D_PATH)) | |
| # Global models | |
| SAM3D_MODEL = None | |
| SAM3_PREDICTOR = None | |
| def load_sam3(): | |
| """Load SAM3 for segmentation""" | |
| global SAM3_PREDICTOR | |
| if SAM3_PREDICTOR is not None: | |
| return SAM3_PREDICTOR | |
| import torch | |
| from sam3.model_builder import build_sam3_image_model | |
| from sam3.model.sam3_image_processor import Sam3Processor | |
| print("Loading SAM3 model...") | |
| model = build_sam3_image_model() | |
| SAM3_PREDICTOR = Sam3Processor(model) | |
| print("β SAM3 loaded") | |
| return SAM3_PREDICTOR | |
| def load_sam3d(): | |
| """Load SAM 3D Objects model""" | |
| global SAM3D_MODEL | |
| if SAM3D_MODEL is not None: | |
| return SAM3D_MODEL | |
| import torch | |
| print("Loading SAM 3D Objects model...") | |
| checkpoint_dir = snapshot_download( | |
| repo_id="facebook/sam-3d-objects", | |
| token=os.environ.get("HF_TOKEN") | |
| ) | |
| from sam_3d_objects import Sam3dObjects | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| SAM3D_MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device) | |
| print("β SAM 3D Objects loaded") | |
| return SAM3D_MODEL | |
| def segment_with_text(image: np.ndarray, text_prompt: str): | |
| """Segment object using text prompt with SAM3""" | |
| if image is None: | |
| return None, None, "β No image provided" | |
| if not text_prompt: | |
| return None, None, "β No text prompt provided" | |
| try: | |
| from PIL import Image as PILImage | |
| processor = load_sam3() | |
| # Convert to PIL | |
| if isinstance(image, np.ndarray): | |
| pil_image = PILImage.fromarray(image) | |
| else: | |
| pil_image = image | |
| # Run SAM3 with text prompt | |
| state = processor.set_image(pil_image) | |
| output = processor.set_text_prompt(state=state, prompt=text_prompt) | |
| if output is None or "masks" not in output: | |
| return image, None, "β οΈ No object found" | |
| masks = output["masks"] | |
| scores = output.get("scores", [1.0]) | |
| if len(masks) == 0: | |
| return image, None, "β οΈ No object found" | |
| # Use best mask | |
| best_idx = np.argmax(scores) if len(scores) > 0 else 0 | |
| mask = np.array(masks[best_idx]) | |
| # Create overlay | |
| overlay = image.copy() | |
| overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8) | |
| return overlay, (mask > 0).astype(np.uint8) * 255, f"β Found: {text_prompt}" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return image, None, f"β Error: {e}" | |
| def segment_with_click(image: np.ndarray, evt: gr.SelectData): | |
| """Segment object at clicked point with SAM3""" | |
| if image is None: | |
| return None, None, "β No image provided" | |
| try: | |
| from PIL import Image as PILImage | |
| processor = load_sam3() | |
| # Convert to PIL | |
| if isinstance(image, np.ndarray): | |
| pil_image = PILImage.fromarray(image) | |
| else: | |
| pil_image = image | |
| # Get click coordinates | |
| point = [evt.index[0], evt.index[1]] | |
| # Run SAM3 with point prompt | |
| state = processor.set_image(pil_image) | |
| output = processor.set_point_prompt(state=state, points=[point], labels=[1]) | |
| if output is None or "masks" not in output: | |
| return image, None, "β οΈ No object found" | |
| masks = output["masks"] | |
| scores = output.get("scores", [1.0]) | |
| if len(masks) == 0: | |
| return image, None, "β οΈ No object found" | |
| # Use best mask | |
| best_idx = np.argmax(scores) if len(scores) > 0 else 0 | |
| mask = np.array(masks[best_idx]) | |
| # Create overlay | |
| overlay = image.copy() | |
| overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8) | |
| return overlay, (mask > 0).astype(np.uint8) * 255, "β Object selected" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return image, None, f"β Error: {e}" | |
| def reconstruct_3d(image: np.ndarray, mask: np.ndarray): | |
| """ | |
| Reconstruct 3D object from image and mask. | |
| Args: | |
| image: Input RGB image | |
| mask: Binary mask from SAM3 | |
| Returns: | |
| tuple: (glb_path, status) | |
| """ | |
| if image is None: | |
| return None, "β No image provided" | |
| if mask is None: | |
| return None, "β No mask - segment object first" | |
| try: | |
| import torch | |
| import trimesh | |
| model = load_sam3d() | |
| # Ensure mask is binary | |
| if len(mask.shape) == 3: | |
| mask = mask[:, :, 0] | |
| mask = (mask > 127).astype(np.uint8) | |
| # Run 3D reconstruction | |
| outputs = model.predict(image, mask) | |
| if outputs is None: | |
| return None, "β οΈ Reconstruction failed" | |
| # Export as GLB | |
| output_dir = tempfile.mkdtemp() | |
| glb_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.glb" | |
| # Get vertices from gaussian splat | |
| vertices = outputs.get_xyz().cpu().numpy() | |
| # Export as point cloud GLB | |
| cloud = trimesh.PointCloud(vertices) | |
| cloud.export(glb_path, file_type='glb') | |
| return glb_path, f"β Reconstructed ({len(vertices)} points)" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"β Error: {e}" | |
| # Gradio Interface | |
| with gr.Blocks(title="SAM 3D Objects MCP") as demo: | |
| gr.Markdown(""" | |
| # π¦ SAM 3D Objects MCP Server | |
| **Image β 3D Object (GLB)** | |
| 1. Upload image | |
| 2. Segment: Type what to select OR click on object | |
| 3. Reconstruct 3D | |
| """) | |
| mask_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="numpy") | |
| with gr.Row(): | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g. 'the chair', 'red car', 'coffee mug'", | |
| scale=3 | |
| ) | |
| segment_btn = gr.Button("π― Segment", scale=1) | |
| gr.Markdown("*Or click directly on the object in the image*") | |
| with gr.Column(): | |
| preview = gr.Image(label="Segmentation Preview", type="numpy", interactive=False) | |
| status = gr.Textbox(label="Status") | |
| with gr.Row(): | |
| reconstruct_btn = gr.Button("π Reconstruct 3D", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_model = gr.Model3D(label="3D Preview") | |
| with gr.Column(): | |
| output_file = gr.File(label="Download GLB") | |
| # Events | |
| segment_btn.click( | |
| segment_with_text, | |
| inputs=[input_image, text_prompt], | |
| outputs=[preview, mask_state, status] | |
| ) | |
| input_image.select( | |
| segment_with_click, | |
| inputs=[input_image], | |
| outputs=[preview, mask_state, status] | |
| ) | |
| reconstruct_btn.click( | |
| reconstruct_3d, | |
| inputs=[input_image, mask_state], | |
| outputs=[output_model, status] | |
| ) | |
| output_model.change(lambda x: x, inputs=[output_model], outputs=[output_file]) | |
| gr.Markdown(""" | |
| --- | |
| ### MCP Server | |
| ```json | |
| { | |
| "mcpServers": { | |
| "sam3d-objects": { | |
| "url": "https://dev-bjoern-sam3d-objects-mcp.hf.space/gradio_api/mcp/sse" | |
| } | |
| } | |
| } | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) | |