dev-bjoern commited on
Commit
1ed02fd
Β·
1 Parent(s): af69327

SAM3 text/click segmentation + SAM 3D Objects reconstruction

Browse files
Files changed (1) hide show
  1. app.py +110 -72
app.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
  SAM 3D Objects MCP Server
3
- Image + Click β†’ 3D Object (GLB)
 
 
4
  """
5
  import os
6
  import sys
@@ -30,39 +32,38 @@ if not SAM3D_PATH.exists():
30
  ], check=True)
31
  sys.path.insert(0, str(SAM3D_PATH))
32
 
33
- # Add to path
34
  sys.path.insert(0, str(SAM3D_PATH))
35
 
36
  # Global models
37
  SAM3D_MODEL = None
38
- SAM_PREDICTOR = None
39
 
40
 
41
- def load_sam_model():
42
- """Load SAM3 model for segmentation"""
43
- global SAM_PREDICTOR
44
 
45
- if SAM_PREDICTOR is not None:
46
- return SAM_PREDICTOR
47
 
48
  import torch
49
- from sam3 import SAM3ImagePredictor
50
 
51
  print("Loading SAM3 model...")
52
 
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
 
55
- SAM_PREDICTOR = SAM3ImagePredictor.from_pretrained(
56
  "facebook/sam3-hiera-large",
57
  device=device,
58
  token=os.environ.get("HF_TOKEN")
59
  )
60
 
61
- print("βœ“ SAM3 model loaded")
62
- return SAM_PREDICTOR
63
 
64
 
65
- def load_sam3d_model():
66
  """Load SAM 3D Objects model"""
67
  global SAM3D_MODEL
68
 
@@ -72,7 +73,6 @@ def load_sam3d_model():
72
  import torch
73
  print("Loading SAM 3D Objects model...")
74
 
75
- # Download checkpoint
76
  checkpoint_dir = snapshot_download(
77
  repo_id="facebook/sam-3d-objects",
78
  token=os.environ.get("HF_TOKEN")
@@ -81,39 +81,60 @@ def load_sam3d_model():
81
  from sam_3d_objects import Sam3dObjects
82
 
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
84
-
85
  SAM3D_MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device)
86
 
87
- print("βœ“ SAM 3D Objects model loaded")
88
  return SAM3D_MODEL
89
 
90
 
91
  @spaces.GPU(duration=60)
92
- def segment_object(image: np.ndarray, evt: gr.SelectData) -> np.ndarray:
93
- """
94
- Segment object at clicked point using SAM2.
 
 
 
95
 
96
- Args:
97
- image: Input RGB image
98
- evt: Click event with coordinates
99
 
100
- Returns:
101
- Image with mask overlay
102
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if image is None:
104
- return None
105
 
106
  try:
107
- predictor = load_sam_model()
108
 
109
  # Get click coordinates
110
  point = np.array([[evt.index[0], evt.index[1]]])
111
- label = np.array([1]) # 1 = foreground
112
 
113
- # Set image
114
  predictor.set_image(image)
115
-
116
- # Predict mask
117
  masks, scores, _ = predictor.predict(
118
  point_coords=point,
119
  point_labels=label,
@@ -121,28 +142,29 @@ def segment_object(image: np.ndarray, evt: gr.SelectData) -> np.ndarray:
121
  )
122
 
123
  # Use best mask
124
- best_mask = masks[np.argmax(scores)]
 
125
 
126
  # Create overlay
127
  overlay = image.copy()
128
- overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5
129
 
130
- return overlay, best_mask.astype(np.uint8) * 255
131
 
132
  except Exception as e:
133
  import traceback
134
  traceback.print_exc()
135
- return image, None
136
 
137
 
138
  @spaces.GPU(duration=120)
139
- def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
140
  """
141
  Reconstruct 3D object from image and mask.
142
 
143
  Args:
144
  image: Input RGB image
145
- mask: Binary mask indicating object region
146
 
147
  Returns:
148
  tuple: (glb_path, status)
@@ -150,46 +172,37 @@ def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
150
  if image is None:
151
  return None, "❌ No image provided"
152
  if mask is None:
153
- return None, "❌ No mask provided - click on object first"
154
 
155
  try:
156
  import torch
157
  import trimesh
158
- model = load_sam3d_model()
159
 
160
- # Process image
161
- if isinstance(image, Image.Image):
162
- image = np.array(image)
163
 
164
- # Process mask
165
- if isinstance(mask, Image.Image):
166
- mask = np.array(mask)
167
-
168
- # Convert mask to binary if needed
169
  if len(mask.shape) == 3:
170
  mask = mask[:, :, 0]
171
  mask = (mask > 127).astype(np.uint8)
172
 
173
- # Run inference
174
  outputs = model.predict(image, mask)
175
 
176
  if outputs is None:
177
  return None, "⚠️ Reconstruction failed"
178
 
179
- # Export as GLB via trimesh
180
  output_dir = tempfile.mkdtemp()
181
  glb_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.glb"
182
 
183
- # Get vertices and faces from gaussian splat
184
- # Convert to mesh and export as GLB
185
  vertices = outputs.get_xyz().cpu().numpy()
186
 
187
- # Create point cloud mesh (gaussian splats don't have faces directly)
188
- # We'll export as a point cloud GLB
189
  cloud = trimesh.PointCloud(vertices)
190
  cloud.export(glb_path, file_type='glb')
191
 
192
- return glb_path, f"βœ“ Object reconstructed ({len(vertices)} points)"
193
 
194
  except Exception as e:
195
  import traceback
@@ -199,39 +212,57 @@ def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
199
 
200
  # Gradio Interface
201
  with gr.Blocks(title="SAM 3D Objects MCP") as demo:
202
- gr.Markdown("# πŸ“¦ SAM 3D Objects MCP Server\n**Click on object β†’ 3D Reconstruction (GLB)**")
 
 
 
 
 
 
 
203
 
204
- # State for mask
205
  mask_state = gr.State(None)
206
 
207
  with gr.Row():
208
  with gr.Column():
209
- input_image = gr.Image(label="Input Image (click on object)", type="numpy")
210
- gr.Markdown("*Click on the object you want to reconstruct*")
 
 
 
 
 
 
 
 
 
211
 
212
  with gr.Column():
213
- preview_image = gr.Image(label="Segmentation Preview", type="numpy", interactive=False)
 
214
 
215
  with gr.Row():
216
- btn = gr.Button("🎯 Reconstruct 3D", variant="primary", size="lg")
217
 
218
  with gr.Row():
219
- with gr.Column():
220
- output_model = gr.Model3D(label="3D Object")
221
- output_file = gr.File(label="Download GLB")
222
- with gr.Column():
223
- status = gr.Textbox(label="Status")
 
 
 
 
224
 
225
- # Click to segment
226
  input_image.select(
227
- segment_object,
228
  inputs=[input_image],
229
- outputs=[preview_image, mask_state]
230
  )
231
 
232
- # Reconstruct
233
- btn.click(
234
- reconstruct_object,
235
  inputs=[input_image, mask_state],
236
  outputs=[output_file, status]
237
  )
@@ -240,7 +271,14 @@ with gr.Blocks(title="SAM 3D Objects MCP") as demo:
240
  ---
241
  ### MCP Server
242
  ```json
243
- {"mcpServers": {"sam3d-objects": {"command": "npx", "args": ["mcp-remote", "URL/gradio_api/mcp/sse"]}}}
 
 
 
 
 
 
 
244
  ```
245
  """)
246
 
 
1
  """
2
  SAM 3D Objects MCP Server
3
+ Image + Text/Click β†’ 3D Object (GLB)
4
+
5
+ Uses SAM3 for segmentation and SAM 3D Objects for 3D reconstruction.
6
  """
7
  import os
8
  import sys
 
32
  ], check=True)
33
  sys.path.insert(0, str(SAM3D_PATH))
34
 
 
35
  sys.path.insert(0, str(SAM3D_PATH))
36
 
37
  # Global models
38
  SAM3D_MODEL = None
39
+ SAM3_PREDICTOR = None
40
 
41
 
42
+ def load_sam3():
43
+ """Load SAM3 for segmentation"""
44
+ global SAM3_PREDICTOR
45
 
46
+ if SAM3_PREDICTOR is not None:
47
+ return SAM3_PREDICTOR
48
 
49
  import torch
50
+ from sam3 import SAM3Predictor
51
 
52
  print("Loading SAM3 model...")
53
 
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
 
56
+ SAM3_PREDICTOR = SAM3Predictor.from_pretrained(
57
  "facebook/sam3-hiera-large",
58
  device=device,
59
  token=os.environ.get("HF_TOKEN")
60
  )
61
 
62
+ print("βœ“ SAM3 loaded")
63
+ return SAM3_PREDICTOR
64
 
65
 
66
+ def load_sam3d():
67
  """Load SAM 3D Objects model"""
68
  global SAM3D_MODEL
69
 
 
73
  import torch
74
  print("Loading SAM 3D Objects model...")
75
 
 
76
  checkpoint_dir = snapshot_download(
77
  repo_id="facebook/sam-3d-objects",
78
  token=os.environ.get("HF_TOKEN")
 
81
  from sam_3d_objects import Sam3dObjects
82
 
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
84
  SAM3D_MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device)
85
 
86
+ print("βœ“ SAM 3D Objects loaded")
87
  return SAM3D_MODEL
88
 
89
 
90
  @spaces.GPU(duration=60)
91
+ def segment_with_text(image: np.ndarray, text_prompt: str):
92
+ """Segment object using text prompt with SAM3"""
93
+ if image is None:
94
+ return None, None, "❌ No image provided"
95
+ if not text_prompt:
96
+ return None, None, "❌ No text prompt provided"
97
 
98
+ try:
99
+ predictor = load_sam3()
 
100
 
101
+ # Run SAM3 with text prompt
102
+ predictor.set_image(image)
103
+ masks, scores, _ = predictor.predict(text=text_prompt)
104
+
105
+ if masks is None or len(masks) == 0:
106
+ return image, None, "⚠️ No object found"
107
+
108
+ # Use best mask
109
+ best_idx = np.argmax(scores)
110
+ mask = masks[best_idx]
111
+
112
+ # Create overlay
113
+ overlay = image.copy()
114
+ overlay[mask] = (overlay[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
115
+
116
+ return overlay, mask.astype(np.uint8) * 255, f"βœ“ Found: {text_prompt}"
117
+
118
+ except Exception as e:
119
+ import traceback
120
+ traceback.print_exc()
121
+ return image, None, f"❌ Error: {e}"
122
+
123
+
124
+ @spaces.GPU(duration=60)
125
+ def segment_with_click(image: np.ndarray, evt: gr.SelectData):
126
+ """Segment object at clicked point with SAM3"""
127
  if image is None:
128
+ return None, None, "❌ No image provided"
129
 
130
  try:
131
+ predictor = load_sam3()
132
 
133
  # Get click coordinates
134
  point = np.array([[evt.index[0], evt.index[1]]])
135
+ label = np.array([1]) # foreground
136
 
 
137
  predictor.set_image(image)
 
 
138
  masks, scores, _ = predictor.predict(
139
  point_coords=point,
140
  point_labels=label,
 
142
  )
143
 
144
  # Use best mask
145
+ best_idx = np.argmax(scores)
146
+ mask = masks[best_idx]
147
 
148
  # Create overlay
149
  overlay = image.copy()
150
+ overlay[mask] = (overlay[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
151
 
152
+ return overlay, mask.astype(np.uint8) * 255, "βœ“ Object selected"
153
 
154
  except Exception as e:
155
  import traceback
156
  traceback.print_exc()
157
+ return image, None, f"❌ Error: {e}"
158
 
159
 
160
  @spaces.GPU(duration=120)
161
+ def reconstruct_3d(image: np.ndarray, mask: np.ndarray):
162
  """
163
  Reconstruct 3D object from image and mask.
164
 
165
  Args:
166
  image: Input RGB image
167
+ mask: Binary mask from SAM3
168
 
169
  Returns:
170
  tuple: (glb_path, status)
 
172
  if image is None:
173
  return None, "❌ No image provided"
174
  if mask is None:
175
+ return None, "❌ No mask - segment object first"
176
 
177
  try:
178
  import torch
179
  import trimesh
 
180
 
181
+ model = load_sam3d()
 
 
182
 
183
+ # Ensure mask is binary
 
 
 
 
184
  if len(mask.shape) == 3:
185
  mask = mask[:, :, 0]
186
  mask = (mask > 127).astype(np.uint8)
187
 
188
+ # Run 3D reconstruction
189
  outputs = model.predict(image, mask)
190
 
191
  if outputs is None:
192
  return None, "⚠️ Reconstruction failed"
193
 
194
+ # Export as GLB
195
  output_dir = tempfile.mkdtemp()
196
  glb_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.glb"
197
 
198
+ # Get vertices from gaussian splat
 
199
  vertices = outputs.get_xyz().cpu().numpy()
200
 
201
+ # Export as point cloud GLB
 
202
  cloud = trimesh.PointCloud(vertices)
203
  cloud.export(glb_path, file_type='glb')
204
 
205
+ return glb_path, f"βœ“ Reconstructed ({len(vertices)} points)"
206
 
207
  except Exception as e:
208
  import traceback
 
212
 
213
  # Gradio Interface
214
  with gr.Blocks(title="SAM 3D Objects MCP") as demo:
215
+ gr.Markdown("""
216
+ # πŸ“¦ SAM 3D Objects MCP Server
217
+ **Image β†’ 3D Object (GLB)**
218
+
219
+ 1. Upload image
220
+ 2. Segment: Type what to select OR click on object
221
+ 3. Reconstruct 3D
222
+ """)
223
 
 
224
  mask_state = gr.State(None)
225
 
226
  with gr.Row():
227
  with gr.Column():
228
+ input_image = gr.Image(label="Input Image", type="numpy")
229
+
230
+ with gr.Row():
231
+ text_prompt = gr.Textbox(
232
+ label="Text Prompt",
233
+ placeholder="e.g. 'the chair', 'red car', 'coffee mug'",
234
+ scale=3
235
+ )
236
+ segment_btn = gr.Button("🎯 Segment", scale=1)
237
+
238
+ gr.Markdown("*Or click directly on the object in the image*")
239
 
240
  with gr.Column():
241
+ preview = gr.Image(label="Segmentation Preview", type="numpy", interactive=False)
242
+ status = gr.Textbox(label="Status")
243
 
244
  with gr.Row():
245
+ reconstruct_btn = gr.Button("πŸš€ Reconstruct 3D", variant="primary", size="lg")
246
 
247
  with gr.Row():
248
+ output_model = gr.Model3D(label="3D Preview")
249
+ output_file = gr.File(label="Download GLB")
250
+
251
+ # Events
252
+ segment_btn.click(
253
+ segment_with_text,
254
+ inputs=[input_image, text_prompt],
255
+ outputs=[preview, mask_state, status]
256
+ )
257
 
 
258
  input_image.select(
259
+ segment_with_click,
260
  inputs=[input_image],
261
+ outputs=[preview, mask_state, status]
262
  )
263
 
264
+ reconstruct_btn.click(
265
+ reconstruct_3d,
 
266
  inputs=[input_image, mask_state],
267
  outputs=[output_file, status]
268
  )
 
271
  ---
272
  ### MCP Server
273
  ```json
274
+ {
275
+ "mcpServers": {
276
+ "sam3d-objects": {
277
+ "command": "npx",
278
+ "args": ["mcp-remote", "https://dev-bjoern-sam3d-objects-mcp.hf.space/gradio_api/mcp/sse"]
279
+ }
280
+ }
281
+ }
282
  ```
283
  """)
284