| |
| |
| |
| |
| |
| """ |
| @Author : Peike Li |
| @Contact : peike.li@yahoo.com |
| @File : simple_extractor.py |
| @Time : 8/30/19 8:59 PM |
| @Desc : Simple Extractor (modified for single image input) |
| """ |
|
|
| import os |
| import torch |
| import argparse |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| import cv2 |
|
|
| from torch.utils.data import Dataset, DataLoader |
| import torchvision.transforms as transforms |
|
|
| import networks |
| from preprocess.utils.transforms import transform_logits, get_affine_transform |
|
|
|
|
| class SimpleFileDataset(Dataset): |
| def __init__(self, image_path, input_size=[512, 512], transform=None): |
| self.image_path = image_path |
| self.input_size = np.asarray(input_size) |
| self.transform = transform |
| self.aspect_ratio = input_size[1] * 1.0 / input_size[0] |
| self.img_name = os.path.basename(image_path) |
|
|
| def __len__(self): |
| return 1 |
|
|
| def _box2cs(self, box): |
| x, y, w, h = box[:4] |
| return self._xywh2cs(x, y, w, h) |
|
|
| def _xywh2cs(self, x, y, w, h): |
| center = np.zeros((2), dtype=np.float32) |
| center[0] = x + w * 0.5 |
| center[1] = y + h * 0.5 |
| if w > self.aspect_ratio * h: |
| h = w * 1.0 / self.aspect_ratio |
| elif w < self.aspect_ratio * h: |
| w = h * self.aspect_ratio |
| scale = np.array([w, h], dtype=np.float32) |
| return center, scale |
|
|
| def __getitem__(self, index): |
| img = cv2.imread(self.image_path, cv2.IMREAD_COLOR) |
| h, w, _ = img.shape |
| person_center, s = self._box2cs([0, 0, w - 1, h - 1]) |
| r = 0 |
| trans = get_affine_transform(person_center, s, r, self.input_size) |
| input = cv2.warpAffine( |
| img, |
| trans, |
| (int(self.input_size[1]), int(self.input_size[0])), |
| flags=cv2.INTER_LINEAR, |
| borderMode=cv2.BORDER_CONSTANT, |
| borderValue=(0, 0, 0)) |
| input = self.transform(input) |
| meta = { |
| 'name': self.img_name, |
| 'center': person_center, |
| 'height': h, |
| 'width': w, |
| 'scale': s, |
| 'rotation': r |
| } |
| return input, meta |
|
|
|
|
| dataset_settings = { |
| 'atr': { |
| 'input_size': [512, 512], |
| 'num_classes': 18, |
| 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', |
| 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] |
| } |
| } |
|
|
| def get_palette(num_cls): |
| n = 18 |
| palette = [0] * (n * 3) |
| j = num_cls |
| lab = num_cls |
| palette[j * 3 + 0] = 0 |
| palette[j * 3 + 1] = 0 |
| palette[j * 3 + 2] = 0 |
| i = 0 |
| while lab: |
| palette[j * 3 + 0] = 255 |
| palette[j * 3 + 1] = 255 |
| palette[j * 3 + 2] = 255 |
| i += 1 |
| lab >>= 3 |
| return palette |
|
|
|
|
| def masking(image_path, class_num=0): |
| num_classes = dataset_settings['atr']['num_classes'] |
| input_size = dataset_settings['atr']['input_size'] |
| label = dataset_settings['atr']['label'] |
| print("Evaluating total class number {} with {}".format(num_classes, label)) |
|
|
| model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) |
| state_dict = torch.load('./ckpts/exp-schp-201908301523-atr.pth')['state_dict'] |
|
|
| from collections import OrderedDict |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k[7:] |
| new_state_dict[name] = v |
| model.load_state_dict(new_state_dict) |
| model.cuda() |
| model.eval() |
|
|
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) |
| ]) |
| dataset = SimpleFileDataset(image_path=image_path, input_size=input_size, transform=transform) |
| dataloader = DataLoader(dataset) |
|
|
| if not os.path.exists('./outputs'): |
| os.makedirs('./outputs') |
|
|
| palette = get_palette(class_num) |
| with torch.no_grad(): |
| for idx, batch in enumerate(tqdm(dataloader)): |
| image, meta = batch |
| img_name = meta['name'][0] |
| c = meta['center'].numpy()[0] |
| s = meta['scale'].numpy()[0] |
| w = meta['width'].numpy()[0] |
| h = meta['height'].numpy()[0] |
|
|
| output = model(image.cuda()) |
| upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) |
| upsample_output = upsample(output[0][-1][0].unsqueeze(0)) |
| upsample_output = upsample_output.squeeze() |
| upsample_output = upsample_output.permute(1, 2, 0) |
|
|
| logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size) |
| parsing_result = np.argmax(logits_result, axis=2) |
| parsing_result_path = os.path.join('./outputs', img_name[:-4] + '.png') |
| output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) |
| output_img.putpalette(palette) |
| output_img.save(parsing_result_path) |
| gray_img = output_img.convert('L') |
|
|
| return gray_img |
|
|