#!/usr/bin/env python # -*- encoding: utf-8 -*- # + # #!/usr/bin/env python # -*- encoding: utf-8 -*- """ @Author : Peike Li @Contact : peike.li@yahoo.com @File : simple_extractor.py @Time : 8/30/19 8:59 PM @Desc : Simple Extractor (modified for single image input) """ import os import torch import argparse import numpy as np from PIL import Image from tqdm import tqdm import cv2 from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms import networks from preprocess.utils.transforms import transform_logits, get_affine_transform class SimpleFileDataset(Dataset): def __init__(self, image_path, input_size=[512, 512], transform=None): self.image_path = image_path self.input_size = np.asarray(input_size) self.transform = transform self.aspect_ratio = input_size[1] * 1.0 / input_size[0] self.img_name = os.path.basename(image_path) def __len__(self): return 1 def _box2cs(self, box): x, y, w, h = box[:4] return self._xywh2cs(x, y, w, h) def _xywh2cs(self, x, y, w, h): center = np.zeros((2), dtype=np.float32) center[0] = x + w * 0.5 center[1] = y + h * 0.5 if w > self.aspect_ratio * h: h = w * 1.0 / self.aspect_ratio elif w < self.aspect_ratio * h: w = h * self.aspect_ratio scale = np.array([w, h], dtype=np.float32) return center, scale def __getitem__(self, index): img = cv2.imread(self.image_path, cv2.IMREAD_COLOR) h, w, _ = img.shape person_center, s = self._box2cs([0, 0, w - 1, h - 1]) r = 0 trans = get_affine_transform(person_center, s, r, self.input_size) input = cv2.warpAffine( img, trans, (int(self.input_size[1]), int(self.input_size[0])), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0)) input = self.transform(input) meta = { 'name': self.img_name, 'center': person_center, 'height': h, 'width': w, 'scale': s, 'rotation': r } return input, meta dataset_settings = { 'atr': { 'input_size': [512, 512], 'num_classes': 18, 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] } } def get_palette(num_cls): n = 18 palette = [0] * (n * 3) j = num_cls lab = num_cls palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] = 255 palette[j * 3 + 1] = 255 palette[j * 3 + 2] = 255 i += 1 lab >>= 3 return palette def masking(image_path, class_num=0): num_classes = dataset_settings['atr']['num_classes'] input_size = dataset_settings['atr']['input_size'] label = dataset_settings['atr']['label'] print("Evaluating total class number {} with {}".format(num_classes, label)) model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) state_dict = torch.load('./ckpts/exp-schp-201908301523-atr.pth')['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) model.cuda() model.eval() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) ]) dataset = SimpleFileDataset(image_path=image_path, input_size=input_size, transform=transform) dataloader = DataLoader(dataset) if not os.path.exists('./outputs'): os.makedirs('./outputs') palette = get_palette(class_num) with torch.no_grad(): for idx, batch in enumerate(tqdm(dataloader)): image, meta = batch img_name = meta['name'][0] c = meta['center'].numpy()[0] s = meta['scale'].numpy()[0] w = meta['width'].numpy()[0] h = meta['height'].numpy()[0] output = model(image.cuda()) upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) upsample_output = upsample(output[0][-1][0].unsqueeze(0)) upsample_output = upsample_output.squeeze() upsample_output = upsample_output.permute(1, 2, 0) logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size) parsing_result = np.argmax(logits_result, axis=2) parsing_result_path = os.path.join('./outputs', img_name[:-4] + '.png') output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) output_img.putpalette(palette) output_img.save(parsing_result_path) gray_img = output_img.convert('L') return gray_img