dev-bjoern commited on
Commit
0e828b5
Β·
1 Parent(s): cd9cd46

Replace SAM3 with SAM2 for automatic mask generation

Browse files
Files changed (2) hide show
  1. app.py +13 -18
  2. requirements.txt +1 -3
app.py CHANGED
@@ -2,7 +2,7 @@
2
  SAM 3D Objects MCP Server
3
  Image β†’ 3D Object (GLB)
4
 
5
- Automatic object detection with SAM3 + 3D reconstruction with SAM 3D Objects.
6
  """
7
  import os
8
  import sys
@@ -36,29 +36,24 @@ sys.path.insert(0, str(SAM3D_PATH))
36
 
37
  # Global models
38
  SAM3D_MODEL = None
39
- SAM3_GENERATOR = None
40
 
41
 
42
- def load_sam3():
43
- """Load SAM3 automatic mask generator"""
44
- global SAM3_GENERATOR
45
 
46
- if SAM3_GENERATOR is not None:
47
- return SAM3_GENERATOR
48
 
49
- import torch
50
- from sam3.automatic_mask_generator import SAM3AutomaticMaskGenerator
51
- from sam3.model_builder import build_sam3
52
-
53
- print("Loading SAM3 model...")
54
 
55
- device = "cuda" if torch.cuda.is_available() else "cpu"
56
 
57
- sam3_model = build_sam3(device=device)
58
- SAM3_GENERATOR = SAM3AutomaticMaskGenerator(sam3_model)
59
 
60
- print("βœ“ SAM3 loaded")
61
- return SAM3_GENERATOR
62
 
63
 
64
  def load_sam3d():
@@ -105,7 +100,7 @@ def reconstruct_objects(image: np.ndarray):
105
  from PIL import Image as PILImage
106
 
107
  # Load models
108
- generator = load_sam3()
109
  sam3d = load_sam3d()
110
 
111
  # Convert to PIL if needed
 
2
  SAM 3D Objects MCP Server
3
  Image β†’ 3D Object (GLB)
4
 
5
+ Automatic object detection with SAM2 + 3D reconstruction with SAM 3D Objects.
6
  """
7
  import os
8
  import sys
 
36
 
37
  # Global models
38
  SAM3D_MODEL = None
39
+ SAM2_GENERATOR = None
40
 
41
 
42
+ def load_sam2():
43
+ """Load SAM2 automatic mask generator"""
44
+ global SAM2_GENERATOR
45
 
46
+ if SAM2_GENERATOR is not None:
47
+ return SAM2_GENERATOR
48
 
49
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
 
 
 
 
50
 
51
+ print("Loading SAM2 model...")
52
 
53
+ SAM2_GENERATOR = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-large")
 
54
 
55
+ print("βœ“ SAM2 loaded")
56
+ return SAM2_GENERATOR
57
 
58
 
59
  def load_sam3d():
 
100
  from PIL import Image as PILImage
101
 
102
  # Load models
103
+ generator = load_sam2()
104
  sam3d = load_sam3d()
105
 
106
  # Convert to PIL if needed
requirements.txt CHANGED
@@ -20,6 +20,4 @@ jaxtyping
20
  rich
21
  kaolin==0.17.0
22
  gsplat
23
- sam3>=0.1.2
24
- open_clip_torch>=2.24.0
25
- ftfy
 
20
  rich
21
  kaolin==0.17.0
22
  gsplat
23
+ sam2>=1.1.0