Spaces:
Running
Running
File size: 16,009 Bytes
db1a689 ed6a23d db1a689 b4c58d3 fe2b283 db1a689 b4c58d3 66f7927 db1a689 66f7927 b4c58d3 66f7927 b4c58d3 549ff77 b4c58d3 66f7927 b4c58d3 66f7927 ed6a23d b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 b89295c b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 b4c58d3 66f7927 8b101f9 66f7927 b4c58d3 db1a689 66f7927 b4c58d3 b89295c 66f7927 b4c58d3 be50bae b4c58d3 66f7927 b4c58d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
from torch.autograd import Function
from transformers import AutoModelForDepthEstimation, AutoImageProcessor
from huggingface_hub import hf_hub_download
import os
# === DEVICE ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")
# ==============================================================================
# 1. FORWARD WARP IMPLEMENTATION (Native PyTorch)
# ==============================================================================
class ForwardWarpFunction(Function):
@staticmethod
def forward(ctx, im0, flow, interpolation_mode_int):
# Input validation
assert (len(im0.shape) == len(flow.shape) == 4)
assert (interpolation_mode_int == 0 or interpolation_mode_int == 1)
assert (im0.shape[0] == flow.shape[0])
assert (im0.shape[-2:] == flow.shape[1:3])
assert (flow.shape[3] == 2)
B, C, H, W = im0.shape
# Create a contiguous output tensor to prevent view/reshape errors
im1 = torch.zeros(im0.shape, device=im0.device, dtype=im0.dtype).contiguous()
# Grid creation
grid_x, grid_y = torch.meshgrid(
torch.arange(W, device=im0.device, dtype=im0.dtype),
torch.arange(H, device=im0.device, dtype=im0.dtype),
indexing='xy'
)
grid_x = grid_x.unsqueeze(0).expand(B, -1, -1)
grid_y = grid_y.unsqueeze(0).expand(B, -1, -1)
# Destination coordinates
x_dest = grid_x + flow[:, :, :, 0]
y_dest = grid_y + flow[:, :, :, 1]
if interpolation_mode_int == 0: # Bilinear Splatting
x_f = torch.floor(x_dest).long()
y_f = torch.floor(y_dest).long()
x_c = x_f + 1
y_c = y_f + 1
# Weights
nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest)
ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest)
sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float())
se_k = (x_dest - x_f.float()) * (y_dest - y_f.float())
# Clamp coords
x_f_clamped = torch.clamp(x_f, 0, W - 1)
y_f_clamped = torch.clamp(y_f, 0, H - 1)
x_c_clamped = torch.clamp(x_c, 0, W - 1)
y_c_clamped = torch.clamp(y_c, 0, H - 1)
# Per-corner validity masks
mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H)
mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H)
mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H)
mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H)
# Reshape for broadcasting
nw_k = nw_k.unsqueeze(1)
ne_k = ne_k.unsqueeze(1)
sw_k = sw_k.unsqueeze(1)
se_k = se_k.unsqueeze(1)
mask_nw = mask_nw.unsqueeze(1)
mask_ne = mask_ne.unsqueeze(1)
mask_sw = mask_sw.unsqueeze(1)
mask_se = mask_se.unsqueeze(1)
# Flatten indices for scatter_add
b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
base_idx = b_indices * (C * H * W) + c_indices * (H * W)
# Scatter to 4 neighbors (Accumulate/Splat)
def scatter_corner(y_idx, x_idx, weights, mask):
flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1)
values = (im0 * weights) * mask.float()
# Since im1 is contiguous, we can safely use view() for in-place scatter
im1_flat = im1.view(-1)
idx_flat = flat_idx.contiguous().view(-1)
val_flat = values.contiguous().view(-1)
im1_flat.scatter_add_(0, idx_flat, val_flat)
scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw) # NW
scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne) # NE
scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw) # SW
scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se) # SE
else: # Nearest Neighbor (Legacy fallback)
x_nearest = torch.round(x_dest).long()
y_nearest = torch.round(y_dest).long()
valid_mask = (x_nearest >= 0) & (x_nearest < W) & (y_nearest >= 0) & (y_nearest < H)
valid_mask = valid_mask.unsqueeze(1)
x_clamped = torch.clamp(x_nearest, 0, W - 1)
y_clamped = torch.clamp(y_nearest, 0, H - 1)
b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W)
c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W)
dest_idx = b_indices * (C * H * W) + c_indices * (H * W) + y_clamped.unsqueeze(1) * W + x_clamped.unsqueeze(
1)
source_values = im0 * valid_mask.float()
# Since im1 is contiguous, we can safely use view()
im1.view(-1).scatter_(0, dest_idx.contiguous().view(-1), source_values.contiguous().view(-1))
return im1
@staticmethod
def backward(ctx, grad_output):
return None, None, None
class forward_warp(nn.Module):
def __init__(self, interpolation_mode="Bilinear"):
super(forward_warp, self).__init__()
self.interpolation_mode_int = 0 if interpolation_mode == "Bilinear" else 1
def forward(self, im0, flow):
return ForwardWarpFunction.apply(im0, flow, self.interpolation_mode_int)
# ==============================================================================
# 2. STEREO WARPER WRAPPER
# ==============================================================================
class ForwardWarpStereo(nn.Module):
"""
Weighted Splatting wrapper.
Handles Occlusions using exponential depth weights (Soft Z-Buffering).
"""
def __init__(self, eps=1e-6):
super(ForwardWarpStereo, self).__init__()
self.eps = eps
self.fw = forward_warp(interpolation_mode="Bilinear")
def forward(self, im, disp, convergence, divergence):
# disp comes in as [B, 1, H, W] or [1, 1, H, W]
# We need to squeeze the channel dim to do math with coordinates [B, H, W]
disp_squeeze = disp.squeeze(1) # Shape [B, H, W]
# Create Flow from Disparity
# Shift = (Depth - Convergence) * Divergence
# We negate it because standard flow is source->dest, but disparity logic varies.
# For Right Eye view: Target = Source - Shift. So Flow = -Shift.
shift = (disp_squeeze - convergence) * divergence
flow_x = -shift
# Stack flow (x, y=0) -> (B, H, W, 2)
flow_y = torch.zeros_like(flow_x)
# Stack along last dim: [B, H, W] + [B, H, W] -> [B, H, W, 2]
flow = torch.stack((flow_x, flow_y), dim=-1)
# 1. Calculate Weights (Soft Z-Buffer)
# Closer objects (higher disparity) get exponentially higher weight.
# This allows foreground to overwrite background during accumulation.
# Using 1.5^disp is a tuned heuristic for separation.
disp_norm = disp_squeeze / (disp_squeeze.max() + 1e-8)
weights_map = disp_norm + 0.05
weights_map = weights_map.unsqueeze(1)
# 2. Warp Image * Weights (Accumulate Weighted Color)
# Input im is (B, C, H, W), weights is (B, 1, H, W)
res_accum = self.fw(im * weights_map, flow)
# 3. Warp Weights (Accumulate Weights)
mask_accum = self.fw(weights_map, flow)
# 4. Normalize (Color / TotalWeight)
# Add epsilon to avoid divide-by-zero in empty regions
mask_accum.clamp_(min=self.eps)
res = res_accum / mask_accum
# 5. Generate Binary Occlusion Mask (for Inpainting)
# Splat a grid of ones. Where sum is 0, we have a hole.
ones = torch.ones_like(disp)
occupancy = self.fw(ones, flow)
# Valid pixels have occupancy > 0.
# We want holes = 1.0, filled = 0.0
occlusion_mask = (occupancy < self.eps).float()
return res, occlusion_mask
# ==============================================================================
# 3. APP LOGIC & MODELS
# ==============================================================================
# === LOAD MODELS ===
def load_models():
print("Loading Depth Anything V2 Large...")
depth_model = AutoModelForDepthEstimation.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf"
).to(device)
depth_processor = AutoImageProcessor.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf"
)
print("Loading LaMa Inpainting Model...")
try:
model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
lama_model = torch.jit.load(model_path, map_location=device)
lama_model.eval()
except Exception as e:
print(f"Error loading LaMa model: {e}")
raise e
# Initialize the new Stereo Warper
stereo_warper = ForwardWarpStereo().to(device)
return depth_model, depth_processor, lama_model, stereo_warper
# Load models once at startup
depth_model, depth_processor, lama_model, stereo_warper = load_models()
# === DEPTH ESTIMATION ===
@torch.no_grad()
def estimate_depth(image_pil, model, processor):
original_size = image_pil.size
inputs = processor(images=image_pil, return_tensors="pt").to(device)
depth = model(**inputs).predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=(original_size[1], original_size[0]),
mode="bicubic",
align_corners=False,
).squeeze()
depth_min, depth_max = depth.min(), depth.max()
if depth_max - depth_min > 0:
depth = (depth - depth_min) / (depth_max - depth_min)
else:
depth = torch.zeros_like(depth)
return depth
# === DEPTH MANIPULATION ===
def erode_depth(depth_tensor, kernel_size):
if kernel_size <= 0: return depth_tensor
k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
x = depth_tensor.unsqueeze(0).unsqueeze(0)
padding = k // 2
x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding)
return x_eroded.squeeze()
# === LOCAL INPAINTING ===
@torch.no_grad()
def run_local_lama(image_bgr, mask_float):
# 0. Dilate Mask slightly to catch edge artifacts from splatting
kernel = np.ones((3, 3), np.uint8)
mask_uint8 = (mask_float * 255).astype(np.uint8)
mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=3)
# 1. Resize to be divisible by 8
h, w = image_bgr.shape[:2]
new_h = (h // 8) * 8
new_w = (w // 8) * 8
img_resized = cv2.resize(image_bgr, (new_w, new_h))
mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
# 2. Convert to Torch
img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB
mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
mask_t = (mask_t > 0.5).float()
img_t = img_t.to(device)
mask_t = mask_t.to(device)
# 3. Inference
img_t = img_t * (1 - mask_t)
inpainted_t = lama_model(img_t, mask_t)
# 4. Post-process
inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy()
inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8)
inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR)
if new_h != h or new_w != w:
inpainted = cv2.resize(inpainted, (w, h))
return inpainted
def make_anaglyph(left, right):
l_arr = np.array(left)
r_arr = np.array(right)
anaglyph = np.zeros_like(l_arr)
anaglyph[:, :, 0] = l_arr[:, :, 0]
anaglyph[:, :, 1] = r_arr[:, :, 1]
anaglyph[:, :, 2] = r_arr[:, :, 2]
return Image.fromarray(anaglyph)
# === PIPELINE ===
def stereo_pipeline(image_pil, divergence, convergence, edge_erosion):
if image_pil is None:
return None, None, None, None
# Resize input if too large
w, h = image_pil.size
if w > 1920:
ratio = 1920 / w
new_h = int(h * ratio)
image_pil = image_pil.resize((1920, new_h), Image.LANCZOS)
# 1. Depth Estimation
depth_tensor = estimate_depth(image_pil, depth_model, depth_processor)
# 2. Depth Erosion (optional halo reduction)
if edge_erosion > 0:
depth_tensor = erode_depth(depth_tensor, int(edge_erosion))
# Visualize Depth
depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8)
depth_image = Image.fromarray(depth_vis)
# 3. Forward Warp (Weighted Bilinear Splatting)
# Convert image to tensor (B, C, H, W)
image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2, 0, 1).unsqueeze(0) / 255.0
# Prepare depth tensor (B, 1, H, W)
depth_input = depth_tensor.unsqueeze(0).unsqueeze(0)
# Run the new Stereo Warper
with torch.no_grad():
right_img_tensor, mask_tensor = stereo_warper(
image_tensor,
depth_input,
float(convergence),
float(divergence)
)
# Convert results back to CPU/Numpy
right_img_rgb = (right_img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
mask_vis = (mask_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255).astype(np.uint8)
mask_image = Image.fromarray(mask_vis)
# 4. Inpainting
right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR)
mask_float = mask_tensor.squeeze().cpu().numpy()
right_filled_bgr = run_local_lama(right_img_bgr, mask_float)
# 5. Finalize
left = image_pil
right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB))
width, height = left.size
combined_image = Image.new('RGB', (width * 2, height))
combined_image.paste(left, (0, 0))
combined_image.paste(right, (width, 0))
anaglyph_image = make_anaglyph(left, right)
return combined_image, anaglyph_image, depth_image, mask_image
# === GRADIO UI ===
with gr.Blocks(title="2D to 3D Stereo") as demo:
# Inject CSS
gr.Markdown("## 2D to 3D Stereo Generator (High-Quality Splatting)")
gr.Markdown("Uses **Depth Anything V2**, **Bilinear Weighted Splatting** (Soft Z-Buffer), and **LaMa Inpainting**.")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Input Image", height=320)
with gr.Group():
gr.Markdown("### 3D Controls")
divergence_slider = gr.Slider(
minimum=0, maximum=100, value=30, step=1,
label="3D Strength (Divergence)",
info="Max separation in pixels."
)
convergence_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="Focus Plane (Convergence)",
info="0.0 = Background at screen. 1.0 = Foreground at screen."
)
erosion_slider = gr.Slider(
minimum=0, maximum=20, value=2, step=1,
label="Edge Masking (Erosion)",
info="Cleanup edges. Set to 0 for raw splatting."
)
btn = gr.Button("Generate 3D", variant="primary")
with gr.Column(scale=1):
out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320)
out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320)
with gr.Row():
out_depth = gr.Image(label="Depth Map", height=200)
out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200)
btn.click(
fn=stereo_pipeline,
inputs=[input_img, divergence_slider, convergence_slider, erosion_slider],
outputs=[out_stereo, out_anaglyph, out_depth, out_mask]
)
if __name__ == "__main__":
demo.launch() |