Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from transformers import AutoModelForDepthEstimation, AutoImageProcessor | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # === DEVICE === | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Running on device: {device}") | |
| # ============================================================================== | |
| # 1. SAFE & FAST FORWARD WARPER USING grid_sample (NO MORE BLACK IMAGES!) | |
| # ============================================================================== | |
| class SafeForwardWarp(nn.Module): | |
| def forward(self, img, flow): | |
| """ | |
| img: [B, C, H, W] in [0,1] | |
| flow: [B, H, W, 2] flow[...,0] = delta_x (positive = right), flow[...,1] = delta_y | |
| """ | |
| B, C, H, W = img.shape | |
| # Create sampling grid in normalized coordinates [-1, 1] | |
| grid_x, grid_y = torch.meshgrid( | |
| torch.arange(W, device=img.device), | |
| torch.arange(H, device=img.device), | |
| indexing='ij' | |
| ) | |
| grid_x = grid_x.float().unsqueeze(0).expand(B, -1, -1) # [B, H, W] | |
| grid_y = grid_y.float().unsqueeze(0).expand(B, -1, -1) | |
| dest_x = grid_x + flow[..., 0] # source pixel moves to x + dx | |
| dest_y = grid_y + flow[..., 1] | |
| # Normalize to [-1, 1] | |
| norm_x = 2.0 * dest_x / (W - 1) - 1.0 | |
| norm_y = 2.0 * dest_y / (H - 1) - 1.0 | |
| grid = torch.stack((norm_x, norm_y), dim=-1) # [B, H, W, 2] | |
| grid = grid.clamp(-1, 1) | |
| warped = torch.nn.functional.grid_sample( | |
| img, | |
| grid, | |
| mode='bilinear', | |
| padding_mode='zeros', | |
| align_corners=True | |
| ) | |
| return warped | |
| # ============================================================================== | |
| # 2. STEREO WARPER β Improved weighting + safer dilation | |
| # ============================================================================== | |
| class ForwardWarpStereo(nn.Module): | |
| def __init__(self, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.warp = SafeForwardWarp() | |
| def forward(self, img, shift, disp_for_weights): | |
| # shift: [B, H, W] (positive = shift right-eye left β object pops out) | |
| flow_x = -shift # negative = move pixels left for right eye | |
| flow_y = torch.zeros_like(flow_x) | |
| flow = torch.stack((flow_x, flow_y), dim=-1) # [B, H, W, 2] | |
| # Better weighting: closer pixels contribute more | |
| weights = 1.0 / (disp_for_weights + 0.1) | |
| weights = weights / (weights.max() + 1e-8) | |
| weighted_img = img * weights.unsqueeze(1) | |
| warped_img = self.warp(weighted_img, flow) | |
| warped_weights = self.warp(weights.unsqueeze(1), flow) | |
| # Avoid division by zero | |
| warped_weights = torch.clamp(warped_weights, min=self.eps) | |
| result = warped_img / warped_weights | |
| # Occlusion mask via occupancy count | |
| ones = torch.ones_like(img[:, :1]) | |
| occupancy = self.warp(ones, flow) | |
| occlusion = (occupancy < self.eps).float() | |
| # Smart dilation β preserve foreground edges | |
| with torch.no_grad(): | |
| fg_thresh = torch.quantile(disp_for_weights, 0.90) | |
| fg_mask = (disp_for_weights > fg_thresh).float().unsqueeze(0) | |
| k = 9 | |
| dilated = torch.nn.functional.conv2d( | |
| occlusion, | |
| torch.ones(1, 1, k, k, device=occlusion.device), | |
| padding=k // 2 | |
| ) > 0.5 | |
| safe_dilation = dilated.float() * (1 - fg_mask) | |
| occlusion = torch.clamp(occlusion + safe_dilation, 0, 1) | |
| return result, occlusion | |
| # ============================================================================== | |
| # 3. MODELS & HELPERS | |
| # ============================================================================== | |
| def load_models(): | |
| print("Loading Depth Anything V2 Large...") | |
| depth_model = AutoModelForDepthEstimation.from_pretrained( | |
| "depth-anything/Depth-Anything-V2-Large-hf" | |
| ).to(device).eval() | |
| depth_processor = AutoImageProcessor.from_pretrained( | |
| "depth-anything/Depth-Anything-V2-Large-hf" | |
| ) | |
| print("Loading LaMa Inpainting Model...") | |
| try: | |
| model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt") | |
| lama_model = torch.jit.load(model_path, map_location=device).eval() | |
| except Exception as e: | |
| print(f"LaMa load failed: {e}") | |
| lama_model = None | |
| stereo_warper = ForwardWarpStereo().to(device) | |
| return depth_model, depth_processor, lama_model, stereo_warper | |
| depth_model, depth_processor, lama_model, stereo_warper = load_models() | |
| def estimate_depth(image_pil): | |
| original_size = image_pil.size | |
| inputs = depth_processor(images=image_pil, return_tensors="pt").to(device) | |
| outputs = depth_model(**inputs) | |
| depth = outputs.predicted_depth | |
| depth = torch.nn.functional.interpolate( | |
| depth.unsqueeze(1), | |
| size=(original_size[1], original_size[0]), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze(0).squeeze(0) | |
| # Normalize to [0,1] | |
| d_min, d_max = depth.min(), depth.max() | |
| if d_max > d_min: | |
| depth = (depth - d_min) / (d_max - d_min) | |
| return depth | |
| def run_lama(image_bgr, mask_float): | |
| if lama_model is None: | |
| return image_bgr | |
| mask_uint8 = (mask_float * 255).astype(np.uint8) | |
| kernel = np.ones((7, 7), np.uint8) | |
| mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=2) | |
| h, w = image_bgr.shape[:2] | |
| new_h = (h // 8) * 8 | |
| new_w = (w // 8) * 8 | |
| img_resized = cv2.resize(image_bgr, (new_w, new_h)) | |
| mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST) | |
| img_t = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 | |
| img_t = img_t[:, [2, 1, 0]].to(device) # BGR β RGB | |
| mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0 | |
| mask_t = (mask_t > 0.5).float().to(device) | |
| img_t = img_t * (1 - mask_t) | |
| inpainted = lama_model(img_t, mask_t) | |
| result = (inpainted[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8) | |
| result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) | |
| if (new_h, new_w) != (h, w): | |
| result = cv2.resize(result, (w, h)) | |
| return result | |
| def make_anaglyph(left, right): | |
| l = np.array(left) | |
| r = np.array(right) | |
| ana = np.zeros_like(l) | |
| ana[:, :, 0] = l[:, :, 0] # Red β Left | |
| ana[:, :, 1] = r[:, :, 1] # Green β Right | |
| ana[:, :, 2] = r[:, :, 2] # Blue β Right | |
| return Image.fromarray(ana) | |
| # ============================================================================== | |
| # 4. MAIN PIPELINE | |
| # ============================================================================== | |
| def stereo_pipeline(image_pil, divergence_percent=3.2, convergence_plane=0.08): | |
| if image_pil is None: | |
| return None, None, None, None | |
| w, h = image_pil.size | |
| if w > 1920: | |
| ratio = 1920 / w | |
| image_pil = image_pil.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS) | |
| w, h = image_pil.size | |
| # 1. Depth | |
| depth = estimate_depth(image_pil) # [H, W] in [0,1] | |
| depth_vis = Image.fromarray((depth.cpu().numpy() * 255).astype(np.uint8)) | |
| # 2. Disparity (stronger volume with square) | |
| disp_raw = depth ** 2 | |
| disp_clipped = torch.clamp(disp_raw, max=torch.quantile(disp_raw, 0.995)) | |
| # 3. Shift | |
| max_shift = w * (divergence_percent / 100.0) | |
| shift_raw = disp_clipped * max_shift | |
| shift_min, shift_max = shift_raw.min(), shift_raw.max() | |
| convergence_offset = shift_min + convergence_plane * (shift_max - shift_min) | |
| final_shift = shift_raw - convergence_offset | |
| print(f"Final shift range: {final_shift.min():.1f} β {final_shift.max():.1f anywhere} px") | |
| # 4. Warp right eye | |
| img_tensor = torch.from_numpy(np.array(image_pil)).float().to(device) / 255.0 | |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1,3,H,W] | |
| shift_tensor = final_shift.unsqueeze(0).to(device) # [1,H,W] | |
| disp_tensor = disp_clipped.unsqueeze(0).to(device) | |
| right_tensor, occlusion_mask = stereo_warper(img_tensor, shift_tensor, disp_tensor) | |
| # 5. To numpy | |
| right_np = (right_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| right_bgr = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR) | |
| mask_np = occlusion_mask.squeeze(0).cpu().numpy() | |
| # 6. Inpaint occlusions | |
| right_filled_bgr = run_lama(right_bgr, mask_np) | |
| right_filled = Image.fromarray(cv2.cvtColor(right_filled_bgr, cv2.COLOR_BGR2RGB)) | |
| # 7. Outputs | |
| mask_vis = Image.fromarray((mask_np * 255).astype(np.uint8)) | |
| sbs = Image.new('RGB', (w * 2, h)) | |
| sbs.paste(image_pil, (0, 0)) | |
| sbs.paste(right_filled, (w, 0)) | |
| anaglyph = make_anaglyph(image_pil, right_filled) | |
| return sbs, anaglyph, depth_vis, mask_vis | |
| # ============================================================================== | |
| # 5. GRADIO UI | |
| # ============================================================================== | |
| with gr.Blocks(title="2D β 3D Stereo β Pro & Stable") as demo: | |
| gr.HTML("<h1 style='text-align:center;'>2D to 3D Stereo β Pro Quality (Fixed & Stable)</h1>") | |
| gr.Markdown("Depth Anything V2 + Safe Forward Warping + LaMa Inpainting") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image(type="pil", label="Upload Image", height=520) | |
| with gr.Accordion("Settings", open=True): | |
| divergence = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="3D Strength (%)") | |
| convergence = gr.Slider(0.0, 1.0, value=0.08, step=0.01, | |
| label="Convergence Plane (0 = pop-out, 1 = deep)") | |
| btn = gr.Button("Generate 3D", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| out_anaglyph = gr.Image(label="Anaglyph (Red/Cyan Glasses)", height=520) | |
| out_sbs = gr.Image(label="Side-by-Side (Cross-eye / Parallel)", height=300) | |
| with gr.Row(): | |
| out_depth = gr.Image(label="Depth Map", height=200) | |
| out_mask = gr.Image(label="Occlusion Mask", height=200) | |
| btn.click( | |
| fn=stereo_pipeline, | |
| inputs=[input_img, divergence, convergence], | |
| outputs=[out_sbs, out_anaglyph, out_depth, out_mask] | |
| ) | |
| gr.Markdown("**Tip:** Use Red/Cyan glasses for anaglyph β’ Cross-eye or parallel view for SBS") | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |