| | import os
|
| | import cv2
|
| | import tkinter as tk
|
| | from PIL import Image, ImageTk
|
| | import threading
|
| | import time
|
| | import numpy as np
|
| | from skimage import transform as trans
|
| | import subprocess
|
| | from math import floor, ceil
|
| | import bisect
|
| | import onnxruntime
|
| | import torchvision
|
| | from torchvision.transforms.functional import normalize
|
| | import torch
|
| | from torchvision import transforms
|
| | torchvision.disable_beta_transforms_warning()
|
| | from torchvision.transforms import v2
|
| | torch.set_grad_enabled(False)
|
| | onnxruntime.set_default_logger_severity(4)
|
| |
|
| | import inspect
|
| |
|
| | device = 'cuda'
|
| |
|
| | lock=threading.Lock()
|
| |
|
| | class VideoManager():
|
| | def __init__(self, models ):
|
| | self.models = models
|
| |
|
| | self.swapper_model = []
|
| |
|
| | self.input_names = []
|
| | self.input_size = []
|
| |
|
| | self.output_names = []
|
| | self.arcface_dst = np.array( [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
|
| |
|
| | self.video_file = []
|
| |
|
| | self.FFHQ_kps = np.array([[ 192.98138, 239.94708 ], [ 318.90277, 240.1936 ], [ 256.63416, 314.01935 ], [ 201.26117, 371.41043 ], [ 313.08905, 371.15118 ] ])
|
| |
|
| |
|
| |
|
| |
|
| | self.capture = []
|
| | self.is_video_loaded = False
|
| | self.video_frame_total = None
|
| | self.play = False
|
| | self.current_frame = 0
|
| | self.create_video = False
|
| | self.output_video = []
|
| | self.file_name = []
|
| |
|
| |
|
| |
|
| |
|
| | self.frame_timer = 0.0
|
| |
|
| |
|
| | self.action_q = []
|
| | self.frame_q = []
|
| |
|
| | self.r_frame_q = []
|
| | self.read_video_frame_q = []
|
| |
|
| |
|
| |
|
| |
|
| | self.found_faces = []
|
| |
|
| | self.parameters = []
|
| |
|
| |
|
| | self.target_video = []
|
| |
|
| | self.fps = 1.0
|
| | self.temp_file = []
|
| |
|
| |
|
| | self.clip_session = []
|
| |
|
| | self.start_time = []
|
| | self.record = False
|
| | self.output = []
|
| | self.image = []
|
| |
|
| | self.saved_video_path = []
|
| | self.sp = []
|
| | self.timer = []
|
| | self.fps_average = []
|
| | self.total_thread_time = 0.0
|
| |
|
| | self.start_play_time = []
|
| | self.start_play_frame = []
|
| |
|
| | self.rec_thread = []
|
| | self.markers = []
|
| | self.is_image_loaded = False
|
| | self.stop_marker = -1
|
| | self.perf_test = False
|
| |
|
| | self.control = []
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.process_q = {
|
| | "Thread": [],
|
| | "FrameNumber": [],
|
| | "ProcessedFrame": [],
|
| | "Status": 'clear',
|
| | "ThreadTime": []
|
| | }
|
| | self.process_qs = []
|
| | self.rec_q = {
|
| | "Thread": [],
|
| | "FrameNumber": [],
|
| | "Status": 'clear'
|
| | }
|
| | self.rec_qs = []
|
| |
|
| | def assign_found_faces(self, found_faces):
|
| | self.found_faces = found_faces
|
| |
|
| |
|
| | def load_target_video( self, file ):
|
| |
|
| | if self.capture:
|
| | self.capture.release()
|
| |
|
| |
|
| | self.video_file = file
|
| | self.capture = cv2.VideoCapture(file)
|
| | self.fps = self.capture.get(cv2.CAP_PROP_FPS)
|
| |
|
| | if not self.capture.isOpened():
|
| | print("Cannot open file: ", file)
|
| |
|
| | else:
|
| | self.target_video = file
|
| | self.is_video_loaded = True
|
| | self.is_image_loaded = False
|
| | self.video_frame_total = int(self.capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
| | self.play = False
|
| | self.current_frame = 0
|
| | self.frame_timer = time.time()
|
| | self.frame_q = []
|
| | self.r_frame_q = []
|
| | self.found_faces = []
|
| | self.add_action("set_slider_length",self.video_frame_total-1)
|
| |
|
| | self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
|
| | success, image = self.capture.read()
|
| |
|
| | if success:
|
| | crop = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| | temp = [crop, False]
|
| | self.r_frame_q.append(temp)
|
| | self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
|
| |
|
| | def load_target_image(self, file):
|
| | if self.capture:
|
| | self.capture.release()
|
| | self.is_video_loaded = False
|
| | self.play = False
|
| | self.frame_q = []
|
| | self.r_frame_q = []
|
| | self.found_faces = []
|
| | self.image = cv2.imread(file)
|
| | self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
| | temp = [self.image, False]
|
| | self.frame_q.append(temp)
|
| |
|
| | self.is_image_loaded = True
|
| |
|
| |
|
| |
|
| | def add_action(self, action, param):
|
| |
|
| | temp = [action, param]
|
| | self.action_q.append(temp)
|
| |
|
| | def get_action_length(self):
|
| | return len(self.action_q)
|
| |
|
| | def get_action(self):
|
| | action = self.action_q[0]
|
| | self.action_q.pop(0)
|
| | return action
|
| |
|
| |
|
| | def get_frame(self):
|
| | frame = self.frame_q[0]
|
| | self.frame_q.pop(0)
|
| | return frame
|
| |
|
| | def get_frame_length(self):
|
| | return len(self.frame_q)
|
| |
|
| | def get_requested_frame(self):
|
| | frame = self.r_frame_q[0]
|
| | self.r_frame_q.pop(0)
|
| | return frame
|
| |
|
| | def get_requested_frame_length(self):
|
| | return len(self.r_frame_q)
|
| |
|
| |
|
| | def get_requested_video_frame(self, frame, marker=True):
|
| | temp = []
|
| | if self.is_video_loaded:
|
| |
|
| | if self.play == True:
|
| | self.play_video("stop")
|
| | self.process_qs = []
|
| |
|
| | self.current_frame = int(frame)
|
| |
|
| | self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
|
| | success, target_image = self.capture.read()
|
| |
|
| | if success:
|
| | target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB)
|
| | if not self.control['SwapFacesButton']:
|
| | temp = [target_image, self.current_frame]
|
| | else:
|
| | temp = [self.swap_video(target_image, self.current_frame, marker), self.current_frame]
|
| |
|
| | self.r_frame_q.append(temp)
|
| |
|
| | elif self.is_image_loaded:
|
| | if not self.control['SwapFacesButton']:
|
| | temp = [self.image, self.current_frame]
|
| |
|
| | else:
|
| | temp = [self.swap_video(self.image, self.current_frame, False), self.current_frame]
|
| |
|
| | self.r_frame_q.append(temp)
|
| |
|
| |
|
| | def find_lowest_frame(self, queues):
|
| | min_frame=999999999
|
| | index=-1
|
| |
|
| | for idx, thread in enumerate(queues):
|
| | frame = thread['FrameNumber']
|
| | if frame != []:
|
| | if frame < min_frame:
|
| | min_frame = frame
|
| | index=idx
|
| | return index, min_frame
|
| |
|
| |
|
| | def play_video(self, command):
|
| |
|
| | if command == "play":
|
| |
|
| | self.play = True
|
| | self.fps_average = []
|
| | self.process_qs = []
|
| | self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
|
| | self.frame_timer = time.time()
|
| |
|
| |
|
| | for i in range(self.parameters['ThreadsSlider']):
|
| | new_process_q = self.process_q.copy()
|
| | self.process_qs.append(new_process_q)
|
| |
|
| |
|
| |
|
| | if self.control['AudioButton']:
|
| | seek_time = (self.current_frame)/self.fps
|
| | args = ["ffplay",
|
| | '-vn',
|
| | '-ss', str(seek_time),
|
| | '-nodisp',
|
| | '-stats',
|
| | '-loglevel', 'quiet',
|
| | '-sync', 'audio',
|
| | self.video_file]
|
| |
|
| |
|
| | self.audio_sp = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
| |
|
| |
|
| | while True:
|
| | temp = self.audio_sp.stdout.read(69)
|
| | if temp[:7] != b' nan':
|
| | sought_time = float(temp[:7])
|
| | self.current_frame = int(self.fps*sought_time)
|
| |
|
| | self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
|
| |
|
| | break
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | elif command == "stop":
|
| | self.play = False
|
| | self.add_action("stop_play", True)
|
| |
|
| | index, min_frame = self.find_lowest_frame(self.process_qs)
|
| |
|
| | if index != -1:
|
| | self.current_frame = min_frame-1
|
| |
|
| | if self.control['AudioButton']:
|
| | self.audio_sp.terminate()
|
| |
|
| | torch.cuda.empty_cache()
|
| |
|
| | elif command=='stop_from_gui':
|
| | self.play = False
|
| |
|
| |
|
| | index, min_frame = self.find_lowest_frame(self.process_qs)
|
| | if index != -1:
|
| | self.current_frame = min_frame-1
|
| |
|
| | if self.control['AudioButton']:
|
| | self.audio_sp.terminate()
|
| |
|
| | torch.cuda.empty_cache()
|
| |
|
| | elif command == "record":
|
| | self.record = True
|
| | self.play = True
|
| | self.total_thread_time = 0.0
|
| | self.process_qs = []
|
| | self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
|
| |
|
| | for i in range(self.parameters['ThreadsSlider']):
|
| | new_process_q = self.process_q.copy()
|
| | self.process_qs.append(new_process_q)
|
| |
|
| |
|
| | self.timer = time.time()
|
| | frame_width = int(self.capture.get(3))
|
| | frame_width = int(self.capture.get(3))
|
| | frame_height = int(self.capture.get(4))
|
| |
|
| | self.start_time = float(self.capture.get(cv2.CAP_PROP_POS_FRAMES) / float(self.fps))
|
| |
|
| | self.file_name = os.path.splitext(os.path.basename(self.target_video))
|
| | base_filename = self.file_name[0]+"_"+str(time.time())[:10]
|
| | self.output = os.path.join(self.saved_video_path, base_filename)
|
| | self.temp_file = self.output+"_temp"+self.file_name[1]
|
| |
|
| | if self.parameters['RecordTypeTextSel']=='FFMPEG':
|
| | args = ["ffmpeg",
|
| | '-hide_banner',
|
| | '-loglevel', 'error',
|
| | "-an",
|
| | "-r", str(self.fps),
|
| | "-i", "pipe:",
|
| |
|
| | "-vf", "format=yuvj420p",
|
| | "-c:v", "libx264",
|
| | "-crf", str(self.parameters['VideoQualSlider']),
|
| | "-r", str(self.fps),
|
| | "-s", str(frame_width)+"x"+str(frame_height),
|
| | self.temp_file]
|
| |
|
| | self.sp = subprocess.Popen(args, stdin=subprocess.PIPE)
|
| |
|
| | elif self.parameters['RecordTypeTextSel']=='OPENCV':
|
| | size = (frame_width, frame_height)
|
| | self.sp = cv2.VideoWriter(self.temp_file, cv2.VideoWriter_fourcc(*'mp4v') , self.fps, size)
|
| |
|
| |
|
| | def process(self):
|
| | process_qs_len = range(len(self.process_qs))
|
| |
|
| |
|
| | if self.play == True and self.is_video_loaded == True:
|
| | for item in self.process_qs:
|
| | if item['Status'] == 'clear' and self.current_frame < self.video_frame_total:
|
| | item['Thread'] = threading.Thread(target=self.thread_video_read, args = [self.current_frame]).start()
|
| | item['FrameNumber'] = self.current_frame
|
| | item['Status'] = 'started'
|
| | item['ThreadTime'] = time.time()
|
| |
|
| | self.current_frame += 1
|
| | break
|
| |
|
| | else:
|
| | self.play = False
|
| |
|
| |
|
| | time_diff = time.time() - self.frame_timer
|
| |
|
| | if not self.record and time_diff >= 1.0/float(self.fps) and self.play:
|
| |
|
| | index, min_frame = self.find_lowest_frame(self.process_qs)
|
| |
|
| | if index != -1:
|
| | if self.process_qs[index]['Status'] == 'finished':
|
| | temp = [self.process_qs[index]['ProcessedFrame'], self.process_qs[index]['FrameNumber']]
|
| | self.frame_q.append(temp)
|
| |
|
| |
|
| | self.fps_average.append(1.0/time_diff)
|
| | if len(self.fps_average) >= floor(self.fps):
|
| | fps = round(np.average(self.fps_average), 2)
|
| | msg = "%s fps, %s process time" % (fps, round(self.process_qs[index]['ThreadTime'], 4))
|
| | self.fps_average = []
|
| |
|
| | if self.process_qs[index]['FrameNumber'] >= self.video_frame_total-1 or self.process_qs[index]['FrameNumber'] == self.stop_marker:
|
| | self.play_video('stop')
|
| |
|
| | self.process_qs[index]['Status'] = 'clear'
|
| | self.process_qs[index]['Thread'] = []
|
| | self.process_qs[index]['FrameNumber'] = []
|
| | self.process_qs[index]['ThreadTime'] = []
|
| | self.frame_timer += 1.0/self.fps
|
| |
|
| | elif self.record:
|
| |
|
| | index, min_frame = self.find_lowest_frame(self.process_qs)
|
| |
|
| | if index != -1:
|
| |
|
| |
|
| | if self.process_qs[index]['Status'] == 'finished':
|
| | image = self.process_qs[index]['ProcessedFrame']
|
| |
|
| | if self.parameters['RecordTypeTextSel']=='FFMPEG':
|
| | pil_image = Image.fromarray(image)
|
| | pil_image.save(self.sp.stdin, 'BMP')
|
| |
|
| | elif self.parameters['RecordTypeTextSel']=='OPENCV':
|
| | self.sp.write(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| |
|
| | temp = [image, self.process_qs[index]['FrameNumber']]
|
| | self.frame_q.append(temp)
|
| |
|
| |
|
| | if self.process_qs[index]['FrameNumber'] >= self.video_frame_total-1 or self.process_qs[index]['FrameNumber'] == self.stop_marker or self.play == False:
|
| | self.play_video("stop")
|
| | stop_time = float(self.capture.get(cv2.CAP_PROP_POS_FRAMES) / float(self.fps))
|
| | if stop_time == 0:
|
| | stop_time = float(self.video_frame_total) / float(self.fps)
|
| |
|
| | if self.parameters['RecordTypeTextSel']=='FFMPEG':
|
| | self.sp.stdin.close()
|
| | self.sp.wait()
|
| | elif self.parameters['RecordTypeTextSel']=='OPENCV':
|
| | self.sp.release()
|
| |
|
| | orig_file = self.target_video
|
| | final_file = self.output+self.file_name[1]
|
| | print("adding audio...")
|
| | args = ["ffmpeg",
|
| | '-hide_banner',
|
| | '-loglevel', 'error',
|
| | "-i", self.temp_file,
|
| | "-ss", str(self.start_time), "-to", str(stop_time), "-i", orig_file,
|
| | "-c", "copy",
|
| | "-map", "0:v:0", "-map", "1:a:0?",
|
| | "-shortest",
|
| | final_file]
|
| |
|
| | four = subprocess.run(args)
|
| | os.remove(self.temp_file)
|
| |
|
| | timef= time.time() - self.timer
|
| | self.record = False
|
| | print('Video saved as:', final_file)
|
| | msg = "Total time: %s s." % (round(timef,1))
|
| | print(msg)
|
| |
|
| |
|
| | self.total_thread_time = []
|
| | self.process_qs[index]['Status'] = 'clear'
|
| | self.process_qs[index]['FrameNumber'] = []
|
| | self.process_qs[index]['Thread'] = []
|
| | self.frame_timer = time.time()
|
| |
|
| | def thread_video_read(self, frame_number):
|
| | with lock:
|
| | success, target_image = self.capture.read()
|
| |
|
| | if success:
|
| | target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB)
|
| | if not self.control['SwapFacesButton']:
|
| | temp = [target_image, frame_number]
|
| |
|
| | else:
|
| | temp = [self.swap_video(target_image, frame_number, True), frame_number]
|
| |
|
| | for item in self.process_qs:
|
| | if item['FrameNumber'] == frame_number:
|
| | item['ProcessedFrame'] = temp[0]
|
| | item['Status'] = 'finished'
|
| | item['ThreadTime'] = time.time() - item['ThreadTime']
|
| | break
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def swap_video(self, target_image, frame_number, use_markers):
|
| |
|
| | parameters = self.parameters.copy()
|
| | control = self.control.copy()
|
| |
|
| |
|
| | if self.markers and use_markers:
|
| | temp=[]
|
| | for i in range(len(self.markers)):
|
| | temp.append(self.markers[i]['frame'])
|
| | idx = bisect.bisect(temp, frame_number)
|
| |
|
| | parameters = self.markers[idx-1]['parameters'].copy()
|
| |
|
| |
|
| | img = torch.from_numpy(target_image.astype('uint8')).to('cuda')
|
| | img = img.permute(2,0,1)
|
| |
|
| |
|
| | img_x = img.size()[2]
|
| | img_y = img.size()[1]
|
| |
|
| | if img_x<512 and img_y<512:
|
| |
|
| | if img_x <= img_y:
|
| | tscale = v2.Resize((int(512*img_y/img_x), 512), antialias=True)
|
| | else:
|
| | tscale = v2.Resize((512, int(512*img_x/img_y)), antialias=True)
|
| |
|
| | img = tscale(img)
|
| |
|
| | elif img_x<512:
|
| | tscale = v2.Resize((int(512*img_y/img_x), 512), antialias=True)
|
| | img = tscale(img)
|
| |
|
| | elif img_y<512:
|
| | tscale = v2.Resize((512, int(512*img_x/img_y)), antialias=True)
|
| | img = tscale(img)
|
| |
|
| |
|
| | if parameters['OrientSwitch']:
|
| | img = v2.functional.rotate(img, angle=parameters['OrientSlider'], interpolation=v2.InterpolationMode.BILINEAR, expand=True)
|
| |
|
| |
|
| | bboxes, kpss = self.func_w_test("detect", self.models.run_detect, img, parameters['DetectTypeTextSel'], max_num=20, score=parameters['DetectScoreSlider']/100.0, use_landmark_detection=parameters['LandmarksDetectionAdjSwitch'], landmark_detect_mode=parameters["LandmarksDetectTypeTextSel"], landmark_score=parameters["LandmarksDetectScoreSlider"]/100.0, from_points=parameters["LandmarksAlignModeFromPointsSwitch"])
|
| |
|
| |
|
| | ret = []
|
| | for face_kps in kpss:
|
| | face_emb, _ = self.func_w_test('recognize', self.models.run_recognize, img, face_kps)
|
| | ret.append([face_kps, face_emb])
|
| |
|
| | if ret:
|
| |
|
| | for fface in ret:
|
| | for found_face in self.found_faces:
|
| |
|
| | sim = self.findCosineDistance(fface[1], found_face["Embedding"])
|
| |
|
| | if sim>=float(parameters["ThresholdSlider"]) and found_face["SourceFaceAssignments"]:
|
| | s_e = found_face["AssignedEmbedding"]
|
| |
|
| | img = self.func_w_test("swap_video", self.swap_core, img, fface[0], s_e, parameters, control)
|
| |
|
| |
|
| | img = img.permute(1,2,0)
|
| | if not control['MaskViewButton'] and parameters['OrientSwitch']:
|
| | img = img.permute(2,0,1)
|
| | img = transforms.functional.rotate(img, angle=-parameters['OrientSlider'], expand=True)
|
| | img = img.permute(1,2,0)
|
| |
|
| | else:
|
| | img = img.permute(1,2,0)
|
| | if parameters['OrientSwitch']:
|
| | img = img.permute(2,0,1)
|
| | img = v2.functional.rotate(img, angle=-parameters['OrientSlider'], interpolation=v2.InterpolationMode.BILINEAR, expand=True)
|
| | img = img.permute(1,2,0)
|
| |
|
| | if self.perf_test:
|
| | print('------------------------')
|
| |
|
| |
|
| | if img_x <512 or img_y < 512:
|
| | tscale = v2.Resize((img_y, img_x), antialias=True)
|
| | img = img.permute(2,0,1)
|
| | img = tscale(img)
|
| | img = img.permute(1,2,0)
|
| |
|
| |
|
| | img = img.cpu().numpy()
|
| |
|
| | if parameters["ShowLandmarksSwitch"]:
|
| | if ret:
|
| | if img_y <= 720:
|
| | p = 1
|
| | else:
|
| | p = 2
|
| |
|
| | for face in ret:
|
| | for kpoint in face[0]:
|
| | for i in range(-1, p):
|
| | for j in range(-1, p):
|
| | try:
|
| | img[int(kpoint[1])+i][int(kpoint[0])+j][0] = 0
|
| | img[int(kpoint[1])+i][int(kpoint[0])+j][1] = 255
|
| | img[int(kpoint[1])+i][int(kpoint[0])+j][2] = 255
|
| | except:
|
| | print("Key-points value {} exceed the image size {}.".format(kpoint, (img_x, img_y)))
|
| | continue
|
| |
|
| | return img.astype(np.uint8)
|
| |
|
| | def findCosineDistance(self, vector1, vector2):
|
| | vector1 = vector1.ravel()
|
| | vector2 = vector2.ravel()
|
| | cos_dist = 1.0 - np.dot(vector1, vector2)/(np.linalg.norm(vector1)*np.linalg.norm(vector2))
|
| |
|
| | return 100.0-cos_dist*50.0
|
| | '''
|
| | vector1 = vector1.ravel()
|
| | vector2 = vector2.ravel()
|
| |
|
| | return 1 - np.dot(vector1, vector2)/(np.linalg.norm(vector1)*np.linalg.norm(vector2))
|
| | '''
|
| |
|
| | def func_w_test(self, name, func, *args, **argsv):
|
| | timing = time.time()
|
| | result = func(*args, **argsv)
|
| | if self.perf_test:
|
| | print(name, round(time.time()-timing, 5), 's')
|
| | return result
|
| |
|
| |
|
| | def swap_core(self, img, kps, s_e, parameters, control):
|
| |
|
| | dst = self.arcface_dst * 4.0
|
| | dst[:,0] += 32.0
|
| |
|
| |
|
| | if parameters['FaceAdjSwitch']:
|
| | dst[:,0] += parameters['KPSXSlider']
|
| | dst[:,1] += parameters['KPSYSlider']
|
| | dst[:,0] -= 255
|
| | dst[:,0] *= (1+parameters['KPSScaleSlider']/100)
|
| | dst[:,0] += 255
|
| | dst[:,1] -= 255
|
| | dst[:,1] *= (1+parameters['KPSScaleSlider']/100)
|
| | dst[:,1] += 255
|
| |
|
| | tform = trans.SimilarityTransform()
|
| | tform.estimate(kps, dst)
|
| |
|
| |
|
| | t512 = v2.Resize((512, 512), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
|
| | t256 = v2.Resize((256, 256), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
|
| | t128 = v2.Resize((128, 128), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
|
| |
|
| |
|
| | original_face_512 = v2.functional.affine(img, tform.rotation*57.2958, (tform.translation[0], tform.translation[1]) , tform.scale, 0, center = (0,0), interpolation=v2.InterpolationMode.BILINEAR )
|
| | original_face_512 = v2.functional.crop(original_face_512, 0,0, 512, 512)
|
| | original_face_256 = t256(original_face_512)
|
| | original_face_128 = t128(original_face_256)
|
| |
|
| | latent = torch.from_numpy(self.models.calc_swapper_latent(s_e)).float().to('cuda')
|
| |
|
| | dim = 1
|
| | if parameters['SwapperTypeTextSel'] == '128':
|
| | dim = 1
|
| | input_face_affined = original_face_128
|
| | elif parameters['SwapperTypeTextSel'] == '256':
|
| | dim = 2
|
| | input_face_affined = original_face_256
|
| | elif parameters['SwapperTypeTextSel'] == '512':
|
| | dim = 4
|
| | input_face_affined = original_face_512
|
| |
|
| |
|
| | if parameters['FaceAdjSwitch']:
|
| | input_face_affined = v2.functional.affine(input_face_affined, 0, (0, 0), 1 + parameters['FaceScaleSlider'] / 100, 0, center=(dim*128-1, dim*128-1), interpolation=v2.InterpolationMode.BILINEAR)
|
| |
|
| | itex = 1
|
| | if parameters['StrengthSwitch']:
|
| | itex = ceil(parameters['StrengthSlider'] / 100.)
|
| |
|
| | output_size = int(128 * dim)
|
| | output = torch.zeros((output_size, output_size, 3), dtype=torch.float32, device='cuda')
|
| | input_face_affined = input_face_affined.permute(1, 2, 0)
|
| | input_face_affined = torch.div(input_face_affined, 255.0)
|
| |
|
| | for k in range(itex):
|
| | for j in range(dim):
|
| | for i in range(dim):
|
| | input_face_disc = input_face_affined[j::dim,i::dim]
|
| | input_face_disc = input_face_disc.permute(2, 0, 1)
|
| | input_face_disc = torch.unsqueeze(input_face_disc, 0).contiguous()
|
| |
|
| | swapper_output = torch.empty((1,3,128,128), dtype=torch.float32, device='cuda').contiguous()
|
| | self.models.run_swapper(input_face_disc, latent, swapper_output)
|
| |
|
| | swapper_output = torch.squeeze(swapper_output)
|
| | swapper_output = swapper_output.permute(1, 2, 0)
|
| |
|
| |
|
| | output[j::dim, i::dim] = swapper_output.clone()
|
| | prev_face = input_face_affined.clone()
|
| | input_face_affined = output.clone()
|
| | output = torch.mul(output, 255)
|
| | output = torch.clamp(output, 0, 255)
|
| |
|
| |
|
| | output = output.permute(2, 0, 1)
|
| |
|
| |
|
| | swap = t512(output)
|
| |
|
| | if parameters['StrengthSwitch']:
|
| | if itex == 0:
|
| | swap = original_face_512.clone()
|
| | else:
|
| | alpha = np.mod(parameters['StrengthSlider'], 100)*0.01
|
| | if alpha==0:
|
| | alpha=1
|
| |
|
| |
|
| | prev_face = torch.mul(prev_face, 255)
|
| | prev_face = torch.clamp(prev_face, 0, 255)
|
| | prev_face = prev_face.permute(2, 0, 1)
|
| | prev_face = t512(prev_face)
|
| | swap = torch.mul(swap, alpha)
|
| | prev_face = torch.mul(prev_face, 1-alpha)
|
| | swap = torch.add(swap, prev_face)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if parameters['ColorSwitch']:
|
| |
|
| | swap = torch.unsqueeze(swap,0)
|
| | swap = v2.functional.adjust_gamma(swap, parameters['ColorGammaSlider'], 1.0)
|
| | swap = torch.squeeze(swap)
|
| | swap = swap.permute(1, 2, 0).type(torch.float32)
|
| |
|
| | del_color = torch.tensor([parameters['ColorRedSlider'], parameters['ColorGreenSlider'], parameters['ColorBlueSlider']], device=device)
|
| | swap += del_color
|
| | swap = torch.clamp(swap, min=0., max=255.)
|
| | swap = swap.permute(2, 0, 1).type(torch.uint8)
|
| |
|
| |
|
| | border_mask = torch.ones((128, 128), dtype=torch.float32, device=device)
|
| | border_mask = torch.unsqueeze(border_mask,0)
|
| |
|
| |
|
| | top = parameters['BorderTopSlider']
|
| | left = parameters['BorderSidesSlider']
|
| | right = 128-parameters['BorderSidesSlider']
|
| | bottom = 128-parameters['BorderBottomSlider']
|
| |
|
| | border_mask[:, :top, :] = 0
|
| | border_mask[:, bottom:, :] = 0
|
| | border_mask[:, :, :left] = 0
|
| | border_mask[:, :, right:] = 0
|
| |
|
| | gauss = transforms.GaussianBlur(parameters['BorderBlurSlider']*2+1, (parameters['BorderBlurSlider']+1)*0.2)
|
| | border_mask = gauss(border_mask)
|
| |
|
| |
|
| | swap_mask = torch.ones((128, 128), dtype=torch.float32, device=device)
|
| | swap_mask = torch.unsqueeze(swap_mask,0)
|
| |
|
| |
|
| | if parameters["DiffSwitch"]:
|
| | mask = self.apply_fake_diff(swap, original_face_512, parameters["DiffSlider"])
|
| |
|
| | gauss = transforms.GaussianBlur(parameters['BlendSlider']*2+1, (parameters['BlendSlider']+1)*0.2)
|
| | mask = gauss(mask.type(torch.float32))
|
| | swap = swap*mask + original_face_512*(1-mask)
|
| |
|
| |
|
| | if parameters["RestorerSwitch"]:
|
| | swap = self.func_w_test('Restorer', self.apply_restorer, swap, parameters)
|
| |
|
| |
|
| |
|
| | if parameters["OccluderSwitch"]:
|
| | mask = self.func_w_test('occluder', self.apply_occlusion , original_face_256, parameters["OccluderSlider"])
|
| | mask = t128(mask)
|
| | swap_mask = torch.mul(swap_mask, mask)
|
| |
|
| |
|
| | if parameters["FaceParserSwitch"]:
|
| | mask = self.apply_face_parser(swap, parameters["FaceParserSlider"], parameters['MouthParserSlider'])
|
| | mask = t128(mask)
|
| | swap_mask = torch.mul(swap_mask, mask)
|
| |
|
| |
|
| | if parameters["CLIPSwitch"]:
|
| | with lock:
|
| | mask = self.func_w_test('CLIP', self.apply_CLIPs, original_face_512, parameters["CLIPTextEntry"], parameters["CLIPSlider"])
|
| | mask = cv2.resize(mask, (128,128))
|
| | mask = torch.from_numpy(mask).to('cuda')
|
| | swap_mask *= mask
|
| |
|
| |
|
| |
|
| | gauss = transforms.GaussianBlur(parameters['BlendSlider']*2+1, (parameters['BlendSlider']+1)*0.2)
|
| | swap_mask = gauss(swap_mask)
|
| |
|
| |
|
| |
|
| | swap_mask = torch.mul(swap_mask, border_mask)
|
| | swap_mask = t512(swap_mask)
|
| | swap = torch.mul(swap, swap_mask)
|
| |
|
| | if not control['MaskViewButton']:
|
| |
|
| | IM512 = tform.inverse.params[0:2, :]
|
| | corners = np.array([[0,0], [0,511], [511, 0], [511, 511]])
|
| |
|
| | x = (IM512[0][0]*corners[:,0] + IM512[0][1]*corners[:,1] + IM512[0][2])
|
| | y = (IM512[1][0]*corners[:,0] + IM512[1][1]*corners[:,1] + IM512[1][2])
|
| |
|
| | left = floor(np.min(x))
|
| | if left<0:
|
| | left=0
|
| | top = floor(np.min(y))
|
| | if top<0:
|
| | top=0
|
| | right = ceil(np.max(x))
|
| | if right>img.shape[2]:
|
| | right=img.shape[2]
|
| | bottom = ceil(np.max(y))
|
| | if bottom>img.shape[1]:
|
| | bottom=img.shape[1]
|
| |
|
| |
|
| | swap = v2.functional.pad(swap, (0,0,img.shape[2]-512, img.shape[1]-512))
|
| | swap = v2.functional.affine(swap, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0,interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) )
|
| | swap = swap[0:3, top:bottom, left:right]
|
| | swap = swap.permute(1, 2, 0)
|
| |
|
| |
|
| | swap_mask = v2.functional.pad(swap_mask, (0,0,img.shape[2]-512, img.shape[1]-512))
|
| | swap_mask = v2.functional.affine(swap_mask, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0, interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) )
|
| | swap_mask = swap_mask[0:1, top:bottom, left:right]
|
| | swap_mask = swap_mask.permute(1, 2, 0)
|
| | swap_mask = torch.sub(1, swap_mask)
|
| |
|
| |
|
| | img_crop = img[0:3, top:bottom, left:right]
|
| | img_crop = img_crop.permute(1,2,0)
|
| | img_crop = torch.mul(swap_mask,img_crop)
|
| |
|
| |
|
| | swap = torch.add(swap, img_crop)
|
| | swap = swap.type(torch.uint8)
|
| | swap = swap.permute(2,0,1)
|
| | img[0:3, top:bottom, left:right] = swap
|
| |
|
| | else:
|
| |
|
| | swap_mask = torch.sub(1, swap_mask)
|
| |
|
| |
|
| | original_face_512 = torch.mul(swap_mask, original_face_512)
|
| | original_face_512 = torch.add(swap, original_face_512)
|
| | original_face_512 = original_face_512.type(torch.uint8)
|
| | original_face_512 = original_face_512.permute(1, 2, 0)
|
| |
|
| |
|
| | swap_mask = torch.sub(1, swap_mask)
|
| | swap_mask = torch.cat((swap_mask,swap_mask,swap_mask),0)
|
| | swap_mask = swap_mask.permute(1, 2, 0)
|
| |
|
| |
|
| | img = torch.hstack([original_face_512, swap_mask*255])
|
| | img = img.permute(2,0,1)
|
| |
|
| | return img
|
| |
|
| |
|
| | def apply_occlusion(self, img, amount):
|
| | img = torch.div(img, 255)
|
| | img = torch.unsqueeze(img, 0)
|
| | outpred = torch.ones((256,256), dtype=torch.float32, device=device).contiguous()
|
| |
|
| | self.models.run_occluder(img, outpred)
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| | outpred = (outpred > 0)
|
| | outpred = torch.unsqueeze(outpred, 0).type(torch.float32)
|
| |
|
| | if amount >0:
|
| | kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
|
| |
|
| | for i in range(int(amount)):
|
| | outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
|
| | outpred = torch.clamp(outpred, 0, 1)
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| |
|
| | if amount <0:
|
| | outpred = torch.neg(outpred)
|
| | outpred = torch.add(outpred, 1)
|
| | kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
|
| |
|
| | for i in range(int(-amount)):
|
| | outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
|
| | outpred = torch.clamp(outpred, 0, 1)
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| | outpred = torch.neg(outpred)
|
| | outpred = torch.add(outpred, 1)
|
| |
|
| | outpred = torch.reshape(outpred, (1, 256, 256))
|
| | return outpred
|
| |
|
| |
|
| | def apply_CLIPs(self, img, CLIPText, CLIPAmount):
|
| | clip_mask = np.ones((352, 352))
|
| | img = img.permute(1,2,0)
|
| | img = img.cpu().numpy()
|
| |
|
| |
|
| | transform = transforms.Compose([transforms.ToTensor(),
|
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| | transforms.Resize((352, 352))])
|
| | CLIPimg = transform(img).unsqueeze(0)
|
| |
|
| | if CLIPText != "":
|
| | prompts = CLIPText.split(',')
|
| |
|
| | with torch.no_grad():
|
| | preds = self.clip_session(CLIPimg.repeat(len(prompts),1,1,1), prompts)[0]
|
| |
|
| |
|
| | clip_mask = 1 - torch.sigmoid(preds[0][0])
|
| | for i in range(len(prompts)-1):
|
| | clip_mask *= 1-torch.sigmoid(preds[i+1][0])
|
| | clip_mask = clip_mask.data.cpu().numpy()
|
| |
|
| | thresh = CLIPAmount/100.0
|
| | clip_mask[clip_mask>thresh] = 1.0
|
| | clip_mask[clip_mask<=thresh] = 0.0
|
| | return clip_mask
|
| |
|
| |
|
| | def apply_face_parser(self, img, FaceAmount, MouthAmount):
|
| |
|
| |
|
| |
|
| | outpred = torch.ones((512,512), dtype=torch.float32, device='cuda').contiguous()
|
| |
|
| |
|
| | img = torch.div(img, 255)
|
| | img = v2.functional.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
| | img = torch.reshape(img, (1, 3, 512, 512))
|
| | outpred = torch.empty((1,19,512,512), dtype=torch.float32, device='cuda').contiguous()
|
| |
|
| | self.models.run_faceparser(img, outpred)
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| | outpred = torch.argmax(outpred, 0)
|
| |
|
| |
|
| | if MouthAmount <0:
|
| | mouth_idxs = torch.tensor([11], device='cuda')
|
| | iters = int(-MouthAmount)
|
| |
|
| | mouth_parse = torch.isin(outpred, mouth_idxs)
|
| | mouth_parse = torch.clamp(~mouth_parse, 0, 1).type(torch.float32)
|
| | mouth_parse = torch.reshape(mouth_parse, (1, 1, 512, 512))
|
| | mouth_parse = torch.neg(mouth_parse)
|
| | mouth_parse = torch.add(mouth_parse, 1)
|
| |
|
| | kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32,
|
| | device='cuda')
|
| |
|
| | for i in range(iters):
|
| | mouth_parse = torch.nn.functional.conv2d(mouth_parse, kernel,
|
| | padding=(1, 1))
|
| | mouth_parse = torch.clamp(mouth_parse, 0, 1)
|
| |
|
| | mouth_parse = torch.squeeze(mouth_parse)
|
| | mouth_parse = torch.neg(mouth_parse)
|
| | mouth_parse = torch.add(mouth_parse, 1)
|
| | mouth_parse = torch.reshape(mouth_parse, (1, 512, 512))
|
| |
|
| | elif MouthAmount >0:
|
| | mouth_idxs = torch.tensor([11,12,13], device='cuda')
|
| | iters = int(MouthAmount)
|
| |
|
| | mouth_parse = torch.isin(outpred, mouth_idxs)
|
| | mouth_parse = torch.clamp(~mouth_parse, 0, 1).type(torch.float32)
|
| | mouth_parse = torch.reshape(mouth_parse, (1,1,512,512))
|
| | mouth_parse = torch.neg(mouth_parse)
|
| | mouth_parse = torch.add(mouth_parse, 1)
|
| |
|
| | kernel = torch.ones((1,1,3,3), dtype=torch.float32, device='cuda')
|
| |
|
| | for i in range(iters):
|
| | mouth_parse = torch.nn.functional.conv2d(mouth_parse, kernel, padding=(1, 1))
|
| | mouth_parse = torch.clamp(mouth_parse, 0, 1)
|
| |
|
| | mouth_parse = torch.squeeze(mouth_parse)
|
| | mouth_parse = torch.neg(mouth_parse)
|
| | mouth_parse = torch.add(mouth_parse, 1)
|
| | mouth_parse = torch.reshape(mouth_parse, (1, 512, 512))
|
| |
|
| | else:
|
| | mouth_parse = torch.ones((1, 512, 512), dtype=torch.float32, device='cuda')
|
| |
|
| |
|
| | bg_idxs = torch.tensor([0, 14, 15, 16, 17, 18], device=device)
|
| | bg_parse = torch.isin(outpred, bg_idxs)
|
| | bg_parse = torch.clamp(~bg_parse, 0, 1).type(torch.float32)
|
| | bg_parse = torch.reshape(bg_parse, (1, 1, 512, 512))
|
| |
|
| | if FaceAmount > 0:
|
| | kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32, device=device)
|
| |
|
| | for i in range(int(FaceAmount)):
|
| | bg_parse = torch.nn.functional.conv2d(bg_parse, kernel, padding=(1, 1))
|
| | bg_parse = torch.clamp(bg_parse, 0, 1)
|
| |
|
| | bg_parse = torch.squeeze(bg_parse)
|
| |
|
| | elif FaceAmount < 0:
|
| | bg_parse = torch.neg(bg_parse)
|
| | bg_parse = torch.add(bg_parse, 1)
|
| |
|
| | kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32, device=device)
|
| |
|
| | for i in range(int(-FaceAmount)):
|
| | bg_parse = torch.nn.functional.conv2d(bg_parse, kernel, padding=(1, 1))
|
| | bg_parse = torch.clamp(bg_parse, 0, 1)
|
| |
|
| | bg_parse = torch.squeeze(bg_parse)
|
| | bg_parse = torch.neg(bg_parse)
|
| | bg_parse = torch.add(bg_parse, 1)
|
| | bg_parse = torch.reshape(bg_parse, (1, 512, 512))
|
| | else:
|
| | bg_parse = torch.ones((1,512,512), dtype=torch.float32, device='cuda')
|
| |
|
| | out_parse = torch.mul(bg_parse, mouth_parse)
|
| |
|
| | return out_parse
|
| |
|
| | def apply_bg_face_parser(self, img, FaceParserAmount):
|
| |
|
| |
|
| |
|
| |
|
| | outpred = torch.ones((512,512), dtype=torch.float32, device='cuda').contiguous()
|
| |
|
| |
|
| | if FaceParserAmount != 0:
|
| | img = torch.div(img, 255)
|
| | img = v2.functional.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
| | img = torch.reshape(img, (1, 3, 512, 512))
|
| | outpred = torch.empty((1,19,512,512), dtype=torch.float32, device=device).contiguous()
|
| |
|
| | self.models.run_faceparser(img, outpred)
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| | outpred = torch.argmax(outpred, 0)
|
| |
|
| | test = torch.tensor([ 0, 14, 15, 16, 17, 18], device=device)
|
| | outpred = torch.isin(outpred, test)
|
| | outpred = torch.clamp(~outpred, 0, 1).type(torch.float32)
|
| | outpred = torch.reshape(outpred, (1,1,512,512))
|
| |
|
| | if FaceParserAmount >0:
|
| | kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
|
| |
|
| | for i in range(int(FaceParserAmount)):
|
| | outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
|
| | outpred = torch.clamp(outpred, 0, 1)
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| |
|
| | if FaceParserAmount <0:
|
| | outpred = torch.neg(outpred)
|
| | outpred = torch.add(outpred, 1)
|
| |
|
| | kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
|
| |
|
| | for i in range(int(-FaceParserAmount)):
|
| | outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
|
| | outpred = torch.clamp(outpred, 0, 1)
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| | outpred = torch.neg(outpred)
|
| | outpred = torch.add(outpred, 1)
|
| |
|
| | outpred = torch.reshape(outpred, (1, 512, 512))
|
| |
|
| | return outpred
|
| |
|
| |
|
| |
|
| | def apply_restorer(self, swapped_face_upscaled, parameters):
|
| | temp = swapped_face_upscaled
|
| | t512 = v2.Resize((512, 512), antialias=False)
|
| | t256 = v2.Resize((256, 256), antialias=False)
|
| | t1024 = v2.Resize((1024, 1024), antialias=False)
|
| |
|
| |
|
| | if parameters['RestorerDetTypeTextSel'] == 'Blend' or parameters['RestorerDetTypeTextSel'] == 'Reference':
|
| | if parameters['RestorerDetTypeTextSel'] == 'Blend':
|
| |
|
| | dst = self.arcface_dst * 4.0
|
| | dst[:,0] += 32.0
|
| |
|
| | elif parameters['RestorerDetTypeTextSel'] == 'Reference':
|
| | try:
|
| | dst = self.models.resnet50(swapped_face_upscaled, score=parameters['DetectScoreSlider']/100.0)
|
| | except:
|
| | return swapped_face_upscaled
|
| |
|
| | tform = trans.SimilarityTransform()
|
| | tform.estimate(dst, self.FFHQ_kps)
|
| |
|
| |
|
| | temp = v2.functional.affine(swapped_face_upscaled, tform.rotation*57.2958, (tform.translation[0], tform.translation[1]) , tform.scale, 0, center = (0,0) )
|
| | temp = v2.functional.crop(temp, 0,0, 512, 512)
|
| |
|
| | temp = torch.div(temp, 255)
|
| | temp = v2.functional.normalize(temp, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=False)
|
| | if parameters['RestorerTypeTextSel'] == 'GPEN256':
|
| | temp = t256(temp)
|
| | temp = torch.unsqueeze(temp, 0).contiguous()
|
| |
|
| |
|
| | outpred = torch.empty((1,3,512,512), dtype=torch.float32, device=device).contiguous()
|
| |
|
| | if parameters['RestorerTypeTextSel'] == 'GFPGAN':
|
| | self.models.run_GFPGAN(temp, outpred)
|
| |
|
| | elif parameters['RestorerTypeTextSel'] == 'CF':
|
| | self.models.run_codeformer(temp, outpred)
|
| |
|
| | elif parameters['RestorerTypeTextSel'] == 'GPEN256':
|
| | outpred = torch.empty((1,3,256,256), dtype=torch.float32, device=device).contiguous()
|
| | self.models.run_GPEN_256(temp, outpred)
|
| |
|
| | elif parameters['RestorerTypeTextSel'] == 'GPEN512':
|
| | self.models.run_GPEN_512(temp, outpred)
|
| |
|
| | elif parameters['RestorerTypeTextSel'] == 'GPEN1024':
|
| | temp = t1024(temp)
|
| | outpred = torch.empty((1, 3, 1024, 1024), dtype=torch.float32, device=device).contiguous()
|
| | self.models.run_GPEN_1024(temp, outpred)
|
| |
|
| |
|
| | outpred = torch.squeeze(outpred)
|
| | outpred = torch.clamp(outpred, -1, 1)
|
| | outpred = torch.add(outpred, 1)
|
| | outpred = torch.div(outpred, 2)
|
| | outpred = torch.mul(outpred, 255)
|
| | if parameters['RestorerTypeTextSel'] == 'GPEN256':
|
| | outpred = t512(outpred)
|
| | elif parameters['RestorerTypeTextSel'] == 'GPEN1024':
|
| | outpred = t512(outpred)
|
| |
|
| | if parameters['RestorerDetTypeTextSel'] == 'Blend' or parameters['RestorerDetTypeTextSel'] == 'Reference':
|
| | outpred = v2.functional.affine(outpred, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0, interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) )
|
| |
|
| |
|
| | alpha = float(parameters["RestorerSlider"])/100.0
|
| | outpred = torch.add(torch.mul(outpred, alpha), torch.mul(swapped_face_upscaled, 1-alpha))
|
| |
|
| | return outpred
|
| |
|
| | def apply_fake_diff(self, swapped_face, original_face, DiffAmount):
|
| | swapped_face = swapped_face.permute(1,2,0)
|
| | original_face = original_face.permute(1,2,0)
|
| |
|
| | diff = swapped_face-original_face
|
| | diff = torch.abs(diff)
|
| |
|
| |
|
| | fthresh = DiffAmount*2.55
|
| |
|
| |
|
| | diff[diff<fthresh] = 0
|
| | diff[diff>=fthresh] = 1
|
| |
|
| |
|
| | diff = torch.sum(diff, dim=2)
|
| | diff = torch.unsqueeze(diff, 2)
|
| | diff[diff>0] = 1
|
| |
|
| | diff = diff.permute(2,0,1)
|
| |
|
| | return diff
|
| |
|
| |
|
| |
|
| | def clear_mem(self):
|
| | del self.swapper_model
|
| | del self.GFPGAN_model
|
| | del self.occluder_model
|
| | del self.face_parsing_model
|
| | del self.codeformer_model
|
| | del self.GPEN_256_model
|
| | del self.GPEN_512_model
|
| | del self.GPEN_1024_model
|
| | del self.resnet_model
|
| | del self.detection_model
|
| | del self.recognition_model
|
| |
|
| | self.swapper_model = []
|
| | self.GFPGAN_model = []
|
| | self.occluder_model = []
|
| | self.face_parsing_model = []
|
| | self.codeformer_model = []
|
| | self.GPEN_256_model = []
|
| | self.GPEN_512_model = []
|
| | self.GPEN_1024_model = []
|
| | self.resnet_model = []
|
| | self.detection_model = []
|
| | self.recognition_model = []
|
| |
|
| |
|
| |
|
| |
|
| |
|