enoky commited on
Commit
b89295c
Β·
verified Β·
1 Parent(s): 549ff77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -134
app.py CHANGED
@@ -6,53 +6,52 @@ import cv2
6
  from PIL import Image
7
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
8
  from huggingface_hub import hf_hub_download
9
- import os
10
 
11
  # === DEVICE ===
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Running on device: {device}")
14
 
15
  # ==============================================================================
16
- # 1. SAFE & FAST FORWARD WARPER USING grid_sample (NO MORE BLACK IMAGES!)
17
  # ==============================================================================
18
  class SafeForwardWarp(nn.Module):
19
  def forward(self, img, flow):
20
  """
21
- img: [B, C, H, W] in [0,1]
22
- flow: [B, H, W, 2] flow[...,0] = delta_x (positive = right), flow[...,1] = delta_y
23
  """
24
  B, C, H, W = img.shape
25
 
26
- # Create sampling grid in normalized coordinates [-1, 1]
27
- grid_x, grid_y = torch.meshgrid(
28
- torch.arange(W, device=img.device),
29
- torch.arange(H, device=img.device),
30
- indexing='ij'
31
- )
32
- grid_x = grid_x.float().unsqueeze(0).expand(B, -1, -1) # [B, H, W]
33
- grid_y = grid_y.float().unsqueeze(0).expand(B, -1, -1)
34
 
35
- dest_x = grid_x + flow[..., 0] # source pixel moves to x + dx
36
  dest_y = grid_y + flow[..., 1]
37
 
38
  # Normalize to [-1, 1]
39
- norm_x = 2.0 * dest_x / (W - 1) - 1.0
40
- norm_y = 2.0 * dest_y / (H - 1) - 1.0
41
 
42
- grid = torch.stack((norm_x, norm_y), dim=-1) # [B, H, W, 2]
43
- grid = grid.clamp(-1, 1)
44
 
45
  warped = torch.nn.functional.grid_sample(
46
  img,
47
  grid,
48
- mode='bilinear',
49
- padding_mode='zeros',
50
- align_corners=True
51
  )
52
  return warped
53
 
54
  # ==============================================================================
55
- # 2. STEREO WARPER β€” Improved weighting + safer dilation
56
  # ==============================================================================
57
  class ForwardWarpStereo(nn.Module):
58
  def __init__(self, eps=1e-6):
@@ -61,46 +60,40 @@ class ForwardWarpStereo(nn.Module):
61
  self.warp = SafeForwardWarp()
62
 
63
  def forward(self, img, shift, disp_for_weights):
64
- # shift: [B, H, W] (positive = shift right-eye left β†’ object pops out)
65
- flow_x = -shift # negative = move pixels left for right eye
66
  flow_y = torch.zeros_like(flow_x)
67
- flow = torch.stack((flow_x, flow_y), dim=-1) # [B, H, W, 2]
68
 
69
- # Better weighting: closer pixels contribute more
70
  weights = 1.0 / (disp_for_weights + 0.1)
71
  weights = weights / (weights.max() + 1e-8)
72
 
73
- weighted_img = img * weights.unsqueeze(1)
74
- warped_img = self.warp(weighted_img, flow)
75
- warped_weights = self.warp(weights.unsqueeze(1), flow)
76
-
77
- # Avoid division by zero
78
- warped_weights = torch.clamp(warped_weights, min=self.eps)
79
- result = warped_img / warped_weights
80
 
81
- # Occlusion mask via occupancy count
82
  ones = torch.ones_like(img[:, :1])
83
  occupancy = self.warp(ones, flow)
84
  occlusion = (occupancy < self.eps).float()
85
 
86
- # Smart dilation β€” preserve foreground edges
87
  with torch.no_grad():
88
- fg_thresh = torch.quantile(disp_for_weights, 0.90)
89
- fg_mask = (disp_for_weights > fg_thresh).float().unsqueeze(0)
90
-
91
  k = 9
92
  dilated = torch.nn.functional.conv2d(
93
  occlusion,
94
- torch.ones(1, 1, k, k, device=occlusion.device),
95
- padding=k // 2
96
  ) > 0.5
97
- safe_dilation = dilated.float() * (1 - fg_mask)
98
- occlusion = torch.clamp(occlusion + safe_dilation, 0, 1)
99
 
100
  return result, occlusion
101
 
102
  # ==============================================================================
103
- # 3. MODELS & HELPERS
104
  # ==============================================================================
105
  def load_models():
106
  print("Loading Depth Anything V2 Large...")
@@ -111,82 +104,80 @@ def load_models():
111
  "depth-anything/Depth-Anything-V2-Large-hf"
112
  )
113
 
114
- print("Loading LaMa Inpainting Model...")
115
  try:
116
- model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
117
- lama_model = torch.jit.load(model_path, map_location=device).eval()
118
  except Exception as e:
119
- print(f"LaMa load failed: {e}")
120
  lama_model = None
121
 
122
- stereo_warper = ForwardWarpStereo().to(device)
123
-
124
- return depth_model, depth_processor, lama_model, stereo_warper
125
 
126
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
127
 
 
 
 
128
  @torch.no_grad()
129
- def estimate_depth(image_pil):
130
- original_size = image_pil.size
131
- inputs = depth_processor(images=image_pil, return_tensors="pt").to(device)
132
- outputs = depth_model(**inputs)
133
- depth = outputs.predicted_depth
134
-
135
- depth = torch.nn.functional.interpolate(
136
- depth.unsqueeze(1),
137
- size=(original_size[1], original_size[0]),
138
  mode="bicubic",
139
  align_corners=False,
140
- ).squeeze(0).squeeze(0)
141
 
142
- # Normalize to [0,1]
143
- d_min, d_max = depth.min(), depth.max()
144
- if d_max > d_min:
145
- depth = (depth - d_min) / (d_max - d_min)
146
- return depth
147
 
148
  @torch.no_grad()
149
- def run_lama(image_bgr, mask_float):
150
  if lama_model is None:
151
- return image_bgr
152
-
153
- mask_uint8 = (mask_float * 255).astype(np.uint8)
154
  kernel = np.ones((7, 7), np.uint8)
155
- mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=2)
156
-
157
- h, w = image_bgr.shape[:2]
158
- new_h = (h // 8) * 8
159
- new_w = (w // 8) * 8
160
- img_resized = cv2.resize(image_bgr, (new_w, new_h))
161
- mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
162
-
163
- img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
164
- img_t = img_t[:, [2, 1, 0]].to(device) # BGR β†’ RGB
165
- mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
166
- mask_t = (mask_t > 0.5).float().to(device)
167
-
168
- img_t = img_t * (1 - mask_t)
169
- inpainted = lama_model(img_t, mask_t)
170
- result = (inpainted[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
171
- result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
172
- if (new_h, new_w) != (h, w):
173
- result = cv2.resize(result, (w, h))
174
- return result
175
 
176
  def make_anaglyph(left, right):
177
  l = np.array(left)
178
  r = np.array(right)
179
  ana = np.zeros_like(l)
180
- ana[:, :, 0] = l[:, :, 0] # Red ← Left
181
- ana[:, :, 1] = r[:, :, 1] # Green ← Right
182
- ana[:, :, 2] = r[:, :, 2] # Blue ← Right
183
  return Image.fromarray(ana)
184
 
185
  # ==============================================================================
186
- # 4. MAIN PIPELINE
187
  # ==============================================================================
188
  @torch.no_grad()
189
- def stereo_pipeline(image_pil, divergence_percent=3.2, convergence_plane=0.08):
190
  if image_pil is None:
191
  return None, None, None, None
192
 
@@ -196,45 +187,44 @@ def stereo_pipeline(image_pil, divergence_percent=3.2, convergence_plane=0.08):
196
  image_pil = image_pil.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
197
  w, h = image_pil.size
198
 
199
- # 1. Depth
200
- depth = estimate_depth(image_pil) # [H, W] in [0,1]
201
  depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8))
202
 
203
- # 2. Disparity (stronger volume with square)
204
- disp_raw = depth ** 2
205
- disp_clipped = torch.clamp(disp_raw, max=torch.quantile(disp_raw, 0.995))
206
 
207
- # 3. Shift
208
  max_shift = w * (divergence_percent / 100.0)
209
- shift_raw = disp_clipped * max_shift
210
  shift_min, shift_max = shift_raw.min(), shift_raw.max()
211
- convergence_offset = shift_min + convergence_plane * (shift_max - shift_min)
212
- final_shift = shift_raw - convergence_offset
213
 
214
- print(f"Final shift range: {final_shift.min():.1f} β†’ {final_shift.max():.1f anywhere} px")
215
 
216
- # 4. Warp right eye
217
- img_tensor = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
218
- img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1,3,H,W]
219
 
220
- shift_tensor = final_shift.unsqueeze(0).to(device) # [1,H,W]
221
- disp_tensor = disp_clipped.unsqueeze(0).to(device)
222
 
223
- right_tensor, occlusion_mask = stereo_warper(img_tensor, shift_tensor, disp_tensor)
224
 
225
- # 5. To numpy
226
- right_np = (right_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
227
  right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
228
- mask_np = occlusion_mask.squeeze(0).cpu().numpy()
229
 
230
- # 6. Inpaint occlusions
231
  right_filled_bgr = run_lama(right_bgr, mask_np)
232
  right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
233
 
234
- # 7. Outputs
235
  mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
236
 
237
- sbs = Image.new('RGB', (w * 2, h))
238
  sbs.paste(image_pil, (0, 0))
239
  sbs.paste(right_filled, (w, 0))
240
 
@@ -243,35 +233,31 @@ def stereo_pipeline(image_pil, divergence_percent=3.2, convergence_plane=0.08):
243
  return sbs, anaglyph, depth_vis, mask_vis
244
 
245
  # ==============================================================================
246
- # 5. GRADIO UI
247
  # ==============================================================================
248
- with gr.Blocks(title="2D β†’ 3D Stereo β€” Pro & Stable") as demo:
249
- gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo β€” Pro Quality (Fixed & Stable)</h1>")
250
- gr.Markdown("Depth Anything V2 + Safe Forward Warping + LaMa Inpainting")
251
 
252
  with gr.Row():
253
  with gr.Column(scale=1):
254
- input_img = gr.Image(type="pil", label="Upload Image", height=520)
255
  with gr.Accordion("Settings", open=True):
256
- divergence = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="3D Strength (%)")
257
- convergence = gr.Slider(0.0, 1.0, value=0.08, step=0.01,
258
- label="Convergence Plane (0 = pop-out, 1 = deep)")
259
- btn = gr.Button("Generate 3D", variant="primary", size="lg")
260
 
261
  with gr.Column(scale=1):
262
- out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=520)
263
- out_sbs = gr.Image(label="Side-by-Side (Cross-eye / Parallel)", height=300)
264
  with gr.Row():
265
- out_depth = gr.Image(label="Depth Map", height=200)
266
- out_mask = gr.Image(label="Occlusion Mask", height=200)
267
 
268
- btn.click(
269
- fn=stereo_pipeline,
270
- inputs=[input_img, divergence, convergence],
271
- outputs=[out_sbs, out_anaglyph, out_depth, out_mask]
272
- )
273
 
274
- gr.Markdown("**Tip:** Use Red/Cyan glasses for anaglyph β€’ Cross-eye or parallel view for SBS")
275
 
276
  if __name__ == "__main__":
277
  demo.launch(share=True)
 
6
  from PIL import Image
7
  from transformers import AutoModelForDepthEstimation, AutoImageProcessor
8
  from huggingface_hub import hf_hub_download
 
9
 
10
  # === DEVICE ===
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  print(f"Running on device: {device}")
13
 
14
  # ==============================================================================
15
+ # 1. SAFE & FAST FORWARD WARPER (grid_sample)
16
  # ==============================================================================
17
  class SafeForwardWarp(nn.Module):
18
  def forward(self, img, flow):
19
  """
20
+ img: [B, C, H, W] float32 in [0,1]
21
+ flow: [B, H, W, 2] flow[...,0]=dx, flow[...,1]=dy
22
  """
23
  B, C, H, W = img.shape
24
 
25
+ grid_y, grid_x = torch.meshgrid(
26
+ torch.arange(H, device=img.device, dtype=torch.float32),
27
+ torch.arange(W, device=img.device, dtype=torch.float32),
28
+ indexing="ij",
29
+ ) # [H,W] each
30
+
31
+ grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) # [B,H,W]
32
+ grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
33
 
34
+ dest_x = grid_x + flow[..., 0]
35
  dest_y = grid_y + flow[..., 1]
36
 
37
  # Normalize to [-1, 1]
38
+ norm_x = dest_x / (W - 1) * 2.0 - 1.0
39
+ norm_y = dest_y / (H - 1) * 2.0 - 1.0
40
 
41
+ grid = torch.stack((norm_x, norm_y), dim=-1) # [B,H,W,2]
42
+ grid = grid.clamp(-1.0, 1.0)
43
 
44
  warped = torch.nn.functional.grid_sample(
45
  img,
46
  grid,
47
+ mode="bilinear",
48
+ padding_mode="zeros",
49
+ align_corners=True,
50
  )
51
  return warped
52
 
53
  # ==============================================================================
54
+ # 2. STEREO WARPER
55
  # ==============================================================================
56
  class ForwardWarpStereo(nn.Module):
57
  def __init__(self, eps=1e-6):
 
60
  self.warp = SafeForwardWarp()
61
 
62
  def forward(self, img, shift, disp_for_weights):
63
+ flow_x = -shift
 
64
  flow_y = torch.zeros_like(flow_x)
65
+ flow = torch.stack((flow_x, flow_y), dim=-1) # [B,H,W,2]
66
 
67
+ # Weighting: nearer = stronger contribution
68
  weights = 1.0 / (disp_for_weights + 0.1)
69
  weights = weights / (weights.max() + 1e-8)
70
 
71
+ warped_img = self.warp(img * weights.unsqueeze(1), flow)
72
+ warped_w = self.warp(weights.unsqueeze(1), flow)
73
+ warped_w = torch.clamp(warped_w, min=self.eps)
74
+ result = warped_img / warped_w
 
 
 
75
 
76
+ # Occupancy β†’ occlusion mask
77
  ones = torch.ones_like(img[:, :1])
78
  occupancy = self.warp(ones, flow)
79
  occlusion = (occupancy < self.eps).float()
80
 
81
+ # Smart dilation (preserve sharp foreground)
82
  with torch.no_grad():
83
+ fg = (disp_for_weights > torch.quantile(disp_for_weights, 0.90)).float().unsqueeze(0)
 
 
84
  k = 9
85
  dilated = torch.nn.functional.conv2d(
86
  occlusion,
87
+ torch.ones(1, 1, k, k, device=device),
88
+ padding=k // 2,
89
  ) > 0.5
90
+ safe_dilate = dilated.float() * (1 - fg)
91
+ occlusion = torch.clamp(occlusion + safe_dilate, 0, 1)
92
 
93
  return result, occlusion
94
 
95
  # ==============================================================================
96
+ # 3. MODELS
97
  # ==============================================================================
98
  def load_models():
99
  print("Loading Depth Anything V2 Large...")
 
104
  "depth-anything/Depth-Anything-V2-Large-hf"
105
  )
106
 
107
+ print("Loading LaMa...")
108
  try:
109
+ path = hf_hub_download("fashn-ai/LaMa", "big-lama.pt")
110
+ lama_model = torch.jit.load(path, map_location=device).eval()
111
  except Exception as e:
112
+ print("LaMa failed β†’ running without inpainting:", e)
113
  lama_model = None
114
 
115
+ warper = ForwardWarpStereo().to(device)
116
+ return depth_model, depth_processor, lama_model, warper
 
117
 
118
  depth_model, depth_processor, lama_model, stereo_warper = load_models()
119
 
120
+ # ==============================================================================
121
+ # 4. HELPERS
122
+ # ==============================================================================
123
  @torch.no_grad()
124
+ def estimate_depth(pil_img):
125
+ w, h = pil_img.size
126
+ inputs = depth_processor(images=pil_img, return_tensors="pt").to(device)
127
+ pred = depth_model(**inputs).predicted_depth[0] # [H,W]
128
+
129
+ pred = torch.nn.functional.interpolate(
130
+ pred.unsqueeze(0).unsqueeze(0),
131
+ size=(h, w),
 
132
  mode="bicubic",
133
  align_corners=False,
134
+ )[0, 0]
135
 
136
+ mi, ma = pred.min(), pred.max()
137
+ if ma > mi:
138
+ pred = (pred - mi) / (ma - mi)
139
+ return pred
 
140
 
141
  @torch.no_grad()
142
+ def run_lama(bgr_img, mask_float):
143
  if lama_model is None:
144
+ return bgr_img
145
+ mask_u8 = (mask_float * 255).astype(np.uint8)
 
146
  kernel = np.ones((7, 7), np.uint8)
147
+ mask_dil = cv2.dilate(mask_u8, kernel, iterations=2)
148
+
149
+ h, w = bgr_img.shape[:2]
150
+ nh, nw = (h // 8) * 8, (w // 8) * 8
151
+ img_res = cv2.resize(bgr_img, (nw, nh))
152
+ mask_res = cv2.resize(mask_dil, (nw, nh), interpolation=cv2.INTER_NEAREST)
153
+
154
+ t = torch.from_numpy(img_res).float().permute(2, 0, 1).unsqueeze(0) / 255.0
155
+ t = t[:, [2, 1, 0]].to(device) # BGR→RGB
156
+ m = torch.from_numpy(mask_res).float().unsqueeze(0).unsqueeze(0) / 255.0
157
+ m = (m > 0.5).float().to(device)
158
+
159
+ t = t * (1 - m)
160
+ out = lama_model(t, m)
161
+ out = (out[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
162
+ out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
163
+ if (nh, nw) != (h, w):
164
+ out = cv2.resize(out, (w, h))
165
+ return out
 
166
 
167
  def make_anaglyph(left, right):
168
  l = np.array(left)
169
  r = np.array(right)
170
  ana = np.zeros_like(l)
171
+ ana[..., 0] = l[..., 0] # Red ← left eye
172
+ ana[..., 1] = r[..., 1] # Green ← right eye
173
+ ana[..., 2] = r[..., 2] # Blue ← right eye
174
  return Image.fromarray(ana)
175
 
176
  # ==============================================================================
177
+ # 5. MAIN PIPELINE
178
  # ==============================================================================
179
  @torch.no_grad()
180
+ def stereo_pipeline(image_pil, divergence_percent=3.5, convergence_plane=0.08):
181
  if image_pil is None:
182
  return None, None, None, None
183
 
 
187
  image_pil = image_pil.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
188
  w, h = image_pil.size
189
 
190
+ # Depth
191
+ depth = estimate_depth(image_pil) # [H,W] in [0,1]
192
  depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8))
193
 
194
+ # Disparity
195
+ disp = torch.clamp(depth ** 2, max=torch.quantile(depth ** 2, 0.995))
 
196
 
197
+ # Shift
198
  max_shift = w * (divergence_percent / 100.0)
199
+ shift_raw = disp * max_shift
200
  shift_min, shift_max = shift_raw.min(), shift_raw.max()
201
+ offset = shift_min + convergence_plane * (shift_max - shift_min)
202
+ final_shift = shift_raw - offset
203
 
204
+ print(f"Final shift range: {final_shift.min():.1f} β†’ {final_shift.max():.1f} px")
205
 
206
+ # Warp right eye
207
+ img_t = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
208
+ img_t = img_t.permute(2, 0, 1).unsqueeze(0) # [1,3,H,W]
209
 
210
+ shift_t = final_shift.unsqueeze(0).to(device) # [1,H,W]
211
+ disp_t = disp.unsqueeze(0).to(device)
212
 
213
+ right_t, occ_mask = stereo_warper(img_t, shift_t, disp_t)
214
 
215
+ # To numpy
216
+ right_np = (right_t[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
217
  right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
218
+ mask_np = occ_mask[0, 0].cpu().numpy()
219
 
220
+ # Inpaint
221
  right_filled_bgr = run_lama(right_bgr, mask_np)
222
  right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
223
 
224
+ # Outputs
225
  mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
226
 
227
+ sbs = Image.new("RGB", (w * 2, h))
228
  sbs.paste(image_pil, (0, 0))
229
  sbs.paste(right_filled, (w, 0))
230
 
 
233
  return sbs, anaglyph, depth_vis, mask_vis
234
 
235
  # ==============================================================================
236
+ # 6. GRADIO UI
237
  # ==============================================================================
238
+ with gr.Blocks(title="2D β†’ 3D Stereo β€” Stable & Fixed") as demo:
239
+ gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo β€” Rock-Solid Version</h1>")
240
+ gr.Markdown("Depth Anything V2 + Safe Warping + LaMa Inpainting")
241
 
242
  with gr.Row():
243
  with gr.Column(scale=1):
244
+ inp = gr.Image(type="pil", label="Upload Image", height=520)
245
  with gr.Accordion("Settings", open=True):
246
+ div = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="3D Strength (%)")
247
+ conv = gr.Slider(0.0, 1.0, value=0.08, step=0.01, label="Convergence (0=pop-out, 1=deep)")
248
+ btn = gr.Button("Generate 3D", variant="primary")
 
249
 
250
  with gr.Column(scale=1):
251
+ out_ana = gr.Image(label="Anaglyph (Red/Cyan)", height=520)
252
+ out_sbs = gr.Image(label="Side-by-Side", height=300)
253
  with gr.Row():
254
+ out_dep = gr.Image(label="Depth Map", height=200)
255
+ out_msk = gr.Image(label="Occlusion Mask", height=200)
256
 
257
+ btn.click(stereo_pipeline, inputs=[inp, div, conv],
258
+ outputs=[out_sbs, out_ana, out_dep, out_msk])
 
 
 
259
 
260
+ gr.Markdown("**Tip:** Red/Cyan glasses β†’ anaglyph β€’ Cross-eye / parallel β†’ SBS")
261
 
262
  if __name__ == "__main__":
263
  demo.launch(share=True)