enoky commited on
Commit
be50bae
·
verified ·
1 Parent(s): b5cd334

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -135
app.py CHANGED
@@ -7,14 +7,13 @@ from PIL import Image
7
  from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
10
- import os
11
 
12
  # === DEVICE ===
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Running on device: {device}")
15
 
16
  # ==============================================================================
17
- # 1. FIXED FORWARD WARP WITH BILINEAR SPLATTING (Contiguous & Stable)
18
  # ==============================================================================
19
  class ForwardWarpFunction(Function):
20
  @staticmethod
@@ -26,7 +25,6 @@ class ForwardWarpFunction(Function):
26
  B, C, H, W = im0.shape
27
  im1 = torch.zeros_like(im0)
28
 
29
- # Grid: [B, H, W]
30
  grid_y, grid_x = torch.meshgrid(
31
  torch.arange(H, device=im0.device, dtype=torch.float32),
32
  torch.arange(W, device=im0.device, dtype=torch.float32),
@@ -44,51 +42,44 @@ class ForwardWarpFunction(Function):
44
  x1 = x0 + 1
45
  y1 = y0 + 1
46
 
47
- # Bilinear weights
48
- w00 = (x1.float() - x_dest) * (y1.float() - y_dest) # top-left
49
- w10 = (x_dest - x0.float()) * (y1.float() - y_dest) # top-right
50
- w01 = (x1.float() - x_dest) * (y_dest - y0.float()) # bottom-left
51
- w11 = (x_dest - x0.float()) * (y_dest - y0.float()) # bottom-right
52
 
53
- # Clamp coordinates
54
  x0c = x0.clamp(0, W - 1)
55
  y0c = y0.clamp(0, H - 1)
56
  x1c = x1.clamp(0, W - 1)
57
  y1c = y1.clamp(0, H - 1)
58
 
59
- valid = (x0 >= 0) & (x1 < W) & (y0 >= 0) & (y1 < H) # [B, H, W]
60
 
61
- # Ensure contiguous
62
  im0 = im0.contiguous()
63
- valid = valid.unsqueeze(1).float() # [B, 1, H, W]
64
 
65
  def splat(y_idx, x_idx, weight):
66
- weight = (weight.unsqueeze(1) * valid).contiguous() # [B,1,H,W]
67
- values = (im0 * weight).reshape(B * C, -1) # [B*C, H*W]
68
 
69
- # Compute flat indices: B,C,H,W → global index
70
- b_idx = torch.arange(B, device=im0.device).view(B, 1, 1, 1)
71
- c_idx = torch.arange(C, device=im0.device).view(1, C, 1, 1)
72
- base = (b_idx * C * H * W + c_idx * H * W).expand(-1, -1, H, W)
73
 
74
- idx = base + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
75
- idx = idx.reshape(B * C, -1).contiguous()
76
-
77
- im1.view(-1).scatter_add_(0, idx.view(-1), values.view(-1))
78
 
79
  splat(y0c, x0c, w00)
80
  splat(y0c, x1c, w10)
81
  splat(y1c, x0c, w01)
82
  splat(y1c, x1c, w11)
83
 
84
- else: # Nearest neighbor (fallback)
85
  x_nn = torch.round(x_dest).long().clamp(0, W - 1)
86
  y_nn = torch.round(y_dest).long().clamp(0, H - 1)
87
 
88
  b_idx = torch.arange(B, device=im0.device)[:, None, None, None]
89
  c_idx = torch.arange(C, device=im0.device)[None, :, None, None]
90
- idx = (b_idx * C * H * W + c_idx * H * W + y_nn.unsqueeze(1) * W + x_nn.unsqueeze(1))
91
- idx = idx.reshape(-1)
92
 
93
  valid = ((x_nn >= 0) & (x_nn < W) & (y_nn >= 0) & (y_nn < H)).unsqueeze(1)
94
  values = (im0 * valid.float()).reshape(-1)
@@ -112,7 +103,7 @@ class forward_warp(nn.Module):
112
 
113
 
114
  # ==============================================================================
115
- # 2. STEREO WARPER (Soft Z-Buffer Splatting)
116
  # ==============================================================================
117
  class ForwardWarpStereo(nn.Module):
118
  def __init__(self, eps=1e-6):
@@ -121,23 +112,19 @@ class ForwardWarpStereo(nn.Module):
121
  self.fw = forward_warp(interpolation_mode="Bilinear")
122
 
123
  def forward(self, im, disp, convergence, divergence):
124
- disp = disp.squeeze(1) # [B, H, W]
125
  shift = (disp - convergence) * divergence
126
  flow_x = -shift
127
- flow = torch.zeros_like(flow_x)
128
- flow = torch.stack([flow_x, flow_y], dim=-1) # [B, H, W, 2]
129
 
130
- # Soft Z-buffer weights (closer = higher weight)
131
  weights = (1.5) ** (disp - disp.min())
132
 
133
- # Warp color * weight
134
  accum_color = self.fw(im * weights.unsqueeze(1), flow)
135
  accum_weight = self.fw(weights.unsqueeze(1), flow)
136
 
137
- # Normalize
138
  result = accum_color / (accum_weight + self.eps)
139
 
140
- # Occlusion mask (holes)
141
  ones = torch.ones_like(disp)
142
  occupancy = self.fw(ones.unsqueeze(1), flow)
143
  occlusion_mask = (occupancy < self.eps).float()
@@ -146,7 +133,7 @@ class ForwardWarpStereo(nn.Module):
146
 
147
 
148
  # ==============================================================================
149
- # 3. MODELS & PIPELINE
150
  # ==============================================================================
151
  def load_models():
152
  print("Loading Depth Anything V2 Large...")
@@ -158,30 +145,28 @@ def load_models():
158
  "depth-anything/Depth-Anything-V2-Large-hf"
159
  )
160
 
161
- print("Loading LaMa Inpainting Model...")
162
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
163
  lama_model = torch.jit.load(model_path, map_location=device).eval()
164
 
165
  stereo_warper = ForwardWarpStereo().to(device)
166
-
167
  return depth_model, depth_processor, lama_model, stereo_warper
168
 
169
 
170
- # Load once at startup
171
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
172
 
173
 
 
 
 
174
  @torch.inference_mode()
175
  def estimate_depth(image_pil):
176
- original_size = image_pil.size
177
  inputs = depth_processor(images=image_pil, return_tensors="pt").to(device)
178
- depth = depth_model(**inputs).predicted_depth # [1, H, W]
179
 
180
  depth = torch.nn.functional.interpolate(
181
- depth.unsqueeze(1),
182
- size=(original_size[1], original_size[0]),
183
- mode="bicubic",
184
- align_corners=False,
185
  ).squeeze()
186
 
187
  depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
@@ -193,148 +178,112 @@ def erode_depth(depth_tensor, kernel_size):
193
  return depth_tensor
194
  k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
195
  x = depth_tensor.unsqueeze(0).unsqueeze(0)
196
- padding = k // 2
197
- eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
198
- return eroded.squeeze()
199
 
200
 
201
  @torch.inference_mode()
202
  def run_local_lama(image_bgr, mask_float):
203
- kernel = np.ones((3, 3), np.uint8)
204
- mask_uint8 = (mask_float * 255).astype(np.uint8)
205
- mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
206
 
207
  h, w = image_bgr.shape[:2]
208
- new_h = (h // 8) * 8
209
- new_w = (w // 8) * 8
210
 
211
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
212
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
213
 
214
- img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
215
- img_t = img_t[:, [2, 1, 0], :, :] # BGR RGB
216
- mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
217
- mask_t = (mask_t > 0.5).float()
218
 
219
- img_t = img_t.to(device)
220
- mask_t = mask_t.to(device)
221
 
222
- img_t = img_t * (1 - mask_t)
223
- with torch.no_grad():
224
- inpainted_t = lama_model(img_t, mask_t)
225
-
226
- inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
227
- inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
228
- inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
229
-
230
- if (new_h != h) or (new_w != w):
231
- inpainted = cv2.resize(inpainted, (w, h), interpolation=cv2.INTER_LANCZOS4)
232
-
233
- return inpainted
234
 
235
 
236
  def make_anaglyph(left, right):
237
  l = np.array(left)
238
  r = np.array(right)
239
- anaglyph = np.zeros_like(l)
240
- anaglyph[:, :, 0] = l[:, :, 0] # Red ← Left
241
- anaglyph[:, :, 1] = r[:, :, 1] # Green ← Right
242
- anaglyph[:, :, 2] = r[:, :, 2] # Blue ← Right
243
- return Image.fromarray(anaglyph)
244
 
245
 
246
- # ==============================================================================
247
- # MAIN PIPELINE
248
- # ==============================================================================
249
  def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
250
  if image_pil is None:
251
  return None, None, None, None
252
 
253
- # Resize if too large (HF Spaces limit)
254
- w, h = image_pil.size
255
- if w > 1920:
256
- ratio = 1920 / w
257
- image_pil = image_pil.resize((1920, int(h * ratio)), Image.LANCZOS)
258
 
259
- # 1. Depth
260
  depth = estimate_depth(image_pil)
261
  if edge_erosion > 0:
262
  depth = erode_depth(depth, int(edge_erosion))
263
 
264
- depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8))
265
 
266
- # 2. Prepare tensors
267
- img_tensor = torch.from_numpy(np.array(image_pil)).float().to(device)
268
- img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0 # [1,3,H,W]
269
- depth_tensor = depth.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
270
 
271
- # 3. Stereo warp
272
  with torch.inference_mode():
273
- right_tensor, mask_tensor = stereo_warper(
274
- img_tensor, depth_tensor, float(convergence), float(divergence)
275
- )
276
 
277
- right_np = (right_tensor.squeeze(0).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
278
- mask_np = (mask_tensor.squeeze().cpu().numpy() * 255).astype(np.uint8)
279
 
280
- # 4. Inpaint holes
281
  right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
282
- mask_float = mask_tensor.squeeze().cpu().numpy()
283
- right_filled_bgr = run_local_lama(right_bgr, mask_float)
284
- right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
285
 
286
- # 5. Outputs
287
  w, h = image_pil.size
288
- sbs = Image.new("RGB", (w * 2, h))
289
- sbs.paste(image_pil, (0, 0))
290
- sbs.paste(right_filled, (w, 0))
291
 
292
- anaglyph = make_anaglyph(image_pil, right_filled)
293
 
294
  return sbs, anaglyph, depth_vis, Image.fromarray(mask_np)
295
 
296
 
297
  # ==============================================================================
298
- # GRADIO UI
299
  # ==============================================================================
300
- css = ".gradio-container {max-width: 1400px !important; margin: auto !important;}"
 
301
 
302
- with gr.Blocks(css=css, title="2D → 3D Stereo (Depth Anything + Splatting)") as demo:
303
  gr.Markdown("# 2D to 3D Stereo Generator")
304
- gr.Markdown("High-quality automatic stereo conversion using **Depth Anything V2**, **bilinear splatting with soft Z-buffer**, and **LaMa inpainting**.")
305
 
306
  with gr.Row():
307
- with gr.Column(scale=1):
308
- input_img = gr.Image(type="pil", label="Upload Image", height=400)
309
-
310
- with gr.Accordion("3D Settings", open=True):
311
- divergence_slider = gr.Slider(0, 100, value=30, step=1,
312
- label="3D Strength (Divergence)",
313
- info="Higher = stronger 3D pop-out")
314
- convergence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.05,
315
- label="Focus Plane (Convergence)",
316
- info="0 = background at screen, 1 = foreground at screen")
317
- erosion_slider = gr.Slider(0, 20, value=3, step=1,
318
- label="Edge Cleanup (Depth Erosion)",
319
- info="Reduces halos, 0 = raw")
320
-
321
- btn = gr.Button("Generate 3D Stereo", variant="primary", size="lg")
322
-
323
- with gr.Column(scale=1):
324
- out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=400)
325
- out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=400)
326
 
327
- with gr.Row():
328
- out_depth = gr.Image(label="Estimated Depth Map", height=200)
329
- out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200)
 
330
 
331
- btn.click(
332
- fn=stereo_pipeline,
333
- inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
334
- outputs=[out_stereo, out_anaglyph, out_depth, out_mask]
335
- )
 
 
 
 
336
 
337
- gr.Markdown("Made with Depth Anything V2 Bilinear Splatting • LaMa • Gradio")
338
 
339
  if __name__ == "__main__":
340
  demo.launch()
 
7
  from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
 
10
 
11
  # === DEVICE ===
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Running on device: {device}")
14
 
15
  # ==============================================================================
16
+ # 1. FIXED FORWARD WARP (Bilinear Splatting fully contiguous)
17
  # ==============================================================================
18
  class ForwardWarpFunction(Function):
19
  @staticmethod
 
25
  B, C, H, W = im0.shape
26
  im1 = torch.zeros_like(im0)
27
 
 
28
  grid_y, grid_x = torch.meshgrid(
29
  torch.arange(H, device=im0.device, dtype=torch.float32),
30
  torch.arange(W, device=im0.device, dtype=torch.float32),
 
42
  x1 = x0 + 1
43
  y1 = y0 + 1
44
 
45
+ w00 = (x1.float() - x_dest) * (y1.float() - y_dest)
46
+ w10 = (x_dest - x0.float()) * (y1.float() - y_dest)
47
+ w01 = (x1.float() - x_dest) * (y_dest - y0.float())
48
+ w11 = (x_dest - x0.float()) * (y_dest - y0.float())
 
49
 
 
50
  x0c = x0.clamp(0, W - 1)
51
  y0c = y0.clamp(0, H - 1)
52
  x1c = x1.clamp(0, W - 1)
53
  y1c = y1.clamp(0, H - 1)
54
 
55
+ valid = (x0 >= 0) & (x1 < W) & (y0 >= 0) & (y1 < H)
56
 
 
57
  im0 = im0.contiguous()
58
+ valid = valid.unsqueeze(1).float()
59
 
60
  def splat(y_idx, x_idx, weight):
61
+ weight = (weight.unsqueeze(1) * valid).contiguous()
62
+ values = (im0 * weight).reshape(B * C, -1)
63
 
64
+ b_idx = torch.arange(B, device=im0.device)[:, None, None, None]
65
+ c_idx = torch.arange(C, device=im0.device)[None, :, None, None]
66
+ base = b_idx * C * H * W + c_idx * H * W
67
+ idx = (base + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)).reshape(B * C, -1)
68
 
69
+ im1.view(-1).scatter_add_(0, idx.reshape(-1), values.reshape(-1))
 
 
 
70
 
71
  splat(y0c, x0c, w00)
72
  splat(y0c, x1c, w10)
73
  splat(y1c, x0c, w01)
74
  splat(y1c, x1c, w11)
75
 
76
+ else: # Nearest neighbor fallback
77
  x_nn = torch.round(x_dest).long().clamp(0, W - 1)
78
  y_nn = torch.round(y_dest).long().clamp(0, H - 1)
79
 
80
  b_idx = torch.arange(B, device=im0.device)[:, None, None, None]
81
  c_idx = torch.arange(C, device=im0.device)[None, :, None, None]
82
+ idx = (b_idx * C * H * W + c_idx * H * W + y_nn.unsqueeze(1) * W + x_nn.unsqueeze(1)).reshape(-1)
 
83
 
84
  valid = ((x_nn >= 0) & (x_nn < W) & (y_nn >= 0) & (y_nn < H)).unsqueeze(1)
85
  values = (im0 * valid.float()).reshape(-1)
 
103
 
104
 
105
  # ==============================================================================
106
+ # 2. STEREO WARPER
107
  # ==============================================================================
108
  class ForwardWarpStereo(nn.Module):
109
  def __init__(self, eps=1e-6):
 
112
  self.fw = forward_warp(interpolation_mode="Bilinear")
113
 
114
  def forward(self, im, disp, convergence, divergence):
115
+ disp = disp.squeeze(1)
116
  shift = (disp - convergence) * divergence
117
  flow_x = -shift
118
+ flow_y = torch.zeros_like(flow_x)
119
+ flow = torch.stack([flow_x, flow_y], dim=-1)
120
 
 
121
  weights = (1.5) ** (disp - disp.min())
122
 
 
123
  accum_color = self.fw(im * weights.unsqueeze(1), flow)
124
  accum_weight = self.fw(weights.unsqueeze(1), flow)
125
 
 
126
  result = accum_color / (accum_weight + self.eps)
127
 
 
128
  ones = torch.ones_like(disp)
129
  occupancy = self.fw(ones.unsqueeze(1), flow)
130
  occlusion_mask = (occupancy < self.eps).float()
 
133
 
134
 
135
  # ==============================================================================
136
+ # 3. MODELS
137
  # ==============================================================================
138
  def load_models():
139
  print("Loading Depth Anything V2 Large...")
 
145
  "depth-anything/Depth-Anything-V2-Large-hf"
146
  )
147
 
148
+ print("Loading LaMa Inpainting...")
149
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
150
  lama_model = torch.jit.load(model_path, map_location=device).eval()
151
 
152
  stereo_warper = ForwardWarpStereo().to(device)
 
153
  return depth_model, depth_processor, lama_model, stereo_warper
154
 
155
 
 
156
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
157
 
158
 
159
+ # ==============================================================================
160
+ # 4. PIPELINE
161
+ # ==============================================================================
162
  @torch.inference_mode()
163
  def estimate_depth(image_pil):
164
+ w, h = image_pil.size
165
  inputs = depth_processor(images=image_pil, return_tensors="pt").to(device)
166
+ depth = depth_model(**inputs).predicted_depth
167
 
168
  depth = torch.nn.functional.interpolate(
169
+ depth.unsqueeze(1), size=(h, w), mode="bicubic", align_corners=False
 
 
 
170
  ).squeeze()
171
 
172
  depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
 
178
  return depth_tensor
179
  k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
180
  x = depth_tensor.unsqueeze(0).unsqueeze(0)
181
+ return -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=k//2).squeeze()
 
 
182
 
183
 
184
  @torch.inference_mode()
185
  def run_local_lama(image_bgr, mask_float):
186
+ kernel = np.ones((3,3), np.uint8)
187
+ mask_dilated = cv2.dilate((mask_float*255).astype(np.uint8), kernel, iterations=1)
 
188
 
189
  h, w = image_bgr.shape[:2]
190
+ new_h, new_w = (h//8)*8, (w//8)*8
 
191
 
192
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
193
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
194
 
195
+ img_t = torch.from_numpy(img_resized).float().permute(2,0,1).unsqueeze(0)/255.0
196
+ img_t = img_t[:, [2,1,0]] # BGR→RGB
197
+ mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)/255.0 > 0.5
 
198
 
199
+ img_t = img_t.to(device) * (1 - mask_t.to(device))
200
+ inpainted = lama_model(img_t.to(device), mask_t.to(device))
201
 
202
+ out = (inpainted[0].permute(1,2,0).cpu().numpy()*255).clip(0,255).astype(np.uint8)
203
+ out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
204
+ if (new_h, new_w) != (h, w):
205
+ out = cv2.resize(out, (w, h), interpolation=cv2.INTER_LANCZOS4)
206
+ return out
 
 
 
 
 
 
 
207
 
208
 
209
  def make_anaglyph(left, right):
210
  l = np.array(left)
211
  r = np.array(right)
212
+ a = np.zeros_like(l)
213
+ a[:,:,0] = l[:,:,0]
214
+ a[:,:,1] = r[:,:,1]
215
+ a[:,:,2] = r[:,:,2]
216
+ return Image.fromarray(a)
217
 
218
 
 
 
 
219
  def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
220
  if image_pil is None:
221
  return None, None, None, None
222
 
223
+ # Downscale huge images
224
+ if image_pil.width > 1920:
225
+ ratio = 1920 / image_pil.width
226
+ image_pil = image_pil.resize((1920, int(image_pil.height*ratio)), Image.LANCZOS)
 
227
 
 
228
  depth = estimate_depth(image_pil)
229
  if edge_erosion > 0:
230
  depth = erode_depth(depth, int(edge_erosion))
231
 
232
+ depth_vis = Image.fromarray((depth.cpu().numpy()*255).astype(np.uint8))
233
 
234
+ img_t = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2,0,1).unsqueeze(0)/255.0
235
+ depth_t = depth.unsqueeze(0).unsqueeze(0)
 
 
236
 
 
237
  with torch.inference_mode():
238
+ right_t, mask_t = stereo_warper(img_t, depth_t, float(convergence), float(divergence))
 
 
239
 
240
+ right_np = (right_t[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
241
+ mask_np = (mask_t[0,0].cpu().numpy()*255).astype(np.uint8)
242
 
 
243
  right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
244
+ right_filled = run_local_lama(right_bgr, mask_t[0,0].cpu().numpy())
245
+ right_pil = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
 
246
 
247
+ # Side-by-side
248
  w, h = image_pil.size
249
+ sbs = Image.new("RGB", (w*2, h))
250
+ sbs.paste(image_pil, (0,0))
251
+ sbs.paste(right_pil, (w,0))
252
 
253
+ anaglyph = make_anaglyph(image_pil, right_pil)
254
 
255
  return sbs, anaglyph, depth_vis, Image.fromarray(mask_np)
256
 
257
 
258
  # ==============================================================================
259
+ # GRADIO UI (compatible with latest Gradio)
260
  # ==============================================================================
261
+ with gr.Blocks(title="2D to 3D Stereo Depth Anything + Splatting") as demo:
262
+ gr.HTML("<style>.gradio-container {max-width: 1400px !important; margin: auto !important;}</style>")
263
 
 
264
  gr.Markdown("# 2D to 3D Stereo Generator")
265
+ gr.Markdown("Depth Anything V2 + Bilinear Splatting + LaMa Inpainting beautiful 3D")
266
 
267
  with gr.Row():
268
+ with gr.Column():
269
+ inp = gr.Image(type="pil", label="Input Image", height=400)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ with gr.Accordion("Settings", open=True):
272
+ div = gr.Slider(0, 100, 30, step=1, label="3D Strength (Divergence)")
273
+ conv = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Focus Plane (Convergence)")
274
+ ero = gr.Slider(0, 20, 3, step=1, label="Edge Cleanup (Erosion)")
275
 
276
+ btn = gr.Button("Generate 3D", variant="primary")
277
+
278
+ with gr.Column():
279
+ out_sbs = gr.Image(label="Side-by-Side", height=400)
280
+ out_ana = gr.Image(label="Anaglyph (Red/Cyan)", height=400)
281
+
282
+ with gr.Row():
283
+ out_depth = gr.Image(label="Depth Map")
284
+ out_mask = gr.Image(label="Holes Mask")
285
 
286
+ btn.click(stereo_pipeline, [inp, div, conv, ero], [out_sbs, out_ana, out_depth, out_mask])
287
 
288
  if __name__ == "__main__":
289
  demo.launch()