|
|
import facer |
|
|
from facer.face_parsing import FaRLFaceParser |
|
|
from facer.face_detection import RetinaFaceDetector |
|
|
from facer.face_detection.retinaface import RetinaFace |
|
|
from configs.paths import DefaultPaths |
|
|
import torch.backends.cudnn as cudnn |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
def my_fp_init(self, model_path=DefaultPaths.farl_path): |
|
|
super(FaRLFaceParser, self).__init__() |
|
|
self.conf_name = 'lapa/448' |
|
|
self.net = torch.jit.load(model_path) |
|
|
self.eval() |
|
|
|
|
|
FaRLFaceParser.__init__ = my_fp_init |
|
|
|
|
|
def remove_prefix(state_dict, prefix): |
|
|
""" Old style model is stored with all names of parameters sharing common prefix 'module.' """ |
|
|
def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x |
|
|
return {f(key): value for key, value in state_dict.items()} |
|
|
|
|
|
def check_keys(model, pretrained_state_dict): |
|
|
ckpt_keys = set(pretrained_state_dict.keys()) |
|
|
model_keys = set(model.state_dict().keys()) |
|
|
used_pretrained_keys = model_keys & ckpt_keys |
|
|
assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" |
|
|
return True |
|
|
|
|
|
def load_model(model, pretrained_path, load_to_cpu, network: str): |
|
|
if load_to_cpu: |
|
|
pretrained_dict = torch.load( |
|
|
pretrained_path, map_location=lambda storage, loc: storage |
|
|
) |
|
|
else: |
|
|
device = torch.cuda.current_device() |
|
|
pretrained_dict = torch.load( |
|
|
pretrained_path, map_location=lambda storage, loc: storage.cuda(device) |
|
|
) |
|
|
if "state_dict" in pretrained_dict.keys(): |
|
|
pretrained_dict = remove_prefix( |
|
|
pretrained_dict["state_dict"], "module.") |
|
|
else: |
|
|
pretrained_dict = remove_prefix(pretrained_dict, "module.") |
|
|
check_keys(model, pretrained_dict) |
|
|
model.load_state_dict(pretrained_dict, strict=False) |
|
|
return model |
|
|
|
|
|
def load_net(model_path): |
|
|
cfg = { |
|
|
"name": "mobilenet0.25", |
|
|
"min_sizes": [[16, 32], [64, 128], [256, 512]], |
|
|
"steps": [8, 16, 32], |
|
|
"variance": [0.1, 0.2], |
|
|
"clip": False, |
|
|
"loc_weight": 2.0, |
|
|
"gpu_train": True, |
|
|
"batch_size": 32, |
|
|
"ngpu": 1, |
|
|
"epoch": 250, |
|
|
"decay1": 190, |
|
|
"decay2": 220, |
|
|
"image_size": 640, |
|
|
"pretrain": True, |
|
|
"return_layers": {"stage1": 1, "stage2": 2, "stage3": 3}, |
|
|
"in_channel": 32, |
|
|
"out_channel": 64, |
|
|
} |
|
|
|
|
|
net = RetinaFace(cfg=cfg, phase="test").cuda() |
|
|
net = load_model(net, model_path, True, network="mobilenet") |
|
|
net.eval() |
|
|
cudnn.benchmark = True |
|
|
|
|
|
return net |
|
|
|
|
|
def my_fd_init(self, model_path=DefaultPaths.mobile_net_pth, trash=0.8): |
|
|
super(RetinaFaceDetector, self).__init__() |
|
|
self.conf_name = 'mobilenet' |
|
|
self.threshold=trash |
|
|
self.net = load_net(model_path) |
|
|
self.eval() |
|
|
|
|
|
RetinaFaceDetector.__init__ = my_fd_init |
|
|
|
|
|
class TargetMask(nn.Module): |
|
|
def __init__(self, tfm=True): |
|
|
super().__init__() |
|
|
self.face_detector = RetinaFaceDetector(trash=0.8).cuda().eval() |
|
|
self.face_parser = FaRLFaceParser().cuda().eval() |
|
|
self.to_farl = transforms.Compose( |
|
|
[ |
|
|
transforms.Normalize([0., 0., 0.], [2., 2., 2.]), |
|
|
transforms.Normalize([-0.5, -0.5, -0.5], [1., 1., 1.]), |
|
|
] |
|
|
) |
|
|
self.tfm = tfm |
|
|
self.sigm = torch.nn.Sigmoid() |
|
|
|
|
|
def get_u_idxs(self, all_indexes): |
|
|
res = [] |
|
|
for i in range(all_indexes[-1] + 1): |
|
|
res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0])) |
|
|
return torch.tensor(res) |
|
|
|
|
|
def get_mask(self, y, threshold=0.5): |
|
|
|
|
|
y = y.long() |
|
|
faces = self.face_detector(y) |
|
|
|
|
|
faces = self.face_parser(y, faces) |
|
|
|
|
|
seg_logits = faces['seg']['logits'] |
|
|
seg_probs = seg_logits.softmax(dim=1) |
|
|
|
|
|
uniq_idx = self.get_u_idxs(faces['image_ids']) |
|
|
|
|
|
chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1) |
|
|
return chroma_mask |
|
|
|
|
|
def forward(self, x, y): |
|
|
if self.tfm: |
|
|
mask_y = self.get_mask(255. * self.to_farl(y)) |
|
|
else: |
|
|
mask_y = self.get_mask(255. * y) |
|
|
return (1 - mask_y) * x + mask_y * y |
|
|
|
|
|
def forward2(self, x, y): |
|
|
batch = (255. * self.to_farl(y)).long() |
|
|
try: |
|
|
faces = self.face_detector(batch) |
|
|
assert len(faces['image_ids']) != 0 |
|
|
except: |
|
|
for trash in [0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]: |
|
|
try: |
|
|
new_masker = Masker(trash=trash) |
|
|
faces = new_masker.face_detector(batch) |
|
|
assert len(faces['image_ids']) != 0 |
|
|
break |
|
|
except: |
|
|
pass |
|
|
assert len(faces['image_ids']) != 0 |
|
|
faces = self.face_parser(batch, faces) |
|
|
farl_mask = self.sigm(faces['seg']['logits'][:, 0]) |
|
|
farl_mask = (farl_mask >= 0.995).float()[0] |
|
|
|
|
|
return (1 - farl_mask) * x + farl_mask * y |
|
|
|
|
|
|
|
|
class Masker(nn.Module): |
|
|
def __init__(self, trash=0.8): |
|
|
super().__init__() |
|
|
self.face_detector = RetinaFaceDetector(trash=trash).cuda().eval() |
|
|
self.face_parser = FaRLFaceParser().cuda().eval() |
|
|
|
|
|
def get_u_idxs(self, all_indexes): |
|
|
res = [] |
|
|
for i in range(all_indexes[-1] + 1): |
|
|
res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0])) |
|
|
return torch.tensor(res) |
|
|
|
|
|
def get_mask(self, y, threshold=0.5): |
|
|
faces = self.face_detector(y) |
|
|
faces = self.face_parser(y, faces) |
|
|
|
|
|
seg_logits = faces['seg']['logits'] |
|
|
seg_probs = seg_logits.softmax(dim=1) |
|
|
|
|
|
uniq_idx = self.get_u_idxs(faces['image_ids']) |
|
|
|
|
|
chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1) |
|
|
return chroma_mask |
|
|
|
|
|
def forward(self, x): |
|
|
return self.get_mask(255. * x).repeat(1, 3, 1, 1) |
|
|
|