dev-bjoern commited on
Commit
cd9cd46
Β·
1 Parent(s): c82fe65

Auto object detection: SAM3 finds objects automatically, no click needed

Browse files
Files changed (1) hide show
  1. app.py +55 -173
app.py CHANGED
@@ -1,8 +1,8 @@
1
  """
2
  SAM 3D Objects MCP Server
3
- Image + Text/Click β†’ 3D Object (GLB)
4
 
5
- Uses SAM3 for segmentation and SAM 3D Objects for 3D reconstruction.
6
  """
7
  import os
8
  import sys
@@ -36,27 +36,29 @@ sys.path.insert(0, str(SAM3D_PATH))
36
 
37
  # Global models
38
  SAM3D_MODEL = None
39
- SAM3_PREDICTOR = None
40
 
41
 
42
  def load_sam3():
43
- """Load SAM3 for segmentation"""
44
- global SAM3_PREDICTOR
45
 
46
- if SAM3_PREDICTOR is not None:
47
- return SAM3_PREDICTOR
48
 
49
  import torch
50
- from sam3.model_builder import build_sam3_image_model
51
- from sam3.model.sam3_image_processor import Sam3Processor
52
 
53
  print("Loading SAM3 model...")
54
 
55
- model = build_sam3_image_model()
56
- SAM3_PREDICTOR = Sam3Processor(model)
 
 
57
 
58
  print("βœ“ SAM3 loaded")
59
- return SAM3_PREDICTOR
60
 
61
 
62
  def load_sam3d():
@@ -83,142 +85,58 @@ def load_sam3d():
83
  return SAM3D_MODEL
84
 
85
 
86
- @spaces.GPU(duration=60)
87
- def segment_with_text(image: np.ndarray, text_prompt: str):
88
- """Segment object using text prompt with SAM3"""
89
- if image is None:
90
- return None, None, "❌ No image provided"
91
- if not text_prompt:
92
- return None, None, "❌ No text prompt provided"
93
-
94
- try:
95
- from PIL import Image as PILImage
96
- processor = load_sam3()
97
-
98
- # Convert to PIL
99
- if isinstance(image, np.ndarray):
100
- pil_image = PILImage.fromarray(image)
101
- else:
102
- pil_image = image
103
-
104
- # Run SAM3 with text prompt
105
- state = processor.set_image(pil_image)
106
- output = processor.set_text_prompt(state=state, prompt=text_prompt)
107
-
108
- if output is None or "masks" not in output:
109
- return image, None, "⚠️ No object found"
110
-
111
- masks = output["masks"]
112
- scores = output.get("scores", [1.0])
113
-
114
- if len(masks) == 0:
115
- return image, None, "⚠️ No object found"
116
-
117
- # Use best mask
118
- best_idx = np.argmax(scores) if len(scores) > 0 else 0
119
- mask = np.array(masks[best_idx])
120
-
121
- # Create overlay
122
- overlay = image.copy()
123
- overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
124
-
125
- return overlay, (mask > 0).astype(np.uint8) * 255, f"βœ“ Found: {text_prompt}"
126
-
127
- except Exception as e:
128
- import traceback
129
- traceback.print_exc()
130
- return image, None, f"❌ Error: {e}"
131
-
132
-
133
- def handle_click(image, evt: gr.SelectData):
134
- """Handle click event and extract coordinates"""
135
- if image is None or evt is None:
136
- return None, None, None, "❌ Click on an image first"
137
- # Store coordinates and pass to GPU function
138
- x, y = evt.index[0], evt.index[1]
139
- return image, x, y, "Processing..."
140
 
 
 
141
 
142
- @spaces.GPU(duration=60)
143
- def segment_with_point(image: np.ndarray, x: int, y: int):
144
- """Segment object at point with SAM3"""
145
  if image is None:
146
  return None, None, "❌ No image provided"
147
- if x is None or y is None:
148
- return None, None, "❌ No point selected"
149
 
150
  try:
 
 
151
  from PIL import Image as PILImage
152
- processor = load_sam3()
153
 
154
- # Convert to PIL
 
 
 
 
155
  if isinstance(image, np.ndarray):
156
  pil_image = PILImage.fromarray(image)
157
  else:
158
  pil_image = image
 
159
 
160
- # Run SAM3 with point prompt
161
- state = processor.set_image(pil_image)
162
- output = processor.set_point_prompt(state=state, points=[[x, y]], labels=[1])
163
 
164
- if output is None or "masks" not in output:
165
- return image, None, "⚠️ No object found"
166
 
167
- masks = output["masks"]
168
- scores = output.get("scores", [1.0])
 
169
 
170
- if len(masks) == 0:
171
- return image, None, "⚠️ No object found"
 
172
 
173
- # Use best mask
174
- best_idx = np.argmax(scores) if len(scores) > 0 else 0
175
- mask = np.array(masks[best_idx])
176
-
177
- # Create overlay
178
- overlay = image.copy()
179
- overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
180
-
181
- return overlay, (mask > 0).astype(np.uint8) * 255, "βœ“ Object selected"
182
-
183
- except Exception as e:
184
- import traceback
185
- traceback.print_exc()
186
- return image, None, f"❌ Error: {e}"
187
-
188
-
189
- @spaces.GPU(duration=120)
190
- def reconstruct_3d(image: np.ndarray, mask: np.ndarray):
191
- """
192
- Reconstruct 3D object from image and mask.
193
-
194
- Args:
195
- image: Input RGB image
196
- mask: Binary mask from SAM3
197
-
198
- Returns:
199
- tuple: (glb_path, status)
200
- """
201
- if image is None:
202
- return None, "❌ No image provided"
203
- if mask is None:
204
- return None, "❌ No mask - segment object first"
205
-
206
- try:
207
- import torch
208
- import trimesh
209
-
210
- model = load_sam3d()
211
-
212
- # Ensure mask is binary
213
- if len(mask.shape) == 3:
214
- mask = mask[:, :, 0]
215
- mask = (mask > 127).astype(np.uint8)
216
-
217
- # Run 3D reconstruction
218
- outputs = model.predict(image, mask)
219
 
220
  if outputs is None:
221
- return None, "⚠️ Reconstruction failed"
222
 
223
  # Export as GLB
224
  output_dir = tempfile.mkdtemp()
@@ -231,12 +149,12 @@ def reconstruct_3d(image: np.ndarray, mask: np.ndarray):
231
  cloud = trimesh.PointCloud(vertices)
232
  cloud.export(glb_path, file_type='glb')
233
 
234
- return glb_path, f"βœ“ Reconstructed ({len(vertices)} points)"
235
 
236
  except Exception as e:
237
  import traceback
238
  traceback.print_exc()
239
- return None, f"❌ Error: {e}"
240
 
241
 
242
  # Gradio Interface
@@ -245,64 +163,28 @@ with gr.Blocks(title="SAM 3D Objects MCP") as demo:
245
  # πŸ“¦ SAM 3D Objects MCP Server
246
  **Image β†’ 3D Object (GLB)**
247
 
248
- 1. Upload image
249
- 2. Segment: Type what to select OR click on object
250
- 3. Reconstruct 3D
251
  """)
252
 
253
- mask_state = gr.State(None)
254
- click_x = gr.State(None)
255
- click_y = gr.State(None)
256
-
257
  with gr.Row():
258
  with gr.Column():
259
  input_image = gr.Image(label="Input Image", type="numpy")
260
-
261
- with gr.Row():
262
- text_prompt = gr.Textbox(
263
- label="Text Prompt",
264
- placeholder="e.g. 'the chair', 'red car', 'coffee mug'",
265
- scale=3
266
- )
267
- segment_btn = gr.Button("🎯 Segment", scale=1)
268
-
269
- gr.Markdown("*Or click directly on the object in the image*")
270
 
271
  with gr.Column():
272
- preview = gr.Image(label="Segmentation Preview", type="numpy", interactive=False)
273
  status = gr.Textbox(label="Status")
274
 
275
- with gr.Row():
276
- reconstruct_btn = gr.Button("πŸš€ Reconstruct 3D", variant="primary", size="lg")
277
-
278
  with gr.Row():
279
  with gr.Column():
280
  output_model = gr.Model3D(label="3D Preview")
281
  with gr.Column():
282
  output_file = gr.File(label="Download GLB")
283
 
284
- # Events
285
- segment_btn.click(
286
- segment_with_text,
287
- inputs=[input_image, text_prompt],
288
- outputs=[preview, mask_state, status]
289
- )
290
-
291
- # Click handler: first extract coordinates (no GPU), then segment (GPU)
292
- input_image.select(
293
- handle_click,
294
  inputs=[input_image],
295
- outputs=[input_image, click_x, click_y, status]
296
- ).then(
297
- segment_with_point,
298
- inputs=[input_image, click_x, click_y],
299
- outputs=[preview, mask_state, status]
300
- )
301
-
302
- reconstruct_btn.click(
303
- reconstruct_3d,
304
- inputs=[input_image, mask_state],
305
- outputs=[output_model, status]
306
  )
307
  output_model.change(lambda x: x, inputs=[output_model], outputs=[output_file])
308
 
 
1
  """
2
  SAM 3D Objects MCP Server
3
+ Image β†’ 3D Object (GLB)
4
 
5
+ Automatic object detection with SAM3 + 3D reconstruction with SAM 3D Objects.
6
  """
7
  import os
8
  import sys
 
36
 
37
  # Global models
38
  SAM3D_MODEL = None
39
+ SAM3_GENERATOR = None
40
 
41
 
42
  def load_sam3():
43
+ """Load SAM3 automatic mask generator"""
44
+ global SAM3_GENERATOR
45
 
46
+ if SAM3_GENERATOR is not None:
47
+ return SAM3_GENERATOR
48
 
49
  import torch
50
+ from sam3.automatic_mask_generator import SAM3AutomaticMaskGenerator
51
+ from sam3.model_builder import build_sam3
52
 
53
  print("Loading SAM3 model...")
54
 
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+
57
+ sam3_model = build_sam3(device=device)
58
+ SAM3_GENERATOR = SAM3AutomaticMaskGenerator(sam3_model)
59
 
60
  print("βœ“ SAM3 loaded")
61
+ return SAM3_GENERATOR
62
 
63
 
64
  def load_sam3d():
 
85
  return SAM3D_MODEL
86
 
87
 
88
+ @spaces.GPU(duration=120)
89
+ def reconstruct_objects(image: np.ndarray):
90
+ """
91
+ Automatically detect and reconstruct 3D objects from image.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ Args:
94
+ image: Input RGB image
95
 
96
+ Returns:
97
+ tuple: (glb_path, preview_image, status)
98
+ """
99
  if image is None:
100
  return None, None, "❌ No image provided"
 
 
101
 
102
  try:
103
+ import torch
104
+ import trimesh
105
  from PIL import Image as PILImage
 
106
 
107
+ # Load models
108
+ generator = load_sam3()
109
+ sam3d = load_sam3d()
110
+
111
+ # Convert to PIL if needed
112
  if isinstance(image, np.ndarray):
113
  pil_image = PILImage.fromarray(image)
114
  else:
115
  pil_image = image
116
+ image = np.array(pil_image)
117
 
118
+ # Auto-detect all objects
119
+ print("Detecting objects...")
120
+ masks = generator.generate(pil_image)
121
 
122
+ if not masks or len(masks) == 0:
123
+ return None, image, "⚠️ No objects detected"
124
 
125
+ # Sort by area, take largest object
126
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
127
+ best_mask = masks[0]['segmentation']
128
 
129
+ # Create preview with mask overlay
130
+ preview = image.copy()
131
+ preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
132
 
133
+ # Run 3D reconstruction on largest object
134
+ print("Reconstructing 3D...")
135
+ mask_uint8 = best_mask.astype(np.uint8)
136
+ outputs = sam3d.predict(image, mask_uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  if outputs is None:
139
+ return None, preview, "⚠️ 3D reconstruction failed"
140
 
141
  # Export as GLB
142
  output_dir = tempfile.mkdtemp()
 
149
  cloud = trimesh.PointCloud(vertices)
150
  cloud.export(glb_path, file_type='glb')
151
 
152
+ return glb_path, preview, f"βœ“ Detected {len(masks)} objects, reconstructed largest ({len(vertices)} points)"
153
 
154
  except Exception as e:
155
  import traceback
156
  traceback.print_exc()
157
+ return None, None, f"❌ Error: {e}"
158
 
159
 
160
  # Gradio Interface
 
163
  # πŸ“¦ SAM 3D Objects MCP Server
164
  **Image β†’ 3D Object (GLB)**
165
 
166
+ Automatically detects objects and reconstructs the largest one in 3D.
 
 
167
  """)
168
 
 
 
 
 
169
  with gr.Row():
170
  with gr.Column():
171
  input_image = gr.Image(label="Input Image", type="numpy")
172
+ btn = gr.Button("πŸš€ Detect & Reconstruct", variant="primary", size="lg")
 
 
 
 
 
 
 
 
 
173
 
174
  with gr.Column():
175
+ preview = gr.Image(label="Detected Object", type="numpy", interactive=False)
176
  status = gr.Textbox(label="Status")
177
 
 
 
 
178
  with gr.Row():
179
  with gr.Column():
180
  output_model = gr.Model3D(label="3D Preview")
181
  with gr.Column():
182
  output_file = gr.File(label="Download GLB")
183
 
184
+ btn.click(
185
+ reconstruct_objects,
 
 
 
 
 
 
 
 
186
  inputs=[input_image],
187
+ outputs=[output_model, preview, status]
 
 
 
 
 
 
 
 
 
 
188
  )
189
  output_model.change(lambda x: x, inputs=[output_model], outputs=[output_file])
190