MogensR commited on
Commit
345218c
·
1 Parent(s): 2eef9e8

Update models/loaders/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/sam2_loader.py +73 -4
models/loaders/sam2_loader.py CHANGED
@@ -107,6 +107,7 @@ def _load_official(self) -> Optional[Any]:
107
  cache_dir=self.cache_dir,
108
  local_files_only=False,
109
  trust_remote_code=True,
 
110
  )
111
 
112
  # Move to device and set to eval mode
@@ -115,10 +116,78 @@ def _load_official(self) -> Optional[Any]:
115
  predictor.model.eval()
116
 
117
  # Set device attribute for the predictor
118
- if hasattr(predictor, "device"):
119
- predictor.device = self.device
120
-
121
- return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def _load_transformers(self) -> Optional[Any]:
124
  """Load using transformers library"""
 
107
  cache_dir=self.cache_dir,
108
  local_files_only=False,
109
  trust_remote_code=True,
110
+ device=self.device, # Pass device directly
111
  )
112
 
113
  # Move to device and set to eval mode
 
116
  predictor.model.eval()
117
 
118
  # Set device attribute for the predictor
119
+ predictor.device = self.device
120
+
121
+ # Wrap to ensure proper automatic mask generation
122
+ class SAM2Wrapper:
123
+ def __init__(self, predictor, device):
124
+ self.predictor = predictor
125
+ self.device = device
126
+ self._image_set = False
127
+
128
+ def set_image(self, image):
129
+ """Set image for processing"""
130
+ self.predictor.set_image(image)
131
+ self._image_set = True
132
+
133
+ def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
134
+ """Generate masks with automatic detection if no prompts given"""
135
+ if not self._image_set:
136
+ # Auto-set image if not already done
137
+ logger.warning("Image not set, returning empty mask")
138
+ return {
139
+ "masks": np.zeros((1, 512, 512), dtype=np.float32),
140
+ "scores": np.array([0.0]),
141
+ "logits": np.zeros((1, 512, 512), dtype=np.float32),
142
+ }
143
+
144
+ # If no prompts, generate automatic mask
145
+ if point_coords is None and box is None:
146
+ # Use center point as default
147
+ h, w = 512, 512 # Default size
148
+ point_coords = np.array([[w//2, h//2]], dtype=np.float32)
149
+ point_labels = np.array([1], dtype=np.int32)
150
+
151
+ return self.predictor.predict(
152
+ point_coords=point_coords,
153
+ point_labels=point_labels,
154
+ box=box,
155
+ **kwargs
156
+ )
157
+
158
+ def generate_automatic_masks(self, image):
159
+ """Generate masks automatically for the entire image"""
160
+ self.set_image(image)
161
+ # Generate with points in a grid
162
+ h, w = image.shape[:2]
163
+ points = []
164
+ labels = []
165
+
166
+ # Create a grid of points
167
+ for y in range(h//4, h, h//2):
168
+ for x in range(w//4, w, w//2):
169
+ points.append([x, y])
170
+ labels.append(1)
171
+
172
+ if points:
173
+ masks, scores, logits = self.predictor.predict(
174
+ point_coords=np.array(points, dtype=np.float32),
175
+ point_labels=np.array(labels, dtype=np.int32),
176
+ multimask_output=True
177
+ )
178
+
179
+ # Return best mask
180
+ if len(scores) > 0:
181
+ best_idx = scores.argmax()
182
+ return masks[best_idx], scores[best_idx]
183
+
184
+ return np.ones((h, w), dtype=np.float32), 1.0
185
+
186
+ def __getattr__(self, name):
187
+ """Forward other attributes to predictor"""
188
+ return getattr(self.predictor, name)
189
+
190
+ return SAM2Wrapper(predictor, self.device)
191
 
192
  def _load_transformers(self) -> Optional[Any]:
193
  """Load using transformers library"""