Update processing/two_stage/two_stage_processor.py
Browse files
processing/two_stage/two_stage_processor.py
CHANGED
|
@@ -15,6 +15,11 @@
|
|
| 15 |
* Ensures MatAnyone receives a valid first-frame mask (bootstraps the session
|
| 16 |
with the first SAM2 mask). This prevents "First frame arrived without a mask"
|
| 17 |
warnings and shape mismatches inside the stateful refiner.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
from __future__ import annotations
|
| 20 |
|
|
@@ -115,6 +120,15 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
|
|
| 115 |
'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
|
| 116 |
}
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
# ---------------------------------------------------------------------------
|
| 119 |
# Two-Stage Processor
|
| 120 |
# ---------------------------------------------------------------------------
|
|
@@ -129,6 +143,14 @@ def __init__(self, sam2_predictor=None, matanyone_model=None):
|
|
| 129 |
self._mat_bootstrapped = False
|
| 130 |
self._alpha_prev: Optional[np.ndarray] = None # temporal smoothing
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
|
| 133 |
|
| 134 |
# --------------------------- internal utils ---------------------------
|
|
@@ -197,12 +219,10 @@ def _suppress_green_spill(self, frame: np.ndarray, amount: float = 0.35) -> np.n
|
|
| 197 |
amount: 0..1
|
| 198 |
"""
|
| 199 |
b, g, r = cv2.split(frame.astype(np.float32))
|
| 200 |
-
y = 0.299*r + 0.587*g + 0.114*b # luminance (unused directly but good for future tuning)
|
| 201 |
green_dom = (g > r) & (g > b)
|
| 202 |
avg_rb = (r + b) * 0.5
|
| 203 |
g2 = np.where(green_dom, g*(1.0-amount) + avg_rb*amount, g)
|
| 204 |
-
|
| 205 |
-
skin = (r > g + 12)
|
| 206 |
g2 = np.where(skin, g, g2)
|
| 207 |
out = cv2.merge([np.clip(b,0,255), np.clip(g2,0,255), np.clip(r,0,255)]).astype(np.uint8)
|
| 208 |
return out
|
|
@@ -237,7 +257,6 @@ def _soft_key_mask(self, frame_bgr: np.ndarray, key_bgr: np.ndarray, tol: int =
|
|
| 237 |
Soft chroma mask (uint8 0..255, 255=keep subject) using CbCr distance.
|
| 238 |
"""
|
| 239 |
if key_bgr is None:
|
| 240 |
-
# fallback to keep-all (no key)
|
| 241 |
return np.full(frame_bgr.shape[:2], 255, np.uint8)
|
| 242 |
|
| 243 |
ycbcr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb).astype(np.float32)
|
|
@@ -340,8 +359,9 @@ def _prog(p, d):
|
|
| 340 |
probe_done = True
|
| 341 |
logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
|
| 342 |
|
| 343 |
-
# --- Optional refinement via MatAnyone
|
| 344 |
-
|
|
|
|
| 345 |
try:
|
| 346 |
mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
|
| 347 |
except Exception as e:
|
|
@@ -391,7 +411,7 @@ def stage2_greenscreen_to_final(
|
|
| 391 |
chroma_settings: Optional[Dict[str, Any]] = None,
|
| 392 |
progress_callback: Optional[Callable[[float, str], None]] = None,
|
| 393 |
stop_event: Optional["threading.Event"] = None,
|
| 394 |
-
key_bgr: Optional[np.ndarray] = None, #
|
| 395 |
) -> Tuple[Optional[str], str]:
|
| 396 |
|
| 397 |
def _prog(p, d):
|
|
@@ -422,6 +442,11 @@ def _prog(p, d):
|
|
| 422 |
else:
|
| 423 |
bg = cv2.resize(background, (w, h))
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
writer, out_path = create_video_writer(output_path, fps, w, h)
|
| 426 |
if writer is None:
|
| 427 |
cap.release()
|
|
@@ -438,11 +463,11 @@ def _prog(p, d):
|
|
| 438 |
except Exception as e:
|
| 439 |
logger.warning(f"Could not load mask cache: {e}")
|
| 440 |
|
| 441 |
-
# Get chroma settings
|
| 442 |
settings = chroma_settings or CHROMA_PRESETS.get('standard', {})
|
| 443 |
-
tolerance = int(settings.get('tolerance', 38))
|
| 444 |
-
edge_softness = int(settings.get('edge_softness', 2))
|
| 445 |
-
spill_suppression = float(settings.get('spill_suppression', 0.35))
|
| 446 |
|
| 447 |
# If caller didn't pass key_bgr, try preset or default green
|
| 448 |
if key_bgr is None:
|
|
@@ -529,7 +554,7 @@ def _chroma_key_composite(self, frame, bg, *, tolerance=38, edge_softness=2, spi
|
|
| 529 |
return np.clip(out, 0, 255).astype(np.uint8)
|
| 530 |
|
| 531 |
def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
|
| 532 |
-
"""Apply hybrid compositing using both chroma key and cached mask."""
|
| 533 |
chroma_result = self._chroma_key_composite(
|
| 534 |
frame, bg,
|
| 535 |
tolerance=tolerance,
|
|
@@ -537,13 +562,28 @@ def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, s
|
|
| 537 |
spill_suppression=spill_suppression,
|
| 538 |
key_bgr=key_bgr
|
| 539 |
)
|
| 540 |
-
if mask is
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
|
| 548 |
# ---------------------------------------------------------------------
|
| 549 |
# Combined pipeline
|
|
|
|
| 15 |
* Ensures MatAnyone receives a valid first-frame mask (bootstraps the session
|
| 16 |
with the first SAM2 mask). This prevents "First frame arrived without a mask"
|
| 17 |
warnings and shape mismatches inside the stateful refiner.
|
| 18 |
+
|
| 19 |
+
Quality profiles (set via env BFX_QUALITY = speed | balanced | max):
|
| 20 |
+
* refine cadence, spill suppression, edge softness
|
| 21 |
+
* hybrid matte mix (segmentation vs chroma), small dilate/blur on mask
|
| 22 |
+
* optional tiny background blur to hide seams on very flat backgrounds
|
| 23 |
"""
|
| 24 |
from __future__ import annotations
|
| 25 |
|
|
|
|
| 120 |
'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
|
| 121 |
}
|
| 122 |
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
# Quality profiles (env: BFX_QUALITY = speed | balanced | max)
|
| 125 |
+
# ---------------------------------------------------------------------------
|
| 126 |
+
QUALITY_PROFILES: Dict[str, Dict[str, Any]] = {
|
| 127 |
+
"speed": dict(refine_stride=4, spill=0.30, edge_softness=2, mix=0.60, dilate=0, blur=0, bg_sigma=0.0),
|
| 128 |
+
"balanced": dict(refine_stride=2, spill=0.40, edge_softness=2, mix=0.75, dilate=1, blur=1, bg_sigma=0.6),
|
| 129 |
+
"max": dict(refine_stride=1, spill=0.45, edge_softness=3, mix=0.85, dilate=2, blur=2, bg_sigma=1.0),
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
# ---------------------------------------------------------------------------
|
| 133 |
# Two-Stage Processor
|
| 134 |
# ---------------------------------------------------------------------------
|
|
|
|
| 143 |
self._mat_bootstrapped = False
|
| 144 |
self._alpha_prev: Optional[np.ndarray] = None # temporal smoothing
|
| 145 |
|
| 146 |
+
# Quality selection
|
| 147 |
+
qname = os.getenv("BFX_QUALITY", "balanced").strip().lower()
|
| 148 |
+
if qname not in QUALITY_PROFILES:
|
| 149 |
+
qname = "balanced"
|
| 150 |
+
self.quality = qname
|
| 151 |
+
self.q = QUALITY_PROFILES[qname]
|
| 152 |
+
logger.info(f"TwoStageProcessor quality='{self.quality}' ⇒ {self.q}")
|
| 153 |
+
|
| 154 |
logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
|
| 155 |
|
| 156 |
# --------------------------- internal utils ---------------------------
|
|
|
|
| 219 |
amount: 0..1
|
| 220 |
"""
|
| 221 |
b, g, r = cv2.split(frame.astype(np.float32))
|
|
|
|
| 222 |
green_dom = (g > r) & (g > b)
|
| 223 |
avg_rb = (r + b) * 0.5
|
| 224 |
g2 = np.where(green_dom, g*(1.0-amount) + avg_rb*amount, g)
|
| 225 |
+
skin = (r > g + 12) # protect skin tones
|
|
|
|
| 226 |
g2 = np.where(skin, g, g2)
|
| 227 |
out = cv2.merge([np.clip(b,0,255), np.clip(g2,0,255), np.clip(r,0,255)]).astype(np.uint8)
|
| 228 |
return out
|
|
|
|
| 257 |
Soft chroma mask (uint8 0..255, 255=keep subject) using CbCr distance.
|
| 258 |
"""
|
| 259 |
if key_bgr is None:
|
|
|
|
| 260 |
return np.full(frame_bgr.shape[:2], 255, np.uint8)
|
| 261 |
|
| 262 |
ycbcr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb).astype(np.float32)
|
|
|
|
| 359 |
probe_done = True
|
| 360 |
logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
|
| 361 |
|
| 362 |
+
# --- Optional refinement via MatAnyone (profile cadence) ---
|
| 363 |
+
stride = int(self.q.get("refine_stride", 3))
|
| 364 |
+
if self.matanyone and (frame_idx % max(1, stride) == 0):
|
| 365 |
try:
|
| 366 |
mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
|
| 367 |
except Exception as e:
|
|
|
|
| 411 |
chroma_settings: Optional[Dict[str, Any]] = None,
|
| 412 |
progress_callback: Optional[Callable[[float, str], None]] = None,
|
| 413 |
stop_event: Optional["threading.Event"] = None,
|
| 414 |
+
key_bgr: Optional[np.ndarray] = None, # pass chosen key color
|
| 415 |
) -> Tuple[Optional[str], str]:
|
| 416 |
|
| 417 |
def _prog(p, d):
|
|
|
|
| 442 |
else:
|
| 443 |
bg = cv2.resize(background, (w, h))
|
| 444 |
|
| 445 |
+
# Optional tiny BG blur per profile to hide seams on flat BGs
|
| 446 |
+
sigma = float(self.q.get("bg_sigma", 0.0))
|
| 447 |
+
if sigma > 0:
|
| 448 |
+
bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=sigma, sigmaY=sigma)
|
| 449 |
+
|
| 450 |
writer, out_path = create_video_writer(output_path, fps, w, h)
|
| 451 |
if writer is None:
|
| 452 |
cap.release()
|
|
|
|
| 463 |
except Exception as e:
|
| 464 |
logger.warning(f"Could not load mask cache: {e}")
|
| 465 |
|
| 466 |
+
# Get chroma settings and override with profile
|
| 467 |
settings = chroma_settings or CHROMA_PRESETS.get('standard', {})
|
| 468 |
+
tolerance = int(settings.get('tolerance', 38)) # keep user tolerance
|
| 469 |
+
edge_softness = int(self.q.get('edge_softness', settings.get('edge_softness', 2)))
|
| 470 |
+
spill_suppression = float(self.q.get('spill', settings.get('spill_suppression', 0.35)))
|
| 471 |
|
| 472 |
# If caller didn't pass key_bgr, try preset or default green
|
| 473 |
if key_bgr is None:
|
|
|
|
| 554 |
return np.clip(out, 0, 255).astype(np.uint8)
|
| 555 |
|
| 556 |
def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
|
| 557 |
+
"""Apply hybrid compositing using both chroma key and cached mask, with profile controls."""
|
| 558 |
chroma_result = self._chroma_key_composite(
|
| 559 |
frame, bg,
|
| 560 |
tolerance=tolerance,
|
|
|
|
| 562 |
spill_suppression=spill_suppression,
|
| 563 |
key_bgr=key_bgr
|
| 564 |
)
|
| 565 |
+
if mask is None:
|
| 566 |
+
return chroma_result
|
| 567 |
+
|
| 568 |
+
# profile-driven dilate/feather on cached mask to close pinholes + soften edges
|
| 569 |
+
m = mask
|
| 570 |
+
d = int(self.q.get("dilate", 0))
|
| 571 |
+
if d > 0:
|
| 572 |
+
k = 2*d + 1
|
| 573 |
+
se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
| 574 |
+
m = cv2.dilate(m, se, iterations=1)
|
| 575 |
+
b = int(self.q.get("blur", 0))
|
| 576 |
+
if b > 0:
|
| 577 |
+
m = cv2.GaussianBlur(m, (2*b+1, 2*b+1), 0)
|
| 578 |
+
|
| 579 |
+
m3 = cv2.cvtColor(m, cv2.COLOR_GRAY2BGR) if m.ndim == 2 else m
|
| 580 |
+
m3f = (m3.astype(np.float32) / 255.0)
|
| 581 |
+
|
| 582 |
+
seg_comp = frame.astype(np.float32) * m3f + bg.astype(np.float32) * (1.0 - m3f)
|
| 583 |
+
|
| 584 |
+
mix = float(self.q.get("mix", 0.7)) # weight towards segmentation on "max"
|
| 585 |
+
out = chroma_result.astype(np.float32) * (1.0 - mix) + seg_comp * mix
|
| 586 |
+
return np.clip(out, 0, 255).astype(np.uint8)
|
| 587 |
|
| 588 |
# ---------------------------------------------------------------------
|
| 589 |
# Combined pipeline
|