dev-bjoern commited on
Commit
98d11cc
Β·
1 Parent(s): 822ace9

Fix SAM3 imports and API usage

Browse files
Files changed (1) hide show
  1. app.py +50 -29
app.py CHANGED
@@ -47,17 +47,13 @@ def load_sam3():
47
  return SAM3_PREDICTOR
48
 
49
  import torch
50
- from sam3 import SAM3Predictor
 
51
 
52
  print("Loading SAM3 model...")
53
 
54
- device = "cuda" if torch.cuda.is_available() else "cpu"
55
-
56
- SAM3_PREDICTOR = SAM3Predictor.from_pretrained(
57
- "facebook/sam3-hiera-large",
58
- device=device,
59
- token=os.environ.get("HF_TOKEN")
60
- )
61
 
62
  print("βœ“ SAM3 loaded")
63
  return SAM3_PREDICTOR
@@ -96,24 +92,37 @@ def segment_with_text(image: np.ndarray, text_prompt: str):
96
  return None, None, "❌ No text prompt provided"
97
 
98
  try:
99
- predictor = load_sam3()
 
 
 
 
 
 
 
100
 
101
  # Run SAM3 with text prompt
102
- predictor.set_image(image)
103
- masks, scores, _ = predictor.predict(text=text_prompt)
 
 
 
104
 
105
- if masks is None or len(masks) == 0:
 
 
 
106
  return image, None, "⚠️ No object found"
107
 
108
  # Use best mask
109
- best_idx = np.argmax(scores)
110
- mask = masks[best_idx]
111
 
112
  # Create overlay
113
  overlay = image.copy()
114
- overlay[mask] = (overlay[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
115
 
116
- return overlay, mask.astype(np.uint8) * 255, f"βœ“ Found: {text_prompt}"
117
 
118
  except Exception as e:
119
  import traceback
@@ -128,28 +137,40 @@ def segment_with_click(image: np.ndarray, evt: gr.SelectData):
128
  return None, None, "❌ No image provided"
129
 
130
  try:
131
- predictor = load_sam3()
 
 
 
 
 
 
 
132
 
133
  # Get click coordinates
134
- point = np.array([[evt.index[0], evt.index[1]]])
135
- label = np.array([1]) # foreground
136
 
137
- predictor.set_image(image)
138
- masks, scores, _ = predictor.predict(
139
- point_coords=point,
140
- point_labels=label,
141
- multimask_output=True
142
- )
 
 
 
 
 
 
143
 
144
  # Use best mask
145
- best_idx = np.argmax(scores)
146
- mask = masks[best_idx]
147
 
148
  # Create overlay
149
  overlay = image.copy()
150
- overlay[mask] = (overlay[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
151
 
152
- return overlay, mask.astype(np.uint8) * 255, "βœ“ Object selected"
153
 
154
  except Exception as e:
155
  import traceback
 
47
  return SAM3_PREDICTOR
48
 
49
  import torch
50
+ from sam3.model_builder import build_sam3_image_model
51
+ from sam3.model.sam3_image_processor import Sam3Processor
52
 
53
  print("Loading SAM3 model...")
54
 
55
+ model = build_sam3_image_model()
56
+ SAM3_PREDICTOR = Sam3Processor(model)
 
 
 
 
 
57
 
58
  print("βœ“ SAM3 loaded")
59
  return SAM3_PREDICTOR
 
92
  return None, None, "❌ No text prompt provided"
93
 
94
  try:
95
+ from PIL import Image as PILImage
96
+ processor = load_sam3()
97
+
98
+ # Convert to PIL
99
+ if isinstance(image, np.ndarray):
100
+ pil_image = PILImage.fromarray(image)
101
+ else:
102
+ pil_image = image
103
 
104
  # Run SAM3 with text prompt
105
+ state = processor.set_image(pil_image)
106
+ output = processor.set_text_prompt(state=state, prompt=text_prompt)
107
+
108
+ if output is None or "masks" not in output:
109
+ return image, None, "⚠️ No object found"
110
 
111
+ masks = output["masks"]
112
+ scores = output.get("scores", [1.0])
113
+
114
+ if len(masks) == 0:
115
  return image, None, "⚠️ No object found"
116
 
117
  # Use best mask
118
+ best_idx = np.argmax(scores) if len(scores) > 0 else 0
119
+ mask = np.array(masks[best_idx])
120
 
121
  # Create overlay
122
  overlay = image.copy()
123
+ overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
124
 
125
+ return overlay, (mask > 0).astype(np.uint8) * 255, f"βœ“ Found: {text_prompt}"
126
 
127
  except Exception as e:
128
  import traceback
 
137
  return None, None, "❌ No image provided"
138
 
139
  try:
140
+ from PIL import Image as PILImage
141
+ processor = load_sam3()
142
+
143
+ # Convert to PIL
144
+ if isinstance(image, np.ndarray):
145
+ pil_image = PILImage.fromarray(image)
146
+ else:
147
+ pil_image = image
148
 
149
  # Get click coordinates
150
+ point = [evt.index[0], evt.index[1]]
 
151
 
152
+ # Run SAM3 with point prompt
153
+ state = processor.set_image(pil_image)
154
+ output = processor.set_point_prompt(state=state, points=[point], labels=[1])
155
+
156
+ if output is None or "masks" not in output:
157
+ return image, None, "⚠️ No object found"
158
+
159
+ masks = output["masks"]
160
+ scores = output.get("scores", [1.0])
161
+
162
+ if len(masks) == 0:
163
+ return image, None, "⚠️ No object found"
164
 
165
  # Use best mask
166
+ best_idx = np.argmax(scores) if len(scores) > 0 else 0
167
+ mask = np.array(masks[best_idx])
168
 
169
  # Create overlay
170
  overlay = image.copy()
171
+ overlay[mask > 0] = (overlay[mask > 0] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
172
 
173
+ return overlay, (mask > 0).astype(np.uint8) * 255, "βœ“ Object selected"
174
 
175
  except Exception as e:
176
  import traceback