Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision.transforms as transforms | |
| import criteria.deeplab as deeplab | |
| import PIL.Image as Image | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from configs import paths_config, global_config | |
| import numpy as np | |
| class Mask(nn.Module): | |
| def __init__(self, device="cpu"): | |
| """ | |
| | Class | Number | Class | Number | | |
| |------------|--------|-------|--------| | |
| | background | 0 | mouth | 10 | | |
| | skin | 1 | u_lip | 11 | | |
| | nose | 2 | l_lip | 12 | | |
| | eye_g | 3 | hair | 13 | | |
| | l_eye | 4 | hat | 14 | | |
| | r_eye | 5 | ear_r | 15 | | |
| | l_brow | 6 | neck_l| 16 | | |
| | r_brow | 7 | neck | 17 | | |
| | l_ear | 8 | cloth | 18 | | |
| | r_ear | 9 | | |
| """ | |
| super().__init__() | |
| self.seg_model = ( | |
| getattr(deeplab, "resnet101")( | |
| path=paths_config.deeplab, | |
| pretrained=True, | |
| num_classes=19, | |
| num_groups=32, | |
| weight_std=True, | |
| beta=False, | |
| device=device, | |
| ) | |
| .eval() | |
| .requires_grad_(False) | |
| ) | |
| ckpt = torch.load(paths_config.deeplab, map_location=device) | |
| state_dict = { | |
| k[7:]: v for k, v in ckpt["state_dict"].items() if "tracked" not in k | |
| } | |
| self.seg_model.load_state_dict(state_dict) | |
| self.seg_model = self.seg_model.to(global_config.device) | |
| self.labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17] | |
| self.kernel = torch.ones((1, 1, 25, 25), device=global_config.device) | |
| def get_labels(self, img): | |
| """Returns a mask from an input image""" | |
| data_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize((513, 513)), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| img = data_transforms(img) | |
| with torch.no_grad(): | |
| out = self.seg_model(img) | |
| _, label = torch.max(out, 1) | |
| label = label.unsqueeze(0).type(torch.float32) | |
| label = ( | |
| F.interpolate(label, size=(256, 256), mode="nearest") | |
| .squeeze() | |
| .type(torch.LongTensor) | |
| ) | |
| return label | |
| def get_mask(self, label): | |
| mask = torch.zeros_like(label, device=global_config.device, dtype=torch.float) | |
| for idx in self.labels: | |
| mask[label == idx] = 1 | |
| # smooth the mask with a mean convolution | |
| """mask = ( | |
| 1 | |
| - torch.clamp( | |
| torch.nn.functional.conv2d( | |
| 1 - mask[None, None, :, :], self.kernel, padding="same" | |
| ), | |
| 0, | |
| 1, | |
| ).squeeze() | |
| )""" | |
| """ mask = torch.clamp( | |
| torch.nn.functional.conv2d( | |
| mask[None, None, :, :], self.kernel, padding="same" | |
| ), | |
| 0, | |
| 1, | |
| ).squeeze()""" | |
| mask[label == 13] = 0.1 | |
| return mask | |
| def forward(self, real_imgs, generated_imgs): | |
| #return real_imgs, generated_imgs | |
| label = self.get_labels(real_imgs) | |
| mask = self.get_mask(label) | |
| real_imgs = real_imgs * mask | |
| generated_imgs = generated_imgs * mask | |
| """out = (real_imgs * mask).squeeze().detach() | |
| out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) | |
| Image.fromarray(out.cpu().numpy()).save("real_mask.png") | |
| out = (generated_imgs).squeeze().detach() | |
| out = (out.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) | |
| Image.fromarray(out.cpu().numpy()).save("generated_mask.png") | |
| mask = (mask).squeeze().detach() | |
| mask = mask.repeat(3, 1, 1) | |
| mask = (mask.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) | |
| Image.fromarray(mask.cpu().numpy()).save("mask.png")""" | |
| return real_imgs, generated_imgs | |