dev-bjoern commited on
Commit
af69327
Β·
1 Parent(s): 696b8f4

Add SAM3 for auto-segmentation, GLB export

Browse files
Files changed (2) hide show
  1. app.py +125 -24
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  SAM 3D Objects MCP Server
3
- Image + Mask β†’ 3D Object (PLY)
4
  """
5
  import os
6
  import sys
@@ -33,16 +33,41 @@ if not SAM3D_PATH.exists():
33
  # Add to path
34
  sys.path.insert(0, str(SAM3D_PATH))
35
 
36
- # Global model
37
- MODEL = None
 
38
 
39
 
40
- def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  """Load SAM 3D Objects model"""
42
- global MODEL
43
 
44
- if MODEL is not None:
45
- return MODEL
46
 
47
  import torch
48
  print("Loading SAM 3D Objects model...")
@@ -57,10 +82,57 @@ def load_model():
57
 
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
59
 
60
- MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device)
61
 
62
- print("βœ“ Model loaded")
63
- return MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  @spaces.GPU(duration=120)
@@ -73,17 +145,17 @@ def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
73
  mask: Binary mask indicating object region
74
 
75
  Returns:
76
- tuple: (ply_path, status)
77
  """
78
  if image is None:
79
  return None, "❌ No image provided"
80
  if mask is None:
81
- return None, "❌ No mask provided"
82
 
83
  try:
84
  import torch
85
  import trimesh
86
- model = load_model()
87
 
88
  # Process image
89
  if isinstance(image, Image.Image):
@@ -104,14 +176,20 @@ def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
104
  if outputs is None:
105
  return None, "⚠️ Reconstruction failed"
106
 
107
- # Export as PLY
108
  output_dir = tempfile.mkdtemp()
109
- ply_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.ply"
110
 
111
- # Save gaussian splat as PLY
112
- outputs.save_ply(ply_path)
 
113
 
114
- return ply_path, "βœ“ Object reconstructed"
 
 
 
 
 
115
 
116
  except Exception as e:
117
  import traceback
@@ -121,19 +199,42 @@ def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
121
 
122
  # Gradio Interface
123
  with gr.Blocks(title="SAM 3D Objects MCP") as demo:
124
- gr.Markdown("# πŸ“¦ SAM 3D Objects MCP Server\n**Image + Mask β†’ 3D Object (PLY)**")
 
 
 
125
 
126
  with gr.Row():
127
  with gr.Column():
128
- input_image = gr.Image(label="Input Image", type="numpy")
129
- input_mask = gr.Image(label="Object Mask", type="numpy")
130
- btn = gr.Button("🎯 Reconstruct", variant="primary")
131
 
132
  with gr.Column():
133
- output_file = gr.File(label="3D Object (PLY)")
 
 
 
 
 
 
 
 
 
134
  status = gr.Textbox(label="Status")
135
 
136
- btn.click(reconstruct_object, inputs=[input_image, input_mask], outputs=[output_file, status])
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  gr.Markdown("""
139
  ---
 
1
  """
2
  SAM 3D Objects MCP Server
3
+ Image + Click β†’ 3D Object (GLB)
4
  """
5
  import os
6
  import sys
 
33
  # Add to path
34
  sys.path.insert(0, str(SAM3D_PATH))
35
 
36
+ # Global models
37
+ SAM3D_MODEL = None
38
+ SAM_PREDICTOR = None
39
 
40
 
41
+ def load_sam_model():
42
+ """Load SAM3 model for segmentation"""
43
+ global SAM_PREDICTOR
44
+
45
+ if SAM_PREDICTOR is not None:
46
+ return SAM_PREDICTOR
47
+
48
+ import torch
49
+ from sam3 import SAM3ImagePredictor
50
+
51
+ print("Loading SAM3 model...")
52
+
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+
55
+ SAM_PREDICTOR = SAM3ImagePredictor.from_pretrained(
56
+ "facebook/sam3-hiera-large",
57
+ device=device,
58
+ token=os.environ.get("HF_TOKEN")
59
+ )
60
+
61
+ print("βœ“ SAM3 model loaded")
62
+ return SAM_PREDICTOR
63
+
64
+
65
+ def load_sam3d_model():
66
  """Load SAM 3D Objects model"""
67
+ global SAM3D_MODEL
68
 
69
+ if SAM3D_MODEL is not None:
70
+ return SAM3D_MODEL
71
 
72
  import torch
73
  print("Loading SAM 3D Objects model...")
 
82
 
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
84
 
85
+ SAM3D_MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device)
86
 
87
+ print("βœ“ SAM 3D Objects model loaded")
88
+ return SAM3D_MODEL
89
+
90
+
91
+ @spaces.GPU(duration=60)
92
+ def segment_object(image: np.ndarray, evt: gr.SelectData) -> np.ndarray:
93
+ """
94
+ Segment object at clicked point using SAM2.
95
+
96
+ Args:
97
+ image: Input RGB image
98
+ evt: Click event with coordinates
99
+
100
+ Returns:
101
+ Image with mask overlay
102
+ """
103
+ if image is None:
104
+ return None
105
+
106
+ try:
107
+ predictor = load_sam_model()
108
+
109
+ # Get click coordinates
110
+ point = np.array([[evt.index[0], evt.index[1]]])
111
+ label = np.array([1]) # 1 = foreground
112
+
113
+ # Set image
114
+ predictor.set_image(image)
115
+
116
+ # Predict mask
117
+ masks, scores, _ = predictor.predict(
118
+ point_coords=point,
119
+ point_labels=label,
120
+ multimask_output=True
121
+ )
122
+
123
+ # Use best mask
124
+ best_mask = masks[np.argmax(scores)]
125
+
126
+ # Create overlay
127
+ overlay = image.copy()
128
+ overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5
129
+
130
+ return overlay, best_mask.astype(np.uint8) * 255
131
+
132
+ except Exception as e:
133
+ import traceback
134
+ traceback.print_exc()
135
+ return image, None
136
 
137
 
138
  @spaces.GPU(duration=120)
 
145
  mask: Binary mask indicating object region
146
 
147
  Returns:
148
+ tuple: (glb_path, status)
149
  """
150
  if image is None:
151
  return None, "❌ No image provided"
152
  if mask is None:
153
+ return None, "❌ No mask provided - click on object first"
154
 
155
  try:
156
  import torch
157
  import trimesh
158
+ model = load_sam3d_model()
159
 
160
  # Process image
161
  if isinstance(image, Image.Image):
 
176
  if outputs is None:
177
  return None, "⚠️ Reconstruction failed"
178
 
179
+ # Export as GLB via trimesh
180
  output_dir = tempfile.mkdtemp()
181
+ glb_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.glb"
182
 
183
+ # Get vertices and faces from gaussian splat
184
+ # Convert to mesh and export as GLB
185
+ vertices = outputs.get_xyz().cpu().numpy()
186
 
187
+ # Create point cloud mesh (gaussian splats don't have faces directly)
188
+ # We'll export as a point cloud GLB
189
+ cloud = trimesh.PointCloud(vertices)
190
+ cloud.export(glb_path, file_type='glb')
191
+
192
+ return glb_path, f"βœ“ Object reconstructed ({len(vertices)} points)"
193
 
194
  except Exception as e:
195
  import traceback
 
199
 
200
  # Gradio Interface
201
  with gr.Blocks(title="SAM 3D Objects MCP") as demo:
202
+ gr.Markdown("# πŸ“¦ SAM 3D Objects MCP Server\n**Click on object β†’ 3D Reconstruction (GLB)**")
203
+
204
+ # State for mask
205
+ mask_state = gr.State(None)
206
 
207
  with gr.Row():
208
  with gr.Column():
209
+ input_image = gr.Image(label="Input Image (click on object)", type="numpy")
210
+ gr.Markdown("*Click on the object you want to reconstruct*")
 
211
 
212
  with gr.Column():
213
+ preview_image = gr.Image(label="Segmentation Preview", type="numpy", interactive=False)
214
+
215
+ with gr.Row():
216
+ btn = gr.Button("🎯 Reconstruct 3D", variant="primary", size="lg")
217
+
218
+ with gr.Row():
219
+ with gr.Column():
220
+ output_model = gr.Model3D(label="3D Object")
221
+ output_file = gr.File(label="Download GLB")
222
+ with gr.Column():
223
  status = gr.Textbox(label="Status")
224
 
225
+ # Click to segment
226
+ input_image.select(
227
+ segment_object,
228
+ inputs=[input_image],
229
+ outputs=[preview_image, mask_state]
230
+ )
231
+
232
+ # Reconstruct
233
+ btn.click(
234
+ reconstruct_object,
235
+ inputs=[input_image, mask_state],
236
+ outputs=[output_file, status]
237
+ )
238
 
239
  gr.Markdown("""
240
  ---
requirements.txt CHANGED
@@ -20,3 +20,4 @@ jaxtyping
20
  rich
21
  kaolin==0.17.0
22
  gsplat
 
 
20
  rich
21
  kaolin==0.17.0
22
  gsplat
23
+ sam3 @ git+https://github.com/facebookresearch/sam3.git