# depth_texture_mask.py # Modified: lazy MiDaS init and safe for server use. import os import cv2 import torch import numpy as np import matplotlib.pyplot as plt # Globals (initialized by init_midas) midas = None midas_transforms = None transform = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _midas_initialized = False def init_midas(model_name="DPT_Hybrid", device_override=None, force_reload=False): """ Initialize/load the MiDaS model and transforms into global variables. Call this once (e.g., at FastAPI startup). """ global midas, midas_transforms, transform, device, _midas_initialized if device_override is not None: device = device_override else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if _midas_initialized and not force_reload: return # Use torch.hub to load MiDaS transforms & model # NOTE: this will download if not cached midas = torch.hub.load("intel-isl/MiDaS", model_name, pretrained=True) midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") # choose the appropriate transform (DPT / midas small has different names) if hasattr(midas_transforms, "dpt_transform"): transform = midas_transforms.dpt_transform elif hasattr(midas_transforms, "small_transform"): transform = midas_transforms.small_transform else: # fallback: try a generic 'transform' transform = getattr(midas_transforms, "transform", None) midas.to(device).eval() _midas_initialized = True return def _ensure_initialized(): if not _midas_initialized: init_midas() def generate_texture_depth_mask(input_data, mask_only=False): """ Generate a texture + depth structural mask. Supports: - File paths (.jpg, .png) - NumPy arrays (H,W,C) RGB or RGBA - List of inputs (batch mode) Returns: mask_only=False: - Single: (fig, mask) - Batch: list of (fig, mask) mask_only=True: - Single: mask - Batch: list of masks """ _ensure_initialized() def _process_single(image_source): # Load image (array or file path) if isinstance(image_source, np.ndarray): img_rgb = image_source if img_rgb.shape[-1] == 4: img_rgb = img_rgb[:, :, :3] img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) elif isinstance(image_source, str) and os.path.isfile(image_source): img_bgr = cv2.imread(image_source) if img_bgr is None: raise ValueError(f"Could not read {image_source}") img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) else: raise TypeError("Input must be a file path or NumPy image array.") gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) blurred = cv2.GaussianBlur(gray, (3, 3), 0) # Depth (MiDaS) t = transform(img_rgb).to(device) if t.ndim == 3: t = t.unsqueeze(0) with torch.no_grad(): depth = midas(t) depth = torch.nn.functional.interpolate( depth.unsqueeze(1), size=gray.shape, mode="bicubic", align_corners=False ).squeeze() depth = depth.cpu().numpy() depth = cv2.normalize(depth, None, 0, 255, cv2.NORM_MINMAX) depth_mask = cv2.convertScaleAbs(255 - depth) # Texture features canny = cv2.Canny(blurred, 40, 120) lap = cv2.convertScaleAbs(cv2.Laplacian(blurred, cv2.CV_64F)) corners = cv2.cornerHarris(np.float32(blurred), 2, 3, 0.04) corners = cv2.dilate(corners, None) corner_mask = np.zeros_like(gray) corner_mask[corners > 0.01 * corners.max()] = 255 edges_all = cv2.addWeighted(canny, 0.6, lap, 0.4, 0) mask = cv2.bitwise_or(edges_all, corner_mask) mask = cv2.addWeighted(mask, 0.8, depth_mask, 0.2, 0) noise = np.random.randint(0, 60, gray.shape, dtype=np.uint8) mask = cv2.addWeighted(mask, 1.0, noise, 0.2, 0) mask = cv2.convertScaleAbs(mask) if mask_only: return mask # Visualization mode fig, ax = plt.subplots(1, 2, figsize=(14, 6)) ax[0].imshow(img_rgb) ax[0].set_title("Original Image") ax[0].axis("off") ax[1].imshow(mask, cmap="gray") ax[1].set_title("Texture + Depth Structural Mask") ax[1].axis("off") plt.tight_layout() return fig, mask # Batch support if isinstance(input_data, list): return [_process_single(item) for item in input_data] return _process_single(input_data) # CLI entrypoint preserved for local use if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, required=True) parser.add_argument("--save", type=str, default="./mask_img.png") parser.add_argument("--mask_only", action="store_true") args = parser.parse_args() output = generate_texture_depth_mask(args.input, mask_only=args.mask_only) if args.mask_only: mask = output else: fig, mask = output cv2.imwrite(args.save, mask) print(f"[OK] Saved mask to {args.save}")