Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from skimage import io | |
| from torch.autograd import Variable | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from .u2net import RescaleT, ToTensorLab, SalObjDataset, normPRED, load_human_segm_model | |
| def pred_to_image(predictions, image_path): | |
| im = Image.fromarray(predictions.squeeze().cpu().data.numpy() * 255).convert('RGB') | |
| image = io.imread(image_path) | |
| imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR) | |
| return imo | |
| def segment_human(image_path, output_dir): | |
| """ | |
| Segment human using U-2-Net | |
| :param image_path: image path | |
| :param output_dir: output directory | |
| """ | |
| model_name = "u2net" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| images = [image_path] | |
| # 1. dataloader | |
| test_salobj_dataset = SalObjDataset(img_name_list=images, | |
| lbl_name_list=[], | |
| transform=transforms.Compose([RescaleT(320), | |
| ToTensorLab(flag=0)]) | |
| ) | |
| test_salobj_dataloader = DataLoader(test_salobj_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=1) | |
| net = load_human_segm_model(device, model_name) | |
| # 2. inference | |
| for i_test, data_test in enumerate(test_salobj_dataloader): | |
| print("inferencing:", images[i_test].split(os.sep)[-1]) | |
| inputs_test = data_test['image'] | |
| inputs_test = inputs_test.type(torch.FloatTensor) | |
| if torch.cuda.is_available(): | |
| inputs_test = Variable(inputs_test.cuda()) | |
| else: | |
| inputs_test = Variable(inputs_test) | |
| d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) | |
| # normalization | |
| pred = d1[:, 0, :, :] | |
| pred = normPRED(pred) | |
| mask = pred_to_image(pred, images[i_test]) | |
| mask_cv2 = cv2.cvtColor(np.array(mask), cv2.COLOR_RGB2BGR) | |
| subimage = cv2.subtract(mask_cv2, cv2.imread(images[i_test])) | |
| original = Image.open(images[i_test]) | |
| subimage = Image.fromarray(cv2.cvtColor(subimage, cv2.COLOR_BGR2RGB)) | |
| subimage = subimage.convert("RGBA") | |
| original = original.convert("RGBA") | |
| subdata = subimage.getdata() | |
| ogdata = original.getdata() | |
| newdata = [] | |
| for i in range(subdata.size[0] * subdata.size[1]): | |
| if subdata[i][0] == 0 and subdata[i][1] == 0 and subdata[i][2] == 0: | |
| newdata.append((231, 231, 231, 231)) | |
| else: | |
| newdata.append(ogdata[i]) | |
| subimage.putdata(newdata) | |
| subimage.save(os.path.join(output_dir, f"{images[i_test].split(os.sep)[-1].split('.')[0]}.png")) | |
| del d1, d2, d3, d4, d5, d6, d7 | |