2D-to-Stereo-3D / app.py
enoky's picture
Update app.py
549ff77 verified
raw
history blame
10.7 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
from transformers import AutoModelForDepthEstimation, AutoImageProcessor
from huggingface_hub import hf_hub_download
import os
# === DEVICE ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")
# ==============================================================================
# 1. SAFE & FAST FORWARD WARPER USING grid_sample (NO MORE BLACK IMAGES!)
# ==============================================================================
class SafeForwardWarp(nn.Module):
def forward(self, img, flow):
"""
img: [B, C, H, W] in [0,1]
flow: [B, H, W, 2] flow[...,0] = delta_x (positive = right), flow[...,1] = delta_y
"""
B, C, H, W = img.shape
# Create sampling grid in normalized coordinates [-1, 1]
grid_x, grid_y = torch.meshgrid(
torch.arange(W, device=img.device),
torch.arange(H, device=img.device),
indexing='ij'
)
grid_x = grid_x.float().unsqueeze(0).expand(B, -1, -1) # [B, H, W]
grid_y = grid_y.float().unsqueeze(0).expand(B, -1, -1)
dest_x = grid_x + flow[..., 0] # source pixel moves to x + dx
dest_y = grid_y + flow[..., 1]
# Normalize to [-1, 1]
norm_x = 2.0 * dest_x / (W - 1) - 1.0
norm_y = 2.0 * dest_y / (H - 1) - 1.0
grid = torch.stack((norm_x, norm_y), dim=-1) # [B, H, W, 2]
grid = grid.clamp(-1, 1)
warped = torch.nn.functional.grid_sample(
img,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=True
)
return warped
# ==============================================================================
# 2. STEREO WARPER β€” Improved weighting + safer dilation
# ==============================================================================
class ForwardWarpStereo(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
self.warp = SafeForwardWarp()
def forward(self, img, shift, disp_for_weights):
# shift: [B, H, W] (positive = shift right-eye left β†’ object pops out)
flow_x = -shift # negative = move pixels left for right eye
flow_y = torch.zeros_like(flow_x)
flow = torch.stack((flow_x, flow_y), dim=-1) # [B, H, W, 2]
# Better weighting: closer pixels contribute more
weights = 1.0 / (disp_for_weights + 0.1)
weights = weights / (weights.max() + 1e-8)
weighted_img = img * weights.unsqueeze(1)
warped_img = self.warp(weighted_img, flow)
warped_weights = self.warp(weights.unsqueeze(1), flow)
# Avoid division by zero
warped_weights = torch.clamp(warped_weights, min=self.eps)
result = warped_img / warped_weights
# Occlusion mask via occupancy count
ones = torch.ones_like(img[:, :1])
occupancy = self.warp(ones, flow)
occlusion = (occupancy < self.eps).float()
# Smart dilation β€” preserve foreground edges
with torch.no_grad():
fg_thresh = torch.quantile(disp_for_weights, 0.90)
fg_mask = (disp_for_weights > fg_thresh).float().unsqueeze(0)
k = 9
dilated = torch.nn.functional.conv2d(
occlusion,
torch.ones(1, 1, k, k, device=occlusion.device),
padding=k // 2
) > 0.5
safe_dilation = dilated.float() * (1 - fg_mask)
occlusion = torch.clamp(occlusion + safe_dilation, 0, 1)
return result, occlusion
# ==============================================================================
# 3. MODELS & HELPERS
# ==============================================================================
def load_models():
print("Loading Depth Anything V2 Large...")
depth_model = AutoModelForDepthEstimation.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf"
).to(device).eval()
depth_processor = AutoImageProcessor.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf"
)
print("Loading LaMa Inpainting Model...")
try:
model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
lama_model = torch.jit.load(model_path, map_location=device).eval()
except Exception as e:
print(f"LaMa load failed: {e}")
lama_model = None
stereo_warper = ForwardWarpStereo().to(device)
return depth_model, depth_processor, lama_model, stereo_warper
depth_model, depth_processor, lama_model, stereo_warper = load_models()
@torch.no_grad()
def estimate_depth(image_pil):
original_size = image_pil.size
inputs = depth_processor(images=image_pil, return_tensors="pt").to(device)
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=(original_size[1], original_size[0]),
mode="bicubic",
align_corners=False,
).squeeze(0).squeeze(0)
# Normalize to [0,1]
d_min, d_max = depth.min(), depth.max()
if d_max > d_min:
depth = (depth - d_min) / (d_max - d_min)
return depth
@torch.no_grad()
def run_lama(image_bgr, mask_float):
if lama_model is None:
return image_bgr
mask_uint8 = (mask_float * 255).astype(np.uint8)
kernel = np.ones((7, 7), np.uint8)
mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=2)
h, w = image_bgr.shape[:2]
new_h = (h // 8) * 8
new_w = (w // 8) * 8
img_resized = cv2.resize(image_bgr, (new_w, new_h))
mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
img_t = img_t[:, [2, 1, 0]].to(device) # BGR β†’ RGB
mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
mask_t = (mask_t > 0.5).float().to(device)
img_t = img_t * (1 - mask_t)
inpainted = lama_model(img_t, mask_t)
result = (inpainted[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
if (new_h, new_w) != (h, w):
result = cv2.resize(result, (w, h))
return result
def make_anaglyph(left, right):
l = np.array(left)
r = np.array(right)
ana = np.zeros_like(l)
ana[:, :, 0] = l[:, :, 0] # Red ← Left
ana[:, :, 1] = r[:, :, 1] # Green ← Right
ana[:, :, 2] = r[:, :, 2] # Blue ← Right
return Image.fromarray(ana)
# ==============================================================================
# 4. MAIN PIPELINE
# ==============================================================================
@torch.no_grad()
def stereo_pipeline(image_pil, divergence_percent=3.2, convergence_plane=0.08):
if image_pil is None:
return None, None, None, None
w, h = image_pil.size
if w > 1920:
ratio = 1920 / w
image_pil = image_pil.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
w, h = image_pil.size
# 1. Depth
depth = estimate_depth(image_pil) # [H, W] in [0,1]
depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8))
# 2. Disparity (stronger volume with square)
disp_raw = depth ** 2
disp_clipped = torch.clamp(disp_raw, max=torch.quantile(disp_raw, 0.995))
# 3. Shift
max_shift = w * (divergence_percent / 100.0)
shift_raw = disp_clipped * max_shift
shift_min, shift_max = shift_raw.min(), shift_raw.max()
convergence_offset = shift_min + convergence_plane * (shift_max - shift_min)
final_shift = shift_raw - convergence_offset
print(f"Final shift range: {final_shift.min():.1f} β†’ {final_shift.max():.1f anywhere} px")
# 4. Warp right eye
img_tensor = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1,3,H,W]
shift_tensor = final_shift.unsqueeze(0).to(device) # [1,H,W]
disp_tensor = disp_clipped.unsqueeze(0).to(device)
right_tensor, occlusion_mask = stereo_warper(img_tensor, shift_tensor, disp_tensor)
# 5. To numpy
right_np = (right_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR)
mask_np = occlusion_mask.squeeze(0).cpu().numpy()
# 6. Inpaint occlusions
right_filled_bgr = run_lama(right_bgr, mask_np)
right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
# 7. Outputs
mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8))
sbs = Image.new('RGB', (w * 2, h))
sbs.paste(image_pil, (0, 0))
sbs.paste(right_filled, (w, 0))
anaglyph = make_anaglyph(image_pil, right_filled)
return sbs, anaglyph, depth_vis, mask_vis
# ==============================================================================
# 5. GRADIO UI
# ==============================================================================
with gr.Blocks(title="2D β†’ 3D Stereo β€” Pro & Stable") as demo:
gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo β€” Pro Quality (Fixed & Stable)</h1>")
gr.Markdown("Depth Anything V2 + Safe Forward Warping + LaMa Inpainting")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Upload Image", height=520)
with gr.Accordion("Settings", open=True):
divergence = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="3D Strength (%)")
convergence = gr.Slider(0.0, 1.0, value=0.08, step=0.01,
label="Convergence Plane (0 = pop-out, 1 = deep)")
btn = gr.Button("Generate 3D", variant="primary", size="lg")
with gr.Column(scale=1):
out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=520)
out_sbs = gr.Image(label="Side-by-Side (Cross-eye / Parallel)", height=300)
with gr.Row():
out_depth = gr.Image(label="Depth Map", height=200)
out_mask = gr.Image(label="Occlusion Mask", height=200)
btn.click(
fn=stereo_pipeline,
inputs=[input_img, divergence, convergence],
outputs=[out_sbs, out_anaglyph, out_depth, out_mask]
)
gr.Markdown("**Tip:** Use Red/Cyan glasses for anaglyph β€’ Cross-eye or parallel view for SBS")
if __name__ == "__main__":
demo.launch(share=True)