fixed_rope / rope /VideoManager.py
MonsterMMORPG's picture
Upload folder using huggingface_hub (#1)
17eddbd verified
import os
import cv2
import tkinter as tk
from PIL import Image, ImageTk
import threading
import time
import numpy as np
from skimage import transform as trans
import subprocess
from math import floor, ceil
import bisect
import onnxruntime
import torchvision
from torchvision.transforms.functional import normalize #update to v2
import torch
from torchvision import transforms
torchvision.disable_beta_transforms_warning()
from torchvision.transforms import v2
torch.set_grad_enabled(False)
onnxruntime.set_default_logger_severity(4)
import inspect #print(inspect.currentframe().f_back.f_code.co_name, 'resize_image')
device = 'cuda'
lock=threading.Lock()
class VideoManager():
def __init__(self, models ):
self.models = models
# Model related
self.swapper_model = [] # insightface swapper model
# self.faceapp_model = [] # insight faceapp model
self.input_names = [] # names of the inswapper.onnx inputs
self.input_size = [] # size of the inswapper.onnx inputs
self.output_names = [] # names of the inswapper.onnx outputs
self.arcface_dst = np.array( [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
self.video_file = []
self.FFHQ_kps = np.array([[ 192.98138, 239.94708 ], [ 318.90277, 240.1936 ], [ 256.63416, 314.01935 ], [ 201.26117, 371.41043 ], [ 313.08905, 371.15118 ] ])
#Video related
self.capture = [] # cv2 video
self.is_video_loaded = False # flag for video loaded state
self.video_frame_total = None # length of currently loaded video
self.play = False # flag for the play button toggle
self.current_frame = 0 # the current frame of the video
self.create_video = False
self.output_video = []
self.file_name = []
# Play related
# self.set_read_threads = [] # Name of threaded function
self.frame_timer = 0.0 # used to set the framerate during playing
# Queues
self.action_q = [] # queue for sending to the coordinator
self.frame_q = [] # queue for frames that are ready for coordinator
self.r_frame_q = [] # queue for frames that are requested by the GUI
self.read_video_frame_q = []
# swapping related
# self.source_embedding = [] # array with indexed source embeddings
self.found_faces = [] # array that maps the found faces to source faces
self.parameters = []
self.target_video = []
self.fps = 1.0
self.temp_file = []
self.clip_session = []
self.start_time = []
self.record = False
self.output = []
self.image = []
self.saved_video_path = []
self.sp = []
self.timer = []
self.fps_average = []
self.total_thread_time = 0.0
self.start_play_time = []
self.start_play_frame = []
self.rec_thread = []
self.markers = []
self.is_image_loaded = False
self.stop_marker = -1
self.perf_test = False
self.control = []
self.process_q = {
"Thread": [],
"FrameNumber": [],
"ProcessedFrame": [],
"Status": 'clear',
"ThreadTime": []
}
self.process_qs = []
self.rec_q = {
"Thread": [],
"FrameNumber": [],
"Status": 'clear'
}
self.rec_qs = []
def assign_found_faces(self, found_faces):
self.found_faces = found_faces
def load_target_video( self, file ):
# If we already have a video loaded, release it
if self.capture:
self.capture.release()
# Open file
self.video_file = file
self.capture = cv2.VideoCapture(file)
self.fps = self.capture.get(cv2.CAP_PROP_FPS)
if not self.capture.isOpened():
print("Cannot open file: ", file)
else:
self.target_video = file
self.is_video_loaded = True
self.is_image_loaded = False
self.video_frame_total = int(self.capture.get(cv2.CAP_PROP_FRAME_COUNT))
self.play = False
self.current_frame = 0
self.frame_timer = time.time()
self.frame_q = []
self.r_frame_q = []
self.found_faces = []
self.add_action("set_slider_length",self.video_frame_total-1)
self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
success, image = self.capture.read()
if success:
crop = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB
temp = [crop, False]
self.r_frame_q.append(temp)
self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
def load_target_image(self, file):
if self.capture:
self.capture.release()
self.is_video_loaded = False
self.play = False
self.frame_q = []
self.r_frame_q = []
self.found_faces = []
self.image = cv2.imread(file) # BGR
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) # RGB
temp = [self.image, False]
self.frame_q.append(temp)
self.is_image_loaded = True
## Action queue
def add_action(self, action, param):
# print(inspect.currentframe().f_back.f_code.co_name, '->add_action: '+action)
temp = [action, param]
self.action_q.append(temp)
def get_action_length(self):
return len(self.action_q)
def get_action(self):
action = self.action_q[0]
self.action_q.pop(0)
return action
## Queues for the Coordinator
def get_frame(self):
frame = self.frame_q[0]
self.frame_q.pop(0)
return frame
def get_frame_length(self):
return len(self.frame_q)
def get_requested_frame(self):
frame = self.r_frame_q[0]
self.r_frame_q.pop(0)
return frame
def get_requested_frame_length(self):
return len(self.r_frame_q)
def get_requested_video_frame(self, frame, marker=True):
temp = []
if self.is_video_loaded:
if self.play == True:
self.play_video("stop")
self.process_qs = []
self.current_frame = int(frame)
self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
success, target_image = self.capture.read() #BGR
if success:
target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB) #RGB
if not self.control['SwapFacesButton']:
temp = [target_image, self.current_frame] #temp = RGB
else:
temp = [self.swap_video(target_image, self.current_frame, marker), self.current_frame] # temp = RGB
self.r_frame_q.append(temp)
elif self.is_image_loaded:
if not self.control['SwapFacesButton']:
temp = [self.image, self.current_frame] # image = RGB
else:
temp = [self.swap_video(self.image, self.current_frame, False), self.current_frame] # image = RGB
self.r_frame_q.append(temp)
def find_lowest_frame(self, queues):
min_frame=999999999
index=-1
for idx, thread in enumerate(queues):
frame = thread['FrameNumber']
if frame != []:
if frame < min_frame:
min_frame = frame
index=idx
return index, min_frame
def play_video(self, command):
# print(inspect.currentframe().f_back.f_code.co_name, '->play_video: ')
if command == "play":
# Initialization
self.play = True
self.fps_average = []
self.process_qs = []
self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
self.frame_timer = time.time()
# Create reusable queue based on number of threads
for i in range(self.parameters['ThreadsSlider']):
new_process_q = self.process_q.copy()
self.process_qs.append(new_process_q)
# Start up audio if requested
if self.control['AudioButton']:
seek_time = (self.current_frame)/self.fps
args = ["ffplay",
'-vn',
'-ss', str(seek_time),
'-nodisp',
'-stats',
'-loglevel', 'quiet',
'-sync', 'audio',
self.video_file]
self.audio_sp = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
# Parse the console to find where the audio started
while True:
temp = self.audio_sp.stdout.read(69)
if temp[:7] != b' nan':
sought_time = float(temp[:7])
self.current_frame = int(self.fps*sought_time)
self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
break
#' nan : 0.000
#' 1.25 M-A: 0.000 fd= 0 aq= 12KB vq= 0KB sq= 0B f=0/0'
elif command == "stop":
self.play = False
self.add_action("stop_play", True)
index, min_frame = self.find_lowest_frame(self.process_qs)
if index != -1:
self.current_frame = min_frame-1
if self.control['AudioButton']:
self.audio_sp.terminate()
torch.cuda.empty_cache()
elif command=='stop_from_gui':
self.play = False
# Find the lowest frame in the current render queue and set the current frame to the one before it
index, min_frame = self.find_lowest_frame(self.process_qs)
if index != -1:
self.current_frame = min_frame-1
if self.control['AudioButton']:
self.audio_sp.terminate()
torch.cuda.empty_cache()
elif command == "record":
self.record = True
self.play = True
self.total_thread_time = 0.0
self.process_qs = []
self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame)
for i in range(self.parameters['ThreadsSlider']):
new_process_q = self.process_q.copy()
self.process_qs.append(new_process_q)
# Initialize
self.timer = time.time()
frame_width = int(self.capture.get(3))
frame_width = int(self.capture.get(3))
frame_height = int(self.capture.get(4))
self.start_time = float(self.capture.get(cv2.CAP_PROP_POS_FRAMES) / float(self.fps))
self.file_name = os.path.splitext(os.path.basename(self.target_video))
base_filename = self.file_name[0]+"_"+str(time.time())[:10]
self.output = os.path.join(self.saved_video_path, base_filename)
self.temp_file = self.output+"_temp"+self.file_name[1]
if self.parameters['RecordTypeTextSel']=='FFMPEG':
args = ["ffmpeg",
'-hide_banner',
'-loglevel', 'error',
"-an",
"-r", str(self.fps),
"-i", "pipe:",
# '-g', '25',
"-vf", "format=yuvj420p",
"-c:v", "libx264",
"-crf", str(self.parameters['VideoQualSlider']),
"-r", str(self.fps),
"-s", str(frame_width)+"x"+str(frame_height),
self.temp_file]
self.sp = subprocess.Popen(args, stdin=subprocess.PIPE)
elif self.parameters['RecordTypeTextSel']=='OPENCV':
size = (frame_width, frame_height)
self.sp = cv2.VideoWriter(self.temp_file, cv2.VideoWriter_fourcc(*'mp4v') , self.fps, size)
# @profile
def process(self):
process_qs_len = range(len(self.process_qs))
# Add threads to Queue
if self.play == True and self.is_video_loaded == True:
for item in self.process_qs:
if item['Status'] == 'clear' and self.current_frame < self.video_frame_total:
item['Thread'] = threading.Thread(target=self.thread_video_read, args = [self.current_frame]).start()
item['FrameNumber'] = self.current_frame
item['Status'] = 'started'
item['ThreadTime'] = time.time()
self.current_frame += 1
break
else:
self.play = False
# Always be emptying the queues
time_diff = time.time() - self.frame_timer
if not self.record and time_diff >= 1.0/float(self.fps) and self.play:
index, min_frame = self.find_lowest_frame(self.process_qs)
if index != -1:
if self.process_qs[index]['Status'] == 'finished':
temp = [self.process_qs[index]['ProcessedFrame'], self.process_qs[index]['FrameNumber']]
self.frame_q.append(temp)
# Report fps, other data
self.fps_average.append(1.0/time_diff)
if len(self.fps_average) >= floor(self.fps):
fps = round(np.average(self.fps_average), 2)
msg = "%s fps, %s process time" % (fps, round(self.process_qs[index]['ThreadTime'], 4))
self.fps_average = []
if self.process_qs[index]['FrameNumber'] >= self.video_frame_total-1 or self.process_qs[index]['FrameNumber'] == self.stop_marker:
self.play_video('stop')
self.process_qs[index]['Status'] = 'clear'
self.process_qs[index]['Thread'] = []
self.process_qs[index]['FrameNumber'] = []
self.process_qs[index]['ThreadTime'] = []
self.frame_timer += 1.0/self.fps
elif self.record:
index, min_frame = self.find_lowest_frame(self.process_qs)
if index != -1:
# If the swapper thread has finished generating a frame
if self.process_qs[index]['Status'] == 'finished':
image = self.process_qs[index]['ProcessedFrame']
if self.parameters['RecordTypeTextSel']=='FFMPEG':
pil_image = Image.fromarray(image)
pil_image.save(self.sp.stdin, 'BMP')
elif self.parameters['RecordTypeTextSel']=='OPENCV':
self.sp.write(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
temp = [image, self.process_qs[index]['FrameNumber']]
self.frame_q.append(temp)
# Close video and process
if self.process_qs[index]['FrameNumber'] >= self.video_frame_total-1 or self.process_qs[index]['FrameNumber'] == self.stop_marker or self.play == False:
self.play_video("stop")
stop_time = float(self.capture.get(cv2.CAP_PROP_POS_FRAMES) / float(self.fps))
if stop_time == 0:
stop_time = float(self.video_frame_total) / float(self.fps)
if self.parameters['RecordTypeTextSel']=='FFMPEG':
self.sp.stdin.close()
self.sp.wait()
elif self.parameters['RecordTypeTextSel']=='OPENCV':
self.sp.release()
orig_file = self.target_video
final_file = self.output+self.file_name[1]
print("adding audio...")
args = ["ffmpeg",
'-hide_banner',
'-loglevel', 'error',
"-i", self.temp_file,
"-ss", str(self.start_time), "-to", str(stop_time), "-i", orig_file,
"-c", "copy", # may be c:v
"-map", "0:v:0", "-map", "1:a:0?",
"-shortest",
final_file]
four = subprocess.run(args)
os.remove(self.temp_file)
timef= time.time() - self.timer
self.record = False
print('Video saved as:', final_file)
msg = "Total time: %s s." % (round(timef,1))
print(msg)
self.total_thread_time = []
self.process_qs[index]['Status'] = 'clear'
self.process_qs[index]['FrameNumber'] = []
self.process_qs[index]['Thread'] = []
self.frame_timer = time.time()
# @profile
def thread_video_read(self, frame_number):
with lock:
success, target_image = self.capture.read()
if success:
target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB)
if not self.control['SwapFacesButton']:
temp = [target_image, frame_number]
else:
temp = [self.swap_video(target_image, frame_number, True), frame_number]
for item in self.process_qs:
if item['FrameNumber'] == frame_number:
item['ProcessedFrame'] = temp[0]
item['Status'] = 'finished'
item['ThreadTime'] = time.time() - item['ThreadTime']
break
# @profile
def swap_video(self, target_image, frame_number, use_markers):
# Grab a local copy of the parameters to prevent threading issues
parameters = self.parameters.copy()
control = self.control.copy()
# Find out if the frame is in a marker zone and copy the parameters if true
if self.markers and use_markers:
temp=[]
for i in range(len(self.markers)):
temp.append(self.markers[i]['frame'])
idx = bisect.bisect(temp, frame_number)
parameters = self.markers[idx-1]['parameters'].copy()
# Load frame into VRAM
img = torch.from_numpy(target_image.astype('uint8')).to('cuda') #HxWxc
img = img.permute(2,0,1)#cxHxW
#Scale up frame if it is smaller than 512
img_x = img.size()[2]
img_y = img.size()[1]
if img_x<512 and img_y<512:
# if x is smaller, set x to 512
if img_x <= img_y:
tscale = v2.Resize((int(512*img_y/img_x), 512), antialias=True)
else:
tscale = v2.Resize((512, int(512*img_x/img_y)), antialias=True)
img = tscale(img)
elif img_x<512:
tscale = v2.Resize((int(512*img_y/img_x), 512), antialias=True)
img = tscale(img)
elif img_y<512:
tscale = v2.Resize((512, int(512*img_x/img_y)), antialias=True)
img = tscale(img)
# Rotate the frame
if parameters['OrientSwitch']:
img = v2.functional.rotate(img, angle=parameters['OrientSlider'], interpolation=v2.InterpolationMode.BILINEAR, expand=True)
# Find all faces in frame and return a list of 5-pt kpss
bboxes, kpss = self.func_w_test("detect", self.models.run_detect, img, parameters['DetectTypeTextSel'], max_num=20, score=parameters['DetectScoreSlider']/100.0, use_landmark_detection=parameters['LandmarksDetectionAdjSwitch'], landmark_detect_mode=parameters["LandmarksDetectTypeTextSel"], landmark_score=parameters["LandmarksDetectScoreSlider"]/100.0, from_points=parameters["LandmarksAlignModeFromPointsSwitch"])
# Get embeddings for all faces found in the frame
ret = []
for face_kps in kpss:
face_emb, _ = self.func_w_test('recognize', self.models.run_recognize, img, face_kps)
ret.append([face_kps, face_emb])
if ret:
# Loop through target faces to see if they match our found face embeddings
for fface in ret:
for found_face in self.found_faces:
# sim between face in video and already found face
sim = self.findCosineDistance(fface[1], found_face["Embedding"])
# if the face[i] in the frame matches afound face[j] AND the found face is active (not [])
if sim>=float(parameters["ThresholdSlider"]) and found_face["SourceFaceAssignments"]:
s_e = found_face["AssignedEmbedding"]
# s_e = found_face['ptrdata']
img = self.func_w_test("swap_video", self.swap_core, img, fface[0], s_e, parameters, control)
# img = img.permute(2,0,1)
img = img.permute(1,2,0)
if not control['MaskViewButton'] and parameters['OrientSwitch']:
img = img.permute(2,0,1)
img = transforms.functional.rotate(img, angle=-parameters['OrientSlider'], expand=True)
img = img.permute(1,2,0)
else:
img = img.permute(1,2,0)
if parameters['OrientSwitch']:
img = img.permute(2,0,1)
img = v2.functional.rotate(img, angle=-parameters['OrientSlider'], interpolation=v2.InterpolationMode.BILINEAR, expand=True)
img = img.permute(1,2,0)
if self.perf_test:
print('------------------------')
# Unscale small videos
if img_x <512 or img_y < 512:
tscale = v2.Resize((img_y, img_x), antialias=True)
img = img.permute(2,0,1)
img = tscale(img)
img = img.permute(1,2,0)
img = img.cpu().numpy()
if parameters["ShowLandmarksSwitch"]:
if ret:
if img_y <= 720:
p = 1
else:
p = 2
for face in ret:
for kpoint in face[0]:
for i in range(-1, p):
for j in range(-1, p):
try:
img[int(kpoint[1])+i][int(kpoint[0])+j][0] = 0
img[int(kpoint[1])+i][int(kpoint[0])+j][1] = 255
img[int(kpoint[1])+i][int(kpoint[0])+j][2] = 255
except:
print("Key-points value {} exceed the image size {}.".format(kpoint, (img_x, img_y)))
continue
return img.astype(np.uint8)
def findCosineDistance(self, vector1, vector2):
vector1 = vector1.ravel()
vector2 = vector2.ravel()
cos_dist = 1.0 - np.dot(vector1, vector2)/(np.linalg.norm(vector1)*np.linalg.norm(vector2)) # 2..0
return 100.0-cos_dist*50.0
'''
vector1 = vector1.ravel()
vector2 = vector2.ravel()
return 1 - np.dot(vector1, vector2)/(np.linalg.norm(vector1)*np.linalg.norm(vector2))
'''
def func_w_test(self, name, func, *args, **argsv):
timing = time.time()
result = func(*args, **argsv)
if self.perf_test:
print(name, round(time.time()-timing, 5), 's')
return result
# @profile
def swap_core(self, img, kps, s_e, parameters, control): # img = RGB
# 512 transforms
dst = self.arcface_dst * 4.0
dst[:,0] += 32.0
# Change the ref points
if parameters['FaceAdjSwitch']:
dst[:,0] += parameters['KPSXSlider']
dst[:,1] += parameters['KPSYSlider']
dst[:,0] -= 255
dst[:,0] *= (1+parameters['KPSScaleSlider']/100)
dst[:,0] += 255
dst[:,1] -= 255
dst[:,1] *= (1+parameters['KPSScaleSlider']/100)
dst[:,1] += 255
tform = trans.SimilarityTransform()
tform.estimate(kps, dst)
# Scaling Transforms
t512 = v2.Resize((512, 512), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
t256 = v2.Resize((256, 256), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
t128 = v2.Resize((128, 128), interpolation=v2.InterpolationMode.BILINEAR, antialias=False)
# Grab 512 face from image and create 256 and 128 copys
original_face_512 = v2.functional.affine(img, tform.rotation*57.2958, (tform.translation[0], tform.translation[1]) , tform.scale, 0, center = (0,0), interpolation=v2.InterpolationMode.BILINEAR )
original_face_512 = v2.functional.crop(original_face_512, 0,0, 512, 512)# 3, 512, 512
original_face_256 = t256(original_face_512)
original_face_128 = t128(original_face_256)
latent = torch.from_numpy(self.models.calc_swapper_latent(s_e)).float().to('cuda')
dim = 1
if parameters['SwapperTypeTextSel'] == '128':
dim = 1
input_face_affined = original_face_128
elif parameters['SwapperTypeTextSel'] == '256':
dim = 2
input_face_affined = original_face_256
elif parameters['SwapperTypeTextSel'] == '512':
dim = 4
input_face_affined = original_face_512
# Optional Scaling # change the thransform matrix
if parameters['FaceAdjSwitch']:
input_face_affined = v2.functional.affine(input_face_affined, 0, (0, 0), 1 + parameters['FaceScaleSlider'] / 100, 0, center=(dim*128-1, dim*128-1), interpolation=v2.InterpolationMode.BILINEAR)
itex = 1
if parameters['StrengthSwitch']:
itex = ceil(parameters['StrengthSlider'] / 100.)
output_size = int(128 * dim)
output = torch.zeros((output_size, output_size, 3), dtype=torch.float32, device='cuda')
input_face_affined = input_face_affined.permute(1, 2, 0)
input_face_affined = torch.div(input_face_affined, 255.0)
for k in range(itex):
for j in range(dim):
for i in range(dim):
input_face_disc = input_face_affined[j::dim,i::dim]
input_face_disc = input_face_disc.permute(2, 0, 1)
input_face_disc = torch.unsqueeze(input_face_disc, 0).contiguous()
swapper_output = torch.empty((1,3,128,128), dtype=torch.float32, device='cuda').contiguous()
self.models.run_swapper(input_face_disc, latent, swapper_output)
swapper_output = torch.squeeze(swapper_output)
swapper_output = swapper_output.permute(1, 2, 0)
output[j::dim, i::dim] = swapper_output.clone()
prev_face = input_face_affined.clone()
input_face_affined = output.clone()
output = torch.mul(output, 255)
output = torch.clamp(output, 0, 255)
output = output.permute(2, 0, 1)
swap = t512(output)
if parameters['StrengthSwitch']:
if itex == 0:
swap = original_face_512.clone()
else:
alpha = np.mod(parameters['StrengthSlider'], 100)*0.01
if alpha==0:
alpha=1
# Blend the images
prev_face = torch.mul(prev_face, 255)
prev_face = torch.clamp(prev_face, 0, 255)
prev_face = prev_face.permute(2, 0, 1)
prev_face = t512(prev_face)
swap = torch.mul(swap, alpha)
prev_face = torch.mul(prev_face, 1-alpha)
swap = torch.add(swap, prev_face)
# swap = torch.squeeze(swap)
# swap = torch.mul(swap, 255)
# swap = torch.clamp(swap, 0, 255)
# # swap_128 = swap
# swap = t256(swap)
# swap = t512(swap)
# Apply color corerctions
if parameters['ColorSwitch']:
# print(parameters['ColorGammaSlider'])
swap = torch.unsqueeze(swap,0)
swap = v2.functional.adjust_gamma(swap, parameters['ColorGammaSlider'], 1.0)
swap = torch.squeeze(swap)
swap = swap.permute(1, 2, 0).type(torch.float32)
del_color = torch.tensor([parameters['ColorRedSlider'], parameters['ColorGreenSlider'], parameters['ColorBlueSlider']], device=device)
swap += del_color
swap = torch.clamp(swap, min=0., max=255.)
swap = swap.permute(2, 0, 1).type(torch.uint8)
# Create border mask
border_mask = torch.ones((128, 128), dtype=torch.float32, device=device)
border_mask = torch.unsqueeze(border_mask,0)
# if parameters['BorderState']:
top = parameters['BorderTopSlider']
left = parameters['BorderSidesSlider']
right = 128-parameters['BorderSidesSlider']
bottom = 128-parameters['BorderBottomSlider']
border_mask[:, :top, :] = 0
border_mask[:, bottom:, :] = 0
border_mask[:, :, :left] = 0
border_mask[:, :, right:] = 0
gauss = transforms.GaussianBlur(parameters['BorderBlurSlider']*2+1, (parameters['BorderBlurSlider']+1)*0.2)
border_mask = gauss(border_mask)
# Create image mask
swap_mask = torch.ones((128, 128), dtype=torch.float32, device=device)
swap_mask = torch.unsqueeze(swap_mask,0)
# Face Diffing
if parameters["DiffSwitch"]:
mask = self.apply_fake_diff(swap, original_face_512, parameters["DiffSlider"])
# mask = t128(mask)
gauss = transforms.GaussianBlur(parameters['BlendSlider']*2+1, (parameters['BlendSlider']+1)*0.2)
mask = gauss(mask.type(torch.float32))
swap = swap*mask + original_face_512*(1-mask)
# Restorer
if parameters["RestorerSwitch"]:
swap = self.func_w_test('Restorer', self.apply_restorer, swap, parameters)
# Occluder
if parameters["OccluderSwitch"]:
mask = self.func_w_test('occluder', self.apply_occlusion , original_face_256, parameters["OccluderSlider"])
mask = t128(mask)
swap_mask = torch.mul(swap_mask, mask)
if parameters["FaceParserSwitch"]:
mask = self.apply_face_parser(swap, parameters["FaceParserSlider"], parameters['MouthParserSlider'])
mask = t128(mask)
swap_mask = torch.mul(swap_mask, mask)
# CLIPs
if parameters["CLIPSwitch"]:
with lock:
mask = self.func_w_test('CLIP', self.apply_CLIPs, original_face_512, parameters["CLIPTextEntry"], parameters["CLIPSlider"])
mask = cv2.resize(mask, (128,128))
mask = torch.from_numpy(mask).to('cuda')
swap_mask *= mask
# Add blur to swap_mask results
gauss = transforms.GaussianBlur(parameters['BlendSlider']*2+1, (parameters['BlendSlider']+1)*0.2)
swap_mask = gauss(swap_mask)
# Combine border and swap mask, scale, and apply to swap
swap_mask = torch.mul(swap_mask, border_mask)
swap_mask = t512(swap_mask)
swap = torch.mul(swap, swap_mask)
if not control['MaskViewButton']:
# Cslculate the area to be mergerd back to the original frame
IM512 = tform.inverse.params[0:2, :]
corners = np.array([[0,0], [0,511], [511, 0], [511, 511]])
x = (IM512[0][0]*corners[:,0] + IM512[0][1]*corners[:,1] + IM512[0][2])
y = (IM512[1][0]*corners[:,0] + IM512[1][1]*corners[:,1] + IM512[1][2])
left = floor(np.min(x))
if left<0:
left=0
top = floor(np.min(y))
if top<0:
top=0
right = ceil(np.max(x))
if right>img.shape[2]:
right=img.shape[2]
bottom = ceil(np.max(y))
if bottom>img.shape[1]:
bottom=img.shape[1]
# Untransform the swap
swap = v2.functional.pad(swap, (0,0,img.shape[2]-512, img.shape[1]-512))
swap = v2.functional.affine(swap, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0,interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) )
swap = swap[0:3, top:bottom, left:right]
swap = swap.permute(1, 2, 0)
# Untransform the swap mask
swap_mask = v2.functional.pad(swap_mask, (0,0,img.shape[2]-512, img.shape[1]-512))
swap_mask = v2.functional.affine(swap_mask, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0, interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) )
swap_mask = swap_mask[0:1, top:bottom, left:right]
swap_mask = swap_mask.permute(1, 2, 0)
swap_mask = torch.sub(1, swap_mask)
# Apply the mask to the original image areas
img_crop = img[0:3, top:bottom, left:right]
img_crop = img_crop.permute(1,2,0)
img_crop = torch.mul(swap_mask,img_crop)
#Add the cropped areas and place them back into the original image
swap = torch.add(swap, img_crop)
swap = swap.type(torch.uint8)
swap = swap.permute(2,0,1)
img[0:3, top:bottom, left:right] = swap
else:
# Invert swap mask
swap_mask = torch.sub(1, swap_mask)
# Combine preswapped face with swap
original_face_512 = torch.mul(swap_mask, original_face_512)
original_face_512 = torch.add(swap, original_face_512)
original_face_512 = original_face_512.type(torch.uint8)
original_face_512 = original_face_512.permute(1, 2, 0)
# Uninvert and create image from swap mask
swap_mask = torch.sub(1, swap_mask)
swap_mask = torch.cat((swap_mask,swap_mask,swap_mask),0)
swap_mask = swap_mask.permute(1, 2, 0)
# Place them side by side
img = torch.hstack([original_face_512, swap_mask*255])
img = img.permute(2,0,1)
return img
# @profile
def apply_occlusion(self, img, amount):
img = torch.div(img, 255)
img = torch.unsqueeze(img, 0)
outpred = torch.ones((256,256), dtype=torch.float32, device=device).contiguous()
self.models.run_occluder(img, outpred)
outpred = torch.squeeze(outpred)
outpred = (outpred > 0)
outpred = torch.unsqueeze(outpred, 0).type(torch.float32)
if amount >0:
kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
for i in range(int(amount)):
outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
outpred = torch.clamp(outpred, 0, 1)
outpred = torch.squeeze(outpred)
if amount <0:
outpred = torch.neg(outpred)
outpred = torch.add(outpred, 1)
kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
for i in range(int(-amount)):
outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
outpred = torch.clamp(outpred, 0, 1)
outpred = torch.squeeze(outpred)
outpred = torch.neg(outpred)
outpred = torch.add(outpred, 1)
outpred = torch.reshape(outpred, (1, 256, 256))
return outpred
def apply_CLIPs(self, img, CLIPText, CLIPAmount):
clip_mask = np.ones((352, 352))
img = img.permute(1,2,0)
img = img.cpu().numpy()
# img = img.to(torch.float)
# img = img.permute(1,2,0)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((352, 352))])
CLIPimg = transform(img).unsqueeze(0)
if CLIPText != "":
prompts = CLIPText.split(',')
with torch.no_grad():
preds = self.clip_session(CLIPimg.repeat(len(prompts),1,1,1), prompts)[0]
# preds = self.clip_session(CLIPimg, maskimg, True)[0]
clip_mask = 1 - torch.sigmoid(preds[0][0])
for i in range(len(prompts)-1):
clip_mask *= 1-torch.sigmoid(preds[i+1][0])
clip_mask = clip_mask.data.cpu().numpy()
thresh = CLIPAmount/100.0
clip_mask[clip_mask>thresh] = 1.0
clip_mask[clip_mask<=thresh] = 0.0
return clip_mask
# @profile
def apply_face_parser(self, img, FaceAmount, MouthAmount):
# atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
outpred = torch.ones((512,512), dtype=torch.float32, device='cuda').contiguous()
img = torch.div(img, 255)
img = v2.functional.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
img = torch.reshape(img, (1, 3, 512, 512))
outpred = torch.empty((1,19,512,512), dtype=torch.float32, device='cuda').contiguous()
self.models.run_faceparser(img, outpred)
outpred = torch.squeeze(outpred)
outpred = torch.argmax(outpred, 0)
# Mouth Parse
if MouthAmount <0:
mouth_idxs = torch.tensor([11], device='cuda')
iters = int(-MouthAmount)
mouth_parse = torch.isin(outpred, mouth_idxs)
mouth_parse = torch.clamp(~mouth_parse, 0, 1).type(torch.float32)
mouth_parse = torch.reshape(mouth_parse, (1, 1, 512, 512))
mouth_parse = torch.neg(mouth_parse)
mouth_parse = torch.add(mouth_parse, 1)
kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32,
device='cuda')
for i in range(iters):
mouth_parse = torch.nn.functional.conv2d(mouth_parse, kernel,
padding=(1, 1))
mouth_parse = torch.clamp(mouth_parse, 0, 1)
mouth_parse = torch.squeeze(mouth_parse)
mouth_parse = torch.neg(mouth_parse)
mouth_parse = torch.add(mouth_parse, 1)
mouth_parse = torch.reshape(mouth_parse, (1, 512, 512))
elif MouthAmount >0:
mouth_idxs = torch.tensor([11,12,13], device='cuda')
iters = int(MouthAmount)
mouth_parse = torch.isin(outpred, mouth_idxs)
mouth_parse = torch.clamp(~mouth_parse, 0, 1).type(torch.float32)
mouth_parse = torch.reshape(mouth_parse, (1,1,512,512))
mouth_parse = torch.neg(mouth_parse)
mouth_parse = torch.add(mouth_parse, 1)
kernel = torch.ones((1,1,3,3), dtype=torch.float32, device='cuda')
for i in range(iters):
mouth_parse = torch.nn.functional.conv2d(mouth_parse, kernel, padding=(1, 1))
mouth_parse = torch.clamp(mouth_parse, 0, 1)
mouth_parse = torch.squeeze(mouth_parse)
mouth_parse = torch.neg(mouth_parse)
mouth_parse = torch.add(mouth_parse, 1)
mouth_parse = torch.reshape(mouth_parse, (1, 512, 512))
else:
mouth_parse = torch.ones((1, 512, 512), dtype=torch.float32, device='cuda')
# BG Parse
bg_idxs = torch.tensor([0, 14, 15, 16, 17, 18], device=device)
bg_parse = torch.isin(outpred, bg_idxs)
bg_parse = torch.clamp(~bg_parse, 0, 1).type(torch.float32)
bg_parse = torch.reshape(bg_parse, (1, 1, 512, 512))
if FaceAmount > 0:
kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32, device=device)
for i in range(int(FaceAmount)):
bg_parse = torch.nn.functional.conv2d(bg_parse, kernel, padding=(1, 1))
bg_parse = torch.clamp(bg_parse, 0, 1)
bg_parse = torch.squeeze(bg_parse)
elif FaceAmount < 0:
bg_parse = torch.neg(bg_parse)
bg_parse = torch.add(bg_parse, 1)
kernel = torch.ones((1, 1, 3, 3), dtype=torch.float32, device=device)
for i in range(int(-FaceAmount)):
bg_parse = torch.nn.functional.conv2d(bg_parse, kernel, padding=(1, 1))
bg_parse = torch.clamp(bg_parse, 0, 1)
bg_parse = torch.squeeze(bg_parse)
bg_parse = torch.neg(bg_parse)
bg_parse = torch.add(bg_parse, 1)
bg_parse = torch.reshape(bg_parse, (1, 512, 512))
else:
bg_parse = torch.ones((1,512,512), dtype=torch.float32, device='cuda')
out_parse = torch.mul(bg_parse, mouth_parse)
return out_parse
def apply_bg_face_parser(self, img, FaceParserAmount):
# atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
# out = np.ones((512, 512), dtype=np.float32)
outpred = torch.ones((512,512), dtype=torch.float32, device='cuda').contiguous()
# turn mouth parser off at 0 so someone can just use the mouth parser
if FaceParserAmount != 0:
img = torch.div(img, 255)
img = v2.functional.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
img = torch.reshape(img, (1, 3, 512, 512))
outpred = torch.empty((1,19,512,512), dtype=torch.float32, device=device).contiguous()
self.models.run_faceparser(img, outpred)
outpred = torch.squeeze(outpred)
outpred = torch.argmax(outpred, 0)
test = torch.tensor([ 0, 14, 15, 16, 17, 18], device=device)
outpred = torch.isin(outpred, test)
outpred = torch.clamp(~outpred, 0, 1).type(torch.float32)
outpred = torch.reshape(outpred, (1,1,512,512))
if FaceParserAmount >0:
kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
for i in range(int(FaceParserAmount)):
outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
outpred = torch.clamp(outpred, 0, 1)
outpred = torch.squeeze(outpred)
if FaceParserAmount <0:
outpred = torch.neg(outpred)
outpred = torch.add(outpred, 1)
kernel = torch.ones((1,1,3,3), dtype=torch.float32, device=device)
for i in range(int(-FaceParserAmount)):
outpred = torch.nn.functional.conv2d(outpred, kernel, padding=(1, 1))
outpred = torch.clamp(outpred, 0, 1)
outpred = torch.squeeze(outpred)
outpred = torch.neg(outpred)
outpred = torch.add(outpred, 1)
outpred = torch.reshape(outpred, (1, 512, 512))
return outpred
def apply_restorer(self, swapped_face_upscaled, parameters):
temp = swapped_face_upscaled
t512 = v2.Resize((512, 512), antialias=False)
t256 = v2.Resize((256, 256), antialias=False)
t1024 = v2.Resize((1024, 1024), antialias=False)
# If using a separate detection mode
if parameters['RestorerDetTypeTextSel'] == 'Blend' or parameters['RestorerDetTypeTextSel'] == 'Reference':
if parameters['RestorerDetTypeTextSel'] == 'Blend':
# Set up Transformation
dst = self.arcface_dst * 4.0
dst[:,0] += 32.0
elif parameters['RestorerDetTypeTextSel'] == 'Reference':
try:
dst = self.models.resnet50(swapped_face_upscaled, score=parameters['DetectScoreSlider']/100.0)
except:
return swapped_face_upscaled
tform = trans.SimilarityTransform()
tform.estimate(dst, self.FFHQ_kps)
# Transform, scale, and normalize
temp = v2.functional.affine(swapped_face_upscaled, tform.rotation*57.2958, (tform.translation[0], tform.translation[1]) , tform.scale, 0, center = (0,0) )
temp = v2.functional.crop(temp, 0,0, 512, 512)
temp = torch.div(temp, 255)
temp = v2.functional.normalize(temp, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=False)
if parameters['RestorerTypeTextSel'] == 'GPEN256':
temp = t256(temp)
temp = torch.unsqueeze(temp, 0).contiguous()
# Bindings
outpred = torch.empty((1,3,512,512), dtype=torch.float32, device=device).contiguous()
if parameters['RestorerTypeTextSel'] == 'GFPGAN':
self.models.run_GFPGAN(temp, outpred)
elif parameters['RestorerTypeTextSel'] == 'CF':
self.models.run_codeformer(temp, outpred)
elif parameters['RestorerTypeTextSel'] == 'GPEN256':
outpred = torch.empty((1,3,256,256), dtype=torch.float32, device=device).contiguous()
self.models.run_GPEN_256(temp, outpred)
elif parameters['RestorerTypeTextSel'] == 'GPEN512':
self.models.run_GPEN_512(temp, outpred)
elif parameters['RestorerTypeTextSel'] == 'GPEN1024':
temp = t1024(temp)
outpred = torch.empty((1, 3, 1024, 1024), dtype=torch.float32, device=device).contiguous()
self.models.run_GPEN_1024(temp, outpred)
# Format back to cxHxW @ 255
outpred = torch.squeeze(outpred)
outpred = torch.clamp(outpred, -1, 1)
outpred = torch.add(outpred, 1)
outpred = torch.div(outpred, 2)
outpred = torch.mul(outpred, 255)
if parameters['RestorerTypeTextSel'] == 'GPEN256':
outpred = t512(outpred)
elif parameters['RestorerTypeTextSel'] == 'GPEN1024':
outpred = t512(outpred)
# Invert Transform
if parameters['RestorerDetTypeTextSel'] == 'Blend' or parameters['RestorerDetTypeTextSel'] == 'Reference':
outpred = v2.functional.affine(outpred, tform.inverse.rotation*57.2958, (tform.inverse.translation[0], tform.inverse.translation[1]), tform.inverse.scale, 0, interpolation=v2.InterpolationMode.BILINEAR, center = (0,0) )
# Blend
alpha = float(parameters["RestorerSlider"])/100.0
outpred = torch.add(torch.mul(outpred, alpha), torch.mul(swapped_face_upscaled, 1-alpha))
return outpred
def apply_fake_diff(self, swapped_face, original_face, DiffAmount):
swapped_face = swapped_face.permute(1,2,0)
original_face = original_face.permute(1,2,0)
diff = swapped_face-original_face
diff = torch.abs(diff)
# Find the diffrence between the swap and original, per channel
fthresh = DiffAmount*2.55
# Bimodal
diff[diff<fthresh] = 0
diff[diff>=fthresh] = 1
# If any of the channels exceeded the threshhold, them add them to the mask
diff = torch.sum(diff, dim=2)
diff = torch.unsqueeze(diff, 2)
diff[diff>0] = 1
diff = diff.permute(2,0,1)
return diff
def clear_mem(self):
del self.swapper_model
del self.GFPGAN_model
del self.occluder_model
del self.face_parsing_model
del self.codeformer_model
del self.GPEN_256_model
del self.GPEN_512_model
del self.GPEN_1024_model
del self.resnet_model
del self.detection_model
del self.recognition_model
self.swapper_model = []
self.GFPGAN_model = []
self.occluder_model = []
self.face_parsing_model = []
self.codeformer_model = []
self.GPEN_256_model = []
self.GPEN_512_model = []
self.GPEN_1024_model = []
self.resnet_model = []
self.detection_model = []
self.recognition_model = []
# test = swap.permute(1, 2, 0)
# test = test.cpu().numpy()
# cv2.imwrite('2.jpg', test)