MogensR commited on
Commit
acad788
·
1 Parent(s): 48e80b4

Update processing/two_stage/two_stage_processor.py

Browse files
processing/two_stage/two_stage_processor.py CHANGED
@@ -1,26 +1,25 @@
1
  #!/usr/bin/env python3
2
  """
3
- Two-Stage Green-Screen Processing System ✅ 2025-08-29
4
  Stage 1: Original → keyed background (auto-selected colour)
5
- Stage 2: Keyed video → final composite (hybrid chroma + segmentation rescue)
6
-
7
  UPDATED: Enhanced quality profiles, improved frame handling, better status reporting
 
 
 
 
8
  """
9
  from __future__ import annotations
10
-
11
  import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
12
  from pathlib import Path
13
- from typing import Optional, Dict, Any, Callable, Tuple, List
14
-
15
  from utils.cv_processing import segment_person_hq, refine_mask_hq
16
-
17
  # Project logger if available
18
  try:
19
  from utils.logger import get_logger
20
  logger = get_logger(__name__)
21
  except Exception:
22
  logger = logging.getLogger(__name__)
23
-
24
  # ---------------------------------------------------------------------------
25
  # Local video-writer helper with frame count guarantee
26
  # ---------------------------------------------------------------------------
@@ -34,7 +33,6 @@ def create_video_writer(output_path: str, fps: float, width: int, height: int, p
34
  base, curr_ext = os.path.splitext(output_path)
35
  if curr_ext.lower() not in [".mp4", ".avi", ".mov", ".mkv"]:
36
  output_path = base + ext
37
-
38
  fourcc = cv2.VideoWriter_fourcc(*("mp4v" if prefer_mp4 else "XVID"))
39
  writer = cv2.VideoWriter(output_path, fourcc, float(fps), (int(width), int(height)))
40
  if writer is None or not writer.isOpened():
@@ -49,28 +47,27 @@ def create_video_writer(output_path: str, fps: float, width: int, height: int, p
49
  except Exception as e:
50
  logger.error(f"create_video_writer failed: {e}")
51
  return None, output_path
52
-
53
  # ---------------------------------------------------------------------------
54
  # Robust video writer wrapper to prevent frame loss
55
  # ---------------------------------------------------------------------------
56
  class RobustVideoWriter:
57
  """Wrapper that ensures all frames are written"""
58
-
59
  def __init__(self, writer, output_path: str):
60
  self.writer = writer
61
  self.output_path = output_path
62
  self.frame_buffer = []
63
  self.frames_written = 0
64
  self.frames_attempted = 0
65
-
66
  def write(self, frame):
67
  """Buffer and write frame"""
68
  if frame is None:
69
  return False
70
-
71
  self.frames_attempted += 1
72
  self.frame_buffer.append(frame.copy())
73
-
74
  # Try to write buffered frames
75
  while self.frame_buffer and self.writer:
76
  try:
@@ -81,7 +78,7 @@ def write(self, frame):
81
  logger.warning(f"Frame write failed: {e}")
82
  return False
83
  return True
84
-
85
  def release(self):
86
  """Flush remaining frames and close"""
87
  # Write any remaining buffered frames
@@ -92,14 +89,31 @@ def release(self):
92
  self.frames_written += 1
93
  except Exception:
94
  break
95
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Close writer
97
  if self.writer:
98
  self.writer.release()
99
-
100
  # Log statistics
101
  logger.info(f"Video writer closed: {self.frames_written}/{self.frames_attempted} frames written")
102
-
103
  # Verify output exists
104
  if os.path.exists(self.output_path):
105
  size = os.path.getsize(self.output_path)
@@ -107,7 +121,6 @@ def release(self):
107
  logger.error("WARNING: Output file is empty!")
108
  else:
109
  logger.info(f"Output file size: {size:,} bytes")
110
-
111
  # ---------------------------------------------------------------------------
112
  # Key-colour helpers (fast, no external deps)
113
  # ---------------------------------------------------------------------------
@@ -115,32 +128,27 @@ def _bgr_to_hsv_hue_deg(bgr: np.ndarray) -> np.ndarray:
115
  hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
116
  # OpenCV H is 0-180; scale to degrees 0-360
117
  return hsv[..., 0].astype(np.float32) * 2.0
118
-
119
  def _hue_distance(a_deg: float, b_deg: float) -> float:
120
  """Circular distance on the hue wheel (degrees)."""
121
  d = abs(a_deg - b_deg) % 360.0
122
  return min(d, 360.0 - d)
123
-
124
  def _key_candidates_bgr() -> dict:
125
  return {
126
- "green": {"bgr": np.array([ 0,255, 0], dtype=np.uint8), "hue": 120.0},
127
- "blue": {"bgr": np.array([255, 0, 0], dtype=np.uint8), "hue": 240.0},
128
- "cyan": {"bgr": np.array([255,255, 0], dtype=np.uint8), "hue": 180.0},
129
- "magenta": {"bgr": np.array([255, 0,255], dtype=np.uint8), "hue": 300.0},
130
  }
131
-
132
  def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dict:
133
  """Pick the candidate colour farthest from the actor's dominant hues."""
134
  try:
135
  fg = frame_bgr[mask_uint8 > 127]
136
  if fg.size < 1_000:
137
  return _key_candidates_bgr()["green"]
138
-
139
  fg_hue = _bgr_to_hsv_hue_deg(fg.reshape(-1, 1, 3)).reshape(-1)
140
  hist, edges = np.histogram(fg_hue, bins=36, range=(0.0, 360.0))
141
  top_idx = np.argsort(hist)[-3:]
142
  top_hues = [(edges[i] + edges[i+1]) * 0.5 for i in top_idx]
143
-
144
  best_name, best_score = None, -1.0
145
  for name, info in _key_candidates_bgr().items():
146
  cand_hue = info["hue"]
@@ -150,49 +158,14 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
150
  return _key_candidates_bgr().get(best_name, _key_candidates_bgr()["green"])
151
  except Exception:
152
  return _key_candidates_bgr()["green"]
153
-
154
  # ---------------------------------------------------------------------------
155
  # Chroma presets
156
  # ---------------------------------------------------------------------------
157
  CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
158
  'standard': {'key_color': [0,255,0], 'tolerance': 38, 'edge_softness': 2, 'spill_suppression': 0.35},
159
- 'studio': {'key_color': [0,255,0], 'tolerance': 30, 'edge_softness': 1, 'spill_suppression': 0.45},
160
- 'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
161
  }
162
-
163
- # ---------------------------------------------------------------------------
164
- # ENHANCED Quality profiles with clear differentiation
165
- # ---------------------------------------------------------------------------
166
- QUALITY_PROFILES: Dict[str, Dict[str, Any]] = {
167
- "speed": dict(
168
- refine_stride=10, # Refine every 10th frame only
169
- spill=0.15, # Minimal spill work
170
- edge_softness=1, # Basic edges
171
- mix=0.50, # 50/50 chroma/segmentation
172
- dilate=1, # Minimal morphology
173
- blur=0, # No blur
174
- bg_sigma=0.0 # No background blur
175
- ),
176
- "balanced": dict(
177
- refine_stride=3, # Refine every 3rd frame
178
- spill=0.35, # Moderate spill removal
179
- edge_softness=2, # Smooth edges
180
- mix=0.70, # Favor segmentation (70%)
181
- dilate=2, # Some hole filling
182
- blur=1, # Light feathering
183
- bg_sigma=0.8 # Subtle background blur
184
- ),
185
- "max": dict(
186
- refine_stride=1, # Refine EVERY frame
187
- spill=0.50, # Strong spill removal
188
- edge_softness=3, # Very smooth edges
189
- mix=0.85, # Heavy segmentation bias (85%)
190
- dilate=3, # Strong hole filling
191
- blur=2, # More feathering
192
- bg_sigma=1.5 # Visible background blur
193
- ),
194
- }
195
-
196
  # ---------------------------------------------------------------------------
197
  # Two-Stage Processor
198
  # ---------------------------------------------------------------------------
@@ -202,29 +175,18 @@ def __init__(self, sam2_predictor=None, matanyone_model=None):
202
  self.matanyone = matanyone_model
203
  self.mask_cache_dir = Path("/tmp/mask_cache")
204
  self.mask_cache_dir.mkdir(parents=True, exist_ok=True)
205
-
206
  # Internal flags/state
207
  self._mat_bootstrapped = False
208
- self._alpha_prev: Optional[np.ndarray] = None # temporal smoothing
209
-
210
  # Frame tracking
211
  self.total_frames_processed = 0
212
  self.frames_refined = 0
213
-
214
- # Quality selection at construction
215
- qname = os.getenv("BFX_QUALITY", "balanced").strip().lower()
216
- if qname not in QUALITY_PROFILES:
217
- qname = "balanced"
218
- self.quality = qname
219
- self.q = QUALITY_PROFILES[qname]
220
-
221
- # Log quality details
222
- logger.info(f"TwoStageProcessor quality='{self.quality}' ⇒ refine_every={self.q['refine_stride']}, "
223
- f"spill={self.q['spill']:.2f}, mix={self.q['mix']:.2f}, bg_blur={self.q['bg_sigma']:.1f}")
224
  logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
225
-
226
  # --------------------------- internal utils ---------------------------
227
-
228
  def _unwrap_sam2(self, predictor):
229
  """Unwrap the SAM2 predictor if needed."""
230
  if predictor is None:
@@ -232,20 +194,6 @@ def _unwrap_sam2(self, predictor):
232
  if hasattr(predictor, 'sam_predictor'):
233
  return predictor.sam_predictor
234
  return predictor
235
-
236
- def _refresh_quality_from_env(self):
237
- """Pick up UI changes to BFX_QUALITY without rebuilding models."""
238
- qname = os.getenv("BFX_QUALITY", self.quality).strip().lower()
239
- if qname not in QUALITY_PROFILES:
240
- qname = "balanced"
241
- if qname != getattr(self, "quality", None) or not hasattr(self, "q"):
242
- old_quality = self.quality
243
- self.quality = qname
244
- self.q = QUALITY_PROFILES[qname]
245
- logger.info(f"Quality switched from '{old_quality}' to '{self.quality}' ⇒ "
246
- f"refine_every={self.q['refine_stride']}, spill={self.q['spill']:.2f}, "
247
- f"mix={self.q['mix']:.2f}, bg_blur={self.q['bg_sigma']:.1f}")
248
-
249
  def _get_mask(self, frame: np.ndarray) -> np.ndarray:
250
  """Get segmentation mask using SAM2 (delegates to project helper)."""
251
  if self.sam2 is None:
@@ -253,7 +201,6 @@ def _get_mask(self, frame: np.ndarray) -> np.ndarray:
253
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
254
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
255
  return mask
256
-
257
  try:
258
  mask = segment_person_hq(frame, self.sam2)
259
  # segment_person_hq returns either uint8(0..255) or float(0..1) in most builds
@@ -263,7 +210,6 @@ def _get_mask(self, frame: np.ndarray) -> np.ndarray:
263
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
264
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
265
  return mask
266
-
267
  @staticmethod
268
  def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
269
  """Convert mask to uint8(0..255)."""
@@ -275,7 +221,6 @@ def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
275
  m = np.clip(mask, 0.0, 1.0)
276
  return (m * 255.0 + 0.5).astype(np.uint8)
277
  return mask
278
-
279
  @staticmethod
280
  def _to_float01(mask: np.ndarray, h: int = None, w: int = None) -> Optional[np.ndarray]:
281
  """Float [0,1] mask, optionally resized to (h,w)."""
@@ -287,14 +232,13 @@ def _to_float01(mask: np.ndarray, h: int = None, w: int = None) -> Optional[np.n
287
  if h is not None and w is not None and (m.shape[0] != h or m.shape[1] != w):
288
  m = cv2.resize(m, (w, h), interpolation=cv2.INTER_LINEAR)
289
  return np.clip(m, 0.0, 1.0)
290
-
291
- def _apply_greenscreen_hard(self, frame: np.ndarray, mask: np.ndarray, bg: np.ndarray) -> np.ndarray:
292
  """Apply hard greenscreen compositing."""
293
  mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) if mask.ndim == 2 else mask
294
  mask_norm = mask_3ch.astype(np.float32) / 255.0
295
  result = frame * mask_norm + bg * (1 - mask_norm)
296
  return result.astype(np.uint8)
297
-
298
  # -------- improved spill suppression (preserves luminance & skin) --------
299
  def _suppress_green_spill(self, frame: np.ndarray, amount: float = 0.35) -> np.ndarray:
300
  """
@@ -305,11 +249,10 @@ def _suppress_green_spill(self, frame: np.ndarray, amount: float = 0.35) -> np.n
305
  green_dom = (g > r) & (g > b)
306
  avg_rb = (r + b) * 0.5
307
  g2 = np.where(green_dom, g*(1.0-amount) + avg_rb*amount, g)
308
- skin = (r > g + 12) # protect skin tones
309
  g2 = np.where(skin, g, g2)
310
  out = cv2.merge([np.clip(b,0,255), np.clip(g2,0,255), np.clip(r,0,255)]).astype(np.uint8)
311
  return out
312
-
313
  # -------- edge-aware alpha refinement (guided-like) --------
314
  def _refine_alpha_edges(self, frame_bgr: np.ndarray, alpha_u8: np.ndarray, radius: int = 3, iters: int = 1) -> np.ndarray:
315
  """
@@ -319,11 +262,9 @@ def _refine_alpha_edges(self, frame_bgr: np.ndarray, alpha_u8: np.ndarray, radiu
319
  a = alpha_u8.astype(np.uint8)
320
  if radius <= 0:
321
  return a
322
-
323
  band = cv2.Canny(a, 32, 64)
324
  if band.max() == 0:
325
  return a
326
-
327
  for _ in range(max(1, iters)):
328
  a_blur = cv2.GaussianBlur(a, (radius*2+1, radius*2+1), 0)
329
  b,g,r = cv2.split(frame_bgr.astype(np.float32))
@@ -331,9 +272,7 @@ def _refine_alpha_edges(self, frame_bgr: np.ndarray, alpha_u8: np.ndarray, radiu
331
  spill_mask = (green_dom & (a > 96) & (a < 224)).astype(np.uint8)*255
332
  u = cv2.bitwise_or(band, spill_mask)
333
  a = np.where(u>0, a_blur, a).astype(np.uint8)
334
-
335
  return a
336
-
337
  # -------- soft key based on chosen color (robust to blue/cyan/magenta) --------
338
  def _soft_key_mask(self, frame_bgr: np.ndarray, key_bgr: np.ndarray, tol: int = 40) -> np.ndarray:
339
  """
@@ -341,14 +280,12 @@ def _soft_key_mask(self, frame_bgr: np.ndarray, key_bgr: np.ndarray, tol: int =
341
  """
342
  if key_bgr is None:
343
  return np.full(frame_bgr.shape[:2], 255, np.uint8)
344
-
345
  ycbcr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb).astype(np.float32)
346
  kycbcr = cv2.cvtColor(key_bgr.reshape(1,1,3).astype(np.uint8), cv2.COLOR_BGR2YCrCb).astype(np.float32)[0,0]
347
  d = np.linalg.norm((ycbcr[...,1:] - kycbcr[1:]), axis=-1)
348
  d = cv2.GaussianBlur(d, (5,5), 0)
349
- alpha = 255.0 * np.clip((d - tol) / (tol*1.7), 0.0, 1.0) # far from key = keep (255)
350
  return alpha.astype(np.uint8)
351
-
352
  # --------------------- MatAnyone bootstrap ----------------------
353
  def _bootstrap_matanyone_if_needed(self, frame_bgr: np.ndarray, coarse_mask: np.ndarray):
354
  """
@@ -361,94 +298,80 @@ def _bootstrap_matanyone_if_needed(self, frame_bgr: np.ndarray, coarse_mask: np.
361
  h, w = frame_bgr.shape[:2]
362
  mask_f = self._to_float01(coarse_mask, h, w)
363
  rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
364
- _ = self.matanyone(rgb, mask_f) # boot only; ignore returned alpha
365
  self._mat_bootstrapped = True
366
  logger.info("MatAnyone session bootstrapped with first-frame mask.")
367
  except Exception as e:
368
  logger.warning(f"MatAnyone bootstrap failed (continuing without): {e}")
369
-
370
  def _should_refine_frame(self, frame_idx: int) -> bool:
371
  """Check if current frame should be refined based on quality profile"""
372
  if not self.matanyone:
373
  return False
374
-
375
  # Always refine first frame for bootstrap
376
  if frame_idx == 0:
377
  return True
378
-
379
- stride = max(1, int(self.q.get("refine_stride", 3)))
380
- return (frame_idx % stride) == 0
381
-
382
  # ---------------------------------------------------------------------
383
- # Stage 1 – Original → keyed (green/blue/…) -- chooses colour on 1st frame
384
  # ---------------------------------------------------------------------
385
  def stage1_extract_to_greenscreen(
386
  self,
387
  video_path: str,
388
  output_path: str,
389
  *,
390
- key_color_mode: str = "auto", # "auto" | "green" | "blue" | "cyan" | "magenta"
391
  progress_callback: Optional[Callable[[float, str], None]] = None,
392
- stop_event: Optional["threading.Event"] = None,
393
  ) -> Tuple[Optional[dict], str]:
394
-
395
  def _prog(p, d):
396
  if progress_callback:
397
  try:
398
  progress_callback(float(p), str(d))
399
  except Exception:
400
  pass
401
-
402
  try:
403
  # pick up any new quality selection
404
- self._refresh_quality_from_env()
405
-
406
  _prog(0.0, "Stage 1: opening video…")
407
  cap = cv2.VideoCapture(video_path)
408
  if not cap.isOpened():
409
  return None, "Could not open input video"
410
-
411
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
412
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
413
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
414
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
415
-
416
  base_writer, out_path = create_video_writer(output_path, fps, w, h)
417
  if base_writer is None:
418
  cap.release()
419
  return None, "Could not create output writer"
420
-
421
  # Use robust wrapper
422
  writer = RobustVideoWriter(base_writer, out_path)
423
-
424
  key_info: dict | None = None
425
- chosen_bgr = np.array([0, 255, 0], np.uint8) # default
426
  probe_done = False
427
  masks: List[np.ndarray] = []
428
  frame_idx = 0
429
  self.frames_refined = 0
430
-
431
- solid_bg = np.zeros((h, w, 3), np.uint8) # overwritten per-frame
432
-
433
  while True:
434
  if stop_event and stop_event.is_set():
435
  _prog(1.0, "Stage 1: cancelled")
436
  break
437
-
438
  ok, frame = cap.read()
439
  if not ok:
440
  break
441
-
442
  # --- SAM2 segmentation ---
443
  mask = self._get_mask(frame)
444
-
445
  # --- MatAnyone bootstrap exactly once (first frame) ---
446
  if frame_idx == 0 and self.matanyone is not None:
447
  try:
448
  self._bootstrap_matanyone_if_needed(frame, mask)
449
  except Exception as e:
450
  logger.warning(f"Bootstrap error (non-fatal): {e}")
451
-
452
  # --- Decide key colour once ---
453
  if not probe_done:
454
  if key_color_mode.lower() == "auto":
@@ -460,33 +383,28 @@ def _prog(p, d):
460
  chosen_bgr = cand["bgr"]
461
  probe_done = True
462
  logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
463
-
464
  # --- Optional refinement via MatAnyone (profile cadence) ---
465
  if self._should_refine_frame(frame_idx):
466
  try:
467
  mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
468
  self.frames_refined += 1
469
- logger.debug(f"Frame {frame_idx}: Refined (quality={self.quality})")
470
  except Exception as e:
471
  logger.warning(f"MatAnyOne refine fail f={frame_idx}: {e}")
472
  else:
473
- logger.debug(f"Frame {frame_idx}: Skipped refinement (cadence={self.q['refine_stride']})")
474
-
475
  # --- Composite onto solid key colour ---
476
  solid_bg[:] = chosen_bgr
477
  mask_u8 = self._to_binary_mask(mask)
478
  gs = self._apply_greenscreen_hard(frame, mask_u8, solid_bg)
479
  writer.write(gs)
480
  masks.append(mask_u8)
481
-
482
  frame_idx += 1
483
  pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
484
  _prog(pct, f"Stage 1: {frame_idx}/{total or '?'} (refined: {self.frames_refined})")
485
-
486
  cap.release()
487
  writer.release()
488
  self.total_frames_processed = frame_idx
489
-
490
  # save mask cache
491
  try:
492
  cache_file = self.mask_cache_dir / (Path(out_path).stem + "_masks.pkl")
@@ -495,24 +413,21 @@ def _prog(p, d):
495
  logger.info(f"Cached {len(masks)} masks to {cache_file}")
496
  except Exception as e:
497
  logger.warning(f"mask cache save fail: {e}")
498
-
499
  _prog(1.0, "Stage 1: complete")
500
-
501
  # Log quality impact
502
  logger.info(f"Stage 1 complete: {frame_idx} frames, {self.frames_refined} refined "
503
  f"({100*self.frames_refined/max(1,frame_idx):.1f}%)")
504
-
505
  return (
506
  {"path": out_path, "frames": frame_idx, "key_bgr": chosen_bgr.tolist()},
507
  f"Green-screen video created ({frame_idx} frames, {self.frames_refined} refined)"
508
  )
509
-
510
  except Exception as e:
511
  logger.error(f"Stage 1 error: {e}\n{traceback.format_exc()}")
512
  return None, f"Stage 1 failed: {e}"
513
-
514
  # ---------------------------------------------------------------------
515
- # Stage 2 – keyed video → final composite (hybrid matte)
516
  # ---------------------------------------------------------------------
517
  def stage2_greenscreen_to_final(
518
  self,
@@ -522,31 +437,27 @@ def stage2_greenscreen_to_final(
522
  *,
523
  chroma_settings: Optional[Dict[str, Any]] = None,
524
  progress_callback: Optional[Callable[[float, str], None]] = None,
525
- stop_event: Optional["threading.Event"] = None,
526
- key_bgr: Optional[np.ndarray] = None, # pass chosen key color
527
  ) -> Tuple[Optional[str], str]:
528
-
529
  def _prog(p, d):
530
  if progress_callback:
531
  try:
532
  progress_callback(float(p), str(d))
533
  except Exception:
534
  pass
535
-
536
  try:
537
  # pick up any new quality selection
538
- self._refresh_quality_from_env()
539
-
540
  _prog(0.0, "Stage 2: opening keyed video…")
541
  cap = cv2.VideoCapture(gs_path)
542
  if not cap.isOpened():
543
  return None, "Could not open keyed video"
544
-
545
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
546
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
547
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
548
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
549
-
550
  # Load or prepare background
551
  if isinstance(background, str):
552
  bg = cv2.imread(background)
@@ -556,21 +467,18 @@ def _prog(p, d):
556
  bg = cv2.resize(bg, (w, h))
557
  else:
558
  bg = cv2.resize(background, (w, h))
559
-
560
  # Optional tiny BG blur per profile to hide seams on flat BGs
561
- sigma = float(self.q.get("bg_sigma", 0.0))
562
  if sigma > 0:
563
  bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=sigma, sigmaY=sigma)
564
  logger.debug(f"Applied background blur: sigma={sigma:.1f}")
565
-
566
  base_writer, out_path = create_video_writer(output_path, fps, w, h)
567
  if base_writer is None:
568
  cap.release()
569
  return None, "Could not create output writer"
570
-
571
  # Use robust wrapper
572
  writer = RobustVideoWriter(base_writer, out_path)
573
-
574
  # Load cached masks if available
575
  masks = None
576
  try:
@@ -581,29 +489,23 @@ def _prog(p, d):
581
  logger.info(f"Loaded {len(masks)} cached masks")
582
  except Exception as e:
583
  logger.warning(f"Could not load mask cache: {e}")
584
-
585
  # Get chroma settings and override with profile
586
  settings = chroma_settings or CHROMA_PRESETS.get('standard', {})
587
- tolerance = int(settings.get('tolerance', 38)) # keep user tolerance
588
- edge_softness = int(self.q.get('edge_softness', settings.get('edge_softness', 2)))
589
- spill_suppression = float(self.q.get('spill', settings.get('spill_suppression', 0.35)))
590
-
591
  # If caller didn't pass key_bgr, try preset or default green
592
  if key_bgr is None:
593
  key_bgr = np.array(settings.get('key_color', [0,255,0]), dtype=np.uint8)
594
-
595
- self._alpha_prev = None # reset temporal smoothing per render
596
-
597
  frame_idx = 0
598
  while True:
599
  if stop_event and stop_event.is_set():
600
  _prog(1.0, "Stage 2: cancelled")
601
  break
602
-
603
  ok, frame = cap.read()
604
  if not ok:
605
  break
606
-
607
  # Apply chroma keying with optional mask assistance
608
  if masks and frame_idx < len(masks):
609
  mask = masks[frame_idx]
@@ -623,34 +525,28 @@ def _prog(p, d):
623
  spill_suppression=spill_suppression,
624
  key_bgr=key_bgr
625
  )
626
-
627
  writer.write(final_frame)
628
  frame_idx += 1
629
  pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
630
  _prog(pct, f"Stage 2: {frame_idx}/{total or '?'}")
631
-
632
  cap.release()
633
  writer.release()
634
-
635
  _prog(1.0, "Stage 2: complete")
636
-
637
  # Verify frame counts match
638
  if total > 0 and frame_idx != total:
639
  logger.warning(f"Frame count mismatch: processed {frame_idx}, expected {total}")
640
-
641
  return out_path, f"Final composite created ({frame_idx} frames)"
642
-
643
  except Exception as e:
644
  logger.error(f"Stage 2 error: {e}\n{traceback.format_exc()}")
645
  return None, f"Stage 2 failed: {e}"
646
-
647
  # ---------------- chroma + hybrid compositors (polished) ----------------
648
  def _chroma_key_composite(self, frame, bg, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
649
  """Apply chroma key compositing with soft color distance + edge refinement."""
650
  # 1) spill first
651
  if spill_suppression > 0:
652
  frame = self._suppress_green_spill(frame, spill_suppression)
653
-
654
  # 2) build alpha
655
  if key_bgr is not None:
656
  alpha = self._soft_key_mask(frame, key_bgr, tol=int(tolerance))
@@ -660,23 +556,19 @@ def _chroma_key_composite(self, frame, bg, *, tolerance=38, edge_softness=2, spi
660
  lower_green = np.array([40, 40, 40])
661
  upper_green = np.array([80, 255, 255])
662
  alpha = cv2.bitwise_not(cv2.inRange(hsv, lower_green, upper_green))
663
-
664
  # 3) soft edges + refinement
665
  if edge_softness > 0:
666
  k = edge_softness * 2 + 1
667
  alpha = cv2.GaussianBlur(alpha, (k, k), 0)
668
  alpha = self._refine_alpha_edges(frame, alpha, radius=max(1, edge_softness), iters=1)
669
-
670
  # 4) temporal smoothing
671
  if self._alpha_prev is not None and self._alpha_prev.shape == alpha.shape:
672
  alpha = cv2.addWeighted(alpha, 0.75, self._alpha_prev, 0.25, 0)
673
  self._alpha_prev = alpha
674
-
675
  # 5) composite
676
  mask_3ch = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR).astype(np.float32) / 255.0
677
  out = frame.astype(np.float32) * mask_3ch + bg.astype(np.float32) * (1.0 - mask_3ch)
678
  return np.clip(out, 0, 255).astype(np.uint8)
679
-
680
  def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
681
  """Apply hybrid compositing using both chroma key and cached mask, with profile controls."""
682
  chroma_result = self._chroma_key_composite(
@@ -688,27 +580,22 @@ def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, s
688
  )
689
  if mask is None:
690
  return chroma_result
691
-
692
  # profile-driven dilate/feather on cached mask to close pinholes + soften edges
693
  m = mask
694
- d = int(self.q.get("dilate", 0))
695
  if d > 0:
696
  k = 2*d + 1
697
  se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
698
  m = cv2.dilate(m, se, iterations=1)
699
- b = int(self.q.get("blur", 0))
700
  if b > 0:
701
  m = cv2.GaussianBlur(m, (2*b+1, 2*b+1), 0)
702
-
703
  m3 = cv2.cvtColor(m, cv2.COLOR_GRAY2BGR) if m.ndim == 2 else m
704
  m3f = (m3.astype(np.float32) / 255.0)
705
-
706
  seg_comp = frame.astype(np.float32) * m3f + bg.astype(np.float32) * (1.0 - m3f)
707
-
708
- mix = float(self.q.get("mix", 0.7)) # weight towards segmentation on "max"
709
  out = chroma_result.astype(np.float32) * (1.0 - mix) + seg_comp * mix
710
  return np.clip(out, 0, 255).astype(np.uint8)
711
-
712
  # ---------------------------------------------------------------------
713
  # Combined pipeline
714
  # ---------------------------------------------------------------------
@@ -721,73 +608,59 @@ def process_full_pipeline(
721
  key_color_mode: str = "auto",
722
  chroma_settings: Optional[Dict[str, Any]] = None,
723
  progress_callback: Optional[Callable[[float, str], None]] = None,
724
- stop_event: Optional["threading.Event"] = None,
725
- ) -> Tuple[Optional[str], str]:
726
- """Run both stages in sequence."""
727
-
728
  def _combined_progress(pct, desc):
729
  # Scale progress: Stage 1 is 0-50%, Stage 2 is 50-100%
730
  if "Stage 1" in desc:
731
  actual_pct = pct * 0.5
732
- else: # Stage 2
733
  actual_pct = 0.5 + pct * 0.5
734
-
735
  if progress_callback:
736
  try:
737
  progress_callback(actual_pct, desc)
738
  except Exception:
739
  pass
740
-
741
  try:
742
- # pick up any new quality selection once per run
743
- self._refresh_quality_from_env()
744
-
745
  # Reset per-video state
746
  self._mat_bootstrapped = False
747
  self._alpha_prev = None
748
  self.total_frames_processed = 0
749
  self.frames_refined = 0
750
-
751
  if self.matanyone is not None and hasattr(self.matanyone, "reset"):
752
  try:
753
  self.matanyone.reset()
754
  except Exception:
755
  pass
756
-
757
- # Stage 1
758
- temp_gs_path = tempfile.mktemp(suffix="_greenscreen.mp4")
759
  stage1_result, stage1_msg = self.stage1_extract_to_greenscreen(
760
- video_path, temp_gs_path,
761
  key_color_mode=key_color_mode,
762
  progress_callback=_combined_progress,
763
  stop_event=stop_event
764
  )
765
  if stage1_result is None:
766
- return None, stage1_msg
767
-
768
  # Stage 2 (pass through chosen key color)
769
  key_bgr = np.array(stage1_result.get("key_bgr", [0,255,0]), dtype=np.uint8)
770
  final_path, stage2_msg = self.stage2_greenscreen_to_final(
771
- stage1_result["path"], background, output_path,
772
  chroma_settings=chroma_settings,
773
  progress_callback=_combined_progress,
774
  stop_event=stop_event,
775
  key_bgr=key_bgr,
776
  )
777
-
778
- # Clean up temp file
779
- try:
780
- os.remove(temp_gs_path)
781
- except Exception:
782
- pass
783
-
784
  # Report quality impact
785
- logger.info(f"Pipeline complete with quality='{self.quality}': "
786
  f"{self.total_frames_processed} frames, "
787
  f"{self.frames_refined} refined ({100*self.frames_refined/max(1,self.total_frames_processed):.1f}%)")
788
-
789
- return final_path, stage2_msg
790
-
791
  except Exception as e:
792
  logger.error(f"Full pipeline error: {e}\n{traceback.format_exc()}")
793
- return None, f"Pipeline failed: {e}"
 
1
  #!/usr/bin/env python3
2
  """
3
+ Two-Stage Green-Screen Processing System ✅ 2025-08-29
4
  Stage 1: Original → keyed background (auto-selected colour)
5
+ Stage 2: Keyed video → final composite (hybrid chroma + segmentation rescue)
 
6
  UPDATED: Enhanced quality profiles, improved frame handling, better status reporting
7
+ - New: Integrate QualityManager for separated logic
8
+ - New: Return green screen path for monitoring
9
+ - Fix: Force green key color
10
+ - Fix: Use RobustVideoWriter to prevent frame loss
11
  """
12
  from __future__ import annotations
 
13
  import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
14
  from pathlib import Path
 
 
15
  from utils.cv_processing import segment_person_hq, refine_mask_hq
16
+ from quality_manager import quality_manager # New quality manager import
17
  # Project logger if available
18
  try:
19
  from utils.logger import get_logger
20
  logger = get_logger(__name__)
21
  except Exception:
22
  logger = logging.getLogger(__name__)
 
23
  # ---------------------------------------------------------------------------
24
  # Local video-writer helper with frame count guarantee
25
  # ---------------------------------------------------------------------------
 
33
  base, curr_ext = os.path.splitext(output_path)
34
  if curr_ext.lower() not in [".mp4", ".avi", ".mov", ".mkv"]:
35
  output_path = base + ext
 
36
  fourcc = cv2.VideoWriter_fourcc(*("mp4v" if prefer_mp4 else "XVID"))
37
  writer = cv2.VideoWriter(output_path, fourcc, float(fps), (int(width), int(height)))
38
  if writer is None or not writer.isOpened():
 
47
  except Exception as e:
48
  logger.error(f"create_video_writer failed: {e}")
49
  return None, output_path
 
50
  # ---------------------------------------------------------------------------
51
  # Robust video writer wrapper to prevent frame loss
52
  # ---------------------------------------------------------------------------
53
  class RobustVideoWriter:
54
  """Wrapper that ensures all frames are written"""
55
+
56
  def __init__(self, writer, output_path: str):
57
  self.writer = writer
58
  self.output_path = output_path
59
  self.frame_buffer = []
60
  self.frames_written = 0
61
  self.frames_attempted = 0
62
+
63
  def write(self, frame):
64
  """Buffer and write frame"""
65
  if frame is None:
66
  return False
67
+
68
  self.frames_attempted += 1
69
  self.frame_buffer.append(frame.copy())
70
+
71
  # Try to write buffered frames
72
  while self.frame_buffer and self.writer:
73
  try:
 
78
  logger.warning(f"Frame write failed: {e}")
79
  return False
80
  return True
81
+
82
  def release(self):
83
  """Flush remaining frames and close"""
84
  # Write any remaining buffered frames
 
89
  self.frames_written += 1
90
  except Exception:
91
  break
92
+
93
+ # Duplicate last frame 3 times to force flush
94
+ if self.frames_written > 0 and self.writer:
95
+ last_frame = self.frame_buffer[-1] if self.frame_buffer else None
96
+ if last_frame is None:
97
+ # Read last written frame if needed
98
+ try:
99
+ cap = cv2.VideoCapture(self.output_path)
100
+ cap.set(cv2.CAP_PROP_POS_FRAMES, self.frames_written - 1)
101
+ _, last_frame = cap.read()
102
+ cap.release()
103
+ except Exception:
104
+ pass
105
+ if last_frame is not None:
106
+ for _ in range(3):
107
+ self.writer.write(last_frame)
108
+ self.frames_written += 1
109
+
110
  # Close writer
111
  if self.writer:
112
  self.writer.release()
113
+
114
  # Log statistics
115
  logger.info(f"Video writer closed: {self.frames_written}/{self.frames_attempted} frames written")
116
+
117
  # Verify output exists
118
  if os.path.exists(self.output_path):
119
  size = os.path.getsize(self.output_path)
 
121
  logger.error("WARNING: Output file is empty!")
122
  else:
123
  logger.info(f"Output file size: {size:,} bytes")
 
124
  # ---------------------------------------------------------------------------
125
  # Key-colour helpers (fast, no external deps)
126
  # ---------------------------------------------------------------------------
 
128
  hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
129
  # OpenCV H is 0-180; scale to degrees 0-360
130
  return hsv[..., 0].astype(np.float32) * 2.0
 
131
  def _hue_distance(a_deg: float, b_deg: float) -> float:
132
  """Circular distance on the hue wheel (degrees)."""
133
  d = abs(a_deg - b_deg) % 360.0
134
  return min(d, 360.0 - d)
 
135
  def _key_candidates_bgr() -> dict:
136
  return {
137
+ "green": {"bgr": np.array([ 0,255, 0], dtype=np.uint8), "hue": 120.0},
138
+ "blue": {"bgr": np.array([255, 0, 0], dtype=np.uint8), "hue": 240.0},
139
+ "cyan": {"bgr": np.array([255,255, 0], dtype=np.uint8), "hue": 180.0},
140
+ "magenta": {"bgr": np.array([255, 0,255], dtype=np.uint8), "hue": 300.0},
141
  }
 
142
  def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dict:
143
  """Pick the candidate colour farthest from the actor's dominant hues."""
144
  try:
145
  fg = frame_bgr[mask_uint8 > 127]
146
  if fg.size < 1_000:
147
  return _key_candidates_bgr()["green"]
 
148
  fg_hue = _bgr_to_hsv_hue_deg(fg.reshape(-1, 1, 3)).reshape(-1)
149
  hist, edges = np.histogram(fg_hue, bins=36, range=(0.0, 360.0))
150
  top_idx = np.argsort(hist)[-3:]
151
  top_hues = [(edges[i] + edges[i+1]) * 0.5 for i in top_idx]
 
152
  best_name, best_score = None, -1.0
153
  for name, info in _key_candidates_bgr().items():
154
  cand_hue = info["hue"]
 
158
  return _key_candidates_bgr().get(best_name, _key_candidates_bgr()["green"])
159
  except Exception:
160
  return _key_candidates_bgr()["green"]
 
161
  # ---------------------------------------------------------------------------
162
  # Chroma presets
163
  # ---------------------------------------------------------------------------
164
  CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
165
  'standard': {'key_color': [0,255,0], 'tolerance': 38, 'edge_softness': 2, 'spill_suppression': 0.35},
166
+ 'studio': {'key_color': [0,255,0], 'tolerance': 30, 'edge_softness': 1, 'spill_suppression': 0.45},
167
+ 'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
168
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # ---------------------------------------------------------------------------
170
  # Two-Stage Processor
171
  # ---------------------------------------------------------------------------
 
175
  self.matanyone = matanyone_model
176
  self.mask_cache_dir = Path("/tmp/mask_cache")
177
  self.mask_cache_dir.mkdir(parents=True, exist_ok=True)
 
178
  # Internal flags/state
179
  self._mat_bootstrapped = False
180
+ self._alpha_prev: Optional[np.ndarray] = None # temporal smoothing
181
+
182
  # Frame tracking
183
  self.total_frames_processed = 0
184
  self.frames_refined = 0
185
+
186
+ # Load quality profile
187
+ self.q = quality_manager.get_params()
 
 
 
 
 
 
 
 
188
  logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
 
189
  # --------------------------- internal utils ---------------------------
 
190
  def _unwrap_sam2(self, predictor):
191
  """Unwrap the SAM2 predictor if needed."""
192
  if predictor is None:
 
194
  if hasattr(predictor, 'sam_predictor'):
195
  return predictor.sam_predictor
196
  return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def _get_mask(self, frame: np.ndarray) -> np.ndarray:
198
  """Get segmentation mask using SAM2 (delegates to project helper)."""
199
  if self.sam2 is None:
 
201
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
202
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
203
  return mask
 
204
  try:
205
  mask = segment_person_hq(frame, self.sam2)
206
  # segment_person_hq returns either uint8(0..255) or float(0..1) in most builds
 
210
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
211
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
212
  return mask
 
213
  @staticmethod
214
  def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
215
  """Convert mask to uint8(0..255)."""
 
221
  m = np.clip(mask, 0.0, 1.0)
222
  return (m * 255.0 + 0.5).astype(np.uint8)
223
  return mask
 
224
  @staticmethod
225
  def _to_float01(mask: np.ndarray, h: int = None, w: int = None) -> Optional[np.ndarray]:
226
  """Float [0,1] mask, optionally resized to (h,w)."""
 
232
  if h is not None and w is not None and (m.shape[0] != h or m.shape[1] != w):
233
  m = cv2.resize(m, (w, h), interpolation=cv2.INTER_LINEAR)
234
  return np.clip(m, 0.0, 1.0)
235
+ @staticmethod
236
+ def _apply_greenscreen_hard(frame: np.ndarray, mask: np.ndarray, bg: np.ndarray) -> np.ndarray:
237
  """Apply hard greenscreen compositing."""
238
  mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) if mask.ndim == 2 else mask
239
  mask_norm = mask_3ch.astype(np.float32) / 255.0
240
  result = frame * mask_norm + bg * (1 - mask_norm)
241
  return result.astype(np.uint8)
 
242
  # -------- improved spill suppression (preserves luminance & skin) --------
243
  def _suppress_green_spill(self, frame: np.ndarray, amount: float = 0.35) -> np.ndarray:
244
  """
 
249
  green_dom = (g > r) & (g > b)
250
  avg_rb = (r + b) * 0.5
251
  g2 = np.where(green_dom, g*(1.0-amount) + avg_rb*amount, g)
252
+ skin = (r > g + 12) # protect skin tones
253
  g2 = np.where(skin, g, g2)
254
  out = cv2.merge([np.clip(b,0,255), np.clip(g2,0,255), np.clip(r,0,255)]).astype(np.uint8)
255
  return out
 
256
  # -------- edge-aware alpha refinement (guided-like) --------
257
  def _refine_alpha_edges(self, frame_bgr: np.ndarray, alpha_u8: np.ndarray, radius: int = 3, iters: int = 1) -> np.ndarray:
258
  """
 
262
  a = alpha_u8.astype(np.uint8)
263
  if radius <= 0:
264
  return a
 
265
  band = cv2.Canny(a, 32, 64)
266
  if band.max() == 0:
267
  return a
 
268
  for _ in range(max(1, iters)):
269
  a_blur = cv2.GaussianBlur(a, (radius*2+1, radius*2+1), 0)
270
  b,g,r = cv2.split(frame_bgr.astype(np.float32))
 
272
  spill_mask = (green_dom & (a > 96) & (a < 224)).astype(np.uint8)*255
273
  u = cv2.bitwise_or(band, spill_mask)
274
  a = np.where(u>0, a_blur, a).astype(np.uint8)
 
275
  return a
 
276
  # -------- soft key based on chosen color (robust to blue/cyan/magenta) --------
277
  def _soft_key_mask(self, frame_bgr: np.ndarray, key_bgr: np.ndarray, tol: int = 40) -> np.ndarray:
278
  """
 
280
  """
281
  if key_bgr is None:
282
  return np.full(frame_bgr.shape[:2], 255, np.uint8)
 
283
  ycbcr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb).astype(np.float32)
284
  kycbcr = cv2.cvtColor(key_bgr.reshape(1,1,3).astype(np.uint8), cv2.COLOR_BGR2YCrCb).astype(np.float32)[0,0]
285
  d = np.linalg.norm((ycbcr[...,1:] - kycbcr[1:]), axis=-1)
286
  d = cv2.GaussianBlur(d, (5,5), 0)
287
+ alpha = 255.0 * np.clip((d - tol) / (tol*1.7), 0.0, 1.0) # far from key = keep (255)
288
  return alpha.astype(np.uint8)
 
289
  # --------------------- MatAnyone bootstrap ----------------------
290
  def _bootstrap_matanyone_if_needed(self, frame_bgr: np.ndarray, coarse_mask: np.ndarray):
291
  """
 
298
  h, w = frame_bgr.shape[:2]
299
  mask_f = self._to_float01(coarse_mask, h, w)
300
  rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
301
+ _ = self.matanyone(rgb, mask_f) # boot only; ignore returned returned alpha
302
  self._mat_bootstrapped = True
303
  logger.info("MatAnyone session bootstrapped with first-frame mask.")
304
  except Exception as e:
305
  logger.warning(f"MatAnyone bootstrap failed (continuing without): {e}")
 
306
  def _should_refine_frame(self, frame_idx: int) -> bool:
307
  """Check if current frame should be refined based on quality profile"""
308
  if not self.matanyone:
309
  return False
310
+
311
  # Always refine first frame for bootstrap
312
  if frame_idx == 0:
313
  return True
314
+
315
+ return quality_manager.should_refine_frame(frame_idx)
 
 
316
  # ---------------------------------------------------------------------
317
+ # Stage 1 – Original → keyed (green/blue/…) -- chooses colour on 1st frame
318
  # ---------------------------------------------------------------------
319
  def stage1_extract_to_greenscreen(
320
  self,
321
  video_path: str,
322
  output_path: str,
323
  *,
324
+ key_color_mode: str = "auto", # "auto" | "green" | "blue" | "cyan" | "magenta"
325
  progress_callback: Optional[Callable[[float, str], None]] = None,
326
+ stop_event: Optional[threading.Event] = None,
327
  ) -> Tuple[Optional[dict], str]:
 
328
  def _prog(p, d):
329
  if progress_callback:
330
  try:
331
  progress_callback(float(p), str(d))
332
  except Exception:
333
  pass
 
334
  try:
335
  # pick up any new quality selection
336
+ quality_manager.load_profile() # Refresh
337
+ self.q = quality_manager.get_params()
338
  _prog(0.0, "Stage 1: opening video…")
339
  cap = cv2.VideoCapture(video_path)
340
  if not cap.isOpened():
341
  return None, "Could not open input video"
342
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
 
343
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
344
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
345
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
346
  base_writer, out_path = create_video_writer(output_path, fps, w, h)
347
  if base_writer is None:
348
  cap.release()
349
  return None, "Could not create output writer"
350
+
351
  # Use robust wrapper
352
  writer = RobustVideoWriter(base_writer, out_path)
 
353
  key_info: dict | None = None
354
+ chosen_bgr = np.array([0, 255, 0], np.uint8) # Force green
355
  probe_done = False
356
  masks: List[np.ndarray] = []
357
  frame_idx = 0
358
  self.frames_refined = 0
359
+ solid_bg = np.zeros((h, w, 3), np.uint8) # overwritten per-frame
 
 
360
  while True:
361
  if stop_event and stop_event.is_set():
362
  _prog(1.0, "Stage 1: cancelled")
363
  break
 
364
  ok, frame = cap.read()
365
  if not ok:
366
  break
 
367
  # --- SAM2 segmentation ---
368
  mask = self._get_mask(frame)
 
369
  # --- MatAnyone bootstrap exactly once (first frame) ---
370
  if frame_idx == 0 and self.matanyone is not None:
371
  try:
372
  self._bootstrap_matanyone_if_needed(frame, mask)
373
  except Exception as e:
374
  logger.warning(f"Bootstrap error (non-fatal): {e}")
 
375
  # --- Decide key colour once ---
376
  if not probe_done:
377
  if key_color_mode.lower() == "auto":
 
383
  chosen_bgr = cand["bgr"]
384
  probe_done = True
385
  logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
 
386
  # --- Optional refinement via MatAnyone (profile cadence) ---
387
  if self._should_refine_frame(frame_idx):
388
  try:
389
  mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
390
  self.frames_refined += 1
391
+ logger.debug(f"Frame {frame_idx}: Refined (quality={quality_manager.profile_name})")
392
  except Exception as e:
393
  logger.warning(f"MatAnyOne refine fail f={frame_idx}: {e}")
394
  else:
395
+ logger.debug(f"Frame {frame_idx}: Skipped refinement (cadence={self.q['refine_cadence']})")
 
396
  # --- Composite onto solid key colour ---
397
  solid_bg[:] = chosen_bgr
398
  mask_u8 = self._to_binary_mask(mask)
399
  gs = self._apply_greenscreen_hard(frame, mask_u8, solid_bg)
400
  writer.write(gs)
401
  masks.append(mask_u8)
 
402
  frame_idx += 1
403
  pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
404
  _prog(pct, f"Stage 1: {frame_idx}/{total or '?'} (refined: {self.frames_refined})")
 
405
  cap.release()
406
  writer.release()
407
  self.total_frames_processed = frame_idx
 
408
  # save mask cache
409
  try:
410
  cache_file = self.mask_cache_dir / (Path(out_path).stem + "_masks.pkl")
 
413
  logger.info(f"Cached {len(masks)} masks to {cache_file}")
414
  except Exception as e:
415
  logger.warning(f"mask cache save fail: {e}")
 
416
  _prog(1.0, "Stage 1: complete")
417
+
418
  # Log quality impact
419
  logger.info(f"Stage 1 complete: {frame_idx} frames, {self.frames_refined} refined "
420
  f"({100*self.frames_refined/max(1,frame_idx):.1f}%)")
421
+
422
  return (
423
  {"path": out_path, "frames": frame_idx, "key_bgr": chosen_bgr.tolist()},
424
  f"Green-screen video created ({frame_idx} frames, {self.frames_refined} refined)"
425
  )
 
426
  except Exception as e:
427
  logger.error(f"Stage 1 error: {e}\n{traceback.format_exc()}")
428
  return None, f"Stage 1 failed: {e}"
 
429
  # ---------------------------------------------------------------------
430
+ # Stage 2 – keyed video → final composite (hybrid matte)
431
  # ---------------------------------------------------------------------
432
  def stage2_greenscreen_to_final(
433
  self,
 
437
  *,
438
  chroma_settings: Optional[Dict[str, Any]] = None,
439
  progress_callback: Optional[Callable[[float, str], None]] = None,
440
+ stop_event: Optional[threading.Event] = None,
441
+ key_bgr: Optional[np.ndarray] = None, # pass chosen key color
442
  ) -> Tuple[Optional[str], str]:
 
443
  def _prog(p, d):
444
  if progress_callback:
445
  try:
446
  progress_callback(float(p), str(d))
447
  except Exception:
448
  pass
 
449
  try:
450
  # pick up any new quality selection
451
+ quality_manager.load_profile()
452
+ self.q = quality_manager.get_params()
453
  _prog(0.0, "Stage 2: opening keyed video…")
454
  cap = cv2.VideoCapture(gs_path)
455
  if not cap.isOpened():
456
  return None, "Could not open keyed video"
457
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
 
458
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
459
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
460
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
461
  # Load or prepare background
462
  if isinstance(background, str):
463
  bg = cv2.imread(background)
 
467
  bg = cv2.resize(bg, (w, h))
468
  else:
469
  bg = cv2.resize(background, (w, h))
 
470
  # Optional tiny BG blur per profile to hide seams on flat BGs
471
+ sigma = float(self.q['bg_blur_sigma'])
472
  if sigma > 0:
473
  bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=sigma, sigmaY=sigma)
474
  logger.debug(f"Applied background blur: sigma={sigma:.1f}")
 
475
  base_writer, out_path = create_video_writer(output_path, fps, w, h)
476
  if base_writer is None:
477
  cap.release()
478
  return None, "Could not create output writer"
479
+
480
  # Use robust wrapper
481
  writer = RobustVideoWriter(base_writer, out_path)
 
482
  # Load cached masks if available
483
  masks = None
484
  try:
 
489
  logger.info(f"Loaded {len(masks)} cached masks")
490
  except Exception as e:
491
  logger.warning(f"Could not load mask cache: {e}")
 
492
  # Get chroma settings and override with profile
493
  settings = chroma_settings or CHROMA_PRESETS.get('standard', {})
494
+ tolerance = int(self.q['chroma_tolerance'])
495
+ edge_softness = int(self.q['chroma_softness'])
496
+ spill_suppression = float(self.q['spill_suppression'])
 
497
  # If caller didn't pass key_bgr, try preset or default green
498
  if key_bgr is None:
499
  key_bgr = np.array(settings.get('key_color', [0,255,0]), dtype=np.uint8)
500
+ self._alpha_prev = None # reset temporal smoothing per render
 
 
501
  frame_idx = 0
502
  while True:
503
  if stop_event and stop_event.is_set():
504
  _prog(1.0, "Stage 2: cancelled")
505
  break
 
506
  ok, frame = cap.read()
507
  if not ok:
508
  break
 
509
  # Apply chroma keying with optional mask assistance
510
  if masks and frame_idx < len(masks):
511
  mask = masks[frame_idx]
 
525
  spill_suppression=spill_suppression,
526
  key_bgr=key_bgr
527
  )
 
528
  writer.write(final_frame)
529
  frame_idx += 1
530
  pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
531
  _prog(pct, f"Stage 2: {frame_idx}/{total or '?'}")
 
532
  cap.release()
533
  writer.release()
 
534
  _prog(1.0, "Stage 2: complete")
535
+
536
  # Verify frame counts match
537
  if total > 0 and frame_idx != total:
538
  logger.warning(f"Frame count mismatch: processed {frame_idx}, expected {total}")
539
+
540
  return out_path, f"Final composite created ({frame_idx} frames)"
 
541
  except Exception as e:
542
  logger.error(f"Stage 2 error: {e}\n{traceback.format_exc()}")
543
  return None, f"Stage 2 failed: {e}"
 
544
  # ---------------- chroma + hybrid compositors (polished) ----------------
545
  def _chroma_key_composite(self, frame, bg, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
546
  """Apply chroma key compositing with soft color distance + edge refinement."""
547
  # 1) spill first
548
  if spill_suppression > 0:
549
  frame = self._suppress_green_spill(frame, spill_suppression)
 
550
  # 2) build alpha
551
  if key_bgr is not None:
552
  alpha = self._soft_key_mask(frame, key_bgr, tol=int(tolerance))
 
556
  lower_green = np.array([40, 40, 40])
557
  upper_green = np.array([80, 255, 255])
558
  alpha = cv2.bitwise_not(cv2.inRange(hsv, lower_green, upper_green))
 
559
  # 3) soft edges + refinement
560
  if edge_softness > 0:
561
  k = edge_softness * 2 + 1
562
  alpha = cv2.GaussianBlur(alpha, (k, k), 0)
563
  alpha = self._refine_alpha_edges(frame, alpha, radius=max(1, edge_softness), iters=1)
 
564
  # 4) temporal smoothing
565
  if self._alpha_prev is not None and self._alpha_prev.shape == alpha.shape:
566
  alpha = cv2.addWeighted(alpha, 0.75, self._alpha_prev, 0.25, 0)
567
  self._alpha_prev = alpha
 
568
  # 5) composite
569
  mask_3ch = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR).astype(np.float32) / 255.0
570
  out = frame.astype(np.float32) * mask_3ch + bg.astype(np.float32) * (1.0 - mask_3ch)
571
  return np.clip(out, 0, 255).astype(np.uint8)
 
572
  def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
573
  """Apply hybrid compositing using both chroma key and cached mask, with profile controls."""
574
  chroma_result = self._chroma_key_composite(
 
580
  )
581
  if mask is None:
582
  return chroma_result
 
583
  # profile-driven dilate/feather on cached mask to close pinholes + soften edges
584
  m = mask
585
+ d = int(self.q['mask_dilate'])
586
  if d > 0:
587
  k = 2*d + 1
588
  se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
589
  m = cv2.dilate(m, se, iterations=1)
590
+ b = int(self.q['mask_blur'])
591
  if b > 0:
592
  m = cv2.GaussianBlur(m, (2*b+1, 2*b+1), 0)
 
593
  m3 = cv2.cvtColor(m, cv2.COLOR_GRAY2BGR) if m.ndim == 2 else m
594
  m3f = (m3.astype(np.float32) / 255.0)
 
595
  seg_comp = frame.astype(np.float32) * m3f + bg.astype(np.float32) * (1.0 - m3f)
596
+ mix = float(self.q['hybrid_mix']) # weight towards segmentation on "max"
 
597
  out = chroma_result.astype(np.float32) * (1.0 - mix) + seg_comp * mix
598
  return np.clip(out, 0, 255).astype(np.uint8)
 
599
  # ---------------------------------------------------------------------
600
  # Combined pipeline
601
  # ---------------------------------------------------------------------
 
608
  key_color_mode: str = "auto",
609
  chroma_settings: Optional[Dict[str, Any]] = None,
610
  progress_callback: Optional[Callable[[float, str], None]] = None,
611
+ stop_event: Optional[threading.Event] = None,
612
+ ) -> Tuple[Optional[str], Optional[str], str]:
613
+ """Run both stages in sequence, return final, green, msg."""
 
614
  def _combined_progress(pct, desc):
615
  # Scale progress: Stage 1 is 0-50%, Stage 2 is 50-100%
616
  if "Stage 1" in desc:
617
  actual_pct = pct * 0.5
618
+ else: # Stage 2
619
  actual_pct = 0.5 + pct * 0.5
 
620
  if progress_callback:
621
  try:
622
  progress_callback(actual_pct, desc)
623
  except Exception:
624
  pass
 
625
  try:
626
+ # pick up any new new quality selection once per run
627
+ quality_manager.load_profile()
628
+ self.q = quality_manager.get_params()
629
  # Reset per-video state
630
  self._mat_bootstrapped = False
631
  self._alpha_prev = None
632
  self.total_frames_processed = 0
633
  self.frames_refined = 0
634
+
635
  if self.matanyone is not None and hasattr(self.matanyone, "reset"):
636
  try:
637
  self.matanyone.reset()
638
  except Exception:
639
  pass
640
+ # Stage 1 - use persistent path for green screen
641
+ green_path = os.path.splitext(output_path)[0] + '_green.mp4'
 
642
  stage1_result, stage1_msg = self.stage1_extract_to_greenscreen(
643
+ video_path, green_path,
644
  key_color_mode=key_color_mode,
645
  progress_callback=_combined_progress,
646
  stop_event=stop_event
647
  )
648
  if stage1_result is None:
649
+ return None, None, stage1_msg
 
650
  # Stage 2 (pass through chosen key color)
651
  key_bgr = np.array(stage1_result.get("key_bgr", [0,255,0]), dtype=np.uint8)
652
  final_path, stage2_msg = self.stage2_greenscreen_to_final(
653
+ green_path, background, output_path,
654
  chroma_settings=chroma_settings,
655
  progress_callback=_combined_progress,
656
  stop_event=stop_event,
657
  key_bgr=key_bgr,
658
  )
 
 
 
 
 
 
 
659
  # Report quality impact
660
+ logger.info(f"Pipeline complete with quality='{quality_manager.profile_name}': "
661
  f"{self.total_frames_processed} frames, "
662
  f"{self.frames_refined} refined ({100*self.frames_refined/max(1,self.total_frames_processed):.1f}%)")
663
+ return final_path, green_path, stage2_msg
 
 
664
  except Exception as e:
665
  logger.error(f"Full pipeline error: {e}\n{traceback.format_exc()}")
666
+ return None, None, f"Pipeline failed: {e}"