enoky commited on
Commit
66f7927
Β·
verified Β·
1 Parent(s): cd5cadf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +252 -150
app.py CHANGED
@@ -7,11 +7,15 @@ from PIL import Image
7
  from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
 
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- print(f"Running on {device}")
13
 
14
- # ==================== 1. FORWARD WARP (unchanged) ====================
 
 
15
  class ForwardWarpFunction(Function):
16
  @staticmethod
17
  def forward(ctx, im0, flow, interpolation_mode_int):
@@ -26,202 +30,300 @@ class ForwardWarpFunction(Function):
26
  x_dest = grid_x + flow[:, :, :, 0]
27
  y_dest = grid_y + flow[:, :, :, 1]
28
 
29
- x_f = torch.floor(x_dest).long(); x_c = x_f + 1
30
- y_f = torch.floor(y_dest).long(); y_c = y_f + 1
 
 
31
 
32
  nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
33
  ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
34
  sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
35
  se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
36
 
37
- x_f_clamped = torch.clamp(x_f, 0, W-1); x_c_clamped = torch.clamp(x_c, 0, W-1)
38
- y_f_clamped = torch.clamp(y_f, 0, H-1); y_c_clamped = torch.clamp(y_c, 0, H-1)
 
 
39
 
40
  mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
41
  mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
42
  mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
43
  mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
44
 
45
- for w,k,m in [(nw_k,mask_nw),(ne_k,mask_ne),(sw_k,mask_sw),(se_k,mask_se)]:
46
- w.unsqueeze_(1); m.unsqueeze_(1)
47
-
48
- b_idx = torch.arange(B, device=im0.device).view(B,1,1,1).expand(-1,C,H,W)
49
- c_idx = torch.arange(C, device=im0.device).view(1,C,1,1).expand(B,-1,H,W)
50
- base = b_idx * (C*H*W) + c_idx * (H*W)
51
-
52
- def scatter(y_idx, x_idx, weights, mask):
53
- flat = base + y_idx.unsqueeze(1)*W + x_idx.unsqueeze(1)
54
- val = (im0 * weights) * mask.float()
55
- im1.reshape(-1).scatter_add_(0, flat.reshape(-1), val.reshape(-1))
56
-
57
- scatter(y_f_clamped, x_f_clamped, nw_k, mask_nw)
58
- scatter(y_f_clamped, x_c_clamped, ne_k, mask_ne)
59
- scatter(y_c_clamped, x_f_clamped, sw_k, mask_sw)
60
- scatter(y_c_clamped, x_c_clamped, se_k, mask_se)
 
 
 
 
 
 
61
 
62
  return im1
63
 
64
  @staticmethod
65
- def backward(ctx, grad_output): return None,None,None
 
66
 
67
  class forward_warp(nn.Module):
68
- def forward(self, im0, flow): return ForwardWarpFunction.apply(im0, flow, 0)
 
 
69
 
70
- # ==================== 2. STEREO WARPER (fixed + safe dilation) ====================
 
 
71
  class ForwardWarpStereo(nn.Module):
72
- def __init__(self): super().__init__()
73
- def forward(self, im, shift, disp):
74
- flow = torch.stack((-shift, torch.zeros_like(shift)), dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Fixed linear weights – no more detached arms
77
- weights = disp / (disp.max() + 1e-8) + 0.05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- warped = forward_warp()(im * weights.unsqueeze(1), flow)
80
- wmap = forward_warp()(weights.unsqueeze(1), flow)
81
- wmap.clamp_(min=1e-6)
82
- res = warped / wmap
83
 
84
- occ = forward_warp()(torch.ones_like(im[:,:1]), flow) < 1e-6
 
85
 
86
- # Smart dilation that never eats foreground
87
- with torch.no_grad():
88
- fg = (disp > disp.quantile(0.88)).float().unsqueeze(0)
89
- dilated = torch.nn.functional.conv2d(occ.float(), torch.ones(1,1,15,15,device=device), padding=7) > 0.1
90
- occ = torch.clamp(occ.float() + dilated * (1-fg), 0, 1)
91
-
92
- return res, occ
93
-
94
- stereo_warper = ForwardWarpStereo().to(device)
95
-
96
- # ==================== 3. MODELS ====================
97
- print("Loading Depth Anything V2 Large...")
98
- depth_model = AutoModelForDepthEstimation.from_pretrained(
99
- "depth-anything/Depth-Anything-V2-Large-hf").to(device)
100
- processor = AutoImageProcessor.from_pretrained(
101
- "depth-anything/Depth-Anything-V2-Large-hf")
102
-
103
- print("Loading LaMa...")
104
- try:
105
- lama_path = hf_hub_download("fashn-ai/LaMa", "big-lama.pt")
106
- lama_model = torch.jit.load(lama_path, map_location=device).eval()
107
- except:
108
- lama_model = None
109
- print("LaMa not available – inpainting will be skipped")
110
-
111
- # ==================== 4. HELPERS ====================
112
- @torch.no_grad()
113
- def estimate_depth(img_pil):
114
- inputs = processor(images=img_pil, return_tensors="pt").to(device)
115
- d = depth_model(**inputs).predicted_depth
116
- d = torch.nn.functional.interpolate(d.unsqueeze(1), size=img_pil.size[::-1],
117
- mode="bicubic", align_corners=False).squeeze()
118
- d = (d - d.min()) / (d.max() - d.min() + 1e-8)
119
- return d
120
-
121
- def safe_dilate(mask_np, k=5, it=2):
122
- if mask_np.sum() == 0: return mask_np
123
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k,k))
124
- return cv2.dilate(mask_np, kernel, iterations=it)
125
 
126
- @torch.no_grad()
127
- def lama_inpaint(img_bgr, mask_np):
128
- if lama_model is None or mask_np.sum() == 0:
129
- return img_bgr
 
 
 
 
 
130
 
131
- mask_dil = safe_dilate((mask_np*255).astype(np.uint8), k=7, it=3) / 255.0
 
 
132
 
133
- h, w = img_bgr.shape[:2]
134
- nh, nw = (h//8)*8, (w//8)*8
135
- img_res = cv2.resize(img_bgr, (nw, nh))
136
- mask_res = cv2.resize(mask_dil, (nw, nh), interpolation=cv2.INTER_NEAREST)
137
 
138
- img_t = torch.from_numpy(img_res).float().permute(2,0,1).unsqueeze(0)/255.0
139
- img_t = img_t[:,[2,1,0]].to(device)
140
- mask_t = torch.from_numpy(mask_res > 0.5).float().unsqueeze(0).unsqueeze(0).to(device)
 
141
 
142
  img_t = img_t * (1 - mask_t)
143
- out = lama_model(img_t, mask_t)[0].permute(1,2,0).cpu().numpy()
144
- out = np.clip(out*255, 0, 255).astype(np.uint8)
145
- out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
146
- if (nh,nw) != (h,w):
147
- out = cv2.resize(out, (w,h))
148
- return out
149
-
150
- def make_anaglyph(l, r):
151
- l = np.array(l); r = np.array(r)
152
- a = np.zeros_like(l)
153
- a[...,0] = l[...,0]
154
- a[...,1] = r[...,1]
155
- a[...,2] = r[...,2]
156
- return Image.fromarray(a)
157
-
158
- # ==================== 5. MAIN PIPELINE ====================
 
 
 
 
 
159
  @torch.no_grad()
160
- def stereo_pipeline(img_pil, strength=3.2, convergence=0.08):
161
- if img_pil is None: return None,None,None,None
 
162
 
163
- w, h = img_pil.size
164
  if w > 1920:
165
- ratio = 1920/w
166
- img_pil = img_pil.resize((int(w*ratio), int(h*ratio)), Image.LANCZOS)
167
- w, h = img_pil.size
168
 
169
- depth = estimate_depth(img_pil)
170
- disp = torch.clamp(depth**2, max=torch.quantile(depth**2, 0.995))
 
 
171
 
172
- max_shift = w * strength / 100.0
173
- shift = disp * max_shift
174
- shift = shift - shift.min() - convergence * (shift.max() - shift.min())
 
175
 
176
- tensor = torch.from_numpy(np.array(img_pil)).float().to(device)/255.0
177
- tensor = tensor.permute(2,0,1).unsqueeze(0)
 
178
 
179
- right, occ = stereo_warper(tensor, shift.unsqueeze(0), disp.unsqueeze(0))
 
 
180
 
181
- right_np = (right.squeeze(0).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
182
- right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
183
- mask_np = occ.squeeze(0).cpu().numpy()
184
 
185
- # Two-pass LaMa (safe + perfect edges)
186
- right_filled = lama_inpaint(right_bgr, mask_np)
187
- right_filled = lama_inpaint(right_filled, mask_np) # second pass
188
 
189
- right_pil = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
 
190
 
191
- sbs = Image.new("RGB", (w*2, h))
192
- sbs.paste(img_pil, (0,0))
193
- sbs.paste(right_pil, (w,0))
194
 
195
- ana = make_anaglyph(img_pil, right_pil)
196
- depth_vis = Image.fromarray((depth.cpu().numpy()*255).astype(np.uint8))
197
- mask_vis = Image.fromarray((mask_np*255).astype(np.uint8))
 
198
 
199
- return sbs, ana, depth_vis, mask_vis
 
 
200
 
201
- # ==================== 6. GRADIO UI ====================
202
- css = ".gradio-container {max-width: 1450px !important; margin: auto !important;}"
203
- with gr.Blocks() as demo:
204
- gr.HTML(f"<style>{css}</style>")
205
- gr.Markdown("# 2D β†’ 3D Stereo – Pro Quality\nDepth Anything V2 + Forward Warp + Smart LaMa")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  with gr.Row():
208
- with gr.Column():
209
- inp = gr.Image(type="pil", label="Upload Image", height=520)
210
  with gr.Accordion("Settings", open=True):
211
- strength = gr.Slider(0.5, 8, 3.2, step=0.1, label="3D Strength (%)")
212
- conv = gr.Slider(0, 1, 0.08, step=0.01, label="Convergence (0=pop-out)")
213
- btn = gr.Button("Generate 3D", variant="primary")
214
-
215
- with gr.Column():
216
- out_ana = gr.Image(label="Anaglyph (Red/Cyan)", height=520)
217
- out_sbs = gr.Image(label="Side-by-Side", height=320)
 
 
218
  with gr.Row():
219
- gr.Image(label="Depth Map", height=200)
220
- gr.Image(label="Mask", height=200)
221
 
222
- btn.click(stereo_pipeline, [inp, strength, conv],
223
- [out_sbs, out_ana, gr.Image(), gr.Image()])
 
224
 
225
- gr.Markdown("**Red/Cyan glasses** β†’ anaglyph β€’ **Cross-eye/parallel** β†’ side-by-side")
226
 
227
- demo.launch(share=True)
 
 
7
  from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
10
+ import os
11
 
12
+ # === DEVICE ===
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Running on device: {device}")
15
 
16
+ # ==============================================================================
17
+ # 1. FORWARD WARP (unchanged β€” your version was already excellent)
18
+ # ==============================================================================
19
  class ForwardWarpFunction(Function):
20
  @staticmethod
21
  def forward(ctx, im0, flow, interpolation_mode_int):
 
30
  x_dest = grid_x + flow[:, :, :, 0]
31
  y_dest = grid_y + flow[:, :, :, 1]
32
 
33
+ x_f = torch.floor(x_dest).long()
34
+ y_f = torch.floor(y_dest).long()
35
+ x_c = x_f + 1
36
+ y_c = y_f + 1
37
 
38
  nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
39
  ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
40
  sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
41
  se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
42
 
43
+ x_f_clamped = torch.clamp(x_f, 0, W - 1)
44
+ y_f_clamped = torch.clamp(y_f, 0, H - 1)
45
+ x_c_clamped = torch.clamp(x_c, 0, W - 1)
46
+ y_c_clamped = torch.clamp(y_c, 0, H - 1)
47
 
48
  mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
49
  mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
50
  mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
51
  mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
52
 
53
+ nw_k = nw_k.unsqueeze(1)
54
+ ne_k = ne_k.unsqueeze(1)
55
+ sw_k = sw_k.unsqueeze(1)
56
+ se_k = se_k.unsqueeze(1)
57
+ mask_nw = mask_nw.unsqueeze(1)
58
+ mask_ne = mask_ne.unsqueeze(1)
59
+ mask_sw = mask_sw.unsqueeze(1)
60
+ mask_se = mask_se.unsqueeze(1)
61
+
62
+ b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
63
+ c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
64
+ base_idx = b_indices * (C * H * W) + c_indices * (H * W)
65
+
66
+ def scatter_corner(y_idx, x_idx, weights, mask):
67
+ flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
68
+ values = (im0 * weights) * mask.float()
69
+ im1.reshape(-1).scatter_add_(0, flat_idx.contiguous().reshape(-1), values.contiguous().reshape(-1))
70
+
71
+ scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw)
72
+ scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne)
73
+ scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw)
74
+ scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se)
75
 
76
  return im1
77
 
78
  @staticmethod
79
+ def backward(ctx, grad_output):
80
+ return None, None, None
81
 
82
  class forward_warp(nn.Module):
83
+ def __init__(self): super().__init__()
84
+ def forward(self, im0, flow):
85
+ return ForwardWarpFunction.apply(im0, flow, 0)
86
 
87
+ # ==============================================================================
88
+ # 2. STEREO WARPER β€” FIXED Z-BUFFER + SMART MASK DILATION
89
+ # ==============================================================================
90
  class ForwardWarpStereo(nn.Module):
91
+ def __init__(self, eps=1e-6):
92
+ super().__init__()
93
+ self.eps = eps
94
+ self.fw = forward_warp()
95
+
96
+ def forward(self, im, shift, disp_for_weights):
97
+ flow_x = -shift
98
+ flow_y = torch.zeros_like(flow_x)
99
+ flow = torch.stack((flow_x, flow_y), dim=-1)
100
+
101
+ # ────── FIXED: Linear + bias weights (no more detached arms) ──────
102
+ disp_norm = disp_for_weights / (disp_for_weights.max() + 1e-8)
103
+ weights_map = disp_norm + 0.05
104
+ # ─────────────────────────────────────────────────────────────────────
105
+
106
+ res_accum = self.fw(im * weights_map.unsqueeze(1), flow)
107
+ mask_accum = self.fw(weights_map.unsqueeze(1), flow)
108
+ mask_accum.clamp_(min=self.eps)
109
+ res = res_accum / mask_accum
110
+
111
+ # Occupancy for occlusion detection
112
+ ones = torch.ones_like(im[:,0:1,:,:])
113
+ occupancy = self.fw(ones, flow)
114
+ occlusion_mask = (occupancy < self.eps).float()
115
+
116
+ # ────── NEW: Smart, foreground-preserving mask dilation ──────
117
+ with torch.no_grad():
118
+ # Protect clear foreground from over-dilation
119
+ fg_thresh = torch.quantile(disp_for_weights, 0.88)
120
+ fg_mask = (disp_for_weights > fg_thresh).float().unsqueeze(0)
121
+
122
+ # Aggressive but safe dilation
123
+ k = 15
124
+ dilated = torch.nn.functional.conv2d(
125
+ occlusion_mask, torch.ones(1,1,k,k,device=occlusion_mask.device),
126
+ padding=k//2) > 0.1
127
+ safe_dilation = dilated.float() * (1 - fg_mask)
128
+ occlusion_mask = torch.clamp(occlusion_mask + safe_dilation, 0, 1)
129
+ # ─────────────────────────────────────────────────────────────────
130
+
131
+ return res, occlusion_mask
132
+
133
+ # ==============================================================================
134
+ # 3. MODELS & HELPERS (unchanged except LaMa now runs twice for perfection)
135
+ # ==============================================================================
136
+ def load_models():
137
+ print("Loading Depth Anything V2 Large...")
138
+ depth_model = AutoModelForDepthEstimation.from_pretrained(
139
+ "depth-anything/Depth-Anything-V2-Large-hf"
140
+ ).to(device)
141
+ depth_processor = AutoImageProcessor.from_pretrained(
142
+ "depth-anything/Depth-Anything-V2-Large-hf"
143
+ )
144
+
145
+ print("Loading LaMa Inpainting Model...")
146
+ try:
147
+ model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
148
+ lama_model = torch.jit.load(model_path, map_location=device)
149
+ lama_model.eval()
150
+ except Exception as e:
151
+ print(f"LaMa load failed: {e}")
152
+ lama_model = None
153
+
154
+ stereo_warper = ForwardWarpStereo().to(device)
155
+ return depth_model, depth_processor, lama_model, stereo_warper
156
+
157
+ depth_model, depth_processor, lama_model, stereo_warper = load_models()
158
 
159
+ @torch.no_grad()
160
+ def estimate_depth(image_pil, model, processor):
161
+ original_size = image_pil.size
162
+ inputs = processor(images=image_pil, return_tensors="pt").to(device)
163
+ depth = model(**inputs).predicted_depth
164
+ depth = torch.nn.functional.interpolate(
165
+ depth.unsqueeze(1),
166
+ size=(original_size[1], original_size[0]),
167
+ mode="bicubic",
168
+ align_corners=False,
169
+ ).squeeze()
170
+
171
+ d_min, d_max = depth.min(), depth.max()
172
+ if d_max - d_min > 0:
173
+ depth = (depth - d_min) / (d_max - d_min)
174
+ else:
175
+ depth = torch.zeros_like(depth)
176
+ return depth
177
 
178
+ @torch.no_grad()
179
+ def run_lama_twice(image_bgr, mask_float):
180
+ if lama_model is None:
181
+ return image_bgr
182
 
183
+ # First pass
184
+ img1 = run_local_lama(image_bgr, mask_float)
185
 
186
+ # Second pass with slightly larger mask
187
+ kernel = np.ones((9,9), np.uint8)
188
+ mask_dilated = cv2.dilate(mask_float, kernel, iterations=2)
189
+ img2 = run_local_lama(img1, mask_dilated)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ return img2
192
+
193
+ def run_local_lama(image_bgr, mask_float):
194
+ if lama_model is None:
195
+ return image_bgr
196
+
197
+ kernel = np.ones((5,5), np.uint8)
198
+ mask_uint8 = (mask_float * 255).astype(np.uint8)
199
+ mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=2)
200
 
201
+ h, w = image_bgr.shape[:2]
202
+ new_h = (h // 8) * 8
203
+ new_w = (w // 8) * 8
204
 
205
+ img_resized = cv2.resize(image_bgr, (new_w, new_h))
206
+ mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
 
 
207
 
208
+ img_t = torch.from_numpy(img_resized).float().permute(2,0,1).unsqueeze(0)/255.0
209
+ img_t = img_t[:,[2,1,0],:,:].to(device)
210
+ mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)/255.0
211
+ mask_t = (mask_t > 0.5).float().to(device)
212
 
213
  img_t = img_t * (1 - mask_t)
214
+ inpainted_t = lama_model(img_t, mask_t)
215
+
216
+ inpainted = inpainted_t[0].permute(1,2,0).cpu().numpy()
217
+ inpainted = np.clip(inpainted*255, 0, 255).astype(np.uint8)
218
+ inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
219
+ if new_h != h or new_w != w:
220
+ inpainted = cv2.resize(inpainted, (w, h))
221
+ return inpainted
222
+
223
+ def make_anaglyph(left, right):
224
+ l = np.array(left)
225
+ r = np.array(right)
226
+ ana = np.zeros_like(l)
227
+ ana[:,:,0] = l[:,:,0] # Red ← Left eye
228
+ ana[:,:,1] = r[:,:,1] # Green ← Right eye
229
+ ana[:,:,2] = r[:,:,2] # Blue ← Right eye
230
+ return Image.fromarray(ana)
231
+
232
+ # ==============================================================================
233
+ # 4. MAIN PIPELINE β€” FINAL CLEAN VERSION
234
+ # ==============================================================================
235
  @torch.no_grad()
236
+ def stereo_pipeline(image_pil, divergence_percent, convergence_plane):
237
+ if image_pil is None:
238
+ return None, None, None, None
239
 
240
+ w, h = image_pil.size
241
  if w > 1920:
242
+ ratio = 1920 / w
243
+ image_pil = image_pil.resize((int(w*ratio), int(h*ratio)), Image.LANCZOS)
244
+ w, h = image_pil.size
245
 
246
+ # 1. Depth
247
+ depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
248
+ depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8)
249
+ depth_image = Image.fromarray(depth_vis)
250
 
251
+ # 2. Disparity (square for better volume)
252
+ disp_raw = depth_tensor ** 2
253
+ disp_max = torch.quantile(disp_raw, 0.995)
254
+ disp_clipped = torch.clamp(disp_raw, max=disp_max)
255
 
256
+ # 3. Shift calculation
257
+ max_shift_px = w * (divergence_percent / 100.0)
258
+ shift_pixels_raw = disp_clipped * max_shift_px
259
 
260
+ shift_min, shift_max = shift_pixels_raw.min(), shift_pixels_raw.max()
261
+ convergence_offset = shift_min + convergence_plane * (shift_max - shift_min)
262
+ final_shift_pixels = shift_pixels_raw - convergence_offset
263
 
264
+ print(f"Shift range: {final_shift_pixels.min():.1f} β†’ {final_shift_pixels.max():.1f} px")
 
 
265
 
266
+ # 4. Warp
267
+ image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
268
+ image_tensor = image_tensor.permute(2,0,1).unsqueeze(0)
269
 
270
+ shift_input = final_shift_pixels.unsqueeze(0).to(device)
271
+ disp_for_weights = disp_clipped.unsqueeze(0).to(device)
272
 
273
+ right_tensor, occlusion_mask = stereo_warper(image_tensor, shift_input, disp_for_weights)
 
 
274
 
275
+ # 5. Convert to numpy
276
+ right_rgb = (right_tensor.squeeze(0).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
277
+ right_bgr = cv2.cvtColor(right_rgb, cv2.COLOR_RGB2BGR)
278
+ mask_np = occlusion_mask.squeeze(0).cpu().numpy()
279
 
280
+ # 6. Two-pass LaMa (perfect edges)
281
+ right_filled_bgr = run_lama_twice(right_bgr, mask_np)
282
+ right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
283
 
284
+ # 7. Outputs
285
+ mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
286
+ combined = Image.new('RGB', (w*2, h))
287
+ combined.paste(image_pil, (0, 0))
288
+ combined.paste(right_filled, (w, 0))
289
+ anaglyph = make_anaglyph(image_pil, right_filled)
290
+
291
+ return combined, anaglyph, depth_image, mask_vis
292
+
293
+ # ==============================================================================
294
+ # 5. GRADIO UI β€” Simplified (erosion slider removed)
295
+ # ==============================================================================
296
+ css_style = """
297
+ .gradio-container {max-width: 1400px !important; margin: auto !important;}
298
+ """
299
+
300
+ with gr.Blocks(title="2D β†’ 3D Stereo (Final Pro Version)") as demo:
301
+ gr.HTML(f"<style>{css_style}</style>")
302
+ gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo β€” Pro Quality</h1>")
303
+ gr.Markdown("Depth Anything V2 + Forward Warp + Smart Inpainting")
304
 
305
  with gr.Row():
306
+ with gr.Column(scale=1):
307
+ input_img = gr.Image(type="pil", label="Upload Image", height=500)
308
  with gr.Accordion("Settings", open=True):
309
+ divergence = gr.Slider(0.5, 8.0, value=3.2, step=0.1,
310
+ label="3D Strength (%)")
311
+ convergence = gr.Slider(0.0, 1.0, value=0.08, step=0.01,
312
+ label="Convergence Plane (0 = pop-out, 1 = deep-in)")
313
+ btn = gr.Button("Generate 3D", variant="primary", size="lg")
314
+
315
+ with gr.Column(scale=1):
316
+ out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=500)
317
+ out_sbs = gr.Image(label="Side-by-Side Pair", height=300)
318
  with gr.Row():
319
+ out_depth = gr.Image(label="Depth Map", height=200)
320
+ out_mask = gr.Image(label="Inpainting Mask", height=200)
321
 
322
+ btn.click(fn=stereo_pipeline,
323
+ inputs=[input_img, divergence, convergence],
324
+ outputs=[out_sbs, out_anaglyph, out_depth, out_mask])
325
 
326
+ gr.Markdown("**Tip:** Red/Cyan glasses β†’ anaglyph β€’ Cross-eye or parallel β†’ side-by-side")
327
 
328
+ if __name__ == "__main__":
329
+ demo.launch(share=True)