Nekochu commited on
Commit
c2d53e4
·
1 Parent(s): 16d8be4

professional output package (Comp+FG+Matte+Processed)

Browse files
Files changed (1) hide show
  1. app.py +140 -295
app.py CHANGED
@@ -29,17 +29,16 @@ import gradio as gr
29
  import onnxruntime as ort
30
 
31
  # Workaround: Gradio cache_examples bug with None outputs.
32
- # CSVLogger.flag() writes "" for None, read_from_flag("") calls json.loads("") -> crash.
33
  _original_read_from_flag = gr.components.Component.read_from_flag
34
  def _patched_read_from_flag(self, payload):
35
  if payload is None or (isinstance(payload, str) and payload.strip() == ""):
36
  return None
37
  return _original_read_from_flag(self, payload)
38
  gr.components.Component.read_from_flag = _patched_read_from_flag
 
39
  from huggingface_hub import hf_hub_download
40
 
41
  cv2.setNumThreads(2)
42
-
43
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
44
  logger = logging.getLogger(__name__)
45
 
@@ -48,55 +47,40 @@ logger = logging.getLogger(__name__)
48
  # ---------------------------------------------------------------------------
49
  BIREFNET_REPO = "onnx-community/BiRefNet_lite-ONNX"
50
  BIREFNET_FILE = "onnx/model.onnx"
51
-
52
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
53
  CORRIDORKEY_MODELS = {
54
  "1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"),
55
  "2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"),
56
  }
57
-
58
  IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
59
  IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
60
-
61
  MAX_DURATION_CPU = 5
62
  MAX_DURATION_GPU = 30
63
  MAX_FRAMES = 150
64
-
65
- # GPU auto-detect via ONNX Runtime (no torch dependency)
66
  HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers()
67
 
68
  # ---------------------------------------------------------------------------
69
- # Color utilities (numpy-only, from CorridorKeyModule/core/color_utils.py)
70
  # ---------------------------------------------------------------------------
71
-
72
  def linear_to_srgb(x):
73
  x = np.clip(x, 0.0, None)
74
  return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055)
75
 
76
-
77
  def srgb_to_linear(x):
78
  x = np.clip(x, 0.0, None)
79
  return np.where(x <= 0.04045, x / 12.92, np.power((x + 0.055) / 1.055, 2.4))
80
 
81
-
82
  def composite_straight(fg, bg, alpha):
83
  return fg * alpha + bg * (1.0 - alpha)
84
 
85
-
86
  def despill(image, green_limit_mode="average", strength=1.0):
87
  if strength <= 0.0:
88
  return image
89
  r, g, b = image[..., 0], image[..., 1], image[..., 2]
90
  limit = (r + b) / 2.0 if green_limit_mode == "average" else np.maximum(r, b)
91
- spill_amount = np.maximum(g - limit, 0.0)
92
- g_new = g - spill_amount
93
- r_new = r + spill_amount * 0.5
94
- b_new = b + spill_amount * 0.5
95
- despilled = np.stack([r_new, g_new, b_new], axis=-1)
96
- if strength < 1.0:
97
- return image * (1.0 - strength) + despilled * strength
98
- return despilled
99
-
100
 
101
  def clean_matte(alpha_np, area_threshold=300, dilation=15, blur_size=5):
102
  is_3d = alpha_np.ndim == 3
@@ -104,39 +88,30 @@ def clean_matte(alpha_np, area_threshold=300, dilation=15, blur_size=5):
104
  alpha_np = alpha_np[:, :, 0]
105
  mask_8u = (alpha_np > 0.5).astype(np.uint8) * 255
106
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_8u, connectivity=8)
107
- # Vectorized: find valid labels in one pass
108
  valid = np.zeros(num_labels, dtype=bool)
109
  valid[1:] = stats[1:, cv2.CC_STAT_AREA] >= area_threshold
110
  cleaned = (valid[labels].astype(np.uint8) * 255)
111
  if dilation > 0:
112
  k = int(dilation * 2 + 1)
113
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
114
- cleaned = cv2.dilate(cleaned, kernel)
115
  if blur_size > 0:
116
  b = int(blur_size * 2 + 1)
117
  cleaned = cv2.GaussianBlur(cleaned, (b, b), 0)
118
- safe_zone = cleaned.astype(np.float32) / 255.0
119
- result = alpha_np * safe_zone
120
  return result[:, :, np.newaxis] if is_3d else result
121
 
122
-
123
  def create_checkerboard(w, h, checker_size=64, color1=0.15, color2=0.55):
124
- x_tiles = np.arange(w) // checker_size
125
- y_tiles = np.arange(h) // checker_size
126
- xg, yg = np.meshgrid(x_tiles, y_tiles)
127
- checker = ((xg + yg) % 2).astype(np.float32)
128
- bg = np.where(checker == 0, color1, color2).astype(np.float32)
129
  return np.stack([bg, bg, bg], axis=-1)
130
 
 
 
131
 
132
  # ---------------------------------------------------------------------------
133
- # Fast classical green-screen mask (alternative to BiRefNet)
134
  # ---------------------------------------------------------------------------
135
-
136
  def fast_greenscreen_mask(frame_rgb_f32):
137
- """Fast green-screen detection using corner sampling + HSV threshold.
138
- Returns (mask_f32, confidence) or (None, 0.0) if not a green screen.
139
- """
140
  h, w = frame_rgb_f32.shape[:2]
141
  ph, pw = max(int(h * 0.05), 4), max(int(w * 0.05), 4)
142
  corners = np.concatenate([
@@ -146,38 +121,25 @@ def fast_greenscreen_mask(frame_rgb_f32):
146
  frame_rgb_f32[-ph:, -pw:].reshape(-1, 3),
147
  ], axis=0)
148
  bg_color = np.median(corners, axis=0)
149
-
150
- # Check if background is green-ish (G channel dominant)
151
  if not (bg_color[1] > bg_color[0] + 0.05 and bg_color[1] > bg_color[2] + 0.05):
152
  return None, 0.0
153
-
154
- # HSV-based mask (more robust than RGB distance)
155
  frame_u8 = (np.clip(frame_rgb_f32, 0, 1) * 255).astype(np.uint8)
156
  hsv = cv2.cvtColor(frame_u8, cv2.COLOR_RGB2HSV)
157
- # Green hue range in HSV
158
  green_mask = cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))
159
- # Invert: foreground = NOT green
160
  fg_mask = cv2.bitwise_not(green_mask)
161
- # Morphological close to fill small holes
162
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
163
- fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, kernel)
164
  fg_mask = cv2.GaussianBlur(fg_mask, (5, 5), 0)
165
  mask_f32 = fg_mask.astype(np.float32) / 255.0
166
-
167
- # Confidence: how bimodal is the mask (closer to 0/1 = better)
168
  confidence = 1.0 - 2.0 * np.mean(np.minimum(mask_f32, 1.0 - mask_f32))
169
-
170
  return mask_f32, confidence
171
 
172
-
173
  # ---------------------------------------------------------------------------
174
- # Model loading (lazy singletons)
175
  # ---------------------------------------------------------------------------
176
  _birefnet_session = None
177
  _corridorkey_sessions = {}
178
 
179
-
180
- def _ort_session_opts():
181
  opts = ort.SessionOptions()
182
  opts.intra_op_num_threads = 2
183
  opts.inter_op_num_threads = 1
@@ -186,17 +148,15 @@ def _ort_session_opts():
186
  opts.enable_mem_pattern = True
187
  return opts
188
 
189
-
190
  def get_birefnet():
191
  global _birefnet_session
192
  if _birefnet_session is None:
193
  logger.info("Downloading BiRefNet-Lite ONNX...")
194
  path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
195
  logger.info("Loading BiRefNet ONNX: %s", path)
196
- _birefnet_session = ort.InferenceSession(path, _ort_session_opts(), providers=["CPUExecutionProvider"])
197
  return _birefnet_session
198
 
199
-
200
  def get_corridorkey(resolution="1024"):
201
  global _corridorkey_sessions
202
  if resolution not in _corridorkey_sessions:
@@ -204,62 +164,44 @@ def get_corridorkey(resolution="1024"):
204
  if not onnx_path or not os.path.exists(onnx_path):
205
  raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.")
206
  logger.info("Loading CorridorKey ONNX (%s): %s", resolution, onnx_path)
207
- _corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_session_opts(), providers=["CPUExecutionProvider"])
208
  return _corridorkey_sessions[resolution]
209
 
210
-
211
  # ---------------------------------------------------------------------------
212
  # Per-frame inference
213
  # ---------------------------------------------------------------------------
214
-
215
  def birefnet_frame(session, image_rgb_uint8):
216
- """BiRefNet: RGB uint8 [H,W,3] -> float32 [H,W] mask 0-1."""
217
  h, w = image_rgb_uint8.shape[:2]
218
- inp_info = session.get_inputs()[0]
219
- res = (inp_info.shape[2], inp_info.shape[3])
220
  img = cv2.resize(image_rgb_uint8, res).astype(np.float32) / 255.0
221
- img = (img - IMAGENET_MEAN) / IMAGENET_STD
222
- img = img.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
223
- outputs = session.run(None, {inp_info.name: img})
224
- pred = 1.0 / (1.0 + np.exp(-outputs[-1])) # sigmoid
225
- mask = cv2.resize(pred[0, 0], (w, h))
226
- return (mask > 0.04).astype(np.float32)
227
-
228
 
229
  def corridorkey_frame(session, image_f32, mask_f32, img_size,
230
- despill_strength=0.5, auto_despeckle=True,
231
- despeckle_size=400):
232
- """CorridorKey: image [H,W,3] float32 0-1 + mask [H,W] float32 0-1 -> dict."""
233
  h, w = image_f32.shape[:2]
234
- img_resized = cv2.resize(image_f32, (img_size, img_size))
235
- mask_resized = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis]
236
- img_norm = (img_resized - IMAGENET_MEAN) / IMAGENET_STD
237
- inp = np.concatenate([img_norm, mask_resized], axis=-1)
238
  inp = inp.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
239
-
240
  alpha_raw, fg_raw = session.run(None, {"input": inp})
241
-
242
  alpha = cv2.resize(alpha_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
243
  fg = cv2.resize(fg_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
244
  if alpha.ndim == 2:
245
  alpha = alpha[:, :, np.newaxis]
246
-
247
  if auto_despeckle:
248
  alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
249
  fg = despill(fg, green_limit_mode="average", strength=despill_strength)
250
-
251
  return {"alpha": alpha, "fg": fg}
252
 
253
-
254
  # ---------------------------------------------------------------------------
255
- # Video stitching via ffmpeg
256
  # ---------------------------------------------------------------------------
257
-
258
  def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420p",
259
  codec="libx264", extra_args=None):
260
- """Stitch PNG frames into video via ffmpeg subprocess."""
261
- cmd = ["ffmpeg", "-y", "-framerate", str(fps),
262
- "-i", os.path.join(frame_dir, pattern),
263
  "-c:v", codec, "-pix_fmt", pix_fmt]
264
  if extra_args:
265
  cmd.extend(extra_args)
@@ -271,36 +213,13 @@ def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420
271
  logger.warning("ffmpeg failed: %s", e)
272
  return False
273
 
274
-
275
- def _stitch_cv2_fallback(frame_dir, out_path, fps, w, h, grayscale=False):
276
- """Fallback: stitch via OpenCV VideoWriter if ffmpeg unavailable."""
277
- files = sorted([f for f in os.listdir(frame_dir) if f.endswith(".png")])
278
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
279
- writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
280
- if not writer.isOpened():
281
- logger.warning("mp4v codec unavailable")
282
- return False
283
- for f in files:
284
- img = cv2.imread(os.path.join(frame_dir, f),
285
- cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR)
286
- if img is None:
287
- continue
288
- if grayscale:
289
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
290
- writer.write(img)
291
- writer.release()
292
- return True
293
-
294
-
295
  # ---------------------------------------------------------------------------
296
- # Video processing pipeline (single-pass, streaming)
297
  # ---------------------------------------------------------------------------
298
-
299
  def process_video(video_path, resolution, despill_val, mask_mode,
300
- auto_despeckle, despeckle_size, output_mode, progress=gr.Progress()):
301
  """Remove green screen background from video using CorridorKey AI matting.
302
- Handles transparent objects (glass, water, cloth) that traditional chroma key cannot.
303
- Returns composite video, downloadable file, and status message.
304
  """
305
  if video_path is None:
306
  raise gr.Error("Please upload a video.")
@@ -308,7 +227,6 @@ def process_video(video_path, resolution, despill_val, mask_mode,
308
  max_dur = MAX_DURATION_GPU if HAS_CUDA else MAX_DURATION_CPU
309
  img_size = int(resolution)
310
 
311
- # Probe video
312
  cap = cv2.VideoCapture(video_path)
313
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
314
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -318,7 +236,6 @@ def process_video(video_path, resolution, despill_val, mask_mode,
318
 
319
  if total_frames == 0:
320
  raise gr.Error("Could not read video frames. Check file format.")
321
-
322
  duration = total_frames / fps
323
  if duration > max_dur:
324
  raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s on {'GPU' if HAS_CUDA else 'free CPU'} tier.")
@@ -327,7 +244,6 @@ def process_video(video_path, resolution, despill_val, mask_mode,
327
  logger.info("Processing %d frames (%dx%d @ %.1f fps), resolution=%d, mask=%s",
328
  frames_to_process, w, h, fps, img_size, mask_mode)
329
 
330
- # Load models
331
  try:
332
  birefnet = None
333
  if mask_mode != "Fast (classical)":
@@ -339,42 +255,22 @@ def process_video(video_path, resolution, despill_val, mask_mode,
339
  raise gr.Error(f"Failed to load models: {e}")
340
 
341
  despill_strength = despill_val / 10.0
342
-
343
- # Determine what outputs we need
344
- need_comp = output_mode == "Composite on checkerboard (MP4)"
345
- need_alpha = output_mode == "Alpha matte (MP4)"
346
- need_rgba = output_mode in ("Transparent video (WebM)", "PNG sequence (ZIP)")
347
-
348
  tmpdir = tempfile.mkdtemp(prefix="ck_")
 
349
  try:
350
- # Pre-compute checkerboard if needed
351
- bg_lin = None
352
- if need_comp:
353
- bg_lin = srgb_to_linear(create_checkerboard(w, h))
354
-
355
- # For PNG-based outputs, create dirs
356
- rgba_dir = None
357
- alpha_dir = None
358
- comp_dir = None
359
- if need_rgba:
360
- rgba_dir = os.path.join(tmpdir, "rgba")
361
- os.makedirs(rgba_dir, exist_ok=True)
362
- if output_mode == "PNG sequence (ZIP)":
363
- alpha_dir = os.path.join(tmpdir, "alphas")
364
- os.makedirs(alpha_dir, exist_ok=True)
365
-
366
- # For MP4 modes, write directly to VideoWriter via temp PNGs + ffmpeg
367
- # (we still need PNGs as ffmpeg input, but only the needed type)
368
- if need_comp:
369
- comp_dir = os.path.join(tmpdir, "comp")
370
- os.makedirs(comp_dir, exist_ok=True)
371
- if need_alpha:
372
- alpha_dir = os.path.join(tmpdir, "alphas")
373
- os.makedirs(alpha_dir, exist_ok=True)
374
-
375
- # Single-pass processing
376
  cap = cv2.VideoCapture(video_path)
377
  frame_times = []
 
378
 
379
  for i in range(frames_to_process):
380
  t0 = time.time()
@@ -385,16 +281,16 @@ def process_video(video_path, resolution, despill_val, mask_mode,
385
  frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
386
  frame_f32 = frame_rgb.astype(np.float32) / 255.0
387
 
388
- # Coarse mask: fast classical or BiRefNet
389
  if mask_mode == "Fast (classical)":
390
- mask, confidence = fast_greenscreen_mask(frame_f32)
391
  if mask is None:
392
- raise gr.Error("Fast mask failed: video doesn't appear to have a green screen background. Try 'AI (BiRefNet)' mode.")
393
  elif mask_mode == "Hybrid (auto)":
394
- mask, confidence = fast_greenscreen_mask(frame_f32)
395
- if mask is None or confidence < 0.7:
396
  mask = birefnet_frame(birefnet, frame_rgb)
397
- else: # "AI (BiRefNet)"
398
  mask = birefnet_frame(birefnet, frame_rgb)
399
 
400
  # CorridorKey inference
@@ -402,92 +298,76 @@ def process_video(video_path, resolution, despill_val, mask_mode,
402
  despill_strength=despill_strength,
403
  auto_despeckle=auto_despeckle,
404
  despeckle_size=int(despeckle_size))
405
-
406
  alpha = result["alpha"]
407
  fg = result["fg"]
408
 
409
- # Write only the output we need
410
- if need_comp:
411
- fg_lin = srgb_to_linear(fg)
412
- comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
413
- comp_uint8 = (np.clip(comp, 0, 1) * 255).astype(np.uint8)
414
- cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.png"), comp_uint8[:, :, ::-1])
415
-
416
- if need_alpha or alpha_dir:
417
- alpha_uint8 = (np.clip(alpha, 0, 1) * 255).astype(np.uint8)
418
- if alpha_uint8.ndim == 3:
419
- alpha_uint8 = alpha_uint8[:, :, 0]
420
- if alpha_dir:
421
- cv2.imwrite(os.path.join(alpha_dir, f"{i:05d}.png"), alpha_uint8)
422
-
423
- if need_rgba:
424
- fg_uint8 = (np.clip(fg, 0, 1) * 255).astype(np.uint8)
425
- a_uint8 = (np.clip(alpha, 0, 1) * 255).astype(np.uint8)
426
- if a_uint8.ndim == 3:
427
- a_uint8 = a_uint8[:, :, 0]
428
- rgba = np.concatenate([fg_uint8[:, :, ::-1], a_uint8[:, :, np.newaxis]], axis=-1)
429
- cv2.imwrite(os.path.join(rgba_dir, f"{i:05d}.png"), rgba)
 
 
 
 
 
430
 
431
  # Progress with ETA
432
  elapsed = time.time() - t0
433
  frame_times.append(elapsed)
434
- avg_time = np.mean(frame_times[-5:]) if len(frame_times) >= 2 else elapsed
435
- remaining = (frames_to_process - i - 1) * avg_time
436
  eta = f"{remaining/60:.1f}min" if remaining > 60 else f"{remaining:.0f}s"
437
  pct = 0.05 + 0.85 * (i + 1) / frames_to_process
438
  progress(pct, desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) | ~{eta} left")
439
 
440
  cap.release()
441
-
442
- # Assemble output
443
- progress(0.92, desc="Stitching video...")
444
- output_video = None
445
- output_file = None
446
-
447
- if need_comp:
448
- out_path = os.path.join(tmpdir, "composite.mp4")
449
- ok = _stitch_ffmpeg(comp_dir, out_path, fps, extra_args=["-crf", "18"])
450
- if not ok:
451
- ok = _stitch_cv2_fallback(comp_dir, out_path, fps, w, h)
452
- if not ok:
453
- raise gr.Error("Video encoding failed. No suitable codec found.")
454
- output_video = out_path
455
- output_file = out_path
456
-
457
- elif need_alpha:
458
- out_path = os.path.join(tmpdir, "alpha_matte.mp4")
459
- ok = _stitch_ffmpeg(alpha_dir, out_path, fps, extra_args=["-crf", "18"])
460
- if not ok:
461
- ok = _stitch_cv2_fallback(alpha_dir, out_path, fps, w, h, grayscale=True)
462
- if not ok:
463
- raise gr.Error("Video encoding failed. No suitable codec found.")
464
- output_video = out_path
465
- output_file = out_path
466
-
467
- elif output_mode == "Transparent video (WebM)":
468
- out_path = os.path.join(tmpdir, "transparent.webm")
469
- ok = _stitch_ffmpeg(rgba_dir, out_path, fps,
470
- codec="libvpx-vp9", pix_fmt="yuva420p",
471
- extra_args=["-crf", "30", "-b:v", "0"])
472
- if not ok:
473
- raise gr.Error("WebM encoding failed. ffmpeg with libvpx-vp9 required.")
474
- output_video = out_path
475
- output_file = out_path
476
-
477
- elif output_mode == "PNG sequence (ZIP)":
478
- zip_path = os.path.join(tmpdir, "rgba_sequence.zip")
479
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
480
- for f in sorted(os.listdir(rgba_dir)):
481
- zf.write(os.path.join(rgba_dir, f), f"rgba/{f}")
482
- if alpha_dir:
483
- for f in sorted(os.listdir(alpha_dir)):
484
- zf.write(os.path.join(alpha_dir, f), f"alpha/{f}")
485
- output_file = zip_path
486
 
487
  progress(1.0, desc="Done!")
 
488
  avg = np.mean(frame_times) if frame_times else 0
489
- status = f"Processed {len(frame_times)} frames ({w}x{h}) at {img_size}px | {avg:.1f}s/frame avg"
490
- return output_video, output_file, status
 
 
 
 
 
 
491
 
492
  except gr.Error:
493
  raise
@@ -495,8 +375,7 @@ def process_video(video_path, resolution, despill_val, mask_mode,
495
  logger.exception("Processing failed")
496
  raise gr.Error(f"Processing failed: {e}")
497
  finally:
498
- # Cleanup intermediate dirs (keep output files in tmpdir root)
499
- for d in ["comp", "alphas", "rgba"]:
500
  p = os.path.join(tmpdir, d)
501
  if os.path.isdir(p):
502
  shutil.rmtree(p, ignore_errors=True)
@@ -506,10 +385,8 @@ def process_video(video_path, resolution, despill_val, mask_mode,
506
  # ---------------------------------------------------------------------------
507
  # Gradio UI
508
  # ---------------------------------------------------------------------------
509
-
510
- def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size, output_mode):
511
- return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size, output_mode)
512
-
513
 
514
  if HAS_CUDA:
515
  DESCRIPTION = "# CorridorKey Green Screen Matting\nRemove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. GPU mode: max {max_dur}s / {max_frames} frames.".format(max_dur=MAX_DURATION_GPU, max_frames=MAX_FRAMES)
@@ -522,60 +399,44 @@ with gr.Blocks(title="CorridorKey") as demo:
522
  with gr.Row():
523
  with gr.Column(scale=1):
524
  input_video = gr.Video(label="Upload Green Screen Video")
525
-
526
  with gr.Accordion("Settings", open=True):
527
  resolution = gr.Radio(
528
- choices=["1024", "2048"],
529
- value="1024",
530
  label="Processing Resolution",
531
- info="1024 = balanced (~8s/frame CPU), 2048 = max quality (trained resolution, fast on GPU)"
532
  )
533
  mask_mode = gr.Radio(
534
  choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"],
535
- value="Hybrid (auto)",
536
- label="Mask Mode",
537
- info="Hybrid = fast green detection + AI fallback. Fast = classical only (~0.01s). AI = always use BiRefNet (~13s/frame)"
538
  )
539
  despill_slider = gr.Slider(
540
- 0, 10, value=5, step=1,
541
- label="Despill Strength",
542
- info="Remove green reflections from subject (0=off, 10=max)"
543
  )
544
  despeckle_check = gr.Checkbox(
545
- value=True,
546
- label="Auto Despeckle",
547
- info="Remove small disconnected artifacts (tracking markers, noise)"
548
  )
549
  despeckle_size = gr.Number(
550
- value=400, precision=0,
551
- label="Despeckle Size",
552
- info="Minimum pixel area to keep (smaller = more aggressive cleanup)"
553
  )
554
-
555
- output_mode = gr.Dropdown(
556
- choices=[
557
- "Composite on checkerboard (MP4)",
558
- "Alpha matte (MP4)",
559
- "Transparent video (WebM)",
560
- "PNG sequence (ZIP)",
561
- ],
562
- value="Composite on checkerboard (MP4)",
563
- label="Output Format"
564
- )
565
-
566
  process_btn = gr.Button("Process Video", variant="primary", size="lg")
567
 
568
  with gr.Column(scale=1):
569
- output_video = gr.Video(label="Result Preview")
570
- output_file = gr.File(label="Download Result")
 
 
571
  status_text = gr.Textbox(label="Status", interactive=False)
572
 
573
  gr.Examples(
574
  examples=[
575
- ["examples/corridor_greenscreen_demo.mp4", "1024", 5, "Hybrid (auto)", True, 400, "Composite on checkerboard (MP4)"],
576
  ],
577
- inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size, output_mode],
578
- outputs=[output_video, output_file, status_text],
579
  fn=process_example,
580
  cache_examples=True,
581
  cache_mode="lazy",
@@ -584,62 +445,46 @@ with gr.Blocks(title="CorridorKey") as demo:
584
 
585
  process_btn.click(
586
  fn=process_video,
587
- inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size, output_mode],
588
- outputs=[output_video, output_file, status_text],
589
  )
590
 
591
 
592
  # ---------------------------------------------------------------------------
593
  # CLI mode
594
  # ---------------------------------------------------------------------------
595
-
596
  def cli_main():
597
- """CLI mode: python app.py --input video.mp4 [options]"""
598
  import argparse
599
  parser = argparse.ArgumentParser(description="CorridorKey Green Screen Matting")
600
- parser.add_argument("--input", required=True, help="Input video path")
601
- parser.add_argument("--output", default="output", help="Output directory")
602
- parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"],
603
- help="Force device (auto=detect GPU/CPU)")
604
- parser.add_argument("--resolution", default="1024", choices=["1024", "2048"],
605
- help="Model resolution (1024=fast, 2048=max quality)")
606
  parser.add_argument("--mask-mode", default="Hybrid (auto)",
607
  choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"])
608
- parser.add_argument("--despill", type=int, default=5, help="Despill strength 0-10")
609
  parser.add_argument("--no-despeckle", action="store_true")
610
  parser.add_argument("--despeckle-size", type=int, default=400)
611
- parser.add_argument("--format", default="Composite on checkerboard (MP4)",
612
- choices=["Composite on checkerboard (MP4)", "Alpha matte (MP4)",
613
- "Transparent video (WebM)", "PNG sequence (ZIP)"])
614
  args = parser.parse_args()
615
 
616
  global HAS_CUDA
617
- if args.device == "cpu":
618
- HAS_CUDA = False
619
- elif args.device == "cuda":
620
- HAS_CUDA = True
621
  print(f"Device: {'CUDA' if HAS_CUDA else 'CPU'}")
622
 
623
  class CLIProgress:
624
  def __call__(self, val, desc=""):
625
- if desc:
626
- print(f" [{val:.0%}] {desc}")
627
 
628
- video, file, status = process_video(
629
  args.input, args.resolution, args.despill, args.mask_mode,
630
- not args.no_despeckle, args.despeckle_size, args.format,
631
- progress=CLIProgress()
632
  )
633
  print(f"\n{status}")
634
- if video:
635
- os.makedirs(args.output, exist_ok=True)
636
- dst = os.path.join(args.output, os.path.basename(video))
637
- shutil.copy2(video, dst)
638
- print(f"Output: {dst}")
639
- if file:
640
- os.makedirs(args.output, exist_ok=True)
641
- dst = os.path.join(args.output, os.path.basename(file))
642
- shutil.copy2(file, dst)
643
  print(f"Output: {dst}")
644
 
645
 
 
29
  import onnxruntime as ort
30
 
31
  # Workaround: Gradio cache_examples bug with None outputs.
 
32
  _original_read_from_flag = gr.components.Component.read_from_flag
33
  def _patched_read_from_flag(self, payload):
34
  if payload is None or (isinstance(payload, str) and payload.strip() == ""):
35
  return None
36
  return _original_read_from_flag(self, payload)
37
  gr.components.Component.read_from_flag = _patched_read_from_flag
38
+
39
  from huggingface_hub import hf_hub_download
40
 
41
  cv2.setNumThreads(2)
 
42
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
43
  logger = logging.getLogger(__name__)
44
 
 
47
  # ---------------------------------------------------------------------------
48
  BIREFNET_REPO = "onnx-community/BiRefNet_lite-ONNX"
49
  BIREFNET_FILE = "onnx/model.onnx"
 
50
  MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
51
  CORRIDORKEY_MODELS = {
52
  "1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"),
53
  "2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"),
54
  }
 
55
  IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
56
  IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
 
57
  MAX_DURATION_CPU = 5
58
  MAX_DURATION_GPU = 30
59
  MAX_FRAMES = 150
 
 
60
  HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers()
61
 
62
  # ---------------------------------------------------------------------------
63
+ # Color utilities (numpy-only)
64
  # ---------------------------------------------------------------------------
 
65
  def linear_to_srgb(x):
66
  x = np.clip(x, 0.0, None)
67
  return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055)
68
 
 
69
  def srgb_to_linear(x):
70
  x = np.clip(x, 0.0, None)
71
  return np.where(x <= 0.04045, x / 12.92, np.power((x + 0.055) / 1.055, 2.4))
72
 
 
73
  def composite_straight(fg, bg, alpha):
74
  return fg * alpha + bg * (1.0 - alpha)
75
 
 
76
  def despill(image, green_limit_mode="average", strength=1.0):
77
  if strength <= 0.0:
78
  return image
79
  r, g, b = image[..., 0], image[..., 1], image[..., 2]
80
  limit = (r + b) / 2.0 if green_limit_mode == "average" else np.maximum(r, b)
81
+ spill = np.maximum(g - limit, 0.0)
82
+ despilled = np.stack([r + spill * 0.5, g - spill, b + spill * 0.5], axis=-1)
83
+ return image * (1.0 - strength) + despilled * strength if strength < 1.0 else despilled
 
 
 
 
 
 
84
 
85
  def clean_matte(alpha_np, area_threshold=300, dilation=15, blur_size=5):
86
  is_3d = alpha_np.ndim == 3
 
88
  alpha_np = alpha_np[:, :, 0]
89
  mask_8u = (alpha_np > 0.5).astype(np.uint8) * 255
90
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_8u, connectivity=8)
 
91
  valid = np.zeros(num_labels, dtype=bool)
92
  valid[1:] = stats[1:, cv2.CC_STAT_AREA] >= area_threshold
93
  cleaned = (valid[labels].astype(np.uint8) * 255)
94
  if dilation > 0:
95
  k = int(dilation * 2 + 1)
96
+ cleaned = cv2.dilate(cleaned, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)))
 
97
  if blur_size > 0:
98
  b = int(blur_size * 2 + 1)
99
  cleaned = cv2.GaussianBlur(cleaned, (b, b), 0)
100
+ result = alpha_np * (cleaned.astype(np.float32) / 255.0)
 
101
  return result[:, :, np.newaxis] if is_3d else result
102
 
 
103
  def create_checkerboard(w, h, checker_size=64, color1=0.15, color2=0.55):
104
+ xg, yg = np.meshgrid(np.arange(w) // checker_size, np.arange(h) // checker_size)
105
+ bg = np.where(((xg + yg) % 2) == 0, color1, color2).astype(np.float32)
 
 
 
106
  return np.stack([bg, bg, bg], axis=-1)
107
 
108
+ def premultiply(fg, alpha):
109
+ return fg * alpha
110
 
111
  # ---------------------------------------------------------------------------
112
+ # Fast classical green-screen mask
113
  # ---------------------------------------------------------------------------
 
114
  def fast_greenscreen_mask(frame_rgb_f32):
 
 
 
115
  h, w = frame_rgb_f32.shape[:2]
116
  ph, pw = max(int(h * 0.05), 4), max(int(w * 0.05), 4)
117
  corners = np.concatenate([
 
121
  frame_rgb_f32[-ph:, -pw:].reshape(-1, 3),
122
  ], axis=0)
123
  bg_color = np.median(corners, axis=0)
 
 
124
  if not (bg_color[1] > bg_color[0] + 0.05 and bg_color[1] > bg_color[2] + 0.05):
125
  return None, 0.0
 
 
126
  frame_u8 = (np.clip(frame_rgb_f32, 0, 1) * 255).astype(np.uint8)
127
  hsv = cv2.cvtColor(frame_u8, cv2.COLOR_RGB2HSV)
 
128
  green_mask = cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))
 
129
  fg_mask = cv2.bitwise_not(green_mask)
130
+ fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
 
 
131
  fg_mask = cv2.GaussianBlur(fg_mask, (5, 5), 0)
132
  mask_f32 = fg_mask.astype(np.float32) / 255.0
 
 
133
  confidence = 1.0 - 2.0 * np.mean(np.minimum(mask_f32, 1.0 - mask_f32))
 
134
  return mask_f32, confidence
135
 
 
136
  # ---------------------------------------------------------------------------
137
+ # Model loading
138
  # ---------------------------------------------------------------------------
139
  _birefnet_session = None
140
  _corridorkey_sessions = {}
141
 
142
+ def _ort_opts():
 
143
  opts = ort.SessionOptions()
144
  opts.intra_op_num_threads = 2
145
  opts.inter_op_num_threads = 1
 
148
  opts.enable_mem_pattern = True
149
  return opts
150
 
 
151
  def get_birefnet():
152
  global _birefnet_session
153
  if _birefnet_session is None:
154
  logger.info("Downloading BiRefNet-Lite ONNX...")
155
  path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
156
  logger.info("Loading BiRefNet ONNX: %s", path)
157
+ _birefnet_session = ort.InferenceSession(path, _ort_opts(), providers=["CPUExecutionProvider"])
158
  return _birefnet_session
159
 
 
160
  def get_corridorkey(resolution="1024"):
161
  global _corridorkey_sessions
162
  if resolution not in _corridorkey_sessions:
 
164
  if not onnx_path or not os.path.exists(onnx_path):
165
  raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.")
166
  logger.info("Loading CorridorKey ONNX (%s): %s", resolution, onnx_path)
167
+ _corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_opts(), providers=["CPUExecutionProvider"])
168
  return _corridorkey_sessions[resolution]
169
 
 
170
  # ---------------------------------------------------------------------------
171
  # Per-frame inference
172
  # ---------------------------------------------------------------------------
 
173
  def birefnet_frame(session, image_rgb_uint8):
 
174
  h, w = image_rgb_uint8.shape[:2]
175
+ inp = session.get_inputs()[0]
176
+ res = (inp.shape[2], inp.shape[3])
177
  img = cv2.resize(image_rgb_uint8, res).astype(np.float32) / 255.0
178
+ img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
179
+ pred = 1.0 / (1.0 + np.exp(-session.run(None, {inp.name: img})[-1]))
180
+ return (cv2.resize(pred[0, 0], (w, h)) > 0.04).astype(np.float32)
 
 
 
 
181
 
182
  def corridorkey_frame(session, image_f32, mask_f32, img_size,
183
+ despill_strength=0.5, auto_despeckle=True, despeckle_size=400):
 
 
184
  h, w = image_f32.shape[:2]
185
+ img_r = cv2.resize(image_f32, (img_size, img_size))
186
+ mask_r = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis]
187
+ inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1)
 
188
  inp = inp.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
 
189
  alpha_raw, fg_raw = session.run(None, {"input": inp})
 
190
  alpha = cv2.resize(alpha_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
191
  fg = cv2.resize(fg_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
192
  if alpha.ndim == 2:
193
  alpha = alpha[:, :, np.newaxis]
 
194
  if auto_despeckle:
195
  alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
196
  fg = despill(fg, green_limit_mode="average", strength=despill_strength)
 
197
  return {"alpha": alpha, "fg": fg}
198
 
 
199
  # ---------------------------------------------------------------------------
200
+ # Video stitching
201
  # ---------------------------------------------------------------------------
 
202
  def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420p",
203
  codec="libx264", extra_args=None):
204
+ cmd = ["ffmpeg", "-y", "-framerate", str(fps), "-i", os.path.join(frame_dir, pattern),
 
 
205
  "-c:v", codec, "-pix_fmt", pix_fmt]
206
  if extra_args:
207
  cmd.extend(extra_args)
 
213
  logger.warning("ffmpeg failed: %s", e)
214
  return False
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # ---------------------------------------------------------------------------
217
+ # Main pipeline: generates ALL professional outputs
218
  # ---------------------------------------------------------------------------
 
219
  def process_video(video_path, resolution, despill_val, mask_mode,
220
+ auto_despeckle, despeckle_size, progress=gr.Progress()):
221
  """Remove green screen background from video using CorridorKey AI matting.
222
+ Returns: comp_video, matte_video, download_zip, status
 
223
  """
224
  if video_path is None:
225
  raise gr.Error("Please upload a video.")
 
227
  max_dur = MAX_DURATION_GPU if HAS_CUDA else MAX_DURATION_CPU
228
  img_size = int(resolution)
229
 
 
230
  cap = cv2.VideoCapture(video_path)
231
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
232
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
236
 
237
  if total_frames == 0:
238
  raise gr.Error("Could not read video frames. Check file format.")
 
239
  duration = total_frames / fps
240
  if duration > max_dur:
241
  raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s on {'GPU' if HAS_CUDA else 'free CPU'} tier.")
 
244
  logger.info("Processing %d frames (%dx%d @ %.1f fps), resolution=%d, mask=%s",
245
  frames_to_process, w, h, fps, img_size, mask_mode)
246
 
 
247
  try:
248
  birefnet = None
249
  if mask_mode != "Fast (classical)":
 
255
  raise gr.Error(f"Failed to load models: {e}")
256
 
257
  despill_strength = despill_val / 10.0
 
 
 
 
 
 
258
  tmpdir = tempfile.mkdtemp(prefix="ck_")
259
+
260
  try:
261
+ # Output dirs matching original CorridorKey structure
262
+ comp_dir = os.path.join(tmpdir, "Comp")
263
+ fg_dir = os.path.join(tmpdir, "FG")
264
+ matte_dir = os.path.join(tmpdir, "Matte")
265
+ processed_dir = os.path.join(tmpdir, "Processed")
266
+ for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
267
+ os.makedirs(d, exist_ok=True)
268
+
269
+ bg_lin = srgb_to_linear(create_checkerboard(w, h))
270
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  cap = cv2.VideoCapture(video_path)
272
  frame_times = []
273
+ total_start = time.time()
274
 
275
  for i in range(frames_to_process):
276
  t0 = time.time()
 
281
  frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
282
  frame_f32 = frame_rgb.astype(np.float32) / 255.0
283
 
284
+ # Coarse mask
285
  if mask_mode == "Fast (classical)":
286
+ mask, _ = fast_greenscreen_mask(frame_f32)
287
  if mask is None:
288
+ raise gr.Error("Fast mask failed: no green screen detected. Try 'AI (BiRefNet)' mode.")
289
  elif mask_mode == "Hybrid (auto)":
290
+ mask, conf = fast_greenscreen_mask(frame_f32)
291
+ if mask is None or conf < 0.7:
292
  mask = birefnet_frame(birefnet, frame_rgb)
293
+ else:
294
  mask = birefnet_frame(birefnet, frame_rgb)
295
 
296
  # CorridorKey inference
 
298
  despill_strength=despill_strength,
299
  auto_despeckle=auto_despeckle,
300
  despeckle_size=int(despeckle_size))
 
301
  alpha = result["alpha"]
302
  fg = result["fg"]
303
 
304
+ # Ensure alpha is [H,W,1] and get 2D version
305
+ if alpha.ndim == 2:
306
+ alpha = alpha[:, :, np.newaxis]
307
+ alpha_2d = alpha[:, :, 0]
308
+
309
+ # -- Comp: composite on checkerboard (sRGB PNG) --
310
+ fg_lin = srgb_to_linear(fg)
311
+ comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
312
+ cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.png"),
313
+ (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1])
314
+
315
+ # -- FG: straight foreground, 100% opaque (sRGB PNG) --
316
+ cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.png"),
317
+ (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1])
318
+
319
+ # -- Matte: alpha channel (grayscale PNG) --
320
+ cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"),
321
+ (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8))
322
+
323
+ # -- Processed: premultiplied RGBA (PNG with transparency) --
324
+ fg_premul_lin = premultiply(fg_lin, alpha)
325
+ fg_premul_srgb = linear_to_srgb(fg_premul_lin)
326
+ fg_premul_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8)
327
+ alpha_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8)
328
+ rgba = np.concatenate([fg_premul_u8[:, :, ::-1], alpha_u8[:, :, np.newaxis]], axis=-1)
329
+ cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba)
330
 
331
  # Progress with ETA
332
  elapsed = time.time() - t0
333
  frame_times.append(elapsed)
334
+ avg_t = np.mean(frame_times[-5:]) if len(frame_times) >= 2 else elapsed
335
+ remaining = (frames_to_process - i - 1) * avg_t
336
  eta = f"{remaining/60:.1f}min" if remaining > 60 else f"{remaining:.0f}s"
337
  pct = 0.05 + 0.85 * (i + 1) / frames_to_process
338
  progress(pct, desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) | ~{eta} left")
339
 
340
  cap.release()
341
+ total_elapsed = time.time() - total_start
342
+ total_min = total_elapsed / 60
343
+
344
+ # Stitch preview videos
345
+ progress(0.92, desc="Stitching videos...")
346
+ comp_video = os.path.join(tmpdir, "comp_preview.mp4")
347
+ matte_video = os.path.join(tmpdir, "matte_preview.mp4")
348
+ _stitch_ffmpeg(comp_dir, comp_video, fps, extra_args=["-crf", "18"])
349
+ _stitch_ffmpeg(matte_dir, matte_video, fps, extra_args=["-crf", "18"])
350
+
351
+ # Package full professional ZIP
352
+ progress(0.96, desc="Packaging ZIP...")
353
+ zip_path = os.path.join(tmpdir, "CorridorKey_Output.zip")
354
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
355
+ for folder in ["Comp", "FG", "Matte", "Processed"]:
356
+ src = os.path.join(tmpdir, folder)
357
+ for f in sorted(os.listdir(src)):
358
+ zf.write(os.path.join(src, f), f"Output/{folder}/{f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  progress(1.0, desc="Done!")
361
+ n = len(frame_times)
362
  avg = np.mean(frame_times) if frame_times else 0
363
+ status = f"Processed {n} frames in {total_min:.1f}min ({w}x{h}) at {img_size}px | {avg:.1f}s/frame"
364
+
365
+ return (
366
+ comp_video if os.path.exists(comp_video) else None,
367
+ matte_video if os.path.exists(matte_video) else None,
368
+ zip_path,
369
+ status,
370
+ )
371
 
372
  except gr.Error:
373
  raise
 
375
  logger.exception("Processing failed")
376
  raise gr.Error(f"Processing failed: {e}")
377
  finally:
378
+ for d in ["Comp", "FG", "Matte", "Processed"]:
 
379
  p = os.path.join(tmpdir, d)
380
  if os.path.isdir(p):
381
  shutil.rmtree(p, ignore_errors=True)
 
385
  # ---------------------------------------------------------------------------
386
  # Gradio UI
387
  # ---------------------------------------------------------------------------
388
+ def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size):
389
+ return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size)
 
 
390
 
391
  if HAS_CUDA:
392
  DESCRIPTION = "# CorridorKey Green Screen Matting\nRemove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. GPU mode: max {max_dur}s / {max_frames} frames.".format(max_dur=MAX_DURATION_GPU, max_frames=MAX_FRAMES)
 
399
  with gr.Row():
400
  with gr.Column(scale=1):
401
  input_video = gr.Video(label="Upload Green Screen Video")
 
402
  with gr.Accordion("Settings", open=True):
403
  resolution = gr.Radio(
404
+ choices=["1024", "2048"], value="1024",
 
405
  label="Processing Resolution",
406
+ info="1024 = balanced (~8s/frame CPU), 2048 = max quality (fast on GPU)"
407
  )
408
  mask_mode = gr.Radio(
409
  choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"],
410
+ value="Hybrid (auto)", label="Mask Mode",
411
+ info="Hybrid = fast green detection + AI fallback. Fast = classical only. AI = always BiRefNet"
 
412
  )
413
  despill_slider = gr.Slider(
414
+ 0, 10, value=5, step=1, label="Despill Strength",
415
+ info="Remove green reflections (0=off, 10=max)"
 
416
  )
417
  despeckle_check = gr.Checkbox(
418
+ value=True, label="Auto Despeckle",
419
+ info="Remove small disconnected artifacts"
 
420
  )
421
  despeckle_size = gr.Number(
422
+ value=400, precision=0, label="Despeckle Size",
423
+ info="Min pixel area to keep"
 
424
  )
 
 
 
 
 
 
 
 
 
 
 
 
425
  process_btn = gr.Button("Process Video", variant="primary", size="lg")
426
 
427
  with gr.Column(scale=1):
428
+ with gr.Row():
429
+ comp_video = gr.Video(label="Composite Preview")
430
+ matte_video = gr.Video(label="Alpha Matte")
431
+ download_zip = gr.File(label="Download Full Package (Comp + FG + Matte + Processed)")
432
  status_text = gr.Textbox(label="Status", interactive=False)
433
 
434
  gr.Examples(
435
  examples=[
436
+ ["examples/corridor_greenscreen_demo.mp4", "1024", 5, "Hybrid (auto)", True, 400],
437
  ],
438
+ inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size],
439
+ outputs=[comp_video, matte_video, download_zip, status_text],
440
  fn=process_example,
441
  cache_examples=True,
442
  cache_mode="lazy",
 
445
 
446
  process_btn.click(
447
  fn=process_video,
448
+ inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size],
449
+ outputs=[comp_video, matte_video, download_zip, status_text],
450
  )
451
 
452
 
453
  # ---------------------------------------------------------------------------
454
  # CLI mode
455
  # ---------------------------------------------------------------------------
 
456
  def cli_main():
 
457
  import argparse
458
  parser = argparse.ArgumentParser(description="CorridorKey Green Screen Matting")
459
+ parser.add_argument("--input", required=True)
460
+ parser.add_argument("--output", default="output")
461
+ parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
462
+ parser.add_argument("--resolution", default="1024", choices=["1024", "2048"])
 
 
463
  parser.add_argument("--mask-mode", default="Hybrid (auto)",
464
  choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"])
465
+ parser.add_argument("--despill", type=int, default=5)
466
  parser.add_argument("--no-despeckle", action="store_true")
467
  parser.add_argument("--despeckle-size", type=int, default=400)
 
 
 
468
  args = parser.parse_args()
469
 
470
  global HAS_CUDA
471
+ if args.device == "cpu": HAS_CUDA = False
472
+ elif args.device == "cuda": HAS_CUDA = True
 
 
473
  print(f"Device: {'CUDA' if HAS_CUDA else 'CPU'}")
474
 
475
  class CLIProgress:
476
  def __call__(self, val, desc=""):
477
+ if desc: print(f" [{val:.0%}] {desc}")
 
478
 
479
+ comp, matte, zipf, status = process_video(
480
  args.input, args.resolution, args.despill, args.mask_mode,
481
+ not args.no_despeckle, args.despeckle_size, progress=CLIProgress()
 
482
  )
483
  print(f"\n{status}")
484
+ os.makedirs(args.output, exist_ok=True)
485
+ if zipf:
486
+ dst = os.path.join(args.output, os.path.basename(zipf))
487
+ shutil.copy2(zipf, dst)
 
 
 
 
 
488
  print(f"Output: {dst}")
489
 
490