Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| CREStereo Gradio Demo with ZeroGPU Integration | |
| This demo showcases the CREStereo model for stereo depth estimation. | |
| Optimized for Hugging Face Spaces with ZeroGPU support. | |
| Key ZeroGPU optimizations: | |
| - @spaces.GPU decorators for GPU-intensive functions | |
| - CUDA operations only within GPU context | |
| - Memory-efficient inference with cleanup | |
| - Safe CUDA initialization patterns | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import tempfile | |
| import gc | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Union | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| import imageio | |
| # Import spaces BEFORE torch to ensure proper ZeroGPU initialization | |
| import spaces | |
| # Import torch after spaces - avoid any CUDA calls during import | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.cuda.amp import autocast | |
| # Completely avoid CUDA operations during import phase | |
| # Do not set default tensor type or modify CUDA settings outside GPU context | |
| # torch.set_default_tensor_type('torch.FloatTensor') # Commented out - causes CUDA init | |
| # Do not modify CUDA settings during import - this can trigger CUDA initialization | |
| # torch.backends.cudnn.enabled = False # Commented out | |
| # torch.backends.cudnn.benchmark = False # Commented out | |
| # Use current directory as base | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| base_dir = current_dir | |
| # Add current directory to path for local imports | |
| sys.path.insert(0, current_dir) | |
| # Import local modules | |
| from nets import Model | |
| # Import Open3D with error handling | |
| OPEN3D_AVAILABLE = False | |
| try: | |
| # Set Open3D to CPU mode to avoid CUDA initialization | |
| os.environ['OPEN3D_CPU_RENDERING'] = '1' | |
| # Don't import open3d here - do it inside functions | |
| # import open3d as o3d | |
| OPEN3D_AVAILABLE = True # Assume available, will check later | |
| except Exception as e: | |
| logging.warning(f"Open3D setup failed: {e}") | |
| OPEN3D_AVAILABLE = False | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Model configuration | |
| MODEL_VARIANTS = { | |
| "crestereo_eth3d": { | |
| "display_name": "CREStereo ETH3D (Pre-trained model)", | |
| "model_file": "models/crestereo_eth3d.pth", | |
| "max_disp": 256 | |
| } | |
| } | |
| # Global variables for model caching | |
| _cached_model = None | |
| _cached_device = None | |
| _cached_model_selection = None | |
| class InputPadder: | |
| """ Pads images such that dimensions are divisible by divis_by """ | |
| def __init__(self, dims, divis_by=8, force_square=False): | |
| self.ht, self.wd = dims[-2:] | |
| pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by | |
| pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by | |
| if force_square: | |
| # Make the padded dimensions square | |
| max_dim = max(self.ht + pad_ht, self.wd + pad_wd) | |
| pad_ht = max_dim - self.ht | |
| pad_wd = max_dim - self.wd | |
| self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] | |
| def pad(self, *inputs): | |
| return [F.pad(x, self._pad, mode='replicate') for x in inputs] | |
| def unpad(self, x): | |
| ht, wd = x.shape[-2:] | |
| c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] | |
| return x[..., c[0]:c[1], c[2]:c[3]] | |
| def aggressive_cleanup(): | |
| """Perform basic cleanup - no CUDA operations outside GPU context""" | |
| import gc | |
| gc.collect() | |
| logging.info("Performed basic memory cleanup") | |
| def initialize_gpu_context(): | |
| """Initialize GPU context safely for ZeroGPU""" | |
| try: | |
| # Set CUDA settings safely within GPU context | |
| torch.set_default_tensor_type('torch.cuda.FloatTensor') | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = True | |
| # Check GPU availability and log info | |
| if torch.cuda.is_available(): | |
| device_name = torch.cuda.get_device_name(0) | |
| memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| logging.info(f"GPU initialized: {device_name}, Total memory: {memory_total:.2f}GB") | |
| return True | |
| else: | |
| logging.error("CUDA not available after GPU context initialization") | |
| return False | |
| except Exception as e: | |
| logging.error(f"GPU context initialization failed: {e}") | |
| return False | |
| def check_gpu_memory(): | |
| """Check and log current GPU memory usage - only call within GPU context""" | |
| try: | |
| allocated = torch.cuda.memory_allocated(0) / 1024**3 | |
| reserved = torch.cuda.memory_reserved(0) / 1024**3 | |
| max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB") | |
| return allocated, reserved, max_allocated, total | |
| except RuntimeError as e: | |
| logging.warning(f"Failed to get GPU memory info: {e}") | |
| return None, None, None, None | |
| def get_available_models() -> dict: | |
| """Get all available models with their display names""" | |
| models = {} | |
| # Check for local models | |
| for variant, info in MODEL_VARIANTS.items(): | |
| model_path = os.path.join(current_dir, info["model_file"]) | |
| if os.path.exists(model_path): | |
| display_name = info["display_name"] | |
| models[display_name] = { | |
| "model_path": model_path, | |
| "variant": variant, | |
| "max_disp": info["max_disp"], | |
| "source": "local" | |
| } | |
| return models | |
| def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[dict]]: | |
| """Get model path and config from the selected model""" | |
| models = get_available_models() | |
| # Check if it's in our models dict | |
| if model_selection in models: | |
| model_info = models[model_selection] | |
| logging.info(f"π Using local model: {model_selection}") | |
| return model_info["model_path"], model_info | |
| return None, None | |
| def load_model_for_inference(model_path: str, model_info: dict): | |
| """Load CREStereo model for inference temporarily (demo-style)""" | |
| # Set CUDA settings safely within GPU context | |
| torch.set_default_tensor_type('torch.cuda.FloatTensor') # Now safe to use CUDA tensors | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = True | |
| # Check if CUDA is available after ZeroGPU initialization | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is not available. ZeroGPU initialization may have failed.") | |
| # Use the first available CUDA device | |
| device = torch.device("cuda") | |
| # Set CUDA seed safely within GPU context | |
| try: | |
| random_seed = 0 | |
| torch.cuda.manual_seed_all(random_seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| except Exception as e: | |
| logging.warning(f"Could not set CUDA seed: {e}") | |
| try: | |
| # Create model | |
| max_disp = model_info.get("max_disp", 256) | |
| model = Model(max_disp=max_disp, mixed_precision=False, test_mode=True) | |
| # Load checkpoint | |
| ckpt = torch.load(model_path, map_location=device) | |
| model.load_state_dict(ckpt, strict=True) | |
| model.to(device) | |
| model.eval() | |
| logging.info("Loaded CREStereo model weights") | |
| # Memory optimizations | |
| torch.set_grad_enabled(False) | |
| logging.info("Applied memory optimizations") | |
| return model, device | |
| except Exception as e: | |
| logging.error(f"Model loading failed: {e}") | |
| raise RuntimeError(f"Failed to load model: {e}") | |
| def get_cached_model(model_selection: str): | |
| """Get cached model or load new one if selection changed""" | |
| global _cached_model, _cached_device, _cached_model_selection | |
| # Get model paths from selection | |
| model_path, model_info = get_model_paths_from_selection(model_selection) | |
| if model_path is None or model_info is None: | |
| raise ValueError(f"Selected model not found: {model_selection}") | |
| # Check if we need to reload the model | |
| if (_cached_model is None or | |
| _cached_model_selection != model_selection): | |
| # Clear previous model if exists | |
| if _cached_model is not None: | |
| del _cached_model | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| logging.info(f"π Loading model: {model_selection}") | |
| _cached_model, _cached_device = load_model_for_inference(model_path, model_info) | |
| _cached_model_selection = model_selection | |
| logging.info(f"β Model loaded successfully: {model_selection}") | |
| else: | |
| logging.info(f"β Using cached model: {model_selection}") | |
| return _cached_model, _cached_device | |
| def clear_model_cache(): | |
| """Clear the cached model to free memory""" | |
| global _cached_model, _cached_device, _cached_model_selection | |
| if _cached_model is not None: | |
| logging.info("Clearing model cache...") | |
| del _cached_model | |
| _cached_model = None | |
| _cached_device = None | |
| _cached_model_selection = None | |
| # Simple cleanup | |
| import gc | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| logging.info("Model cache cleared") | |
| else: | |
| logging.info("No model in cache to clear") | |
| def inference(left, right, model, device, n_iter=20): | |
| """Run CREStereo inference on stereo pair""" | |
| print("Model Forwarding...") | |
| imgL = left.transpose(2, 0, 1) | |
| imgR = right.transpose(2, 0, 1) | |
| imgL = np.ascontiguousarray(imgL[None, :, :, :]) | |
| imgR = np.ascontiguousarray(imgR[None, :, :, :]) | |
| imgL = torch.tensor(imgL.astype("float32")).to(device) | |
| imgR = torch.tensor(imgR.astype("float32")).to(device) | |
| # Use InputPadder to handle any image size | |
| padder = InputPadder(imgL.shape, divis_by=8) | |
| imgL_padded, imgR_padded = padder.pad(imgL, imgR) | |
| # Downsample for coarse prediction | |
| imgL_dw2 = F.interpolate( | |
| imgL_padded, | |
| size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2), | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| imgR_dw2 = F.interpolate( | |
| imgR_padded, | |
| size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2), | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| with torch.inference_mode(): | |
| pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) | |
| pred_flow = model(imgL_padded, imgR_padded, iters=n_iter, flow_init=pred_flow_dw2) | |
| # Unpad the result to original dimensions | |
| pred_flow = padder.unpad(pred_flow) | |
| pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() | |
| return pred_disp | |
| def vis_disparity(disparity_map, max_val=None): | |
| """Visualize disparity map""" | |
| if max_val is None: | |
| disp_vis = (disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min()) * 255.0 | |
| else: | |
| disp_vis = np.clip(disparity_map / max_val * 255.0, 0, 255) | |
| disp_vis = disp_vis.astype("uint8") | |
| disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) | |
| disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB) | |
| return disp_vis | |
| # Fixed with static duration | |
| # Static 60 seconds for basic processing | |
| def process_stereo_pair(model_selection: str, left_image: str, right_image: str, | |
| progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]: | |
| """ | |
| Main processing function for stereo pair (with model caching) | |
| """ | |
| logging.info("Starting stereo pair processing...") | |
| if left_image is None or right_image is None: | |
| return None, "β Please upload both left and right images." | |
| # Convert image paths to numpy arrays | |
| logging.info(f"Loading images: left={left_image}, right={right_image}") | |
| try: | |
| # Load left image | |
| if not os.path.exists(left_image): | |
| logging.error(f"Left image file does not exist: {left_image}") | |
| return None, f"β Left image file not found: {left_image}" | |
| logging.info(f"Loading left image from: {left_image}") | |
| left_img = cv2.imread(left_image) | |
| if left_img is not None: | |
| left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) | |
| else: | |
| # Try with imageio as fallback | |
| left_img = imageio.imread(left_image) | |
| if len(left_img.shape) == 3 and left_img.shape[2] == 4: | |
| left_img = left_img[:, :, :3] | |
| # Load right image | |
| if not os.path.exists(right_image): | |
| logging.error(f"Right image file does not exist: {right_image}") | |
| return None, f"β Right image file not found: {right_image}" | |
| logging.info(f"Loading right image from: {right_image}") | |
| right_img = cv2.imread(right_image) | |
| if right_img is not None: | |
| right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB) | |
| else: | |
| # Try with imageio as fallback | |
| right_img = imageio.imread(right_image) | |
| if len(right_img.shape) == 3 and right_img.shape[2] == 4: | |
| right_img = right_img[:, :, :3] | |
| logging.info(f"Images loaded successfully - Left: {left_img.shape}, Right: {right_img.shape}") | |
| except Exception as e: | |
| logging.error(f"Failed to load images: {e}") | |
| return None, f"β Failed to load images: {str(e)}" | |
| try: | |
| # Get cached model | |
| variant_name = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection | |
| progress(0.1, desc=f"Loading cached model ({variant_name})...") | |
| logging.info("π Getting cached model...") | |
| model, device = get_cached_model(model_selection) | |
| logging.info("β Cached model loaded successfully") | |
| progress(0.2, desc="Preprocessing images...") | |
| # Validate input images | |
| if left_img.shape != right_img.shape: | |
| return None, "β Left and right images must have the same dimensions." | |
| H, W = left_img.shape[:2] | |
| progress(0.5, desc="Running inference...") | |
| # Process stereo pair | |
| torch.cuda.empty_cache() # Clear any cached memory before inference | |
| disp_cpu = inference(left_img, right_img, model, device, n_iter=20) | |
| progress(0.8, desc="Creating visualization...") | |
| # Create visualization | |
| disparity_vis = vis_disparity(disp_cpu) | |
| result_image = disparity_vis | |
| progress(1.0, desc="Complete!") | |
| # Create status message | |
| valid_mask = ~np.isinf(disp_cpu) | |
| min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0 | |
| max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0 | |
| mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0 | |
| # Get model variant for status | |
| variant = variant_name | |
| # Check current memory usage | |
| try: | |
| current_memory = torch.cuda.memory_allocated(0) / 1024**3 | |
| max_memory = torch.cuda.max_memory_allocated(0) / 1024**3 | |
| memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak" | |
| except: | |
| memory_info = "" | |
| status = f"""β Processing successful! | |
| π§ Model: {variant}{memory_info} | |
| π Disparity Statistics: | |
| β’ Range: {min_disp:.2f} - {max_disp:.2f} | |
| β’ Mean: {mean_disp:.2f} | |
| β’ Input size: {W}Γ{H} | |
| β’ Valid pixels: {valid_mask.sum()}/{valid_mask.size}""" | |
| return result_image, status | |
| except Exception as e: | |
| logging.error(f"Processing failed: {e}") | |
| # Clean up GPU memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return None, f"β Error: {str(e)}" | |
| # Fixed with static duration | |
| # Static 120 seconds for depth processing | |
| def process_with_depth(model_selection: str, left_image: str, right_image: str, | |
| camera_matrix: str, baseline: float, | |
| progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]: | |
| """ | |
| Process stereo pair and generate depth map and point cloud (with model caching) | |
| """ | |
| # Import Open3D | |
| global OPEN3D_AVAILABLE | |
| try: | |
| import open3d as o3d | |
| OPEN3D_AVAILABLE = True | |
| except ImportError as e: | |
| logging.warning(f"Open3D not available: {e}") | |
| OPEN3D_AVAILABLE = False | |
| return None, None, None, "β Open3D not available. Point cloud generation disabled." | |
| if left_image is None or right_image is None: | |
| return None, None, None, "β Please upload both left and right images." | |
| try: | |
| progress(0.1, desc="Parsing camera parameters...") | |
| # Parse camera matrix | |
| try: | |
| K_values = list(map(float, camera_matrix.strip().split())) | |
| if len(K_values) != 9: | |
| return None, None, None, "β Camera matrix must contain exactly 9 values." | |
| K = np.array(K_values).reshape(3, 3) | |
| except ValueError: | |
| return None, None, None, "β Invalid camera matrix format. Use space-separated numbers." | |
| if baseline <= 0: | |
| return None, None, None, "β Baseline must be positive." | |
| # First get disparity using the same process as basic function | |
| disparity_result, status = process_stereo_pair(model_selection, left_image, right_image, progress) | |
| if disparity_result is None: | |
| return None, None, None, status | |
| # Load images again for depth processing | |
| left_img = cv2.imread(left_image) | |
| left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) | |
| # Get disparity from model again (we need the raw values, not the visualization) | |
| model, device = get_cached_model(model_selection) | |
| disp_cpu = inference(left_img, cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), model, device, n_iter=20) | |
| progress(0.6, desc="Converting to depth...") | |
| # Remove invisible points | |
| H, W = disp_cpu.shape | |
| yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij') | |
| us_right = xx - disp_cpu | |
| invalid = us_right < 0 | |
| disp_cpu[invalid] = np.inf | |
| # Convert to depth using the formula: depth = focal_length * baseline / disparity | |
| depth = K[0, 0] * baseline / disp_cpu | |
| # Visualize depth | |
| depth_vis = vis_disparity(depth, max_val=10.0) | |
| progress(0.8, desc="Generating point cloud...") | |
| # Generate point cloud | |
| fx, fy = K[0, 0], K[1, 1] | |
| cx, cy = K[0, 2], K[1, 2] | |
| # Create coordinate meshgrids | |
| u, v = np.meshgrid(np.arange(W), np.arange(H)) | |
| # Convert to 3D coordinates | |
| valid_depth = ~np.isinf(depth) | |
| z = depth[valid_depth] # Z coordinate (depth) | |
| x = (u[valid_depth] - cx) * z / fx # X coordinate | |
| y = (v[valid_depth] - cy) * z / fy # Y coordinate | |
| # Stack coordinates (X, Y, Z) | |
| points = np.stack([x, y, z], axis=-1) | |
| # Get corresponding colors | |
| colors = left_img[valid_depth] | |
| # Filter points by depth range | |
| depth_mask = (z > 0) & (z <= 10.0) | |
| valid_points = points[depth_mask] | |
| valid_colors = colors[depth_mask] | |
| if len(valid_points) == 0: | |
| return depth_vis, None, None, "β οΈ No valid points generated for point cloud." | |
| # Subsample points for better performance | |
| if len(valid_points) > 100000: | |
| indices = np.random.choice(len(valid_points), 100000, replace=False) | |
| valid_points = valid_points[indices] | |
| valid_colors = valid_colors[indices] | |
| # Transform coordinates for proper visualization | |
| transformed_points = valid_points.copy() | |
| transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis | |
| transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis | |
| # Generate point cloud | |
| pcd = o3d.geometry.PointCloud() | |
| pcd.points = o3d.utility.Vector3dVector(transformed_points) | |
| pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0) | |
| progress(1.0, desc="Complete!") | |
| # Check current memory usage | |
| try: | |
| current_memory = torch.cuda.memory_allocated(0) / 1024**3 | |
| max_memory = torch.cuda.max_memory_allocated(0) / 1024**3 | |
| memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak" | |
| except: | |
| memory_info = "" | |
| variant = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection | |
| status = f"""β Depth processing successful! | |
| π§ Model: {variant}{memory_info} | |
| π Statistics: | |
| β’ Valid points: {len(valid_points):,} | |
| β’ Depth range: {z.min():.2f} - {z.max():.2f} m | |
| β’ Baseline: {baseline} m | |
| β’ Point cloud generated with {len(valid_points)} points | |
| β’ 3D visualization available""" | |
| return depth_vis, None, None, status | |
| except Exception as e: | |
| logging.error(f"Depth processing failed: {e}") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return None, None, None, f"β Error: {str(e)}" | |
| def create_app() -> gr.Blocks: | |
| """Create the Gradio application""" | |
| # Get available models | |
| try: | |
| available_models = get_available_models() | |
| logging.info(f"Successfully got available models: {len(available_models)} found") | |
| except Exception as e: | |
| logging.error(f"Failed to get available models: {e}") | |
| available_models = {} | |
| with gr.Blocks( | |
| title="CREStereo - Stereo Depth Estimation", | |
| theme=gr.themes.Soft(), | |
| css="footer {visibility: hidden}", | |
| delete_cache=(60, 60) | |
| ) as app: | |
| gr.Markdown(""" | |
| # π CREStereo: Practical Stereo Matching | |
| Upload a pair of **rectified** stereo images to get disparity estimation using CREStereo. | |
| β οΈ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted. | |
| β‘ **GPU Powered**: Runs on CUDA-enabled GPUs for fast inference. | |
| """) | |
| # Instructions section | |
| with gr.Accordion("π Instructions", open=False): | |
| gr.Markdown(""" | |
| ## π How to Use This Demo | |
| ### πΌοΈ Input Requirements | |
| 1. **Image Format**: Upload images in JPEG or PNG format. | |
| 2. **Image Size**: Images should be of the same size and resolution. | |
| 3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted. | |
| 4. **Camera Parameters**: For depth processing, provide camera matrix and baseline distance. | |
| ### π Using the Demo | |
| 1. **Select Model**: Choose the CREStereo model variant | |
| 2. **Upload Images**: Provide rectified stereo image pairs | |
| 3. **Basic Processing**: Get disparity visualization | |
| 4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters) | |
| ### π Original Work | |
| This demo is based on CREStereo: Practical Stereo Matching via Cascaded Recurrent Network. | |
| - **Paper**: [CREStereo: Practical Stereo Matching via Cascaded Recurrent Network](https://arxiv.org/abs/2203.11483) | |
| - **Official Repository**: [https://github.com/megvii-research/CREStereo](https://github.com/megvii-research/CREStereo) | |
| """) | |
| # Model selection | |
| with gr.Row(): | |
| all_choices = list(available_models.keys()) | |
| if not all_choices: | |
| all_choices = ["No models found - Please ensure crestereo_eth3d.pth is in models/ directory"] | |
| default_model = all_choices[0] if all_choices else None | |
| model_selector = gr.Dropdown( | |
| choices=all_choices, | |
| value=default_model, | |
| label="π― Select Model", | |
| info="Choose the CREStereo model variant.", | |
| interactive=True | |
| ) | |
| with gr.Tabs(): | |
| # Basic stereo processing tab | |
| with gr.TabItem("πΌοΈ Basic Stereo Processing"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| left_input = gr.Image( | |
| label="π· Left Image", | |
| type="filepath", | |
| height=300 | |
| ) | |
| right_input = gr.Image( | |
| label="π· Right Image", | |
| type="filepath", | |
| height=300 | |
| ) | |
| process_btn = gr.Button( | |
| "π Process Stereo Pair", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="π Disparity Visualization", | |
| height=400 | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=8 | |
| ) | |
| # Example images | |
| examples_list = [] | |
| # Example 1 | |
| if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")): | |
| examples_list.append([ | |
| os.path.join(current_dir, "assets", "example1", "left.png"), | |
| os.path.join(current_dir, "assets", "example1", "right.png") | |
| ]) | |
| # Example 2 | |
| if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")): | |
| examples_list.append([ | |
| os.path.join(current_dir, "assets", "example2", "left.png"), | |
| os.path.join(current_dir, "assets", "example2", "right.png") | |
| ]) | |
| if examples_list: | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[left_input, right_input], | |
| label="π Example Images" | |
| ) | |
| # Advanced processing with depth | |
| with gr.TabItem("π Advanced Processing (Depth & Point Cloud)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| left_input_adv = gr.Image( | |
| label="π· Left Image", | |
| type="filepath", | |
| height=250 | |
| ) | |
| right_input_adv = gr.Image( | |
| label="π· Right Image", | |
| type="filepath", | |
| height=250 | |
| ) | |
| # Camera parameters | |
| with gr.Group(): | |
| gr.Markdown("### πΉ Camera Parameters") | |
| camera_matrix_input = gr.Textbox( | |
| label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)", | |
| value="", | |
| ) | |
| baseline_input = gr.Number( | |
| label="Baseline (meters)", | |
| value=None, | |
| minimum=0.001, | |
| maximum=10.0, | |
| step=0.001 | |
| ) | |
| process_depth_btn = gr.Button( | |
| "π¬ Process with Depth", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| depth_output = gr.Image( | |
| label="π Depth Visualization", | |
| height=300 | |
| ) | |
| pointcloud_output = gr.File( | |
| label="βοΈ Point Cloud Download (.ply)", | |
| file_types=[".ply"] | |
| ) | |
| status_depth = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=6 | |
| ) | |
| # 3D Point Cloud Visualization | |
| with gr.Row(): | |
| pointcloud_3d = gr.Model3D( | |
| label="π 3D Point Cloud Viewer", | |
| clear_color=[0.0, 0.0, 0.0, 0.0], | |
| height=400 | |
| ) | |
| # Example images for advanced processing | |
| examples_advanced_list = [] | |
| # Try to read camera parameters from K.txt files | |
| # Example 1 | |
| if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")): | |
| k_file = os.path.join(current_dir, "assets", "example1", "K.txt") | |
| camera_matrix_str = "" | |
| baseline_val = 0.063 # default | |
| if os.path.exists(k_file): | |
| try: | |
| with open(k_file, 'r') as f: | |
| lines = f.readlines() | |
| if len(lines) >= 1: | |
| camera_matrix_str = lines[0].strip() | |
| if len(lines) >= 2: | |
| baseline_val = float(lines[1].strip()) | |
| except: | |
| camera_matrix_str = "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0" | |
| examples_advanced_list.append([ | |
| os.path.join(current_dir, "assets", "example1", "left.png"), | |
| os.path.join(current_dir, "assets", "example1", "right.png"), | |
| camera_matrix_str, | |
| baseline_val | |
| ]) | |
| # Example 2 | |
| if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")): | |
| k_file = os.path.join(current_dir, "assets", "example2", "K.txt") | |
| camera_matrix_str = "" | |
| baseline_val = 0.537 # default | |
| if os.path.exists(k_file): | |
| try: | |
| with open(k_file, 'r') as f: | |
| lines = f.readlines() | |
| if len(lines) >= 1: | |
| camera_matrix_str = lines[0].strip() | |
| if len(lines) >= 2: | |
| baseline_val = float(lines[1].strip()) | |
| except: | |
| camera_matrix_str = "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0" | |
| examples_advanced_list.append([ | |
| os.path.join(current_dir, "assets", "example2", "left.png"), | |
| os.path.join(current_dir, "assets", "example2", "right.png"), | |
| camera_matrix_str, | |
| baseline_val | |
| ]) | |
| if examples_advanced_list: | |
| gr.Examples( | |
| examples=examples_advanced_list, | |
| inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
| label="π Example Images with Camera Parameters" | |
| ) | |
| # Event handlers | |
| if available_models: | |
| process_btn.click( | |
| fn=process_stereo_pair, | |
| inputs=[model_selector, left_input, right_input], | |
| outputs=[output_image, status_text], | |
| show_progress=True | |
| ) | |
| if OPEN3D_AVAILABLE: | |
| process_depth_btn.click( | |
| fn=process_with_depth, | |
| inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
| outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth], | |
| show_progress=True | |
| ) | |
| else: | |
| process_depth_btn.click( | |
| fn=lambda *args: (None, None, None, "β Open3D not available. Install with: pip install open3d"), | |
| inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
| outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth] | |
| ) | |
| else: | |
| # No models available | |
| process_btn.click( | |
| fn=lambda *args: (None, "β No models available. Please ensure crestereo_eth3d.pth is in models/ directory."), | |
| inputs=[model_selector, left_input, right_input], | |
| outputs=[output_image, status_text] | |
| ) | |
| process_depth_btn.click( | |
| fn=lambda *args: (None, None, None, "β No models available. Please ensure crestereo_eth3d.pth is in models/ directory."), | |
| inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
| outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth] | |
| ) | |
| # Citation section at the bottom | |
| with gr.Accordion("π Citation", open=False): | |
| gr.Markdown(""" | |
| ### π Please Cite the Original Paper | |
| If you use this work in your research, please cite: | |
| ```bibtex | |
| @article{li2022practical, | |
| title={Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation}, | |
| author={Li, Jiankun and Wang, Peisen and Xiong, Pengfei and Cai, Tao and Yan, Ziwei and Yang, Lei and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng}, | |
| journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | |
| pages={16263--16272}, | |
| year={2022} | |
| } | |
| ``` | |
| """) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| ### π Notes: | |
| - **Input images must be rectified stereo pairs** (epipolar lines are horizontal) | |
| - **β‘ GPU Acceleration**: Requires CUDA-compatible GPU | |
| - **π¦ Model Caching**: Models are cached for efficient repeated usage | |
| - For best results, use high-quality rectified stereo pairs | |
| - Model works on RGB images and supports various resolutions | |
| ### π References: | |
| - [CREStereo Paper](https://arxiv.org/abs/2203.11483) | |
| - [Original GitHub Repository](https://github.com/megvii-research/CREStereo) | |
| - [This PyTorch Implementation](https://github.com/ibaiGorordo/CREStereo-Pytorch) | |
| """) | |
| return app | |
| def main(): | |
| """Main function to launch the app""" | |
| # Ensure no CUDA operations during startup | |
| if torch.cuda.is_available(): | |
| logging.warning("CUDA detected during startup - this should not happen in ZeroGPU") | |
| logging.info("π Starting CREStereo Gradio App...") | |
| # Parse command line arguments | |
| import argparse | |
| parser = argparse.ArgumentParser(description="CREStereo Gradio App") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") | |
| parser.add_argument("--port", type=int, default=7860, help="Port to bind to") | |
| parser.add_argument("--share", action="store_true", help="Create shareable link") | |
| parser.add_argument("--debug", action="store_true", help="Enable debug mode") | |
| args = parser.parse_args() | |
| if args.debug: | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| try: | |
| # Create and launch app | |
| logging.info("Creating Gradio app...") | |
| app = create_app() | |
| logging.info("β Gradio app created successfully") | |
| logging.info(f"Launching app on {args.host}:{args.port}") | |
| if args.share: | |
| logging.info("Share link will be created") | |
| # For ZeroGPU compatibility, launch with appropriate settings | |
| app.launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share, | |
| show_error=True, | |
| favicon_path=None, | |
| ssr_mode=False, # Disable SSR for ZeroGPU compatibility | |
| allowed_paths=["./"] # Allow access to local files | |
| ) | |
| except Exception as e: | |
| logging.error(f"Failed to launch app: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| # Additional safety check for ZeroGPU environment | |
| if 'SPACE_ID' in os.environ: | |
| logging.info("Running in Hugging Face Spaces environment") | |
| # Do not check CUDA status during startup - this can trigger CUDA initialization | |
| # The CUDA status will be checked inside the @spaces.GPU decorated functions | |
| logging.info("β CUDA status will be checked within GPU-decorated functions") | |
| main() | |