MogensR commited on
Commit
6bc7492
·
1 Parent(s): b1313ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -22
app.py CHANGED
@@ -58,11 +58,12 @@ def setup_gpu():
58
 
59
  logger.info(f"Device: {DEVICE} | GPU: {GPU_NAME} | Memory: {GPU_MEMORY:.1f}GB | Type: {GPU_TYPE}")
60
 
61
- # SAM2 Lazy Loader with Enhanced Performance
62
- class SAM2EnhancedLazy:
63
  def __init__(self):
64
  self.predictor = None
65
  self.current_model_size = None
 
66
  self.model_cache_dir = Path(tempfile.gettempdir()) / "sam2_cache"
67
  self.model_cache_dir.mkdir(exist_ok=True)
68
 
@@ -99,10 +100,114 @@ def clear_model(self):
99
  self.predictor = None
100
  self.current_model_size = None
101
 
 
 
 
 
102
  if CUDA_AVAILABLE:
103
  torch.cuda.empty_cache()
104
  gc.collect()
105
- logger.info("SAM2 model cleared from memory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def download_model(self, model_size, progress_fn=None):
108
  """Download model with progress tracking and verification"""
@@ -128,7 +233,7 @@ def download_model(self, model_size, progress_fn=None):
128
  downloaded += len(chunk)
129
  if progress_fn and total_size > 0:
130
  progress = downloaded / total_size * 0.15 # 15% of total progress
131
- progress_fn(progress, f"Downloading SAM2 {model_size} ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)")
132
 
133
  logger.info(f"SAM2 {model_size} downloaded successfully")
134
  return model_path
@@ -142,6 +247,9 @@ def download_model(self, model_size, progress_fn=None):
142
  def load_model(self, model_size, progress_fn=None):
143
  """Load SAM2 model with optimization"""
144
  try:
 
 
 
145
  # Import SAM2 (lazy import to avoid import errors if not available)
146
  try:
147
  from sam2.build_sam import build_sam2
@@ -153,7 +261,7 @@ def load_model(self, model_size, progress_fn=None):
153
  model_path = self.download_model(model_size, progress_fn)
154
 
155
  if progress_fn:
156
- progress_fn(0.2, f"Loading SAM2 {model_size} model...")
157
 
158
  # Build model
159
  model_config = self.models[model_size]["config"]
@@ -168,9 +276,9 @@ def load_model(self, model_size, progress_fn=None):
168
  self.current_model_size = model_size
169
 
170
  if progress_fn:
171
- progress_fn(0.25, f"SAM2 {model_size} loaded successfully!")
172
 
173
- logger.info(f"SAM2 {model_size} model loaded and ready")
174
  return self.predictor
175
 
176
  except Exception as e:
@@ -185,26 +293,35 @@ def get_predictor(self, model_size="small", progress_fn=None):
185
  return self.load_model(model_size, progress_fn)
186
  return self.predictor
187
 
188
- def segment_image(self, image, model_size="small", progress_fn=None):
189
- """Segment image with SAM2"""
190
  predictor = self.get_predictor(model_size, progress_fn)
191
 
192
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  predictor.set_image(image)
194
- h, w = image.shape[:2]
195
-
196
- # Smart point selection for better segmentation
197
- center_points = [
198
- [w//2, h//2], # Center
199
- [w//2, h//3], # Upper center
200
- [w//2, 2*h//3], # Lower center
201
- [w//3, h//2], # Left center
202
- [2*w//3, h//2] # Right center
203
- ]
204
 
205
- point_coords = np.array(center_points)
206
  point_labels = np.ones(len(point_coords))
207
 
 
 
 
208
  masks, scores, logits = predictor.predict(
209
  point_coords=point_coords,
210
  point_labels=point_labels,
@@ -216,15 +333,23 @@ def segment_image(self, image, model_size="small", progress_fn=None):
216
  best_mask = masks[best_mask_idx]
217
  best_score = scores[best_mask_idx]
218
 
219
- # Post-process mask for better edges
220
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
221
  best_mask = cv2.morphologyEx(best_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
 
 
222
  best_mask = cv2.GaussianBlur(best_mask.astype(np.float32), (3, 3), 1.0)
223
 
 
 
 
 
 
 
224
  return best_mask, float(best_score)
225
 
226
  except Exception as e:
227
- logger.error(f"Segmentation failed: {e}")
228
  return None, 0.0
229
 
230
  # MatAnyone Professional Video Matting
 
58
 
59
  logger.info(f"Device: {DEVICE} | GPU: {GPU_NAME} | Memory: {GPU_MEMORY:.1f}GB | Type: {GPU_TYPE}")
60
 
61
+ # Enhanced SAM2 with Person Detection and Tracking
62
+ class SAM2WithPersonDetection:
63
  def __init__(self):
64
  self.predictor = None
65
  self.current_model_size = None
66
+ self.person_detector = None
67
  self.model_cache_dir = Path(tempfile.gettempdir()) / "sam2_cache"
68
  self.model_cache_dir.mkdir(exist_ok=True)
69
 
 
100
  self.predictor = None
101
  self.current_model_size = None
102
 
103
+ if self.person_detector:
104
+ del self.person_detector
105
+ self.person_detector = None
106
+
107
  if CUDA_AVAILABLE:
108
  torch.cuda.empty_cache()
109
  gc.collect()
110
+ logger.info("SAM2 model and person detector cleared from memory")
111
+
112
+ def load_person_detector(self, progress_fn=None):
113
+ """Load lightweight person detector"""
114
+ if self.person_detector is not None:
115
+ return self.person_detector
116
+
117
+ try:
118
+ if progress_fn:
119
+ progress_fn(0.05, "Loading person detector...")
120
+
121
+ # Use OpenCV DNN with MobileNet for fast person detection
122
+ import cv2
123
+
124
+ # Create a simple person detector using OpenCV's built-in methods
125
+ # This is lightweight and doesn't require additional models
126
+ self.person_detector = cv2.createBackgroundSubtractorMOG2(detectShadows=True)
127
+
128
+ if progress_fn:
129
+ progress_fn(0.1, "Person detector loaded!")
130
+
131
+ logger.info("Person detector loaded successfully")
132
+ return self.person_detector
133
+
134
+ except Exception as e:
135
+ logger.warning(f"Failed to load person detector: {e}")
136
+ self.person_detector = None
137
+ return None
138
+
139
+ def detect_person_bbox(self, image, progress_fn=None):
140
+ """Detect person bounding box in image"""
141
+ try:
142
+ # Method 1: Use simple contour detection for person-like shapes
143
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
144
+
145
+ # Apply GaussianBlur to reduce noise
146
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
147
+
148
+ # Use edge detection to find contours
149
+ edges = cv2.Canny(blurred, 50, 150)
150
+
151
+ # Find contours
152
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
153
+
154
+ if not contours:
155
+ return None
156
+
157
+ # Find the largest contour (likely the main subject)
158
+ largest_contour = max(contours, key=cv2.contourArea)
159
+
160
+ # Get bounding box of largest contour
161
+ x, y, w, h = cv2.boundingRect(largest_contour)
162
+
163
+ # Filter out too small or too large bounding boxes
164
+ image_area = image.shape[0] * image.shape[1]
165
+ bbox_area = w * h
166
+
167
+ # Person should be 5-80% of image
168
+ if bbox_area < image_area * 0.05 or bbox_area > image_area * 0.8:
169
+ return None
170
+
171
+ # Ensure reasonable aspect ratio for person (height > width)
172
+ if h < w * 0.8: # Person should be taller than wide
173
+ return None
174
+
175
+ return [x, y, x + w, y + h]
176
+
177
+ except Exception as e:
178
+ logger.warning(f"Person detection failed: {e}")
179
+ return None
180
+
181
+ def get_smart_points_from_bbox(self, bbox, image_shape):
182
+ """Generate smart points within person bounding box"""
183
+ if bbox is None:
184
+ # Fallback to grid points across entire image
185
+ h, w = image_shape[:2]
186
+ return [
187
+ [w//4, h//3], [w//2, h//3], [3*w//4, h//3],
188
+ [w//4, h//2], [w//2, h//2], [3*w//4, h//2],
189
+ [w//4, 2*h//3], [w//2, 2*h//3], [3*w//4, 2*h//3]
190
+ ]
191
+
192
+ x1, y1, x2, y2 = bbox
193
+ center_x = (x1 + x2) // 2
194
+ center_y = (y1 + y2) // 2
195
+ width = x2 - x1
196
+ height = y2 - y1
197
+
198
+ # Generate points within the person's bounding box
199
+ points = [
200
+ [center_x, center_y], # Center of person
201
+ [center_x, y1 + height//4], # Upper torso/head
202
+ [center_x, y1 + height//2], # Mid torso
203
+ [center_x, y1 + 3*height//4], # Lower torso
204
+ [x1 + width//4, center_y], # Left side
205
+ [x2 - width//4, center_y], # Right side
206
+ [center_x - width//6, y1 + height//3], # Left shoulder area
207
+ [center_x + width//6, y1 + height//3], # Right shoulder area
208
+ ]
209
+
210
+ return points
211
 
212
  def download_model(self, model_size, progress_fn=None):
213
  """Download model with progress tracking and verification"""
 
233
  downloaded += len(chunk)
234
  if progress_fn and total_size > 0:
235
  progress = downloaded / total_size * 0.15 # 15% of total progress
236
+ progress_fn(0.1 + progress, f"Downloading SAM2 {model_size} ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)")
237
 
238
  logger.info(f"SAM2 {model_size} downloaded successfully")
239
  return model_path
 
247
  def load_model(self, model_size, progress_fn=None):
248
  """Load SAM2 model with optimization"""
249
  try:
250
+ # Load person detector first
251
+ self.load_person_detector(progress_fn)
252
+
253
  # Import SAM2 (lazy import to avoid import errors if not available)
254
  try:
255
  from sam2.build_sam import build_sam2
 
261
  model_path = self.download_model(model_size, progress_fn)
262
 
263
  if progress_fn:
264
+ progress_fn(0.25, f"Loading SAM2 {model_size} model...")
265
 
266
  # Build model
267
  model_config = self.models[model_size]["config"]
 
276
  self.current_model_size = model_size
277
 
278
  if progress_fn:
279
+ progress_fn(0.3, f"SAM2 {model_size} with person detection ready!")
280
 
281
+ logger.info(f"SAM2 {model_size} model with person detection loaded and ready")
282
  return self.predictor
283
 
284
  except Exception as e:
 
293
  return self.load_model(model_size, progress_fn)
294
  return self.predictor
295
 
296
+ def segment_image_smart(self, image, model_size="small", progress_fn=None):
297
+ """Smart segmentation: Find person first, then segment"""
298
  predictor = self.get_predictor(model_size, progress_fn)
299
 
300
  try:
301
+ if progress_fn:
302
+ progress_fn(0.32, "Finding person in image...")
303
+
304
+ # Step 1: Detect person bounding box
305
+ person_bbox = self.detect_person_bbox(image, progress_fn)
306
+
307
+ if progress_fn:
308
+ if person_bbox:
309
+ progress_fn(0.35, f"Person found! Segmenting with high precision...")
310
+ else:
311
+ progress_fn(0.35, f"Using grid search for segmentation...")
312
+
313
+ # Step 2: Generate smart points based on person location
314
+ smart_points = self.get_smart_points_from_bbox(person_bbox, image.shape)
315
+
316
+ # Step 3: Set image and predict with smart points
317
  predictor.set_image(image)
 
 
 
 
 
 
 
 
 
 
318
 
319
+ point_coords = np.array(smart_points)
320
  point_labels = np.ones(len(point_coords))
321
 
322
+ if progress_fn:
323
+ progress_fn(0.38, f"SAM2 segmenting with {len(smart_points)} smart points...")
324
+
325
  masks, scores, logits = predictor.predict(
326
  point_coords=point_coords,
327
  point_labels=point_labels,
 
333
  best_mask = masks[best_mask_idx]
334
  best_score = scores[best_mask_idx]
335
 
336
+ # Enhanced post-processing for better edges
337
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
338
  best_mask = cv2.morphologyEx(best_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
339
+
340
+ # Apply gentle blur for smoother edges
341
  best_mask = cv2.GaussianBlur(best_mask.astype(np.float32), (3, 3), 1.0)
342
 
343
+ # If we found a person bbox, boost confidence
344
+ if person_bbox and best_score > 0.3:
345
+ best_score = min(best_score * 1.5, 1.0) # Boost confidence
346
+
347
+ logger.info(f"Smart segmentation complete: confidence={best_score:.3f}, person_detected={person_bbox is not None}")
348
+
349
  return best_mask, float(best_score)
350
 
351
  except Exception as e:
352
+ logger.error(f"Smart segmentation failed: {e}")
353
  return None, 0.0
354
 
355
  # MatAnyone Professional Video Matting