Ronald-M commited on
Commit
1df73db
·
verified ·
1 Parent(s): beec552

Upload parallax_gradio_app.py

Browse files
Files changed (1) hide show
  1. parallax_gradio_app.py +310 -0
parallax_gradio_app.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ from PIL import Image
5
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
6
+ import gradio as gr
7
+ import imageio
8
+ import cv2 as cv
9
+ import tempfile
10
+ import os
11
+
12
+ # Initialize depth model globally
13
+ print("Loading Intel DPT depth estimation model...")
14
+ processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
15
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
16
+ model.eval()
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = model.to(device)
20
+ print(f"Model loaded on {device}")
21
+
22
+
23
+ def get_depth_map(image):
24
+ """Extract depth map from image using DPT model."""
25
+ # Resize for faster processing
26
+ max_size = 640
27
+ if max(image.size) > max_size:
28
+ ratio = max_size / max(image.size)
29
+ new_size = tuple(int(dim * ratio) for dim in image.size)
30
+ image = image.resize(new_size, Image.LANCZOS)
31
+
32
+ # Prepare image for the model
33
+ inputs = processor(images=image, return_tensors="pt")
34
+ inputs = {k: v.to(device) for k, v in inputs.items()}
35
+
36
+ # Run depth estimation
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ predicted_depth = outputs.predicted_depth
40
+
41
+ # Interpolate to original size
42
+ prediction = torch.nn.functional.interpolate(
43
+ predicted_depth.unsqueeze(1),
44
+ size=image.size[::-1],
45
+ mode="bicubic",
46
+ align_corners=False,
47
+ )
48
+
49
+ # Normalize
50
+ depth_map = prediction.squeeze().cpu().numpy()
51
+ depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
52
+
53
+ return depth_map, image
54
+
55
+
56
+ def separate_layers(depth_map, image):
57
+ """Separate foreground and background using depth."""
58
+ depth_np = np.array(depth_map)
59
+ depth_norm = cv.normalize(depth_np, None, 0, 255, cv.NORM_MINMAX).astype("uint8")
60
+
61
+ # Threshold to separate foreground/background
62
+ _, depth_thresh = cv.threshold(depth_norm, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU)
63
+
64
+ foreground_mask = depth_thresh
65
+ background_mask = cv.bitwise_not(foreground_mask)
66
+
67
+ return foreground_mask, background_mask
68
+
69
+
70
+ def inpaint_background(image_np, foreground_mask, background_mask):
71
+ """Reconstruct background by inpainting foreground area."""
72
+ foreground_mask = (foreground_mask > 128).astype(np.uint8) * 255
73
+ background_mask = (background_mask > 128).astype(np.uint8) * 255
74
+
75
+ # Prepare damaged background
76
+ damaged_bg = image_np.copy()[:, :, :3]
77
+ damaged_bg[foreground_mask == 255] = 0
78
+ inpainted_bg = damaged_bg.copy()
79
+
80
+ # Dilate mask
81
+ kernel_iter = cv.getStructuringElement(cv.MORPH_ELLIPSE, (7, 7))
82
+ mask_iter = cv.dilate(foreground_mask, cv.getStructuringElement(cv.MORPH_ELLIPSE, (3, 3)), iterations=2)
83
+
84
+ # Iterative inpainting
85
+ hole_area = np.count_nonzero(mask_iter)
86
+ max_erode = max(1, hole_area // 5000)
87
+ iterations = 12
88
+
89
+ for i in range(iterations):
90
+ erode_steps = max(1, max_erode // (i + 1))
91
+ eroded = cv.erode(mask_iter, kernel_iter, iterations=erode_steps)
92
+ ring_mask = cv.subtract(mask_iter, eroded)
93
+ ring_mask = (ring_mask > 0).astype(np.uint8) * 255
94
+
95
+ if np.count_nonzero(ring_mask) == 0:
96
+ break
97
+
98
+ method = cv.INPAINT_TELEA if i < iterations // 2 else cv.INPAINT_NS
99
+ inpainted_bg = cv.inpaint(inpainted_bg, ring_mask, 5, method)
100
+ mask_iter = eroded
101
+
102
+ # Final refinement
103
+ inpainted_bg = cv.bilateralFilter(inpainted_bg, d=9, sigmaColor=75, sigmaSpace=75)
104
+ inpainted_bg = cv.inpaint(inpainted_bg, foreground_mask, 5, cv.INPAINT_NS)
105
+ inpainted_bg = cv.bilateralFilter(inpainted_bg, d=9, sigmaColor=75, sigmaSpace=75)
106
+
107
+ # Prepare foreground with smooth alpha
108
+ foreground_rgb = image_np.copy()[:, :, :3]
109
+ foreground_rgb[foreground_mask == 0] = 0
110
+
111
+ alpha = foreground_mask / 255.0
112
+ alpha_blurred = cv.GaussianBlur(alpha, (9, 9), 0)
113
+ fg_rgba = np.dstack((foreground_rgb, (alpha_blurred * 255).astype(np.uint8)))
114
+
115
+ return inpainted_bg, fg_rgba, foreground_mask
116
+
117
+
118
+ def create_parallax_animation(inpainted_bg, fg_rgba, depth_map, motion_strength, parallax_strength,
119
+ aperture, speed_multiplier, zoom_base, progress=gr.Progress()):
120
+ """Create parallax animation with depth-of-field effects."""
121
+ num_frames = 60
122
+ zoom_scale_center = 1.0 + (zoom_base * 0.15)
123
+ zoom_scale_sides = 1.0 + (zoom_base * 0.125)
124
+ fps = 20
125
+
126
+ h, w = inpainted_bg.shape[:2]
127
+
128
+ progress(0.1, desc="Preparing layers...")
129
+
130
+ # Create zoomed images at max zoom
131
+ zoom_h_max, zoom_w_max = int(h * zoom_scale_center), int(w * zoom_scale_center)
132
+ zoomed_fg_max = cv.resize(fg_rgba, (zoom_w_max, zoom_h_max), interpolation=cv.INTER_LINEAR)
133
+ zoomed_bg_max = cv.resize(inpainted_bg, (zoom_w_max, zoom_h_max), interpolation=cv.INTER_LINEAR)
134
+
135
+ # Pre-compute blur
136
+ max_kernel = int(aperture * 5)
137
+ max_kernel = max_kernel if max_kernel % 2 == 1 else max_kernel + 1
138
+ zoomed_bg_blurred_max = cv.GaussianBlur(zoomed_bg_max, (max_kernel, max_kernel), 0)
139
+
140
+ # Resize depth map
141
+ depth_map_resized = cv.resize(depth_map, (w, h), interpolation=cv.INTER_LINEAR)
142
+ depth_map_resized = 1 - depth_map_resized
143
+ depth_map_3c = np.repeat(depth_map_resized[:, :, None], 3, axis=2)
144
+
145
+ frames = []
146
+
147
+ progress(0.2, desc="Generating frames...")
148
+
149
+ for i in range(num_frames):
150
+ t = i / (num_frames - 1)
151
+ oscillation = -math.cos(t * 2 * math.pi) / 2 + 0.5
152
+ oscillation = (oscillation - 0.5) * 2
153
+
154
+ zoom_factor = zoom_scale_center - abs(oscillation) * (zoom_scale_center - zoom_scale_sides)
155
+ current_h, current_w = int(h * zoom_factor), int(w * zoom_factor)
156
+
157
+ # Resize from max zoom
158
+ zoomed_fg = cv.resize(zoomed_fg_max, (current_w, current_h), interpolation=cv.INTER_LINEAR)
159
+ zoomed_bg = cv.resize(zoomed_bg_max, (current_w, current_h), interpolation=cv.INTER_LINEAR)
160
+ zoomed_bg_blurred = cv.resize(zoomed_bg_blurred_max, (current_w, current_h), interpolation=cv.INTER_LINEAR)
161
+
162
+ # Compute crop coordinates
163
+ center_y, center_x = current_h // 2, current_w // 2
164
+ crop_y1 = center_y - h // 2
165
+ crop_y2 = center_y + h // 2
166
+
167
+ shift_x_total = current_w - w
168
+ shift_bg_float = oscillation * shift_x_total * 0.10 * motion_strength
169
+ shift_fg_float = oscillation * shift_x_total * 0.20 * motion_strength * parallax_strength
170
+
171
+ crop_bg1 = int(round(center_x - w // 2 + shift_bg_float))
172
+ crop_fg1 = int(round(center_x - w // 2 + shift_fg_float))
173
+
174
+ crop_bg1 = max(0, min(current_w - w, crop_bg1))
175
+ crop_fg1 = max(0, min(current_w - w, crop_fg1))
176
+
177
+ crop_bg2 = crop_bg1 + w
178
+ crop_fg2 = crop_fg1 + w
179
+
180
+ # Crop images
181
+ fg_crop = zoomed_fg[crop_y1:crop_y2, crop_fg1:crop_fg2]
182
+ bg_crop = zoomed_bg[crop_y1:crop_y2, crop_bg1:crop_bg2]
183
+ bg_crop_blurred = zoomed_bg_blurred[crop_y1:crop_y2, crop_bg1:crop_bg2]
184
+
185
+ # Safety resize
186
+ if fg_crop.shape[:2] != (h, w):
187
+ fg_crop = cv.resize(fg_crop, (w, h), interpolation=cv.INTER_LINEAR)
188
+ if bg_crop.shape[:2] != (h, w):
189
+ bg_crop = cv.resize(bg_crop, (w, h), interpolation=cv.INTER_LINEAR)
190
+ bg_crop_blurred = cv.resize(bg_crop_blurred, (w, h), interpolation=cv.INTER_LINEAR)
191
+
192
+ # Blend background with depth
193
+ bg_composite = ((1 - depth_map_3c) * bg_crop + depth_map_3c * bg_crop_blurred).astype(np.uint8)
194
+
195
+ # Alpha composite
196
+ alpha = fg_crop[:, :, 3] / 255.0
197
+ kernel = np.ones((5, 5), np.uint8)
198
+ alpha_uint8 = (alpha * 255).astype(np.uint8)
199
+ alpha_eroded = cv.erode(alpha_uint8, kernel, iterations=1)
200
+ alpha_smooth = cv.GaussianBlur(alpha_eroded, (5, 5), 0) / 255.0
201
+ alpha_smooth_3c = alpha_smooth[:, :, np.newaxis]
202
+
203
+ fg_rgb = fg_crop[:, :, :3].astype(float)
204
+ composite = (fg_rgb * alpha_smooth_3c + bg_composite * (1 - alpha_smooth_3c)).astype(np.uint8)
205
+
206
+ frames.append(composite)
207
+
208
+ # Update progress
209
+ if i % 10 == 0:
210
+ progress(0.2 + (i / num_frames) * 0.7, desc=f"Rendering frame {i}/{num_frames}...")
211
+
212
+ progress(0.95, desc="Saving animation...")
213
+
214
+ # Save GIF
215
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.gif')
216
+ imageio.mimsave(temp_file.name, frames, duration=1000/fps/speed_multiplier, loop=0)
217
+
218
+ progress(1.0, desc="Complete!")
219
+
220
+ return temp_file.name
221
+
222
+
223
+ def process_image(image, motion, parallax, aperture, speed, zoom, progress=gr.Progress()):
224
+ """Main processing pipeline."""
225
+ if image is None:
226
+ return None
227
+
228
+ progress(0, desc="Loading image...")
229
+
230
+ # Convert to PIL if needed
231
+ if not isinstance(image, Image.Image):
232
+ image = Image.fromarray(image).convert('RGB')
233
+
234
+ progress(0.05, desc="Extracting depth map...")
235
+ depth_map, processed_image = get_depth_map(image)
236
+
237
+ progress(0.3, desc="Separating layers...")
238
+ image_np = np.array(processed_image)
239
+ foreground_mask, background_mask = separate_layers(depth_map, processed_image)
240
+
241
+ progress(0.4, desc="Reconstructing background...")
242
+ inpainted_bg, fg_rgba, fg_mask = inpaint_background(image_np, foreground_mask, background_mask)
243
+
244
+ progress(0.5, desc="Creating parallax animation...")
245
+ gif_path = create_parallax_animation(
246
+ inpainted_bg, fg_rgba, depth_map,
247
+ motion, parallax, aperture, speed, zoom,
248
+ progress=progress
249
+ )
250
+
251
+ return gif_path
252
+
253
+
254
+ # Create Gradio interface
255
+ with gr.Blocks(title="🧪 The Parallax Lab", theme=gr.themes.Soft()) as demo:
256
+ gr.Markdown("""
257
+ # 🧪 The Parallax Lab
258
+
259
+ Upload an image to create a stunning depth-based parallax animation with bokeh effects!
260
+
261
+ **How it works:**
262
+ 1. AI extracts depth information from your image
263
+ 2. Separates foreground and background layers
264
+ 3. Creates smooth parallax motion with depth-of-field blur
265
+ """)
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=1):
269
+ input_image = gr.Image(type="pil", label="Upload Your Image", Image="./HW4_Dog.jpg")
270
+
271
+ gr.Markdown("### Effect Controls")
272
+
273
+ motion = gr.Slider(0.5, 2, value=1, step=0.1, label="Motion Strength",
274
+ info="How much the camera moves")
275
+ parallax = gr.Slider(0.5, 2, value=1, step=0.1, label="Parallax Strength",
276
+ info="Separation between foreground/background")
277
+ aperture = gr.Slider(1.4, 5.6, value=2.8, step=0.2, label="Aperture Size",
278
+ info="Blur intensity (lower = more blur)")
279
+ speed = gr.Slider(0.5, 2, value=1, step=0.1, label="Animation Speed",
280
+ info="Playback speed multiplier")
281
+ zoom = gr.Slider(0.5, 2, value=1, step=0.1, label="Zoom Intensity",
282
+ info="How much to zoom in/out")
283
+
284
+ start_btn = gr.Button("✨ Create Parallax Animation", variant="primary", size="lg")
285
+
286
+ with gr.Column(scale=1):
287
+ output_gif = gr.File(label="📥 Download Your Animation", file_types=[".gif"])
288
+
289
+ gr.Markdown("""
290
+ ### Tips for Best Results:
291
+ - Use images with clear foreground subjects
292
+ - Portraits and objects work especially well
293
+ - Higher motion/parallax = more dramatic effect
294
+ - Lower aperture = stronger bokeh blur
295
+ """)
296
+
297
+ start_btn.click(
298
+ fn=process_image,
299
+ inputs=[input_image, motion, parallax, aperture, speed, zoom],
300
+ outputs=[output_gif]
301
+ )
302
+
303
+ gr.Markdown("""
304
+ ---
305
+ **Note:** Processing may take 1-2 minutes depending on image size and hardware.
306
+ """)
307
+
308
+
309
+ if __name__ == "__main__":
310
+ demo.launch()