enoky commited on
Commit
549ff77
Β·
verified Β·
1 Parent(s): 18f40ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -186
app.py CHANGED
@@ -4,7 +4,6 @@ import torch.nn as nn
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
7
- from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
10
  import os
@@ -14,130 +13,100 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Running on device: {device}")
15
 
16
  # ==============================================================================
17
- # 1. FORWARD WARP (unchanged β€” your version was already excellent)
18
  # ==============================================================================
19
- class ForwardWarpFunction(Function):
20
- @staticmethod
21
- def forward(ctx, im0, flow, interpolation_mode_int):
22
- B, C, H, W = im0.shape
23
- im1 = torch.zeros_like(im0, device=im0.device, dtype=im0.dtype).contiguous()
24
-
25
- grid_x = torch.arange(W, device=im0.device, dtype=im0.dtype).unsqueeze(0).expand(H, W)
26
- grid_y = torch.arange(H, device=im0.device, dtype=im0.dtype).unsqueeze(1).expand(H, W)
27
- grid_x = grid_x.unsqueeze(0).expand(B, H, W)
28
- grid_y = grid_y.unsqueeze(0).expand(B, H, W)
29
-
30
- x_dest = grid_x + flow[:, :, :, 0]
31
- y_dest = grid_y + flow[:, :, :, 1]
32
-
33
- x_f = torch.floor(x_dest).long()
34
- y_f = torch.floor(y_dest).long()
35
- x_c = x_f + 1
36
- y_c = y_f + 1
37
-
38
- nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
39
- ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
40
- sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
41
- se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
42
-
43
- x_f_clamped = torch.clamp(x_f, 0, W - 1)
44
- y_f_clamped = torch.clamp(y_f, 0, H - 1)
45
- x_c_clamped = torch.clamp(x_c, 0, W - 1)
46
- y_c_clamped = torch.clamp(y_c, 0, H - 1)
47
-
48
- mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
49
- mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
50
- mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
51
- mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
52
-
53
- nw_k = nw_k.unsqueeze(1)
54
- ne_k = ne_k.unsqueeze(1)
55
- sw_k = sw_k.unsqueeze(1)
56
- se_k = se_k.unsqueeze(1)
57
- mask_nw = mask_nw.unsqueeze(1)
58
- mask_ne = mask_ne.unsqueeze(1)
59
- mask_sw = mask_sw.unsqueeze(1)
60
- mask_se = mask_se.unsqueeze(1)
61
-
62
- b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
63
- c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
64
- base_idx = b_indices * (C * H * W) + c_indices * (H * W)
65
-
66
- def scatter_corner(y_idx, x_idx, weights, mask):
67
- flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
68
- values = (im0 * weights) * mask.float()
69
- im1.reshape(-1).scatter_add_(0, flat_idx.contiguous().reshape(-1), values.contiguous().reshape(-1))
70
-
71
- scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw)
72
- scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne)
73
- scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw)
74
- scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se)
75
-
76
- return im1
77
-
78
- @staticmethod
79
- def backward(ctx, grad_output):
80
- return None, None, None
81
-
82
- class forward_warp(nn.Module):
83
- def __init__(self): super().__init__()
84
- def forward(self, im0, flow):
85
- return ForwardWarpFunction.apply(im0, flow, 0)
86
 
87
  # ==============================================================================
88
- # 2. STEREO WARPER β€” FIXED Z-BUFFER + SMART MASK DILATION
89
  # ==============================================================================
90
  class ForwardWarpStereo(nn.Module):
91
  def __init__(self, eps=1e-6):
92
  super().__init__()
93
  self.eps = eps
94
- self.fw = forward_warp()
95
 
96
- def forward(self, im, shift, disp_for_weights):
97
- flow_x = -shift
 
98
  flow_y = torch.zeros_like(flow_x)
99
- flow = torch.stack((flow_x, flow_y), dim=-1)
100
 
101
- # ────── FIXED: Linear + bias weights (no more detached arms) ──────
102
- disp_norm = disp_for_weights / (disp_for_weights.max() + 1e-8)
103
- weights_map = disp_norm + 0.05
104
- # ─────────────────────────────────────────────────────────────────────
105
 
106
- res_accum = self.fw(im * weights_map.unsqueeze(1), flow)
107
- mask_accum = self.fw(weights_map.unsqueeze(1), flow)
108
- mask_accum.clamp_(min=self.eps)
109
- res = res_accum / mask_accum
110
 
111
- # Occupancy for occlusion detection
112
- ones = torch.ones_like(im[:,0:1,:,:])
113
- occupancy = self.fw(ones, flow)
114
- occlusion_mask = (occupancy < self.eps).float()
115
 
116
- # ────── NEW: Smart, foreground-preserving mask dilation ──────
 
 
 
 
 
117
  with torch.no_grad():
118
- # Protect clear foreground from over-dilation
119
- fg_thresh = torch.quantile(disp_for_weights, 0.88)
120
  fg_mask = (disp_for_weights > fg_thresh).float().unsqueeze(0)
121
 
122
- # Aggressive but safe dilation
123
- k = 15
124
  dilated = torch.nn.functional.conv2d(
125
- occlusion_mask, torch.ones(1,1,k,k,device=occlusion_mask.device),
126
- padding=k//2) > 0.1
 
 
127
  safe_dilation = dilated.float() * (1 - fg_mask)
128
- occlusion_mask = torch.clamp(occlusion_mask + safe_dilation, 0, 1)
129
- # ─────────────────────────────────────────────────────────────────
130
 
131
- return res, occlusion_mask
132
 
133
  # ==============================================================================
134
- # 3. MODELS & HELPERS (unchanged except LaMa now runs twice for perfection)
135
  # ==============================================================================
136
  def load_models():
137
  print("Loading Depth Anything V2 Large...")
138
  depth_model = AutoModelForDepthEstimation.from_pretrained(
139
  "depth-anything/Depth-Anything-V2-Large-hf"
140
- ).to(device)
141
  depth_processor = AutoImageProcessor.from_pretrained(
142
  "depth-anything/Depth-Anything-V2-Large-hf"
143
  )
@@ -145,180 +114,164 @@ def load_models():
145
  print("Loading LaMa Inpainting Model...")
146
  try:
147
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
148
- lama_model = torch.jit.load(model_path, map_location=device)
149
- lama_model.eval()
150
  except Exception as e:
151
  print(f"LaMa load failed: {e}")
152
  lama_model = None
153
 
154
  stereo_warper = ForwardWarpStereo().to(device)
 
155
  return depth_model, depth_processor, lama_model, stereo_warper
156
 
157
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
158
 
159
  @torch.no_grad()
160
- def estimate_depth(image_pil, model, processor):
161
  original_size = image_pil.size
162
- inputs = processor(images=image_pil, return_tensors="pt").to(device)
163
- depth = model(**inputs).predicted_depth
 
 
164
  depth = torch.nn.functional.interpolate(
165
  depth.unsqueeze(1),
166
  size=(original_size[1], original_size[0]),
167
  mode="bicubic",
168
  align_corners=False,
169
- ).squeeze()
170
 
 
171
  d_min, d_max = depth.min(), depth.max()
172
- if d_max - d_min > 0:
173
  depth = (depth - d_min) / (d_max - d_min)
174
- else:
175
- depth = torch.zeros_like(depth)
176
  return depth
177
 
178
  @torch.no_grad()
179
- def run_lama_twice(image_bgr, mask_float):
180
  if lama_model is None:
181
  return image_bgr
182
 
183
- # First pass
184
- img1 = run_local_lama(image_bgr, mask_float)
185
-
186
- # Second pass with slightly larger mask
187
- kernel = np.ones((9,9), np.uint8)
188
- mask_dilated = cv2.dilate(mask_float, kernel, iterations=2)
189
- img2 = run_local_lama(img1, mask_dilated)
190
-
191
- return img2
192
-
193
- def run_local_lama(image_bgr, mask_float):
194
- if lama_model is None:
195
- return image_bgr
196
-
197
- kernel = np.ones((5,5), np.uint8)
198
  mask_uint8 = (mask_float * 255).astype(np.uint8)
 
199
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=2)
200
 
201
  h, w = image_bgr.shape[:2]
202
  new_h = (h // 8) * 8
203
  new_w = (w // 8) * 8
204
-
205
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
206
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
207
 
208
- img_t = torch.from_numpy(img_resized).float().permute(2,0,1).unsqueeze(0)/255.0
209
- img_t = img_t[:,[2,1,0],:,:].to(device)
210
- mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)/255.0
211
  mask_t = (mask_t > 0.5).float().to(device)
212
 
213
  img_t = img_t * (1 - mask_t)
214
- inpainted_t = lama_model(img_t, mask_t)
215
-
216
- inpainted = inpainted_t[0].permute(1,2,0).cpu().numpy()
217
- inpainted = np.clip(inpainted*255, 0, 255).astype(np.uint8)
218
- inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
219
- if new_h != h or new_w != w:
220
- inpainted = cv2.resize(inpainted, (w, h))
221
- return inpainted
222
 
223
  def make_anaglyph(left, right):
224
  l = np.array(left)
225
  r = np.array(right)
226
  ana = np.zeros_like(l)
227
- ana[:,:,0] = l[:,:,0] # Red ← Left eye
228
- ana[:,:,1] = r[:,:,1] # Green ← Right eye
229
- ana[:,:,2] = r[:,:,2] # Blue ← Right eye
230
  return Image.fromarray(ana)
231
 
232
  # ==============================================================================
233
- # 4. MAIN PIPELINE β€” FINAL CLEAN VERSION
234
  # ==============================================================================
235
  @torch.no_grad()
236
- def stereo_pipeline(image_pil, divergence_percent, convergence_plane):
237
  if image_pil is None:
238
  return None, None, None, None
239
 
240
  w, h = image_pil.size
241
  if w > 1920:
242
  ratio = 1920 / w
243
- image_pil = image_pil.resize((int(w*ratio), int(h*ratio)), Image.LANCZOS)
244
  w, h = image_pil.size
245
 
246
  # 1. Depth
247
- depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
248
- depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8)
249
- depth_image = Image.fromarray(depth_vis)
250
-
251
- # 2. Disparity (square for better volume)
252
- disp_raw = depth_tensor ** 2
253
- disp_max = torch.quantile(disp_raw, 0.995)
254
- disp_clipped = torch.clamp(disp_raw, max=disp_max)
255
 
256
- # 3. Shift calculation
257
- max_shift_px = w * (divergence_percent / 100.0)
258
- shift_pixels_raw = disp_clipped * max_shift_px
259
 
260
- shift_min, shift_max = shift_pixels_raw.min(), shift_pixels_raw.max()
 
 
 
261
  convergence_offset = shift_min + convergence_plane * (shift_max - shift_min)
262
- final_shift_pixels = shift_pixels_raw - convergence_offset
263
 
264
- print(f"Shift range: {final_shift_pixels.min():.1f} β†’ {final_shift_pixels.max():.1f} px")
265
 
266
- # 4. Warp
267
- image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
268
- image_tensor = image_tensor.permute(2,0,1).unsqueeze(0)
269
 
270
- shift_input = final_shift_pixels.unsqueeze(0).to(device)
271
- disp_for_weights = disp_clipped.unsqueeze(0).to(device)
272
 
273
- right_tensor, occlusion_mask = stereo_warper(image_tensor, shift_input, disp_for_weights)
274
 
275
- # 5. Convert to numpy
276
- right_rgb = (right_tensor.squeeze(0).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
277
- right_bgr = cv2.cvtColor(right_rgb, cv2.COLOR_RGB2BGR)
278
- mask_np = occlusion_mask.squeeze().cpu().numpy()
279
 
280
- # 6. Two-pass LaMa (perfect edges)
281
- right_filled_bgr = run_lama_twice(right_bgr, mask_np)
282
  right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
283
 
284
  # 7. Outputs
285
  mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
286
- combined = Image.new('RGB', (w*2, h))
287
- combined.paste(image_pil, (0, 0))
288
- combined.paste(right_filled, (w, 0))
 
 
289
  anaglyph = make_anaglyph(image_pil, right_filled)
290
 
291
- return combined, anaglyph, depth_image, mask_vis
292
 
293
  # ==============================================================================
294
- # 5. GRADIO UI β€” Simplified (erosion slider removed)
295
  # ==============================================================================
296
- with gr.Blocks(title="2D β†’ 3D Stereo (Final Pro Version)") as demo:
297
- gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo β€” Pro Quality</h1>")
298
- gr.Markdown("Depth Anything V2 + Forward Warp + Smart Inpainting")
299
 
300
  with gr.Row():
301
  with gr.Column(scale=1):
302
- input_img = gr.Image(type="pil", label="Upload Image", height=500)
303
  with gr.Accordion("Settings", open=True):
304
- divergence = gr.Slider(0.5, 8.0, value=3.2, step=0.1,
305
- label="3D Strength (%)")
306
  convergence = gr.Slider(0.0, 1.0, value=0.08, step=0.01,
307
- label="Convergence Plane (0 = pop-out, 1 = deep-in)")
308
  btn = gr.Button("Generate 3D", variant="primary", size="lg")
309
 
310
  with gr.Column(scale=1):
311
- out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=500)
312
- out_sbs = gr.Image(label="Side-by-Side Pair", height=300)
313
  with gr.Row():
314
  out_depth = gr.Image(label="Depth Map", height=200)
315
- out_mask = gr.Image(label="Inpainting Mask", height=200)
316
 
317
- btn.click(fn=stereo_pipeline,
318
- inputs=[input_img, divergence, convergence],
319
- outputs=[out_sbs, out_anaglyph, out_depth, out_mask])
 
 
320
 
321
- gr.Markdown("**Tip:** Red/Cyan glasses β†’ anaglyph β€’ Cross-eye or parallel β†’ side-by-side")
322
 
323
  if __name__ == "__main__":
324
  demo.launch(share=True)
 
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
 
7
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
8
  from huggingface_hub import hf_hub_download
9
  import os
 
13
  print(f"Running on device: {device}")
14
 
15
  # ==============================================================================
16
+ # 1. SAFE & FAST FORWARD WARPER USING grid_sample (NO MORE BLACK IMAGES!)
17
  # ==============================================================================
18
+ class SafeForwardWarp(nn.Module):
19
+ def forward(self, img, flow):
20
+ """
21
+ img: [B, C, H, W] in [0,1]
22
+ flow: [B, H, W, 2] flow[...,0] = delta_x (positive = right), flow[...,1] = delta_y
23
+ """
24
+ B, C, H, W = img.shape
25
+
26
+ # Create sampling grid in normalized coordinates [-1, 1]
27
+ grid_x, grid_y = torch.meshgrid(
28
+ torch.arange(W, device=img.device),
29
+ torch.arange(H, device=img.device),
30
+ indexing='ij'
31
+ )
32
+ grid_x = grid_x.float().unsqueeze(0).expand(B, -1, -1) # [B, H, W]
33
+ grid_y = grid_y.float().unsqueeze(0).expand(B, -1, -1)
34
+
35
+ dest_x = grid_x + flow[..., 0] # source pixel moves to x + dx
36
+ dest_y = grid_y + flow[..., 1]
37
+
38
+ # Normalize to [-1, 1]
39
+ norm_x = 2.0 * dest_x / (W - 1) - 1.0
40
+ norm_y = 2.0 * dest_y / (H - 1) - 1.0
41
+
42
+ grid = torch.stack((norm_x, norm_y), dim=-1) # [B, H, W, 2]
43
+ grid = grid.clamp(-1, 1)
44
+
45
+ warped = torch.nn.functional.grid_sample(
46
+ img,
47
+ grid,
48
+ mode='bilinear',
49
+ padding_mode='zeros',
50
+ align_corners=True
51
+ )
52
+ return warped
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # ==============================================================================
55
+ # 2. STEREO WARPER β€” Improved weighting + safer dilation
56
  # ==============================================================================
57
  class ForwardWarpStereo(nn.Module):
58
  def __init__(self, eps=1e-6):
59
  super().__init__()
60
  self.eps = eps
61
+ self.warp = SafeForwardWarp()
62
 
63
+ def forward(self, img, shift, disp_for_weights):
64
+ # shift: [B, H, W] (positive = shift right-eye left β†’ object pops out)
65
+ flow_x = -shift # negative = move pixels left for right eye
66
  flow_y = torch.zeros_like(flow_x)
67
+ flow = torch.stack((flow_x, flow_y), dim=-1) # [B, H, W, 2]
68
 
69
+ # Better weighting: closer pixels contribute more
70
+ weights = 1.0 / (disp_for_weights + 0.1)
71
+ weights = weights / (weights.max() + 1e-8)
 
72
 
73
+ weighted_img = img * weights.unsqueeze(1)
74
+ warped_img = self.warp(weighted_img, flow)
75
+ warped_weights = self.warp(weights.unsqueeze(1), flow)
 
76
 
77
+ # Avoid division by zero
78
+ warped_weights = torch.clamp(warped_weights, min=self.eps)
79
+ result = warped_img / warped_weights
 
80
 
81
+ # Occlusion mask via occupancy count
82
+ ones = torch.ones_like(img[:, :1])
83
+ occupancy = self.warp(ones, flow)
84
+ occlusion = (occupancy < self.eps).float()
85
+
86
+ # Smart dilation β€” preserve foreground edges
87
  with torch.no_grad():
88
+ fg_thresh = torch.quantile(disp_for_weights, 0.90)
 
89
  fg_mask = (disp_for_weights > fg_thresh).float().unsqueeze(0)
90
 
91
+ k = 9
 
92
  dilated = torch.nn.functional.conv2d(
93
+ occlusion,
94
+ torch.ones(1, 1, k, k, device=occlusion.device),
95
+ padding=k // 2
96
+ ) > 0.5
97
  safe_dilation = dilated.float() * (1 - fg_mask)
98
+ occlusion = torch.clamp(occlusion + safe_dilation, 0, 1)
 
99
 
100
+ return result, occlusion
101
 
102
  # ==============================================================================
103
+ # 3. MODELS & HELPERS
104
  # ==============================================================================
105
  def load_models():
106
  print("Loading Depth Anything V2 Large...")
107
  depth_model = AutoModelForDepthEstimation.from_pretrained(
108
  "depth-anything/Depth-Anything-V2-Large-hf"
109
+ ).to(device).eval()
110
  depth_processor = AutoImageProcessor.from_pretrained(
111
  "depth-anything/Depth-Anything-V2-Large-hf"
112
  )
 
114
  print("Loading LaMa Inpainting Model...")
115
  try:
116
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
117
+ lama_model = torch.jit.load(model_path, map_location=device).eval()
 
118
  except Exception as e:
119
  print(f"LaMa load failed: {e}")
120
  lama_model = None
121
 
122
  stereo_warper = ForwardWarpStereo().to(device)
123
+
124
  return depth_model, depth_processor, lama_model, stereo_warper
125
 
126
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
127
 
128
  @torch.no_grad()
129
+ def estimate_depth(image_pil):
130
  original_size = image_pil.size
131
+ inputs = depth_processor(images=image_pil, return_tensors="pt").to(device)
132
+ outputs = depth_model(**inputs)
133
+ depth = outputs.predicted_depth
134
+
135
  depth = torch.nn.functional.interpolate(
136
  depth.unsqueeze(1),
137
  size=(original_size[1], original_size[0]),
138
  mode="bicubic",
139
  align_corners=False,
140
+ ).squeeze(0).squeeze(0)
141
 
142
+ # Normalize to [0,1]
143
  d_min, d_max = depth.min(), depth.max()
144
+ if d_max > d_min:
145
  depth = (depth - d_min) / (d_max - d_min)
 
 
146
  return depth
147
 
148
  @torch.no_grad()
149
+ def run_lama(image_bgr, mask_float):
150
  if lama_model is None:
151
  return image_bgr
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  mask_uint8 = (mask_float * 255).astype(np.uint8)
154
+ kernel = np.ones((7, 7), np.uint8)
155
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=2)
156
 
157
  h, w = image_bgr.shape[:2]
158
  new_h = (h // 8) * 8
159
  new_w = (w // 8) * 8
 
160
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
161
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
162
 
163
+ img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
164
+ img_t = img_t[:, [2, 1, 0]].to(device) # BGR β†’ RGB
165
+ mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
166
  mask_t = (mask_t > 0.5).float().to(device)
167
 
168
  img_t = img_t * (1 - mask_t)
169
+ inpainted = lama_model(img_t, mask_t)
170
+ result = (inpainted[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
171
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
172
+ if (new_h, new_w) != (h, w):
173
+ result = cv2.resize(result, (w, h))
174
+ return result
 
 
175
 
176
  def make_anaglyph(left, right):
177
  l = np.array(left)
178
  r = np.array(right)
179
  ana = np.zeros_like(l)
180
+ ana[:, :, 0] = l[:, :, 0] # Red ← Left
181
+ ana[:, :, 1] = r[:, :, 1] # Green ← Right
182
+ ana[:, :, 2] = r[:, :, 2] # Blue ← Right
183
  return Image.fromarray(ana)
184
 
185
  # ==============================================================================
186
+ # 4. MAIN PIPELINE
187
  # ==============================================================================
188
  @torch.no_grad()
189
+ def stereo_pipeline(image_pil, divergence_percent=3.2, convergence_plane=0.08):
190
  if image_pil is None:
191
  return None, None, None, None
192
 
193
  w, h = image_pil.size
194
  if w > 1920:
195
  ratio = 1920 / w
196
+ image_pil = image_pil.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
197
  w, h = image_pil.size
198
 
199
  # 1. Depth
200
+ depth = estimate_depth(image_pil) # [H, W] in [0,1]
201
+ depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8))
 
 
 
 
 
 
202
 
203
+ # 2. Disparity (stronger volume with square)
204
+ disp_raw = depth ** 2
205
+ disp_clipped = torch.clamp(disp_raw, max=torch.quantile(disp_raw, 0.995))
206
 
207
+ # 3. Shift
208
+ max_shift = w * (divergence_percent / 100.0)
209
+ shift_raw = disp_clipped * max_shift
210
+ shift_min, shift_max = shift_raw.min(), shift_raw.max()
211
  convergence_offset = shift_min + convergence_plane * (shift_max - shift_min)
212
+ final_shift = shift_raw - convergence_offset
213
 
214
+ print(f"Final shift range: {final_shift.min():.1f} β†’ {final_shift.max():.1f anywhere} px")
215
 
216
+ # 4. Warp right eye
217
+ img_tensor = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
218
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1,3,H,W]
219
 
220
+ shift_tensor = final_shift.unsqueeze(0).to(device) # [1,H,W]
221
+ disp_tensor = disp_clipped.unsqueeze(0).to(device)
222
 
223
+ right_tensor, occlusion_mask = stereo_warper(img_tensor, shift_tensor, disp_tensor)
224
 
225
+ # 5. To numpy
226
+ right_np = (right_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
227
+ right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
228
+ mask_np = occlusion_mask.squeeze(0).cpu().numpy()
229
 
230
+ # 6. Inpaint occlusions
231
+ right_filled_bgr = run_lama(right_bgr, mask_np)
232
  right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
233
 
234
  # 7. Outputs
235
  mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
236
+
237
+ sbs = Image.new('RGB', (w * 2, h))
238
+ sbs.paste(image_pil, (0, 0))
239
+ sbs.paste(right_filled, (w, 0))
240
+
241
  anaglyph = make_anaglyph(image_pil, right_filled)
242
 
243
+ return sbs, anaglyph, depth_vis, mask_vis
244
 
245
  # ==============================================================================
246
+ # 5. GRADIO UI
247
  # ==============================================================================
248
+ with gr.Blocks(title="2D β†’ 3D Stereo β€” Pro & Stable") as demo:
249
+ gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo β€” Pro Quality (Fixed & Stable)</h1>")
250
+ gr.Markdown("Depth Anything V2 + Safe Forward Warping + LaMa Inpainting")
251
 
252
  with gr.Row():
253
  with gr.Column(scale=1):
254
+ input_img = gr.Image(type="pil", label="Upload Image", height=520)
255
  with gr.Accordion("Settings", open=True):
256
+ divergence = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="3D Strength (%)")
 
257
  convergence = gr.Slider(0.0, 1.0, value=0.08, step=0.01,
258
+ label="Convergence Plane (0 = pop-out, 1 = deep)")
259
  btn = gr.Button("Generate 3D", variant="primary", size="lg")
260
 
261
  with gr.Column(scale=1):
262
+ out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=520)
263
+ out_sbs = gr.Image(label="Side-by-Side (Cross-eye / Parallel)", height=300)
264
  with gr.Row():
265
  out_depth = gr.Image(label="Depth Map", height=200)
266
+ out_mask = gr.Image(label="Occlusion Mask", height=200)
267
 
268
+ btn.click(
269
+ fn=stereo_pipeline,
270
+ inputs=[input_img, divergence, convergence],
271
+ outputs=[out_sbs, out_anaglyph, out_depth, out_mask]
272
+ )
273
 
274
+ gr.Markdown("**Tip:** Use Red/Cyan glasses for anaglyph β€’ Cross-eye or parallel view for SBS")
275
 
276
  if __name__ == "__main__":
277
  demo.launch(share=True)