enoky commited on
Commit
cd5cadf
Β·
verified Β·
1 Parent(s): 04b39f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -218
app.py CHANGED
@@ -7,15 +7,11 @@ from PIL import Image
7
  from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
10
- import os
11
 
12
- # === DEVICE ===
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- print(f"Running on device: {device}")
15
 
16
- # ==============================================================================
17
- # 1. FORWARD WARP (unchanged)
18
- # ==============================================================================
19
  class ForwardWarpFunction(Function):
20
  @staticmethod
21
  def forward(ctx, im0, flow, interpolation_mode_int):
@@ -30,270 +26,202 @@ class ForwardWarpFunction(Function):
30
  x_dest = grid_x + flow[:, :, :, 0]
31
  y_dest = grid_y + flow[:, :, :, 1]
32
 
33
- x_f = torch.floor(x_dest).long()
34
- y_f = torch.floor(y_dest).long()
35
- x_c = x_f + 1
36
- y_c = y_f + 1
37
 
38
  nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
39
  ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
40
  sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
41
  se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
42
 
43
- x_f_clamped = torch.clamp(x_f, 0, W - 1)
44
- y_f_clamped = torch.clamp(y_f, 0, H - 1)
45
- x_c_clamped = torch.clamp(x_c, 0, W - 1)
46
- y_c_clamped = torch.clamp(y_c, 0, H - 1)
47
 
48
  mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
49
  mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
50
  mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
51
  mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
52
 
53
- nw_k = nw_k.unsqueeze(1); ne_k = ne_k.unsqueeze(1)
54
- sw_k = sw_k.unsqueeze(1); se_k = se_k.unsqueeze(1)
55
- mask_nw = mask_nw.unsqueeze(1); mask_ne = mask_ne.unsqueeze(1)
56
- mask_sw = mask_sw.unsqueeze(1); mask_se = mask_se.unsqueeze(1)
57
 
58
- b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
59
- c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
60
- base_idx = b_indices * (C * H * W) + c_indices * (H * W)
61
 
62
- def scatter_corner(y_idx, x_idx, weights, mask):
63
- flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
64
- values = (im0 * weights) * mask.float()
65
- im1.reshape(-1).scatter_add_(0, flat_idx.contiguous().reshape(-1), values.contiguous().reshape(-1))
66
 
67
- scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw)
68
- scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne)
69
- scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw)
70
- scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se)
71
 
72
  return im1
73
 
74
  @staticmethod
75
- def backward(ctx, grad_output):
76
- return None, None, None
77
 
78
  class forward_warp(nn.Module):
79
- def __init__(self): super().__init__()
80
- def forward(self, im0, flow):
81
- return ForwardWarpFunction.apply(im0, flow, 0)
82
 
83
- # ==============================================================================
84
- # 2. STEREO WARPER – FIXED + SMART DILATION
85
- # ==============================================================================
86
  class ForwardWarpStereo(nn.Module):
87
- def __init__(self, eps=1e-6):
88
- super().__init__()
89
- self.eps = eps
90
- self.fw = forward_warp()
91
-
92
- def forward(self, im, shift, disp_for_weights):
93
- flow_x = -shift
94
- flow_y = torch.zeros_like(flow_x)
95
- flow = torch.stack((flow_x, flow_y), dim=-1)
96
-
97
- # Fixed z-buffer weights (no detached limbs)
98
- disp_norm = disp_for_weights / (disp_for_weights.max() + 1e-8)
99
- weights_map = disp_norm + 0.05
100
-
101
- res_accum = self.fw(im * weights_map.unsqueeze(1), flow)
102
- mask_accum = self.fw(weights_map.unsqueeze(1), flow)
103
- mask_accum.clamp_(min=self.eps)
104
- res = res_accum / mask_accum
105
-
106
- ones = torch.ones_like(im[:,0:1,:,:])
107
- occupancy = self.fw(ones, flow)
108
- occlusion_mask = (occupancy < self.eps).float()
109
-
110
- # Smart foreground-preserving dilation
111
- with torch.no_grad():
112
- fg_thresh = torch.quantile(disp_for_weights, 0.88)
113
- fg_mask = (disp_for_weights > fg_thresh).float().unsqueeze(0)
114
- k = 15
115
- dilated = torch.nn.functional.conv2d(
116
- occlusion_mask, torch.ones(1,1,k,k,device=occlusion_mask.device),
117
- padding=k//2) > 0.1
118
- safe_dilation = dilated.float() * (1 - fg_mask)
119
- occlusion_mask = torch.clamp(occlusion_mask + safe_dilation, 0, 1)
120
-
121
- return res, occlusion_mask
122
-
123
- # ==============================================================================
124
- # 3. MODELS & HELPERS
125
- # ==============================================================================
126
- def load_models():
127
- print("Loading Depth Anything V2 Large...")
128
- depth_model = AutoModelForDepthEstimation.from_pretrained(
129
- "depth-anything/Depth-Anything-V2-Large-hf"
130
- ).to(device)
131
- depth_processor = AutoImageProcessor.from_pretrained(
132
- "depth-anything/Depth-Anything-V2-Large-hf"
133
- )
134
-
135
- print("Loading LaMa Inpainting Model...")
136
- try:
137
- model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
138
- lama_model = torch.jit.load(model_path, map_location=device)
139
- lama_model.eval()
140
- except Exception as e:
141
- print(f"LaMa load failed: {e}")
142
- lama_model = None
143
-
144
- stereo_warper = ForwardWarpStereo().to(device)
145
- return depth_model, depth_processor, lama_model, stereo_warper
146
-
147
- depth_model, depth_processor, lama_model, stereo_warper = load_models()
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  @torch.no_grad()
150
- def estimate_depth(image_pil, model, processor):
151
- original_size = image_pil.size
152
- inputs = processor(images=image_pil, return_tensors="pt").to(device)
153
- depth = model(**inputs).predicted_depth
154
- depth = torch.nn.functional.interpolate(
155
- depth.unsqueeze(1),
156
- size=(original_size[1], original_size[0]),
157
- mode="bicubic",
158
- align_corners=False,
159
- ).squeeze()
160
-
161
- d_min, d_max = depth.min(), depth.max()
162
- depth = (depth - d_min) / (d_max - d_min + 1e-8) if d_max > d_min else torch.zeros_like(depth)
163
- return depth
164
 
165
  @torch.no_grad()
166
- def run_local_lama(image_bgr, mask_float):
167
- if lama_model is None: return image_bgr
168
- kernel = np.ones((5,5), np.uint8)
169
- mask_uint8 = (mask_float * 255).astype(np.uint8)
170
- mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=2)
171
-
172
- h, w = image_bgr.shape[:2]
173
- new_h, new_w = (h // 8) * 8, (w // 8) * 8
174
- img_resized = cv2.resize(image_bgr, (new_w, new_h))
175
- mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
176
-
177
- img_t = torch.from_numpy(img_resized).float().permute(2,0,1).unsqueeze(0)/255.0
178
- img_t = img_t[:,[2,1,0],:,:].to(device)
179
- mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)/255.0
180
- mask_t = (mask_t > 0.5).float().to(device)
181
 
182
- img_t = img_t * (1 - mask_t)
183
- inpainted_t = lama_model(img_t, mask_t)
184
 
185
- inpainted = inpainted_t[0].permute(1,2,0).cpu().numpy()
186
- inpainted = np.clip(inpainted*255, 0, 255).astype(np.uint8)
187
- inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
188
- if (new_h, new_w) != (h, w):
189
- inpainted = cv2.resize(inpainted, (w, h))
190
- return inpainted
191
 
192
- @torch.no_grad()
193
- def run_lama_twice(image_bgr, mask_float):
194
- if lama_model is None: return image_bgr
195
- img1 = run_local_lama(image_bgr, mask_float)
196
- kernel = np.ones((9,9), np.uint8)
197
- mask_dilated = cv2.dilate(mask_float, kernel, iterations=2)
198
- return run_local_lama(img1, mask_dilated)
199
-
200
- def make_anaglyph(left, right):
201
- l = np.array(left); r = np.array(right)
 
 
 
 
202
  a = np.zeros_like(l)
203
- a[:,:,0] = l[:,:,0] # Red ← left eye
204
- a[:,:,1] = r[:,:,1] # Green ← right eye
205
- a[:,:,2] = r[:,:,2] # Blue ← right eye
206
  return Image.fromarray(a)
207
 
208
- # ==============================================================================
209
- # 4. MAIN PIPELINE
210
- # ==============================================================================
211
  @torch.no_grad()
212
- def stereo_pipeline(image_pil, divergence_percent, convergence_plane):
213
- if image_pil is None:
214
- return None, None, None, None
215
 
216
- w, h = image_pil.size
217
  if w > 1920:
218
- ratio = 1920 / w
219
- image_pil = image_pil.resize((int(w*ratio), int(h*ratio)), Image.LANCZOS)
220
- w, h = image_pil.size
221
-
222
- depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
223
- depth_vis = Image.fromarray((depth_tensor.cpu().numpy() * 255).astype(np.uint8))
224
 
225
- disp_raw = depth_tensor ** 2
226
- disp_max = torch.quantile(disp_raw, 0.995)
227
- disp_clipped = torch.clamp(disp_raw, max=disp_max)
228
 
229
- max_shift_px = w * (divergence_percent / 100.0)
230
- shift_pixels_raw = disp_clipped * max_shift_px
231
- shift_min, shift_max = shift_pixels_raw.min(), shift_pixels_raw.max()
232
- convergence_offset = shift_min + convergence_plane * (shift_max - shift_min)
233
- final_shift_pixels = shift_pixels_raw - convergence_offset
234
 
235
- image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
236
- image_tensor = image_tensor.permute(2,0,1).unsqueeze(0)
237
 
238
- right_tensor, occlusion_mask = stereo_warper(
239
- image_tensor,
240
- final_shift_pixels.unsqueeze(0).to(device),
241
- disp_clipped.unsqueeze(0).to(device)
242
- )
243
 
244
- right_rgb = (right_tensor.squeeze(0).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
245
- right_bgr = cv2.cvtColor(right_rgb, cv2.COLOR_RGB2BGR)
246
- mask_np = occlusion_mask.squeeze(0).cpu().numpy()
247
 
248
- right_filled_bgr = run_lama_twice(right_bgr, mask_np)
249
- right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
 
250
 
251
- mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
252
 
253
- combined = Image.new('RGB', (w*2, h))
254
- combined.paste(image_pil, (0, 0))
255
- combined.paste(right_filled, (w, 0))
256
 
257
- anaglyph = make_anaglyph(image_pil, right_filled)
 
 
258
 
259
- return combined, anaglyph, depth_vis, mask_vis
260
 
261
- # ==============================================================================
262
- # 5. GRADIO UI – COMPATIBLE WITH CURRENT GRADIO
263
- # ==============================================================================
264
- css = """
265
- .gradio-container {max-width: 1450px !important; margin: auto !important;}
266
- """
267
-
268
- with gr.Blocks() as demo: # ← removed css= argument
269
- gr.HTML(f"<style>{css}</style>") # ← inject CSS here instead
270
- gr.HTML("<h1 style='text-align:center;'>2D β†’ 3D Stereo – Pro Quality</h1>")
271
- gr.Markdown("Depth Anything V2 + Forward Warp + Smart LaMa Inpainting")
272
 
273
  with gr.Row():
274
- with gr.Column(scale=1):
275
- input_img = gr.Image(type="pil", label="Upload Image", height=520)
276
-
277
  with gr.Accordion("Settings", open=True):
278
- divergence = gr.Slider(0.5, 8.0, value=3.2, step=0.1,
279
- label="3D Strength (%)")
280
- convergence = gr.Slider(0.0, 1.0, value=0.08, step=0.01,
281
- label="Convergence Plane (0 = pop-out, 1 = deep-in)")
282
-
283
- btn = gr.Button("Generate 3D", variant="primary", size="lg")
284
 
285
- with gr.Column(scale=1):
286
- out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=520)
287
- out_sbs = gr.Image(label="Side-by-Side Pair", height=320)
288
  with gr.Row():
289
- out_depth = gr.Image(label="Depth Map", height=200)
290
- out_mask = gr.Image(label="Inpainting Mask", height=200)
291
 
292
- btn.click(fn=stereo_pipeline,
293
- inputs=[input_img, divergence, convergence],
294
- outputs=[out_sbs, out_anaglyph, out_depth, out_mask])
295
 
296
- gr.Markdown("**Tip:** Red/Cyan glasses β†’ anaglyph β€’ Cross-eye or parallel β†’ side-by-side")
297
 
298
- if __name__ == "__main__":
299
- demo.launch(share=True)
 
7
  from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
 
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print(f"Running on {device}")
13
 
14
+ # ==================== 1. FORWARD WARP (unchanged) ====================
 
 
15
  class ForwardWarpFunction(Function):
16
  @staticmethod
17
  def forward(ctx, im0, flow, interpolation_mode_int):
 
26
  x_dest = grid_x + flow[:, :, :, 0]
27
  y_dest = grid_y + flow[:, :, :, 1]
28
 
29
+ x_f = torch.floor(x_dest).long(); x_c = x_f + 1
30
+ y_f = torch.floor(y_dest).long(); y_c = y_f + 1
 
 
31
 
32
  nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
33
  ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
34
  sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
35
  se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
36
 
37
+ x_f_clamped = torch.clamp(x_f, 0, W-1); x_c_clamped = torch.clamp(x_c, 0, W-1)
38
+ y_f_clamped = torch.clamp(y_f, 0, H-1); y_c_clamped = torch.clamp(y_c, 0, H-1)
 
 
39
 
40
  mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
41
  mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
42
  mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
43
  mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
44
 
45
+ for w,k,m in [(nw_k,mask_nw),(ne_k,mask_ne),(sw_k,mask_sw),(se_k,mask_se)]:
46
+ w.unsqueeze_(1); m.unsqueeze_(1)
 
 
47
 
48
+ b_idx = torch.arange(B, device=im0.device).view(B,1,1,1).expand(-1,C,H,W)
49
+ c_idx = torch.arange(C, device=im0.device).view(1,C,1,1).expand(B,-1,H,W)
50
+ base = b_idx * (C*H*W) + c_idx * (H*W)
51
 
52
+ def scatter(y_idx, x_idx, weights, mask):
53
+ flat = base + y_idx.unsqueeze(1)*W + x_idx.unsqueeze(1)
54
+ val = (im0 * weights) * mask.float()
55
+ im1.reshape(-1).scatter_add_(0, flat.reshape(-1), val.reshape(-1))
56
 
57
+ scatter(y_f_clamped, x_f_clamped, nw_k, mask_nw)
58
+ scatter(y_f_clamped, x_c_clamped, ne_k, mask_ne)
59
+ scatter(y_c_clamped, x_f_clamped, sw_k, mask_sw)
60
+ scatter(y_c_clamped, x_c_clamped, se_k, mask_se)
61
 
62
  return im1
63
 
64
  @staticmethod
65
+ def backward(ctx, grad_output): return None,None,None
 
66
 
67
  class forward_warp(nn.Module):
68
+ def forward(self, im0, flow): return ForwardWarpFunction.apply(im0, flow, 0)
 
 
69
 
70
+ # ==================== 2. STEREO WARPER (fixed + safe dilation) ====================
 
 
71
  class ForwardWarpStereo(nn.Module):
72
+ def __init__(self): super().__init__()
73
+ def forward(self, im, shift, disp):
74
+ flow = torch.stack((-shift, torch.zeros_like(shift)), dim=-1)
75
+
76
+ # Fixed linear weights – no more detached arms
77
+ weights = disp / (disp.max() + 1e-8) + 0.05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ warped = forward_warp()(im * weights.unsqueeze(1), flow)
80
+ wmap = forward_warp()(weights.unsqueeze(1), flow)
81
+ wmap.clamp_(min=1e-6)
82
+ res = warped / wmap
83
+
84
+ occ = forward_warp()(torch.ones_like(im[:,:1]), flow) < 1e-6
85
+
86
+ # Smart dilation that never eats foreground
87
+ with torch.no_grad():
88
+ fg = (disp > disp.quantile(0.88)).float().unsqueeze(0)
89
+ dilated = torch.nn.functional.conv2d(occ.float(), torch.ones(1,1,15,15,device=device), padding=7) > 0.1
90
+ occ = torch.clamp(occ.float() + dilated * (1-fg), 0, 1)
91
+
92
+ return res, occ
93
+
94
+ stereo_warper = ForwardWarpStereo().to(device)
95
+
96
+ # ==================== 3. MODELS ====================
97
+ print("Loading Depth Anything V2 Large...")
98
+ depth_model = AutoModelForDepthEstimation.from_pretrained(
99
+ "depth-anything/Depth-Anything-V2-Large-hf").to(device)
100
+ processor = AutoImageProcessor.from_pretrained(
101
+ "depth-anything/Depth-Anything-V2-Large-hf")
102
+
103
+ print("Loading LaMa...")
104
+ try:
105
+ lama_path = hf_hub_download("fashn-ai/LaMa", "big-lama.pt")
106
+ lama_model = torch.jit.load(lama_path, map_location=device).eval()
107
+ except:
108
+ lama_model = None
109
+ print("LaMa not available – inpainting will be skipped")
110
+
111
+ # ==================== 4. HELPERS ====================
112
  @torch.no_grad()
113
+ def estimate_depth(img_pil):
114
+ inputs = processor(images=img_pil, return_tensors="pt").to(device)
115
+ d = depth_model(**inputs).predicted_depth
116
+ d = torch.nn.functional.interpolate(d.unsqueeze(1), size=img_pil.size[::-1],
117
+ mode="bicubic", align_corners=False).squeeze()
118
+ d = (d - d.min()) / (d.max() - d.min() + 1e-8)
119
+ return d
120
+
121
+ def safe_dilate(mask_np, k=5, it=2):
122
+ if mask_np.sum() == 0: return mask_np
123
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k,k))
124
+ return cv2.dilate(mask_np, kernel, iterations=it)
 
 
125
 
126
  @torch.no_grad()
127
+ def lama_inpaint(img_bgr, mask_np):
128
+ if lama_model is None or mask_np.sum() == 0:
129
+ return img_bgr
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ mask_dil = safe_dilate((mask_np*255).astype(np.uint8), k=7, it=3) / 255.0
 
132
 
133
+ h, w = img_bgr.shape[:2]
134
+ nh, nw = (h//8)*8, (w//8)*8
135
+ img_res = cv2.resize(img_bgr, (nw, nh))
136
+ mask_res = cv2.resize(mask_dil, (nw, nh), interpolation=cv2.INTER_NEAREST)
 
 
137
 
138
+ img_t = torch.from_numpy(img_res).float().permute(2,0,1).unsqueeze(0)/255.0
139
+ img_t = img_t[:,[2,1,0]].to(device)
140
+ mask_t = torch.from_numpy(mask_res > 0.5).float().unsqueeze(0).unsqueeze(0).to(device)
141
+
142
+ img_t = img_t * (1 - mask_t)
143
+ out = lama_model(img_t, mask_t)[0].permute(1,2,0).cpu().numpy()
144
+ out = np.clip(out*255, 0, 255).astype(np.uint8)
145
+ out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
146
+ if (nh,nw) != (h,w):
147
+ out = cv2.resize(out, (w,h))
148
+ return out
149
+
150
+ def make_anaglyph(l, r):
151
+ l = np.array(l); r = np.array(r)
152
  a = np.zeros_like(l)
153
+ a[...,0] = l[...,0]
154
+ a[...,1] = r[...,1]
155
+ a[...,2] = r[...,2]
156
  return Image.fromarray(a)
157
 
158
+ # ==================== 5. MAIN PIPELINE ====================
 
 
159
  @torch.no_grad()
160
+ def stereo_pipeline(img_pil, strength=3.2, convergence=0.08):
161
+ if img_pil is None: return None,None,None,None
 
162
 
163
+ w, h = img_pil.size
164
  if w > 1920:
165
+ ratio = 1920/w
166
+ img_pil = img_pil.resize((int(w*ratio), int(h*ratio)), Image.LANCZOS)
167
+ w, h = img_pil.size
 
 
 
168
 
169
+ depth = estimate_depth(img_pil)
170
+ disp = torch.clamp(depth**2, max=torch.quantile(depth**2, 0.995))
 
171
 
172
+ max_shift = w * strength / 100.0
173
+ shift = disp * max_shift
174
+ shift = shift - shift.min() - convergence * (shift.max() - shift.min())
 
 
175
 
176
+ tensor = torch.from_numpy(np.array(img_pil)).float().to(device)/255.0
177
+ tensor = tensor.permute(2,0,1).unsqueeze(0)
178
 
179
+ right, occ = stereo_warper(tensor, shift.unsqueeze(0), disp.unsqueeze(0))
 
 
 
 
180
 
181
+ right_np = (right.squeeze(0).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
182
+ right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
183
+ mask_np = occ.squeeze(0).cpu().numpy()
184
 
185
+ # Two-pass LaMa (safe + perfect edges)
186
+ right_filled = lama_inpaint(right_bgr, mask_np)
187
+ right_filled = lama_inpaint(right_filled, mask_np) # second pass
188
 
189
+ right_pil = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
190
 
191
+ sbs = Image.new("RGB", (w*2, h))
192
+ sbs.paste(img_pil, (0,0))
193
+ sbs.paste(right_pil, (w,0))
194
 
195
+ ana = make_anaglyph(img_pil, right_pil)
196
+ depth_vis = Image.fromarray((depth.cpu().numpy()*255).astype(np.uint8))
197
+ mask_vis = Image.fromarray((mask_np*255).astype(np.uint8))
198
 
199
+ return sbs, ana, depth_vis, mask_vis
200
 
201
+ # ==================== 6. GRADIO UI ====================
202
+ css = ".gradio-container {max-width: 1450px !important; margin: auto !important;}"
203
+ with gr.Blocks() as demo:
204
+ gr.HTML(f"<style>{css}</style>")
205
+ gr.Markdown("# 2D β†’ 3D Stereo – Pro Quality\nDepth Anything V2 + Forward Warp + Smart LaMa")
 
 
 
 
 
 
206
 
207
  with gr.Row():
208
+ with gr.Column():
209
+ inp = gr.Image(type="pil", label="Upload Image", height=520)
 
210
  with gr.Accordion("Settings", open=True):
211
+ strength = gr.Slider(0.5, 8, 3.2, step=0.1, label="3D Strength (%)")
212
+ conv = gr.Slider(0, 1, 0.08, step=0.01, label="Convergence (0=pop-out)")
213
+ btn = gr.Button("Generate 3D", variant="primary")
 
 
 
214
 
215
+ with gr.Column():
216
+ out_ana = gr.Image(label="Anaglyph (Red/Cyan)", height=520)
217
+ out_sbs = gr.Image(label="Side-by-Side", height=320)
218
  with gr.Row():
219
+ gr.Image(label="Depth Map", height=200)
220
+ gr.Image(label="Mask", height=200)
221
 
222
+ btn.click(stereo_pipeline, [inp, strength, conv],
223
+ [out_sbs, out_ana, gr.Image(), gr.Image()])
 
224
 
225
+ gr.Markdown("**Red/Cyan glasses** β†’ anaglyph β€’ **Cross-eye/parallel** β†’ side-by-side")
226
 
227
+ demo.launch(share=True)