LogicGoInfotechSpaces's picture
Bundle StyleFeatureEditor code packages in Space to fix ModuleNotFoundError
95b1715
import facer
from facer.face_parsing import FaRLFaceParser
from facer.face_detection import RetinaFaceDetector
from facer.face_detection.retinaface import RetinaFace
from configs.paths import DefaultPaths
import torch.backends.cudnn as cudnn
import numpy as np
import torch
from torch import nn
import torchvision.transforms as transforms
import torch.nn.functional as F
def my_fp_init(self, model_path=DefaultPaths.farl_path):
super(FaRLFaceParser, self).__init__()
self.conf_name = 'lapa/448'
self.net = torch.jit.load(model_path)
self.eval()
FaRLFaceParser.__init__ = my_fp_init
def remove_prefix(state_dict, prefix):
""" Old style model is stored with all names of parameters sharing common prefix 'module.' """
def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
return True
def load_model(model, pretrained_path, load_to_cpu, network: str):
if load_to_cpu:
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage
)
else:
device = torch.cuda.current_device()
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage.cuda(device)
)
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(
pretrained_dict["state_dict"], "module.")
else:
pretrained_dict = remove_prefix(pretrained_dict, "module.")
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
def load_net(model_path):
cfg = {
"name": "mobilenet0.25",
"min_sizes": [[16, 32], [64, 128], [256, 512]],
"steps": [8, 16, 32],
"variance": [0.1, 0.2],
"clip": False,
"loc_weight": 2.0,
"gpu_train": True,
"batch_size": 32,
"ngpu": 1,
"epoch": 250,
"decay1": 190,
"decay2": 220,
"image_size": 640,
"pretrain": True,
"return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
"in_channel": 32,
"out_channel": 64,
}
# net and model
net = RetinaFace(cfg=cfg, phase="test").cuda()
net = load_model(net, model_path, True, network="mobilenet")
net.eval()
cudnn.benchmark = True
# net = net.to(device)
return net
def my_fd_init(self, model_path=DefaultPaths.mobile_net_pth, trash=0.8):
super(RetinaFaceDetector, self).__init__()
self.conf_name = 'mobilenet'
self.threshold=trash
self.net = load_net(model_path)
self.eval()
RetinaFaceDetector.__init__ = my_fd_init
class TargetMask(nn.Module):
def __init__(self, tfm=True):
super().__init__()
self.face_detector = RetinaFaceDetector(trash=0.8).cuda().eval()
self.face_parser = FaRLFaceParser().cuda().eval()
self.to_farl = transforms.Compose(
[
transforms.Normalize([0., 0., 0.], [2., 2., 2.]),
transforms.Normalize([-0.5, -0.5, -0.5], [1., 1., 1.]),
]
)
self.tfm = tfm
self.sigm = torch.nn.Sigmoid()
def get_u_idxs(self, all_indexes):
res = []
for i in range(all_indexes[-1] + 1):
res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0]))
return torch.tensor(res)
def get_mask(self, y, threshold=0.5):
#print(y.type(), y.shape, y.max(), y.min())
y = y.long()
faces = self.face_detector(y)
#print(len(faces['image_ids']))
faces = self.face_parser(y, faces)
seg_logits = faces['seg']['logits']
seg_probs = seg_logits.softmax(dim=1)
uniq_idx = self.get_u_idxs(faces['image_ids'])
chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1)
return chroma_mask
def forward(self, x, y):
if self.tfm:
mask_y = self.get_mask(255. * self.to_farl(y))
else:
mask_y = self.get_mask(255. * y)
return (1 - mask_y) * x + mask_y * y
def forward2(self, x, y):
batch = (255. * self.to_farl(y)).long()
try:
faces = self.face_detector(batch)
assert len(faces['image_ids']) != 0
except:
for trash in [0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
try:
new_masker = Masker(trash=trash)
faces = new_masker.face_detector(batch)
assert len(faces['image_ids']) != 0
break
except:
pass
assert len(faces['image_ids']) != 0
faces = self.face_parser(batch, faces)
farl_mask = self.sigm(faces['seg']['logits'][:, 0])
farl_mask = (farl_mask >= 0.995).float()[0]
return (1 - farl_mask) * x + farl_mask * y
class Masker(nn.Module):
def __init__(self, trash=0.8):
super().__init__()
self.face_detector = RetinaFaceDetector(trash=trash).cuda().eval()
self.face_parser = FaRLFaceParser().cuda().eval()
def get_u_idxs(self, all_indexes):
res = []
for i in range(all_indexes[-1] + 1):
res.append(((all_indexes == i).nonzero(as_tuple=True)[0][0]))
return torch.tensor(res)
def get_mask(self, y, threshold=0.5):
faces = self.face_detector(y)
faces = self.face_parser(y, faces)
seg_logits = faces['seg']['logits']
seg_probs = seg_logits.softmax(dim=1)
uniq_idx = self.get_u_idxs(faces['image_ids'])
chroma_mask = (seg_probs[uniq_idx, 0, :, :] >= threshold).to(y.dtype).unsqueeze(1)
return chroma_mask
def forward(self, x):
return self.get_mask(255. * x).repeat(1, 3, 1, 1)