import facer from facer.face_parsing import FaRLFaceParser from facer.face_detection import RetinaFaceDetector from facer.face_detection.retinaface import RetinaFace from configs.paths import DefaultPaths import torch.backends.cudnn as cudnn import numpy as np import torch from torch import nn import torchvision.transforms as transforms import torch.nn.functional as F def my_fp_init(self, model_path=DefaultPaths.farl_path): super(FaRLFaceParser, self).__init__() self.conf_name = 'lapa/448' self.net = torch.jit.load(model_path) self.eval() FaRLFaceParser.__init__ = my_fp_init def remove_prefix(state_dict, prefix): """ Old style model is stored with all names of parameters sharing common prefix 'module.' """ def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x return {f(key): value for key, value in state_dict.items()} def check_keys(model, pretrained_state_dict): ckpt_keys = set(pretrained_state_dict.keys()) model_keys = set(model.state_dict().keys()) used_pretrained_keys = model_keys & ckpt_keys assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" return True def load_model(model, pretrained_path, load_to_cpu, network: str): if load_to_cpu: pretrained_dict = torch.load( pretrained_path, map_location=lambda storage, loc: storage ) else: device = torch.cuda.current_device() pretrained_dict = torch.load( pretrained_path, map_location=lambda storage, loc: storage.cuda(device) ) if "state_dict" in pretrained_dict.keys(): pretrained_dict = remove_prefix( pretrained_dict["state_dict"], "module.") else: pretrained_dict = remove_prefix(pretrained_dict, "module.") check_keys(model, pretrained_dict) model.load_state_dict(pretrained_dict, strict=False) return model def load_net(model_path): cfg = { "name": "mobilenet0.25", "min_sizes": [[16, 32], [64, 128], [256, 512]], "steps": [8, 16, 32], "variance": [0.1, 0.2], "clip": False, "loc_weight": 2.0, "gpu_train": True, "batch_size": 32, "ngpu": 1, "epoch": 250, "decay1": 190, "decay2": 220, "image_size": 640, "pretrain": True, "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3}, "in_channel": 32, "out_channel": 64, } # net and model net = RetinaFace(cfg=cfg, phase="test").cuda() net = load_model(net, model_path, True, network="mobilenet") net.eval() cudnn.benchmark = True # net = net.to(device) return net def my_fd_init(self, model_path=DefaultPaths.mobile_net_pth, trash=0.8): super(RetinaFaceDetector, self).__init__() self.conf_name = 'mobilenet' self.threshold=trash self.net = load_net(model_path) self.eval() RetinaFaceDetector.__init__ = my_fd_init class TargetMask(nn.Module): def __init__(self, tfm=True): super().__init__() self.face_detector = RetinaFaceDetector(trash=0.8).cuda().eval() self.face_parser = FaRLFaceParser().cuda().eval() self.to_farl = transforms.Compose( [ transforms.Normalize([0., 0., 0.], [2., 2., 2.]), transforms.Normalize([-0.5, -0.5, -0.5], [1., 1., 1.]), ] ) self.tfm = tfm self.sigm = torch.nn.Sigmoid() def get_u_idxs(self, all_indexes): res = [] for i in range(all_indexes[-1] + 1): res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0])) return torch.tensor(res) def get_mask(self, y, threshold=0.5): #print(y.type(), y.shape, y.max(), y.min()) y = y.long() faces = self.face_detector(y) #print(len(faces['image_ids'])) faces = self.face_parser(y, faces) seg_logits = faces['seg']['logits'] seg_probs = seg_logits.softmax(dim=1) uniq_idx = self.get_u_idxs(faces['image_ids']) chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1) return chroma_mask def forward(self, x, y): if self.tfm: mask_y = self.get_mask(255. * self.to_farl(y)) else: mask_y = self.get_mask(255. * y) return (1 - mask_y) * x + mask_y * y def forward2(self, x, y): batch = (255. * self.to_farl(y)).long() try: faces = self.face_detector(batch) assert len(faces['image_ids']) != 0 except: for trash in [0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]: try: new_masker = Masker(trash=trash) faces = new_masker.face_detector(batch) assert len(faces['image_ids']) != 0 break except: pass assert len(faces['image_ids']) != 0 faces = self.face_parser(batch, faces) farl_mask = self.sigm(faces['seg']['logits'][:, 0]) farl_mask = (farl_mask >= 0.995).float()[0] return (1 - farl_mask) * x + farl_mask * y class Masker(nn.Module): def __init__(self, trash=0.8): super().__init__() self.face_detector = RetinaFaceDetector(trash=trash).cuda().eval() self.face_parser = FaRLFaceParser().cuda().eval() def get_u_idxs(self, all_indexes): res = [] for i in range(all_indexes[-1] + 1): res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0])) return torch.tensor(res) def get_mask(self, y, threshold=0.5): faces = self.face_detector(y) faces = self.face_parser(y, faces) seg_logits = faces['seg']['logits'] seg_probs = seg_logits.softmax(dim=1) uniq_idx = self.get_u_idxs(faces['image_ids']) chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1) return chroma_mask def forward(self, x): return self.get_mask(255. * x).repeat(1, 3, 1, 1)