dev-bjoern's picture
Initial SAM 3D Objects MCP server
30aba9f
raw
history blame
3.72 kB
"""
SAM 3D Objects MCP Server
Image + Mask β†’ 3D Object (PLY)
"""
import os
import sys
import subprocess
import tempfile
import uuid
from pathlib import Path
import gradio as gr
import numpy as np
import spaces
from huggingface_hub import snapshot_download, login
from PIL import Image
# Login with HF_TOKEN if available
if os.environ.get("HF_TOKEN"):
login(token=os.environ.get("HF_TOKEN"))
# Clone sam-3d-objects repo if not exists
SAM3D_PATH = Path("/home/user/app/sam-3d-objects")
if not SAM3D_PATH.exists():
print("Cloning sam-3d-objects repository...")
subprocess.run([
"git", "clone",
"https://github.com/facebookresearch/sam-3d-objects.git",
str(SAM3D_PATH)
], check=True)
sys.path.insert(0, str(SAM3D_PATH))
# Add to path
sys.path.insert(0, str(SAM3D_PATH))
# Global model
MODEL = None
def load_model():
"""Load SAM 3D Objects model"""
global MODEL
if MODEL is not None:
return MODEL
import torch
print("Loading SAM 3D Objects model...")
# Download checkpoint
checkpoint_dir = snapshot_download(
repo_id="facebook/sam-3d-objects",
token=os.environ.get("HF_TOKEN")
)
from sam_3d_objects import Sam3dObjects
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device)
print("βœ“ Model loaded")
return MODEL
@spaces.GPU(duration=120)
def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
"""
Reconstruct 3D object from image and mask.
Args:
image: Input RGB image
mask: Binary mask indicating object region
Returns:
tuple: (ply_path, status)
"""
if image is None:
return None, "❌ No image provided"
if mask is None:
return None, "❌ No mask provided"
try:
import torch
import trimesh
model = load_model()
# Process image
if isinstance(image, Image.Image):
image = np.array(image)
# Process mask
if isinstance(mask, Image.Image):
mask = np.array(mask)
# Convert mask to binary if needed
if len(mask.shape) == 3:
mask = mask[:, :, 0]
mask = (mask > 127).astype(np.uint8)
# Run inference
outputs = model.predict(image, mask)
if outputs is None:
return None, "⚠️ Reconstruction failed"
# Export as PLY
output_dir = tempfile.mkdtemp()
ply_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.ply"
# Save gaussian splat as PLY
outputs.save_ply(ply_path)
return ply_path, "βœ“ Object reconstructed"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"❌ Error: {e}"
# Gradio Interface
with gr.Blocks(title="SAM 3D Objects MCP") as demo:
gr.Markdown("# πŸ“¦ SAM 3D Objects MCP Server\n**Image + Mask β†’ 3D Object (PLY)**")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
input_mask = gr.Image(label="Object Mask", type="numpy")
btn = gr.Button("🎯 Reconstruct", variant="primary")
with gr.Column():
output_file = gr.File(label="3D Object (PLY)")
status = gr.Textbox(label="Status")
btn.click(reconstruct_object, inputs=[input_image, input_mask], outputs=[output_file, status])
gr.Markdown("""
---
### MCP Server
```json
{"mcpServers": {"sam3d-objects": {"command": "npx", "args": ["mcp-remote", "URL/gradio_api/mcp/sse"]}}}
```
""")
if __name__ == "__main__":
demo.launch(mcp_server=True)