MogensR commited on
Commit
7b9fcf8
·
1 Parent(s): fb41e40

Update processing/two_stage/two_stage_processor.py

Browse files
processing/two_stage/two_stage_processor.py CHANGED
@@ -15,6 +15,11 @@
15
  * Ensures MatAnyone receives a valid first-frame mask (bootstraps the session
16
  with the first SAM2 mask). This prevents "First frame arrived without a mask"
17
  warnings and shape mismatches inside the stateful refiner.
 
 
 
 
 
18
  """
19
  from __future__ import annotations
20
 
@@ -115,6 +120,15 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
115
  'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
116
  }
117
 
 
 
 
 
 
 
 
 
 
118
  # ---------------------------------------------------------------------------
119
  # Two-Stage Processor
120
  # ---------------------------------------------------------------------------
@@ -129,6 +143,14 @@ def __init__(self, sam2_predictor=None, matanyone_model=None):
129
  self._mat_bootstrapped = False
130
  self._alpha_prev: Optional[np.ndarray] = None # temporal smoothing
131
 
 
 
 
 
 
 
 
 
132
  logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
133
 
134
  # --------------------------- internal utils ---------------------------
@@ -197,12 +219,10 @@ def _suppress_green_spill(self, frame: np.ndarray, amount: float = 0.35) -> np.n
197
  amount: 0..1
198
  """
199
  b, g, r = cv2.split(frame.astype(np.float32))
200
- y = 0.299*r + 0.587*g + 0.114*b # luminance (unused directly but good for future tuning)
201
  green_dom = (g > r) & (g > b)
202
  avg_rb = (r + b) * 0.5
203
  g2 = np.where(green_dom, g*(1.0-amount) + avg_rb*amount, g)
204
- # protect skin tones (red significantly above green)
205
- skin = (r > g + 12)
206
  g2 = np.where(skin, g, g2)
207
  out = cv2.merge([np.clip(b,0,255), np.clip(g2,0,255), np.clip(r,0,255)]).astype(np.uint8)
208
  return out
@@ -237,7 +257,6 @@ def _soft_key_mask(self, frame_bgr: np.ndarray, key_bgr: np.ndarray, tol: int =
237
  Soft chroma mask (uint8 0..255, 255=keep subject) using CbCr distance.
238
  """
239
  if key_bgr is None:
240
- # fallback to keep-all (no key)
241
  return np.full(frame_bgr.shape[:2], 255, np.uint8)
242
 
243
  ycbcr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb).astype(np.float32)
@@ -340,8 +359,9 @@ def _prog(p, d):
340
  probe_done = True
341
  logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
342
 
343
- # --- Optional refinement via MatAnyone every few frames ---
344
- if self.matanyone and (frame_idx % 3 == 0):
 
345
  try:
346
  mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
347
  except Exception as e:
@@ -391,7 +411,7 @@ def stage2_greenscreen_to_final(
391
  chroma_settings: Optional[Dict[str, Any]] = None,
392
  progress_callback: Optional[Callable[[float, str], None]] = None,
393
  stop_event: Optional["threading.Event"] = None,
394
- key_bgr: Optional[np.ndarray] = None, # <-- NEW: pass chosen key color
395
  ) -> Tuple[Optional[str], str]:
396
 
397
  def _prog(p, d):
@@ -422,6 +442,11 @@ def _prog(p, d):
422
  else:
423
  bg = cv2.resize(background, (w, h))
424
 
 
 
 
 
 
425
  writer, out_path = create_video_writer(output_path, fps, w, h)
426
  if writer is None:
427
  cap.release()
@@ -438,11 +463,11 @@ def _prog(p, d):
438
  except Exception as e:
439
  logger.warning(f"Could not load mask cache: {e}")
440
 
441
- # Get chroma settings
442
  settings = chroma_settings or CHROMA_PRESETS.get('standard', {})
443
- tolerance = int(settings.get('tolerance', 38))
444
- edge_softness = int(settings.get('edge_softness', 2))
445
- spill_suppression = float(settings.get('spill_suppression', 0.35))
446
 
447
  # If caller didn't pass key_bgr, try preset or default green
448
  if key_bgr is None:
@@ -529,7 +554,7 @@ def _chroma_key_composite(self, frame, bg, *, tolerance=38, edge_softness=2, spi
529
  return np.clip(out, 0, 255).astype(np.uint8)
530
 
531
  def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
532
- """Apply hybrid compositing using both chroma key and cached mask."""
533
  chroma_result = self._chroma_key_composite(
534
  frame, bg,
535
  tolerance=tolerance,
@@ -537,13 +562,28 @@ def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, s
537
  spill_suppression=spill_suppression,
538
  key_bgr=key_bgr
539
  )
540
- if mask is not None:
541
- mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) if mask.ndim == 2 else mask
542
- mask_norm = mask_3ch.astype(np.float32) / 255.0
543
- guided = frame.astype(np.float32) * mask_norm + bg.astype(np.float32) * (1.0 - mask_norm)
544
- result = chroma_result.astype(np.float32) * 0.7 + guided * 0.3
545
- return np.clip(result, 0, 255).astype(np.uint8)
546
- return chroma_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
  # ---------------------------------------------------------------------
549
  # Combined pipeline
 
15
  * Ensures MatAnyone receives a valid first-frame mask (bootstraps the session
16
  with the first SAM2 mask). This prevents "First frame arrived without a mask"
17
  warnings and shape mismatches inside the stateful refiner.
18
+
19
+ Quality profiles (set via env BFX_QUALITY = speed | balanced | max):
20
+ * refine cadence, spill suppression, edge softness
21
+ * hybrid matte mix (segmentation vs chroma), small dilate/blur on mask
22
+ * optional tiny background blur to hide seams on very flat backgrounds
23
  """
24
  from __future__ import annotations
25
 
 
120
  'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
121
  }
122
 
123
+ # ---------------------------------------------------------------------------
124
+ # Quality profiles (env: BFX_QUALITY = speed | balanced | max)
125
+ # ---------------------------------------------------------------------------
126
+ QUALITY_PROFILES: Dict[str, Dict[str, Any]] = {
127
+ "speed": dict(refine_stride=4, spill=0.30, edge_softness=2, mix=0.60, dilate=0, blur=0, bg_sigma=0.0),
128
+ "balanced": dict(refine_stride=2, spill=0.40, edge_softness=2, mix=0.75, dilate=1, blur=1, bg_sigma=0.6),
129
+ "max": dict(refine_stride=1, spill=0.45, edge_softness=3, mix=0.85, dilate=2, blur=2, bg_sigma=1.0),
130
+ }
131
+
132
  # ---------------------------------------------------------------------------
133
  # Two-Stage Processor
134
  # ---------------------------------------------------------------------------
 
143
  self._mat_bootstrapped = False
144
  self._alpha_prev: Optional[np.ndarray] = None # temporal smoothing
145
 
146
+ # Quality selection
147
+ qname = os.getenv("BFX_QUALITY", "balanced").strip().lower()
148
+ if qname not in QUALITY_PROFILES:
149
+ qname = "balanced"
150
+ self.quality = qname
151
+ self.q = QUALITY_PROFILES[qname]
152
+ logger.info(f"TwoStageProcessor quality='{self.quality}' ⇒ {self.q}")
153
+
154
  logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
155
 
156
  # --------------------------- internal utils ---------------------------
 
219
  amount: 0..1
220
  """
221
  b, g, r = cv2.split(frame.astype(np.float32))
 
222
  green_dom = (g > r) & (g > b)
223
  avg_rb = (r + b) * 0.5
224
  g2 = np.where(green_dom, g*(1.0-amount) + avg_rb*amount, g)
225
+ skin = (r > g + 12) # protect skin tones
 
226
  g2 = np.where(skin, g, g2)
227
  out = cv2.merge([np.clip(b,0,255), np.clip(g2,0,255), np.clip(r,0,255)]).astype(np.uint8)
228
  return out
 
257
  Soft chroma mask (uint8 0..255, 255=keep subject) using CbCr distance.
258
  """
259
  if key_bgr is None:
 
260
  return np.full(frame_bgr.shape[:2], 255, np.uint8)
261
 
262
  ycbcr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb).astype(np.float32)
 
359
  probe_done = True
360
  logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
361
 
362
+ # --- Optional refinement via MatAnyone (profile cadence) ---
363
+ stride = int(self.q.get("refine_stride", 3))
364
+ if self.matanyone and (frame_idx % max(1, stride) == 0):
365
  try:
366
  mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
367
  except Exception as e:
 
411
  chroma_settings: Optional[Dict[str, Any]] = None,
412
  progress_callback: Optional[Callable[[float, str], None]] = None,
413
  stop_event: Optional["threading.Event"] = None,
414
+ key_bgr: Optional[np.ndarray] = None, # pass chosen key color
415
  ) -> Tuple[Optional[str], str]:
416
 
417
  def _prog(p, d):
 
442
  else:
443
  bg = cv2.resize(background, (w, h))
444
 
445
+ # Optional tiny BG blur per profile to hide seams on flat BGs
446
+ sigma = float(self.q.get("bg_sigma", 0.0))
447
+ if sigma > 0:
448
+ bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=sigma, sigmaY=sigma)
449
+
450
  writer, out_path = create_video_writer(output_path, fps, w, h)
451
  if writer is None:
452
  cap.release()
 
463
  except Exception as e:
464
  logger.warning(f"Could not load mask cache: {e}")
465
 
466
+ # Get chroma settings and override with profile
467
  settings = chroma_settings or CHROMA_PRESETS.get('standard', {})
468
+ tolerance = int(settings.get('tolerance', 38)) # keep user tolerance
469
+ edge_softness = int(self.q.get('edge_softness', settings.get('edge_softness', 2)))
470
+ spill_suppression = float(self.q.get('spill', settings.get('spill_suppression', 0.35)))
471
 
472
  # If caller didn't pass key_bgr, try preset or default green
473
  if key_bgr is None:
 
554
  return np.clip(out, 0, 255).astype(np.uint8)
555
 
556
  def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
557
+ """Apply hybrid compositing using both chroma key and cached mask, with profile controls."""
558
  chroma_result = self._chroma_key_composite(
559
  frame, bg,
560
  tolerance=tolerance,
 
562
  spill_suppression=spill_suppression,
563
  key_bgr=key_bgr
564
  )
565
+ if mask is None:
566
+ return chroma_result
567
+
568
+ # profile-driven dilate/feather on cached mask to close pinholes + soften edges
569
+ m = mask
570
+ d = int(self.q.get("dilate", 0))
571
+ if d > 0:
572
+ k = 2*d + 1
573
+ se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
574
+ m = cv2.dilate(m, se, iterations=1)
575
+ b = int(self.q.get("blur", 0))
576
+ if b > 0:
577
+ m = cv2.GaussianBlur(m, (2*b+1, 2*b+1), 0)
578
+
579
+ m3 = cv2.cvtColor(m, cv2.COLOR_GRAY2BGR) if m.ndim == 2 else m
580
+ m3f = (m3.astype(np.float32) / 255.0)
581
+
582
+ seg_comp = frame.astype(np.float32) * m3f + bg.astype(np.float32) * (1.0 - m3f)
583
+
584
+ mix = float(self.q.get("mix", 0.7)) # weight towards segmentation on "max"
585
+ out = chroma_result.astype(np.float32) * (1.0 - mix) + seg_comp * mix
586
+ return np.clip(out, 0, 255).astype(np.uint8)
587
 
588
  # ---------------------------------------------------------------------
589
  # Combined pipeline