enoky commited on
Commit
db1a689
·
verified ·
1 Parent(s): 53f760e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -199
app.py CHANGED
@@ -1,200 +1,209 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import cv2
5
- from PIL import Image
6
- from transformers import DPTForDepthEstimation, DPTImageProcessor
7
- from huggingface_hub import hf_hub_download
8
- import os
9
-
10
- # === DEVICE ===
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- print(f"Running on device: {device}")
13
-
14
- # === LOAD MODELS ===
15
- def load_models():
16
- print("Loading Depth Model...")
17
- # 1. Depth Model
18
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
19
- depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
20
-
21
- print("Loading LaMa Inpainting Model...")
22
- # 2. LaMa Inpainting Model (TorchScript)
23
- # We download the JIT traced model which is self-contained
24
- model_path = hf_hub_download(repo_id="smartywu/big-lama", filename="big-lama.pt")
25
- lama_model = torch.jit.load(model_path).to(device)
26
- lama_model.eval()
27
-
28
- return depth_model, depth_processor, lama_model
29
-
30
- # Load models once at startup
31
- depth_model, depth_processor, lama_model = load_models()
32
-
33
- # === DEPTH ESTIMATION ===
34
- @torch.no_grad()
35
- def estimate_depth(image_pil, model, processor):
36
- original_size = image_pil.size
37
- inputs = processor(images=image_pil, return_tensors="pt").to(device)
38
- depth = model(**inputs).predicted_depth
39
-
40
- depth = torch.nn.functional.interpolate(
41
- depth.unsqueeze(1),
42
- size=(original_size[1], original_size[0]),
43
- mode="bicubic",
44
- align_corners=False,
45
- ).squeeze().detach().cpu().numpy()
46
-
47
- depth_min, depth_max = depth.min(), depth.max()
48
- if depth_max - depth_min > 0:
49
- return (depth - depth_min) / (depth_max - depth_min)
50
- return depth
51
-
52
- # === STEREO GENERATION LOGIC ===
53
- def generate_right_and_mask(image, shift_map):
54
- height, width = image.shape[:2]
55
- x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
56
- shift = shift_map.astype(int)
57
- target_x = x_coords - shift
58
-
59
- right = np.zeros_like(image)
60
- # Mask: 1 (or 255) means HOLE/MISSING info.
61
- # Initialize as all holes (255)
62
- mask = np.ones((height, width), dtype=np.float32)
63
-
64
- valid_mask = (target_x >= 0) & (target_x < width)
65
- flat_y = y_coords[valid_mask]
66
- flat_x_target = target_x[valid_mask]
67
- flat_x_source = x_coords[valid_mask]
68
-
69
- right[flat_y, flat_x_target] = image[flat_y, flat_x_source]
70
- # Mark written pixels as valid (0)
71
- mask[flat_y, flat_x_target] = 0.0
72
-
73
- return right, mask
74
-
75
- # === LOCAL INPAINTING ===
76
- @torch.no_grad()
77
- def run_local_lama(image_bgr, mask_float):
78
- """
79
- Runs LaMa locally.
80
- image_bgr: HxWx3 uint8 numpy array
81
- mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
82
- """
83
- # 1. Resize to be divisible by 8 (LaMa requirement)
84
- h, w = image_bgr.shape[:2]
85
- new_h = (h // 8) * 8
86
- new_w = (w // 8) * 8
87
-
88
- img_resized = cv2.resize(image_bgr, (new_w, new_h))
89
- mask_resized = cv2.resize(mask_float, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
90
-
91
- # 2. Convert to Torch Tensors
92
- # Image: (1, 3, H, W), RGB, 0-1
93
- img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
94
- # Swap BGR to RGB
95
- img_t = img_t[:, [2, 1, 0], :, :]
96
-
97
- # Mask: (1, 1, H, W), 0-1
98
- mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)
99
- # Binary threshold just in case
100
- mask_t = (mask_t > 0.5).float()
101
-
102
- img_t = img_t.to(device)
103
- mask_t = mask_t.to(device)
104
-
105
- # 3. Inference
106
- inpainted_t = lama_model(img_t, mask_t)
107
-
108
- # 4. Post-process
109
- inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
110
- inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
111
-
112
- # Swap back RGB to BGR
113
- inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
114
-
115
- # Resize back to original if needed
116
- if new_h != h or new_w != w:
117
- inpainted = cv2.resize(inpainted, (w, h))
118
-
119
- return inpainted
120
-
121
- def make_anaglyph(left, right):
122
- l_arr = np.array(left)
123
- r_arr = np.array(right)
124
- anaglyph = np.zeros_like(l_arr)
125
- anaglyph[:, :, 0] = l_arr[:, :, 0]
126
- anaglyph[:, :, 1] = r_arr[:, :, 1]
127
- anaglyph[:, :, 2] = r_arr[:, :, 2]
128
- return Image.fromarray(anaglyph)
129
-
130
- # === PIPELINE ===
131
- def stereo_pipeline(image_pil, divergence, convergence):
132
- if image_pil is None:
133
- return None, None
134
-
135
- # Convert to BGR for OpenCV processing
136
- image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
137
-
138
- # 1. Depth
139
- depth = estimate_depth(image_pil, depth_model, depth_processor)
140
-
141
- # 2. Shift Map
142
- shift = (depth - convergence) * divergence
143
-
144
- # 3. Warping
145
- right_img, mask = generate_right_and_mask(image_cv, shift)
146
-
147
- # 4. Inpainting (Local)
148
- right_filled = run_local_lama(right_img, mask)
149
-
150
- left = image_pil
151
- right = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
152
-
153
- # 5. Composition
154
- width, height = left.size
155
- combined_image = Image.new('RGB', (width * 2, height))
156
- combined_image.paste(left, (0, 0))
157
- combined_image.paste(right, (width, 0))
158
-
159
- anaglyph_image = make_anaglyph(left, right)
160
-
161
- return combined_image, anaglyph_image
162
-
163
- # === GRADIO UI ===
164
- with gr.Blocks(title="2D to 3D Stereo") as demo:
165
- gr.Markdown("## 2D to 3D Stereo Generator (Fully Local)")
166
- gr.Markdown("Generates stereo pairs using Depth Estimation and **Local LaMa Inpainting**. No external APIs required.")
167
-
168
- with gr.Row():
169
- with gr.Column(scale=1):
170
- input_img = gr.Image(type="pil", label="Input Image", height=480)
171
-
172
- with gr.Group():
173
- gr.Markdown("### 3D Controls")
174
- divergence_slider = gr.Slider(
175
- minimum=0, maximum=100, value=30, step=1,
176
- label="3D Strength (Divergence)",
177
- info="Max pixel separation."
178
- )
179
- convergence_slider = gr.Slider(
180
- minimum=0.0, maximum=1.0, value=0.1, step=0.05,
181
- label="Focus Plane (Convergence)",
182
- info="0.0 = Background at screen. 1.0 = Foreground at screen."
183
- )
184
-
185
- btn = gr.Button("Generate 3D", variant="primary")
186
-
187
- with gr.Column(scale=1):
188
- out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=480)
189
-
190
- with gr.Row():
191
- out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=400)
192
-
193
- btn.click(
194
- fn=stereo_pipeline,
195
- inputs=[input_img, divergence_slider, convergence_slider],
196
- outputs=[out_stereo, out_anaglyph]
197
- )
198
-
199
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
200
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from transformers import DPTForDepthEstimation, DPTImageProcessor
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+
10
+ # === DEVICE ===
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print(f"Running on device: {device}")
13
+
14
+ # === LOAD MODELS ===
15
+ def load_models():
16
+ print("Loading Depth Model...")
17
+ # 1. Depth Model
18
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
19
+ depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
20
+
21
+ print("Loading LaMa Inpainting Model...")
22
+ # 2. LaMa Inpainting Model (TorchScript)
23
+ # We download the .pt file directly from a repository that hosts the compiled JIT version.
24
+ # This avoids dealing with .ckpt files and source code dependencies.
25
+ try:
26
+ model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
27
+
28
+ print(f"Loading LaMa from: {model_path}")
29
+ # Load the TorchScript model
30
+ lama_model = torch.jit.load(model_path, map_location=device)
31
+ lama_model.eval()
32
+
33
+ except Exception as e:
34
+ print(f"Error loading LaMa model: {e}")
35
+ raise e
36
+
37
+ return depth_model, depth_processor, lama_model
38
+
39
+ # Load models once at startup
40
+ depth_model, depth_processor, lama_model = load_models()
41
+
42
+ # === DEPTH ESTIMATION ===
43
+ @torch.no_grad()
44
+ def estimate_depth(image_pil, model, processor):
45
+ original_size = image_pil.size
46
+ inputs = processor(images=image_pil, return_tensors="pt").to(device)
47
+ depth = model(**inputs).predicted_depth
48
+
49
+ depth = torch.nn.functional.interpolate(
50
+ depth.unsqueeze(1),
51
+ size=(original_size[1], original_size[0]),
52
+ mode="bicubic",
53
+ align_corners=False,
54
+ ).squeeze().detach().cpu().numpy()
55
+
56
+ depth_min, depth_max = depth.min(), depth.max()
57
+ if depth_max - depth_min > 0:
58
+ return (depth - depth_min) / (depth_max - depth_min)
59
+ return depth
60
+
61
+ # === STEREO GENERATION LOGIC ===
62
+ def generate_right_and_mask(image, shift_map):
63
+ height, width = image.shape[:2]
64
+ x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
65
+ shift = shift_map.astype(int)
66
+ target_x = x_coords - shift
67
+
68
+ right = np.zeros_like(image)
69
+ # Mask: 1 (or 255) means HOLE/MISSING info.
70
+ # Initialize as all holes (255)
71
+ mask = np.ones((height, width), dtype=np.float32)
72
+
73
+ valid_mask = (target_x >= 0) & (target_x < width)
74
+ flat_y = y_coords[valid_mask]
75
+ flat_x_target = target_x[valid_mask]
76
+ flat_x_source = x_coords[valid_mask]
77
+
78
+ right[flat_y, flat_x_target] = image[flat_y, flat_x_source]
79
+ # Mark written pixels as valid (0)
80
+ mask[flat_y, flat_x_target] = 0.0
81
+
82
+ return right, mask
83
+
84
+ # === LOCAL INPAINTING ===
85
+ @torch.no_grad()
86
+ def run_local_lama(image_bgr, mask_float):
87
+ """
88
+ Runs LaMa locally.
89
+ image_bgr: HxWx3 uint8 numpy array
90
+ mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
91
+ """
92
+ # 1. Resize to be divisible by 8 (LaMa requirement)
93
+ h, w = image_bgr.shape[:2]
94
+ new_h = (h // 8) * 8
95
+ new_w = (w // 8) * 8
96
+
97
+ img_resized = cv2.resize(image_bgr, (new_w, new_h))
98
+ mask_resized = cv2.resize(mask_float, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
99
+
100
+ # 2. Convert to Torch Tensors
101
+ # Image: (1, 3, H, W), RGB, 0-1
102
+ img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
103
+ # Swap BGR to RGB
104
+ img_t = img_t[:, [2, 1, 0], :, :]
105
+
106
+ # Mask: (1, 1, H, W), 0-1
107
+ mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)
108
+ # Binary threshold just in case
109
+ mask_t = (mask_t > 0.5).float()
110
+
111
+ img_t = img_t.to(device)
112
+ mask_t = mask_t.to(device)
113
+
114
+ # 3. Inference
115
+ inpainted_t = lama_model(img_t, mask_t)
116
+
117
+ # 4. Post-process
118
+ inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
119
+ inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
120
+
121
+ # Swap back RGB to BGR
122
+ inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
123
+
124
+ # Resize back to original if needed
125
+ if new_h != h or new_w != w:
126
+ inpainted = cv2.resize(inpainted, (w, h))
127
+
128
+ return inpainted
129
+
130
+ def make_anaglyph(left, right):
131
+ l_arr = np.array(left)
132
+ r_arr = np.array(right)
133
+ anaglyph = np.zeros_like(l_arr)
134
+ anaglyph[:, :, 0] = l_arr[:, :, 0]
135
+ anaglyph[:, :, 1] = r_arr[:, :, 1]
136
+ anaglyph[:, :, 2] = r_arr[:, :, 2]
137
+ return Image.fromarray(anaglyph)
138
+
139
+ # === PIPELINE ===
140
+ def stereo_pipeline(image_pil, divergence, convergence):
141
+ if image_pil is None:
142
+ return None, None
143
+
144
+ # Convert to BGR for OpenCV processing
145
+ image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
146
+
147
+ # 1. Depth
148
+ depth = estimate_depth(image_pil, depth_model, depth_processor)
149
+
150
+ # 2. Shift Map
151
+ shift = (depth - convergence) * divergence
152
+
153
+ # 3. Warping
154
+ right_img, mask = generate_right_and_mask(image_cv, shift)
155
+
156
+ # 4. Inpainting (Local)
157
+ right_filled = run_local_lama(right_img, mask)
158
+
159
+ left = image_pil
160
+ right = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
161
+
162
+ # 5. Composition
163
+ width, height = left.size
164
+ combined_image = Image.new('RGB', (width * 2, height))
165
+ combined_image.paste(left, (0, 0))
166
+ combined_image.paste(right, (width, 0))
167
+
168
+ anaglyph_image = make_anaglyph(left, right)
169
+
170
+ return combined_image, anaglyph_image
171
+
172
+ # === GRADIO UI ===
173
+ with gr.Blocks(title="2D to 3D Stereo") as demo:
174
+ gr.Markdown("## 2D to 3D Stereo Generator (Fully Local)")
175
+ gr.Markdown("Generates stereo pairs using Depth Estimation and **Local LaMa Inpainting**. No external APIs required.")
176
+
177
+ with gr.Row():
178
+ with gr.Column(scale=1):
179
+ input_img = gr.Image(type="pil", label="Input Image", height=480)
180
+
181
+ with gr.Group():
182
+ gr.Markdown("### 3D Controls")
183
+ divergence_slider = gr.Slider(
184
+ minimum=0, maximum=100, value=30, step=1,
185
+ label="3D Strength (Divergence)",
186
+ info="Max pixel separation."
187
+ )
188
+ convergence_slider = gr.Slider(
189
+ minimum=0.0, maximum=1.0, value=0.1, step=0.05,
190
+ label="Focus Plane (Convergence)",
191
+ info="0.0 = Background at screen. 1.0 = Foreground at screen."
192
+ )
193
+
194
+ btn = gr.Button("Generate 3D", variant="primary")
195
+
196
+ with gr.Column(scale=1):
197
+ out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=480)
198
+
199
+ with gr.Row():
200
+ out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=400)
201
+
202
+ btn.click(
203
+ fn=stereo_pipeline,
204
+ inputs=[input_img, divergence_slider, convergence_slider],
205
+ outputs=[out_stereo, out_anaglyph]
206
+ )
207
+
208
+ if __name__ == "__main__":
209
  demo.launch()