enoky commited on
Commit
b5cd334
Β·
verified Β·
1 Parent(s): 366a1de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -291
app.py CHANGED
@@ -14,402 +14,327 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Running on device: {device}")
15
 
16
  # ==============================================================================
17
- # 1. FORWARD WARP IMPLEMENTATION (Native PyTorch)
18
  # ==============================================================================
19
  class ForwardWarpFunction(Function):
20
  @staticmethod
21
  def forward(ctx, im0, flow, interpolation_mode_int):
22
- # Input validation
23
- assert(len(im0.shape) == len(flow.shape) == 4)
24
- assert(interpolation_mode_int == 0 or interpolation_mode_int == 1)
25
- assert(im0.shape[0] == flow.shape[0])
26
- assert(im0.shape[-2:] == flow.shape[1:3])
27
- assert(flow.shape[3] == 2)
28
 
29
  B, C, H, W = im0.shape
30
- im1 = torch.zeros_like(im0, device=im0.device, dtype=im0.dtype)
31
 
32
- # Grid creation
33
- grid_x, grid_y = torch.meshgrid(
34
- torch.arange(W, device=im0.device, dtype=im0.dtype),
35
- torch.arange(H, device=im0.device, dtype=im0.dtype),
36
- indexing='xy'
37
  )
38
  grid_x = grid_x.unsqueeze(0).expand(B, -1, -1)
39
  grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
40
 
41
- # Destination coordinates
42
- x_dest = grid_x + flow[:, :, :, 0]
43
- y_dest = grid_y + flow[:, :, :, 1]
44
-
45
  if interpolation_mode_int == 0: # Bilinear Splatting
46
- x_f = torch.floor(x_dest).long()
47
- y_f = torch.floor(y_dest).long()
48
- x_c = x_f + 1
49
- y_c = y_f + 1
50
-
51
- # Weights
52
- nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
53
- ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
54
- sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
55
- se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
56
-
57
- # Clamp coords
58
- x_f_clamped = torch.clamp(x_f, 0, W - 1)
59
- y_f_clamped = torch.clamp(y_f, 0, H - 1)
60
- x_c_clamped = torch.clamp(x_c, 0, W - 1)
61
- y_c_clamped = torch.clamp(y_c, 0, H - 1)
62
-
63
- # Valid mask (source pixels that land inside canvas)
64
- valid_mask = (x_f >= 0) & (x_c < W) & (y_f >= 0) & (y_c < H)
65
-
66
- # Reshape for broadcasting
67
- nw_k = nw_k.unsqueeze(1)
68
- ne_k = ne_k.unsqueeze(1)
69
- sw_k = sw_k.unsqueeze(1)
70
- se_k = se_k.unsqueeze(1)
71
- valid_mask = valid_mask.unsqueeze(1)
72
-
73
- # Flatten indices for scatter_add
74
- b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
75
- c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
76
- base_idx = b_indices * (C * H * W) + c_indices * (H * W)
77
-
78
- # Scatter to 4 neighbors (Accumulate/Splat)
79
- def scatter_corner(y_idx, x_idx, weights):
80
- flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
81
- values = (im0 * weights) * valid_mask.float()
82
-
83
- # Use contiguous() before view() to fix RuntimeError
84
- im1_flat = im1.view(-1)
85
- idx_flat = flat_idx.contiguous().view(-1)
86
- val_flat = values.contiguous().view(-1)
87
- im1_flat.scatter_add_(0, idx_flat, val_flat)
88
-
89
- scatter_corner(y_f_clamped, x_f_clamped, nw_k) # NW
90
- scatter_corner(y_f_clamped, x_c_clamped, ne_k) # NE
91
- scatter_corner(y_c_clamped, x_f_clamped, sw_k) # SW
92
- scatter_corner(y_c_clamped, x_c_clamped, se_k) # SE
93
-
94
- else: # Nearest Neighbor (Legacy fallback)
95
- x_nearest = torch.round(x_dest).long()
96
- y_nearest = torch.round(y_dest).long()
97
- valid_mask = (x_nearest >= 0) & (x_nearest < W) & (y_nearest >= 0) & (y_nearest < H)
98
- valid_mask = valid_mask.unsqueeze(1)
99
-
100
- x_clamped = torch.clamp(x_nearest, 0, W - 1)
101
- y_clamped = torch.clamp(y_nearest, 0, H - 1)
102
-
103
- b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
104
- c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
105
- dest_idx = b_indices*(C*H*W) + c_indices*(H*W) + y_clamped.unsqueeze(1)*W + x_clamped.unsqueeze(1)
106
-
107
- source_values = im0 * valid_mask.float()
108
-
109
- # Use contiguous()
110
- im1.view(-1).scatter_(0, dest_idx.contiguous().view(-1), source_values.contiguous().view(-1))
111
-
112
  return im1
113
 
114
  @staticmethod
115
  def backward(ctx, grad_output):
116
  return None, None, None
117
 
 
118
  class forward_warp(nn.Module):
119
  def __init__(self, interpolation_mode="Bilinear"):
120
- super(forward_warp, self).__init__()
121
- self.interpolation_mode_int = 0 if interpolation_mode == "Bilinear" else 1
122
 
123
  def forward(self, im0, flow):
124
- return ForwardWarpFunction.apply(im0, flow, self.interpolation_mode_int)
 
125
 
126
  # ==============================================================================
127
- # 2. STEREO WARPER WRAPPER
128
  # ==============================================================================
129
  class ForwardWarpStereo(nn.Module):
130
- """
131
- Weighted Splatting wrapper.
132
- Handles Occlusions using exponential depth weights (Soft Z-Buffering).
133
- """
134
  def __init__(self, eps=1e-6):
135
- super(ForwardWarpStereo, self).__init__()
136
  self.eps = eps
137
  self.fw = forward_warp(interpolation_mode="Bilinear")
138
 
139
  def forward(self, im, disp, convergence, divergence):
140
- # disp comes in as [B, 1, H, W] or [1, 1, H, W]
141
- # We need to squeeze the channel dim to do math with coordinates [B, H, W]
142
- disp_squeeze = disp.squeeze(1) # Shape [B, H, W]
143
-
144
- # Create Flow from Disparity
145
- # Shift = (Depth - Convergence) * Divergence
146
- # We negate it because standard flow is source->dest, but disparity logic varies.
147
- # For Right Eye view: Target = Source - Shift. So Flow = -Shift.
148
- shift = (disp_squeeze - convergence) * divergence
149
  flow_x = -shift
150
-
151
- # Stack flow (x, y=0) -> (B, H, W, 2)
152
- flow_y = torch.zeros_like(flow_x)
153
-
154
- # Stack along last dim: [B, H, W] + [B, H, W] -> [B, H, W, 2]
155
- flow = torch.stack((flow_x, flow_y), dim=-1)
156
-
157
- # 1. Calculate Weights (Soft Z-Buffer)
158
- # Closer objects (higher disparity) get exponentially higher weight.
159
- # This allows foreground to overwrite background during accumulation.
160
- # Using 1.5^disp is a tuned heuristic for separation.
161
- weights_map = disp - disp.min()
162
- weights_map = (1.5) ** weights_map
163
-
164
- # 2. Warp Image * Weights (Accumulate Weighted Color)
165
- # Input im is (B, C, H, W), weights is (B, 1, H, W)
166
- res_accum = self.fw(im * weights_map, flow)
167
-
168
- # 3. Warp Weights (Accumulate Weights)
169
- mask_accum = self.fw(weights_map, flow)
170
-
171
- # 4. Normalize (Color / TotalWeight)
172
- # Add epsilon to avoid divide-by-zero in empty regions
173
- mask_accum.clamp_(min=self.eps)
174
- res = res_accum / mask_accum
175
-
176
- # 5. Generate Binary Occlusion Mask (for Inpainting)
177
- # Splat a grid of ones. Where sum is 0, we have a hole.
178
  ones = torch.ones_like(disp)
179
- occupancy = self.fw(ones, flow)
180
-
181
- # Valid pixels have occupancy > 0.
182
- # We want holes = 1.0, filled = 0.0
183
  occlusion_mask = (occupancy < self.eps).float()
184
-
185
- return res, occlusion_mask
 
186
 
187
  # ==============================================================================
188
- # 3. APP LOGIC & MODELS
189
  # ==============================================================================
190
-
191
- # === LOAD MODELS ===
192
  def load_models():
193
  print("Loading Depth Anything V2 Large...")
194
  depth_model = AutoModelForDepthEstimation.from_pretrained(
195
  "depth-anything/Depth-Anything-V2-Large-hf"
196
- ).to(device)
 
197
  depth_processor = AutoImageProcessor.from_pretrained(
198
  "depth-anything/Depth-Anything-V2-Large-hf"
199
  )
200
-
201
  print("Loading LaMa Inpainting Model...")
202
- try:
203
- model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
204
- lama_model = torch.jit.load(model_path, map_location=device)
205
- lama_model.eval()
206
- except Exception as e:
207
- print(f"Error loading LaMa model: {e}")
208
- raise e
209
-
210
- # Initialize the new Stereo Warper
211
  stereo_warper = ForwardWarpStereo().to(device)
212
-
213
  return depth_model, depth_processor, lama_model, stereo_warper
214
 
215
- # Load models once at startup
 
216
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
217
 
218
- # === DEPTH ESTIMATION ===
219
- @torch.no_grad()
220
- def estimate_depth(image_pil, model, processor):
221
  original_size = image_pil.size
222
- inputs = processor(images=image_pil, return_tensors="pt").to(device)
223
- depth = model(**inputs).predicted_depth
224
-
225
  depth = torch.nn.functional.interpolate(
226
  depth.unsqueeze(1),
227
  size=(original_size[1], original_size[0]),
228
  mode="bicubic",
229
  align_corners=False,
230
  ).squeeze()
231
-
232
- depth_min, depth_max = depth.min(), depth.max()
233
- if depth_max - depth_min > 0:
234
- depth = (depth - depth_min) / (depth_max - depth_min)
235
- else:
236
- depth = torch.zeros_like(depth)
237
  return depth
238
 
239
- # === DEPTH MANIPULATION ===
240
  def erode_depth(depth_tensor, kernel_size):
241
- if kernel_size <= 0: return depth_tensor
 
242
  k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
243
  x = depth_tensor.unsqueeze(0).unsqueeze(0)
244
  padding = k // 2
245
- x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
246
- return x_eroded.squeeze()
247
 
248
- # === LOCAL INPAINTING ===
249
- @torch.no_grad()
250
  def run_local_lama(image_bgr, mask_float):
251
- # 0. Dilate Mask slightly to catch edge artifacts from splatting
252
  kernel = np.ones((3, 3), np.uint8)
253
  mask_uint8 = (mask_float * 255).astype(np.uint8)
254
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
255
-
256
- # 1. Resize to be divisible by 8
257
  h, w = image_bgr.shape[:2]
258
  new_h = (h // 8) * 8
259
  new_w = (w // 8) * 8
260
-
261
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
262
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
263
-
264
- # 2. Convert to Torch
265
  img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
266
- img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB
267
-
268
  mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
269
  mask_t = (mask_t > 0.5).float()
270
-
271
  img_t = img_t.to(device)
272
  mask_t = mask_t.to(device)
273
-
274
- # 3. Inference
275
  img_t = img_t * (1 - mask_t)
276
- inpainted_t = lama_model(img_t, mask_t)
277
-
278
- # 4. Post-process
279
  inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
280
  inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
281
  inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
282
-
283
- if new_h != h or new_w != w:
284
- inpainted = cv2.resize(inpainted, (w, h))
285
-
286
  return inpainted
287
 
 
288
  def make_anaglyph(left, right):
289
- l_arr = np.array(left)
290
- r_arr = np.array(right)
291
- anaglyph = np.zeros_like(l_arr)
292
- anaglyph[:, :, 0] = l_arr[:, :, 0]
293
- anaglyph[:, :, 1] = r_arr[:, :, 1]
294
- anaglyph[:, :, 2] = r_arr[:, :, 2]
295
  return Image.fromarray(anaglyph)
296
 
297
- # === PIPELINE ===
 
 
 
298
  def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
299
  if image_pil is None:
300
  return None, None, None, None
301
-
302
- # Resize input if too large
303
  w, h = image_pil.size
304
  if w > 1920:
305
  ratio = 1920 / w
306
- new_h = int(h * ratio)
307
- image_pil = image_pil.resize((1920, new_h), Image.LANCZOS)
308
-
309
- # 1. Depth Estimation
310
- depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
311
-
312
- # 2. Depth Erosion (optional halo reduction)
313
  if edge_erosion > 0:
314
- depth_tensor = erode_depth(depth_tensor, int(edge_erosion))
315
-
316
- # Visualize Depth
317
- depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8)
318
- depth_image = Image.fromarray(depth_vis)
319
-
320
- # 3. Forward Warp (Weighted Bilinear Splatting)
321
- # Convert image to tensor (B, C, H, W)
322
- image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2, 0, 1).unsqueeze(0) / 255.0
323
-
324
- # Prepare depth tensor (B, 1, H, W)
325
- depth_input = depth_tensor.unsqueeze(0).unsqueeze(0)
326
-
327
- # Run the new Stereo Warper
328
- with torch.no_grad():
329
- right_img_tensor, mask_tensor = stereo_warper(
330
- image_tensor,
331
- depth_input,
332
- float(convergence),
333
- float(divergence)
334
  )
335
-
336
- # Convert results back to CPU/Numpy
337
- right_img_rgb = (right_img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
338
- mask_vis = (mask_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255).astype(np.uint8)
339
-
340
- mask_image = Image.fromarray(mask_vis)
341
-
342
- # 4. Inpainting
343
- right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR)
344
  mask_float = mask_tensor.squeeze().cpu().numpy()
345
-
346
- right_filled_bgr = run_local_lama(right_img_bgr, mask_float)
347
-
348
- # 5. Finalize
349
- left = image_pil
350
- right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
351
-
352
- width, height = left.size
353
- combined_image = Image.new('RGB', (width * 2, height))
354
- combined_image.paste(left, (0, 0))
355
- combined_image.paste(right, (width, 0))
356
-
357
- anaglyph_image = make_anaglyph(left, right)
358
-
359
- return combined_image, anaglyph_image, depth_image, mask_image
360
-
361
- # === GRADIO UI ===
362
- css = """
363
- .gradio-container {
364
- max-width: 1400px !important;
365
- margin: auto !important;
366
- }
367
- """
368
-
369
- with gr.Blocks(title="2D to 3D Stereo") as demo:
370
- # Inject CSS
371
- gr.HTML(f"<style>{css}</style>")
372
-
373
- gr.Markdown("## 2D to 3D Stereo Generator (High-Quality Splatting)")
374
- gr.Markdown("Uses **Depth Anything V2**, **Bilinear Weighted Splatting** (Soft Z-Buffer), and **LaMa Inpainting**.")
375
-
376
  with gr.Row():
377
  with gr.Column(scale=1):
378
- input_img = gr.Image(type="pil", label="Input Image", height=320)
379
-
380
- with gr.Group():
381
- gr.Markdown("### 3D Controls")
382
- divergence_slider = gr.Slider(
383
- minimum=0, maximum=100, value=30, step=1,
384
- label="3D Strength (Divergence)",
385
- info="Max separation in pixels."
386
- )
387
- convergence_slider = gr.Slider(
388
- minimum=0.0, maximum=1.0, value=0.5, step=0.05,
389
- label="Focus Plane (Convergence)",
390
- info="0.0 = Background at screen. 1.0 = Foreground at screen."
391
- )
392
- erosion_slider = gr.Slider(
393
- minimum=0, maximum=20, value=2, step=1,
394
- label="Edge Masking (Erosion)",
395
- info="Cleanup edges. Set to 0 for raw splatting."
396
- )
397
-
398
- btn = gr.Button("Generate 3D", variant="primary")
399
 
400
  with gr.Column(scale=1):
401
- out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320)
402
- out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320)
403
-
404
  with gr.Row():
405
- out_depth = gr.Image(label="Depth Map", height=200)
406
  out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200)
407
-
408
  btn.click(
409
- fn=stereo_pipeline,
410
- inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
411
  outputs=[out_stereo, out_anaglyph, out_depth, out_mask]
412
  )
413
 
 
 
414
  if __name__ == "__main__":
415
  demo.launch()
 
14
  print(f"Running on device: {device}")
15
 
16
  # ==============================================================================
17
+ # 1. FIXED FORWARD WARP WITH BILINEAR SPLATTING (Contiguous & Stable)
18
  # ==============================================================================
19
  class ForwardWarpFunction(Function):
20
  @staticmethod
21
  def forward(ctx, im0, flow, interpolation_mode_int):
22
+ assert im0.shape[0] == flow.shape[0]
23
+ assert im0.shape[-2:] == flow.shape[-3:-1]
24
+ assert flow.shape[-1] == 2
 
 
 
25
 
26
  B, C, H, W = im0.shape
27
+ im1 = torch.zeros_like(im0)
28
 
29
+ # Grid: [B, H, W]
30
+ grid_y, grid_x = torch.meshgrid(
31
+ torch.arange(H, device=im0.device, dtype=torch.float32),
32
+ torch.arange(W, device=im0.device, dtype=torch.float32),
33
+ indexing='ij'
34
  )
35
  grid_x = grid_x.unsqueeze(0).expand(B, -1, -1)
36
  grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
37
 
38
+ x_dest = grid_x + flow[..., 0]
39
+ y_dest = grid_y + flow[..., 1]
40
+
 
41
  if interpolation_mode_int == 0: # Bilinear Splatting
42
+ x0 = torch.floor(x_dest).long()
43
+ y0 = torch.floor(y_dest).long()
44
+ x1 = x0 + 1
45
+ y1 = y0 + 1
46
+
47
+ # Bilinear weights
48
+ w00 = (x1.float() - x_dest) * (y1.float() - y_dest) # top-left
49
+ w10 = (x_dest - x0.float()) * (y1.float() - y_dest) # top-right
50
+ w01 = (x1.float() - x_dest) * (y_dest - y0.float()) # bottom-left
51
+ w11 = (x_dest - x0.float()) * (y_dest - y0.float()) # bottom-right
52
+
53
+ # Clamp coordinates
54
+ x0c = x0.clamp(0, W - 1)
55
+ y0c = y0.clamp(0, H - 1)
56
+ x1c = x1.clamp(0, W - 1)
57
+ y1c = y1.clamp(0, H - 1)
58
+
59
+ valid = (x0 >= 0) & (x1 < W) & (y0 >= 0) & (y1 < H) # [B, H, W]
60
+
61
+ # Ensure contiguous
62
+ im0 = im0.contiguous()
63
+ valid = valid.unsqueeze(1).float() # [B, 1, H, W]
64
+
65
+ def splat(y_idx, x_idx, weight):
66
+ weight = (weight.unsqueeze(1) * valid).contiguous() # [B,1,H,W]
67
+ values = (im0 * weight).reshape(B * C, -1) # [B*C, H*W]
68
+
69
+ # Compute flat indices: B,C,H,W β†’ global index
70
+ b_idx = torch.arange(B, device=im0.device).view(B, 1, 1, 1)
71
+ c_idx = torch.arange(C, device=im0.device).view(1, C, 1, 1)
72
+ base = (b_idx * C * H * W + c_idx * H * W).expand(-1, -1, H, W)
73
+
74
+ idx = base + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
75
+ idx = idx.reshape(B * C, -1).contiguous()
76
+
77
+ im1.view(-1).scatter_add_(0, idx.view(-1), values.view(-1))
78
+
79
+ splat(y0c, x0c, w00)
80
+ splat(y0c, x1c, w10)
81
+ splat(y1c, x0c, w01)
82
+ splat(y1c, x1c, w11)
83
+
84
+ else: # Nearest neighbor (fallback)
85
+ x_nn = torch.round(x_dest).long().clamp(0, W - 1)
86
+ y_nn = torch.round(y_dest).long().clamp(0, H - 1)
87
+
88
+ b_idx = torch.arange(B, device=im0.device)[:, None, None, None]
89
+ c_idx = torch.arange(C, device=im0.device)[None, :, None, None]
90
+ idx = (b_idx * C * H * W + c_idx * H * W + y_nn.unsqueeze(1) * W + x_nn.unsqueeze(1))
91
+ idx = idx.reshape(-1)
92
+
93
+ valid = ((x_nn >= 0) & (x_nn < W) & (y_nn >= 0) & (y_nn < H)).unsqueeze(1)
94
+ values = (im0 * valid.float()).reshape(-1)
95
+
96
+ im1.view(-1).scatter_(0, idx, values)
97
+
 
 
 
 
 
 
 
 
 
 
98
  return im1
99
 
100
  @staticmethod
101
  def backward(ctx, grad_output):
102
  return None, None, None
103
 
104
+
105
  class forward_warp(nn.Module):
106
  def __init__(self, interpolation_mode="Bilinear"):
107
+ super().__init__()
108
+ self.mode = 0 if interpolation_mode == "Bilinear" else 1
109
 
110
  def forward(self, im0, flow):
111
+ return ForwardWarpFunction.apply(im0, flow, self.mode)
112
+
113
 
114
  # ==============================================================================
115
+ # 2. STEREO WARPER (Soft Z-Buffer Splatting)
116
  # ==============================================================================
117
  class ForwardWarpStereo(nn.Module):
 
 
 
 
118
  def __init__(self, eps=1e-6):
119
+ super().__init__()
120
  self.eps = eps
121
  self.fw = forward_warp(interpolation_mode="Bilinear")
122
 
123
  def forward(self, im, disp, convergence, divergence):
124
+ disp = disp.squeeze(1) # [B, H, W]
125
+ shift = (disp - convergence) * divergence
 
 
 
 
 
 
 
126
  flow_x = -shift
127
+ flow = torch.zeros_like(flow_x)
128
+ flow = torch.stack([flow_x, flow_y], dim=-1) # [B, H, W, 2]
129
+
130
+ # Soft Z-buffer weights (closer = higher weight)
131
+ weights = (1.5) ** (disp - disp.min())
132
+
133
+ # Warp color * weight
134
+ accum_color = self.fw(im * weights.unsqueeze(1), flow)
135
+ accum_weight = self.fw(weights.unsqueeze(1), flow)
136
+
137
+ # Normalize
138
+ result = accum_color / (accum_weight + self.eps)
139
+
140
+ # Occlusion mask (holes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  ones = torch.ones_like(disp)
142
+ occupancy = self.fw(ones.unsqueeze(1), flow)
 
 
 
143
  occlusion_mask = (occupancy < self.eps).float()
144
+
145
+ return result, occlusion_mask
146
+
147
 
148
  # ==============================================================================
149
+ # 3. MODELS & PIPELINE
150
  # ==============================================================================
 
 
151
  def load_models():
152
  print("Loading Depth Anything V2 Large...")
153
  depth_model = AutoModelForDepthEstimation.from_pretrained(
154
  "depth-anything/Depth-Anything-V2-Large-hf"
155
+ ).to(device).eval()
156
+
157
  depth_processor = AutoImageProcessor.from_pretrained(
158
  "depth-anything/Depth-Anything-V2-Large-hf"
159
  )
160
+
161
  print("Loading LaMa Inpainting Model...")
162
+ model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
163
+ lama_model = torch.jit.load(model_path, map_location=device).eval()
164
+
 
 
 
 
 
 
165
  stereo_warper = ForwardWarpStereo().to(device)
166
+
167
  return depth_model, depth_processor, lama_model, stereo_warper
168
 
169
+
170
+ # Load once at startup
171
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
172
 
173
+
174
+ @torch.inference_mode()
175
+ def estimate_depth(image_pil):
176
  original_size = image_pil.size
177
+ inputs = depth_processor(images=image_pil, return_tensors="pt").to(device)
178
+ depth = depth_model(**inputs).predicted_depth # [1, H, W]
179
+
180
  depth = torch.nn.functional.interpolate(
181
  depth.unsqueeze(1),
182
  size=(original_size[1], original_size[0]),
183
  mode="bicubic",
184
  align_corners=False,
185
  ).squeeze()
186
+
187
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
 
 
 
 
188
  return depth
189
 
190
+
191
  def erode_depth(depth_tensor, kernel_size):
192
+ if kernel_size <= 0:
193
+ return depth_tensor
194
  k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
195
  x = depth_tensor.unsqueeze(0).unsqueeze(0)
196
  padding = k // 2
197
+ eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
198
+ return eroded.squeeze()
199
 
200
+
201
+ @torch.inference_mode()
202
  def run_local_lama(image_bgr, mask_float):
 
203
  kernel = np.ones((3, 3), np.uint8)
204
  mask_uint8 = (mask_float * 255).astype(np.uint8)
205
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
206
+
 
207
  h, w = image_bgr.shape[:2]
208
  new_h = (h // 8) * 8
209
  new_w = (w // 8) * 8
210
+
211
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
212
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
213
+
 
214
  img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
215
+ img_t = img_t[:, [2, 1, 0], :, :] # BGR β†’ RGB
 
216
  mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
217
  mask_t = (mask_t > 0.5).float()
218
+
219
  img_t = img_t.to(device)
220
  mask_t = mask_t.to(device)
221
+
 
222
  img_t = img_t * (1 - mask_t)
223
+ with torch.no_grad():
224
+ inpainted_t = lama_model(img_t, mask_t)
225
+
226
  inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
227
  inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
228
  inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
229
+
230
+ if (new_h != h) or (new_w != w):
231
+ inpainted = cv2.resize(inpainted, (w, h), interpolation=cv2.INTER_LANCZOS4)
232
+
233
  return inpainted
234
 
235
+
236
  def make_anaglyph(left, right):
237
+ l = np.array(left)
238
+ r = np.array(right)
239
+ anaglyph = np.zeros_like(l)
240
+ anaglyph[:, :, 0] = l[:, :, 0] # Red ← Left
241
+ anaglyph[:, :, 1] = r[:, :, 1] # Green ← Right
242
+ anaglyph[:, :, 2] = r[:, :, 2] # Blue ← Right
243
  return Image.fromarray(anaglyph)
244
 
245
+
246
+ # ==============================================================================
247
+ # MAIN PIPELINE
248
+ # ==============================================================================
249
  def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
250
  if image_pil is None:
251
  return None, None, None, None
252
+
253
+ # Resize if too large (HF Spaces limit)
254
  w, h = image_pil.size
255
  if w > 1920:
256
  ratio = 1920 / w
257
+ image_pil = image_pil.resize((1920, int(h * ratio)), Image.LANCZOS)
258
+
259
+ # 1. Depth
260
+ depth = estimate_depth(image_pil)
 
 
 
261
  if edge_erosion > 0:
262
+ depth = erode_depth(depth, int(edge_erosion))
263
+
264
+ depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8))
265
+
266
+ # 2. Prepare tensors
267
+ img_tensor = torch.from_numpy(np.array(image_pil)).float().to(device)
268
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0 # [1,3,H,W]
269
+ depth_tensor = depth.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
270
+
271
+ # 3. Stereo warp
272
+ with torch.inference_mode():
273
+ right_tensor, mask_tensor = stereo_warper(
274
+ img_tensor, depth_tensor, float(convergence), float(divergence)
 
 
 
 
 
 
 
275
  )
276
+
277
+ right_np = (right_tensor.squeeze(0).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
278
+ mask_np = (mask_tensor.squeeze().cpu().numpy() * 255).astype(np.uint8)
279
+
280
+ # 4. Inpaint holes
281
+ right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
 
 
 
282
  mask_float = mask_tensor.squeeze().cpu().numpy()
283
+ right_filled_bgr = run_local_lama(right_bgr, mask_float)
284
+ right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
285
+
286
+ # 5. Outputs
287
+ w, h = image_pil.size
288
+ sbs = Image.new("RGB", (w * 2, h))
289
+ sbs.paste(image_pil, (0, 0))
290
+ sbs.paste(right_filled, (w, 0))
291
+
292
+ anaglyph = make_anaglyph(image_pil, right_filled)
293
+
294
+ return sbs, anaglyph, depth_vis, Image.fromarray(mask_np)
295
+
296
+
297
+ # ==============================================================================
298
+ # GRADIO UI
299
+ # ==============================================================================
300
+ css = ".gradio-container {max-width: 1400px !important; margin: auto !important;}"
301
+
302
+ with gr.Blocks(css=css, title="2D β†’ 3D Stereo (Depth Anything + Splatting)") as demo:
303
+ gr.Markdown("# 2D to 3D Stereo Generator")
304
+ gr.Markdown("High-quality automatic stereo conversion using **Depth Anything V2**, **bilinear splatting with soft Z-buffer**, and **LaMa inpainting**.")
305
+
 
 
 
 
 
 
 
 
306
  with gr.Row():
307
  with gr.Column(scale=1):
308
+ input_img = gr.Image(type="pil", label="Upload Image", height=400)
309
+
310
+ with gr.Accordion("3D Settings", open=True):
311
+ divergence_slider = gr.Slider(0, 100, value=30, step=1,
312
+ label="3D Strength (Divergence)",
313
+ info="Higher = stronger 3D pop-out")
314
+ convergence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.05,
315
+ label="Focus Plane (Convergence)",
316
+ info="0 = background at screen, 1 = foreground at screen")
317
+ erosion_slider = gr.Slider(0, 20, value=3, step=1,
318
+ label="Edge Cleanup (Depth Erosion)",
319
+ info="Reduces halos, 0 = raw")
320
+
321
+ btn = gr.Button("Generate 3D Stereo", variant="primary", size="lg")
 
 
 
 
 
 
 
322
 
323
  with gr.Column(scale=1):
324
+ out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=400)
325
+ out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=400)
326
+
327
  with gr.Row():
328
+ out_depth = gr.Image(label="Estimated Depth Map", height=200)
329
  out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200)
330
+
331
  btn.click(
332
+ fn=stereo_pipeline,
333
+ inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
334
  outputs=[out_stereo, out_anaglyph, out_depth, out_mask]
335
  )
336
 
337
+ gr.Markdown("Made with Depth Anything V2 β€’ Bilinear Splatting β€’ LaMa β€’ Gradio")
338
+
339
  if __name__ == "__main__":
340
  demo.launch()