MogensR commited on
Commit
235ab01
Β·
1 Parent(s): d1fd07a

Update core/app.py

Browse files
Files changed (1) hide show
  1. core/app.py +416 -546
core/app.py CHANGED
@@ -1,587 +1,457 @@
1
  #!/usr/bin/env python3
2
  """
3
- BackgroundFX Pro – Main Application Entry Point
4
- Refactored modular architecture – orchestrates specialised components
5
-
6
- 2025-08-27 update:
7
- - Robust Two-Stage importer: tries package import, then direct file import (bypasses
8
- any side-effect errors in processing/__init__.py). Clear logging of where it loaded.
9
- - Config hardening: adds safe defaults for fields like max_model_size/use_nvenc/etc.
10
- - Defensive error logs with tracebacks for easier diagnosis.
 
 
11
  """
12
 
13
  from __future__ import annotations
14
 
15
- # ─────────────────────────────────────────────────────────────────────────────
16
- # 0) Early env/threading hygiene (must run first)
17
- # ─────────────────────────────────────────────────────────────────────────────
18
- import early_env # sets OMP/MKL/OPENBLAS + torch threads safely
19
-
20
- import logging
21
- import threading
22
- import traceback
23
- import sys
24
- import os
25
- import importlib
26
- import importlib.util
27
  from pathlib import Path
28
- from typing import Optional, Tuple, Dict, Any, Callable
29
-
30
- # ─────────────────────────────────────────────────────────────────────────────
31
- # 1) Logging
32
- # ─────────────────────────────────────────────────────────────────────────────
33
- logging.basicConfig(
34
- level=logging.INFO,
35
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
36
- )
37
- logger = logging.getLogger("core.app")
38
-
39
- # Ensure project root is importable (helps HF Spaces and local runs)
40
- PROJECT_FILE = Path(__file__).resolve()
41
- CORE_DIR = PROJECT_FILE.parent
42
- ROOT = CORE_DIR.parent
43
- if str(ROOT) not in sys.path:
44
- sys.path.insert(0, str(ROOT))
45
-
46
- # ─────────────────────────────────────────────────────────────────────────────
47
- # 2) Patch Gradio schema early (HF Spaces quirk)
48
- # ─────────────────────────────────────────────────────────────────────────────
49
- try:
50
- import gradio_client.utils as gc_utils
51
-
52
- _orig_get_type = gc_utils.get_type
53
-
54
- def _patched_get_type(schema):
55
- if not isinstance(schema, dict):
56
- if isinstance(schema, bool):
57
- return "boolean"
58
- if isinstance(schema, str):
59
- return "string"
60
- if isinstance(schema, (int, float)):
61
- return "number"
62
- return "string"
63
- return _orig_get_type(schema)
64
-
65
- gc_utils.get_type = _patched_get_type
66
- logger.info("Gradio schema patch applied")
67
- except Exception as e:
68
- logger.warning(f"Gradio patch failed: {e}")
69
-
70
- # ─────────────────────────────────────────────────────────────────────────────
71
- # 3) Core config + components
72
- # ─────────────────────────────────────────────────────────────────────────────
73
- from config.app_config import get_config
74
- from core.exceptions import ModelLoadingError, VideoProcessingError
75
- from utils.hardware.device_manager import DeviceManager
76
- from utils.system.memory_manager import MemoryManager
77
- from models.loaders.model_loader import ModelLoader
78
- from processing.video.video_processor import CoreVideoProcessor
79
- from processing.audio.audio_processor import AudioProcessor
80
- from utils.monitoring.progress_tracker import ProgressTracker
81
- from utils.cv_processing import validate_video_file
82
-
83
- # ─────────────────────────────────────────────────────────────────────────────
84
- # 3.1) Optional: Two-stage import with package and file fallbacks
85
- # ─────────────────────────────────────────────────────────────────────────────
86
- def _import_two_stage() -> tuple[bool, Any, Dict[str, Any], str, str]:
87
- """
88
- Returns (available, TwoStageProcessor|None, CHROMA_PRESETS|{}, import_origin, error_text)
89
-
90
- Tries:
91
- 1) package imports (preferred)
92
- 2) direct file imports (bypass processing/__init__.py side-effects)
93
- """
94
- pkg_paths = [
95
- "processing.two_stage.two_stage_processor",
96
- "two_stage_processor",
97
- "processing.two_stage_processor",
98
- ]
99
- fs_paths = [
100
- ROOT / "processing" / "two_stage" / "two_stage_processor.py",
101
- ROOT / "processing" / "two_stage_processor.py",
102
- ROOT / "two_stage_processor.py",
103
- ]
104
-
105
- # Package imports
106
- last_err_text = ""
107
- for mod_path in pkg_paths:
108
- try:
109
- mod = importlib.import_module(mod_path)
110
- TSP = getattr(mod, "TwoStageProcessor")
111
- PRESETS = getattr(mod, "CHROMA_PRESETS", {"standard": {}})
112
- logger.info(f"Two-stage import OK from package '{mod_path}'")
113
- return True, TSP, PRESETS, f"pkg:{mod_path}", ""
114
- except Exception as e:
115
- tb = traceback.format_exc()
116
- last_err_text += f"[pkg:{mod_path}] {repr(e)}\n{tb}\n"
117
-
118
- # Direct file imports (bypass __init__.py)
119
- for path in fs_paths:
120
- try:
121
- if not path.exists():
122
- continue
123
- spec = importlib.util.spec_from_file_location("bx_two_stage", str(path))
124
- if not spec or not spec.loader:
125
- continue
126
- mod = importlib.util.module_from_spec(spec)
127
- sys.modules["bx_two_stage"] = mod
128
- spec.loader.exec_module(mod) # type: ignore[attr-defined]
129
- TSP = getattr(mod, "TwoStageProcessor")
130
- PRESETS = getattr(mod, "CHROMA_PRESETS", {"standard": {}})
131
- logger.info(f"Two-stage import OK from file '{path}'")
132
- return True, TSP, PRESETS, f"file:{path}", ""
133
- except Exception as e:
134
- tb = traceback.format_exc()
135
- last_err_text += f"[file:{path}] {repr(e)}\n{tb}\n"
136
 
137
- logger.error("Two-stage import failed from all paths:\n%s", last_err_text or "(no traceback)")
138
- return False, None, {"standard": {}}, "", last_err_text or "(no traceback)"
139
 
140
- TWO_STAGE_AVAILABLE, _TwoStageProcessor, CHROMA_PRESETS, TWO_STAGE_IMPORT_ORIGIN, TWO_STAGE_IMPORT_ERROR = _import_two_stage()
141
-
142
- # ╔══════════════════════════════════════════════════════════════════════════╗
143
- # β•‘ VideoProcessor class β•‘
144
- # β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
145
- class VideoProcessor:
146
- """
147
- Main orchestrator – coordinates all specialised components.
148
- """
149
 
150
- def __init__(self):
151
- self.config = get_config()
152
- self._patch_config_defaults(self.config) # ensure missing attrs exist
153
 
154
- self.device_manager = DeviceManager()
155
- self.memory_manager = MemoryManager(self.device_manager.get_optimal_device())
156
- self.model_loader = ModelLoader(self.device_manager, self.memory_manager)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- self.audio_processor = AudioProcessor()
159
- self.core_processor: CoreVideoProcessor | None = None
160
- self.two_stage_processor: Any | None = None # instance of TwoStageProcessor if available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- self.models_loaded = False
163
- self.loading_lock = threading.Lock()
164
- self.cancel_event = threading.Event()
165
- self.progress_tracker: ProgressTracker | None = None
 
166
 
167
- logger.info(f"VideoProcessor on device: {self.device_manager.get_optimal_device()}")
 
 
168
 
169
- # ─────────────────────────────────────────────────────────────────────
170
- # Config hardening – add missing attributes to avoid AttributeError
171
- # ─────────────────────────────────────────────────────────────────────
172
- @staticmethod
173
- def _patch_config_defaults(cfg: Any) -> None:
174
- """
175
- Some downstream modules may read fields that older configs don't define.
176
- Add safe defaults here so AttributeErrors never occur.
177
- """
178
- defaults = {
179
- # video i/o & writer
180
- "use_nvenc": False,
181
- "prefer_mp4": True,
182
- "video_codec": "mp4v",
183
- "audio_copy": True,
184
- "ffmpeg_path": "ffmpeg",
185
- # model/resource guards
186
- "max_model_size": 0, # bytes or MB; downstream should treat 0/None as "no limit"
187
- "max_model_size_bytes": 0,
188
- # housekeeping
189
- "output_dir": str(ROOT / "outputs"),
190
- }
191
- for k, v in defaults.items():
192
- if not hasattr(cfg, k):
193
- setattr(cfg, k, v)
194
-
195
- # Ensure output dir exists
196
- try:
197
- Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
198
- except Exception:
199
- out = ROOT / "outputs"
200
- out.mkdir(parents=True, exist_ok=True)
201
- cfg.output_dir = str(out)
202
-
203
- # ─────────────────────────────────────────────────────────────────────
204
- # Progress helper
205
- # ─────────────────────────────────────────────────────────────────────
206
- def _init_progress(self, video_path: str, cb: Optional[Callable] = None):
207
- try:
208
- import cv2
209
- cap = cv2.VideoCapture(video_path)
210
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
211
  cap.release()
212
- if total <= 0:
213
- total = 100
214
- self.progress_tracker = ProgressTracker(total, cb)
215
- except Exception as e:
216
- logger.warning(f"Progress init failed: {e}")
217
- self.progress_tracker = ProgressTracker(100, cb)
218
-
219
- # ─────────────────────────────────────────────────────────────────────
220
- # Model loading
221
- # ─────────────────────────────────────────────────────────────────────
222
- def load_models(self, progress_callback: Optional[Callable] = None) -> str:
223
- with self.loading_lock:
224
- if self.models_loaded:
225
- return "Models already loaded and validated"
226
 
 
227
  try:
228
- self.cancel_event.clear()
229
- if progress_callback:
230
- progress_callback(
231
- 0.0, f"Loading on {self.device_manager.get_optimal_device()}"
232
- )
233
-
234
- sam2_loaded, mat_loaded = self.model_loader.load_all_models(
235
- progress_callback=progress_callback, cancel_event=self.cancel_event
236
- )
237
-
238
- if self.cancel_event.is_set():
239
- return "Model loading cancelled"
240
-
241
- # Unwrap actual predictor / model objects
242
- sam2_predictor = getattr(sam2_loaded, "model", None) if sam2_loaded else None
243
- mat_model = getattr(mat_loaded, "model", None) if mat_loaded else None
244
-
245
- # Core single-stage processor
246
- self.core_processor = CoreVideoProcessor(
247
- config=self.config, models=self.model_loader
248
- )
249
-
250
- # Two-stage processor (optional)
251
- self.two_stage_processor = None
252
- if TWO_STAGE_AVAILABLE and (_TwoStageProcessor is not None) and (sam2_predictor or mat_model):
253
- try:
254
- self.two_stage_processor = _TwoStageProcessor(
255
- sam2_predictor=sam2_predictor, matanyone_model=mat_model
256
- )
257
- logger.info("Two-stage processor initialised (%s)", TWO_STAGE_IMPORT_ORIGIN or "unknown")
258
- except Exception as e:
259
- logger.warning("Two-stage init failed: %r\n%s", e, traceback.format_exc())
260
- self.two_stage_processor = None
261
-
262
- self.models_loaded = True
263
- msg = self.model_loader.get_load_summary()
264
- msg += (
265
- "\nβœ… Two-stage processor ready"
266
- if self.two_stage_processor
267
- else "\n⚠️ Two-stage processor not available"
268
- )
269
- logger.info(msg)
270
- return msg
271
-
272
- except (AttributeError, ModelLoadingError) as e:
273
- self.models_loaded = False
274
- err = f"Model loading failed: {e}"
275
- logger.error(err)
276
- return err
277
  except Exception as e:
278
- self.models_loaded = False
279
- err = f"Unexpected error during model loading: {e}"
280
- logger.error(err)
281
- logger.debug("Traceback:\n%s", traceback.format_exc())
282
- return err
283
-
284
- # ─────────────────────────────────────────────────────────────────────
285
- # Public entry – process video
286
- # ─────────────────────────────────────────────────────────────────────
287
- def process_video(
 
 
 
 
 
 
288
  self,
289
- video_path: str,
290
- background_choice: str,
291
- custom_background_path: Optional[str] = None,
292
- progress_callback: Optional[Callable] = None,
293
- use_two_stage: bool = False,
294
- chroma_preset: str = "standard",
295
- key_color_mode: str = "auto",
296
- preview_mask: bool = False,
297
- preview_greenscreen: bool = False,
298
  ) -> Tuple[Optional[str], str]:
299
- if not self.models_loaded or not self.core_processor:
300
- return None, "Models not loaded. Please click β€œLoad Models” first."
301
- if self.cancel_event.is_set():
302
- return None, "Processing cancelled"
303
-
304
- self._init_progress(video_path, progress_callback)
305
 
306
- ok, why = validate_video_file(video_path)
307
- if not ok:
308
- return None, f"Invalid video: {why}"
 
 
 
309
 
310
  try:
311
- if use_two_stage:
312
- if not TWO_STAGE_AVAILABLE or self.two_stage_processor is None:
313
- return None, "Two-stage processing not available on this build"
314
- return self._process_two_stage(
315
- video_path,
316
- background_choice,
317
- custom_background_path,
318
- progress_callback,
319
- chroma_preset,
320
- key_color_mode,
321
- )
 
 
 
 
 
 
 
 
 
 
 
322
  else:
323
- return self._process_single_stage(
324
- video_path,
325
- background_choice,
326
- custom_background_path,
327
- progress_callback,
328
- preview_mask,
329
- preview_greenscreen,
330
- )
331
-
332
- except VideoProcessingError as e:
333
- logger.error(f"Processing failed: {e}")
334
- return None, f"Processing failed: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  except Exception as e:
336
- logger.error(f"Unexpected processing error: {e}")
337
- logger.debug("Traceback:\n%s", traceback.format_exc())
338
- return None, f"Unexpected error: {e}"
339
-
340
- # ─────────────────────────────────────────────────────��───────────────
341
- # Private – single-stage
342
- # ─────────────────────────────────────────────────────────────────────
343
- def _process_single_stage(
344
- self,
345
- video_path: str,
346
- background_choice: str,
347
- custom_background_path: Optional[str],
348
- progress_callback: Optional[Callable],
349
- preview_mask: bool,
350
- preview_greenscreen: bool,
351
- ) -> Tuple[Optional[str], str]:
352
- import time
353
-
354
- ts = int(time.time())
355
- out_dir = Path(self.config.output_dir) / "single_stage"
356
- out_dir.mkdir(parents=True, exist_ok=True)
357
- out_path = str(out_dir / f"processed_{ts}.mp4")
358
-
359
- result = self.core_processor.process_video(
360
- input_path=video_path,
361
- output_path=out_path,
362
- bg_config={
363
- "background_choice": background_choice,
364
- "custom_path": custom_background_path,
365
- },
366
- )
367
- if not result:
368
- return None, "Video processing failed"
369
-
370
- # Mux original audio back unless preview flags request otherwise
371
- if not (preview_mask or preview_greenscreen):
372
- try:
373
- final_path = self.audio_processor.add_audio_to_video(
374
- original_video=video_path, processed_video=out_path
375
- )
376
- except Exception as e:
377
- logger.warning("Audio mux failed, returning video without audio: %r", e)
378
- final_path = out_path
379
- else:
380
- final_path = out_path
381
-
382
- msg = (
383
- "Processing completed.\n"
384
- f"Frames: {result.get('frames', 'unknown')}\n"
385
- f"Background: {background_choice}\n"
386
- f"Mode: Single-stage\n"
387
- f"Device: {self.device_manager.get_optimal_device()}"
388
- )
389
- return final_path, msg
390
-
391
- # ─────────────────────────────────────────────────────────────────────
392
- # Private – two-stage
393
- # ─────────────────────────────────────────────────────────────────────
394
- def _process_two_stage(
395
  self,
396
  video_path: str,
397
- background_choice: str,
398
- custom_background_path: Optional[str],
399
- progress_callback: Optional[Callable],
400
- chroma_preset: str,
401
- key_color_mode: str,
 
 
402
  ) -> Tuple[Optional[str], str]:
403
- if self.two_stage_processor is None:
404
- return None, "Two-stage processor not available"
405
-
406
- import cv2, time
 
 
 
 
 
407
 
408
- # Determine output geometry from source
409
- cap = cv2.VideoCapture(video_path)
410
- if not cap.isOpened():
411
- return None, "Could not open input video"
412
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 1280
413
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 720
414
- cap.release()
415
 
416
- # Prepare background
417
- try:
418
- background = self.core_processor.prepare_background(
419
- background_choice, custom_background_path, w, h
420
  )
421
- except Exception as e:
422
- logger.error("Background preparation failed: %r", e)
423
- return None, f"Failed to prepare background: {e}"
424
-
425
- if background is None:
426
- return None, "Failed to prepare background"
427
-
428
- ts = int(time.time())
429
- out_dir = Path(self.config.output_dir) / "two_stage"
430
- out_dir.mkdir(parents=True, exist_ok=True)
431
- final_out = str(out_dir / f"final_{ts}.mp4")
432
-
433
- chroma_cfg = CHROMA_PRESETS.get(chroma_preset, CHROMA_PRESETS.get("standard", {}))
434
- logger.info(
435
- "Two-stage with preset: %s | key_color_mode=%s | origin=%s",
436
- chroma_preset, key_color_mode, TWO_STAGE_IMPORT_ORIGIN or "unknown"
437
- )
438
-
439
- result, message = self.two_stage_processor.process_full_pipeline(
440
- video_path,
441
- background,
442
- final_out,
443
- key_color_mode=key_color_mode,
444
- chroma_settings=chroma_cfg,
445
- progress_callback=progress_callback,
446
- )
447
- if result is None:
448
- return None, message
449
-
450
- # Mux audio from original (same logic as single-stage)
451
  try:
452
- final_path = self.audio_processor.add_audio_to_video(
453
- original_video=video_path, processed_video=result
454
- )
 
 
 
 
 
455
  except Exception as e:
456
- logger.warning("Audio mux failed for two-stage; returning video without audio: %r", e)
457
- final_path = result
458
-
459
- msg = (
460
- "Two-stage processing completed.\n"
461
- f"Background: {background_choice}\n"
462
- f"Chroma Preset: {chroma_preset}\n"
463
- f"Device: {self.device_manager.get_optimal_device()}"
464
- )
465
- return final_path, msg
466
-
467
- # ─────────────────────────────────────────────────────────────────────
468
- # Status helpers
469
- # ─────────────────────────────────────────────────────────────────────
470
- def get_status(self) -> Dict[str, Any]:
471
- status = {
472
- "models_loaded": self.models_loaded,
473
- "two_stage_available": bool(TWO_STAGE_AVAILABLE and (self.two_stage_processor is not None)),
474
- "two_stage_origin": TWO_STAGE_IMPORT_ORIGIN or "",
475
- "two_stage_error": TWO_STAGE_IMPORT_ERROR[:5000] if TWO_STAGE_IMPORT_ERROR else "",
476
- "device": str(self.device_manager.get_optimal_device()),
477
- "core_processor_loaded": self.core_processor is not None,
478
- "config": self._safe_config_dict(),
479
- "memory_usage": self._safe_memory_usage(),
480
- }
481
- try:
482
- status["sam2_loaded"] = self.model_loader.get_sam2() is not None
483
- status["matanyone_loaded"] = self.model_loader.get_matanyone() is not None
484
- except Exception:
485
- status["sam2_loaded"] = False
486
- status["matanyone_loaded"] = False
487
-
488
- if self.progress_tracker:
489
- status["progress"] = self.progress_tracker.get_all_progress()
490
- return status
491
 
492
- def _safe_config_dict(self) -> Dict[str, Any]:
493
  try:
494
- return self.config.to_dict()
495
- except Exception:
496
- keys = ["use_nvenc", "prefer_mp4", "video_codec", "audio_copy",
497
- "ffmpeg_path", "max_model_size", "max_model_size_bytes", "output_dir"]
498
- return {k: getattr(self.config, k, None) for k in keys}
 
 
 
 
 
 
 
 
499
 
500
- def _safe_memory_usage(self) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
501
  try:
502
- return self.memory_manager.get_memory_usage()
 
 
503
  except Exception:
504
- return {}
505
-
506
- def cancel_processing(self):
507
- self.cancel_event.set()
508
- logger.info("Cancellation requested")
509
 
510
- def cleanup_resources(self):
511
- try:
512
- self.memory_manager.cleanup_aggressive()
513
- except Exception:
514
- pass
 
 
515
  try:
516
- self.model_loader.cleanup()
517
- except Exception:
518
- pass
519
- logger.info("Resources cleaned up")
520
-
521
-
522
- # ╔══════════════════════════════════════════════════════════════════════════╗
523
- # β•‘ Singleton instance + wrappers β•‘
524
- # β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
525
- processor = VideoProcessor()
526
-
527
- def load_models_with_validation(progress_callback: Optional[Callable] = None) -> str:
528
- return processor.load_models(progress_callback)
529
-
530
- def process_video_fixed(
531
- video_path: str,
532
- background_choice: str,
533
- custom_background_path: Optional[str],
534
- progress_callback: Optional[Callable] = None,
535
- use_two_stage: bool = False,
536
- chroma_preset: str = "standard",
537
- key_color_mode: str = "auto",
538
- preview_mask: bool = False,
539
- preview_greenscreen: bool = False,
540
- ) -> Tuple[Optional[str], str]:
541
- return processor.process_video(
542
- video_path,
543
- background_choice,
544
- custom_background_path,
545
- progress_callback,
546
- use_two_stage,
547
- chroma_preset,
548
- key_color_mode,
549
- preview_mask,
550
- preview_greenscreen,
551
- )
552
-
553
- def get_model_status() -> Dict[str, Any]:
554
- return processor.get_status()
555
-
556
- def get_cache_status() -> Dict[str, Any]:
557
- return processor.get_status()
558
-
559
- PROCESS_CANCELLED = processor.cancel_event
560
-
561
-
562
- # ╔══════════════════════════════════════════════════════════════════════════╗
563
- # β•‘ CLI β•‘
564
- # β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
565
- def main():
566
- try:
567
- logger.info("Starting BackgroundFX Pro")
568
- logger.info(f"Device: {processor.device_manager.get_optimal_device()}")
569
- logger.info(
570
- "Two-stage available at import-time: %s (origin='%s')",
571
- TWO_STAGE_AVAILABLE, TWO_STAGE_IMPORT_ORIGIN or "unknown"
572
- )
573
-
574
- from ui.ui_components import create_interface
575
- demo = create_interface()
576
- demo.queue().launch(
577
- server_name="0.0.0.0",
578
- server_port=7860,
579
- show_error=True,
580
- debug=False,
581
- )
582
- finally:
583
- processor.cleanup_resources()
584
-
585
-
586
- if __name__ == "__main__":
587
- main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Two-Stage Green-Screen Processing System βœ… 2025-08-26
4
+ Stage 1: Original β†’ keyed background (auto-selected colour)
5
+ Stage 2: Keyed video β†’ final composite (hybrid chroma + segmentation rescue)
6
+
7
+ Aligned with current project layout:
8
+ * uses helpers from utils.cv_processing (segment_person_hq, refine_mask_hq)
9
+ * safe local create_video_writer (no core.app dependency)
10
+ * cancel support via stop_event
11
+ * progress_callback(pct, desc)
12
+ * fully self-contained – just drop in and import TwoStageProcessor
13
  """
14
 
15
  from __future__ import annotations
16
 
17
+ import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
 
 
 
 
 
 
 
 
 
 
 
18
  from pathlib import Path
19
+ from typing import Optional, Dict, Any, Callable, Tuple, List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ from utils.cv_processing import segment_person_hq, refine_mask_hq
 
22
 
23
+ # Project logger if available
24
+ try:
25
+ from utils.logger import get_logger
26
+ logger = get_logger(__name__)
27
+ except Exception:
28
+ logger = logging.getLogger(__name__)
 
 
 
29
 
 
 
 
30
 
31
+ # ---------------------------------------------------------------------------
32
+ # Local video-writer helper
33
+ # ---------------------------------------------------------------------------
34
+ def create_video_writer(output_path: str, fps: float, width: int, height: int, prefer_mp4: bool = True):
35
+ try:
36
+ ext = ".mp4" if prefer_mp4 else ".avi"
37
+ if not output_path:
38
+ output_path = tempfile.mktemp(suffix=ext)
39
+ else:
40
+ base, curr_ext = os.path.splitext(output_path)
41
+ if curr_ext.lower() not in [".mp4", ".avi", ".mov", ".mkv"]:
42
+ output_path = base + ext
43
+
44
+ fourcc = cv2.VideoWriter_fourcc(*("mp4v" if prefer_mp4 else "XVID"))
45
+ writer = cv2.VideoWriter(output_path, fourcc, float(fps), (int(width), int(height)))
46
+ if writer is None or not writer.isOpened():
47
+ alt_ext = ".avi" if prefer_mp4 else ".mp4"
48
+ alt_fourcc = cv2.VideoWriter_fourcc(*("XVID" if prefer_mp4 else "mp4v"))
49
+ alt_path = os.path.splitext(output_path)[0] + alt_ext
50
+ writer = cv2.VideoWriter(alt_path, alt_fourcc, float(fps), (int(width), int(height)))
51
+ if writer is None or not writer.isOpened():
52
+ return None, output_path
53
+ return writer, alt_path
54
+ return writer, output_path
55
+ except Exception as e:
56
+ logger.error(f"create_video_writer failed: {e}")
57
+ return None, output_path
58
+
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # Key-colour helpers (fast, no external deps)
62
+ # ---------------------------------------------------------------------------
63
+ def _bgr_to_hsv_hue_deg(bgr: np.ndarray) -> np.ndarray:
64
+ hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
65
+ # OpenCV H is 0-180; scale to degrees 0-360
66
+ return hsv[..., 0].astype(np.float32) * 2.0
67
+
68
+
69
+ def _hue_distance(a_deg: float, b_deg: float) -> float:
70
+ """Circular distance on the hue wheel (degrees)."""
71
+ d = abs(a_deg - b_deg) % 360.0
72
+ return min(d, 360.0 - d)
73
+
74
+
75
+ def _key_candidates_bgr() -> dict:
76
+ return {
77
+ "green": {"bgr": np.array([ 0,255, 0], dtype=np.uint8), "hue": 120.0},
78
+ "blue": {"bgr": np.array([255, 0, 0], dtype=np.uint8), "hue": 240.0},
79
+ "cyan": {"bgr": np.array([255,255, 0], dtype=np.uint8), "hue": 180.0},
80
+ "magenta": {"bgr": np.array([255, 0,255], dtype=np.uint8), "hue": 300.0},
81
+ }
82
+
83
+
84
+ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dict:
85
+ """Pick the candidate colour farthest from the actor’s dominant hues."""
86
+ try:
87
+ fg = frame_bgr[mask_uint8 > 127]
88
+ if fg.size < 1_000:
89
+ return _key_candidates_bgr()["green"]
90
+
91
+ fg_hue = _bgr_to_hsv_hue_deg(fg.reshape(-1, 1, 3)).reshape(-1)
92
+ hist, edges = np.histogram(fg_hue, bins=36, range=(0.0, 360.0))
93
+ top_idx = np.argsort(hist)[-3:]
94
+ top_hues = [(edges[i] + edges[i+1]) * 0.5 for i in top_idx]
95
+
96
+ best_name, best_score = None, -1.0
97
+ for name, info in _key_candidates_bgr().items():
98
+ cand_hue = info["hue"]
99
+ score = min(abs((cand_hue - th + 180) % 360 - 180) for th in top_hues)
100
+ if score > best_score:
101
+ best_name, best_score = name, score
102
+ return _key_candidates_bgr().get(best_name, _key_candidates_bgr()["green"])
103
+ except Exception:
104
+ return _key_candidates_bgr()["green"]
105
+
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Chroma presets
109
+ # ---------------------------------------------------------------------------
110
+ CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
111
+ 'standard': {'key_color': [0,255,0], 'tolerance': 38, 'edge_softness': 2, 'spill_suppression': 0.35},
112
+ 'studio': {'key_color': [0,255,0], 'tolerance': 30, 'edge_softness': 1, 'spill_suppression': 0.45},
113
+ 'outdoor': {'key_color': [0,255,0], 'tolerance': 50, 'edge_softness': 3, 'spill_suppression': 0.25},
114
+ }
115
+
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # Two-Stage Processor
119
+ # ---------------------------------------------------------------------------
120
+ class TwoStageProcessor:
121
+ def __init__(self, sam2_predictor=None, matanyone_model=None):
122
+ self.sam2 = self._unwrap_sam2(sam2_predictor)
123
+ self.matanyone = matanyone_model
124
+ self.mask_cache_dir = Path("/tmp/mask_cache")
125
+ self.mask_cache_dir.mkdir(parents=True, exist_ok=True)
126
+ logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
127
+
128
+ # ---------------------------------------------------------------------
129
+ # Stage 1 – Original β†’ keyed (green/blue/…) -- chooses colour on 1st frame
130
+ # ---------------------------------------------------------------------
131
+ def stage1_extract_to_greenscreen(
132
+ self,
133
+ video_path: str,
134
+ output_path: str,
135
+ *,
136
+ key_color_mode: str = "auto", # "auto" | "green" | "blue" | "cyan" | "magenta"
137
+ progress_callback: Optional[Callable[[float, str], None]] = None,
138
+ stop_event: Optional["threading.Event"] = None,
139
+ ) -> Tuple[Optional[dict], str]:
140
+
141
+ def _prog(p, d):
142
+ if progress_callback:
143
+ try:
144
+ progress_callback(float(p), str(d))
145
+ except Exception:
146
+ pass
147
 
148
+ try:
149
+ _prog(0.0, "Stage 1: opening video…")
150
+ cap = cv2.VideoCapture(video_path)
151
+ if not cap.isOpened():
152
+ return None, "Could not open input video"
153
+
154
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
155
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
156
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
157
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
158
+
159
+ writer, out_path = create_video_writer(output_path, fps, w, h)
160
+ if writer is None:
161
+ cap.release()
162
+ return None, "Could not create output writer"
163
+
164
+ key_info: dict | None = None
165
+ chosen_bgr = np.array([0, 255, 0], np.uint8) # default
166
+ probe_done = False
167
+ masks: List[np.ndarray] = []
168
+ frame_idx = 0
169
+
170
+ green_bg_template = np.zeros((h, w, 3), np.uint8) # overwritten per-frame
171
+
172
+ while True:
173
+ if stop_event and stop_event.is_set():
174
+ _prog(1.0, "Stage 1: cancelled")
175
+ break
176
+
177
+ ok, frame = cap.read()
178
+ if not ok:
179
+ break
180
+
181
+ mask = self._get_mask(frame)
182
+
183
+ # decide key colour once
184
+ if not probe_done:
185
+ if key_color_mode.lower() == "auto":
186
+ key_info = _choose_best_key_color(frame, mask)
187
+ chosen_bgr = key_info["bgr"]
188
+ else:
189
+ cand = _key_candidates_bgr().get(key_color_mode.lower())
190
+ if cand is not None:
191
+ chosen_bgr = cand["bgr"]
192
+ probe_done = True
193
+ logger.info(f"[TwoStage] Using key colour: {key_color_mode} β†’ {chosen_bgr.tolist()}")
194
+
195
+ # optional refine
196
+ if self.matanyone and frame_idx % 3 == 0:
197
+ try:
198
+ mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
199
+ except Exception as e:
200
+ logger.warning(f"MatAnyOne refine fail f={frame_idx}: {e}")
201
 
202
+ # composite
203
+ green_bg_template[:] = chosen_bgr
204
+ gs = self._apply_greenscreen_hard(frame, mask, green_bg_template)
205
+ writer.write(gs)
206
+ masks.append(self._to_binary_mask(mask))
207
 
208
+ frame_idx += 1
209
+ pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
210
+ _prog(pct, f"Stage 1: {frame_idx}/{total or '?'}")
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  cap.release()
213
+ writer.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # save mask cache
216
  try:
217
+ cache_file = self.mask_cache_dir / (Path(out_path).stem + "_masks.pkl")
218
+ with open(cache_file, "wb") as f:
219
+ pickle.dump(masks, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  except Exception as e:
221
+ logger.warning(f"mask cache save fail: {e}")
222
+
223
+ _prog(1.0, "Stage 1: complete")
224
+ return (
225
+ {"path": out_path, "frames": frame_idx, "key_bgr": chosen_bgr.tolist()},
226
+ f"Green-screen video created ({frame_idx} frames)"
227
+ )
228
+
229
+ except Exception as e:
230
+ logger.error(f"Stage 1 error: {e}\n{traceback.format_exc()}")
231
+ return None, f"Stage 1 failed: {e}"
232
+
233
+ # ---------------------------------------------------------------------
234
+ # Stage 2 – keyed video β†’ final composite (hybrid matte)
235
+ # ---------------------------------------------------------------------
236
+ def stage2_greenscreen_to_final(
237
  self,
238
+ gs_path: str,
239
+ background: np.ndarray | str,
240
+ output_path: str,
241
+ *,
242
+ chroma_settings: Optional[Dict[str, Any]] = None,
243
+ progress_callback: Optional[Callable[[float, str], None]] = None,
244
+ stop_event: Optional["threading.Event"] = None,
 
 
245
  ) -> Tuple[Optional[str], str]:
 
 
 
 
 
 
246
 
247
+ def _prog(p, d):
248
+ if progress_callback:
249
+ try:
250
+ progress_callback(float(p), str(d))
251
+ except Exception:
252
+ pass
253
 
254
  try:
255
+ _prog(0.0, "Stage 2: opening keyed video…")
256
+ cap = cv2.VideoCapture(gs_path)
257
+ if not cap.isOpened():
258
+ return None, "Could not open keyed video"
259
+
260
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
261
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
262
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
263
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
264
+
265
+ writer, out_path = create_video_writer(output_path, fps, w, h)
266
+ if writer is None:
267
+ cap.release()
268
+ return None, "Could not create output writer"
269
+
270
+ # background
271
+ if isinstance(background, str):
272
+ bg = cv2.imread(background, cv2.IMREAD_COLOR)
273
+ if bg is None:
274
+ cap.release()
275
+ writer.release()
276
+ return None, "Could not load background"
277
  else:
278
+ bg = background
279
+ bg = cv2.resize(bg, (w, h), interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
280
+
281
+ # settings
282
+ settings = dict(CHROMA_PRESETS['standard'])
283
+ if chroma_settings:
284
+ settings.update(chroma_settings)
285
+
286
+ # load cached masks if any
287
+ cache_file = self.mask_cache_dir / (Path(gs_path).stem + "_masks.pkl")
288
+ cached_masks = None
289
+ if cache_file.exists():
290
+ try:
291
+ with open(cache_file, 'rb') as f:
292
+ cached_masks = pickle.load(f)
293
+ except Exception as e:
294
+ logger.warning(f"mask cache load fail: {e}")
295
+
296
+ frame_idx = 0
297
+ while True:
298
+ if stop_event and stop_event.is_set():
299
+ _prog(1.0, "Stage 2: cancelled")
300
+ break
301
+ ok, frame = cap.read()
302
+ if not ok:
303
+ break
304
+
305
+ if cached_masks and frame_idx < len(cached_masks):
306
+ seg_mask = cached_masks[frame_idx]
307
+ else:
308
+ seg_mask = self._segmentation_mask_on_stage2(frame)
309
+
310
+ composite = self._chroma_key_advanced(frame, bg, settings, seg_mask)
311
+
312
+ writer.write(composite)
313
+ frame_idx += 1
314
+ pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
315
+ _prog(pct, f"Stage 2: {frame_idx}/{total or '?'}")
316
+
317
+ cap.release()
318
+ writer.release()
319
+ _prog(1.0, "Stage 2: complete")
320
+ return out_path, f"Final video created ({frame_idx} frames)"
321
  except Exception as e:
322
+ logger.error(f"Stage 2 error: {e}\n{traceback.format_exc()}")
323
+ return None, f"Stage 2 failed: {e}"
324
+
325
+ # ---------------------------------------------------------------------
326
+ # Full pipeline – now passes chosen key into Stage 2
327
+ # ---------------------------------------------------------------------
328
+ def process_full_pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  self,
330
  video_path: str,
331
+ background: np.ndarray | str,
332
+ final_output: str,
333
+ *,
334
+ key_color_mode: str = "auto",
335
+ chroma_settings: Optional[Dict[str, Any]] = None,
336
+ progress_callback: Optional[Callable[[float, str], None]] = None,
337
+ stop_event: Optional["threading.Event"] = None,
338
  ) -> Tuple[Optional[str], str]:
339
+ gs_tmp = tempfile.mktemp(suffix="_gs.mp4")
340
+ try:
341
+ gs_info, msg1 = self.stage1_extract_to_greenscreen(
342
+ video_path, gs_tmp,
343
+ key_color_mode=key_color_mode,
344
+ progress_callback=progress_callback, stop_event=stop_event
345
+ )
346
+ if gs_info is None:
347
+ return None, msg1
348
 
349
+ # inject key colour into chroma settings for Stage 2
350
+ chosen_key = gs_info.get("key_bgr", [0, 255, 0])
351
+ cs = dict(chroma_settings or CHROMA_PRESETS['standard'])
352
+ cs['key_color'] = chosen_key
 
 
 
353
 
354
+ result, msg2 = self.stage2_greenscreen_to_final(
355
+ gs_info["path"], background, final_output,
356
+ chroma_settings=cs, progress_callback=progress_callback, stop_event=stop_event
 
357
  )
358
+ return result, msg2
359
+ finally:
360
+ try:
361
+ os.remove(gs_tmp)
362
+ except Exception:
363
+ pass
364
+ gc.collect()
365
+
366
+ # ---------------------------------------------------------------------
367
+ # Internal helpers
368
+ # ---------------------------------------------------------------------
369
+ def _unwrap_sam2(self, obj):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  try:
371
+ if obj is None:
372
+ return None
373
+ if all(hasattr(obj, attr) for attr in ("set_image", "predict")):
374
+ return obj
375
+ for attr in ("model", "predictor"):
376
+ inner = getattr(obj, attr, None)
377
+ if inner and all(hasattr(inner, a) for a in ("set_image", "predict")):
378
+ return inner
379
  except Exception as e:
380
+ logger.warning(f"SAM2 unwrap fail: {e}")
381
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
+ def _get_mask(self, frame: np.ndarray) -> np.ndarray:
384
  try:
385
+ return segment_person_hq(frame, self.sam2, fallback_enabled=True)
386
+ except Exception as e:
387
+ logger.warning(f"Segmentation fallback: {e}")
388
+ h, w = frame.shape[:2]
389
+ m = np.zeros((h, w), np.uint8)
390
+ m[h//6:5*h//6, w//4:3*w//4] = 255
391
+ return m
392
+
393
+ def _apply_greenscreen_hard(self, frame, mask, green_bg):
394
+ mask_u8 = self._to_binary_mask(mask)
395
+ mk = cv2.cvtColor(mask_u8, cv2.COLOR_GRAY2BGR).astype(np.float32) / 255.0
396
+ out = frame.astype(np.float32) * mk + green_bg.astype(np.float32) * (1.0 - mk)
397
+ return np.clip(out, 0, 255).astype(np.uint8)
398
 
399
+ @staticmethod
400
+ def _to_binary_mask(mask: np.ndarray) -> np.ndarray:
401
+ if mask.ndim == 3:
402
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
403
+ if mask.dtype != np.uint8:
404
+ mask = (np.clip(mask, 0, 1) * 255).astype(np.uint8) if mask.max() <= 1.0 else np.clip(mask, 0, 255).astype(np.uint8)
405
+ _, binm = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
406
+ return binm
407
+
408
+ def _segmentation_mask_on_stage2(self, frame_bgr: np.ndarray) -> Optional[np.ndarray]:
409
  try:
410
+ if self.sam2 is None:
411
+ return None
412
+ return self._get_mask(frame_bgr)
413
  except Exception:
414
+ return None
 
 
 
 
415
 
416
+ def _chroma_key_advanced(
417
+ self,
418
+ frame_bgr: np.ndarray,
419
+ bg_bgr: np.ndarray,
420
+ settings: Dict[str, Any],
421
+ seg_mask: Optional[np.ndarray] = None,
422
+ ) -> np.ndarray:
423
  try:
424
+ key = np.array(settings.get("key_color", [0, 255, 0]), dtype=np.float32)
425
+ tol = float(settings.get("tolerance", 40))
426
+ soft = int(settings.get("edge_softness", 2))
427
+ spill= float(settings.get("spill_suppression", 0.3))
428
+
429
+ f = frame_bgr.astype(np.float32)
430
+ b = bg_bgr.astype(np.float32)
431
+
432
+ diff = np.linalg.norm(f - key, axis=2)
433
+ alpha = np.clip((diff - tol * 0.6) / max(1e-6, tol * 0.4), 0.0, 1.0)
434
+ if soft > 0:
435
+ k = soft * 2 + 1
436
+ alpha = cv2.GaussianBlur(alpha, (k, k), soft)
437
+
438
+ # segmentation rescue
439
+ if seg_mask is not None:
440
+ if seg_mask.ndim == 3:
441
+ seg_mask = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2GRAY)
442
+ seg = seg_mask.astype(np.float32) / 255.0
443
+ seg = cv2.GaussianBlur(seg, (5, 5), 1.0)
444
+ alpha = np.clip(np.maximum(alpha, seg * 0.85), 0.0, 1.0)
445
+
446
+ # spill suppression
447
+ if spill > 0:
448
+ zone = 1.0 - alpha
449
+ g = f[:, :, 1]
450
+ f[:, :, 1] = np.clip(g - g * zone * spill, 0, 255)
451
+
452
+ mask3 = np.stack([alpha] * 3, axis=2)
453
+ out = f * mask3 + b * (1.0 - mask3)
454
+ return np.clip(out, 0, 255).astype(np.uint8)
455
+ except Exception as e:
456
+ logger.error(f"Chroma key error: {e}")
457
+ return frame_bgr