enoky commited on
Commit
ed6a23d
·
verified ·
1 Parent(s): 4cc8594

Improve Forward Warp

Browse files
Files changed (1) hide show
  1. app.py +227 -161
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
 
6
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
7
  from huggingface_hub import hf_hub_download
8
  import os
@@ -11,11 +13,172 @@ import os
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  print(f"Running on device: {device}")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # === LOAD MODELS ===
15
  def load_models():
16
  print("Loading Depth Anything V2 Large...")
17
- # 1. Depth Model (Depth Anything V2 Large)
18
- # We use AutoModel to automatically load the correct architecture
19
  depth_model = AutoModelForDepthEstimation.from_pretrained(
20
  "depth-anything/Depth-Anything-V2-Large-hf"
21
  ).to(device)
@@ -24,173 +187,61 @@ def load_models():
24
  )
25
 
26
  print("Loading LaMa Inpainting Model...")
27
- # 2. LaMa Inpainting Model (TorchScript)
28
  try:
29
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
30
- print(f"Loading LaMa from: {model_path}")
31
  lama_model = torch.jit.load(model_path, map_location=device)
32
  lama_model.eval()
33
  except Exception as e:
34
  print(f"Error loading LaMa model: {e}")
35
  raise e
36
 
37
- return depth_model, depth_processor, lama_model
 
 
 
38
 
39
  # Load models once at startup
40
- depth_model, depth_processor, lama_model = load_models()
41
 
42
  # === DEPTH ESTIMATION ===
43
  @torch.no_grad()
44
  def estimate_depth(image_pil, model, processor):
45
  original_size = image_pil.size
46
-
47
- # Preprocess image
48
  inputs = processor(images=image_pil, return_tensors="pt").to(device)
49
-
50
- # Inference
51
  depth = model(**inputs).predicted_depth
52
 
53
- # Interpolate depth back to ORIGINAL image size
54
  depth = torch.nn.functional.interpolate(
55
  depth.unsqueeze(1),
56
  size=(original_size[1], original_size[0]),
57
  mode="bicubic",
58
  align_corners=False,
59
- ).squeeze() # Shape: (H, W)
60
 
61
- # Normalize depth to 0-1 range
62
  depth_min, depth_max = depth.min(), depth.max()
63
  if depth_max - depth_min > 0:
64
  depth = (depth - depth_min) / (depth_max - depth_min)
65
  else:
66
  depth = torch.zeros_like(depth)
67
-
68
  return depth
69
 
70
  # === DEPTH MANIPULATION ===
71
  def erode_depth(depth_tensor, kernel_size):
72
- """
73
- Shrinks the foreground (bright areas) of the depth map to reduce halos.
74
- Uses -MaxPool2d(-x) to simulate Erosion on GPU.
75
- """
76
- if kernel_size <= 0:
77
- return depth_tensor
78
-
79
- # Ensure odd kernel size for symmetry
80
  k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
81
-
82
- # Reshape for pooling: (H, W) -> (1, 1, H, W)
83
  x = depth_tensor.unsqueeze(0).unsqueeze(0)
84
-
85
- # Erosion = -MaxPool(-x)
86
- # Padding = k // 2 ensures output size matches input size
87
  padding = k // 2
88
  x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
89
-
90
  return x_eroded.squeeze()
91
 
92
- # === PYTORCH FORWARD WARP ===
93
- @torch.no_grad()
94
- def generate_right_and_mask_torch(image_pil, depth_tensor, divergence, convergence):
95
- """
96
- High-performance PyTorch Forward Warp implementation.
97
- Mimics the behavior of custom CUDA forward warp kernels but uses standard PyTorch.
98
-
99
- Args:
100
- image_pil: Input PIL image
101
- depth_tensor: Normalized depth tensor (H, W) on GPU
102
- divergence: float (pixels)
103
- convergence: float (0-1)
104
- """
105
- # 1. Prepare Data
106
- w, h = image_pil.size
107
-
108
- # Convert image to tensor (H, W, 3) -> (N, 3)
109
- # We do this on GPU to stay fast
110
- image_tensor = torch.from_numpy(np.array(image_pil)).to(device).float()
111
-
112
- # Calculate Shift Map (N,)
113
- # Shift = (Depth - Convergence) * Divergence
114
- # Positive shift = Leftwards (Pop-out)
115
- shift = (depth_tensor - convergence) * divergence
116
-
117
- # 2. Create Grid Coordinates
118
- y_coords, x_coords = torch.meshgrid(
119
- torch.arange(h, device=device),
120
- torch.arange(w, device=device),
121
- indexing='ij'
122
- )
123
-
124
- # 3. Calculate Target Coordinates
125
- # Target X = Source X - Shift
126
- target_x = x_coords - shift.round() # Round to nearest pixel for sharp mapping
127
-
128
- # 4. Flatten for advanced indexing
129
- flat_y = y_coords.reshape(-1).long()
130
- flat_x_target = target_x.reshape(-1).long()
131
- flat_x_source = x_coords.reshape(-1).long()
132
-
133
- # 5. Filter Invalid Points (Out of bounds)
134
- valid_mask = (flat_x_target >= 0) & (flat_x_target < w)
135
-
136
- flat_y = flat_y[valid_mask]
137
- flat_x_target = flat_x_target[valid_mask]
138
- flat_x_source = flat_x_source[valid_mask]
139
- flat_shift = shift.reshape(-1)[valid_mask]
140
-
141
- # 6. Z-BUFFERING / PAINTER'S ALGORITHM (Crucial for correct occlusion)
142
- # We sort pixels by shift (depth).
143
- # Less shift = Background (draw first)
144
- # More shift = Foreground (draw last)
145
- # This ensures foreground objects overwrite background objects at collision points.
146
- sort_idx = torch.argsort(flat_shift)
147
-
148
- flat_y = flat_y[sort_idx]
149
- flat_x_target = flat_x_target[sort_idx]
150
- flat_x_source = flat_x_source[sort_idx]
151
-
152
- # 7. Write to Output
153
- # Create output canvas (Black)
154
- right_tensor = torch.zeros_like(image_tensor)
155
-
156
- # Create mask (1.0 = hole, 0.0 = filled)
157
- mask_tensor = torch.ones((h, w), device=device, dtype=torch.float32)
158
-
159
- # Compute linear indices for target positions
160
- # target_idx = y * w + x
161
- target_indices = flat_y * w + flat_x_target
162
- source_indices = flat_y * w + flat_x_source
163
-
164
- # Flatten image for indexing
165
- image_flat = image_tensor.reshape(-1, 3)
166
- right_flat = right_tensor.reshape(-1, 3)
167
- mask_flat = mask_tensor.reshape(-1)
168
-
169
- # Perform the Warp
170
- # Since we sorted by depth, the last write to any index wins (Foreground wins)
171
- right_flat[target_indices] = image_flat[source_indices]
172
- mask_flat[target_indices] = 0.0
173
-
174
- # Reshape back
175
- right_img = right_flat.reshape(h, w, 3).cpu().numpy().astype(np.uint8)
176
- mask_img = mask_flat.reshape(h, w).cpu().numpy()
177
-
178
- return right_img, mask_img
179
-
180
  # === LOCAL INPAINTING ===
181
  @torch.no_grad()
182
  def run_local_lama(image_bgr, mask_float):
183
- """
184
- Runs LaMa locally.
185
- image_bgr: HxWx3 uint8 numpy array
186
- mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
187
- """
188
- # 0. Dilate Mask (Fixes smearing/streaking)
189
- kernel = np.ones((5, 5), np.uint8)
190
  mask_uint8 = (mask_float * 255).astype(np.uint8)
191
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
192
 
193
- # 1. Resize to be divisible by 8 (LaMa requirement)
194
  h, w = image_bgr.shape[:2]
195
  new_h = (h // 8) * 8
196
  new_w = (w // 8) * 8
@@ -198,7 +249,7 @@ def run_local_lama(image_bgr, mask_float):
198
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
199
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
200
 
201
- # 2. Convert to Torch Tensors
202
  img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
203
  img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB
204
 
@@ -209,17 +260,14 @@ def run_local_lama(image_bgr, mask_float):
209
  mask_t = mask_t.to(device)
210
 
211
  # 3. Inference
212
- img_t = img_t * (1 - mask_t) # Zero out holes
213
  inpainted_t = lama_model(img_t, mask_t)
214
 
215
  # 4. Post-process
216
  inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
217
  inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
218
-
219
- # Swap back RGB to BGR
220
  inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
221
 
222
- # Resize back to original
223
  if new_h != h or new_w != w:
224
  inpainted = cv2.resize(inpainted, (w, h))
225
 
@@ -237,41 +285,60 @@ def make_anaglyph(left, right):
237
  # === PIPELINE ===
238
  def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
239
  if image_pil is None:
240
- return None, None
241
 
242
- # Resize input if too large (Max Width: 1920)
243
  w, h = image_pil.size
244
  if w > 1920:
245
  ratio = 1920 / w
246
  new_h = int(h * ratio)
247
- print(f"Resizing input from {w}x{h} to 1920x{new_h}")
248
  image_pil = image_pil.resize((1920, new_h), Image.LANCZOS)
249
 
250
- # 1. Depth (Using Depth Anything V2)
251
- # Now returns a Tensor on GPU
252
  depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
253
 
254
- # 2. Depth Manipulation (Erosion)
255
- # This shrinks the foreground depth mask slightly to prevent "halo" pixels
256
- # from being pulled along with the object.
257
  if edge_erosion > 0:
258
  depth_tensor = erode_depth(depth_tensor, int(edge_erosion))
259
 
260
- # 3. Forward Warp (PyTorch)
261
- # Replaces the old NumPy warp + sorting
262
- right_img_rgb, mask = generate_right_and_mask_torch(image_pil, depth_tensor, divergence, convergence)
 
 
 
 
263
 
264
- # Convert to BGR for Inpainting
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR)
 
266
 
267
- # 4. Inpainting (Local LaMa)
268
- right_filled_bgr = run_local_lama(right_img_bgr, mask)
269
-
270
- # 5. Final Processing
271
  left = image_pil
272
  right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
273
-
274
- # 6. Composition
275
  width, height = left.size
276
  combined_image = Image.new('RGB', (width * 2, height))
277
  combined_image.paste(left, (0, 0))
@@ -279,11 +346,9 @@ def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
279
 
280
  anaglyph_image = make_anaglyph(left, right)
281
 
282
- return combined_image, anaglyph_image
283
 
284
  # === GRADIO UI ===
285
-
286
- # Custom CSS to limit width on large screens
287
  css = """
288
  .gradio-container {
289
  max-width: 1400px !important;
@@ -291,15 +356,13 @@ css = """
291
  }
292
  """
293
 
294
- with gr.Blocks(title="2D to 3D Stereo") as demo:
295
- # WORKAROUND: Inject CSS via HTML to avoid "unexpected keyword argument" error
296
  gr.HTML(f"<style>{css}</style>")
297
 
298
- gr.Markdown("## 2D to 3D Stereo Generator (Depth Anything V2)")
299
- gr.Markdown("Generates stereo pairs using **Depth Anything V2 Large**, **PyTorch Forward Warp**, and Local LaMa Inpainting.")
300
 
301
  with gr.Row():
302
- # --- LEFT COLUMN: INPUT & CONTROLS ---
303
  with gr.Column(scale=1):
304
  input_img = gr.Image(type="pil", label="Input Image", height=320)
305
 
@@ -308,30 +371,33 @@ with gr.Blocks(title="2D to 3D Stereo") as demo:
308
  divergence_slider = gr.Slider(
309
  minimum=0, maximum=100, value=30, step=1,
310
  label="3D Strength (Divergence)",
311
- info="Max pixel separation."
312
  )
313
  convergence_slider = gr.Slider(
314
- minimum=0.0, maximum=1.0, value=0.1, step=0.05,
315
  label="Focus Plane (Convergence)",
316
  info="0.0 = Background at screen. 1.0 = Foreground at screen."
317
  )
318
  erosion_slider = gr.Slider(
319
- minimum=0, maximum=20, value=5, step=1,
320
  label="Edge Masking (Erosion)",
321
- info="Shrinks foreground depth to prevent halos/ghosting. Increase if edges look messy."
322
  )
323
 
324
  btn = gr.Button("Generate 3D", variant="primary")
325
 
326
- # --- RIGHT COLUMN: OUTPUTS ---
327
  with gr.Column(scale=1):
328
  out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320)
329
  out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320)
 
 
 
 
330
 
331
  btn.click(
332
  fn=stereo_pipeline,
333
  inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
334
- outputs=[out_stereo, out_anaglyph]
335
  )
336
 
337
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn as nn
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
7
+ from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
10
  import os
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Running on device: {device}")
15
 
16
+ # ==============================================================================
17
+ # 1. FORWARD WARP IMPLEMENTATION (From forward_warp_pytorch.py)
18
+ # ==============================================================================
19
+ class ForwardWarpFunction(Function):
20
+ @staticmethod
21
+ def forward(ctx, im0, flow, interpolation_mode_int):
22
+ # Input validation
23
+ assert(len(im0.shape) == len(flow.shape) == 4)
24
+ assert(interpolation_mode_int == 0 or interpolation_mode_int == 1)
25
+ assert(im0.shape[0] == flow.shape[0])
26
+ assert(im0.shape[-2:] == flow.shape[1:3])
27
+ assert(flow.shape[3] == 2)
28
+
29
+ B, C, H, W = im0.shape
30
+ im1 = torch.zeros_like(im0, device=im0.device, dtype=im0.dtype)
31
+
32
+ # Grid creation
33
+ grid_x, grid_y = torch.meshgrid(
34
+ torch.arange(W, device=im0.device, dtype=im0.dtype),
35
+ torch.arange(H, device=im0.device, dtype=im0.dtype),
36
+ indexing='xy'
37
+ )
38
+ grid_x = grid_x.unsqueeze(0).expand(B, -1, -1)
39
+ grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
40
+
41
+ # Destination coordinates
42
+ x_dest = grid_x + flow[:, :, :, 0]
43
+ y_dest = grid_y + flow[:, :, :, 1]
44
+
45
+ if interpolation_mode_int == 0: # Bilinear Splatting
46
+ x_f = torch.floor(x_dest).long()
47
+ y_f = torch.floor(y_dest).long()
48
+ x_c = x_f + 1
49
+ y_c = y_f + 1
50
+
51
+ # Weights
52
+ nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
53
+ ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
54
+ sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
55
+ se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
56
+
57
+ # Clamp coords
58
+ x_f_clamped = torch.clamp(x_f, 0, W - 1)
59
+ y_f_clamped = torch.clamp(y_f, 0, H - 1)
60
+ x_c_clamped = torch.clamp(x_c, 0, W - 1)
61
+ y_c_clamped = torch.clamp(y_c, 0, H - 1)
62
+
63
+ # Valid mask (source pixels that land inside canvas)
64
+ valid_mask = (x_f >= 0) & (x_c < W) & (y_f >= 0) & (y_c < H)
65
+
66
+ # Reshape for broadcasting
67
+ nw_k = nw_k.unsqueeze(1)
68
+ ne_k = ne_k.unsqueeze(1)
69
+ sw_k = sw_k.unsqueeze(1)
70
+ se_k = se_k.unsqueeze(1)
71
+ valid_mask = valid_mask.unsqueeze(1)
72
+
73
+ # Flatten indices for scatter_add
74
+ b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
75
+ c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
76
+ base_idx = b_indices * (C * H * W) + c_indices * (H * W)
77
+
78
+ # Scatter to 4 neighbors (Accumulate/Splat)
79
+ def scatter_corner(y_idx, x_idx, weights):
80
+ flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
81
+ values = (im0 * weights) * valid_mask.float()
82
+ im1.view(-1).scatter_add_(0, flat_idx.view(-1), values.view(-1))
83
+
84
+ scatter_corner(y_f_clamped, x_f_clamped, nw_k) # NW
85
+ scatter_corner(y_f_clamped, x_c_clamped, ne_k) # NE
86
+ scatter_corner(y_c_clamped, x_f_clamped, sw_k) # SW
87
+ scatter_corner(y_c_clamped, x_c_clamped, se_k) # SE
88
+
89
+ else: # Nearest Neighbor (Legacy fallback)
90
+ x_nearest = torch.round(x_dest).long()
91
+ y_nearest = torch.round(y_dest).long()
92
+ valid_mask = (x_nearest >= 0) & (x_nearest < W) & (y_nearest >= 0) & (y_nearest < H)
93
+ valid_mask = valid_mask.unsqueeze(1)
94
+
95
+ x_clamped = torch.clamp(x_nearest, 0, W - 1)
96
+ y_clamped = torch.clamp(y_nearest, 0, H - 1)
97
+
98
+ b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
99
+ c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
100
+ dest_idx = b_indices*(C*H*W) + c_indices*(H*W) + y_clamped.unsqueeze(1)*W + x_clamped.unsqueeze(1)
101
+
102
+ source_values = im0 * valid_mask.float()
103
+ im1.view(-1).scatter_(0, dest_idx.view(-1), source_values.view(-1))
104
+
105
+ return im1
106
+
107
+ @staticmethod
108
+ def backward(ctx, grad_output):
109
+ # We don't need backward for inference, so we skip implementation for speed/simplicity
110
+ return None, None, None
111
+
112
+ class forward_warp(nn.Module):
113
+ def __init__(self, interpolation_mode="Bilinear"):
114
+ super(forward_warp, self).__init__()
115
+ self.interpolation_mode_int = 0 if interpolation_mode == "Bilinear" else 1
116
+
117
+ def forward(self, im0, flow):
118
+ return ForwardWarpFunction.apply(im0, flow, self.interpolation_mode_int)
119
+
120
+ # ==============================================================================
121
+ # 2. STEREO WARPER (From splatting_gui.py)
122
+ # ==============================================================================
123
+ class ForwardWarpStereo(nn.Module):
124
+ """
125
+ Weighted Splatting wrapper.
126
+ Handles Occlusions using exponential depth weights (Soft Z-Buffering).
127
+ """
128
+ def __init__(self, eps=1e-6):
129
+ super(ForwardWarpStereo, self).__init__()
130
+ self.eps = eps
131
+ self.fw = forward_warp(interpolation_mode="Bilinear")
132
+
133
+ def forward(self, im, disp, convergence, divergence):
134
+ # Create Flow from Disparity
135
+ # Shift = (Depth - Convergence) * Divergence
136
+ # We negate it because standard flow is source->dest, but disparity logic varies.
137
+ # For Right Eye view: Target = Source - Shift. So Flow = -Shift.
138
+ shift = (disp - convergence) * divergence
139
+ flow_x = -shift
140
+
141
+ # Stack flow (x, y=0) -> (B, H, W, 2)
142
+ flow_y = torch.zeros_like(flow_x)
143
+ flow = torch.stack((flow_x, flow_y), dim=-1).permute(0, 2, 3, 1) # (B, H, W, 2)
144
+
145
+ # 1. Calculate Weights (Soft Z-Buffer)
146
+ # Closer objects (higher disparity) get exponentially higher weight.
147
+ # This allows foreground to overwrite background during accumulation.
148
+ # Using 1.414^disp (or similar base) is a common heuristic.
149
+ weights_map = disp - disp.min()
150
+ weights_map = (1.5) ** weights_map # Tuned base for separation
151
+
152
+ # 2. Warp Image * Weights (Accumulate Weighted Color)
153
+ # Input im is (B, C, H, W), weights is (B, 1, H, W)
154
+ res_accum = self.fw(im * weights_map, flow)
155
+
156
+ # 3. Warp Weights (Accumulate Weights)
157
+ mask_accum = self.fw(weights_map, flow)
158
+
159
+ # 4. Normalize (Color / TotalWeight)
160
+ # Add epsilon to avoid divide-by-zero in empty regions
161
+ mask_accum.clamp_(min=self.eps)
162
+ res = res_accum / mask_accum
163
+
164
+ # 5. Generate Binary Occlusion Mask (for Inpainting)
165
+ # Splat a grid of ones. Where sum is 0, we have a hole.
166
+ ones = torch.ones_like(disp)
167
+ occupancy = self.fw(ones, flow)
168
+
169
+ # Valid pixels have occupancy > 0.
170
+ # We want holes = 1.0, filled = 0.0
171
+ occlusion_mask = (occupancy < self.eps).float()
172
+
173
+ return res, occlusion_mask
174
+
175
+ # ==============================================================================
176
+ # 3. APP LOGIC
177
+ # ==============================================================================
178
+
179
  # === LOAD MODELS ===
180
  def load_models():
181
  print("Loading Depth Anything V2 Large...")
 
 
182
  depth_model = AutoModelForDepthEstimation.from_pretrained(
183
  "depth-anything/Depth-Anything-V2-Large-hf"
184
  ).to(device)
 
187
  )
188
 
189
  print("Loading LaMa Inpainting Model...")
 
190
  try:
191
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
 
192
  lama_model = torch.jit.load(model_path, map_location=device)
193
  lama_model.eval()
194
  except Exception as e:
195
  print(f"Error loading LaMa model: {e}")
196
  raise e
197
 
198
+ # Initialize the new Stereo Warper
199
+ stereo_warper = ForwardWarpStereo().to(device)
200
+
201
+ return depth_model, depth_processor, lama_model, stereo_warper
202
 
203
  # Load models once at startup
204
+ depth_model, depth_processor, lama_model, stereo_warper = load_models()
205
 
206
  # === DEPTH ESTIMATION ===
207
  @torch.no_grad()
208
  def estimate_depth(image_pil, model, processor):
209
  original_size = image_pil.size
 
 
210
  inputs = processor(images=image_pil, return_tensors="pt").to(device)
 
 
211
  depth = model(**inputs).predicted_depth
212
 
 
213
  depth = torch.nn.functional.interpolate(
214
  depth.unsqueeze(1),
215
  size=(original_size[1], original_size[0]),
216
  mode="bicubic",
217
  align_corners=False,
218
+ ).squeeze()
219
 
 
220
  depth_min, depth_max = depth.min(), depth.max()
221
  if depth_max - depth_min > 0:
222
  depth = (depth - depth_min) / (depth_max - depth_min)
223
  else:
224
  depth = torch.zeros_like(depth)
 
225
  return depth
226
 
227
  # === DEPTH MANIPULATION ===
228
  def erode_depth(depth_tensor, kernel_size):
229
+ if kernel_size <= 0: return depth_tensor
 
 
 
 
 
 
 
230
  k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
 
 
231
  x = depth_tensor.unsqueeze(0).unsqueeze(0)
 
 
 
232
  padding = k // 2
233
  x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
 
234
  return x_eroded.squeeze()
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # === LOCAL INPAINTING ===
237
  @torch.no_grad()
238
  def run_local_lama(image_bgr, mask_float):
239
+ # 0. Dilate Mask slightly to catch edge artifacts from splatting
240
+ kernel = np.ones((3, 3), np.uint8)
 
 
 
 
 
241
  mask_uint8 = (mask_float * 255).astype(np.uint8)
242
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
243
 
244
+ # 1. Resize to be divisible by 8
245
  h, w = image_bgr.shape[:2]
246
  new_h = (h // 8) * 8
247
  new_w = (w // 8) * 8
 
249
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
250
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
251
 
252
+ # 2. Convert to Torch
253
  img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
254
  img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB
255
 
 
260
  mask_t = mask_t.to(device)
261
 
262
  # 3. Inference
263
+ img_t = img_t * (1 - mask_t)
264
  inpainted_t = lama_model(img_t, mask_t)
265
 
266
  # 4. Post-process
267
  inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
268
  inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
 
 
269
  inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
270
 
 
271
  if new_h != h or new_w != w:
272
  inpainted = cv2.resize(inpainted, (w, h))
273
 
 
285
  # === PIPELINE ===
286
  def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
287
  if image_pil is None:
288
+ return None, None, None, None
289
 
290
+ # Resize input if too large
291
  w, h = image_pil.size
292
  if w > 1920:
293
  ratio = 1920 / w
294
  new_h = int(h * ratio)
 
295
  image_pil = image_pil.resize((1920, new_h), Image.LANCZOS)
296
 
297
+ # 1. Depth Estimation
 
298
  depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
299
 
300
+ # 2. Depth Erosion (optional halo reduction)
 
 
301
  if edge_erosion > 0:
302
  depth_tensor = erode_depth(depth_tensor, int(edge_erosion))
303
 
304
+ # Visualize Depth
305
+ depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8)
306
+ depth_image = Image.fromarray(depth_vis)
307
+
308
+ # 3. Forward Warp (Weighted Bilinear Splatting)
309
+ # Convert image to tensor (B, C, H, W)
310
+ image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2, 0, 1).unsqueeze(0) / 255.0
311
 
312
+ # Prepare depth tensor (B, 1, H, W)
313
+ depth_input = depth_tensor.unsqueeze(0).unsqueeze(0)
314
+
315
+ # Run the new Stereo Warper
316
+ # Note: We scale divergence by width/100 to make the slider roughly %-based or consistent pixels
317
+ # Or keep raw pixels. Let's keep raw pixels as user requested previously.
318
+ with torch.no_grad():
319
+ right_img_tensor, mask_tensor = stereo_warper(
320
+ image_tensor,
321
+ depth_input,
322
+ float(convergence),
323
+ float(divergence)
324
+ )
325
+
326
+ # Convert results back to CPU/Numpy
327
+ right_img_rgb = (right_img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
328
+ mask_vis = (mask_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255).astype(np.uint8)
329
+
330
+ mask_image = Image.fromarray(mask_vis)
331
+
332
+ # 4. Inpainting
333
  right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR)
334
+ mask_float = mask_tensor.squeeze().cpu().numpy()
335
 
336
+ right_filled_bgr = run_local_lama(right_img_bgr, mask_float)
337
+
338
+ # 5. Finalize
 
339
  left = image_pil
340
  right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
341
+
 
342
  width, height = left.size
343
  combined_image = Image.new('RGB', (width * 2, height))
344
  combined_image.paste(left, (0, 0))
 
346
 
347
  anaglyph_image = make_anaglyph(left, right)
348
 
349
+ return combined_image, anaglyph_image, depth_image, mask_image
350
 
351
  # === GRADIO UI ===
 
 
352
  css = """
353
  .gradio-container {
354
  max-width: 1400px !important;
 
356
  }
357
  """
358
 
359
+ with gr.Blocks(title="2D to 3D Stereo", css=css) as demo:
 
360
  gr.HTML(f"<style>{css}</style>")
361
 
362
+ gr.Markdown("## 2D to 3D Stereo Generator (High-Quality Splatting)")
363
+ gr.Markdown("Uses **Depth Anything V2**, **Bilinear Weighted Splatting** (Soft Z-Buffer), and **LaMa Inpainting**.")
364
 
365
  with gr.Row():
 
366
  with gr.Column(scale=1):
367
  input_img = gr.Image(type="pil", label="Input Image", height=320)
368
 
 
371
  divergence_slider = gr.Slider(
372
  minimum=0, maximum=100, value=30, step=1,
373
  label="3D Strength (Divergence)",
374
+ info="Max separation in pixels."
375
  )
376
  convergence_slider = gr.Slider(
377
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
378
  label="Focus Plane (Convergence)",
379
  info="0.0 = Background at screen. 1.0 = Foreground at screen."
380
  )
381
  erosion_slider = gr.Slider(
382
+ minimum=0, maximum=20, value=2, step=1,
383
  label="Edge Masking (Erosion)",
384
+ info="Cleanup edges. Set to 0 for raw splatting."
385
  )
386
 
387
  btn = gr.Button("Generate 3D", variant="primary")
388
 
 
389
  with gr.Column(scale=1):
390
  out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320)
391
  out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320)
392
+
393
+ with gr.Row():
394
+ out_depth = gr.Image(label="Depth Map", height=200)
395
+ out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200)
396
 
397
  btn.click(
398
  fn=stereo_pipeline,
399
  inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
400
+ outputs=[out_stereo, out_anaglyph, out_depth, out_mask]
401
  )
402
 
403
  if __name__ == "__main__":