File size: 2,968 Bytes
83039b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os

import cv2
import numpy as np
import torch
from PIL import Image
from skimage import io
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms

from .u2net import RescaleT, ToTensorLab, SalObjDataset, normPRED, load_human_segm_model


def pred_to_image(predictions, image_path):
    im = Image.fromarray(predictions.squeeze().cpu().data.numpy() * 255).convert('RGB')
    image = io.imread(image_path)
    imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
    return imo


def segment_human(image_path, output_dir):
    """
    Segment human using U-2-Net
    :param image_path: image path
    :param output_dir: output directory
    """
    model_name = "u2net"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    images = [image_path]

    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=images,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    net = load_human_segm_model(device, model_name)

    # 2. inference
    for i_test, data_test in enumerate(test_salobj_dataloader):
        print("inferencing:", images[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        mask = pred_to_image(pred, images[i_test])
        mask_cv2 = cv2.cvtColor(np.array(mask), cv2.COLOR_RGB2BGR)

        subimage = cv2.subtract(mask_cv2, cv2.imread(images[i_test]))
        original = Image.open(images[i_test])
        subimage = Image.fromarray(cv2.cvtColor(subimage, cv2.COLOR_BGR2RGB))

        subimage = subimage.convert("RGBA")
        original = original.convert("RGBA")

        subdata = subimage.getdata()
        ogdata = original.getdata()

        newdata = []
        for i in range(subdata.size[0] * subdata.size[1]):
            if subdata[i][0] == 0 and subdata[i][1] == 0 and subdata[i][2] == 0:
                newdata.append((231, 231, 231, 231))
            else:
                newdata.append(ogdata[i])
        subimage.putdata(newdata)

        subimage.save(os.path.join(output_dir, f"{images[i_test].split(os.sep)[-1].split('.')[0]}.png"))

        del d1, d2, d3, d4, d5, d6, d7