MogensR commited on
Commit
6fdc616
·
1 Parent(s): 0d03497

Create utils/segmentation.py

Browse files
Files changed (1) hide show
  1. utils/segmentation.py +293 -0
utils/segmentation.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ utils.segmentation
4
+ ─────────────────────────────────────────────────────────────────────────────
5
+ All high-quality person-segmentation code for BackgroundFX Pro.
6
+
7
+ Exports
8
+ -------
9
+ segment_person_hq(image, predictor, fallback_enabled=True) → np.ndarray
10
+ segment_person_hq_original(image, predictor, fallback_enabled=True) → np.ndarray
11
+
12
+ Everything else is prefixed “_” and considered private.
13
+ """
14
+
15
+ from __future__ import annotations
16
+ from typing import Any, Tuple, Optional, Dict
17
+ import logging, os, math
18
+
19
+ import cv2
20
+ import numpy as np
21
+ import torch
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+ # ============================================================================
26
+ # TUNABLE CONSTANTS
27
+ # ============================================================================
28
+ USE_ENHANCED_SEGMENTATION = True
29
+ USE_INTELLIGENT_PROMPTING = True
30
+ USE_ITERATIVE_REFINEMENT = True
31
+
32
+ MIN_AREA_RATIO = 0.015
33
+ MAX_AREA_RATIO = 0.97
34
+ SALIENCY_THRESH = 0.65
35
+ GRABCUT_ITERS = 3
36
+
37
+ # ----------------------------------------------------------------------------
38
+ # Public -- main entry-points
39
+ # ----------------------------------------------------------------------------
40
+ __all__ = [
41
+ "segment_person_hq",
42
+ "segment_person_hq_original",
43
+ ]
44
+
45
+ # ============================================================================
46
+ # MAIN API
47
+ # ============================================================================
48
+
49
+ def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
50
+ """
51
+ High-quality person segmentation. Tries SAM-2 with smart prompts first,
52
+ then a classical CV cascade, then a geometric fallback.
53
+ Returns uint8 mask (0/255). Never raises if fallback_enabled=True.
54
+ """
55
+ if not USE_ENHANCED_SEGMENTATION:
56
+ return segment_person_hq_original(image, predictor, fallback_enabled)
57
+
58
+ if image is None or image.size == 0:
59
+ raise ValueError("Invalid input image")
60
+
61
+ # 1) — SAM-2 path -------------------------------------------------------
62
+ if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
63
+ try:
64
+ predictor.set_image(image)
65
+ mask = (
66
+ _segment_with_intelligent_prompts(image, predictor)
67
+ if USE_INTELLIGENT_PROMPTING
68
+ else _segment_with_basic_prompts(image, predictor)
69
+ )
70
+ if USE_ITERATIVE_REFINEMENT:
71
+ mask = _auto_refine_mask_iteratively(image, mask, predictor)
72
+ if _validate_mask_quality(mask, image.shape[:2]):
73
+ return mask
74
+ log.warning("SAM2 mask failed validation → fallback")
75
+ except Exception as e:
76
+ log.warning(f"SAM2 path failed: {e}")
77
+
78
+ # 2) — Classical cascade ----------------------------------------------
79
+ try:
80
+ mask = _classical_segmentation_cascade(image)
81
+ if _validate_mask_quality(mask, image.shape[:2]):
82
+ return mask
83
+ log.warning("Classical cascade weak → geometric fallback")
84
+ except Exception as e:
85
+ log.debug(f"Classical cascade error: {e}")
86
+
87
+ # 3) — Last-chance geometric ellipse ----------------------------------
88
+ return _geometric_person_mask(image)
89
+
90
+
91
+ def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
92
+ """
93
+ Very first implementation kept for rollback. Fewer smarts, still robust.
94
+ """
95
+ if image is None or image.size == 0:
96
+ raise ValueError("Invalid input image")
97
+
98
+ try:
99
+ if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
100
+ h, w = image.shape[:2]
101
+ predictor.set_image(image)
102
+
103
+ points = np.array([
104
+ [w//2, h//4],
105
+ [w//2, h//2],
106
+ [w//2, 3*h//4],
107
+ [w//3, h//2],
108
+ [2*w//3, h//2],
109
+ ], dtype=np.float32)
110
+ labels = np.ones(len(points), np.int32)
111
+
112
+ with torch.no_grad():
113
+ masks, scores, _ = predictor.predict(
114
+ point_coords=points,
115
+ point_labels=labels,
116
+ multimask_output=True,
117
+ )
118
+ if masks is not None and len(masks):
119
+ mask = _process_mask(masks[int(np.argmax(scores))])
120
+ if _validate_mask_quality(mask, image.shape[:2]):
121
+ return mask
122
+ if fallback_enabled:
123
+ return _classical_segmentation_cascade(image)
124
+ raise RuntimeError("SAM2 failed and fallback disabled")
125
+ except Exception as e:
126
+ log.warning(f"segment_person_hq_original error: {e}")
127
+ return _classical_segmentation_cascade(image)
128
+
129
+
130
+ # ============================================================================
131
+ # INTELLIGENT + BASIC PROMPTING
132
+ # ============================================================================
133
+
134
+ def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
135
+ pos, neg = _generate_smart_prompts(image)
136
+ return _sam2_predict(image, predictor, pos, neg)
137
+
138
+
139
+ def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
140
+ h, w = image.shape[:2]
141
+ pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32)
142
+ neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32)
143
+ return _sam2_predict(image, predictor, pos, neg)
144
+
145
+
146
+ def _sam2_predict(image: np.ndarray, predictor: Any,
147
+ pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray:
148
+ if pos_points.size == 0:
149
+ pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32)
150
+ points = np.vstack([pos_points, neg_points])
151
+ labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32)
152
+ with torch.no_grad():
153
+ masks, scores, _ = predictor.predict(
154
+ point_coords=points,
155
+ point_labels=labels,
156
+ multimask_output=True,
157
+ )
158
+ if masks is None or len(masks) == 0:
159
+ raise RuntimeError("SAM2 produced no masks")
160
+ best = masks[int(np.argmax(scores))] if scores is not None else masks[0]
161
+ return _process_mask(best)
162
+
163
+
164
+ def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
165
+ """
166
+ Simple saliency-based heuristic to auto-place positive / negative points.
167
+ """
168
+ h, w = image.shape[:2]
169
+ sal = _compute_saliency(image)
170
+ pos, neg = [], []
171
+ if sal is not None:
172
+ high = sal > (SALIENCY_THRESH - .1)
173
+ contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
174
+ for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
175
+ M = cv2.moments(c)
176
+ if M["m00"]:
177
+ pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])])
178
+ if not pos:
179
+ pos = [[w//2, h//2]]
180
+ neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]]
181
+ return np.asarray(pos, np.float32), np.asarray(neg, np.float32)
182
+
183
+ # ============================================================================
184
+ # CLASSICAL SEGMENTATION CASCADE
185
+ # ============================================================================
186
+
187
+ def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray:
188
+ """
189
+ Edge-median background subtraction → saliency flood-fill → GrabCut.
190
+ """
191
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
192
+ edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]])
193
+ diff = np.abs(gray.astype(float) - np.median(edge_px))
194
+ mask = (diff > 30).astype(np.uint8) * 255
195
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE,
196
+ cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
197
+ if _validate_mask_quality(mask, image.shape[:2]):
198
+ return mask
199
+ # Saliency + flood-fill
200
+ mask = _refine_with_saliency(image, mask)
201
+ if _validate_mask_quality(mask, image.shape[:2]):
202
+ return mask
203
+ # GrabCut
204
+ mask = _refine_with_grabcut(image, mask)
205
+ if _validate_mask_quality(mask, image.shape[:2]):
206
+ return mask
207
+ # Geometric fallback
208
+ return _geometric_person_mask(image)
209
+
210
+ # Saliency, GrabCut helpers --------------------------------------------------
211
+
212
+ def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]:
213
+ try:
214
+ if hasattr(cv2, "saliency"):
215
+ s = cv2.saliency.StaticSaliencySpectralResidual_create()
216
+ ok, smap = s.computeSaliency(image)
217
+ if ok:
218
+ smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min())
219
+ return smap
220
+ except Exception:
221
+ pass
222
+ return None
223
+
224
+ def _auto_person_rect(image):
225
+ sal = _compute_saliency(image)
226
+ if sal is None:
227
+ return None
228
+ m = (sal > SALIENCY_THRESH).astype(np.uint8)
229
+ cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
230
+ if not cnts:
231
+ return None
232
+ x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
233
+ H,W = image.shape[:2]
234
+ pad = 0.05
235
+ x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad))
236
+ w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad)))
237
+ return x,y,w,h
238
+
239
+ def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
240
+ h,w = image.shape[:2]
241
+ gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
242
+ gc[seed>200] = cv2.GC_FGD
243
+ rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7))
244
+ bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64)
245
+ cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK)
246
+ return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
247
+
248
+ def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
249
+ sal = _compute_saliency(image)
250
+ if sal is None:
251
+ return seed
252
+ high = (sal > SALIENCY_THRESH).astype(np.uint8)*255
253
+ ys,xs = np.where(seed>127)
254
+ cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2
255
+ ff = high.copy()
256
+ cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5)
257
+ return ff
258
+
259
+ # ============================================================================
260
+ # QUALITY / HELPER FUNCTIONS
261
+ # ============================================================================
262
+
263
+ def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool:
264
+ h,w = shape
265
+ ratio = np.sum(mask>127)/(h*w)
266
+ return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO
267
+
268
+ def _process_mask(mask: np.ndarray) -> np.ndarray:
269
+ if mask.dtype in (np.float32, np.float64):
270
+ if mask.max() <= 1.0:
271
+ mask = (mask*255).astype(np.uint8)
272
+ if mask.dtype != np.uint8:
273
+ mask = mask.astype(np.uint8)
274
+ if mask.ndim == 3:
275
+ mask = mask.squeeze()
276
+ if mask.ndim == 3: # multi-channel mask → collapse
277
+ mask = mask[:,:,0]
278
+ _,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
279
+ return mask
280
+
281
+ def _geometric_person_mask(image: np.ndarray) -> np.ndarray:
282
+ h,w = image.shape[:2]
283
+ mask = np.zeros((h,w), np.uint8)
284
+ cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1)
285
+ return mask
286
+
287
+ # ============================================================================
288
+ # OPTIONAL: Iterative auto-refinement (lightweight)
289
+ # ============================================================================
290
+
291
+ def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1):
292
+ # Simple one-pass hook (full version lives in refinement.py)
293
+ return mask