MogensR commited on
Commit
9e03b6b
·
1 Parent(s): b3a57d5

Update models/loaders/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/sam2_loader.py +8 -75
models/loaders/sam2_loader.py CHANGED
@@ -99,7 +99,7 @@ def _determine_optimal_size(self) -> str:
99
  return "tiny" # Conservative default
100
 
101
  def _load_official(self) -> Optional[Any]:
102
- """Load using official SAM2 API"""
103
  from sam2.sam2_image_predictor import SAM2ImagePredictor
104
 
105
  predictor = SAM2ImagePredictor.from_pretrained(
@@ -107,7 +107,6 @@ def _load_official(self) -> Optional[Any]:
107
  cache_dir=self.cache_dir,
108
  local_files_only=False,
109
  trust_remote_code=True,
110
- device=self.device, # Pass device directly
111
  )
112
 
113
  # Move to device and set to eval mode
@@ -115,79 +114,13 @@ def _load_official(self) -> Optional[Any]:
115
  predictor.model = predictor.model.to(self.device)
116
  predictor.model.eval()
117
 
118
- # Set device attribute for the predictor
119
- predictor.device = self.device
120
-
121
- # Wrap to ensure proper automatic mask generation
122
- class SAM2Wrapper:
123
- def __init__(self, predictor, device):
124
- self.predictor = predictor
125
- self.device = device
126
- self._image_set = False
127
-
128
- def set_image(self, image):
129
- """Set image for processing"""
130
- self.predictor.set_image(image)
131
- self._image_set = True
132
-
133
- def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
134
- """Generate masks with automatic detection if no prompts given"""
135
- if not self._image_set:
136
- # Auto-set image if not already done
137
- logger.warning("Image not set, returning empty mask")
138
- return {
139
- "masks": np.zeros((1, 512, 512), dtype=np.float32),
140
- "scores": np.array([0.0]),
141
- "logits": np.zeros((1, 512, 512), dtype=np.float32),
142
- }
143
-
144
- # If no prompts, generate automatic mask
145
- if point_coords is None and box is None:
146
- # Use center point as default
147
- h, w = 512, 512 # Default size
148
- point_coords = np.array([[w//2, h//2]], dtype=np.float32)
149
- point_labels = np.array([1], dtype=np.int32)
150
-
151
- return self.predictor.predict(
152
- point_coords=point_coords,
153
- point_labels=point_labels,
154
- box=box,
155
- **kwargs
156
- )
157
-
158
- def generate_automatic_masks(self, image):
159
- """Generate masks automatically for the entire image"""
160
- self.set_image(image)
161
- # Generate with points in a grid
162
- h, w = image.shape[:2]
163
- points = []
164
- labels = []
165
-
166
- # Create a grid of points
167
- for y in range(h//4, h, h//2):
168
- for x in range(w//4, w, w//2):
169
- points.append([x, y])
170
- labels.append(1)
171
-
172
- if points:
173
- masks, scores, logits = self.predictor.predict(
174
- point_coords=np.array(points, dtype=np.float32),
175
- point_labels=np.array(labels, dtype=np.int32),
176
- multimask_output=True
177
- )
178
-
179
- # Return best mask
180
- if len(scores) > 0:
181
- best_idx = scores.argmax()
182
- return masks[best_idx], scores[best_idx]
183
-
184
- return np.ones((h, w), dtype=np.float32), 1.0
185
-
186
- def __getattr__(self, name):
187
- """Forward other attributes to predictor"""
188
- return getattr(self.predictor, name)
189
-
190
- return SAM2Wrapper(predictor, self.device)
191
 
192
  def _load_transformers(self) -> Optional[Any]:
193
  """Load using transformers library"""
 
99
  return "tiny" # Conservative default
100
 
101
  def _load_official(self) -> Optional[Any]:
102
+ """Load using official SAM2 API - return directly without wrapper"""
103
  from sam2.sam2_image_predictor import SAM2ImagePredictor
104
 
105
  predictor = SAM2ImagePredictor.from_pretrained(
 
107
  cache_dir=self.cache_dir,
108
  local_files_only=False,
109
  trust_remote_code=True,
 
110
  )
111
 
112
  # Move to device and set to eval mode
 
114
  predictor.model = predictor.model.to(self.device)
115
  predictor.model.eval()
116
 
117
+ # Set device attribute if it exists
118
+ if hasattr(predictor, "device"):
119
+ predictor.device = self.device
120
+
121
+ # Return the predictor directly - no wrapper!
122
+ # The calling code expects the standard SAM2 interface
123
+ return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def _load_transformers(self) -> Optional[Any]:
126
  """Load using transformers library"""