MogensR commited on
Commit
7930b59
Β·
1 Parent(s): e7b3257

Update processing/two_stage/two_stage_processor.py

Browse files
processing/two_stage/two_stage_processor.py CHANGED
@@ -1,658 +1,346 @@
1
  #!/usr/bin/env python3
2
  """
3
- Two-Stage Green-Screen Processing System βœ… 2025-08-29
4
- Stage 1: Original β†’ keyed background (auto-selected colour)
5
- Stage 2: Keyed video β†’ final composite (hybrid chroma + segmentation rescue)
6
-
7
- Aligned with current project layout:
8
- * uses helpers from utils.cv_processing (segment_person_hq, refine_mask_hq)
9
- * safe local create_video_writer (no core.app dependency)
10
- * cancel support via stop_event
11
- * progress_callback(pct, desc)
12
- * fully self-contained – just drop in and import TwoStageProcessor
13
-
14
- Additional safety:
15
- * Ensures MatAnyone receives a valid first-frame mask (bootstraps the session
16
- with the first SAM2 mask). This prevents "First frame arrived without a mask"
17
- warnings and shape mismatches inside the stateful refiner.
18
-
19
- Quality profiles (set via env BFX_QUALITY = speed | balanced | max):
20
- * refine cadence, spill suppression, edge softness
21
- * hybrid matte mix (segmentation vs chroma), small dilate/blur on mask
22
- * optional tiny background blur to hide seams on very flat backgrounds
23
  """
 
24
  from __future__ import annotations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
27
- from pathlib import Path
28
- from typing import Optional, Dict, Any, Callable, Tuple, List
29
-
30
- from utils.cv_processing import segment_person_hq, refine_mask_hq
31
-
32
- # Project logger if available
33
- try:
34
- from utils.logger import get_logger
35
- logger = get_logger(__name__)
36
- except Exception:
37
- logger = logging.getLogger(__name__)
38
-
39
- # ---------------------------------------------------------------------------
40
- # Local video-writer helper
41
- # ---------------------------------------------------------------------------
42
- def create_video_writer(output_path: str, fps: float, width: int, height: int, prefer_mp4: bool = True):
43
- try:
44
- ext = ".mp4" if prefer_mp4 else ".avi"
45
- if not output_path:
46
- output_path = tempfile.mktemp(suffix=ext)
47
- else:
48
- base, curr_ext = os.path.splitext(output_path)
49
- if curr_ext.lower() not in [".mp4", ".avi", ".mov", ".mkv"]:
50
- output_path = base + ext
51
-
52
- fourcc = cv2.VideoWriter_fourcc(*("mp4v" if prefer_mp4 else "XVID"))
53
- writer = cv2.VideoWriter(output_path, fourcc, float(fps), (int(width), int(height)))
54
- if writer is None or not writer.isOpened():
55
- alt_ext = ".avi" if prefer_mp4 else ".mp4"
56
- alt_fourcc = cv2.VideoWriter_fourcc(*("XVID" if prefer_mp4 else "mp4v"))
57
- alt_path = os.path.splitext(output_path)[0] + alt_ext
58
- writer = cv2.VideoWriter(alt_path, alt_fourcc, float(fps), (int(width), int(height)))
59
- if writer is None or not writer.isOpened():
60
- return None, output_path
61
- return writer, alt_path
62
- return writer, output_path
63
- except Exception as e:
64
- logger.error(f"create_video_writer failed: {e}")
65
- return None, output_path
66
-
67
- # ---------------------------------------------------------------------------
68
- # Key-colour helpers (fast, no external deps)
69
- # ---------------------------------------------------------------------------
70
- def _bgr_to_hsv_hue_deg(bgr: np.ndarray) -> np.ndarray:
71
- hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
72
- # OpenCV H is 0-180; scale to degrees 0-360
73
- return hsv[..., 0].astype(np.float32) * 2.0
74
-
75
-
76
- def _hue_distance(a_deg: float, b_deg: float) -> float:
77
- """Circular distance on the hue wheel (degrees)."""
78
- d = abs(a_deg - b_deg) % 360.0
79
- return min(d, 360.0 - d)
80
-
81
-
82
- def _key_candidates_bgr() -> dict:
83
- return {
84
- "green": {"bgr": np.array([ 0,255, 0], dtype=np.uint8), "hue": 120.0},
85
- "blue": {"bgr": np.array([255, 0, 0], dtype=np.uint8), "hue": 240.0},
86
- "cyan": {"bgr": np.array([255,255, 0], dtype=np.uint8), "hue": 180.0},
87
- "magenta": {"bgr": np.array([255, 0,255], dtype=np.uint8), "hue": 300.0},
88
- }
89
-
90
-
91
- def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dict:
92
- """Pick the candidate colour farthest from the actor's dominant hues."""
93
- try:
94
- fg = frame_bgr[mask_uint8 > 127]
95
- if fg.size < 1_000:
96
- return _key_candidates_bgr()["green"]
97
-
98
- fg_hue = _bgr_to_hsv_hue_deg(fg.reshape(-1, 1, 3)).reshape(-1)
99
- hist, edges = np.histogram(fg_hue, bins=36, range=(0.0, 360.0))
100
- top_idx = np.argsort(hist)[-3:]
101
- top_hues = [(edges[i] + edges[i+1]) * 0.5 for i in top_idx]
102
-
103
- best_name, best_score = None, -1.0
104
- for name, info in _key_candidates_bgr().items():
105
- cand_hue = info["hue"]
106
- score = min(abs((cand_hue - th + 180) % 360 - 180) for th in top_hues)
107
- if score > best_score:
108
- best_name, best_score = name, score
109
- return _key_candidates_bgr().get(best_name, _key_candidates_bgr()["green"])
110
- except Exception:
111
- return _key_candidates_bgr()["green"]
112
-
113
-
114
- # ---------------------------------------------------------------------------
115
- # Chroma presets
116
- # ---------------------------------------------------------------------------
117
- CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
118
- 'standard': {'key_color': [0,255,0], 'tolerance': 38, 'edge_softness': 2, 'spill_suppression': 0.35},
119
- 'studio': {'key_color': [0,255,0], 'tolerance': 30, 'edge_softness': 1, 'spill_suppression': 0.45},
120
- 'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
121
  }
122
 
123
- # ---------------------------------------------------------------------------
124
- # Quality profiles (env: BFX_QUALITY = speed | balanced | max)
125
- # ---------------------------------------------------------------------------
126
- QUALITY_PROFILES: Dict[str, Dict[str, Any]] = {
127
- "speed": dict(refine_stride=4, spill=0.30, edge_softness=2, mix=0.60, dilate=0, blur=0, bg_sigma=0.0),
128
- "balanced": dict(refine_stride=2, spill=0.40, edge_softness=2, mix=0.75, dilate=1, blur=1, bg_sigma=0.6),
129
- "max": dict(refine_stride=1, spill=0.45, edge_softness=3, mix=0.85, dilate=2, blur=2, bg_sigma=1.0),
130
  }
131
 
132
- # ---------------------------------------------------------------------------
133
- # Two-Stage Processor
134
- # ---------------------------------------------------------------------------
135
- class TwoStageProcessor:
136
- def __init__(self, sam2_predictor=None, matanyone_model=None):
137
- self.sam2 = self._unwrap_sam2(sam2_predictor)
138
- self.matanyone = matanyone_model
139
- self.mask_cache_dir = Path("/tmp/mask_cache")
140
- self.mask_cache_dir.mkdir(parents=True, exist_ok=True)
141
-
142
- # Internal flags/state
143
- self._mat_bootstrapped = False
144
- self._alpha_prev: Optional[np.ndarray] = None # temporal smoothing
145
-
146
- # Quality selection
147
- qname = os.getenv("BFX_QUALITY", "balanced").strip().lower()
148
- if qname not in QUALITY_PROFILES:
149
- qname = "balanced"
150
- self.quality = qname
151
- self.q = QUALITY_PROFILES[qname]
152
- logger.info(f"TwoStageProcessor quality='{self.quality}' β‡’ {self.q}")
153
-
154
- logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
155
-
156
- # --------------------------- internal utils ---------------------------
157
-
158
- def _unwrap_sam2(self, predictor):
159
- """Unwrap the SAM2 predictor if needed."""
160
- if predictor is None:
161
- return None
162
- if hasattr(predictor, 'sam_predictor'):
163
- return predictor.sam_predictor
164
- return predictor
165
-
166
- def _get_mask(self, frame: np.ndarray) -> np.ndarray:
167
- """Get segmentation mask using SAM2 (delegates to project helper)."""
168
- if self.sam2 is None:
169
- # Fallback: simple luminance threshold (kept to avoid breaking callers)
170
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
171
- _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
172
- return mask
173
-
174
- try:
175
- mask = segment_person_hq(frame, self.sam2)
176
- # segment_person_hq returns either uint8(0..255) or float(0..1) in most builds
177
- return mask
178
- except Exception as e:
179
- logger.warning(f"SAM2 segmentation failed: {e}")
180
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
181
- _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
182
- return mask
183
-
184
- @staticmethod
185
- def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
186
- """Convert mask to uint8(0..255)."""
187
- if mask is None:
188
- return None
189
- if mask.dtype == bool:
190
- return mask.astype(np.uint8) * 255
191
- if np.issubdtype(mask.dtype, np.floating):
192
- m = np.clip(mask, 0.0, 1.0)
193
- return (m * 255.0 + 0.5).astype(np.uint8)
194
- return mask
195
-
196
- @staticmethod
197
- def _to_float01(mask: np.ndarray, h: int = None, w: int = None) -> Optional[np.ndarray]:
198
- """Float [0,1] mask, optionally resized to (h,w)."""
199
- if mask is None:
200
- return None
201
- m = mask.astype(np.float32)
202
- if m.max() > 1.0:
203
- m = m / 255.0
204
- if h is not None and w is not None and (m.shape[0] != h or m.shape[1] != w):
205
- m = cv2.resize(m, (w, h), interpolation=cv2.INTER_LINEAR)
206
- return np.clip(m, 0.0, 1.0)
207
-
208
- def _apply_greenscreen_hard(self, frame: np.ndarray, mask: np.ndarray, bg: np.ndarray) -> np.ndarray:
209
- """Apply hard greenscreen compositing."""
210
- mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) if mask.ndim == 2 else mask
211
- mask_norm = mask_3ch.astype(np.float32) / 255.0
212
- result = frame * mask_norm + bg * (1 - mask_norm)
213
- return result.astype(np.uint8)
214
-
215
- # -------- improved spill suppression (preserves luminance & skin) --------
216
- def _suppress_green_spill(self, frame: np.ndarray, amount: float = 0.35) -> np.ndarray:
217
- """
218
- Desaturate green dominance while preserving luminance and red skin hues.
219
- amount: 0..1
220
- """
221
- b, g, r = cv2.split(frame.astype(np.float32))
222
- green_dom = (g > r) & (g > b)
223
- avg_rb = (r + b) * 0.5
224
- g2 = np.where(green_dom, g*(1.0-amount) + avg_rb*amount, g)
225
- skin = (r > g + 12) # protect skin tones
226
- g2 = np.where(skin, g, g2)
227
- out = cv2.merge([np.clip(b,0,255), np.clip(g2,0,255), np.clip(r,0,255)]).astype(np.uint8)
228
- return out
229
-
230
- # -------- edge-aware alpha refinement (guided-like) --------
231
- def _refine_alpha_edges(self, frame_bgr: np.ndarray, alpha_u8: np.ndarray, radius: int = 3, iters: int = 1) -> np.ndarray:
232
- """
233
- Fast, dependency-free, guided-like refinement on the alpha border.
234
- Returns: uint8 alpha
235
- """
236
- a = alpha_u8.astype(np.uint8)
237
- if radius <= 0:
238
- return a
239
-
240
- band = cv2.Canny(a, 32, 64)
241
- if band.max() == 0:
242
- return a
243
-
244
- for _ in range(max(1, iters)):
245
- a_blur = cv2.GaussianBlur(a, (radius*2+1, radius*2+1), 0)
246
- b,g,r = cv2.split(frame_bgr.astype(np.float32))
247
- green_dom = (g > r) & (g > b)
248
- spill_mask = (green_dom & (a > 96) & (a < 224)).astype(np.uint8)*255
249
- u = cv2.bitwise_or(band, spill_mask)
250
- a = np.where(u>0, a_blur, a).astype(np.uint8)
251
-
252
- return a
253
-
254
- # -------- soft key based on chosen color (robust to blue/cyan/magenta) --------
255
- def _soft_key_mask(self, frame_bgr: np.ndarray, key_bgr: np.ndarray, tol: int = 40) -> np.ndarray:
256
- """
257
- Soft chroma mask (uint8 0..255, 255=keep subject) using CbCr distance.
258
- """
259
- if key_bgr is None:
260
- return np.full(frame_bgr.shape[:2], 255, np.uint8)
261
-
262
- ycbcr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb).astype(np.float32)
263
- kycbcr = cv2.cvtColor(key_bgr.reshape(1,1,3).astype(np.uint8), cv2.COLOR_BGR2YCrCb).astype(np.float32)[0,0]
264
- d = np.linalg.norm((ycbcr[...,1:] - kycbcr[1:]), axis=-1)
265
- d = cv2.GaussianBlur(d, (5,5), 0)
266
- alpha = 255.0 * np.clip((d - tol) / (tol*1.7), 0.0, 1.0) # far from key = keep (255)
267
- return alpha.astype(np.uint8)
268
-
269
- # --------------------- NEW: MatAnyone bootstrap ----------------------
270
- def _bootstrap_matanyone_if_needed(self, frame_bgr: np.ndarray, coarse_mask: np.ndarray):
271
- """
272
- Call the MatAnyone session ONCE with the first coarse mask to initialize
273
- its memory. This guarantees downstream calls never hit "first frame without a mask".
274
- """
275
- if self.matanyone is None or self._mat_bootstrapped:
276
- return
277
- try:
278
- h, w = frame_bgr.shape[:2]
279
- mask_f = self._to_float01(coarse_mask, h, w)
280
- rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
281
- _ = self.matanyone(rgb, mask_f) # boot only; ignore returned alpha
282
- self._mat_bootstrapped = True
283
- logger.info("MatAnyone session bootstrapped with first-frame mask.")
284
- except Exception as e:
285
- logger.warning(f"MatAnyone bootstrap failed (continuing without): {e}")
286
-
287
- # ---------------------------------------------------------------------
288
- # Stage 1 – Original β†’ keyed (green/blue/…) -- chooses colour on 1st frame
289
- # ---------------------------------------------------------------------
290
- def stage1_extract_to_greenscreen(
291
- self,
292
- video_path: str,
293
- output_path: str,
294
- *,
295
- key_color_mode: str = "auto", # "auto" | "green" | "blue" | "cyan" | "magenta"
296
- progress_callback: Optional[Callable[[float, str], None]] = None,
297
- stop_event: Optional["threading.Event"] = None,
298
- ) -> Tuple[Optional[dict], str]:
299
-
300
- def _prog(p, d):
301
- if progress_callback:
302
- try:
303
- progress_callback(float(p), str(d))
304
- except Exception:
305
- pass
306
-
307
- try:
308
- _prog(0.0, "Stage 1: opening video…")
309
- cap = cv2.VideoCapture(video_path)
310
- if not cap.isOpened():
311
- return None, "Could not open input video"
312
-
313
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
314
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
315
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
316
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
317
-
318
- writer, out_path = create_video_writer(output_path, fps, w, h)
319
- if writer is None:
320
- cap.release()
321
- return None, "Could not create output writer"
322
-
323
- key_info: dict | None = None
324
- chosen_bgr = np.array([0, 255, 0], np.uint8) # default
325
- probe_done = False
326
- masks: List[np.ndarray] = []
327
- frame_idx = 0
328
-
329
- solid_bg = np.zeros((h, w, 3), np.uint8) # overwritten per-frame
330
-
331
- while True:
332
- if stop_event and stop_event.is_set():
333
- _prog(1.0, "Stage 1: cancelled")
334
- break
335
-
336
- ok, frame = cap.read()
337
- if not ok:
338
- break
339
-
340
- # --- SAM2 segmentation ---
341
- mask = self._get_mask(frame)
342
-
343
- # --- MatAnyone bootstrap exactly once (first frame) ---
344
- if frame_idx == 0 and self.matanyone is not None:
345
- try:
346
- self._bootstrap_matanyone_if_needed(frame, mask)
347
- except Exception as e:
348
- logger.warning(f"Bootstrap error (non-fatal): {e}")
349
-
350
- # --- Decide key colour once ---
351
- if not probe_done:
352
- if key_color_mode.lower() == "auto":
353
- key_info = _choose_best_key_color(frame, self._to_binary_mask(mask))
354
- chosen_bgr = key_info["bgr"]
355
- else:
356
- cand = _key_candidates_bgr().get(key_color_mode.lower())
357
- if cand is not None:
358
- chosen_bgr = cand["bgr"]
359
- probe_done = True
360
- logger.info(f"[TwoStage] Using key colour: {key_color_mode} β†’ {chosen_bgr.tolist()}")
361
-
362
- # --- Optional refinement via MatAnyone (profile cadence) ---
363
- stride = int(self.q.get("refine_stride", 3))
364
- if self.matanyone and (frame_idx % max(1, stride) == 0):
365
- try:
366
- mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
367
- except Exception as e:
368
- logger.warning(f"MatAnyOne refine fail f={frame_idx}: {e}")
369
-
370
- # --- Composite onto solid key colour ---
371
- solid_bg[:] = chosen_bgr
372
- mask_u8 = self._to_binary_mask(mask)
373
- gs = self._apply_greenscreen_hard(frame, mask_u8, solid_bg)
374
- writer.write(gs)
375
- masks.append(mask_u8)
376
-
377
- frame_idx += 1
378
- pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
379
- _prog(pct, f"Stage 1: {frame_idx}/{total or '?'}")
380
-
381
- cap.release()
382
- writer.release()
383
-
384
- # save mask cache
385
- try:
386
- cache_file = self.mask_cache_dir / (Path(out_path).stem + "_masks.pkl")
387
- with open(cache_file, "wb") as f:
388
- pickle.dump(masks, f)
389
- except Exception as e:
390
- logger.warning(f"mask cache save fail: {e}")
391
-
392
- _prog(1.0, "Stage 1: complete")
393
- return (
394
- {"path": out_path, "frames": frame_idx, "key_bgr": chosen_bgr.tolist()},
395
- f"Green-screen video created ({frame_idx} frames)"
396
  )
397
 
398
- except Exception as e:
399
- logger.error(f"Stage 1 error: {e}\n{traceback.format_exc()}")
400
- return None, f"Stage 1 failed: {e}"
401
-
402
- # ---------------------------------------------------------------------
403
- # Stage 2 – keyed video β†’ final composite (hybrid matte)
404
- # ---------------------------------------------------------------------
405
- def stage2_greenscreen_to_final(
406
- self,
407
- gs_path: str,
408
- background: np.ndarray | str,
409
- output_path: str,
410
- *,
411
- chroma_settings: Optional[Dict[str, Any]] = None,
412
- progress_callback: Optional[Callable[[float, str], None]] = None,
413
- stop_event: Optional["threading.Event"] = None,
414
- key_bgr: Optional[np.ndarray] = None, # pass chosen key color
415
- ) -> Tuple[Optional[str], str]:
416
-
417
- def _prog(p, d):
418
- if progress_callback:
419
- try:
420
- progress_callback(float(p), str(d))
421
- except Exception:
422
- pass
423
-
424
- try:
425
- _prog(0.0, "Stage 2: opening keyed video…")
426
- cap = cv2.VideoCapture(gs_path)
427
- if not cap.isOpened():
428
- return None, "Could not open keyed video"
429
-
430
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
431
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
432
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
433
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
434
-
435
- # Load or prepare background
436
- if isinstance(background, str):
437
- bg = cv2.imread(background)
438
- if bg is None:
439
- cap.release()
440
- return None, "Could not load background image"
441
- bg = cv2.resize(bg, (w, h))
442
- else:
443
- bg = cv2.resize(background, (w, h))
444
-
445
- # Optional tiny BG blur per profile to hide seams on flat BGs
446
- sigma = float(self.q.get("bg_sigma", 0.0))
447
- if sigma > 0:
448
- bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=sigma, sigmaY=sigma)
449
-
450
- writer, out_path = create_video_writer(output_path, fps, w, h)
451
- if writer is None:
452
- cap.release()
453
- return None, "Could not create output writer"
454
-
455
- # Load cached masks if available
456
- masks = None
457
- try:
458
- cache_file = self.mask_cache_dir / (Path(gs_path).stem + "_masks.pkl")
459
- if cache_file.exists():
460
- with open(cache_file, "rb") as f:
461
- masks = pickle.load(f)
462
- logger.info(f"Loaded {len(masks)} cached masks")
463
- except Exception as e:
464
- logger.warning(f"Could not load mask cache: {e}")
465
-
466
- # Get chroma settings and override with profile
467
- settings = chroma_settings or CHROMA_PRESETS.get('standard', {})
468
- tolerance = int(settings.get('tolerance', 38)) # keep user tolerance
469
- edge_softness = int(self.q.get('edge_softness', settings.get('edge_softness', 2)))
470
- spill_suppression = float(self.q.get('spill', settings.get('spill_suppression', 0.35)))
471
-
472
- # If caller didn't pass key_bgr, try preset or default green
473
- if key_bgr is None:
474
- key_bgr = np.array(settings.get('key_color', [0,255,0]), dtype=np.uint8)
475
-
476
- self._alpha_prev = None # reset temporal smoothing per render
477
-
478
- frame_idx = 0
479
- while True:
480
- if stop_event and stop_event.is_set():
481
- _prog(1.0, "Stage 2: cancelled")
482
- break
483
-
484
- ok, frame = cap.read()
485
- if not ok:
486
- break
487
-
488
- # Apply chroma keying with optional mask assistance
489
- if masks and frame_idx < len(masks):
490
- mask = masks[frame_idx]
491
- final_frame = self._hybrid_composite(
492
- frame, bg, mask,
493
- tolerance=tolerance,
494
- edge_softness=edge_softness,
495
- spill_suppression=spill_suppression,
496
- key_bgr=key_bgr
497
  )
498
- else:
499
- # Pure chroma key
500
- final_frame = self._chroma_key_composite(
501
- frame, bg,
502
- tolerance=tolerance,
503
- edge_softness=edge_softness,
504
- spill_suppression=spill_suppression,
505
- key_bgr=key_bgr
 
 
 
 
 
 
 
506
  )
507
 
508
- writer.write(final_frame)
509
- frame_idx += 1
510
- pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
511
- _prog(pct, f"Stage 2: {frame_idx}/{total or '?'}")
512
-
513
- cap.release()
514
- writer.release()
515
-
516
- _prog(1.0, "Stage 2: complete")
517
- return out_path, f"Final composite created ({frame_idx} frames)"
518
-
519
- except Exception as e:
520
- logger.error(f"Stage 2 error: {e}\n{traceback.format_exc()}")
521
- return None, f"Stage 2 failed: {e}"
522
-
523
- # ---------------- chroma + hybrid compositors (polished) ----------------
524
- def _chroma_key_composite(self, frame, bg, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
525
- """Apply chroma key compositing with soft color distance + edge refinement."""
526
- # 1) spill first
527
- if spill_suppression > 0:
528
- frame = self._suppress_green_spill(frame, spill_suppression)
529
-
530
- # 2) build alpha
531
- if key_bgr is not None:
532
- alpha = self._soft_key_mask(frame, key_bgr, tol=int(tolerance))
533
- else:
534
- # Fallback: HSV green range
535
- hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
536
- lower_green = np.array([40, 40, 40])
537
- upper_green = np.array([80, 255, 255])
538
- alpha = cv2.bitwise_not(cv2.inRange(hsv, lower_green, upper_green))
539
-
540
- # 3) soft edges + refinement
541
- if edge_softness > 0:
542
- k = edge_softness * 2 + 1
543
- alpha = cv2.GaussianBlur(alpha, (k, k), 0)
544
- alpha = self._refine_alpha_edges(frame, alpha, radius=max(1, edge_softness), iters=1)
545
-
546
- # 4) temporal smoothing
547
- if self._alpha_prev is not None and self._alpha_prev.shape == alpha.shape:
548
- alpha = cv2.addWeighted(alpha, 0.75, self._alpha_prev, 0.25, 0)
549
- self._alpha_prev = alpha
550
-
551
- # 5) composite
552
- mask_3ch = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR).astype(np.float32) / 255.0
553
- out = frame.astype(np.float32) * mask_3ch + bg.astype(np.float32) * (1.0 - mask_3ch)
554
- return np.clip(out, 0, 255).astype(np.uint8)
555
-
556
- def _hybrid_composite(self, frame, bg, mask, *, tolerance=38, edge_softness=2, spill_suppression=0.35, key_bgr: Optional[np.ndarray] = None):
557
- """Apply hybrid compositing using both chroma key and cached mask, with profile controls."""
558
- chroma_result = self._chroma_key_composite(
559
- frame, bg,
560
- tolerance=tolerance,
561
- edge_softness=edge_softness,
562
- spill_suppression=spill_suppression,
563
- key_bgr=key_bgr
564
- )
565
- if mask is None:
566
- return chroma_result
567
-
568
- # profile-driven dilate/feather on cached mask to close pinholes + soften edges
569
- m = mask
570
- d = int(self.q.get("dilate", 0))
571
- if d > 0:
572
- k = 2*d + 1
573
- se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
574
- m = cv2.dilate(m, se, iterations=1)
575
- b = int(self.q.get("blur", 0))
576
- if b > 0:
577
- m = cv2.GaussianBlur(m, (2*b+1, 2*b+1), 0)
578
-
579
- m3 = cv2.cvtColor(m, cv2.COLOR_GRAY2BGR) if m.ndim == 2 else m
580
- m3f = (m3.astype(np.float32) / 255.0)
581
-
582
- seg_comp = frame.astype(np.float32) * m3f + bg.astype(np.float32) * (1.0 - m3f)
583
-
584
- mix = float(self.q.get("mix", 0.7)) # weight towards segmentation on "max"
585
- out = chroma_result.astype(np.float32) * (1.0 - mix) + seg_comp * mix
586
- return np.clip(out, 0, 255).astype(np.uint8)
587
-
588
- # ---------------------------------------------------------------------
589
- # Combined pipeline
590
- # ---------------------------------------------------------------------
591
- def process_full_pipeline(
592
- self,
593
- video_path: str,
594
- background: np.ndarray | str,
595
- output_path: str,
596
- *,
597
- key_color_mode: str = "auto",
598
- chroma_settings: Optional[Dict[str, Any]] = None,
599
- progress_callback: Optional[Callable[[float, str], None]] = None,
600
- stop_event: Optional["threading.Event"] = None,
601
- ) -> Tuple[Optional[str], str]:
602
- """Run both stages in sequence."""
603
-
604
- def _combined_progress(pct, desc):
605
- # Scale progress: Stage 1 is 0-50%, Stage 2 is 50-100%
606
- if "Stage 1" in desc:
607
- actual_pct = pct * 0.5
608
- else: # Stage 2
609
- actual_pct = 0.5 + pct * 0.5
610
-
611
- if progress_callback:
612
- try:
613
- progress_callback(actual_pct, desc)
614
- except Exception:
615
- pass
616
-
617
- try:
618
- # Reset per-video state
619
- self._mat_bootstrapped = False
620
- self._alpha_prev = None
621
- if self.matanyone is not None and hasattr(self.matanyone, "reset"):
622
- try:
623
- self.matanyone.reset()
624
- except Exception:
625
- pass
626
-
627
- # Stage 1
628
- temp_gs_path = tempfile.mktemp(suffix="_greenscreen.mp4")
629
- stage1_result, stage1_msg = self.stage1_extract_to_greenscreen(
630
- video_path, temp_gs_path,
631
- key_color_mode=key_color_mode,
632
- progress_callback=_combined_progress,
633
- stop_event=stop_event
634
  )
635
- if stage1_result is None:
636
- return None, stage1_msg
637
-
638
- # Stage 2 (pass through chosen key color)
639
- key_bgr = np.array(stage1_result.get("key_bgr", [0,255,0]), dtype=np.uint8)
640
- final_path, stage2_msg = self.stage2_greenscreen_to_final(
641
- stage1_result["path"], background, output_path,
642
- chroma_settings=chroma_settings,
643
- progress_callback=_combined_progress,
644
- stop_event=stop_event,
645
- key_bgr=key_bgr,
 
646
  )
647
 
648
- # Clean up temp file
649
- try:
650
- os.remove(temp_gs_path)
651
- except Exception:
652
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
- return final_path, stage2_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
- except Exception as e:
657
- logger.error(f"Full pipeline error: {e}\n{traceback.format_exc()}")
658
- return None, f"Pipeline failed: {e}"
 
1
  #!/usr/bin/env python3
2
  """
3
+ UI Components for BackgroundFX Pro (forced Two-Stage)
4
+ -----------------------------------------------------
5
+ * Pure layout (tiny wrapper to set env for quality)
6
+ * All heavy logic stays in ui/callbacks.py
7
+ * Two-stage mode is always active (checkbox removed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
+
10
  from __future__ import annotations
11
+ import os
12
+ import gradio as gr
13
+
14
+ from ui.callbacks import (
15
+ cb_load_models,
16
+ cb_process_video,
17
+ cb_cancel,
18
+ cb_status,
19
+ cb_clear,
20
+ cb_generate_bg,
21
+ cb_use_gen_bg,
22
+ cb_preset_bg_preview,
23
+ )
24
+
25
+ # Typography & UI polish: sharper text + cleaner cards
26
+ CSS = """
27
+ :root {
28
+ --radius: 16px;
29
+ --font-sans: 'Inter', system-ui, -apple-system, 'Segoe UI', Roboto,
30
+ 'Helvetica Neue', Arial, sans-serif;
31
+ }
32
 
33
+ /* Global crisp text */
34
+ html, body, .gradio-container, .gradio-container * {
35
+ font-family: var(--font-sans) !important;
36
+ -webkit-font-smoothing: antialiased !important;
37
+ -moz-osx-font-smoothing: grayscale !important;
38
+ text-rendering: optimizeLegibility !important;
39
+ font-synthesis-weight: none;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  }
41
 
42
+ /* Headings tighter & bolder */
43
+ .gradio-container h1, .gradio-container h2, .gradio-container h3 {
44
+ letter-spacing: -0.01em;
45
+ font-weight: 700;
 
 
 
46
  }
47
 
48
+ /* Body copy slightly tighter */
49
+ #hero .prose, .gr-markdown, .gr-text {
50
+ letter-spacing: -0.003em;
51
+ }
52
+
53
+ /* Card look */
54
+ .card {
55
+ border-radius: var(--radius);
56
+ border: 1px solid rgba(0,0,0,.08);
57
+ padding: 16px;
58
+ background: linear-gradient(180deg, rgba(255,255,255,.94), rgba(248,250,252,.94));
59
+ box-shadow: 0 10px 30px rgba(0,0,0,.06);
60
+ }
61
+
62
+ .footer-note { opacity: 0.7; font-size: 12px; }
63
+ .sm { font-size: 13px; opacity: 0.85; }
64
+ #statusbox { min-height: 120px; }
65
+ .preview-img { border-radius: var(--radius); border: 1px solid rgba(0,0,0,.08); }
66
+
67
+ /* Buttons get a tiny weight bump for clarity */
68
+ button, .gr-button { font-weight: 600; }
69
+
70
+ /* Inline Quality select between buttons */
71
+ .inline-quality .wrap-inner { min-width: 170px; }
72
+ """
73
+
74
+ # Keep in sync with utils/cv_processing.PROFESSIONAL_BACKGROUNDS
75
+ _BG_CHOICES = [
76
+ "minimalist",
77
+ "office_modern",
78
+ "studio_blue",
79
+ "studio_green",
80
+ "warm_gradient",
81
+ "tech_dark",
82
+ ]
83
+ PRO_IMAGE_CHOICES = ["minimalist", "office_modern", "studio_blue", "studio_green"]
84
+ GRADIENT_COLOR_CHOICES = ["warm_gradient", "tech_dark"]
85
+
86
+
87
+ def create_interface() -> gr.Blocks:
88
+ with gr.Blocks(
89
+ title="🎬 BackgroundFX Pro",
90
+ css=CSS,
91
+ analytics_enabled=False,
92
+ theme=gr.themes.Soft()
93
+ ) as demo:
94
+
95
+ # ------------------------------------------------------------------
96
+ # HERO
97
+ # ------------------------------------------------------------------
98
+ with gr.Row(elem_id="hero"):
99
+ gr.Markdown(
100
+ "## 🎬 BackgroundFX Pro (CSP-Safe)\n"
101
+ "Replace your video background with cinema-quality AI matting. "
102
+ "Built for Hugging Face Spaces CSP.\n\n"
103
+ "_Tip: press **Load Models** once after the Space spins up._"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  )
105
 
106
+ # ------------------------------------------------------------------
107
+ # TAB – Quick Start
108
+ # ------------------------------------------------------------------
109
+ with gr.Tab("🏁 Quick Start"):
110
+
111
+ with gr.Row():
112
+ # ── Left column ────────────────────────────────────────────
113
+ with gr.Column(scale=1):
114
+ video = gr.Video(label="Upload Video", interactive=True)
115
+
116
+ # Hidden: effective preset key (still used by callbacks / defaults)
117
+ bg_style = gr.Dropdown(
118
+ label="Background Style (hidden)",
119
+ choices=_BG_CHOICES,
120
+ value="minimalist",
121
+ visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
+
124
+ # =======================
125
+ # Background Source block
126
+ # =======================
127
+ gr.Markdown("### πŸ–ΌοΈ Background Source")
128
+
129
+ bg_method = gr.Radio(
130
+ label="Choose method",
131
+ choices=[
132
+ "Upload file",
133
+ "Pre-loaded professional images",
134
+ "Pre-loaded Gradients / Colors",
135
+ "AI generated background",
136
+ ],
137
+ value="Upload file",
138
  )
139
 
140
+ # a) Upload file option
141
+ with gr.Group(visible=True) as grp_upload:
142
+ custom_bg = gr.Image(
143
+ label="Upload Background Image",
144
+ interactive=True,
145
+ type="filepath", # returns file path
146
+ elem_classes=["preview-img"]
147
+ )
148
+
149
+ # b) Pre-loaded professional images
150
+ with gr.Group(visible=False) as grp_pro_images:
151
+ pro_image_dd = gr.Dropdown(
152
+ label="Professional Images",
153
+ choices=PRO_IMAGE_CHOICES,
154
+ value=PRO_IMAGE_CHOICES[0],
155
+ info="Pre-defined photo-like backgrounds",
156
+ )
157
+ gr.Markdown(
158
+ "<span class='sm'>Selecting a preset updates the preview below.</span>"
159
+ )
160
+
161
+ # c) Pre-loaded gradients & full colors
162
+ with gr.Group(visible=False) as grp_gradients:
163
+ gradient_dd = gr.Dropdown(
164
+ label="Gradients & Full Colors",
165
+ choices=GRADIENT_COLOR_CHOICES,
166
+ value=GRADIENT_COLOR_CHOICES[0],
167
+ info="Clean gradients or solid color styles",
168
+ )
169
+ gr.Markdown(
170
+ "<span class='sm'>Selecting a preset updates the preview below.</span>"
171
+ )
172
+
173
+ # d) AI-generated background (inline, lightweight)
174
+ with gr.Group(visible=False) as grp_ai:
175
+ prompt = gr.Textbox(
176
+ label="Describe vibe",
177
+ value="modern office",
178
+ info="e.g. 'soft sunset studio', 'cool tech dark', 'forest ambience'"
179
+ )
180
+ with gr.Row():
181
+ gen_width = gr.Slider(640, 1920, 1280, step=10, label="Width")
182
+ gen_height = gr.Slider(360, 1080, 720, step=10, label="Height")
183
+ with gr.Row():
184
+ bokeh = gr.Slider(0, 30, 8, step=1, label="Bokeh Blur")
185
+ vignette = gr.Slider(0, 0.6, 0.15, step=0.01, label="Vignette")
186
+ contrast = gr.Slider(0.8, 1.4, 1.05, step=0.01, label="Contrast")
187
+ with gr.Row():
188
+ btn_gen_bg_inline = gr.Button("✨ Generate Background", variant="primary")
189
+ use_gen_as_custom_inline = gr.Button("πŸ“Œ Use as Custom Background", variant="secondary")
190
+ gen_preview = gr.Image(
191
+ label="Generated Background",
192
+ interactive=False,
193
+ elem_classes=["preview-img"]
194
+ )
195
+ gen_path = gr.Textbox(label="Saved Path", interactive=False)
196
+
197
+ # ── Advanced options accordion ───────────────────────
198
+ with gr.Accordion("Advanced", open=False):
199
+ chroma_preset = gr.Dropdown(
200
+ label="Chroma Preset",
201
+ choices=["standard"], # can add 'studio', 'outdoor' later
202
+ value="standard"
203
+ )
204
+ key_color_mode = gr.Dropdown(
205
+ label="Key-Colour Mode",
206
+ choices=["auto", "green", "blue", "cyan", "magenta"],
207
+ value="auto",
208
+ info="Auto picks a colour far from your clothes; override if needed."
209
+ )
210
+ preview_mask = gr.Checkbox(
211
+ label="Preview Mask only (mute audio)",
212
+ value=False
213
+ )
214
+ preview_greenscreen = gr.Checkbox(
215
+ label="Preview Green-screen only (mute audio)",
216
+ value=False
217
+ )
218
+
219
+ # ── Controls row: Load β†’ Quality β†’ Process / Cancel ──
220
+ with gr.Row():
221
+ btn_load = gr.Button("πŸ”„ Load Models", variant="secondary")
222
+
223
+ quality = gr.Dropdown(
224
+ label="Quality",
225
+ choices=["speed", "balanced", "max"],
226
+ value=os.getenv("BFX_QUALITY", "balanced"),
227
+ info="Speed = fastest; Max = best edges & spill control.",
228
+ elem_classes=["inline-quality"],
229
+ )
230
+
231
+ btn_run = gr.Button("🎬 Process Video", variant="primary")
232
+ btn_cancel = gr.Button("⏹️ Cancel", variant="secondary")
233
+
234
+ # ── Right column ──────────────────────────────────────────
235
+ with gr.Column(scale=1):
236
+ out_video = gr.Video(label="Processed Output", interactive=False)
237
+ statusbox = gr.Textbox(label="Status", lines=8, elem_id="statusbox")
238
+ with gr.Row():
239
+ btn_refresh = gr.Button("πŸ” Refresh Status", variant="secondary")
240
+ btn_clear = gr.Button("🧹 Clear", variant="secondary")
241
+
242
+ # ------------------------------------------------------------------
243
+ # TAB – Status & settings
244
+ # ------------------------------------------------------------------
245
+ with gr.Tab("πŸ“ˆ Status & Settings"):
246
+ with gr.Row():
247
+ with gr.Column(scale=1, elem_classes=["card"]):
248
+ model_status = gr.JSON(label="Model Status")
249
+ with gr.Column(scale=1, elem_classes=["card"]):
250
+ cache_status = gr.JSON(label="Cache / System Status")
251
+
252
+ gr.Markdown(
253
+ "<div class='footer-note'>If models fail to load, fallbacks keep the UI responsive. "
254
+ "Check the runtime log for details.</div>"
 
 
 
 
 
 
 
 
 
 
 
255
  )
256
+
257
+ # ------------------------------------------------------------------
258
+ # Callback wiring
259
+ # ------------------------------------------------------------------
260
+
261
+ # Toggle which background sub-section is visible
262
+ def _toggle_bg_sections(choice: str):
263
+ return (
264
+ gr.update(visible=(choice == "Upload file")),
265
+ gr.update(visible=(choice == "Pre-loaded professional images")),
266
+ gr.update(visible=(choice == "Pre-loaded Gradients / Colors")),
267
+ gr.update(visible=(choice == "AI generated background")),
268
  )
269
 
270
+ bg_method.change(
271
+ _toggle_bg_sections,
272
+ inputs=[bg_method],
273
+ outputs=[grp_upload, grp_pro_images, grp_gradients, grp_ai],
274
+ )
275
+
276
+ # Load models
277
+ btn_load.click(cb_load_models, outputs=statusbox)
278
+
279
+ # Tiny wrapper to set env for quality before calling the existing callback
280
+ def _run_with_quality(video_pth, bg_style_val, custom_bg_pth,
281
+ use_two_stage_state, chroma_p, key_mode,
282
+ prev_mask, prev_gs, quality_val):
283
+ os.environ["BFX_QUALITY"] = (quality_val or "balanced")
284
+ return cb_process_video(
285
+ video_pth, bg_style_val, custom_bg_pth,
286
+ use_two_stage_state, chroma_p, key_mode, prev_mask, prev_gs
287
+ )
288
+
289
+ # Always two-stage: pass use_two_stage=True to callback via State
290
+ btn_run.click(
291
+ _run_with_quality,
292
+ inputs=[
293
+ video,
294
+ bg_style,
295
+ custom_bg,
296
+ gr.State(value=True), # Always two-stage
297
+ chroma_preset,
298
+ key_color_mode,
299
+ preview_mask,
300
+ preview_greenscreen,
301
+ quality, # <-- the new control
302
+ ],
303
+ outputs=[out_video, statusbox],
304
+ )
305
+
306
+ # Cancel / Status / Clear
307
+ btn_cancel.click(cb_cancel, outputs=statusbox)
308
+ btn_refresh.click(cb_status, outputs=[model_status, cache_status])
309
+
310
+ btn_clear.click(
311
+ cb_clear,
312
+ outputs=[out_video, statusbox, gen_preview, gen_path, custom_bg]
313
+ )
314
 
315
+ # Preloaded presets β†’ update preview (write into custom_bg)
316
+ pro_image_dd.change(
317
+ cb_preset_bg_preview,
318
+ inputs=[pro_image_dd],
319
+ outputs=[custom_bg],
320
+ )
321
+ gradient_dd.change(
322
+ cb_preset_bg_preview,
323
+ inputs=[gradient_dd],
324
+ outputs=[custom_bg],
325
+ )
326
+
327
+ # AI background generation (inline)
328
+ btn_gen_bg_inline.click(
329
+ cb_generate_bg,
330
+ inputs=[prompt, gen_width, gen_height, bokeh, vignette, contrast],
331
+ outputs=[gen_preview, gen_path],
332
+ )
333
+ use_gen_as_custom_inline.click(
334
+ cb_use_gen_bg,
335
+ inputs=[gen_path],
336
+ outputs=[custom_bg],
337
+ )
338
+
339
+ # Initialize with a default preset preview on load
340
+ demo.load(
341
+ cb_preset_bg_preview,
342
+ inputs=[bg_style],
343
+ outputs=[custom_bg]
344
+ )
345
 
346
+ return demo