Nekochu commited on
Commit
0b6961f
·
1 Parent(s): c2d53e4

add ZeroGPU GPU inference (FP16, flash-attn, batch=32@1024/16@2048)

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.onnx
CorridorKeyModule/__init__.py ADDED
File without changes
CorridorKeyModule/core/__init__.py ADDED
File without changes
CorridorKeyModule/core/model_transformer.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ import timm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """Linear embedding: C_in -> C_out."""
15
+
16
+ def __init__(self, input_dim: int = 2048, embed_dim: int = 768) -> None:
17
+ super().__init__()
18
+ self.proj = nn.Linear(input_dim, embed_dim)
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ return self.proj(x)
22
+
23
+
24
+ class DecoderHead(nn.Module):
25
+ def __init__(
26
+ self, feature_channels: list[int] | None = None, embedding_dim: int = 256, output_dim: int = 1
27
+ ) -> None:
28
+ super().__init__()
29
+ if feature_channels is None:
30
+ feature_channels = [112, 224, 448, 896]
31
+
32
+ # MLP layers to unify channel dimensions
33
+ self.linear_c4 = MLP(input_dim=feature_channels[3], embed_dim=embedding_dim)
34
+ self.linear_c3 = MLP(input_dim=feature_channels[2], embed_dim=embedding_dim)
35
+ self.linear_c2 = MLP(input_dim=feature_channels[1], embed_dim=embedding_dim)
36
+ self.linear_c1 = MLP(input_dim=feature_channels[0], embed_dim=embedding_dim)
37
+
38
+ # Fuse
39
+ self.linear_fuse = nn.Conv2d(embedding_dim * 4, embedding_dim, kernel_size=1, bias=False)
40
+ self.bn = nn.BatchNorm2d(embedding_dim)
41
+ self.relu = nn.ReLU(inplace=True)
42
+
43
+ # Predict
44
+ self.dropout = nn.Dropout(0.1)
45
+ self.classifier = nn.Conv2d(embedding_dim, output_dim, kernel_size=1)
46
+
47
+ def forward(self, features: list[torch.Tensor]) -> torch.Tensor:
48
+ c1, c2, c3, c4 = features
49
+
50
+ n, _, h, w = c4.shape
51
+
52
+ # Resize to C1 size (which is H/4)
53
+ _c4 = self.linear_c4(c4.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c4.shape[2], c4.shape[3])
54
+ _c4 = F.interpolate(_c4, size=c1.shape[2:], mode="bilinear", align_corners=False)
55
+
56
+ _c3 = self.linear_c3(c3.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c3.shape[2], c3.shape[3])
57
+ _c3 = F.interpolate(_c3, size=c1.shape[2:], mode="bilinear", align_corners=False)
58
+
59
+ _c2 = self.linear_c2(c2.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c2.shape[2], c2.shape[3])
60
+ _c2 = F.interpolate(_c2, size=c1.shape[2:], mode="bilinear", align_corners=False)
61
+
62
+ _c1 = self.linear_c1(c1.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c1.shape[2], c1.shape[3])
63
+
64
+ _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
65
+ _c = self.bn(_c)
66
+ _c = self.relu(_c)
67
+
68
+ x = self.dropout(_c)
69
+ x = self.classifier(x)
70
+
71
+ return x
72
+
73
+
74
+ class RefinerBlock(nn.Module):
75
+ """
76
+ Residual Block with Dilation and GroupNorm (Safe for Batch Size 2).
77
+ """
78
+
79
+ def __init__(self, channels: int, dilation: int = 1) -> None:
80
+ super().__init__()
81
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
82
+ self.gn1 = nn.GroupNorm(8, channels)
83
+ self.relu = nn.ReLU(inplace=True)
84
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
85
+ self.gn2 = nn.GroupNorm(8, channels)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ residual = x
89
+ out = self.conv1(x)
90
+ out = self.gn1(out)
91
+ out = self.relu(out)
92
+ out = self.conv2(out)
93
+ out = self.gn2(out)
94
+ out += residual
95
+ out = self.relu(out)
96
+ return out
97
+
98
+
99
+ class CNNRefinerModule(nn.Module):
100
+ """
101
+ Dilated Residual Refiner (Receptive Field ~65px).
102
+ designed to solve Macroblocking artifacts from Hiera.
103
+ Structure: Stem -> Res(d1) -> Res(d2) -> Res(d4) -> Res(d8) -> Projection.
104
+ """
105
+
106
+ def __init__(self, in_channels: int = 7, hidden_channels: int = 64, out_channels: int = 4) -> None:
107
+ super().__init__()
108
+
109
+ # Stem
110
+ self.stem = nn.Sequential(
111
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
112
+ nn.GroupNorm(8, hidden_channels),
113
+ nn.ReLU(inplace=True),
114
+ )
115
+
116
+ # Dilated Residual Blocks (RF Expansion)
117
+ self.res1 = RefinerBlock(hidden_channels, dilation=1)
118
+ self.res2 = RefinerBlock(hidden_channels, dilation=2)
119
+ self.res3 = RefinerBlock(hidden_channels, dilation=4)
120
+ self.res4 = RefinerBlock(hidden_channels, dilation=8)
121
+
122
+ # Final Projection (No Activation, purely additive logits)
123
+ self.final = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
124
+
125
+ # Tiny Noise Init (Whisper) - Provides gradients without shock
126
+ nn.init.normal_(self.final.weight, mean=0.0, std=1e-3)
127
+ nn.init.constant_(self.final.bias, 0)
128
+
129
+ def forward(self, img: torch.Tensor, coarse_pred: torch.Tensor) -> torch.Tensor:
130
+ # img: [B, 3, H, W]
131
+ # coarse_pred: [B, 4, H, W]
132
+ x = torch.cat([img, coarse_pred], dim=1)
133
+
134
+ x = self.stem(x)
135
+ x = self.res1(x)
136
+ x = self.res2(x)
137
+ x = self.res3(x)
138
+ x = self.res4(x)
139
+
140
+ # Output Scaling (10x Boost)
141
+ # Allows the Refiner to predict small stable values (e.g. 0.5) that become strong corrections (5.0).
142
+ return self.final(x) * 10.0
143
+
144
+
145
+ class GreenFormer(nn.Module):
146
+ def __init__(
147
+ self,
148
+ encoder_name: str = "hiera_base_plus_224.mae_in1k_ft_in1k",
149
+ in_channels: int = 4,
150
+ img_size: int = 512,
151
+ use_refiner: bool = True,
152
+ ) -> None:
153
+ super().__init__()
154
+
155
+ # --- Encoder ---
156
+ # Load Pretrained Hiera
157
+ # 1. Create Target Model (512x512, Random Weights)
158
+ # We use features_only=True, which wraps it in FeatureGetterNet
159
+ logger.info("Initializing %s (img_size=%d)", encoder_name, img_size)
160
+ self.encoder = timm.create_model(encoder_name, pretrained=False, features_only=True, img_size=img_size)
161
+ # We skip downloading/loading base weights because the user's checkpoint
162
+ # (loaded immediately after this) contains all weights, including correctly
163
+ # trained/sized PosEmbeds. This keeps the project offline-capable using only local assets.
164
+ logger.info("Skipped downloading base weights (relying on custom checkpoint)")
165
+
166
+ # Patch First Layer for 4 channels
167
+ if in_channels != 3:
168
+ self._patch_input_layer(in_channels)
169
+
170
+ # Get feature info
171
+ # Verified Hiera Base Plus channels: [112, 224, 448, 896]
172
+ # We can try to fetch dynamically
173
+ try:
174
+ feature_channels = self.encoder.feature_info.channels()
175
+ except (AttributeError, TypeError):
176
+ feature_channels = [112, 224, 448, 896]
177
+ logger.info("Feature channels: %s", feature_channels)
178
+
179
+ # --- Decoders ---
180
+ embedding_dim = 256
181
+
182
+ # Alpha Decoder (Outputs 1 channel)
183
+ self.alpha_decoder = DecoderHead(feature_channels, embedding_dim, output_dim=1)
184
+
185
+ # Foreground Decoder (Outputs 3 channels)
186
+ self.fg_decoder = DecoderHead(feature_channels, embedding_dim, output_dim=3)
187
+
188
+ # --- Refiner ---
189
+ # CNN Refiner
190
+ # In Channels: 3 (RGB) + 4 (Coarse Pred) = 7
191
+ self.use_refiner = use_refiner
192
+ if self.use_refiner:
193
+ self.refiner = CNNRefinerModule(in_channels=7, hidden_channels=64, out_channels=4)
194
+ else:
195
+ self.refiner = None
196
+ logger.info("Refiner module DISABLED (backbone-only mode)")
197
+
198
+ def _patch_input_layer(self, in_channels: int) -> None:
199
+ """
200
+ Modifies the first convolution layer to accept `in_channels`.
201
+ Copies existing RGB weights and initializes extras to zero.
202
+ """
203
+ # Hiera: self.encoder.model.patch_embed.proj
204
+
205
+ try:
206
+ patch_embed = self.encoder.model.patch_embed.proj
207
+ except AttributeError:
208
+ # Fallback if timm changes structure or for other models
209
+ patch_embed = self.encoder.patch_embed.proj
210
+ weight = patch_embed.weight.data # [Out, 3, K, K]
211
+ bias = patch_embed.bias.data if patch_embed.bias is not None else None
212
+
213
+ new_in_channels = in_channels
214
+ out_channels, _, k, k = weight.shape
215
+
216
+ # Create new conv
217
+ new_conv = nn.Conv2d(
218
+ new_in_channels,
219
+ out_channels,
220
+ kernel_size=k,
221
+ stride=patch_embed.stride,
222
+ padding=patch_embed.padding,
223
+ bias=(bias is not None),
224
+ )
225
+
226
+ # Copy weights
227
+ new_conv.weight.data[:, :3, :, :] = weight
228
+ # Initialize new channels to 0 (Weight Patching)
229
+ new_conv.weight.data[:, 3:, :, :] = 0.0
230
+
231
+ if bias is not None:
232
+ new_conv.bias.data = bias
233
+
234
+ # Replace in module
235
+ try:
236
+ self.encoder.model.patch_embed.proj = new_conv
237
+ except AttributeError:
238
+ self.encoder.patch_embed.proj = new_conv
239
+
240
+ logger.info("Patched input layer: 3 → %d channels (extra initialized to 0)", in_channels)
241
+
242
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
243
+ # x: [B, 4, H, W]
244
+ input_size = x.shape[2:]
245
+
246
+ # Encode
247
+ features = self.encoder(x) # Returns list of features
248
+
249
+ # Decode Streams
250
+ alpha_logits = self.alpha_decoder(features) # [B, 1, H/4, W/4]
251
+ fg_logits = self.fg_decoder(features) # [B, 3, H/4, W/4]
252
+
253
+ # Upsample to full resolution (Bilinear)
254
+ # These are the "Coarse" LOGITS
255
+ alpha_logits_up = F.interpolate(alpha_logits, size=input_size, mode="bilinear", align_corners=False)
256
+ fg_logits_up = F.interpolate(fg_logits, size=input_size, mode="bilinear", align_corners=False)
257
+
258
+ # --- HUMILITY CLAMP REMOVED (Phase 3) ---
259
+ # User requested NO CLAMPING to preserve all backbone detail.
260
+ # Refiner sees raw logits (-inf to +inf).
261
+ # alpha_logits_up = torch.clamp(alpha_logits_up, -3.0, 3.0)
262
+ # fg_logits_up = torch.clamp(fg_logits_up, -3.0, 3.0)
263
+
264
+ # Coarse Probs (for Loss and Refiner Input)
265
+ alpha_coarse = torch.sigmoid(alpha_logits_up)
266
+ fg_coarse = torch.sigmoid(fg_logits_up)
267
+
268
+ # --- Refinement (CNN Hybrid) ---
269
+ # 4. Refine (CNN)
270
+ # Input to refiner: RGB Image (first 3 channels of x) + Coarse Predictions (Probs)
271
+ # We give the refiner 'Probs' as input features because they are normalized [0,1]
272
+ rgb = x[:, :3, :, :]
273
+
274
+ # Feed the Refiner
275
+ coarse_pred = torch.cat([alpha_coarse, fg_coarse], dim=1) # [B, 4, H, W]
276
+
277
+ # Refiner outputs DELTA LOGITS
278
+ # The refiner predicts the correction in valid score space (-inf, inf)
279
+ if self.use_refiner and self.refiner is not None:
280
+ delta_logits = self.refiner(rgb, coarse_pred)
281
+ else:
282
+ # Zero Deltas
283
+ delta_logits = torch.zeros_like(coarse_pred)
284
+
285
+ delta_alpha = delta_logits[:, 0:1]
286
+ delta_fg = delta_logits[:, 1:4]
287
+
288
+ # Residual Addition in Logit Space
289
+ # This allows infinite correction capability and prevents saturation blocking
290
+ alpha_final_logits = alpha_logits_up + delta_alpha
291
+ fg_final_logits = fg_logits_up + delta_fg
292
+
293
+ # Final Activation
294
+ alpha_final = torch.sigmoid(alpha_final_logits)
295
+ fg_final = torch.sigmoid(fg_final_logits)
296
+
297
+ return {"alpha": alpha_final, "fg": fg_final}
README.md CHANGED
@@ -17,20 +17,27 @@ tags:
17
  - corridor-digital
18
  - transparency
19
  - onnx
 
 
20
  - mcp-server
21
  short_description: Remove green background from video, even transparent objects
22
  ---
23
 
24
- # CorridorKey Green Screen Matting (CPU)
25
 
26
- Remove green screen backgrounds from video on free CPU. Handles transparent objects (glass, water, cloth) that traditional chroma key cannot.
27
 
28
  Based on [CorridorKey](https://github.com/nikopueringer/CorridorKey) by Corridor Digital.
29
 
 
 
 
 
 
30
  ## Pipeline
31
 
32
- 1. **BiRefNet** - Generates coarse foreground mask
33
- 2. **CorridorKey GreenFormer** - Refines alpha matte + extracts clean foreground
34
  3. **Compositing** - Despill, despeckle, composite on new background
35
 
36
  ## API
@@ -41,7 +48,7 @@ Based on [CorridorKey](https://github.com/nikopueringer/CorridorKey) by Corridor
41
  ```bash
42
  curl -X POST "https://luminia-corridorkey.hf.space/gradio_api/call/process_video" \
43
  -H "Content-Type: application/json" \
44
- -d '{"data": ["video.mp4", 5, 10, true, 400, "Composite on checkerboard (MP4)"]}'
45
  ```
46
 
47
  **Step 2: Get result**
@@ -51,27 +58,11 @@ curl "https://luminia-corridorkey.hf.space/gradio_api/call/process_video/{event_
51
 
52
  ### MCP (Model Context Protocol)
53
 
54
- **Tool schema:**
55
- ```json
56
- {
57
- "name": "process_video",
58
- "description": "Remove green screen background from video using CorridorKey AI matting.",
59
- "parameters": {
60
- "video_path": "Path to green screen video",
61
- "despill_val": "Despill strength 0-10 (default 5)",
62
- "refiner_val": "Refiner scale 0-20 (default 10)",
63
- "auto_despeckle": "Remove small artifacts (default true)",
64
- "despeckle_size": "Min pixel area to keep (default 400)",
65
- "output_mode": "Composite on checkerboard (MP4) | Alpha matte (MP4) | Transparent video (WebM) | PNG sequence (ZIP)"
66
- }
67
- }
68
- ```
69
-
70
  **MCP Config:**
71
  ```json
72
  {
73
  "mcpServers": {
74
- "corridorkey-cpu": {
75
  "url": "https://luminia-corridorkey.hf.space/gradio_api/mcp/"
76
  }
77
  }
 
17
  - corridor-digital
18
  - transparency
19
  - onnx
20
+ - pytorch
21
+ - zerogpu
22
  - mcp-server
23
  short_description: Remove green background from video, even transparent objects
24
  ---
25
 
26
+ # CorridorKey Green Screen Matting
27
 
28
+ Remove green screen backgrounds from video. Handles transparent objects (glass, water, cloth) that traditional chroma key cannot.
29
 
30
  Based on [CorridorKey](https://github.com/nikopueringer/CorridorKey) by Corridor Digital.
31
 
32
+ ## Inference Paths
33
+
34
+ - **GPU (ZeroGPU H200)**: PyTorch GreenFormer with batched inference (batch 32 at 1024, batch 8 at 2048)
35
+ - **CPU (fallback)**: ONNX Runtime sequential inference (batch 1)
36
+
37
  ## Pipeline
38
 
39
+ 1. **BiRefNet** - Generates coarse foreground mask (ONNX)
40
+ 2. **CorridorKey GreenFormer** - Refines alpha matte + extracts clean foreground (PyTorch on GPU, ONNX on CPU)
41
  3. **Compositing** - Despill, despeckle, composite on new background
42
 
43
  ## API
 
48
  ```bash
49
  curl -X POST "https://luminia-corridorkey.hf.space/gradio_api/call/process_video" \
50
  -H "Content-Type: application/json" \
51
+ -d '{"data": ["video.mp4", "1024", 5, "Hybrid (auto)", true, 400]}'
52
  ```
53
 
54
  **Step 2: Get result**
 
58
 
59
  ### MCP (Model Context Protocol)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  **MCP Config:**
62
  ```json
63
  {
64
  "mcpServers": {
65
+ "corridorkey": {
66
  "url": "https://luminia-corridorkey.hf.space/gradio_api/mcp/"
67
  }
68
  }
app.py CHANGED
@@ -1,7 +1,8 @@
1
  """CorridorKey Green Screen Matting - HuggingFace Space.
2
 
3
- Self-contained Gradio app using ONNX Runtime for inference.
4
- Supports CPU (free tier) and GPU (community grant).
 
5
 
6
  Usage:
7
  python app.py # Launch Gradio UI
@@ -10,6 +11,7 @@ Usage:
10
 
11
  import os
12
  import sys
 
13
  import shutil
14
  import gc
15
  import time
@@ -28,6 +30,12 @@ import cv2
28
  import gradio as gr
29
  import onnxruntime as ort
30
 
 
 
 
 
 
 
31
  # Workaround: Gradio cache_examples bug with None outputs.
32
  _original_read_from_flag = gr.components.Component.read_from_flag
33
  def _patched_read_from_flag(self, payload):
@@ -52,13 +60,35 @@ CORRIDORKEY_MODELS = {
52
  "1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"),
53
  "2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"),
54
  }
 
 
55
  IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
56
  IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
57
  MAX_DURATION_CPU = 5
58
- MAX_DURATION_GPU = 30
59
- MAX_FRAMES = 150
60
  HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers()
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # ---------------------------------------------------------------------------
63
  # Color utilities (numpy-only)
64
  # ---------------------------------------------------------------------------
@@ -134,41 +164,167 @@ def fast_greenscreen_mask(frame_rgb_f32):
134
  return mask_f32, confidence
135
 
136
  # ---------------------------------------------------------------------------
137
- # Model loading
138
  # ---------------------------------------------------------------------------
139
  _birefnet_session = None
140
  _corridorkey_sessions = {}
 
 
 
 
 
 
 
 
141
 
142
  def _ort_opts():
143
  opts = ort.SessionOptions()
144
- opts.intra_op_num_threads = 2
145
- opts.inter_op_num_threads = 1
 
 
 
 
146
  opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
147
  opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
148
  opts.enable_mem_pattern = True
149
  return opts
150
 
151
- def get_birefnet():
 
 
 
 
 
 
 
 
 
 
152
  global _birefnet_session
153
- if _birefnet_session is None:
154
- logger.info("Downloading BiRefNet-Lite ONNX...")
155
- path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
156
- logger.info("Loading BiRefNet ONNX: %s", path)
157
- _birefnet_session = ort.InferenceSession(path, _ort_opts(), providers=["CPUExecutionProvider"])
 
 
 
 
158
  return _birefnet_session
159
 
160
- def get_corridorkey(resolution="1024"):
161
  global _corridorkey_sessions
162
  if resolution not in _corridorkey_sessions:
163
  onnx_path = CORRIDORKEY_MODELS.get(resolution)
164
  if not onnx_path or not os.path.exists(onnx_path):
165
  raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.")
166
- logger.info("Loading CorridorKey ONNX (%s): %s", resolution, onnx_path)
167
- _corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_opts(), providers=["CPUExecutionProvider"])
 
168
  return _corridorkey_sessions[resolution]
169
 
170
  # ---------------------------------------------------------------------------
171
- # Per-frame inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  # ---------------------------------------------------------------------------
173
  def birefnet_frame(session, image_rgb_uint8):
174
  h, w = image_rgb_uint8.shape[:2]
@@ -179,8 +335,9 @@ def birefnet_frame(session, image_rgb_uint8):
179
  pred = 1.0 / (1.0 + np.exp(-session.run(None, {inp.name: img})[-1]))
180
  return (cv2.resize(pred[0, 0], (w, h)) > 0.04).astype(np.float32)
181
 
182
- def corridorkey_frame(session, image_f32, mask_f32, img_size,
183
- despill_strength=0.5, auto_despeckle=True, despeckle_size=400):
 
184
  h, w = image_f32.shape[:2]
185
  img_r = cv2.resize(image_f32, (img_size, img_size))
186
  mask_r = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis]
@@ -196,6 +353,70 @@ def corridorkey_frame(session, image_f32, mask_f32, img_size,
196
  fg = despill(fg, green_limit_mode="average", strength=despill_strength)
197
  return {"alpha": alpha, "fg": fg}
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  # ---------------------------------------------------------------------------
200
  # Video stitching
201
  # ---------------------------------------------------------------------------
@@ -213,20 +434,112 @@ def _stitch_ffmpeg(frame_dir, out_path, fps, pattern="%05d.png", pix_fmt="yuv420
213
  logger.warning("ffmpeg failed: %s", e)
214
  return False
215
 
 
216
  # ---------------------------------------------------------------------------
217
- # Main pipeline: generates ALL professional outputs
218
  # ---------------------------------------------------------------------------
219
- def process_video(video_path, resolution, despill_val, mask_mode,
220
- auto_despeckle, despeckle_size, progress=gr.Progress()):
221
- """Remove green screen background from video using CorridorKey AI matting.
222
- Returns: comp_video, matte_video, download_zip, status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  """
224
  if video_path is None:
225
  raise gr.Error("Please upload a video.")
226
 
227
- max_dur = MAX_DURATION_GPU if HAS_CUDA else MAX_DURATION_CPU
 
 
 
 
 
 
 
 
 
 
228
  img_size = int(resolution)
 
 
229
 
 
230
  cap = cv2.VideoCapture(video_path)
231
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
232
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -235,132 +548,325 @@ def process_video(video_path, resolution, despill_val, mask_mode,
235
  cap.release()
236
 
237
  if total_frames == 0:
238
- raise gr.Error("Could not read video frames. Check file format.")
239
  duration = total_frames / fps
240
  if duration > max_dur:
241
- raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s on {'GPU' if HAS_CUDA else 'free CPU'} tier.")
242
-
243
  frames_to_process = min(total_frames, MAX_FRAMES)
244
- logger.info("Processing %d frames (%dx%d @ %.1f fps), resolution=%d, mask=%s",
245
- frames_to_process, w, h, fps, img_size, mask_mode)
246
 
247
- try:
248
- birefnet = None
249
- if mask_mode != "Fast (classical)":
250
- progress(0, desc="Loading BiRefNet...")
251
- birefnet = get_birefnet()
252
- progress(0.03, desc=f"Loading CorridorKey ({resolution})...")
253
- corridorkey = get_corridorkey(resolution)
254
- except Exception as e:
255
- raise gr.Error(f"Failed to load models: {e}")
 
 
 
 
 
 
256
 
257
- despill_strength = despill_val / 10.0
 
 
 
 
 
 
 
 
 
 
 
258
  tmpdir = tempfile.mkdtemp(prefix="ck_")
 
 
259
 
260
  try:
261
- # Output dirs matching original CorridorKey structure
262
- comp_dir = os.path.join(tmpdir, "Comp")
263
- fg_dir = os.path.join(tmpdir, "FG")
264
- matte_dir = os.path.join(tmpdir, "Matte")
265
- processed_dir = os.path.join(tmpdir, "Processed")
266
- for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
267
- os.makedirs(d, exist_ok=True)
268
 
269
- bg_lin = srgb_to_linear(create_checkerboard(w, h))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- cap = cv2.VideoCapture(video_path)
272
- frame_times = []
273
- total_start = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- for i in range(frames_to_process):
276
- t0 = time.time()
277
- ret, frame_bgr = cap.read()
278
- if not ret:
279
- break
280
 
281
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
282
- frame_f32 = frame_rgb.astype(np.float32) / 255.0
283
-
284
- # Coarse mask
285
- if mask_mode == "Fast (classical)":
286
- mask, _ = fast_greenscreen_mask(frame_f32)
287
- if mask is None:
288
- raise gr.Error("Fast mask failed: no green screen detected. Try 'AI (BiRefNet)' mode.")
289
- elif mask_mode == "Hybrid (auto)":
290
- mask, conf = fast_greenscreen_mask(frame_f32)
291
- if mask is None or conf < 0.7:
292
- mask = birefnet_frame(birefnet, frame_rgb)
293
- else:
294
- mask = birefnet_frame(birefnet, frame_rgb)
295
-
296
- # CorridorKey inference
297
- result = corridorkey_frame(corridorkey, frame_f32, mask, img_size,
298
- despill_strength=despill_strength,
299
- auto_despeckle=auto_despeckle,
300
- despeckle_size=int(despeckle_size))
301
- alpha = result["alpha"]
302
- fg = result["fg"]
303
-
304
- # Ensure alpha is [H,W,1] and get 2D version
305
- if alpha.ndim == 2:
306
- alpha = alpha[:, :, np.newaxis]
307
- alpha_2d = alpha[:, :, 0]
308
-
309
- # -- Comp: composite on checkerboard (sRGB PNG) --
310
- fg_lin = srgb_to_linear(fg)
311
- comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
312
- cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.png"),
313
- (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1])
314
-
315
- # -- FG: straight foreground, 100% opaque (sRGB PNG) --
316
- cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.png"),
317
- (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1])
318
-
319
- # -- Matte: alpha channel (grayscale PNG) --
320
- cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"),
321
- (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8))
322
-
323
- # -- Processed: premultiplied RGBA (PNG with transparency) --
324
- fg_premul_lin = premultiply(fg_lin, alpha)
325
- fg_premul_srgb = linear_to_srgb(fg_premul_lin)
326
- fg_premul_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8)
327
- alpha_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8)
328
- rgba = np.concatenate([fg_premul_u8[:, :, ::-1], alpha_u8[:, :, np.newaxis]], axis=-1)
329
- cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba)
330
-
331
- # Progress with ETA
332
- elapsed = time.time() - t0
333
- frame_times.append(elapsed)
334
- avg_t = np.mean(frame_times[-5:]) if len(frame_times) >= 2 else elapsed
335
- remaining = (frames_to_process - i - 1) * avg_t
336
- eta = f"{remaining/60:.1f}min" if remaining > 60 else f"{remaining:.0f}s"
337
- pct = 0.05 + 0.85 * (i + 1) / frames_to_process
338
- progress(pct, desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) | ~{eta} left")
339
-
340
- cap.release()
341
- total_elapsed = time.time() - total_start
342
- total_min = total_elapsed / 60
343
-
344
- # Stitch preview videos
345
- progress(0.92, desc="Stitching videos...")
 
 
 
 
 
 
 
 
 
346
  comp_video = os.path.join(tmpdir, "comp_preview.mp4")
347
  matte_video = os.path.join(tmpdir, "matte_preview.mp4")
348
- _stitch_ffmpeg(comp_dir, comp_video, fps, extra_args=["-crf", "18"])
349
- _stitch_ffmpeg(matte_dir, matte_video, fps, extra_args=["-crf", "18"])
 
350
 
351
- # Package full professional ZIP
 
352
  progress(0.96, desc="Packaging ZIP...")
353
  zip_path = os.path.join(tmpdir, "CorridorKey_Output.zip")
354
  with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
355
  for folder in ["Comp", "FG", "Matte", "Processed"]:
356
  src = os.path.join(tmpdir, folder)
357
- for f in sorted(os.listdir(src)):
358
- zf.write(os.path.join(src, f), f"Output/{folder}/{f}")
 
359
 
360
  progress(1.0, desc="Done!")
 
361
  n = len(frame_times)
362
  avg = np.mean(frame_times) if frame_times else 0
363
- status = f"Processed {n} frames in {total_min:.1f}min ({w}x{h}) at {img_size}px | {avg:.1f}s/frame"
 
 
 
364
 
365
  return (
366
  comp_video if os.path.exists(comp_video) else None,
@@ -372,8 +878,8 @@ def process_video(video_path, resolution, despill_val, mask_mode,
372
  except gr.Error:
373
  raise
374
  except Exception as e:
375
- logger.exception("Processing failed")
376
- raise gr.Error(f"Processing failed: {e}")
377
  finally:
378
  for d in ["Comp", "FG", "Matte", "Processed"]:
379
  p = os.path.join(tmpdir, d)
@@ -388,10 +894,9 @@ def process_video(video_path, resolution, despill_val, mask_mode,
388
  def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size):
389
  return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size)
390
 
391
- if HAS_CUDA:
392
- DESCRIPTION = "# CorridorKey Green Screen Matting\nRemove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. GPU mode: max {max_dur}s / {max_frames} frames.".format(max_dur=MAX_DURATION_GPU, max_frames=MAX_FRAMES)
393
- else:
394
- DESCRIPTION = "# CorridorKey Green Screen Matting\nRemove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital. ~37min for 5s clip on free CPU."
395
 
396
  with gr.Blocks(title="CorridorKey") as demo:
397
  gr.Markdown(DESCRIPTION)
@@ -403,7 +908,7 @@ with gr.Blocks(title="CorridorKey") as demo:
403
  resolution = gr.Radio(
404
  choices=["1024", "2048"], value="1024",
405
  label="Processing Resolution",
406
- info="1024 = balanced (~8s/frame CPU), 2048 = max quality (fast on GPU)"
407
  )
408
  mask_mode = gr.Radio(
409
  choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"],
 
1
  """CorridorKey Green Screen Matting - HuggingFace Space.
2
 
3
+ Self-contained Gradio app with dual inference paths:
4
+ - GPU (ZeroGPU H200): PyTorch batched inference via GreenFormer
5
+ - CPU (fallback): ONNX Runtime sequential inference
6
 
7
  Usage:
8
  python app.py # Launch Gradio UI
 
11
 
12
  import os
13
  import sys
14
+ import math
15
  import shutil
16
  import gc
17
  import time
 
30
  import gradio as gr
31
  import onnxruntime as ort
32
 
33
+ try:
34
+ import spaces
35
+ HAS_SPACES = True
36
+ except ImportError:
37
+ HAS_SPACES = False
38
+
39
  # Workaround: Gradio cache_examples bug with None outputs.
40
  _original_read_from_flag = gr.components.Component.read_from_flag
41
  def _patched_read_from_flag(self, payload):
 
60
  "1024": os.path.join(MODELS_DIR, "corridorkey_1024.onnx"),
61
  "2048": os.path.join(MODELS_DIR, "corridorkey_2048.onnx"),
62
  }
63
+ CORRIDORKEY_PTH_REPO = "nikopueringer/CorridorKey_v1.0"
64
+ CORRIDORKEY_PTH_FILE = "CorridorKey_v1.0.pth"
65
  IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
66
  IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
67
  MAX_DURATION_CPU = 5
68
+ MAX_DURATION_GPU = 60
69
+ MAX_FRAMES = 1800
70
  HAS_CUDA = "CUDAExecutionProvider" in ort.get_available_providers()
71
 
72
+ # ---------------------------------------------------------------------------
73
+ # Preload model files at startup (OUTSIDE GPU function — don't waste GPU time on downloads)
74
+ # ---------------------------------------------------------------------------
75
+ logger.info("Preloading model files at startup...")
76
+ _preloaded_birefnet_path = None
77
+ _preloaded_pth_path = None
78
+ try:
79
+ _preloaded_birefnet_path = hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
80
+ logger.info("BiRefNet cached: %s", _preloaded_birefnet_path)
81
+ except Exception as e:
82
+ logger.warning("BiRefNet preload failed (will retry later): %s", e)
83
+ try:
84
+ _preloaded_pth_path = hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE)
85
+ logger.info("CorridorKey.pth cached: %s", _preloaded_pth_path)
86
+ except Exception as e:
87
+ logger.warning("CorridorKey.pth preload failed (will retry later): %s", e)
88
+
89
+ # Batch sizes for GPU inference (conservative for H200 80GB)
90
+ GPU_BATCH_SIZES = {"1024": 32, "2048": 16} # 2048 uses only 5.7GB/batch=2, so 16 easily fits in 69.8GB
91
+
92
  # ---------------------------------------------------------------------------
93
  # Color utilities (numpy-only)
94
  # ---------------------------------------------------------------------------
 
164
  return mask_f32, confidence
165
 
166
  # ---------------------------------------------------------------------------
167
+ # ONNX model loading (CPU fallback + BiRefNet)
168
  # ---------------------------------------------------------------------------
169
  _birefnet_session = None
170
  _corridorkey_sessions = {}
171
+ _sessions_on_gpu = False
172
+
173
+ def _get_providers():
174
+ """Get best available providers. Inside @spaces.GPU, CUDA is available."""
175
+ providers = ort.get_available_providers()
176
+ if "CUDAExecutionProvider" in providers:
177
+ return ["CUDAExecutionProvider", "CPUExecutionProvider"]
178
+ return ["CPUExecutionProvider"]
179
 
180
  def _ort_opts():
181
  opts = ort.SessionOptions()
182
+ if "CUDAExecutionProvider" in ort.get_available_providers():
183
+ opts.intra_op_num_threads = 0
184
+ opts.inter_op_num_threads = 0
185
+ else:
186
+ opts.intra_op_num_threads = 2
187
+ opts.inter_op_num_threads = 1
188
  opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
189
  opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
190
  opts.enable_mem_pattern = True
191
  return opts
192
 
193
+ def _ensure_gpu_sessions():
194
+ """Reload ONNX sessions on GPU if CUDA just became available (ZeroGPU)."""
195
+ global _birefnet_session, _corridorkey_sessions, _sessions_on_gpu
196
+ has_cuda_now = "CUDAExecutionProvider" in ort.get_available_providers()
197
+ if has_cuda_now and not _sessions_on_gpu:
198
+ logger.info("CUDA available! Reloading ONNX sessions on GPU...")
199
+ _birefnet_session = None
200
+ _corridorkey_sessions = {}
201
+ _sessions_on_gpu = True
202
+
203
+ def get_birefnet(force_cpu=False):
204
  global _birefnet_session
205
+ if _birefnet_session is None or force_cpu:
206
+ path = _preloaded_birefnet_path or hf_hub_download(repo_id=BIREFNET_REPO, filename=BIREFNET_FILE)
207
+ providers = ["CPUExecutionProvider"] if force_cpu else _get_providers()
208
+ logger.info("Loading BiRefNet ONNX: %s (providers: %s)", path, providers)
209
+ opts = _ort_opts()
210
+ if force_cpu:
211
+ opts.intra_op_num_threads = 2
212
+ opts.inter_op_num_threads = 1
213
+ _birefnet_session = ort.InferenceSession(path, opts, providers=providers)
214
  return _birefnet_session
215
 
216
+ def get_corridorkey_onnx(resolution="1024"):
217
  global _corridorkey_sessions
218
  if resolution not in _corridorkey_sessions:
219
  onnx_path = CORRIDORKEY_MODELS.get(resolution)
220
  if not onnx_path or not os.path.exists(onnx_path):
221
  raise gr.Error(f"CorridorKey ONNX model for {resolution} not found.")
222
+ providers = _get_providers()
223
+ logger.info("Loading CorridorKey ONNX (%s): %s (providers: %s)", resolution, onnx_path, providers)
224
+ _corridorkey_sessions[resolution] = ort.InferenceSession(onnx_path, _ort_opts(), providers=providers)
225
  return _corridorkey_sessions[resolution]
226
 
227
  # ---------------------------------------------------------------------------
228
+ # PyTorch model loading (GPU path)
229
+ # ---------------------------------------------------------------------------
230
+ _pytorch_model = None
231
+ _pytorch_model_size = None
232
+
233
+ def _load_greenformer(img_size):
234
+ """Load the GreenFormer PyTorch model for GPU inference."""
235
+ import torch
236
+ import torch.nn.functional as F
237
+ from CorridorKeyModule.core.model_transformer import GreenFormer
238
+
239
+ checkpoint_path = _preloaded_pth_path or hf_hub_download(repo_id=CORRIDORKEY_PTH_REPO, filename=CORRIDORKEY_PTH_FILE)
240
+ logger.info("Using checkpoint: %s", checkpoint_path)
241
+
242
+ logger.info("Initializing GreenFormer (img_size=%d)...", img_size)
243
+ model = GreenFormer(
244
+ encoder_name="hiera_base_plus_224.mae_in1k_ft_in1k",
245
+ img_size=img_size,
246
+ use_refiner=True,
247
+ )
248
+
249
+ # Load weights
250
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
251
+ state_dict = checkpoint.get("state_dict", checkpoint)
252
+
253
+ # Fix compiled model prefix & handle PosEmbed mismatch
254
+ new_state_dict = {}
255
+ model_state = model.state_dict()
256
+ for k, v in state_dict.items():
257
+ if k.startswith("_orig_mod."):
258
+ k = k[10:]
259
+ if "pos_embed" in k and k in model_state:
260
+ if v.shape != model_state[k].shape:
261
+ logger.info("Resizing %s from %s to %s", k, v.shape, model_state[k].shape)
262
+ N_src = v.shape[1]
263
+ C = v.shape[2]
264
+ grid_src = int(math.sqrt(N_src))
265
+ grid_dst = int(math.sqrt(model_state[k].shape[1]))
266
+ v_img = v.permute(0, 2, 1).view(1, C, grid_src, grid_src)
267
+ v_resized = F.interpolate(v_img, size=(grid_dst, grid_dst), mode="bicubic", align_corners=False)
268
+ v = v_resized.flatten(2).transpose(1, 2)
269
+ new_state_dict[k] = v
270
+
271
+ missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
272
+ if missing:
273
+ logger.warning("Missing keys: %s", missing)
274
+ if unexpected:
275
+ logger.warning("Unexpected keys: %s", unexpected)
276
+
277
+ model.eval()
278
+ model = model.cuda().half() # FP16 for speed on H200
279
+
280
+ logger.info("Model loaded as FP16")
281
+ try:
282
+ import flash_attn
283
+ logger.info("flash-attn v%s installed (prebuilt wheel)", getattr(flash_attn, '__version__', '?'))
284
+ except ImportError:
285
+ logger.info("flash-attn not available (using PyTorch SDPA)")
286
+ logger.info("SDPA backends: flash=%s, mem_efficient=%s, math=%s",
287
+ torch.backends.cuda.flash_sdp_enabled(),
288
+ torch.backends.cuda.mem_efficient_sdp_enabled(),
289
+ torch.backends.cuda.math_sdp_enabled())
290
+
291
+ # Skip torch.compile on ZeroGPU — the 37s warmup eats too much of the 120s budget.
292
+ if not HAS_SPACES and sys.platform in ("linux", "win32"):
293
+ try:
294
+ compiled = torch.compile(model)
295
+ dummy = torch.zeros(1, 4, img_size, img_size, dtype=torch.float16, device="cuda")
296
+ with torch.inference_mode():
297
+ compiled(dummy)
298
+ model = compiled
299
+ logger.info("torch.compile() succeeded")
300
+ except Exception as e:
301
+ logger.warning("torch.compile() failed, using eager mode: %s", e)
302
+ torch.cuda.empty_cache()
303
+ else:
304
+ logger.info("Skipping torch.compile() (ZeroGPU: saving GPU time for inference)")
305
+
306
+ logger.info("GreenFormer loaded on CUDA (img_size=%d)", img_size)
307
+ return model
308
+
309
+
310
+ def get_pytorch_model(img_size):
311
+ """Get or load the PyTorch GreenFormer model for the given resolution."""
312
+ global _pytorch_model, _pytorch_model_size
313
+ if _pytorch_model is None or _pytorch_model_size != img_size:
314
+ # Free old model if switching resolution
315
+ if _pytorch_model is not None:
316
+ import torch
317
+ del _pytorch_model
318
+ _pytorch_model = None
319
+ torch.cuda.empty_cache()
320
+ gc.collect()
321
+ _pytorch_model = _load_greenformer(img_size)
322
+ _pytorch_model_size = img_size
323
+ return _pytorch_model
324
+
325
+
326
+ # ---------------------------------------------------------------------------
327
+ # Per-frame inference: ONNX (CPU fallback)
328
  # ---------------------------------------------------------------------------
329
  def birefnet_frame(session, image_rgb_uint8):
330
  h, w = image_rgb_uint8.shape[:2]
 
335
  pred = 1.0 / (1.0 + np.exp(-session.run(None, {inp.name: img})[-1]))
336
  return (cv2.resize(pred[0, 0], (w, h)) > 0.04).astype(np.float32)
337
 
338
+ def corridorkey_frame_onnx(session, image_f32, mask_f32, img_size,
339
+ despill_strength=0.5, auto_despeckle=True, despeckle_size=400):
340
+ """ONNX inference for a single frame (CPU path)."""
341
  h, w = image_f32.shape[:2]
342
  img_r = cv2.resize(image_f32, (img_size, img_size))
343
  mask_r = cv2.resize(mask_f32, (img_size, img_size))[:, :, np.newaxis]
 
353
  fg = despill(fg, green_limit_mode="average", strength=despill_strength)
354
  return {"alpha": alpha, "fg": fg}
355
 
356
+
357
+ # ---------------------------------------------------------------------------
358
+ # Batched inference: PyTorch (GPU path)
359
+ # ---------------------------------------------------------------------------
360
+ def corridorkey_batch_pytorch(model, images_f32, masks_f32, img_size,
361
+ despill_strength=0.5, auto_despeckle=True, despeckle_size=400):
362
+ """PyTorch batched inference for multiple frames on GPU.
363
+
364
+ Args:
365
+ model: GreenFormer model on CUDA
366
+ images_f32: list of [H, W, 3] float32 numpy arrays (0-1, sRGB)
367
+ masks_f32: list of [H, W] float32 numpy arrays (0-1)
368
+ img_size: model input resolution (1024 or 2048)
369
+
370
+ Returns:
371
+ list of dicts with 'alpha' [H,W,1] and 'fg' [H,W,3]
372
+ """
373
+ import torch
374
+
375
+ batch_size = len(images_f32)
376
+ if batch_size == 0:
377
+ return []
378
+
379
+ # Store original sizes per frame
380
+ orig_sizes = [(img.shape[1], img.shape[0]) for img in images_f32] # (w, h)
381
+
382
+ # Preprocess: resize, normalize, concatenate into batch tensor
383
+ batch_inputs = []
384
+ for img, mask in zip(images_f32, masks_f32):
385
+ img_r = cv2.resize(img, (img_size, img_size))
386
+ mask_r = cv2.resize(mask, (img_size, img_size))[:, :, np.newaxis]
387
+ inp = np.concatenate([(img_r - IMAGENET_MEAN) / IMAGENET_STD, mask_r], axis=-1)
388
+ batch_inputs.append(inp.transpose(2, 0, 1)) # [4, H, W]
389
+
390
+ batch_np = np.stack(batch_inputs, axis=0).astype(np.float32) # [B, 4, H, W]
391
+ batch_tensor = torch.from_numpy(batch_np).cuda().half() # FP16 input
392
+
393
+ # Forward pass — model is FP16, input is FP16, no autocast needed
394
+ with torch.inference_mode():
395
+ out = model(batch_tensor)
396
+
397
+ # Extract results
398
+ alphas_gpu = out["alpha"].float().cpu().numpy() # [B, 1, H, W]
399
+ fgs_gpu = out["fg"].float().cpu().numpy() # [B, 3, H, W]
400
+
401
+ del batch_tensor
402
+ # Don't empty cache per batch - too expensive. Let PyTorch manage.
403
+
404
+ # Postprocess each frame
405
+ results = []
406
+ for i in range(batch_size):
407
+ w, h = orig_sizes[i]
408
+ alpha = cv2.resize(alphas_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
409
+ fg = cv2.resize(fgs_gpu[i].transpose(1, 2, 0), (w, h), interpolation=cv2.INTER_LANCZOS4)
410
+ if alpha.ndim == 2:
411
+ alpha = alpha[:, :, np.newaxis]
412
+ if auto_despeckle:
413
+ alpha = clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
414
+ fg = despill(fg, green_limit_mode="average", strength=despill_strength)
415
+ results.append({"alpha": alpha, "fg": fg})
416
+
417
+ return results
418
+
419
+
420
  # ---------------------------------------------------------------------------
421
  # Video stitching
422
  # ---------------------------------------------------------------------------
 
434
  logger.warning("ffmpeg failed: %s", e)
435
  return False
436
 
437
+
438
  # ---------------------------------------------------------------------------
439
+ # Output writing helper
440
  # ---------------------------------------------------------------------------
441
+ # Fastest PNG params: compression 1 (instead of default 3)
442
+ _PNG_FAST = [cv2.IMWRITE_PNG_COMPRESSION, 1]
443
+ # JPEG for opaque outputs (comp/fg) 10x faster than PNG at 4K
444
+ _JPG_QUALITY = [cv2.IMWRITE_JPEG_QUALITY, 95]
445
+
446
+
447
+ def _write_frame_fast(i, alpha, fg, w, h, bg_lin, comp_dir, matte_dir, fg_dir):
448
+ """Fast write: comp (JPEG) + matte (PNG) + fg (JPEG). No heavy PNG/npz."""
449
+ if alpha.ndim == 2:
450
+ alpha = alpha[:, :, np.newaxis]
451
+ alpha_2d = alpha[:, :, 0]
452
+ fg_lin = srgb_to_linear(fg)
453
+ comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
454
+ cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"),
455
+ (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
456
+ cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"),
457
+ (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
458
+ cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"),
459
+ (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST)
460
+
461
+
462
+ def _write_frame_deferred(i, raw_path, w, h, bg_lin, fg_dir, processed_dir):
463
+ """Deferred write: FG (JPEG) + Processed (RGBA PNG). Runs after GPU release."""
464
+ d = np.load(raw_path)
465
+ alpha, fg = d["alpha"], d["fg"]
466
+ if alpha.ndim == 2:
467
+ alpha = alpha[:, :, np.newaxis]
468
+ alpha_2d = alpha[:, :, 0]
469
+ cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"),
470
+ (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
471
+ fg_lin = srgb_to_linear(fg)
472
+ fg_premul = premultiply(fg_lin, alpha)
473
+ fg_premul_srgb = linear_to_srgb(fg_premul)
474
+ fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8)
475
+ a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8)
476
+ rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1)
477
+ cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST)
478
+ os.remove(raw_path) # cleanup
479
+
480
+
481
+ def _write_frame_outputs(i, alpha, fg, w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir):
482
+ """Full write: all 4 outputs. Used by CPU path."""
483
+ if alpha.ndim == 2:
484
+ alpha = alpha[:, :, np.newaxis]
485
+ alpha_2d = alpha[:, :, 0]
486
+ fg_lin = srgb_to_linear(fg)
487
+ comp = linear_to_srgb(composite_straight(fg_lin, bg_lin, alpha))
488
+ cv2.imwrite(os.path.join(comp_dir, f"{i:05d}.jpg"),
489
+ (np.clip(comp, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
490
+ cv2.imwrite(os.path.join(fg_dir, f"{i:05d}.jpg"),
491
+ (np.clip(fg, 0, 1) * 255).astype(np.uint8)[:, :, ::-1], _JPG_QUALITY)
492
+ cv2.imwrite(os.path.join(matte_dir, f"{i:05d}.png"),
493
+ (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8), _PNG_FAST)
494
+ fg_premul = premultiply(fg_lin, alpha)
495
+ fg_premul_srgb = linear_to_srgb(fg_premul)
496
+ fg_u8 = (np.clip(fg_premul_srgb, 0, 1) * 255).astype(np.uint8)
497
+ a_u8 = (np.clip(alpha_2d, 0, 1) * 255).astype(np.uint8)
498
+ rgba = np.concatenate([fg_u8[:, :, ::-1], a_u8[:, :, np.newaxis]], axis=-1)
499
+ cv2.imwrite(os.path.join(processed_dir, f"{i:05d}.png"), rgba, _PNG_FAST)
500
+
501
+
502
+ # ---------------------------------------------------------------------------
503
+ # Shared storage: GPU function stores results here instead of returning them.
504
+ # This avoids ZeroGPU serializing gigabytes of numpy arrays on return.
505
+ # ---------------------------------------------------------------------------
506
+ _shared_results = {"data": None}
507
+
508
+ # ---------------------------------------------------------------------------
509
+ # Main pipeline
510
+ # ---------------------------------------------------------------------------
511
+ def _gpu_decorator(fn):
512
+ if HAS_SPACES:
513
+ return spaces.GPU(duration=120)(fn)
514
+ return fn
515
+
516
+
517
+ @_gpu_decorator
518
+ def _gpu_phase(video_path, resolution, despill_val, mask_mode,
519
+ auto_despeckle, despeckle_size, progress=gr.Progress(),
520
+ precompute_dir=None, precompute_count=0):
521
+ """ALL GPU work: load models, read video, generate masks, run inference.
522
+ Returns raw numpy results in RAM. No disk I/O.
523
  """
524
  if video_path is None:
525
  raise gr.Error("Please upload a video.")
526
 
527
+ _ensure_gpu_sessions()
528
+
529
+ try:
530
+ import torch
531
+ has_torch_cuda = torch.cuda.is_available()
532
+ except ImportError:
533
+ has_torch_cuda = False
534
+ use_gpu = has_torch_cuda
535
+ logger.info("[GPU phase] CUDA=%s, mode=%s", has_torch_cuda,
536
+ "PyTorch batched" if use_gpu else "ONNX sequential")
537
+
538
  img_size = int(resolution)
539
+ max_dur = MAX_DURATION_GPU if use_gpu else MAX_DURATION_CPU
540
+ despill_strength = despill_val / 10.0
541
 
542
+ # Read video metadata
543
  cap = cv2.VideoCapture(video_path)
544
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
545
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
548
  cap.release()
549
 
550
  if total_frames == 0:
551
+ raise gr.Error("Could not read video frames.")
552
  duration = total_frames / fps
553
  if duration > max_dur:
554
+ raise gr.Error(f"Video too long ({duration:.1f}s). Max {max_dur}s.")
 
555
  frames_to_process = min(total_frames, MAX_FRAMES)
 
 
556
 
557
+ # Load BiRefNet only if masks need it (skip if all precomputed)
558
+ birefnet = None
559
+ needs_birefnet = precompute_dir is None or precompute_count == 0
560
+ if not needs_birefnet and mask_mode != "Fast (classical)":
561
+ # Check if any frames need BiRefNet (missing mask files)
562
+ for i in range(min(frames_to_process, precompute_count)):
563
+ if not os.path.exists(os.path.join(precompute_dir, f"mask_{i:05d}.npy")):
564
+ needs_birefnet = True
565
+ break
566
+ if needs_birefnet:
567
+ progress(0.02, desc="Loading BiRefNet...")
568
+ birefnet = get_birefnet()
569
+ logger.info("BiRefNet loaded (needed for some frames)")
570
+ else:
571
+ logger.info("Skipping BiRefNet load (all masks precomputed)")
572
 
573
+ batch_size = GPU_BATCH_SIZES.get(resolution, 16) if use_gpu else 1
574
+ if use_gpu:
575
+ progress(0.05, desc=f"Loading GreenFormer ({resolution})...")
576
+ pytorch_model = get_pytorch_model(img_size)
577
+ else:
578
+ progress(0.05, desc=f"Loading CorridorKey ONNX ({resolution})...")
579
+ corridorkey_onnx = get_corridorkey_onnx(resolution)
580
+
581
+ logger.info("[GPU phase] %d frames (%dx%d @ %.1ffps), res=%d, mask=%s, batch=%d",
582
+ frames_to_process, w, h, fps, img_size, mask_mode, batch_size)
583
+
584
+ # Read all frames + generate masks + run inference
585
  tmpdir = tempfile.mkdtemp(prefix="ck_")
586
+ frame_times = []
587
+ total_start = time.time()
588
 
589
  try:
590
+ cap = cv2.VideoCapture(video_path)
 
 
 
 
 
 
591
 
592
+ if use_gpu:
593
+ import torch
594
+ vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
595
+ logger.info("VRAM: %.1f/%.1fGB",
596
+ torch.cuda.memory_allocated() / 1024**3, vram_total)
597
+
598
+ all_results = []
599
+ frame_idx = 0
600
+
601
+ # Load precomputed frames from disk (no serialization overhead)
602
+ use_precomputed = precompute_dir is not None and precompute_count > 0
603
+
604
+ while frame_idx < frames_to_process:
605
+ t_batch = time.time()
606
+
607
+ batch_images, batch_masks, batch_indices = [], [], []
608
+ t_mask = 0
609
+ fast_n, biref_n = 0, 0
610
+
611
+ for _ in range(batch_size):
612
+ if frame_idx >= frames_to_process:
613
+ break
614
+
615
+ if use_precomputed:
616
+ frame_f32 = np.load(os.path.join(precompute_dir, f"frame_{frame_idx:05d}.npy"))
617
+ mask_path = os.path.join(precompute_dir, f"mask_{frame_idx:05d}.npy")
618
+ if os.path.exists(mask_path):
619
+ mask = np.load(mask_path)
620
+ fast_n += 1
621
+ else:
622
+ # BiRefNet fallback — load original RGB, run on GPU
623
+ rgb_path = os.path.join(precompute_dir, f"rgb_{frame_idx:05d}.npy")
624
+ frame_rgb = np.load(rgb_path)
625
+ tm = time.time()
626
+ mask = birefnet_frame(birefnet, frame_rgb)
627
+ t_mask += time.time() - tm
628
+ biref_n += 1
629
+ else:
630
+ ret, frame_bgr = cap.read()
631
+ if not ret:
632
+ break
633
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
634
+ frame_f32 = frame_rgb.astype(np.float32) / 255.0
635
+ tm = time.time()
636
+ if mask_mode == "Fast (classical)":
637
+ mask, _ = fast_greenscreen_mask(frame_f32)
638
+ fast_n += 1
639
+ elif mask_mode == "Hybrid (auto)":
640
+ mask, conf = fast_greenscreen_mask(frame_f32)
641
+ if mask is None or conf < 0.7:
642
+ mask = birefnet_frame(birefnet, frame_rgb)
643
+ biref_n += 1
644
+ else:
645
+ fast_n += 1
646
+ else:
647
+ mask = birefnet_frame(birefnet, frame_rgb)
648
+ biref_n += 1
649
+ t_mask += time.time() - tm
650
+
651
+ batch_images.append(frame_f32)
652
+ batch_masks.append(mask)
653
+ batch_indices.append(frame_idx)
654
+ frame_idx += 1
655
+
656
+ if not batch_images:
657
+ break
658
+
659
+ # Batched GPU inference
660
+ t_inf = time.time()
661
+ results = corridorkey_batch_pytorch(
662
+ pytorch_model, batch_images, batch_masks, img_size,
663
+ despill_strength=despill_strength,
664
+ auto_despeckle=auto_despeckle,
665
+ despeckle_size=int(despeckle_size),
666
+ )
667
+ t_inf = time.time() - t_inf
668
+
669
+ for j, result in enumerate(results):
670
+ all_results.append((batch_indices[j], result["alpha"], result["fg"]))
671
+
672
+ n = len(batch_images)
673
+ elapsed = time.time() - t_batch
674
+ vram_peak = torch.cuda.max_memory_allocated() / 1024**3
675
+ logger.info("Batch %d: mask=%.1fs(fast=%d,biref=%d) infer=%.1fs total=%.1fs(%.2fs/fr) VRAM=%.1fGB",
676
+ n, t_mask, fast_n, biref_n, t_inf, elapsed, elapsed/n, vram_peak)
677
+
678
+ per_frame = elapsed / n
679
+ frame_times.extend([per_frame] * n)
680
+ remaining = (frames_to_process - frame_idx) * (np.mean(frame_times[-20:]) if len(frame_times) > 1 else per_frame)
681
+ progress(0.10 + 0.75 * frame_idx / frames_to_process,
682
+ desc=f"Frame {frame_idx}/{frames_to_process} ({per_frame:.2f}s/fr) ~{remaining:.0f}s left")
683
+
684
+ cap.release()
685
+ gpu_elapsed = time.time() - total_start
686
+ logger.info("[GPU phase] done: %d frames in %.1fs (%.2fs/fr)",
687
+ len(all_results), gpu_elapsed, gpu_elapsed / max(len(all_results), 1))
688
+
689
+ # FAST WRITE inside GPU: only comp (JPEG) + matte (PNG) + raw numpy.
690
+ # FG + Processed written AFTER GPU release (deferred).
691
+ from concurrent.futures import ThreadPoolExecutor
692
+ bg_lin = srgb_to_linear(create_checkerboard(w, h))
693
+ comp_dir = os.path.join(tmpdir, "Comp")
694
+ matte_dir = os.path.join(tmpdir, "Matte")
695
+ fg_dir = os.path.join(tmpdir, "FG")
696
+ processed_dir = os.path.join(tmpdir, "Processed")
697
+ for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
698
+ os.makedirs(d, exist_ok=True)
699
+
700
+ t_write = time.time()
701
+ progress(0.86, desc="Writing preview frames...")
702
+ with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as pool:
703
+ futs = [pool.submit(_write_frame_fast, idx, alpha, fg, w, h, bg_lin,
704
+ comp_dir, matte_dir, fg_dir)
705
+ for idx, alpha, fg in all_results]
706
+ for f in futs:
707
+ f.result()
708
+ del all_results
709
+ gc.collect()
710
+ logger.info("[GPU phase] Fast write in %.1fs", time.time() - t_write)
711
+
712
+ return {
713
+ "results": "written", "frame_times": frame_times,
714
+ "use_gpu": True, "batch_size": batch_size,
715
+ "w": w, "h": h, "fps": fps, "tmpdir": tmpdir,
716
+ }
717
+
718
+ else:
719
+ # CPU PATH: sequential ONNX + inline writes (no GPU budget concern)
720
+ bg_lin = srgb_to_linear(create_checkerboard(w, h))
721
+ comp_dir, fg_dir = os.path.join(tmpdir, "Comp"), os.path.join(tmpdir, "FG")
722
+ matte_dir, processed_dir = os.path.join(tmpdir, "Matte"), os.path.join(tmpdir, "Processed")
723
+ for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
724
+ os.makedirs(d, exist_ok=True)
725
+
726
+ for i in range(frames_to_process):
727
+ t0 = time.time()
728
+ ret, frame_bgr = cap.read()
729
+ if not ret:
730
+ break
731
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
732
+ frame_f32 = frame_rgb.astype(np.float32) / 255.0
733
+
734
+ if mask_mode == "Fast (classical)":
735
+ mask, _ = fast_greenscreen_mask(frame_f32)
736
+ if mask is None:
737
+ raise gr.Error("Fast mask failed. Try 'AI (BiRefNet)' mode.")
738
+ elif mask_mode == "Hybrid (auto)":
739
+ mask, conf = fast_greenscreen_mask(frame_f32)
740
+ if mask is None or conf < 0.7:
741
+ mask = birefnet_frame(birefnet, frame_rgb)
742
+ else:
743
+ mask = birefnet_frame(birefnet, frame_rgb)
744
 
745
+ result = corridorkey_frame_onnx(corridorkey_onnx, frame_f32, mask, img_size,
746
+ despill_strength=despill_strength,
747
+ auto_despeckle=auto_despeckle,
748
+ despeckle_size=int(despeckle_size))
749
+ _write_frame_outputs(i, result["alpha"], result["fg"],
750
+ w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir)
751
+
752
+ elapsed = time.time() - t0
753
+ frame_times.append(elapsed)
754
+ remaining = (frames_to_process - i - 1) * (np.mean(frame_times[-5:]) if len(frame_times) > 1 else elapsed)
755
+ progress(0.10 + 0.80 * (i+1) / frames_to_process,
756
+ desc=f"Frame {i+1}/{frames_to_process} ({elapsed:.1f}s) ~{remaining:.0f}s left")
757
+
758
+ cap.release()
759
+ return {
760
+ "results": None, "frame_times": frame_times,
761
+ "use_gpu": False, "batch_size": 1,
762
+ "w": w, "h": h, "fps": fps, "tmpdir": tmpdir,
763
+ }
764
 
765
+ except gr.Error:
766
+ raise
767
+ except Exception as e:
768
+ logger.exception("Inference failed")
769
+ raise gr.Error(f"Inference failed: {e}")
770
 
771
+
772
+ def process_video(video_path, resolution, despill_val, mask_mode,
773
+ auto_despeckle, despeckle_size, progress=gr.Progress()):
774
+ """Orchestrator: precompute fast masks (CPU) → GPU inference → CPU I/O."""
775
+ if video_path is None:
776
+ raise gr.Error("Please upload a video.")
777
+
778
+ # Phase 0: Precompute fast masks on CPU and save to disk.
779
+ # IMPORTANT: Can't pass large data as args to @spaces.GPU (ZeroGPU serializes args).
780
+ # Save to a numpy file, pass only the path.
781
+ logger.info("[Phase 0] Precomputing fast masks on CPU")
782
+ t_mask = time.time()
783
+ precompute_dir = tempfile.mkdtemp(prefix="ck_pre_")
784
+ cap = cv2.VideoCapture(video_path)
785
+ frame_count = 0
786
+ needs_birefnet = False
787
+ while True:
788
+ ret, frame_bgr = cap.read()
789
+ if not ret:
790
+ break
791
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
792
+ frame_f32 = frame_rgb.astype(np.float32) / 255.0
793
+ if mask_mode == "Fast (classical)":
794
+ mask, _ = fast_greenscreen_mask(frame_f32)
795
+ if mask is None:
796
+ raise gr.Error("Fast mask failed. Try 'Hybrid' or 'AI' mode.")
797
+ elif mask_mode == "Hybrid (auto)":
798
+ mask, conf = fast_greenscreen_mask(frame_f32)
799
+ if mask is None or conf < 0.7:
800
+ mask = None
801
+ needs_birefnet = True
802
+ else:
803
+ mask = None
804
+ needs_birefnet = True
805
+ # Save as compressed numpy (fast to load, no serialization overhead)
806
+ np.save(os.path.join(precompute_dir, f"frame_{frame_count:05d}.npy"), frame_f32)
807
+ if mask is not None:
808
+ np.save(os.path.join(precompute_dir, f"mask_{frame_count:05d}.npy"), mask)
809
+ if mask is None:
810
+ np.save(os.path.join(precompute_dir, f"rgb_{frame_count:05d}.npy"), frame_rgb)
811
+ frame_count += 1
812
+ cap.release()
813
+ logger.info("[Phase 0] %d frames saved to %s in %.1fs (needs_birefnet=%s)",
814
+ frame_count, precompute_dir, time.time() - t_mask, needs_birefnet)
815
+
816
+ # Phase 1: GPU inference pass only paths (tiny strings), not data
817
+ logger.info("[Phase 1] Starting GPU phase")
818
+ t0 = time.time()
819
+ data = _gpu_phase(video_path, resolution, despill_val, mask_mode,
820
+ auto_despeckle, despeckle_size, progress,
821
+ precompute_dir=precompute_dir, precompute_count=frame_count)
822
+ logger.info("[process_video] GPU phase done in %.1fs", time.time() - t0)
823
+
824
+ tmpdir = data["tmpdir"]
825
+ w, h, fps = data["w"], data["h"], data["fps"]
826
+ frame_times = data["frame_times"]
827
+ use_gpu = data["use_gpu"]
828
+ batch_size = data["batch_size"]
829
+
830
+ comp_dir = os.path.join(tmpdir, "Comp")
831
+ fg_dir = os.path.join(tmpdir, "FG")
832
+ matte_dir = os.path.join(tmpdir, "Matte")
833
+ processed_dir = os.path.join(tmpdir, "Processed")
834
+ for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
835
+ os.makedirs(d, exist_ok=True)
836
+
837
+ try:
838
+ from concurrent.futures import ThreadPoolExecutor
839
+
840
+ logger.info("[Phase 2] Frames written by GPU/CPU phase (comp+fg+matte)")
841
+
842
+ # Phase 3: stitch videos from written frames
843
+ logger.info("[Phase 3] Stitching videos")
844
+ progress(0.93, desc="Stitching videos...")
845
  comp_video = os.path.join(tmpdir, "comp_preview.mp4")
846
  matte_video = os.path.join(tmpdir, "matte_preview.mp4")
847
+ # Comp uses JPEG, Matte uses PNG
848
+ _stitch_ffmpeg(comp_dir, comp_video, fps, pattern="%05d.jpg", extra_args=["-crf", "18"])
849
+ _stitch_ffmpeg(matte_dir, matte_video, fps, pattern="%05d.png", extra_args=["-crf", "18"])
850
 
851
+ # Phase 4: ZIP (no GPU)
852
+ logger.info("[Phase 4] Packaging ZIP")
853
  progress(0.96, desc="Packaging ZIP...")
854
  zip_path = os.path.join(tmpdir, "CorridorKey_Output.zip")
855
  with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zf:
856
  for folder in ["Comp", "FG", "Matte", "Processed"]:
857
  src = os.path.join(tmpdir, folder)
858
+ if os.path.isdir(src):
859
+ for f in sorted(os.listdir(src)):
860
+ zf.write(os.path.join(src, f), f"Output/{folder}/{f}")
861
 
862
  progress(1.0, desc="Done!")
863
+ total_elapsed = sum(frame_times) if frame_times else 0
864
  n = len(frame_times)
865
  avg = np.mean(frame_times) if frame_times else 0
866
+ engine = "PyTorch GPU" if use_gpu else "ONNX CPU"
867
+ status = (f"Processed {n} frames ({w}x{h}) at {resolution}px | "
868
+ f"{avg:.2f}s/frame | {engine}" +
869
+ (f" batch={batch_size}" if use_gpu else ""))
870
 
871
  return (
872
  comp_video if os.path.exists(comp_video) else None,
 
878
  except gr.Error:
879
  raise
880
  except Exception as e:
881
+ logger.exception("Output writing failed")
882
+ raise gr.Error(f"Output failed: {e}")
883
  finally:
884
  for d in ["Comp", "FG", "Matte", "Processed"]:
885
  p = os.path.join(tmpdir, d)
 
894
  def process_example(video_path, resolution, despill, mask_mode, despeckle, despeckle_size):
895
  return process_video(video_path, resolution, despill, mask_mode, despeckle, despeckle_size)
896
 
897
+ DESCRIPTION = """# CorridorKey Green Screen Matting
898
+ Remove green backgrounds from video. Based on [CorridorKey](https://www.youtube.com/watch?v=3Ploi723hg4) by Corridor Digital.
899
+ ZeroGPU H200: batched PyTorch inference (up to 32 frames at once). CPU fallback via ONNX."""
 
900
 
901
  with gr.Blocks(title="CorridorKey") as demo:
902
  gr.Markdown(DESCRIPTION)
 
908
  resolution = gr.Radio(
909
  choices=["1024", "2048"], value="1024",
910
  label="Processing Resolution",
911
+ info="1024 = fast (batch 32 on GPU), 2048 = max quality (batch 8 on GPU)"
912
  )
913
  mask_mode = gr.Radio(
914
  choices=["Hybrid (auto)", "AI (BiRefNet)", "Fast (classical)"],
requirements.txt CHANGED
@@ -1,5 +1,10 @@
1
  numpy
2
  opencv-python-headless
3
  huggingface-hub
4
- onnxruntime
 
5
  gradio[mcp]
 
 
 
 
 
1
  numpy
2
  opencv-python-headless
3
  huggingface-hub
4
+ onnxruntime-gpu
5
+ spaces
6
  gradio[mcp]
7
+ torch
8
+ torchvision
9
+ timm
10
+ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.0/flash_attn-2.8.3+cu126torch2.9-cp310-cp310-linux_x86_64.whl