2D-to-Stereo-3D / app.py
enoky's picture
Add Divergence (3D Strength) and Convergence (Focus Point) sliders
f98a0fe verified
raw
history blame
7.78 kB
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import DPTForDepthEstimation, DPTImageProcessor
from gradio_client import Client, handle_file
import tempfile
import os
# === DEVICE ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# === DEPTH MODEL ===
def load_depth_model():
# DPTImageProcessor is the modern replacement for FeatureExtractor
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
return model, processor
@torch.no_grad()
def estimate_depth(image_pil, model, processor):
# Keep original size for restoration later
original_size = image_pil.size # (width, height)
# Preprocess (processor handles resizing internally for the model)
inputs = processor(images=image_pil, return_tensors="pt").to(device)
depth = model(**inputs).predicted_depth
# Interpolate depth back to ORIGINAL image size
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=(original_size[1], original_size[0]), # torch expects (H, W)
mode="bicubic",
align_corners=False,
).squeeze().detach().cpu().numpy()
# Normalize
depth_min, depth_max = depth.min(), depth.max()
if depth_max - depth_min > 0:
return (depth - depth_min) / (depth_max - depth_min)
return depth
def generate_right_and_mask(image, shift_map):
"""
Vectorized shift operation.
shift_map: 2D array indicating how many pixels to shift left (positive) or right (negative).
"""
height, width = image.shape[:2]
# Create a grid of coordinates
x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
# Calculate target coordinates (shift pixels to the left for right eye)
shift = shift_map.astype(int)
target_x = x_coords - shift
# Initialize output and mask
right = np.zeros_like(image)
mask = np.ones((height, width), dtype=np.uint8) * 255 # 255 = hole/inpainting area
# Valid indices mask (ensure pixels land within image bounds)
valid_mask = (target_x >= 0) & (target_x < width)
# Flatten arrays for advanced indexing
flat_y = y_coords[valid_mask]
flat_x_target = target_x[valid_mask]
flat_x_source = x_coords[valid_mask]
# Assign pixels
# Note: simple overwriting handles occlusions naively but effectively for this use case
right[flat_y, flat_x_target] = image[flat_y, flat_x_source]
# Update Mask: Areas that were written to are NOT holes (0)
mask[flat_y, flat_x_target] = 0
return right, mask
def make_anaglyph(left, right):
"""
Creates a Red-Cyan anaglyph.
Left image provides the Red channel.
Right image provides the Green and Blue channels.
"""
# Convert to arrays
l_arr = np.array(left)
r_arr = np.array(right)
# Create output array (same shape)
anaglyph = np.zeros_like(l_arr)
# Red channel from Left
anaglyph[:, :, 0] = l_arr[:, :, 0]
# Green and Blue channels from Right
anaglyph[:, :, 1] = r_arr[:, :, 1]
anaglyph[:, :, 2] = r_arr[:, :, 2]
return Image.fromarray(anaglyph)
# === LAMA INPAINTING (Via Gradio Client) ===
# Note: You need a valid Space that accepts image + mask.
try:
lama_client = Client("asif-k/LaMa-Inpainting")
except Exception as e:
print(f"Could not connect to external LaMa client: {e}")
lama_client = None
def run_lama_inpainting(image_bgr, mask):
if lama_client is None:
print("LaMa client unavailable, returning unfilled image.")
return image_bgr
# Prepare files for Gradio Client
# Convert BGR (OpenCV) to RGB for PIL
img_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_img, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_mask:
Image.fromarray(img_rgb).save(f_img.name)
Image.fromarray(mask).save(f_mask.name)
try:
# Predict using the external space
result_path = lama_client.predict(
image=handle_file(f_img.name),
mask=handle_file(f_mask.name),
api_name="/predict"
)
# Result is a filepath
res_img = Image.open(result_path)
return cv2.cvtColor(np.array(res_img), cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"Inpainting failed: {e}")
return image_bgr # Return original with holes if fail
finally:
# Cleanup
os.remove(f_img.name)
os.remove(f_mask.name)
# === APP LOGIC ===
depth_model, depth_processor = load_depth_model()
def stereo_pipeline(image_pil, divergence, convergence):
if image_pil is None:
return None, None
image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
# 1. Estimate Depth (0.0 far to 1.0 near)
depth = estimate_depth(image_pil, depth_model, depth_processor)
# 2. Calculate Shift Map
# Divergence: Overall separation strength (pixels)
# Convergence: The depth plane that stays still (0.0 - 1.0)
# Result:
# Positive shift (Leftwards) = Pop out of screen (Near objects)
# Negative shift (Rightwards) = Go into screen (Far objects)
shift = (depth - convergence) * divergence
# 3. Shift Pixels
right_img, mask = generate_right_and_mask(image_cv, shift)
# 4. Inpaint Holes
# Pass the mask where 255 indicates holes to be filled
right_filled = run_lama_inpainting(right_img, mask)
left = image_pil
right = Image.fromarray(cv2.cvtColor(right_filled, cv2.COLOR_BGR2RGB))
# === Combine into Side-by-Side ===
width, height = left.size
combined_image = Image.new('RGB', (width * 2, height))
combined_image.paste(left, (0, 0))
combined_image.paste(right, (width, 0))
# === Create Anaglyph ===
anaglyph_image = make_anaglyph(left, right)
return combined_image, anaglyph_image
# === GRADIO UI ===
with gr.Blocks(title="2D to 3D Stereo") as demo:
gr.Markdown("## 2D to 3D Stereo Generator")
gr.Markdown("Generates a side-by-side stereo pair and anaglyph using Depth Estimation and LaMa Inpainting.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Input Image", height=480)
# === Controls ===
with gr.Group():
gr.Markdown("### 3D Controls")
divergence_slider = gr.Slider(
minimum=0, maximum=100, value=30, step=1,
label="3D Strength (Divergence)",
info="Max pixel separation. Higher = Deeper 3D effect."
)
convergence_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
label="Focus Plane (Convergence)",
info="0.0 = Background at screen depth. 0.5 = Mid-range at screen. 1.0 = Foreground at screen."
)
btn = gr.Button("Generate 3D", variant="primary")
with gr.Column(scale=1):
out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=480)
with gr.Row():
out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=400)
btn.click(
fn=stereo_pipeline,
inputs=[input_img, divergence_slider, convergence_slider],
outputs=[out_stereo, out_anaglyph]
)
if __name__ == "__main__":
demo.launch()