Nekochu commited on
Commit
20c856c
·
1 Parent(s): 2b1c753

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,85 @@
1
  ---
2
- title: CorridorKeyCPU
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 6.9.0
8
  app_file: app.py
 
9
  pinned: false
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CorridorKey
3
+ emoji: 🎬
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.9.0
8
  app_file: app.py
9
+ python_version: "3.10"
10
  pinned: false
11
+ tags:
12
+ - green-screen
13
+ - background-removal
14
+ - video-matting
15
+ - alpha-matting
16
+ - vfx
17
+ - corridor-digital
18
+ - transparency
19
+ - onnx
20
+ - mcp-server
21
+ short_description: Remove green background from video, even transparent objects
22
  ---
23
 
24
+ # CorridorKey Green Screen Matting (CPU)
25
+
26
+ Remove green screen backgrounds from video on free CPU. Handles transparent objects (glass, water, cloth) that traditional chroma key cannot.
27
+
28
+ Based on [CorridorKey](https://github.com/nikopueringer/CorridorKey) by Corridor Digital.
29
+
30
+ ## Pipeline
31
+
32
+ 1. **BiRefNet** - Generates coarse foreground mask
33
+ 2. **CorridorKey GreenFormer** - Refines alpha matte + extracts clean foreground
34
+ 3. **Compositing** - Despill, despeckle, composite on new background
35
+
36
+ ## API
37
+
38
+ ### REST API
39
+
40
+ **Step 1: Submit request**
41
+ ```bash
42
+ curl -X POST "https://luminia-corridorkey.hf.space/gradio_api/call/process_video" \
43
+ -H "Content-Type: application/json" \
44
+ -d '{"data": ["video.mp4", 5, 10, true, 400, "Composite on checkerboard (MP4)"]}'
45
+ ```
46
+
47
+ **Step 2: Get result**
48
+ ```bash
49
+ curl "https://luminia-corridorkey.hf.space/gradio_api/call/process_video/{event_id}"
50
+ ```
51
+
52
+ ### MCP (Model Context Protocol)
53
+
54
+ **Tool schema:**
55
+ ```json
56
+ {
57
+ "name": "process_video",
58
+ "description": "Remove green screen background from video using CorridorKey AI matting.",
59
+ "parameters": {
60
+ "video_path": "Path to green screen video",
61
+ "despill_val": "Despill strength 0-10 (default 5)",
62
+ "refiner_val": "Refiner scale 0-20 (default 10)",
63
+ "auto_despeckle": "Remove small artifacts (default true)",
64
+ "despeckle_size": "Min pixel area to keep (default 400)",
65
+ "output_mode": "Composite on checkerboard (MP4) | Alpha matte (MP4) | Transparent video (WebM) | PNG sequence (ZIP)"
66
+ }
67
+ }
68
+ ```
69
+
70
+ **MCP Config:**
71
+ ```json
72
+ {
73
+ "mcpServers": {
74
+ "corridorkey-cpu": {
75
+ "url": "https://luminia-corridorkey.hf.space/gradio_api/mcp/"
76
+ }
77
+ }
78
+ }
79
+ ```
80
+
81
+ ## Credits
82
+
83
+ - [CorridorKey](https://github.com/nikopueringer/CorridorKey) by Niko Pueringer / Corridor Digital
84
+ - [EZ-CorridorKey](https://github.com/edenaion/EZ-CorridorKey) UI reference by edenaion
85
+ - [BiRefNet](https://github.com/ZhengPeng7/BiRefNet) by ZhengPeng7
app.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CorridorKey Green Screen Matting - HuggingFace Space.
2
+
3
+ Self-contained Gradio app using ONNX Runtime for inference.
4
+ Supports CPU (free tier) and GPU (community grant).
5
+
6
+ Usage:
7
+ python app.py # Launch Gradio UI
8
+ python app.py --input video.mp4 # CLI mode
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import shutil
14
+ import gc
15
+ import time
16
+ import tempfile
17
+ import zipfile
18
+ import subprocess
19
+ import logging
20
+
21
+ # Thread tuning for CPU (must be set before numpy/cv2/ort import)
22
+ os.environ["OMP_NUM_THREADS"] = "2"
23
+ os.environ["OPENBLAS_NUM_THREADS"] = "2"
24
+ os.environ["MKL_NUM_THREADS"] = "2"
25
+
26
+ import numpy as np
27
+ import cv2
28
+ import gradio as gr
29
+ import onnxruntime as ort
30
+ from huggingface_hub import hf_hub_download
31
+
32
+ cv2.setNumThreads(2)
33
+
34
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Constants
39
+ # ---------------------------------------------------------------------------
40
+ BIREFNET_REPO = "onnx-community/BiRefNet_lite-ONNX"
41
+ BIREFNET_FILE = "onnx/model.onnx"
42
+
43
+ MODELS_DIR = os.path.join(os.path.dirname(__file__), "models")
44
+ CORRIDORKEY_MODELS = {
45
+ "1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"),
46
+ "2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"),
47
+ }
48
+
49
+ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
50
+ IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
51
+
52
+ MAX_DURATION_CPU = 5
53
+ MAX_DURATION_GPU = 30
54
+ MAX_FRAMES = 150
55
+
56
+ # GPU auto-detect via ONNX Runtime (no torch dependency)
57
+ HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers()
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Color utilities (numpy-only, from CorridorKeyModule/core/color_utils.py)
61
+ # ---------------------------------------------------------------------------
62
+
63
+ def linear_to_srgb(x):
64
+ x = np.clip(x, 0.0, None)
65
+ return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055)
66
+
67
+
68
+ def srgb_to_linear(x):
69
+ x = np.clip(x, 0.0, None)
70
+ return np.where(x <= 0.04045, x / 12.92, np.power((x + 0.055) / 1.055, 2.4))
71
+
72
+
73
+ def composite_straight(fg, bg, alpha):
74
+ return fg * alpha + bg * (1.0 - alpha)
75
+
76
+
77
+ def despill(image, green_limit_mode="average", strength=1.0):
78
+ if strength <= 0.0:
79
+ return image
80
+ r, g, b = image[..., 0], image[..., 1], image[..., 2]
81
+ limit = (r + b) / 2.0 if green_limit_mode == "average" else np.maximum(r, b)
82
+ spill_amount = np.maximum(g - limit, 0.0)
83
+ g_new = g - spill_amount
84
+ r_new = r + spill_amount * 0.5
85
+ b_new = b + spill_amount * 0.5
86
+ despilled = np.stack([r_new, g_new, b_new], axis=-1)
87
+ if strength < 1.0:
88
+ return image * (1.0 - strength) + despilled * strength
89
+ return despilled
90
+
91
+
92
+ def clean_matte(alpha_np, area_threshold=300, dilation=15, blur_size=5):
93
+ is_3d = alpha_np.ndim == 3
94
+ if is_3d:
95
+ alpha_np = alpha_np[:, :, 0]
96
+ mask_8u = (alpha_np > 0.5).astype(np.uint8) * 255
97
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_8u, connectivity=8)
98
+ # Vectorized: find valid labels in one pass
99
+ valid = np.zeros(num_labels, dtype=bool)
100
+ valid[1:] = stats[1:, cv2.CC_STAT_AREA] >= area_threshold
101
+ cleaned = (valid[labels].astype(np.uint8) * 255)
102
+ if dilation > 0:
103
+ k = int(dilation * 2 + 1)
104
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
105
+ cleaned = cv2.dilate(cleaned, kernel)
106
+ if blur_size > 0:
107
+ b = int(blur_size * 2 + 1)
108
+ cleaned = cv2.GaussianBlur(cleaned, (b, b), 0)
109
+ safe_zone = cleaned.astype(np.float32) / 255.0
110
+ result = alpha_np * safe_zone
111
+ return result[:, :, np.newaxis] if is_3d else result
112
+
113
+
114
+ def create_checkerboard(w, h, checker_size=64, color1=0.15, color2=0.55):
115
+ x_tiles = np.arange(w) // checker_size
116
+ y_tiles = np.arange(h) // checker_size
117
+ xg, yg = np.meshgrid(x_tiles, y_tiles)
118
+ checker = ((xg + yg) % 2).astype(np.float32)
119
+ bg = np.where(checker == 0, color1, color2).astype(np.float32)
120
+ return np.stack([bg, bg, bg], axis=-1)
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Fast classical green-screen mask (alternative to BiRefNet)
125
+ # ---------------------------------------------------------------------------
126
+
127
+ def fast_greenscreen_mask(frame_rgb_f32):
128
+ """Fast green-screen detection using corner sampling + HSV threshold.
129
+ Returns (mask_f32, confidence) or (None, 0.0) if not a green screen.
130
+ """
131
+ h, w = frame_rgb_f32.shape[:2]
132
+ ph, pw = max(int(h * 0.05), 4), max(int(w * 0.05), 4)
133
+ corners = np.concatenate([
134
+ frame_rgb_f32[:ph, :pw].reshape(-1, 3),
135
+ frame_rgb_f32[:ph, -pw:].reshape(-1, 3),
136
+ frame_rgb_f32[-ph:, :pw].reshape(-1, 3),
137
+ frame_rgb_f32[-ph:, -pw:].reshape(-1, 3),
138
+ ], axis=0)
139
+ bg_color = np.median(corners, axis=0)
140
+
141
+ # Check if background is green-ish (G channel dominant)
142
+ if not (bg_color[1] > bg_color[0] + 0.05 and bg_color[1] > bg_color[2] + 0.05):
143
+ return None, 0.0
144
+
145
+ # HSV-based mask (more robust than RGB distance)
146
+ frame_u8 = (np.clip(frame_rgb_f32, 0, 1) * 255).astype(np.uint8)
147
+ hsv = cv2.cvtColor(frame_u8, cv2.COLOR_RGB2HSV)
148
+ # Green hue range in HSV
149
+ green_mask = cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))
150
+ # Invert: foreground = NOT green
151
+ fg_mask = cv2.bitwise_not(green_mask)
152
+ # Morphological close to fill small holes
153
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
154
+ fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, kernel)
155
+ fg_mask = cv2.GaussianBlur(fg_mask, (5, 5), 0)
156
+ mask_f32 = fg_mask.astype(np.float32) / 255.0
157
+
158
+ # Confidence: how bimodal is the mask (closer to 0/1 = better)
159
+ confidence = 1.0 - 2.0 * np.mean(np.minimum(mask_f32, 1.0 - mask_f32))
160
+
161
+ return mask_f32, confidence
162
+
163
+
164
+ # ---------------------------------------------------------------------------
165
+ # Model loading (lazy singletons)
166
+ # ---------------------------------------------------------------------------
167
+ _birefnet_session = None
168
+ _corridorkey_sessions = {}
169
+
170
+
171
+ def _ort_session_opts():
172
+ opts = ort.SessionOptions()
173
+ opts.intra_op_num_threads = 2
174
+ opts.inter_op_num_threads = 1
175
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
176
+ opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
177
+ opts.enable_mem_pattern = True
178
+ return opts
179
+
180
+
181
+ def get_birefnet():
182
+ global _birefnet_session
183
+ if _birefnet_session is None:
184
+ logger.info("Downloading BiRefNet-Lite ONNX...")
185
+ path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
186
+ logger.info("Loading BiRefNet ONNX: %s", path)
187
+ _birefnet_session = ort.InferenceSession(path, _ort_session_opts(), providers=["CPUExecutionProvider"])
188
+ return _birefnet_session
189
+
190
+
191
+ def get_corridorkey(resolution="1024"):
192
+ global _corridorkey_sessions
193
+ if resolution not in _corridorkey_sessions:
194
+ onnx_path = CORRIDORKEY_MODELS.get(resolution)
195
+ if not onnx_path or not os.path.exists(onnx_path):
196
+ raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.")
197
+ logger.info("Loading CorridorKey ONNX (%s): %s", resolution, onnx_path)
198
+ _corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_session_opts(), providers=["CPUExecutionProvider"])
199
+ return _corridorkey_sessions[resolution]
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # Per-frame inference
204
+ # ---------------------------------------------------------------------------
205
+
206
+ def birefnet_frame(session, image_rgb_uint8):
207
+ """BiRefNet: RGB uint8 [H,W,3] -> float32 [H,W] mask 0-1."""
208
+ h, w = image_rgb_uint8.shape[:2]
209
+ inp_info = session.get_inputs()[0]
210
+ res = (inp_info.shape[2], inp_info.shape[3])
211
+ img = cv2.resize(image_rgb_uint8, res).astype(np.float32) / 255.0
212
+ img = (img - IMAGENET_MEAN) / IMAGENET_STD
213
+ img = img.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
214
+ outputs = session.run(None, {inp_info.name: img})
215
+ pred = 1.0 / (1.0 + np.exp(-outputs[-1])) # sigmoid
216
+ mask = cv2.resize(pred[0, 0], (w, h))
217
+ return (mask > 0.04).astype(np.float32)
218
+
219
+
220
+ def corridorkey_frame(session, image_f32, mask_f32, img_size,
221
+ despill_strength=0.5, auto_despeckle=True,
222
+ despeckle_size=400):
223
+ """CorridorKey: image [H,W,3] float32 0-1 + mask [H,W] float32 0-1 -> dict."""
224
+ h, w = image_f32.shape[:2]
225
+ img_resized = cv2.resize(image_f32, (img_size, img_size))
226
+ mask_resized = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis]
227
+ img_norm = (img_resized - IMAGENET_MEAN) / IMAGENET_STD
228
+ inp = np.concatenate([img_norm, mask_resized], axis=-1)
229
+ inp = inp.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
230
+
231
+ alpha_raw, fg_raw = session.run(None, {"input": inp})
232
+
233
+ alpha = cv2.resize(alpha_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
234
+ fg = cv2.resize(fg_raw[0].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
235
+ if alpha.ndim == 2:
236
+ alpha = alpha[:, :, np.newaxis]
237
+
238
+ if auto_despeckle:
239
+ alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
240
+ fg = despill(fg, green_limit_mode="average", strength=despill_strength)
241
+
242
+ return {"alpha": alpha, "fg": fg}
243
+
244
+
245
+ # ---------------------------------------------------------------------------
246
+ # Video stitching via ffmpeg
247
+ # ---------------------------------------------------------------------------
248
+
249
+ def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420p",
250
+ codec="libx264", extra_args=None):
251
+ """Stitch PNG frames into video via ffmpeg subprocess."""
252
+ cmd = ["ffmpeg", "-y", "-framerate", str(fps),
253
+ "-i", os.path.join(frame_dir, pattern),
254
+ "-c:v", codec, "-pix_fmt", pix_fmt]
255
+ if extra_args:
256
+ cmd.extend(extra_args)
257
+ cmd.append(out_path)
258
+ try:
259
+ subprocess.run(cmd, capture_output=True, timeout=300, check=True)
260
+ return True
261
+ except (FileNotFoundError, subprocess.TimeoutExpired, subprocess.CalledProcessError) as e:
262
+ logger.warning("ffmpeg failed: %s", e)
263
+ return False
264
+
265
+
266
+ def _stitch_cv2_fallback(frame_dir, out_path, fps, w, h, grayscale=False):
267
+ """Fallback: stitch via OpenCV VideoWriter if ffmpeg unavailable."""
268
+ files = sorted([f for f in os.listdir(frame_dir) if f.endswith(".png")])
269
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
270
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
271
+ if not writer.isOpened():
272
+ logger.warning("mp4v codec unavailable")
273
+ return False
274
+ for f in files:
275
+ img = cv2.imread(os.path.join(frame_dir, f),
276
+ cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR)
277
+ if img is None:
278
+ continue
279
+ if grayscale:
280
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
281
+ writer.write(img)
282
+ writer.release()
283
+ return True
284
+
285
+
286
+ # ---------------------------------------------------------------------------
287
+ # Video processing pipeline (single-pass, streaming)
288
+ # ---------------------------------------------------------------------------
289
+
290
+ def process_video(video_path, resolution, despill_val, mask_mode,
291
+ auto_despeckle, despeckle_size, output_mode, progress=gr.Progress()):
292
+ """Remove green screen background from video using CorridorKey AI matting.
293
+ Handles transparent objects (glass, water, cloth) that traditional chroma key cannot.
294
+ Returns composite video, downloadable file, and status message.
295
+ """
296
+ if video_path is None:
297
+ raise gr.Error("Please upload a video.")
298
+
299
+ max_dur = MAX_DURATION_GPU if HAS_CUDA else MAX_DURATION_CPU
300
+ img_size = int(resolution)
301
+
302
+ # Probe video
303
+ cap = cv2.VideoCapture(video_path)
304
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
305
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
306
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
307
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
308
+ cap.release()
309
+
310
+ if total_frames == 0:
311
+ raise gr.Error("Could not read video frames. Check file format.")
312
+
313
+ duration = total_frames / fps
314
+ if duration > max_dur:
315
+ raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s on {'GPU' if HAS_CUDA else 'free CPU'} tier.")
316
+
317
+ frames_to_process = min(total_frames, MAX_FRAMES)
318
+ logger.info("Processing %d frames (%dx%d @ %.1f fps), resolution=%d, mask=%s",
319
+ frames_to_process, w, h, fps, img_size, mask_mode)
320
+
321
+ # Load models
322
+ try:
323
+ birefnet = None
324
+ if mask_mode != "Fast (classical)":
325
+ progress(0, desc="Loading BiRefNet...")
326
+ birefnet = get_birefnet()
327
+ progress(0.03, desc=f"Loading CorridorKey ({resolution})...")
328
+ corridorkey = get_corridorkey(resolution)
329
+ except Exception as e:
330
+ raise gr.Error(f"Failed to load models: {e}")
331
+
332
+ despill_strength = despill_val / 10.0
333
+
334
+ # Determine what outputs we need
335
+ need_comp = output_mode == "Composite on checkerboard (MP4)"
336
+ need_alpha = output_mode == "Alpha matte (MP4)"
337
+ need_rgba = output_mode in ("Transparent video (WebM)", "PNG sequence (ZIP)")
338
+
339
+ tmpdir = tempfile.mkdtemp(prefix="ck_")
340
+ try:
341
+ # Pre-compute checkerboard if needed
342
+ bg_lin = None
343
+ if need_comp:
344
+ bg_lin = srgb_to_linear(create_checkerboard(w, h))
345
+
346
+ # For PNG-based outputs, create dirs
347
+ rgba_dir = None
348
+ alpha_dir = None
349
+ comp_dir = None
350
+ if need_rgba:
351
+ rgba_dir = os.path.join(tmpdir, "rgba")
352
+ os.makedirs(rgba_dir, exist_ok=True)
353
+ if output_mode == "PNG sequence (ZIP)":
354
+ alpha_dir = os.path.join(tmpdir, "alphas")
355
+ os.makedirs(alpha_dir, exist_ok=True)
356
+
357
+ # For MP4 modes, write directly to VideoWriter via temp PNGs + ffmpeg
358
+ # (we still need PNGs as ffmpeg input, but only the needed type)
359
+ if need_comp:
360
+ comp_dir = os.path.join(tmpdir, "comp")
361
+ os.makedirs(comp_dir, exist_ok=True)
362
+ if need_alpha:
363
+ alpha_dir = os.path.join(tmpdir, "alphas")
364
+ os.makedirs(alpha_dir, exist_ok=True)
365
+
366
+ # Single-pass processing
367
+ cap = cv2.VideoCapture(video_path)
368
+ frame_times = []
369
+
370
+ for i in range(frames_to_process):
371
+ t0 = time.time()
372
+ ret, frame_bgr = cap.read()
373
+ if not ret:
374
+ break
375
+
376
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
377
+ frame_f32 = frame_rgb.astype(np.float32) / 255.0
378
+
379
+ # Coarse mask: fast classical or BiRefNet
380
+ if mask_mode == "Fast (classical)":
381
+ mask, confidence = fast_greenscreen_mask(frame_f32)
382
+ if mask is None:
383
+ raise gr.Error("Fast mask failed: video doesn't appear to have a green screen background. Try 'AI (BiRefNet)' mode.")
384
+ elif mask_mode == "Hybrid (auto)":
385
+ mask, confidence = fast_greenscreen_mask(frame_f32)
386
+ if mask is None or confidence < 0.7:
387
+ mask = birefnet_frame(birefnet, frame_rgb)
388
+ else: # "AI (BiRefNet)"
389
+ mask = birefnet_frame(birefnet, frame_rgb)
390
+
391
+ # CorridorKey inference
392
+ result = corridorkey_frame(corridorkey, frame_f32, mask, img_size,
393
+ despill_strength=despill_strength,
394
+ auto_despeckle=auto_despeckle,
395
+ despeckle_size=int(despeckle_size))
396
+
397
+ alpha = result["alpha"]
398
+ fg = result["fg"]
399
+
400
+ # Write only the output we need
401
+ if need_comp:
402
+ fg_lin = srgb_to_linear(fg)
403
+ comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
404
+ comp_uint8 = (np.clip(comp, 0, 1) * 255).astype(np.uint8)
405
+ cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.png"), comp_uint8[:, :, ::-1])
406
+
407
+ if need_alpha or alpha_dir:
408
+ alpha_uint8 = (np.clip(alpha, 0, 1) * 255).astype(np.uint8)
409
+ if alpha_uint8.ndim == 3:
410
+ alpha_uint8 = alpha_uint8[:, :, 0]
411
+ if alpha_dir:
412
+ cv2.imwrite(os.path.join(alpha_dir, f"{i:05d}.png"), alpha_uint8)
413
+
414
+ if need_rgba:
415
+ fg_uint8 = (np.clip(fg, 0, 1) * 255).astype(np.uint8)
416
+ a_uint8 = (np.clip(alpha, 0, 1) * 255).astype(np.uint8)
417
+ if a_uint8.ndim == 3:
418
+ a_uint8 = a_uint8[:, :, 0]
419
+ rgba = np.concatenate([fg_uint8[:, :, ::-1], a_uint8[:, :, np.newaxis]], axis=-1)
420
+ cv2.imwrite(os.path.join(rgba_dir, f"{i:05d}.png"), rgba)
421
+
422
+ # Progress with ETA
423
+ elapsed = time.time() - t0
424
+ frame_times.append(elapsed)
425
+ avg_time = np.mean(frame_times[-5:]) if len(frame_times) >= 2 else elapsed
426
+ remaining = (frames_to_process - i - 1) * avg_time
427
+ eta = f"{remaining/60:.1f}min" if remaining > 60 else f"{remaining:.0f}s"
428
+ pct = 0.05 + 0.85 * (i + 1) / frames_to_process
429
+ progress(pct, desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) | ~{eta} left")
430
+
431
+ cap.release()
432
+
433
+ # Assemble output
434
+ progress(0.92, desc="Stitching video...")
435
+ output_video = None
436
+ output_file = None
437
+
438
+ if need_comp:
439
+ out_path = os.path.join(tmpdir, "composite.mp4")
440
+ ok = _stitch_ffmpeg(comp_dir, out_path, fps, extra_args=["-crf", "18"])
441
+ if not ok:
442
+ ok = _stitch_cv2_fallback(comp_dir, out_path, fps, w, h)
443
+ if not ok:
444
+ raise gr.Error("Video encoding failed. No suitable codec found.")
445
+ output_video = out_path
446
+
447
+ elif need_alpha:
448
+ out_path = os.path.join(tmpdir, "alpha_matte.mp4")
449
+ ok = _stitch_ffmpeg(alpha_dir, out_path, fps, extra_args=["-crf", "18"])
450
+ if not ok:
451
+ ok = _stitch_cv2_fallback(alpha_dir, out_path, fps, w, h, grayscale=True)
452
+ if not ok:
453
+ raise gr.Error("Video encoding failed. No suitable codec found.")
454
+ output_video = out_path
455
+
456
+ elif output_mode == "Transparent video (WebM)":
457
+ out_path = os.path.join(tmpdir, "transparent.webm")
458
+ ok = _stitch_ffmpeg(rgba_dir, out_path, fps,
459
+ codec="libvpx-vp9", pix_fmt="yuva420p",
460
+ extra_args=["-crf", "30", "-b:v", "0"])
461
+ if not ok:
462
+ raise gr.Error("WebM encoding failed. ffmpeg with libvpx-vp9 required.")
463
+ output_video = out_path
464
+
465
+ elif output_mode == "PNG sequence (ZIP)":
466
+ zip_path = os.path.join(tmpdir, "rgba_sequence.zip")
467
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
468
+ for f in sorted(os.listdir(rgba_dir)):
469
+ zf.write(os.path.join(rgba_dir, f), f"rgba/{f}")
470
+ if alpha_dir:
471
+ for f in sorted(os.listdir(alpha_dir)):
472
+ zf.write(os.path.join(alpha_dir, f), f"alpha/{f}")
473
+ output_file = zip_path
474
+
475
+ progress(1.0, desc="Done!")
476
+ avg = np.mean(frame_times) if frame_times else 0
477
+ status = f"Processed {len(frame_times)} frames ({w}x{h}) at {img_size}px | {avg:.1f}s/frame avg"
478
+ return output_video, output_file, status
479
+
480
+ except gr.Error:
481
+ raise
482
+ except Exception as e:
483
+ logger.exception("Processing failed")
484
+ raise gr.Error(f"Processing failed: {e}")
485
+ finally:
486
+ # Cleanup intermediate dirs (keep output files in tmpdir root)
487
+ for d in ["comp", "alphas", "rgba"]:
488
+ p = os.path.join(tmpdir, d)
489
+ if os.path.isdir(p):
490
+ shutil.rmtree(p, ignore_errors=True)
491
+ gc.collect()
492
+
493
+
494
+ # ---------------------------------------------------------------------------
495
+ # Gradio UI
496
+ # ---------------------------------------------------------------------------
497
+
498
+ def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size, output_mode):
499
+ return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size, output_mode)
500
+
501
+
502
+ if HAS_CUDA:
503
+ DESCRIPTION = "# CorridorKey Green Screen Matting\nRemove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. GPU mode: max {max_dur}s / {max_frames} frames.".format(max_dur=MAX_DURATION_GPU, max_frames=MAX_FRAMES)
504
+ else:
505
+ DESCRIPTION = "# CorridorKey Green Screen Matting\nRemove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. ~37min for 5s clip on free CPU."
506
+
507
+ with gr.Blocks(title="CorridorKey") as demo:
508
+ gr.Markdown(DESCRIPTION)
509
+
510
+ with gr.Row():
511
+ with gr.Column(scale=1):
512
+ input_video = gr.Video(label="Upload Green Screen Video")
513
+
514
+ with gr.Accordion("Settings", open=True):
515
+ resolution = gr.Radio(
516
+ choices=["1024", "2048"],
517
+ value="1024",
518
+ label="Processing Resolution",
519
+ info="1024 = balanced (~8s/frame CPU), 2048 = max quality (trained resolution, fast on GPU)"
520
+ )
521
+ mask_mode = gr.Radio(
522
+ choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"],
523
+ value="Hybrid (auto)",
524
+ label="Mask Mode",
525
+ info="Hybrid = fast green detection + AI fallback. Fast = classical only (~0.01s). AI = always use BiRefNet (~13s/frame)"
526
+ )
527
+ despill_slider = gr.Slider(
528
+ 0, 10, value=5, step=1,
529
+ label="Despill Strength",
530
+ info="Remove green reflections from subject (0=off, 10=max)"
531
+ )
532
+ despeckle_check = gr.Checkbox(
533
+ value=True,
534
+ label="Auto Despeckle",
535
+ info="Remove small disconnected artifacts (tracking markers, noise)"
536
+ )
537
+ despeckle_size = gr.Number(
538
+ value=400, precision=0,
539
+ label="Despeckle Size",
540
+ info="Minimum pixel area to keep (smaller = more aggressive cleanup)"
541
+ )
542
+
543
+ output_mode = gr.Dropdown(
544
+ choices=[
545
+ "Composite on checkerboard (MP4)",
546
+ "Alpha matte (MP4)",
547
+ "Transparent video (WebM)",
548
+ "PNG sequence (ZIP)",
549
+ ],
550
+ value="Composite on checkerboard (MP4)",
551
+ label="Output Format"
552
+ )
553
+
554
+ process_btn = gr.Button("Process Video", variant="primary", size="lg")
555
+
556
+ with gr.Column(scale=1):
557
+ output_video = gr.Video(label="Result Preview")
558
+ output_file = gr.File(label="Download Result")
559
+ status_text = gr.Textbox(label="Status", interactive=False)
560
+
561
+ gr.Examples(
562
+ examples=[
563
+ ["examples/corridor_greenscreen_demo.mp4", "1024", 5, "Hybrid (auto)", True, 400, "Composite on checkerboard (MP4)"],
564
+ ],
565
+ inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size, output_mode],
566
+ outputs=[output_video, output_file, status_text],
567
+ fn=process_example,
568
+ cache_examples=True,
569
+ cache_mode="lazy",
570
+ label="Examples (click to load)"
571
+ )
572
+
573
+ process_btn.click(
574
+ fn=process_video,
575
+ inputs=[input_video, resolution, despill_slider, mask_mode, despeckle_check, despeckle_size, output_mode],
576
+ outputs=[output_video, output_file, status_text],
577
+ )
578
+
579
+
580
+ # ---------------------------------------------------------------------------
581
+ # CLI mode
582
+ # ---------------------------------------------------------------------------
583
+
584
+ def cli_main():
585
+ """CLI mode: python app.py --input video.mp4 [options]"""
586
+ import argparse
587
+ parser = argparse.ArgumentParser(description="CorridorKey Green Screen Matting")
588
+ parser.add_argument("--input", required=True, help="Input video path")
589
+ parser.add_argument("--output", default="output", help="Output directory")
590
+ parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"],
591
+ help="Force device (auto=detect GPU/CPU)")
592
+ parser.add_argument("--resolution", default="1024", choices=["1024", "2048"],
593
+ help="Model resolution (1024=fast, 2048=max quality)")
594
+ parser.add_argument("--mask-mode", default="Hybrid (auto)",
595
+ choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"])
596
+ parser.add_argument("--despill", type=int, default=5, help="Despill strength 0-10")
597
+ parser.add_argument("--no-despeckle", action="store_true")
598
+ parser.add_argument("--despeckle-size", type=int, default=400)
599
+ parser.add_argument("--format", default="Composite on checkerboard (MP4)",
600
+ choices=["Composite on checkerboard (MP4)", "Alpha matte (MP4)",
601
+ "Transparent video (WebM)", "PNG sequence (ZIP)"])
602
+ args = parser.parse_args()
603
+
604
+ global HAS_CUDA
605
+ if args.device == "cpu":
606
+ HAS_CUDA = False
607
+ elif args.device == "cuda":
608
+ HAS_CUDA = True
609
+ print(f"Device: {'CUDA' if HAS_CUDA else 'CPU'}")
610
+
611
+ class CLIProgress:
612
+ def __call__(self, val, desc=""):
613
+ if desc:
614
+ print(f" [{val:.0%}] {desc}")
615
+
616
+ video, file, status = process_video(
617
+ args.input, args.resolution, args.despill, args.mask_mode,
618
+ not args.no_despeckle, args.despeckle_size, args.format,
619
+ progress=CLIProgress()
620
+ )
621
+ print(f"\n{status}")
622
+ if video:
623
+ os.makedirs(args.output, exist_ok=True)
624
+ dst = os.path.join(args.output, os.path.basename(video))
625
+ shutil.copy2(video, dst)
626
+ print(f"Output: {dst}")
627
+ if file:
628
+ os.makedirs(args.output, exist_ok=True)
629
+ dst = os.path.join(args.output, os.path.basename(file))
630
+ shutil.copy2(file, dst)
631
+ print(f"Output: {dst}")
632
+
633
+
634
+ if __name__ == "__main__":
635
+ if len(sys.argv) > 1 and "--input" in sys.argv:
636
+ cli_main()
637
+ else:
638
+ demo.queue(default_concurrency_limit=1)
639
+ demo.launch(ssr_mode=False, mcp_server=True)
examples/corridor_greenscreen_demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fd4d122cd932a796fc94cba5be55f208a3e50bf9e9272d42b59a8b21c2a6e96
3
+ size 7342764
models/corridorkey_1024.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfcf469e58a5b917352ff75277c3d5d3adc4c3720b8642b1751e6c710f0541fc
3
+ size 312511017
models/corridorkey_2048.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f376f73b0045420c2e5d391d6d01dc8a9464df38bffe7cfa4350e1cbb63cde25
3
+ size 400592017
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python-headless
3
+ huggingface-hub
4
+ onnxruntime
5
+ gradio[mcp]