import torch import numpy as np import math from PIL import Image from transformers import DPTImageProcessor, DPTForDepthEstimation import gradio as gr import imageio import cv2 as cv import tempfile import os # Initialize depth model globally print("Loading Intel DPT depth estimation model...") processor = DPTImageProcessor.from_pretrained("Intel/dpt-large") model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large") model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(f"Model loaded on {device}") def get_depth_map(image): """Extract depth map from image using DPT model.""" # Resize for faster processing max_size = 640 if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = tuple(int(dim * ratio) for dim in image.size) image = image.resize(new_size, Image.LANCZOS) # Prepare image for the model inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Run depth estimation with torch.no_grad(): outputs = model(**inputs) predicted_depth = outputs.predicted_depth # Interpolate to original size prediction = torch.nn.functional.interpolate( predicted_depth.unsqueeze(1), size=image.size[::-1], mode="bicubic", align_corners=False, ) # Normalize depth_map = prediction.squeeze().cpu().numpy() depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) return depth_map, image def separate_layers(depth_map, image): """Separate foreground and background using depth.""" depth_np = np.array(depth_map) depth_norm = cv.normalize(depth_np, None, 0, 255, cv.NORM_MINMAX).astype("uint8") # Threshold to separate foreground/background _, depth_thresh = cv.threshold(depth_norm, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU) foreground_mask = depth_thresh background_mask = cv.bitwise_not(foreground_mask) return foreground_mask, background_mask def inpaint_background(image_np, foreground_mask, background_mask): """Reconstruct background by inpainting foreground area.""" foreground_mask = (foreground_mask > 128).astype(np.uint8) * 255 background_mask = (background_mask > 128).astype(np.uint8) * 255 # Prepare damaged background damaged_bg = image_np.copy()[:, :, :3] damaged_bg[foreground_mask == 255] = 0 inpainted_bg = damaged_bg.copy() # Dilate mask kernel_iter = cv.getStructuringElement(cv.MORPH_ELLIPSE, (7, 7)) mask_iter = cv.dilate(foreground_mask, cv.getStructuringElement(cv.MORPH_ELLIPSE, (3, 3)), iterations=2) # Iterative inpainting hole_area = np.count_nonzero(mask_iter) max_erode = max(1, hole_area // 5000) iterations = 12 for i in range(iterations): erode_steps = max(1, max_erode // (i + 1)) eroded = cv.erode(mask_iter, kernel_iter, iterations=erode_steps) ring_mask = cv.subtract(mask_iter, eroded) ring_mask = (ring_mask > 0).astype(np.uint8) * 255 if np.count_nonzero(ring_mask) == 0: break method = cv.INPAINT_TELEA if i < iterations // 2 else cv.INPAINT_NS inpainted_bg = cv.inpaint(inpainted_bg, ring_mask, 5, method) mask_iter = eroded # Final refinement inpainted_bg = cv.bilateralFilter(inpainted_bg, d=9, sigmaColor=75, sigmaSpace=75) inpainted_bg = cv.inpaint(inpainted_bg, foreground_mask, 5, cv.INPAINT_NS) inpainted_bg = cv.bilateralFilter(inpainted_bg, d=9, sigmaColor=75, sigmaSpace=75) # Prepare foreground with smooth alpha foreground_rgb = image_np.copy()[:, :, :3] foreground_rgb[foreground_mask == 0] = 0 alpha = foreground_mask / 255.0 alpha_blurred = cv.GaussianBlur(alpha, (9, 9), 0) fg_rgba = np.dstack((foreground_rgb, (alpha_blurred * 255).astype(np.uint8))) return inpainted_bg, fg_rgba, foreground_mask def create_parallax_animation(inpainted_bg, fg_rgba, depth_map, motion_strength, parallax_strength, aperture, speed_multiplier, zoom_base, progress=gr.Progress()): """Create parallax animation with depth-of-field effects.""" num_frames = 60 zoom_scale_center = 1.0 + (zoom_base * 0.15) zoom_scale_sides = 1.0 + (zoom_base * 0.125) fps = 20 h, w = inpainted_bg.shape[:2] progress(0.1, desc="Preparing layers...") # Create zoomed images at max zoom zoom_h_max, zoom_w_max = int(h * zoom_scale_center), int(w * zoom_scale_center) zoomed_fg_max = cv.resize(fg_rgba, (zoom_w_max, zoom_h_max), interpolation=cv.INTER_LINEAR) zoomed_bg_max = cv.resize(inpainted_bg, (zoom_w_max, zoom_h_max), interpolation=cv.INTER_LINEAR) # Pre-compute blur max_kernel = int(aperture * 5) max_kernel = max_kernel if max_kernel % 2 == 1 else max_kernel + 1 zoomed_bg_blurred_max = cv.GaussianBlur(zoomed_bg_max, (max_kernel, max_kernel), 0) # Resize depth map depth_map_resized = cv.resize(depth_map, (w, h), interpolation=cv.INTER_LINEAR) depth_map_resized = 1 - depth_map_resized depth_map_3c = np.repeat(depth_map_resized[:, :, None], 3, axis=2) frames = [] progress(0.2, desc="Generating frames...") for i in range(num_frames): t = i / (num_frames - 1) oscillation = -math.cos(t * 2 * math.pi) / 2 + 0.5 oscillation = (oscillation - 0.5) * 2 zoom_factor = zoom_scale_center - abs(oscillation) * (zoom_scale_center - zoom_scale_sides) current_h, current_w = int(h * zoom_factor), int(w * zoom_factor) # Resize from max zoom zoomed_fg = cv.resize(zoomed_fg_max, (current_w, current_h), interpolation=cv.INTER_LINEAR) zoomed_bg = cv.resize(zoomed_bg_max, (current_w, current_h), interpolation=cv.INTER_LINEAR) zoomed_bg_blurred = cv.resize(zoomed_bg_blurred_max, (current_w, current_h), interpolation=cv.INTER_LINEAR) # Compute crop coordinates center_y, center_x = current_h // 2, current_w // 2 crop_y1 = center_y - h // 2 crop_y2 = center_y + h // 2 shift_x_total = current_w - w shift_bg_float = oscillation * shift_x_total * 0.10 * motion_strength shift_fg_float = oscillation * shift_x_total * 0.20 * motion_strength * parallax_strength crop_bg1 = int(round(center_x - w // 2 + shift_bg_float)) crop_fg1 = int(round(center_x - w // 2 + shift_fg_float)) crop_bg1 = max(0, min(current_w - w, crop_bg1)) crop_fg1 = max(0, min(current_w - w, crop_fg1)) crop_bg2 = crop_bg1 + w crop_fg2 = crop_fg1 + w # Crop images fg_crop = zoomed_fg[crop_y1:crop_y2, crop_fg1:crop_fg2] bg_crop = zoomed_bg[crop_y1:crop_y2, crop_bg1:crop_bg2] bg_crop_blurred = zoomed_bg_blurred[crop_y1:crop_y2, crop_bg1:crop_bg2] # Safety resize if fg_crop.shape[:2] != (h, w): fg_crop = cv.resize(fg_crop, (w, h), interpolation=cv.INTER_LINEAR) if bg_crop.shape[:2] != (h, w): bg_crop = cv.resize(bg_crop, (w, h), interpolation=cv.INTER_LINEAR) bg_crop_blurred = cv.resize(bg_crop_blurred, (w, h), interpolation=cv.INTER_LINEAR) # Blend background with depth bg_composite = ((1 - depth_map_3c) * bg_crop + depth_map_3c * bg_crop_blurred).astype(np.uint8) # Alpha composite alpha = fg_crop[:, :, 3] / 255.0 kernel = np.ones((5, 5), np.uint8) alpha_uint8 = (alpha * 255).astype(np.uint8) alpha_eroded = cv.erode(alpha_uint8, kernel, iterations=1) alpha_smooth = cv.GaussianBlur(alpha_eroded, (5, 5), 0) / 255.0 alpha_smooth_3c = alpha_smooth[:, :, np.newaxis] fg_rgb = fg_crop[:, :, :3].astype(float) composite = (fg_rgb * alpha_smooth_3c + bg_composite * (1 - alpha_smooth_3c)).astype(np.uint8) frames.append(composite) # Update progress if i % 10 == 0: progress(0.2 + (i / num_frames) * 0.7, desc=f"Rendering frame {i}/{num_frames}...") progress(0.95, desc="Saving animation...") # Save GIF temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.gif') imageio.mimsave(temp_file.name, frames, duration=1000/fps/speed_multiplier, loop=0) progress(1.0, desc="Complete!") return temp_file.name def process_image(image, motion, parallax, aperture, speed, zoom, progress=gr.Progress()): """Main processing pipeline.""" if image is None: return None, None progress(0, desc="Loading image...") # Convert to PIL if needed if not isinstance(image, Image.Image): image = Image.fromarray(image).convert('RGB') progress(0.05, desc="Extracting depth map...") depth_map, processed_image = get_depth_map(image) progress(0.3, desc="Separating layers...") image_np = np.array(processed_image) foreground_mask, background_mask = separate_layers(depth_map, processed_image) progress(0.4, desc="Reconstructing background...") inpainted_bg, fg_rgba, fg_mask = inpaint_background(image_np, foreground_mask, background_mask) progress(0.5, desc="Creating parallax animation...") gif_path = create_parallax_animation( inpainted_bg, fg_rgba, depth_map, motion, parallax, aperture, speed, zoom, progress=progress ) return gif_path, gif_path # Create Gradio interface with gr.Blocks(title="🧪 The Parallax Lab", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🧪 The Parallax Lab Upload an image to create a stunning depth-based parallax animation with bokeh effects! **How it works:** 1. AI extracts depth information from your image 2. Separates foreground and background layers 3. Creates smooth parallax motion with depth-of-field blur """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Your Image", value="HW4_Dog.jpg") gr.Markdown("### Effect Controls") motion = gr.Slider(0.5, 2, value=1, step=0.1, label="Motion Strength", info="How much the camera moves") parallax = gr.Slider(0.5, 2, value=1, step=0.1, label="Parallax Strength", info="Separation between foreground/background") aperture = gr.Slider(1.4, 5.6, value=2.8, step=0.2, label="Aperture Size", info="Blur intensity (lower = more blur)") speed = gr.Slider(0.5, 2, value=1, step=0.1, label="Animation Speed", info="Playback speed multiplier") zoom = gr.Slider(0.5, 2, value=1, step=0.1, label="Zoom Intensity", info="How much to zoom in/out") start_btn = gr.Button("✨ Create Parallax Animation", variant="primary", size="lg") with gr.Column(scale=1): output_gif = gr.Image(label="🎬 Your Parallax Animation", type="filepath", format="gif") download_file = gr.File(label="📥 Download GIF", file_types=[".gif"]) gr.Markdown(""" ### Tips for Best Results: - Use images with clear foreground subjects - Portraits and objects work especially well - Higher motion/parallax = more dramatic effect - Lower aperture = stronger bokeh blur """) start_btn.click( fn=process_image, inputs=[input_image, motion, parallax, aperture, speed, zoom], outputs=[output_gif, download_file] ) gr.Markdown(""" --- **Note:** Processing may take 1-2 minutes depending on image size and hardware. """) if __name__ == "__main__": demo.launch()