Spaces:
Configuration error
Configuration error
| import torch | |
| import numpy as np | |
| import cv2 | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader | |
| from .simple_extractor_dataset import SimpleFolderDataset | |
| from .transforms import transform_logits | |
| from tqdm import tqdm | |
| from PIL import Image | |
| def get_palette(num_cls): | |
| """ Returns the color map for visualizing the segmentation mask. | |
| Args: | |
| num_cls: Number of classes | |
| Returns: | |
| The color map | |
| """ | |
| n = num_cls | |
| palette = [0] * (n * 3) | |
| for j in range(0, n): | |
| lab = j | |
| palette[j * 3 + 0] = 0 | |
| palette[j * 3 + 1] = 0 | |
| palette[j * 3 + 2] = 0 | |
| i = 0 | |
| while lab: | |
| palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) | |
| palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) | |
| palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) | |
| i += 1 | |
| lab >>= 3 | |
| return palette | |
| def delete_irregular(logits_result): | |
| parsing_result = np.argmax(logits_result, axis=2) | |
| upper_cloth = np.where(parsing_result == 4, 255, 0) | |
| contours, hierarchy = cv2.findContours(upper_cloth.astype(np.uint8), | |
| cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
| area = [] | |
| for i in range(len(contours)): | |
| a = cv2.contourArea(contours[i], True) | |
| area.append(abs(a)) | |
| if len(area) != 0: | |
| top = area.index(max(area)) | |
| M = cv2.moments(contours[top]) | |
| cY = int(M["m01"] / M["m00"]) | |
| dresses = np.where(parsing_result == 7, 255, 0) | |
| contours_dress, hierarchy_dress = cv2.findContours(dresses.astype(np.uint8), | |
| cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
| area_dress = [] | |
| for j in range(len(contours_dress)): | |
| a_d = cv2.contourArea(contours_dress[j], True) | |
| area_dress.append(abs(a_d)) | |
| if len(area_dress) != 0: | |
| top_dress = area_dress.index(max(area_dress)) | |
| M_dress = cv2.moments(contours_dress[top_dress]) | |
| cY_dress = int(M_dress["m01"] / M_dress["m00"]) | |
| wear_type = "dresses" | |
| if len(area) != 0: | |
| if len(area_dress) != 0 and cY_dress > cY: | |
| irregular_list = np.array([4, 5, 6]) | |
| logits_result[:, :, irregular_list] = -1 | |
| else: | |
| irregular_list = np.array([5, 6, 7, 8, 9, 10, 12, 13]) | |
| logits_result[:cY, :, irregular_list] = -1 | |
| wear_type = "cloth_pant" | |
| parsing_result = np.argmax(logits_result, axis=2) | |
| # pad border | |
| parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0) | |
| return parsing_result, wear_type | |
| def hole_fill(img): | |
| img_copy = img.copy() | |
| mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8) | |
| cv2.floodFill(img, mask, (0, 0), 255) | |
| img_inverse = cv2.bitwise_not(img) | |
| dst = cv2.bitwise_or(img_copy, img_inverse) | |
| return dst | |
| def refine_mask(mask): | |
| contours, hierarchy = cv2.findContours(mask.astype(np.uint8), | |
| cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
| area = [] | |
| for j in range(len(contours)): | |
| a_d = cv2.contourArea(contours[j], True) | |
| area.append(abs(a_d)) | |
| refine_mask = np.zeros_like(mask).astype(np.uint8) | |
| if len(area) != 0: | |
| i = area.index(max(area)) | |
| cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1) | |
| # keep large area in skin case | |
| for j in range(len(area)): | |
| if j != i and area[i] > 2000: | |
| cv2.drawContours(refine_mask, contours, j, color=255, thickness=-1) | |
| return refine_mask | |
| def refine_hole(parsing_result_filled, parsing_result, arm_mask): | |
| filled_hole = cv2.bitwise_and(np.where(parsing_result_filled == 4, 255, 0), | |
| np.where(parsing_result != 4, 255, 0)) - arm_mask * 255 | |
| contours, hierarchy = cv2.findContours(filled_hole, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
| refine_hole_mask = np.zeros_like(parsing_result).astype(np.uint8) | |
| for i in range(len(contours)): | |
| a = cv2.contourArea(contours[i], True) | |
| # keep hole > 2000 pixels | |
| if abs(a) > 2000: | |
| cv2.drawContours(refine_hole_mask, contours, i, color=255, thickness=-1) | |
| return refine_hole_mask + arm_mask | |
| def onnx_inference(lip_session, input_dir, mask_components=[0]): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) | |
| ]) | |
| input_size = [473, 473] | |
| dataset_lip = SimpleFolderDataset(root=input_dir, input_size=input_size, transform=transform) | |
| dataloader_lip = DataLoader(dataset_lip) | |
| palette = get_palette(20) | |
| with torch.no_grad(): | |
| for _, batch in enumerate(tqdm(dataloader_lip)): | |
| image, meta = batch | |
| c = meta['center'].numpy()[0] | |
| s = meta['scale'].numpy()[0] | |
| w = meta['width'].numpy()[0] | |
| h = meta['height'].numpy()[0] | |
| output = lip_session.run(None, {"input.1": image.numpy().astype(np.float32)}) | |
| upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) | |
| upsample_output = upsample(torch.from_numpy(output[1][0]).unsqueeze(0)) | |
| upsample_output = upsample_output.squeeze() | |
| upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC | |
| logits_result_lip = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, | |
| input_size=input_size) | |
| parsing_result = np.argmax(logits_result_lip, axis=2) | |
| output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) | |
| output_img.putpalette(palette) | |
| mask = np.isin(output_img, mask_components).astype(np.uint8) | |
| mask_image = Image.fromarray(mask * 255) | |
| mask_image = mask_image.convert("RGB") | |
| mask_image = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0) | |
| output_img = output_img.convert('RGB') | |
| output_img = torch.from_numpy(np.array(output_img).astype(np.float32) / 255.0).unsqueeze(0) | |
| return output_img, mask_image | |