Structura-AI / depth_texture_mask.py
AurevinP's picture
Upload the api endpoint and app.
d13d7e1 verified
# depth_texture_mask.py
# Modified: lazy MiDaS init and safe for server use.
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
# Globals (initialized by init_midas)
midas = None
midas_transforms = None
transform = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_midas_initialized = False
def init_midas(model_name="DPT_Hybrid", device_override=None, force_reload=False):
"""
Initialize/load the MiDaS model and transforms into global variables.
Call this once (e.g., at FastAPI startup).
"""
global midas, midas_transforms, transform, device, _midas_initialized
if device_override is not None:
device = device_override
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if _midas_initialized and not force_reload:
return
# Use torch.hub to load MiDaS transforms & model
# NOTE: this will download if not cached
midas = torch.hub.load("intel-isl/MiDaS", model_name, pretrained=True)
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
# choose the appropriate transform (DPT / midas small has different names)
if hasattr(midas_transforms, "dpt_transform"):
transform = midas_transforms.dpt_transform
elif hasattr(midas_transforms, "small_transform"):
transform = midas_transforms.small_transform
else:
# fallback: try a generic 'transform'
transform = getattr(midas_transforms, "transform", None)
midas.to(device).eval()
_midas_initialized = True
return
def _ensure_initialized():
if not _midas_initialized:
init_midas()
def generate_texture_depth_mask(input_data, mask_only=False):
"""
Generate a texture + depth structural mask.
Supports:
- File paths (.jpg, .png)
- NumPy arrays (H,W,C) RGB or RGBA
- List of inputs (batch mode)
Returns:
mask_only=False:
- Single: (fig, mask)
- Batch: list of (fig, mask)
mask_only=True:
- Single: mask
- Batch: list of masks
"""
_ensure_initialized()
def _process_single(image_source):
# Load image (array or file path)
if isinstance(image_source, np.ndarray):
img_rgb = image_source
if img_rgb.shape[-1] == 4:
img_rgb = img_rgb[:, :, :3]
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
elif isinstance(image_source, str) and os.path.isfile(image_source):
img_bgr = cv2.imread(image_source)
if img_bgr is None:
raise ValueError(f"Could not read {image_source}")
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
else:
raise TypeError("Input must be a file path or NumPy image array.")
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
# Depth (MiDaS)
t = transform(img_rgb).to(device)
if t.ndim == 3:
t = t.unsqueeze(0)
with torch.no_grad():
depth = midas(t)
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=gray.shape,
mode="bicubic",
align_corners=False
).squeeze()
depth = depth.cpu().numpy()
depth = cv2.normalize(depth, None, 0, 255, cv2.NORM_MINMAX)
depth_mask = cv2.convertScaleAbs(255 - depth)
# Texture features
canny = cv2.Canny(blurred, 40, 120)
lap = cv2.convertScaleAbs(cv2.Laplacian(blurred, cv2.CV_64F))
corners = cv2.cornerHarris(np.float32(blurred), 2, 3, 0.04)
corners = cv2.dilate(corners, None)
corner_mask = np.zeros_like(gray)
corner_mask[corners > 0.01 * corners.max()] = 255
edges_all = cv2.addWeighted(canny, 0.6, lap, 0.4, 0)
mask = cv2.bitwise_or(edges_all, corner_mask)
mask = cv2.addWeighted(mask, 0.8, depth_mask, 0.2, 0)
noise = np.random.randint(0, 60, gray.shape, dtype=np.uint8)
mask = cv2.addWeighted(mask, 1.0, noise, 0.2, 0)
mask = cv2.convertScaleAbs(mask)
if mask_only:
return mask
# Visualization mode
fig, ax = plt.subplots(1, 2, figsize=(14, 6))
ax[0].imshow(img_rgb)
ax[0].set_title("Original Image")
ax[0].axis("off")
ax[1].imshow(mask, cmap="gray")
ax[1].set_title("Texture + Depth Structural Mask")
ax[1].axis("off")
plt.tight_layout()
return fig, mask
# Batch support
if isinstance(input_data, list):
return [_process_single(item) for item in input_data]
return _process_single(input_data)
# CLI entrypoint preserved for local use
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--save", type=str, default="./mask_img.png")
parser.add_argument("--mask_only", action="store_true")
args = parser.parse_args()
output = generate_texture_depth_mask(args.input, mask_only=args.mask_only)
if args.mask_only:
mask = output
else:
fig, mask = output
cv2.imwrite(args.save, mask)
print(f"[OK] Saved mask to {args.save}")