Spaces:
Sleeping
Sleeping
| # depth_texture_mask.py | |
| # Modified: lazy MiDaS init and safe for server use. | |
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Globals (initialized by init_midas) | |
| midas = None | |
| midas_transforms = None | |
| transform = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _midas_initialized = False | |
| def init_midas(model_name="DPT_Hybrid", device_override=None, force_reload=False): | |
| """ | |
| Initialize/load the MiDaS model and transforms into global variables. | |
| Call this once (e.g., at FastAPI startup). | |
| """ | |
| global midas, midas_transforms, transform, device, _midas_initialized | |
| if device_override is not None: | |
| device = device_override | |
| else: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if _midas_initialized and not force_reload: | |
| return | |
| # Use torch.hub to load MiDaS transforms & model | |
| # NOTE: this will download if not cached | |
| midas = torch.hub.load("intel-isl/MiDaS", model_name, pretrained=True) | |
| midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
| # choose the appropriate transform (DPT / midas small has different names) | |
| if hasattr(midas_transforms, "dpt_transform"): | |
| transform = midas_transforms.dpt_transform | |
| elif hasattr(midas_transforms, "small_transform"): | |
| transform = midas_transforms.small_transform | |
| else: | |
| # fallback: try a generic 'transform' | |
| transform = getattr(midas_transforms, "transform", None) | |
| midas.to(device).eval() | |
| _midas_initialized = True | |
| return | |
| def _ensure_initialized(): | |
| if not _midas_initialized: | |
| init_midas() | |
| def generate_texture_depth_mask(input_data, mask_only=False): | |
| """ | |
| Generate a texture + depth structural mask. | |
| Supports: | |
| - File paths (.jpg, .png) | |
| - NumPy arrays (H,W,C) RGB or RGBA | |
| - List of inputs (batch mode) | |
| Returns: | |
| mask_only=False: | |
| - Single: (fig, mask) | |
| - Batch: list of (fig, mask) | |
| mask_only=True: | |
| - Single: mask | |
| - Batch: list of masks | |
| """ | |
| _ensure_initialized() | |
| def _process_single(image_source): | |
| # Load image (array or file path) | |
| if isinstance(image_source, np.ndarray): | |
| img_rgb = image_source | |
| if img_rgb.shape[-1] == 4: | |
| img_rgb = img_rgb[:, :, :3] | |
| img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | |
| elif isinstance(image_source, str) and os.path.isfile(image_source): | |
| img_bgr = cv2.imread(image_source) | |
| if img_bgr is None: | |
| raise ValueError(f"Could not read {image_source}") | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| else: | |
| raise TypeError("Input must be a file path or NumPy image array.") | |
| gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) | |
| blurred = cv2.GaussianBlur(gray, (3, 3), 0) | |
| # Depth (MiDaS) | |
| t = transform(img_rgb).to(device) | |
| if t.ndim == 3: | |
| t = t.unsqueeze(0) | |
| with torch.no_grad(): | |
| depth = midas(t) | |
| depth = torch.nn.functional.interpolate( | |
| depth.unsqueeze(1), | |
| size=gray.shape, | |
| mode="bicubic", | |
| align_corners=False | |
| ).squeeze() | |
| depth = depth.cpu().numpy() | |
| depth = cv2.normalize(depth, None, 0, 255, cv2.NORM_MINMAX) | |
| depth_mask = cv2.convertScaleAbs(255 - depth) | |
| # Texture features | |
| canny = cv2.Canny(blurred, 40, 120) | |
| lap = cv2.convertScaleAbs(cv2.Laplacian(blurred, cv2.CV_64F)) | |
| corners = cv2.cornerHarris(np.float32(blurred), 2, 3, 0.04) | |
| corners = cv2.dilate(corners, None) | |
| corner_mask = np.zeros_like(gray) | |
| corner_mask[corners > 0.01 * corners.max()] = 255 | |
| edges_all = cv2.addWeighted(canny, 0.6, lap, 0.4, 0) | |
| mask = cv2.bitwise_or(edges_all, corner_mask) | |
| mask = cv2.addWeighted(mask, 0.8, depth_mask, 0.2, 0) | |
| noise = np.random.randint(0, 60, gray.shape, dtype=np.uint8) | |
| mask = cv2.addWeighted(mask, 1.0, noise, 0.2, 0) | |
| mask = cv2.convertScaleAbs(mask) | |
| if mask_only: | |
| return mask | |
| # Visualization mode | |
| fig, ax = plt.subplots(1, 2, figsize=(14, 6)) | |
| ax[0].imshow(img_rgb) | |
| ax[0].set_title("Original Image") | |
| ax[0].axis("off") | |
| ax[1].imshow(mask, cmap="gray") | |
| ax[1].set_title("Texture + Depth Structural Mask") | |
| ax[1].axis("off") | |
| plt.tight_layout() | |
| return fig, mask | |
| # Batch support | |
| if isinstance(input_data, list): | |
| return [_process_single(item) for item in input_data] | |
| return _process_single(input_data) | |
| # CLI entrypoint preserved for local use | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input", type=str, required=True) | |
| parser.add_argument("--save", type=str, default="./mask_img.png") | |
| parser.add_argument("--mask_only", action="store_true") | |
| args = parser.parse_args() | |
| output = generate_texture_depth_mask(args.input, mask_only=args.mask_only) | |
| if args.mask_only: | |
| mask = output | |
| else: | |
| fig, mask = output | |
| cv2.imwrite(args.save, mask) | |
| print(f"[OK] Saved mask to {args.save}") | |