File size: 2,648 Bytes
c20c148
 
 
dff0630
b12ac69
a601816
c20c148
 
 
 
 
712380b
ce450be
b12ac69
 
 
a601816
 
 
 
b12ac69
 
 
8c7592a
 
b12ac69
c20c148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293fac7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import cv2
import numpy as np
import torch
from PIL import Image, ImageFilter
from rembg import remove
from scipy.ndimage import binary_dilation
from torchvision.transforms import Compose

from DPT.dpt.transforms import PrepareForNet, NormalizeImage, Resize


def create_mask(image, blur=0, padding=0):
    rm_bg = remove(np.array(image), post_process_mask=True, only_mask=True)
    rm_bg = Image.fromarray((rm_bg * 255).astype(np.uint8))
    rm_bg = rm_bg.resize(image.size, resample=Image.BILINEAR)
    # Create a padding of 5 pixels around the object in the mask
    if padding > 0:
        padded_mask = np.pad(rm_bg, pad_width=padding, mode='constant', constant_values=0)
    else:
        padded_mask = binary_dilation(np.array(rm_bg), iterations=-padding)

    # Convert mask back to uint8 for PIL compatibility
    pil_mask = Image.fromarray((padded_mask * 255).astype(np.uint8))
    if blur > 0:
        pil_mask = pil_mask.filter(ImageFilter.GaussianBlur(blur))
    return pil_mask.resize((1024, 1024))


def create_depth_map(image, model):
    net_w = net_h = 1024
    normalization = NormalizeImage(mean=[0.35, 0.35, 0.35], std=[0.35, 0.35, 0.35])

    transform = Compose(
        [
            Resize(
                net_w,
                net_h,
                resize_target=None,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method="minimal",
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            normalization,
            PrepareForNet(),
        ]
    )
    img_input = transform({"image": image})["image"]

    # compute
    with torch.no_grad():
        sample = torch.from_numpy(img_input).unsqueeze(0)

        # if optimize == True and device == torch.device("cuda"):
        #     sample = sample.to(memory_format=torch.channels_last)
        #     sample = sample.half()

        prediction = model.forward(sample)
        prediction = (
            torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=image.shape[:2],
                mode="bicubic",
                align_corners=False,
            )
            .squeeze()
            .cpu()
            .numpy()
        )

    depth_min = prediction.min()
    depth_max = prediction.max()
    bits = 2
    max_val = (2 ** (8 * bits)) - 1

    if depth_max - depth_min > np.finfo("float").eps:
        out = max_val * (prediction - depth_min) / (depth_max - depth_min)
    else:
        out = np.zeros(prediction.shape, dtype=prediction.dtype)

    out = (out / 256).astype('uint8')
    return Image.fromarray(out).resize((1024, 1024))