import gradio as gr import torch import torch.nn as nn import numpy as np import cv2 from PIL import Image from torch.autograd import Function from transformers import AutoModelForDepthEstimation, AutoImageProcessor from huggingface_hub import hf_hub_download import os # === DEVICE === device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Running on device: {device}") # ============================================================================== # 1. FORWARD WARP IMPLEMENTATION (Native PyTorch) # ============================================================================== class ForwardWarpFunction(Function): @staticmethod def forward(ctx, im0, flow, interpolation_mode_int): # Input validation assert (len(im0.shape) == len(flow.shape) == 4) assert (interpolation_mode_int == 0 or interpolation_mode_int == 1) assert (im0.shape[0] == flow.shape[0]) assert (im0.shape[-2:] == flow.shape[1:3]) assert (flow.shape[3] == 2) B, C, H, W = im0.shape # Create a contiguous output tensor to prevent view/reshape errors im1 = torch.zeros(im0.shape, device=im0.device, dtype=im0.dtype).contiguous() # Grid creation grid_x, grid_y = torch.meshgrid( torch.arange(W, device=im0.device, dtype=im0.dtype), torch.arange(H, device=im0.device, dtype=im0.dtype), indexing='xy' ) grid_x = grid_x.unsqueeze(0).expand(B, -1, -1) grid_y = grid_y.unsqueeze(0).expand(B, -1, -1) # Destination coordinates x_dest = grid_x + flow[:, :, :, 0] y_dest = grid_y + flow[:, :, :, 1] if interpolation_mode_int == 0: # Bilinear Splatting x_f = torch.floor(x_dest).long() y_f = torch.floor(y_dest).long() x_c = x_f + 1 y_c = y_f + 1 # Weights nw_k = (x_c.float() - x_dest) * (y_c.float() - y_dest) ne_k = (x_dest - x_f.float()) * (y_c.float() - y_dest) sw_k = (x_c.float() - x_dest) * (y_dest - y_f.float()) se_k = (x_dest - x_f.float()) * (y_dest - y_f.float()) # Clamp coords x_f_clamped = torch.clamp(x_f, 0, W - 1) y_f_clamped = torch.clamp(y_f, 0, H - 1) x_c_clamped = torch.clamp(x_c, 0, W - 1) y_c_clamped = torch.clamp(y_c, 0, H - 1) # Per-corner validity masks mask_nw = (x_f >= 0) & (x_f < W) & (y_f >= 0) & (y_f < H) mask_ne = (x_c >= 0) & (x_c < W) & (y_f >= 0) & (y_f < H) mask_sw = (x_f >= 0) & (x_f < W) & (y_c >= 0) & (y_c < H) mask_se = (x_c >= 0) & (x_c < W) & (y_c >= 0) & (y_c < H) # Reshape for broadcasting nw_k = nw_k.unsqueeze(1) ne_k = ne_k.unsqueeze(1) sw_k = sw_k.unsqueeze(1) se_k = se_k.unsqueeze(1) mask_nw = mask_nw.unsqueeze(1) mask_ne = mask_ne.unsqueeze(1) mask_sw = mask_sw.unsqueeze(1) mask_se = mask_se.unsqueeze(1) # Flatten indices for scatter_add b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W) c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W) base_idx = b_indices * (C * H * W) + c_indices * (H * W) # Scatter to 4 neighbors (Accumulate/Splat) def scatter_corner(y_idx, x_idx, weights, mask): flat_idx = base_idx + y_idx.unsqueeze(1) * W + x_idx.unsqueeze(1) values = (im0 * weights) * mask.float() # Since im1 is contiguous, we can safely use view() for in-place scatter im1_flat = im1.view(-1) idx_flat = flat_idx.contiguous().view(-1) val_flat = values.contiguous().view(-1) im1_flat.scatter_add_(0, idx_flat, val_flat) scatter_corner(y_f_clamped, x_f_clamped, nw_k, mask_nw) # NW scatter_corner(y_f_clamped, x_c_clamped, ne_k, mask_ne) # NE scatter_corner(y_c_clamped, x_f_clamped, sw_k, mask_sw) # SW scatter_corner(y_c_clamped, x_c_clamped, se_k, mask_se) # SE else: # Nearest Neighbor (Legacy fallback) x_nearest = torch.round(x_dest).long() y_nearest = torch.round(y_dest).long() valid_mask = (x_nearest >= 0) & (x_nearest < W) & (y_nearest >= 0) & (y_nearest < H) valid_mask = valid_mask.unsqueeze(1) x_clamped = torch.clamp(x_nearest, 0, W - 1) y_clamped = torch.clamp(y_nearest, 0, H - 1) b_indices = torch.arange(B, device=im0.device).view(B, 1, 1, 1).expand(-1, C, H, W) c_indices = torch.arange(C, device=im0.device).view(1, C, 1, 1).expand(B, -1, H, W) dest_idx = b_indices * (C * H * W) + c_indices * (H * W) + y_clamped.unsqueeze(1) * W + x_clamped.unsqueeze( 1) source_values = im0 * valid_mask.float() # Since im1 is contiguous, we can safely use view() im1.view(-1).scatter_(0, dest_idx.contiguous().view(-1), source_values.contiguous().view(-1)) return im1 @staticmethod def backward(ctx, grad_output): return None, None, None class forward_warp(nn.Module): def __init__(self, interpolation_mode="Bilinear"): super(forward_warp, self).__init__() self.interpolation_mode_int = 0 if interpolation_mode == "Bilinear" else 1 def forward(self, im0, flow): return ForwardWarpFunction.apply(im0, flow, self.interpolation_mode_int) # ============================================================================== # 2. STEREO WARPER WRAPPER # ============================================================================== class ForwardWarpStereo(nn.Module): """ Weighted Splatting wrapper. Handles Occlusions using exponential depth weights (Soft Z-Buffering). """ def __init__(self, eps=1e-6): super(ForwardWarpStereo, self).__init__() self.eps = eps self.fw = forward_warp(interpolation_mode="Bilinear") def forward(self, im, disp, convergence, divergence): # disp comes in as [B, 1, H, W] or [1, 1, H, W] # We need to squeeze the channel dim to do math with coordinates [B, H, W] disp_squeeze = disp.squeeze(1) # Shape [B, H, W] # Create Flow from Disparity # Shift = (Depth - Convergence) * Divergence # We negate it because standard flow is source->dest, but disparity logic varies. # For Right Eye view: Target = Source - Shift. So Flow = -Shift. shift = (disp_squeeze - convergence) * divergence flow_x = -shift # Stack flow (x, y=0) -> (B, H, W, 2) flow_y = torch.zeros_like(flow_x) # Stack along last dim: [B, H, W] + [B, H, W] -> [B, H, W, 2] flow = torch.stack((flow_x, flow_y), dim=-1) # 1. Calculate Weights (Soft Z-Buffer) # Closer objects (higher disparity) get exponentially higher weight. # This allows foreground to overwrite background during accumulation. # Using 1.5^disp is a tuned heuristic for separation. disp_norm = disp_squeeze / (disp_squeeze.max() + 1e-8) weights_map = disp_norm + 0.05 weights_map = weights_map.unsqueeze(1) # 2. Warp Image * Weights (Accumulate Weighted Color) # Input im is (B, C, H, W), weights is (B, 1, H, W) res_accum = self.fw(im * weights_map, flow) # 3. Warp Weights (Accumulate Weights) mask_accum = self.fw(weights_map, flow) # 4. Normalize (Color / TotalWeight) # Add epsilon to avoid divide-by-zero in empty regions mask_accum.clamp_(min=self.eps) res = res_accum / mask_accum # 5. Generate Binary Occlusion Mask (for Inpainting) # Splat a grid of ones. Where sum is 0, we have a hole. ones = torch.ones_like(disp) occupancy = self.fw(ones, flow) # Valid pixels have occupancy > 0. # We want holes = 1.0, filled = 0.0 occlusion_mask = (occupancy < self.eps).float() return res, occlusion_mask # ============================================================================== # 3. APP LOGIC & MODELS # ============================================================================== # === LOAD MODELS === def load_models(): print("Loading Depth Anything V2 Large...") depth_model = AutoModelForDepthEstimation.from_pretrained( "depth-anything/Depth-Anything-V2-Large-hf" ).to(device) depth_processor = AutoImageProcessor.from_pretrained( "depth-anything/Depth-Anything-V2-Large-hf" ) print("Loading LaMa Inpainting Model...") try: model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt") lama_model = torch.jit.load(model_path, map_location=device) lama_model.eval() except Exception as e: print(f"Error loading LaMa model: {e}") raise e # Initialize the new Stereo Warper stereo_warper = ForwardWarpStereo().to(device) return depth_model, depth_processor, lama_model, stereo_warper # Load models once at startup depth_model, depth_processor, lama_model, stereo_warper = load_models() # === DEPTH ESTIMATION === @torch.no_grad() def estimate_depth(image_pil, model, processor): original_size = image_pil.size inputs = processor(images=image_pil, return_tensors="pt").to(device) depth = model(**inputs).predicted_depth depth = torch.nn.functional.interpolate( depth.unsqueeze(1), size=(original_size[1], original_size[0]), mode="bicubic", align_corners=False, ).squeeze() depth_min, depth_max = depth.min(), depth.max() if depth_max - depth_min > 0: depth = (depth - depth_min) / (depth_max - depth_min) else: depth = torch.zeros_like(depth) return depth # === DEPTH MANIPULATION === def erode_depth(depth_tensor, kernel_size): if kernel_size <= 0: return depth_tensor k = kernel_size if kernel_size % 2 == 1 else kernel_size + 1 x = depth_tensor.unsqueeze(0).unsqueeze(0) padding = k // 2 x_eroded = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=padding) return x_eroded.squeeze() # === LOCAL INPAINTING === @torch.no_grad() def run_local_lama(image_bgr, mask_float): # 0. Dilate Mask slightly to catch edge artifacts from splatting kernel = np.ones((3, 3), np.uint8) mask_uint8 = (mask_float * 255).astype(np.uint8) mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=3) # 1. Resize to be divisible by 8 h, w = image_bgr.shape[:2] new_h = (h // 8) * 8 new_w = (w // 8) * 8 img_resized = cv2.resize(image_bgr, (new_w, new_h)) mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST) # 2. Convert to Torch img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 img_t = img_t[:, [2, 1, 0], :, :] # BGR to RGB mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0 mask_t = (mask_t > 0.5).float() img_t = img_t.to(device) mask_t = mask_t.to(device) # 3. Inference img_t = img_t * (1 - mask_t) inpainted_t = lama_model(img_t, mask_t) # 4. Post-process inpainted = inpainted_t[0].permute(1, 2, 0).cpu().numpy() inpainted = np.clip(inpainted * 255, 0, 255).astype(np.uint8) inpainted = cv2.cvtColor(inpainted, cv2.COLOR_RGB2BGR) if new_h != h or new_w != w: inpainted = cv2.resize(inpainted, (w, h)) return inpainted def make_anaglyph(left, right): l_arr = np.array(left) r_arr = np.array(right) anaglyph = np.zeros_like(l_arr) anaglyph[:, :, 0] = l_arr[:, :, 0] anaglyph[:, :, 1] = r_arr[:, :, 1] anaglyph[:, :, 2] = r_arr[:, :, 2] return Image.fromarray(anaglyph) # === PIPELINE === def stereo_pipeline(image_pil, divergence, convergence, edge_erosion): if image_pil is None: return None, None, None, None # Resize input if too large w, h = image_pil.size if w > 1920: ratio = 1920 / w new_h = int(h * ratio) image_pil = image_pil.resize((1920, new_h), Image.LANCZOS) # 1. Depth Estimation depth_tensor = estimate_depth(image_pil, depth_model, depth_processor) # 2. Depth Erosion (optional halo reduction) if edge_erosion > 0: depth_tensor = erode_depth(depth_tensor, int(edge_erosion)) # Visualize Depth depth_vis = (depth_tensor.cpu().numpy() * 255).astype(np.uint8) depth_image = Image.fromarray(depth_vis) # 3. Forward Warp (Weighted Bilinear Splatting) # Convert image to tensor (B, C, H, W) image_tensor = torch.from_numpy(np.array(image_pil)).float().to(device).permute(2, 0, 1).unsqueeze(0) / 255.0 # Prepare depth tensor (B, 1, H, W) depth_input = depth_tensor.unsqueeze(0).unsqueeze(0) # Run the new Stereo Warper with torch.no_grad(): right_img_tensor, mask_tensor = stereo_warper( image_tensor, depth_input, float(convergence), float(divergence) ) # Convert results back to CPU/Numpy right_img_rgb = (right_img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) mask_vis = (mask_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255).astype(np.uint8) mask_image = Image.fromarray(mask_vis) # 4. Inpainting right_img_bgr = cv2.cvtColor(right_img_rgb, cv2.COLOR_RGB2BGR) mask_float = mask_tensor.squeeze().cpu().numpy() right_filled_bgr = run_local_lama(right_img_bgr, mask_float) # 5. Finalize left = image_pil right = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB)) width, height = left.size combined_image = Image.new('RGB', (width * 2, height)) combined_image.paste(left, (0, 0)) combined_image.paste(right, (width, 0)) anaglyph_image = make_anaglyph(left, right) return combined_image, anaglyph_image, depth_image, mask_image # === GRADIO UI === with gr.Blocks(title="2D to 3D Stereo") as demo: # Inject CSS gr.Markdown("## 2D to 3D Stereo Generator (High-Quality Splatting)") gr.Markdown("Uses **Depth Anything V2**, **Bilinear Weighted Splatting** (Soft Z-Buffer), and **LaMa Inpainting**.") with gr.Row(): with gr.Column(scale=1): input_img = gr.Image(type="pil", label="Input Image", height=320) with gr.Group(): gr.Markdown("### 3D Controls") divergence_slider = gr.Slider( minimum=0, maximum=100, value=30, step=1, label="3D Strength (Divergence)", info="Max separation in pixels." ) convergence_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Focus Plane (Convergence)", info="0.0 = Background at screen. 1.0 = Foreground at screen." ) erosion_slider = gr.Slider( minimum=0, maximum=20, value=2, step=1, label="Edge Masking (Erosion)", info="Cleanup edges. Set to 0 for raw splatting." ) btn = gr.Button("Generate 3D", variant="primary") with gr.Column(scale=1): out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan)", height=320) out_stereo = gr.Image(label="Side-by-Side Stereo Pair", height=320) with gr.Row(): out_depth = gr.Image(label="Depth Map", height=200) out_mask = gr.Image(label="Inpainting Mask (Holes)", height=200) btn.click( fn=stereo_pipeline, inputs=[input_img, divergence_slider, convergence_slider, erosion_slider], outputs=[out_stereo, out_anaglyph, out_depth, out_mask] ) if __name__ == "__main__": demo.launch()