enoky commited on
Commit
b4c58d3
·
verified ·
1 Parent(s): b89295c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -214
app.py CHANGED
@@ -4,260 +4,332 @@ import torch.nn as nn
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
 
7
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
8
  from huggingface_hub import hf_hub_download
9
-
10
  # === DEVICE ===
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  print(f"Running on device: {device}")
13
-
14
  # ==============================================================================
15
- # 1. SAFE & FAST FORWARD WARPER (grid_sample)
16
  # ==============================================================================
17
- class SafeForwardWarp(nn.Module):
18
- def forward(self, img, flow):
19
- """
20
- img: [B, C, H, W] float32 in [0,1]
21
- flow: [B, H, W, 2] flow[...,0]=dx, flow[...,1]=dy
22
- """
23
- B, C, H, W = img.shape
24
-
25
- grid_y, grid_x = torch.meshgrid(
26
- torch.arange(H, device=img.device, dtype=torch.float32),
27
- torch.arange(W, device=img.device, dtype=torch.float32),
28
- indexing="ij",
29
- ) # [H,W] each
30
-
31
- grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) # [B,H,W]
32
- grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
33
-
34
- dest_x = grid_x + flow[..., 0]
35
- dest_y = grid_y + flow[..., 1]
36
-
37
- # Normalize to [-1, 1]
38
- norm_x = dest_x / (W - 1) * 2.0 - 1.0
39
- norm_y = dest_y / (H - 1) * 2.0 - 1.0
40
-
41
- grid = torch.stack((norm_x, norm_y), dim=-1) # [B,H,W,2]
42
- grid = grid.clamp(-1.0, 1.0)
43
-
44
- warped = torch.nn.functional.grid_sample(
45
- img,
46
- grid,
47
- mode="bilinear",
48
- padding_mode="zeros",
49
- align_corners=True,
50
  )
51
- return warped
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # ==============================================================================
54
- # 2. STEREO WARPER
55
  # ==============================================================================
56
  class ForwardWarpStereo(nn.Module):
 
 
 
 
57
  def __init__(self, eps=1e-6):
58
- super().__init__()
59
  self.eps = eps
60
- self.warp = SafeForwardWarp()
61
-
62
- def forward(self, img, shift, disp_for_weights):
 
 
 
 
 
 
 
63
  flow_x = -shift
 
64
  flow_y = torch.zeros_like(flow_x)
65
- flow = torch.stack((flow_x, flow_y), dim=-1) # [B,H,W,2]
66
-
67
- # Weighting: nearer = stronger contribution
68
- weights = 1.0 / (disp_for_weights + 0.1)
69
- weights = weights / (weights.max() + 1e-8)
70
-
71
- warped_img = self.warp(img * weights.unsqueeze(1), flow)
72
- warped_w = self.warp(weights.unsqueeze(1), flow)
73
- warped_w = torch.clamp(warped_w, min=self.eps)
74
- result = warped_img / warped_w
75
-
76
- # Occupancy occlusion mask
77
- ones = torch.ones_like(img[:, :1])
78
- occupancy = self.warp(ones, flow)
79
- occlusion = (occupancy < self.eps).float()
80
-
81
- # Smart dilation (preserve sharp foreground)
82
- with torch.no_grad():
83
- fg = (disp_for_weights > torch.quantile(disp_for_weights, 0.90)).float().unsqueeze(0)
84
- k = 9
85
- dilated = torch.nn.functional.conv2d(
86
- occlusion,
87
- torch.ones(1, 1, k, k, device=device),
88
- padding=k // 2,
89
- ) > 0.5
90
- safe_dilate = dilated.float() * (1 - fg)
91
- occlusion = torch.clamp(occlusion + safe_dilate, 0, 1)
92
-
93
- return result, occlusion
94
-
95
  # ==============================================================================
96
- # 3. MODELS
97
  # ==============================================================================
 
98
  def load_models():
99
  print("Loading Depth Anything V2 Large...")
100
  depth_model = AutoModelForDepthEstimation.from_pretrained(
101
  "depth-anything/Depth-Anything-V2-Large-hf"
102
- ).to(device).eval()
103
  depth_processor = AutoImageProcessor.from_pretrained(
104
  "depth-anything/Depth-Anything-V2-Large-hf"
105
  )
106
-
107
- print("Loading LaMa...")
108
  try:
109
- path = hf_hub_download("fashn-ai/LaMa", "big-lama.pt")
110
- lama_model = torch.jit.load(path, map_location=device).eval()
 
111
  except Exception as e:
112
- print("LaMa failed running without inpainting:", e)
113
- lama_model = None
114
-
115
- warper = ForwardWarpStereo().to(device)
116
- return depth_model, depth_processor, lama_model, warper
117
-
118
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
119
-
120
- # ==============================================================================
121
- # 4. HELPERS
122
- # ==============================================================================
123
  @torch.no_grad()
124
- def estimate_depth(pil_img):
125
- w, h = pil_img.size
126
- inputs = depth_processor(images=pil_img, return_tensors="pt").to(device)
127
- pred = depth_model(**inputs).predicted_depth[0] # [H,W]
128
-
129
- pred = torch.nn.functional.interpolate(
130
- pred.unsqueeze(0).unsqueeze(0),
131
- size=(h, w),
132
  mode="bicubic",
133
  align_corners=False,
134
- )[0, 0]
135
-
136
- mi, ma = pred.min(), pred.max()
137
- if ma > mi:
138
- pred = (pred - mi) / (ma - mi)
139
- return pred
140
-
 
 
 
 
 
 
 
 
 
141
  @torch.no_grad()
142
- def run_lama(bgr_img, mask_float):
143
- if lama_model is None:
144
- return bgr_img
145
- mask_u8 = (mask_float * 255).astype(np.uint8)
146
- kernel = np.ones((7, 7), np.uint8)
147
- mask_dil = cv2.dilate(mask_u8, kernel, iterations=2)
148
-
149
- h, w = bgr_img.shape[:2]
150
- nh, nw = (h // 8) * 8, (w // 8) * 8
151
- img_res = cv2.resize(bgr_img, (nw, nh))
152
- mask_res = cv2.resize(mask_dil, (nw, nh), interpolation=cv2.INTER_NEAREST)
153
-
154
- t = torch.from_numpy(img_res).float().permute(2, 0, 1).unsqueeze(0) / 255.0
155
- t = t[:, [2, 1, 0]].to(device) # BGRRGB
156
- m = torch.from_numpy(mask_res).float().unsqueeze(0).unsqueeze(0) / 255.0
157
- m = (m > 0.5).float().to(device)
158
-
159
- t = t * (1 - m)
160
- out = lama_model(t, m)
161
- out = (out[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
162
- out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
163
- if (nh, nw) != (h, w):
164
- out = cv2.resize(out, (w, h))
165
- return out
166
-
 
 
 
167
  def make_anaglyph(left, right):
168
- l = np.array(left)
169
- r = np.array(right)
170
- ana = np.zeros_like(l)
171
- ana[..., 0] = l[..., 0] # Red ← left eye
172
- ana[..., 1] = r[..., 1] # Green ← right eye
173
- ana[..., 2] = r[..., 2] # Blue ← right eye
174
- return Image.fromarray(ana)
175
-
176
- # ==============================================================================
177
- # 5. MAIN PIPELINE
178
- # ==============================================================================
179
- @torch.no_grad()
180
- def stereo_pipeline(image_pil, divergence_percent=3.5, convergence_plane=0.08):
181
  if image_pil is None:
182
  return None, None, None, None
183
-
184
  w, h = image_pil.size
185
  if w > 1920:
186
  ratio = 1920 / w
187
- image_pil = image_pil.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
188
- w, h = image_pil.size
189
-
190
- # Depth
191
- depth = estimate_depth(image_pil) # [H,W] in [0,1]
192
- depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8))
193
-
194
- # Disparity
195
- disp = torch.clamp(depth ** 2, max=torch.quantile(depth ** 2, 0.995))
196
-
197
- # Shift
198
- max_shift = w * (divergence_percent / 100.0)
199
- shift_raw = disp * max_shift
200
- shift_min, shift_max = shift_raw.min(), shift_raw.max()
201
- offset = shift_min + convergence_plane * (shift_max - shift_min)
202
- final_shift = shift_raw - offset
203
-
204
- print(f"Final shift range: {final_shift.min():.1f} → {final_shift.max():.1f} px")
205
-
206
- # Warp right eye
207
- img_t = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
208
- img_t = img_t.permute(2, 0, 1).unsqueeze(0) # [1,3,H,W]
209
-
210
- shift_t = final_shift.unsqueeze(0).to(device) # [1,H,W]
211
- disp_t = disp.unsqueeze(0).to(device)
212
-
213
- right_t, occ_mask = stereo_warper(img_t, shift_t, disp_t)
214
-
215
- # To numpy
216
- right_np = (right_t[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
217
- right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
218
- mask_np = occ_mask[0, 0].cpu().numpy()
219
-
220
- # Inpaint
221
- right_filled_bgr = run_lama(right_bgr, mask_np)
222
- right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
223
-
224
- # Outputs
225
- mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
226
-
227
- sbs = Image.new("RGB", (w * 2, h))
228
- sbs.paste(image_pil, (0, 0))
229
- sbs.paste(right_filled, (w, 0))
230
-
231
- anaglyph = make_anaglyph(image_pil, right_filled)
232
-
233
- return sbs, anaglyph, depth_vis, mask_vis
234
-
235
- # ==============================================================================
236
- # 6. GRADIO UI
237
- # ==============================================================================
238
- with gr.Blocks(title="2D → 3D Stereo — Stable & Fixed") as demo:
239
- gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo — Rock-Solid Version</h1>")
240
- gr.Markdown("Depth Anything V2 + Safe Warping + LaMa Inpainting")
241
-
242
  with gr.Row():
243
  with gr.Column(scale=1):
244
- inp = gr.Image(type="pil", label="Upload Image", height=520)
245
- with gr.Accordion("Settings", open=True):
246
- div = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="3D Strength (%)")
247
- conv = gr.Slider(0.0, 1.0, value=0.08, step=0.01, label="Convergence (0=pop-out, 1=deep)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  btn = gr.Button("Generate 3D", variant="primary")
249
-
250
  with gr.Column(scale=1):
251
- out_ana = gr.Image(label="Anaglyph (Red/Cyan)", height=520)
252
- out_sbs = gr.Image(label="Side-by-Side", height=300)
253
  with gr.Row():
254
- out_dep = gr.Image(label="Depth Map", height=200)
255
- out_msk = gr.Image(label="Occlusion Mask", height=200)
256
-
257
- btn.click(stereo_pipeline, inputs=[inp, div, conv],
258
- outputs=[out_sbs, out_ana, out_dep, out_msk])
259
-
260
- gr.Markdown("**Tip:** Red/Cyan glasses → anaglyph • Cross-eye / parallel → SBS")
261
-
262
  if __name__ == "__main__":
263
- demo.launch(share=True)
 
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
7
+ from torch.autograd import Function
8
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
9
  from huggingface_hub import hf_hub_download
10
+ import os
11
  # === DEVICE ===
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Running on device: {device}")
 
14
  # ==============================================================================
15
+ # 1. FORWARD WARP IMPLEMENTATION (Native PyTorch)
16
  # ==============================================================================
17
+ class ForwardWarpFunction(Function):
18
+ @staticmethod
19
+ def forward(ctx, im0, flow, interpolation_mode_int):
20
+ # Input validation
21
+ assert (len(im0.shape) == len(flow.shape) == 4)
22
+ assert (interpolation_mode_int == 0 or interpolation_mode_int == 1)
23
+ assert (im0.shape[0] == flow.shape[0])
24
+ assert (im0.shape[-2:] == flow.shape[1:3])
25
+ assert (flow.shape[3] == 2)
26
+ B, C, H, W = im0.shape
27
+ # Create a contiguous output tensor to prevent view/reshape errors
28
+ im1 = torch.zeros(im0.shape, device=im0.device, dtype=im0.dtype).contiguous()
29
+ # Grid creation
30
+ grid_x, grid_y = torch.meshgrid(
31
+ torch.arange(W, device=im0.device, dtype=im0.dtype),
32
+ torch.arange(H, device=im0.device, dtype=im0.dtype),
33
+ indexing='xy'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
+ grid_x = grid_x.unsqueeze(0).expand(B, -1, -1)
36
+ grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
37
+ # Destination coordinates
38
+ x_dest = grid_x + flow[:, :, :, 0]
39
+ y_dest = grid_y + flow[:, :, :, 1]
40
+ if interpolation_mode_int == 0: # Bilinear Splatting
41
+ x_f = torch.floor(x_dest).long()
42
+ y_f = torch.floor(y_dest).long()
43
+ x_c = x_f + 1
44
+ y_c = y_f + 1
45
+ # Weights
46
+ nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
47
+ ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
48
+ sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
49
+ se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
50
+ # Clamp coords
51
+ x_f_clamped = torch.clamp(x_f, 0, W - 1)
52
+ y_f_clamped = torch.clamp(y_f, 0, H - 1)
53
+ x_c_clamped = torch.clamp(x_c, 0, W - 1)
54
+ y_c_clamped = torch.clamp(y_c, 0, H - 1)
55
+ # Per-corner validity masks
56
+ mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
57
+ mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
58
+ mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
59
+ mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
60
+ # Reshape for broadcasting
61
+ nw_k = nw_k.unsqueeze(1)
62
+ ne_k = ne_k.unsqueeze(1)
63
+ sw_k = sw_k.unsqueeze(1)
64
+ se_k = se_k.unsqueeze(1)
65
+ mask_nw = mask_nw.unsqueeze(1)
66
+ mask_ne = mask_ne.unsqueeze(1)
67
+ mask_sw = mask_sw.unsqueeze(1)
68
+ mask_se = mask_se.unsqueeze(1)
69
+ # Flatten indices for scatter_add
70
+ b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
71
+ c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
72
+ base_idx = b_indices * (C * H * W) + c_indices * (H * W)
73
+ # Scatter to 4 neighbors (Accumulate/Splat)
74
+ def scatter_corner(y_idx, x_idx, weights, mask):
75
+ flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
76
+ values = (im0 * weights) * mask.float()
77
+ # Since im1 is contiguous, we can safely use view() for in-place scatter
78
+ im1_flat = im1.view(-1)
79
+ idx_flat = flat_idx.contiguous().view(-1)
80
+ val_flat = values.contiguous().view(-1)
81
+ im1_flat.scatter_add_(0, idx_flat, val_flat)
82
+ scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw) # NW
83
+ scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne) # NE
84
+ scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw) # SW
85
+ scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se) # SE
86
+ else: # Nearest Neighbor (Legacy fallback)
87
+ x_nearest = torch.round(x_dest).long()
88
+ y_nearest = torch.round(y_dest).long()
89
+ valid_mask = (x_nearest >= 0) & (x_nearest < W) & (y_nearest >= 0) & (y_nearest < H)
90
+ valid_mask = valid_mask.unsqueeze(1)
91
+ x_clamped = torch.clamp(x_nearest, 0, W - 1)
92
+ y_clamped = torch.clamp(y_nearest, 0, H - 1)
93
+ b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
94
+ c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
95
+ dest_idx = b_indices * (C * H * W) + c_indices * (H * W) + y_clamped.unsqueeze(1) * W + x_clamped.unsqueeze(
96
+ 1)
97
+ source_values = im0 * valid_mask.float()
98
+ # Since im1 is contiguous, we can safely use view()
99
+ im1.view(-1).scatter_(0, dest_idx.contiguous().view(-1), source_values.contiguous().view(-1))
100
+ return im1
101
+ @staticmethod
102
+ def backward(ctx, grad_output):
103
+ return None, None, None
104
+ class forward_warp(nn.Module):
105
+ def __init__(self, interpolation_mode="Bilinear"):
106
+ super(forward_warp, self).__init__()
107
+ self.interpolation_mode_int = 0 if interpolation_mode == "Bilinear" else 1
108
+ def forward(self, im0, flow):
109
+ return ForwardWarpFunction.apply(im0, flow, self.interpolation_mode_int)
110
  # ==============================================================================
111
+ # 2. STEREO WARPER WRAPPER
112
  # ==============================================================================
113
  class ForwardWarpStereo(nn.Module):
114
+ """
115
+ Weighted Splatting wrapper.
116
+ Handles Occlusions using exponential depth weights (Soft Z-Buffering).
117
+ """
118
  def __init__(self, eps=1e-6):
119
+ super(ForwardWarpStereo, self).__init__()
120
  self.eps = eps
121
+ self.fw = forward_warp(interpolation_mode="Bilinear")
122
+ def forward(self, im, disp, convergence, divergence):
123
+ # disp comes in as [B, 1, H, W] or [1, 1, H, W]
124
+ # We need to squeeze the channel dim to do math with coordinates [B, H, W]
125
+ disp_squeeze = disp.squeeze(1) # Shape [B, H, W]
126
+ # Create Flow from Disparity
127
+ # Shift = (Depth - Convergence) * Divergence
128
+ # We negate it because standard flow is source->dest, but disparity logic varies.
129
+ # For Right Eye view: Target = Source - Shift. So Flow = -Shift.
130
+ shift = (disp_squeeze - convergence) * divergence
131
  flow_x = -shift
132
+ # Stack flow (x, y=0) -> (B, H, W, 2)
133
  flow_y = torch.zeros_like(flow_x)
134
+ # Stack along last dim: [B, H, W] + [B, H, W] -> [B, H, W, 2]
135
+ flow = torch.stack((flow_x, flow_y), dim=-1)
136
+ # 1. Calculate Weights (Soft Z-Buffer)
137
+ # Closer objects (higher disparity) get exponentially higher weight.
138
+ # This allows foreground to overwrite background during accumulation.
139
+ # Using 1.5^disp is a tuned heuristic for separation.
140
+ disp_norm = disp_squeeze / (disp_squeeze.max() + 1e-8)
141
+ weights_map = disp_norm + 0.05
142
+ weights_map = weights_map.unsqueeze(1)
143
+ # 2. Warp Image * Weights (Accumulate Weighted Color)
144
+ # Input im is (B, C, H, W), weights is (B, 1, H, W)
145
+ res_accum = self.fw(im * weights_map, flow)
146
+ # 3. Warp Weights (Accumulate Weights)
147
+ mask_accum = self.fw(weights_map, flow)
148
+ # 4. Normalize (Color / TotalWeight)
149
+ # Add epsilon to avoid divide-by-zero in empty regions
150
+ mask_accum.clamp_(min=self.eps)
151
+ res = res_accum / mask_accum
152
+ # 5. Generate Binary Occlusion Mask (for Inpainting)
153
+ # Splat a grid of ones. Where sum is 0, we have a hole.
154
+ ones = torch.ones_like(disp)
155
+ occupancy = self.fw(ones, flow)
156
+ # Valid pixels have occupancy > 0.
157
+ # We want holes = 1.0, filled = 0.0
158
+ occlusion_mask = (occupancy < self.eps).float()
159
+ return res, occlusion_mask
 
 
 
 
160
  # ==============================================================================
161
+ # 3. APP LOGIC & MODELS
162
  # ==============================================================================
163
+ # === LOAD MODELS ===
164
  def load_models():
165
  print("Loading Depth Anything V2 Large...")
166
  depth_model = AutoModelForDepthEstimation.from_pretrained(
167
  "depth-anything/Depth-Anything-V2-Large-hf"
168
+ ).to(device)
169
  depth_processor = AutoImageProcessor.from_pretrained(
170
  "depth-anything/Depth-Anything-V2-Large-hf"
171
  )
172
+ print("Loading LaMa Inpainting Model...")
 
173
  try:
174
+ model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
175
+ lama_model = torch.jit.load(model_path, map_location=device)
176
+ lama_model.eval()
177
  except Exception as e:
178
+ print(f"Error loading LaMa model: {e}")
179
+ raise e
180
+ # Initialize the new Stereo Warper
181
+ stereo_warper = ForwardWarpStereo().to(device)
182
+ return depth_model, depth_processor, lama_model, stereo_warper
183
+ # Load models once at startup
184
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
185
+ # === DEPTH ESTIMATION ===
 
 
 
186
  @torch.no_grad()
187
+ def estimate_depth(image_pil, model, processor):
188
+ original_size = image_pil.size
189
+ inputs = processor(images=image_pil, return_tensors="pt").to(device)
190
+ depth = model(**inputs).predicted_depth
191
+ depth = torch.nn.functional.interpolate(
192
+ depth.unsqueeze(1),
193
+ size=(original_size[1], original_size[0]),
 
194
  mode="bicubic",
195
  align_corners=False,
196
+ ).squeeze()
197
+ depth_min, depth_max = depth.min(), depth.max()
198
+ if depth_max - depth_min > 0:
199
+ depth = (depth - depth_min) / (depth_max - depth_min)
200
+ else:
201
+ depth = torch.zeros_like(depth)
202
+ return depth
203
+ # === DEPTH MANIPULATION ===
204
+ def erode_depth(depth_tensor, kernel_size):
205
+ if kernel_size <= 0: return depth_tensor
206
+ k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
207
+ x = depth_tensor.unsqueeze(0).unsqueeze(0)
208
+ padding = k // 2
209
+ x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
210
+ return x_eroded.squeeze()
211
+ # === LOCAL INPAINTING ===
212
  @torch.no_grad()
213
+ def run_local_lama(image_bgr, mask_float):
214
+ # 0. Dilate Mask slightly to catch edge artifacts from splatting
215
+ kernel = np.ones((3, 3), np.uint8)
216
+ mask_uint8 = (mask_float * 255).astype(np.uint8)
217
+ mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=3)
218
+ # 1. Resize to be divisible by 8
219
+ h, w = image_bgr.shape[:2]
220
+ new_h = (h // 8) * 8
221
+ new_w = (w // 8) * 8
222
+ img_resized = cv2.resize(image_bgr, (new_w, new_h))
223
+ mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
224
+ # 2. Convert to Torch
225
+ img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
226
+ img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB
227
+ mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
228
+ mask_t = (mask_t > 0.5).float()
229
+ img_t = img_t.to(device)
230
+ mask_t = mask_t.to(device)
231
+ # 3. Inference
232
+ img_t = img_t * (1 - mask_t)
233
+ inpainted_t = lama_model(img_t, mask_t)
234
+ # 4. Post-process
235
+ inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
236
+ inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
237
+ inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
238
+ if new_h != h or new_w != w:
239
+ inpainted = cv2.resize(inpainted, (w, h))
240
+ return inpainted
241
  def make_anaglyph(left, right):
242
+ l_arr = np.array(left)
243
+ r_arr = np.array(right)
244
+ anaglyph = np.zeros_like(l_arr)
245
+ anaglyph[:, :, 0] = l_arr[:, :, 0]
246
+ anaglyph[:, :, 1] = r_arr[:, :, 1]
247
+ anaglyph[:, :, 2] = r_arr[:, :, 2]
248
+ return Image.fromarray(anaglyph)
249
+ # === PIPELINE ===
250
+ def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
 
 
 
 
251
  if image_pil is None:
252
  return None, None, None, None
253
+ # Resize input if too large
254
  w, h = image_pil.size
255
  if w > 1920:
256
  ratio = 1920 / w
257
+ new_h = int(h * ratio)
258
+ image_pil = image_pil.resize((1920, new_h), Image.LANCZOS)
259
+ # 1. Depth Estimation
260
+ depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
261
+ # 2. Depth Erosion (optional halo reduction)
262
+ if edge_erosion > 0:
263
+ depth_tensor = erode_depth(depth_tensor, int(edge_erosion))
264
+ # Visualize Depth
265
+ depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8)
266
+ depth_image = Image.fromarray(depth_vis)
267
+ # 3. Forward Warp (Weighted Bilinear Splatting)
268
+ # Convert image to tensor (B, C, H, W)
269
+ image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2, 0, 1).unsqueeze(0) / 255.0
270
+ # Prepare depth tensor (B, 1, H, W)
271
+ depth_input = depth_tensor.unsqueeze(0).unsqueeze(0)
272
+ # Run the new Stereo Warper
273
+ with torch.no_grad():
274
+ right_img_tensor, mask_tensor = stereo_warper(
275
+ image_tensor,
276
+ depth_input,
277
+ float(convergence),
278
+ float(divergence)
279
+ )
280
+ # Convert results back to CPU/Numpy
281
+ right_img_rgb = (right_img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
282
+ mask_vis = (mask_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255).astype(np.uint8)
283
+ mask_image = Image.fromarray(mask_vis)
284
+ # 4. Inpainting
285
+ right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR)
286
+ mask_float = mask_tensor.squeeze().cpu().numpy()
287
+ right_filled_bgr = run_local_lama(right_img_bgr, mask_float)
288
+ # 5. Finalize
289
+ left = image_pil
290
+ right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
291
+ width, height = left.size
292
+ combined_image = Image.new('RGB', (width * 2, height))
293
+ combined_image.paste(left, (0, 0))
294
+ combined_image.paste(right, (width, 0))
295
+ anaglyph_image = make_anaglyph(left, right)
296
+ return combined_image, anaglyph_image, depth_image, mask_image
297
+ # === GRADIO UI ===
298
+ with gr.Blocks(title="2D to 3D Stereo") as demo:
299
+ # Inject CSS
300
+ gr.Markdown("## 2D to 3D Stereo Generator (High-Quality Splatting)")
301
+ gr.Markdown("Uses **Depth Anything V2**, **Bilinear Weighted Splatting** (Soft Z-Buffer), and **LaMa Inpainting**.")
 
 
 
 
 
 
 
 
 
 
302
  with gr.Row():
303
  with gr.Column(scale=1):
304
+ input_img = gr.Image(type="pil", label="Input Image", height=320)
305
+ with gr.Group():
306
+ gr.Markdown("### 3D Controls")
307
+ divergence_slider = gr.Slider(
308
+ minimum=0, maximum=100, value=30, step=1,
309
+ label="3D Strength (Divergence)",
310
+ info="Max separation in pixels."
311
+ )
312
+ convergence_slider = gr.Slider(
313
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
314
+ label="Focus Plane (Convergence)",
315
+ info="0.0 = Background at screen. 1.0 = Foreground at screen."
316
+ )
317
+ erosion_slider = gr.Slider(
318
+ minimum=0, maximum=20, value=2, step=1,
319
+ label="Edge Masking (Erosion)",
320
+ info="Cleanup edges. Set to 0 for raw splatting."
321
+ )
322
  btn = gr.Button("Generate 3D", variant="primary")
 
323
  with gr.Column(scale=1):
324
+ out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320)
325
+ out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320)
326
  with gr.Row():
327
+ out_depth = gr.Image(label="Depth Map", height=200)
328
+ out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200)
329
+ btn.click(
330
+ fn=stereo_pipeline,
331
+ inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
332
+ outputs=[out_stereo, out_anaglyph, out_depth, out_mask]
333
+ )
 
334
  if __name__ == "__main__":
335
+ demo.launch()