import cv2 import numpy as np import torch from PIL import Image, ImageFilter from rembg import remove from scipy.ndimage import binary_dilation from torchvision.transforms import Compose from DPT.dpt.transforms import PrepareForNet, NormalizeImage, Resize def create_mask(image, blur=0, padding=0): rm_bg = remove(np.array(image), post_process_mask=True, only_mask=True) rm_bg = Image.fromarray((rm_bg * 255).astype(np.uint8)) rm_bg = rm_bg.resize(image.size, resample=Image.BILINEAR) # Create a padding of 5 pixels around the object in the mask if padding > 0: padded_mask = np.pad(rm_bg, pad_width=padding, mode='constant', constant_values=0) else: padded_mask = binary_dilation(np.array(rm_bg), iterations=-padding) # Convert mask back to uint8 for PIL compatibility pil_mask = Image.fromarray((padded_mask * 255).astype(np.uint8)) if blur > 0: pil_mask = pil_mask.filter(ImageFilter.GaussianBlur(blur)) return pil_mask.resize((1024, 1024)) def create_depth_map(image, model): net_w = net_h = 1024 normalization = NormalizeImage(mean=[0.35, 0.35, 0.35], std=[0.35, 0.35, 0.35]) transform = Compose( [ Resize( net_w, net_h, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), normalization, PrepareForNet(), ] ) img_input = transform({"image": image})["image"] # compute with torch.no_grad(): sample = torch.from_numpy(img_input).unsqueeze(0) # if optimize == True and device == torch.device("cuda"): # sample = sample.to(memory_format=torch.channels_last) # sample = sample.half() prediction = model.forward(sample) prediction = ( torch.nn.functional.interpolate( prediction.unsqueeze(1), size=image.shape[:2], mode="bicubic", align_corners=False, ) .squeeze() .cpu() .numpy() ) depth_min = prediction.min() depth_max = prediction.max() bits = 2 max_val = (2 ** (8 * bits)) - 1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (prediction - depth_min) / (depth_max - depth_min) else: out = np.zeros(prediction.shape, dtype=prediction.dtype) out = (out / 256).astype('uint8') return Image.fromarray(out).resize((1024, 1024))