|
|
import os |
|
|
import sys |
|
|
import cv2 |
|
|
import re |
|
|
from shared import path_manager |
|
|
import modules.async_worker as worker |
|
|
from tqdm import tqdm |
|
|
|
|
|
from modules.settings import default_settings |
|
|
from modules.util import generate_temp_filename |
|
|
|
|
|
from PIL import Image |
|
|
import imageio.v3 as iio |
|
|
import numpy as np |
|
|
import torch |
|
|
import insightface |
|
|
|
|
|
from importlib.abc import MetaPathFinder, Loader |
|
|
from importlib.util import spec_from_loader, module_from_spec |
|
|
|
|
|
class ImportRedirector(MetaPathFinder): |
|
|
def __init__(self, redirect_map): |
|
|
self.redirect_map = redirect_map |
|
|
|
|
|
def find_spec(self, fullname, path, target=None): |
|
|
if fullname in self.redirect_map: |
|
|
return spec_from_loader(fullname, ImportLoader(self.redirect_map[fullname])) |
|
|
return None |
|
|
|
|
|
class ImportLoader(Loader): |
|
|
def __init__(self, redirect): |
|
|
self.redirect = redirect |
|
|
|
|
|
def create_module(self, spec): |
|
|
return None |
|
|
|
|
|
def exec_module(self, module): |
|
|
import importlib |
|
|
redirected = importlib.import_module(self.redirect) |
|
|
module.__dict__.update(redirected.__dict__) |
|
|
|
|
|
|
|
|
redirect_map = { |
|
|
'torchvision.transforms.functional_tensor': 'torchvision.transforms.functional' |
|
|
} |
|
|
|
|
|
sys.meta_path.insert(0, ImportRedirector(redirect_map)) |
|
|
|
|
|
import gfpgan |
|
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class pipeline: |
|
|
pipeline_type = ["faceswap"] |
|
|
|
|
|
analyser_model = None |
|
|
analyser_hash = "" |
|
|
swapper_model = None |
|
|
swapper_hash = "" |
|
|
gfpgan_model = None |
|
|
|
|
|
def parse_gen_data(self, gen_data): |
|
|
gen_data["original_image_number"] = gen_data["image_number"] |
|
|
gen_data["image_number"] = 1 |
|
|
gen_data["show_preview"] = False |
|
|
return gen_data |
|
|
|
|
|
def load_base_model(self, name): |
|
|
model_name = "inswapper_128.onnx" |
|
|
if not self.swapper_hash == model_name: |
|
|
print(f"Loading swapper model: {model_name}") |
|
|
model_path = os.path.join(path_manager.model_paths["faceswap_path"], model_name) |
|
|
try: |
|
|
with open(os.devnull, "w") as sys.stdout: |
|
|
self.swapper_model = insightface.model_zoo.get_model( |
|
|
model_path, |
|
|
download=False, |
|
|
download_zip=False, |
|
|
) |
|
|
self.swapper_hash = model_name |
|
|
sys.stdout = sys.__stdout__ |
|
|
except: |
|
|
print(f"Failed loading model! {model_path}") |
|
|
|
|
|
model_name = "buffalo_l" |
|
|
det_thresh = 0.5 |
|
|
if not self.analyser_hash == model_name: |
|
|
print(f"Loading analyser model: {model_name}") |
|
|
try: |
|
|
with open(os.devnull, "w") as sys.stdout: |
|
|
self.analyser_model = insightface.app.FaceAnalysis(name=model_name) |
|
|
self.analyser_model.prepare( |
|
|
ctx_id=0, det_thresh=det_thresh, det_size=(640, 640) |
|
|
) |
|
|
self.analyser_hash = model_name |
|
|
sys.stdout = sys.__stdout__ |
|
|
except: |
|
|
print(f"Failed loading model! {model_name}") |
|
|
|
|
|
def load_gfpgan_model(self): |
|
|
if self.gfpgan_model is None: |
|
|
channel_multiplier = 2 |
|
|
|
|
|
model_name = "GFPGANv1.4.pth" |
|
|
model_path = os.path.join(path_manager.model_paths["faceswap_path"], model_name) |
|
|
|
|
|
|
|
|
self.gfpgan_model = gfpgan.GFPGANer |
|
|
self.gfpgan_model.bg_upsampler = None |
|
|
|
|
|
self.gfpgan_model.device = torch.device( |
|
|
"cuda" if torch.cuda.is_available() else "cpu" |
|
|
) |
|
|
|
|
|
upscale = 2 |
|
|
self.gfpgan_model.face_helper = FaceRestoreHelper( |
|
|
upscale, |
|
|
det_model="retinaface_resnet50", |
|
|
model_rootpath=path_manager.model_paths["faceswap_path"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.gfpgan_model.gfpgan = gfpgan.GFPGANv1Clean( |
|
|
out_size=512, |
|
|
num_style_feat=512, |
|
|
channel_multiplier=channel_multiplier, |
|
|
decoder_load_path=None, |
|
|
fix_decoder=False, |
|
|
num_mlp=8, |
|
|
input_is_latent=True, |
|
|
different_w=True, |
|
|
narrow=1, |
|
|
sft_half=True, |
|
|
) |
|
|
|
|
|
loadnet = torch.load(model_path) |
|
|
if "params_ema" in loadnet: |
|
|
keyname = "params_ema" |
|
|
else: |
|
|
keyname = "params" |
|
|
self.gfpgan_model.gfpgan.load_state_dict(loadnet[keyname], strict=True) |
|
|
self.gfpgan_model.gfpgan.eval() |
|
|
self.gfpgan_model.gfpgan = self.gfpgan_model.gfpgan.to( |
|
|
self.gfpgan_model.device |
|
|
) |
|
|
|
|
|
def load_keywords(self, lora): |
|
|
return "" |
|
|
|
|
|
def load_loras(self, loras): |
|
|
return |
|
|
|
|
|
def refresh_controlnet(self, name=None): |
|
|
return |
|
|
|
|
|
def clean_prompt_cond_caches(self): |
|
|
return |
|
|
|
|
|
def swap_faces(self, original_image, input_faces, out_faces): |
|
|
idx = 0 |
|
|
for out_face in out_faces: |
|
|
original_image = self.swapper_model.get( |
|
|
original_image, |
|
|
out_face, |
|
|
input_faces[idx % len(input_faces)], |
|
|
paste_back=True, |
|
|
) |
|
|
idx += 1 |
|
|
return original_image |
|
|
|
|
|
def restore_faces(self, image): |
|
|
self.load_gfpgan_model() |
|
|
|
|
|
image_bgr = image[:, :, ::-1] |
|
|
_cropped_faces, _restored_faces, gfpgan_output_bgr = self.gfpgan_model.enhance( |
|
|
self.gfpgan_model, |
|
|
image_bgr, |
|
|
has_aligned=False, |
|
|
only_center_face=False, |
|
|
paste_back=True, |
|
|
weight=0.5, |
|
|
) |
|
|
image = gfpgan_output_bgr[:, :, ::-1] |
|
|
|
|
|
return image |
|
|
|
|
|
def process( |
|
|
self, |
|
|
gen_data=None, |
|
|
callback=None, |
|
|
): |
|
|
worker.add_result( |
|
|
gen_data["task_id"], |
|
|
"preview", |
|
|
(-1, f"Generating ...", None) |
|
|
) |
|
|
|
|
|
input_image = gen_data["input_image"] |
|
|
input_image = cv2.cvtColor(np.asarray(input_image), cv2.COLOR_RGB2BGR) |
|
|
input_faces = sorted( |
|
|
self.analyser_model.get(input_image), key=lambda x: x.bbox[0] |
|
|
) |
|
|
|
|
|
prompt = gen_data["prompt"].strip() |
|
|
if re.fullmatch("https?://.*\.gif", prompt, re.IGNORECASE) is not None: |
|
|
x = iio.immeta(prompt) |
|
|
duration = x["duration"] |
|
|
loop = x["loop"] |
|
|
gif = cv2.VideoCapture(prompt) |
|
|
|
|
|
|
|
|
in_imgs = [] |
|
|
out_imgs = [] |
|
|
while True: |
|
|
ret, frame = gif.read() |
|
|
if not ret: |
|
|
break |
|
|
in_imgs.append(frame) |
|
|
|
|
|
with tqdm(total=len(in_imgs), desc="Groop", unit="frames") as progress: |
|
|
i=0 |
|
|
steps=len(in_imgs) |
|
|
for frame in in_imgs: |
|
|
out_faces = sorted( |
|
|
self.analyser_model.get(frame), key=lambda x: x.bbox[0] |
|
|
) |
|
|
frame = self.swap_faces(frame, input_faces, out_faces) |
|
|
out_imgs.append( |
|
|
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
) |
|
|
i+=1 |
|
|
callback(i, 0, 0, steps, out_imgs[-1]) |
|
|
progress.update(1) |
|
|
images = generate_temp_filename( |
|
|
folder=path_manager.model_paths["temp_outputs_path"], extension="gif" |
|
|
) |
|
|
os.makedirs(os.path.dirname(images), exist_ok=True) |
|
|
out_imgs[0].save( |
|
|
images, |
|
|
save_all=True, |
|
|
append_images=out_imgs[1:], |
|
|
optimize=True, |
|
|
duration=duration, |
|
|
loop=loop, |
|
|
) |
|
|
else: |
|
|
output_image = cv2.imread(gen_data["main_view"]) |
|
|
if output_image is None: |
|
|
images = "html/error.png" |
|
|
else: |
|
|
output_faces = sorted( |
|
|
self.analyser_model.get(output_image), key=lambda x: x.bbox[0] |
|
|
) |
|
|
result_image = self.swap_faces(output_image, input_faces, output_faces) |
|
|
result_image = self.restore_faces(result_image) |
|
|
images = Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
return [images] |
|
|
|