enoky commited on
Commit
fe2b283
·
verified ·
1 Parent(s): 79bdec3

switch to Depth Anything V2 Large

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
- from transformers import DPTForDepthEstimation, DPTImageProcessor
7
  from huggingface_hub import hf_hub_download
8
  import os
9
 
@@ -13,22 +13,23 @@ print(f"Running on device: {device}")
13
 
14
  # === LOAD MODELS ===
15
  def load_models():
16
- print("Loading Depth Model...")
17
- # 1. Depth Model
18
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
19
- depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
 
 
 
 
 
20
 
21
  print("Loading LaMa Inpainting Model...")
22
  # 2. LaMa Inpainting Model (TorchScript)
23
- # We download the .pt file directly from a repository that hosts the compiled JIT version.
24
  try:
25
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
26
-
27
  print(f"Loading LaMa from: {model_path}")
28
- # Load the TorchScript model
29
  lama_model = torch.jit.load(model_path, map_location=device)
30
  lama_model.eval()
31
-
32
  except Exception as e:
33
  print(f"Error loading LaMa model: {e}")
34
  raise e
@@ -42,9 +43,14 @@ depth_model, depth_processor, lama_model = load_models()
42
  @torch.no_grad()
43
  def estimate_depth(image_pil, model, processor):
44
  original_size = image_pil.size
 
 
45
  inputs = processor(images=image_pil, return_tensors="pt").to(device)
 
 
46
  depth = model(**inputs).predicted_depth
47
 
 
48
  depth = torch.nn.functional.interpolate(
49
  depth.unsqueeze(1),
50
  size=(original_size[1], original_size[0]),
@@ -52,6 +58,7 @@ def estimate_depth(image_pil, model, processor):
52
  align_corners=False,
53
  ).squeeze().detach().cpu().numpy()
54
 
 
55
  depth_min, depth_max = depth.min(), depth.max()
56
  if depth_max - depth_min > 0:
57
  return (depth - depth_min) / (depth_max - depth_min)
@@ -65,8 +72,7 @@ def generate_right_and_mask(image, shift_map):
65
  target_x = x_coords - shift
66
 
67
  right = np.zeros_like(image)
68
- # Mask: 1 (or 255) means HOLE/MISSING info.
69
- # Initialize as all holes (255)
70
  mask = np.ones((height, width), dtype=np.float32)
71
 
72
  valid_mask = (target_x >= 0) & (target_x < width)
@@ -75,7 +81,7 @@ def generate_right_and_mask(image, shift_map):
75
  flat_x_source = x_coords[valid_mask]
76
 
77
  right[flat_y, flat_x_target] = image[flat_y, flat_x_source]
78
- # Mark written pixels as valid (0)
79
  mask[flat_y, flat_x_target] = 0.0
80
 
81
  return right, mask
@@ -89,8 +95,7 @@ def run_local_lama(image_bgr, mask_float):
89
  mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
90
  """
91
  # 0. Dilate Mask (Fixes smearing/streaking)
92
- # We expand the "hole" area (values of 1) to cover the jagged edges
93
- # created by the pixel shift. This forces LaMa to regenerate the boundary.
94
  kernel = np.ones((5, 5), np.uint8)
95
  mask_uint8 = (mask_float * 255).astype(np.uint8)
96
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
@@ -104,23 +109,18 @@ def run_local_lama(image_bgr, mask_float):
104
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
105
 
106
  # 2. Convert to Torch Tensors
107
- # Image: (1, 3, H, W), RGB, 0-1
108
  img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
109
  # Swap BGR to RGB
110
  img_t = img_t[:, [2, 1, 0], :, :]
111
 
112
- # Mask: (1, 1, H, W), 0-1
113
  mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
114
- # Binary threshold just in case
115
  mask_t = (mask_t > 0.5).float()
116
 
117
  img_t = img_t.to(device)
118
  mask_t = mask_t.to(device)
119
 
120
  # 3. Inference
121
- # LaMa expects the image to be masked (zeroed out) in the hole regions for best results
122
- img_t = img_t * (1 - mask_t)
123
-
124
  inpainted_t = lama_model(img_t, mask_t)
125
 
126
  # 4. Post-process
@@ -130,7 +130,7 @@ def run_local_lama(image_bgr, mask_float):
130
  # Swap back RGB to BGR
131
  inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
132
 
133
- # Resize back to original if needed
134
  if new_h != h or new_w != w:
135
  inpainted = cv2.resize(inpainted, (w, h))
136
 
@@ -150,10 +150,9 @@ def stereo_pipeline(image_pil, divergence, convergence):
150
  if image_pil is None:
151
  return None, None
152
 
153
- # Convert to BGR for OpenCV processing
154
  image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
155
 
156
- # 1. Depth
157
  depth = estimate_depth(image_pil, depth_model, depth_processor)
158
 
159
  # 2. Shift Map
@@ -162,7 +161,7 @@ def stereo_pipeline(image_pil, divergence, convergence):
162
  # 3. Warping
163
  right_img, mask = generate_right_and_mask(image_cv, shift)
164
 
165
- # 4. Inpainting (Local)
166
  right_filled = run_local_lama(right_img, mask)
167
 
168
  left = image_pil
@@ -180,8 +179,8 @@ def stereo_pipeline(image_pil, divergence, convergence):
180
 
181
  # === GRADIO UI ===
182
  with gr.Blocks(title="2D to 3D Stereo") as demo:
183
- gr.Markdown("## 2D to 3D Stereo Generator (Fully Local)")
184
- gr.Markdown("Generates stereo pairs using Depth Estimation and **Local LaMa Inpainting**. No external APIs required.")
185
 
186
  with gr.Row():
187
  with gr.Column(scale=1):
 
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
+ from transformers import AutoModelForDepthEstimation, AutoImageProcessor
7
  from huggingface_hub import hf_hub_download
8
  import os
9
 
 
13
 
14
  # === LOAD MODELS ===
15
  def load_models():
16
+ print("Loading Depth Anything V2 Large...")
17
+ # 1. Depth Model (Depth Anything V2 Large)
18
+ # We use AutoModel to automatically load the correct architecture
19
+ depth_model = AutoModelForDepthEstimation.from_pretrained(
20
+ "depth-anything/Depth-Anything-V2-Large-hf"
21
+ ).to(device)
22
+ depth_processor = AutoImageProcessor.from_pretrained(
23
+ "depth-anything/Depth-Anything-V2-Large-hf"
24
+ )
25
 
26
  print("Loading LaMa Inpainting Model...")
27
  # 2. LaMa Inpainting Model (TorchScript)
 
28
  try:
29
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
 
30
  print(f"Loading LaMa from: {model_path}")
 
31
  lama_model = torch.jit.load(model_path, map_location=device)
32
  lama_model.eval()
 
33
  except Exception as e:
34
  print(f"Error loading LaMa model: {e}")
35
  raise e
 
43
  @torch.no_grad()
44
  def estimate_depth(image_pil, model, processor):
45
  original_size = image_pil.size
46
+
47
+ # Preprocess image
48
  inputs = processor(images=image_pil, return_tensors="pt").to(device)
49
+
50
+ # Inference
51
  depth = model(**inputs).predicted_depth
52
 
53
+ # Interpolate depth back to ORIGINAL image size
54
  depth = torch.nn.functional.interpolate(
55
  depth.unsqueeze(1),
56
  size=(original_size[1], original_size[0]),
 
58
  align_corners=False,
59
  ).squeeze().detach().cpu().numpy()
60
 
61
+ # Normalize depth to 0-1 range
62
  depth_min, depth_max = depth.min(), depth.max()
63
  if depth_max - depth_min > 0:
64
  return (depth - depth_min) / (depth_max - depth_min)
 
72
  target_x = x_coords - shift
73
 
74
  right = np.zeros_like(image)
75
+ # Mask: 1.0 means HOLE/MISSING info
 
76
  mask = np.ones((height, width), dtype=np.float32)
77
 
78
  valid_mask = (target_x >= 0) & (target_x < width)
 
81
  flat_x_source = x_coords[valid_mask]
82
 
83
  right[flat_y, flat_x_target] = image[flat_y, flat_x_source]
84
+ # Mark written pixels as valid (0.0)
85
  mask[flat_y, flat_x_target] = 0.0
86
 
87
  return right, mask
 
95
  mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
96
  """
97
  # 0. Dilate Mask (Fixes smearing/streaking)
98
+ # We expand the "hole" area to cover jagged edges
 
99
  kernel = np.ones((5, 5), np.uint8)
100
  mask_uint8 = (mask_float * 255).astype(np.uint8)
101
  mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
 
109
  mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
110
 
111
  # 2. Convert to Torch Tensors
 
112
  img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
113
  # Swap BGR to RGB
114
  img_t = img_t[:, [2, 1, 0], :, :]
115
 
 
116
  mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
 
117
  mask_t = (mask_t > 0.5).float()
118
 
119
  img_t = img_t.to(device)
120
  mask_t = mask_t.to(device)
121
 
122
  # 3. Inference
123
+ img_t = img_t * (1 - mask_t) # Zero out holes
 
 
124
  inpainted_t = lama_model(img_t, mask_t)
125
 
126
  # 4. Post-process
 
130
  # Swap back RGB to BGR
131
  inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
132
 
133
+ # Resize back to original
134
  if new_h != h or new_w != w:
135
  inpainted = cv2.resize(inpainted, (w, h))
136
 
 
150
  if image_pil is None:
151
  return None, None
152
 
 
153
  image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
154
 
155
+ # 1. Depth (Using Depth Anything V2)
156
  depth = estimate_depth(image_pil, depth_model, depth_processor)
157
 
158
  # 2. Shift Map
 
161
  # 3. Warping
162
  right_img, mask = generate_right_and_mask(image_cv, shift)
163
 
164
+ # 4. Inpainting (Local LaMa)
165
  right_filled = run_local_lama(right_img, mask)
166
 
167
  left = image_pil
 
179
 
180
  # === GRADIO UI ===
181
  with gr.Blocks(title="2D to 3D Stereo") as demo:
182
+ gr.Markdown("## 2D to 3D Stereo Generator (Depth Anything V2)")
183
+ gr.Markdown("Generates stereo pairs using **Depth Anything V2 Large** and Local LaMa Inpainting.")
184
 
185
  with gr.Row():
186
  with gr.Column(scale=1):