phehvn commited on
Commit
21e7591
·
verified ·
1 Parent(s): 7748096

Upload 37 files

Browse files
installer/installer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import shutil
5
+ import site
6
+ import subprocess
7
+ import sys
8
+
9
+
10
+ script_dir = os.getcwd()
11
+
12
+
13
+ def run_cmd(cmd, capture_output=False, env=None):
14
+ # Run shell commands
15
+ return subprocess.run(cmd, shell=True, capture_output=capture_output, env=env)
16
+
17
+
18
+ def check_env():
19
+ # If we have access to conda, we are probably in an environment
20
+ conda_not_exist = run_cmd("conda", capture_output=True).returncode
21
+ if conda_not_exist:
22
+ print("Conda is not installed. Exiting...")
23
+ sys.exit()
24
+
25
+ # Ensure this is a new environment and not the base environment
26
+ if os.environ["CONDA_DEFAULT_ENV"] == "base":
27
+ print("Create an environment for this project and activate it. Exiting...")
28
+ sys.exit()
29
+
30
+
31
+ def install_dependencies():
32
+ global MY_PATH
33
+
34
+ # Install Git and clone repo
35
+ run_cmd("conda install -y -k git")
36
+ run_cmd("git clone https://github.com/C0untFloyd/roop-unleashed.git")
37
+ os.chdir(MY_PATH)
38
+ run_cmd("git checkout 64e227539d70dc9b83953bd230fbd4d26d2759c7")
39
+ # Installs dependencies from requirements.txt
40
+ run_cmd("python -m pip install -r requirements.txt")
41
+
42
+
43
+
44
+ def update_dependencies():
45
+ global MY_PATH
46
+
47
+ os.chdir(MY_PATH)
48
+ # do a hard reset for to update even if there are local changes
49
+ run_cmd("git fetch --all")
50
+ run_cmd("git reset --hard origin/main")
51
+ run_cmd("git pull")
52
+ # Installs/Updates dependencies from all requirements.txt
53
+ run_cmd("python -m pip install -r requirements.txt")
54
+
55
+
56
+ def start_app():
57
+ global MY_PATH
58
+
59
+ os.chdir(MY_PATH)
60
+ # forward commandline arguments
61
+ sys.argv.pop(0)
62
+ args = ' '.join(sys.argv)
63
+ print("Launching App")
64
+ run_cmd(f'python run.py {args}')
65
+
66
+
67
+ if __name__ == "__main__":
68
+ global MY_PATH
69
+
70
+ MY_PATH = "roop-unleashed"
71
+
72
+
73
+ # Verifies we are in a conda environment
74
+ check_env()
75
+
76
+ # If webui has already been installed, skip and run
77
+ if not os.path.exists(MY_PATH):
78
+ install_dependencies()
79
+ else:
80
+ # moved update from batch to here, because of batch limitations
81
+ updatechoice = input("Check for Updates? [y/n]").lower()
82
+ if updatechoice == "y":
83
+ update_dependencies()
84
+
85
+ # Run the model with webui
86
+ os.chdir(script_dir)
87
+ start_app()
installer/windows_run.bat ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ REM No CLI arguments supported anymore
3
+ set COMMANDLINE_ARGS=
4
+
5
+ cd /D "%~dp0"
6
+
7
+ echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniconda which can not be silently installed under a path with spaces. && goto end
8
+
9
+ set PATH=%PATH%;%SystemRoot%\system32
10
+
11
+ @rem config
12
+ set INSTALL_DIR=%cd%\installer_files
13
+ set CONDA_ROOT_PREFIX=%cd%\installer_files\conda
14
+ set INSTALL_ENV_DIR=%cd%\installer_files\env
15
+ set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe
16
+ set FFMPEG_DOWNLOAD_URL=https://github.com/GyanD/codexffmpeg/releases/download/2023-06-21-git-1bcb8a7338/ffmpeg-2023-06-21-git-1bcb8a7338-essentials_build.zip
17
+ set INSTALL_FFMPEG_DIR=%cd%\installer_files\ffmpeg
18
+ set conda_exists=F
19
+ set ffmpeg_exists=F
20
+
21
+
22
+ @rem figure out whether git and conda needs to be installed
23
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1
24
+ if "%ERRORLEVEL%" EQU "0" set conda_exists=T
25
+
26
+ @rem Check if FFmpeg is already in PATH
27
+ where ffmpeg >nul 2>&1
28
+ if "%ERRORLEVEL%" EQU "0" (
29
+ echo FFmpeg is already installed.
30
+ set ffmpeg_exists=T
31
+ )
32
+
33
+ @rem (if necessary) install git and conda into a contained environment
34
+ @rem download conda
35
+ if "%conda_exists%" == "F" (
36
+ echo Downloading Miniconda from %MINICONDA_DOWNLOAD_URL% to %INSTALL_DIR%\miniconda_installer.exe
37
+
38
+ mkdir "%INSTALL_DIR%"
39
+ call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe" || ( echo. && echo Miniconda failed to download. && goto end )
40
+
41
+ echo Installing Miniconda to %CONDA_ROOT_PREFIX%
42
+ start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX%
43
+
44
+ @rem test the conda binary
45
+ echo Miniconda version:
46
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" --version || ( echo. && echo Miniconda not found. && goto end )
47
+ )
48
+
49
+ @rem create the installer env
50
+ if not exist "%INSTALL_ENV_DIR%" (
51
+ echo Creating Conda Environment
52
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10 || ( echo. && echo Conda environment creation failed. && goto end )
53
+ )
54
+
55
+ @rem Download and install FFmpeg if not already installed
56
+ if "%ffmpeg_exists%" == "F" (
57
+ if not exist "%INSTALL_FFMPEG_DIR%" (
58
+ echo Downloading ffmpeg from %FFMPEG_DOWNLOAD_URL% to %INSTALL_DIR%
59
+ call curl -Lk "%FFMPEG_DOWNLOAD_URL%" > "%INSTALL_DIR%\ffmpeg.zip" || ( echo. && echo ffmpeg failed to download. && goto end )
60
+ call powershell -command "Expand-Archive -Force '%INSTALL_DIR%\ffmpeg.zip' '%INSTALL_DIR%\'"
61
+
62
+ cd "installer_files"
63
+ setlocal EnableExtensions EnableDelayedExpansion
64
+
65
+ for /f "tokens=*" %%f in ('dir /s /b /ad "ffmpeg*"') do (
66
+ ren "%%f" "ffmpeg"
67
+ )
68
+ endlocal
69
+ setx PATH "%INSTALL_FFMPEG_DIR%\bin\;%PATH%"
70
+ echo To use videos, you need to restart roop after this installation.
71
+ cd ..
72
+ )
73
+ ) else (
74
+ echo Skipping FFmpeg installation as it is already available.
75
+ )
76
+
77
+
78
+ @rem check if conda environment was actually created
79
+ if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo ERROR: Conda environment is empty. && goto end )
80
+
81
+ @rem activate installer env
82
+ call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo Miniconda hook not found. && goto end )
83
+
84
+ @rem setup installer env
85
+ echo Launching roop unleashed
86
+ call python installer.py %COMMANDLINE_ARGS%
87
+
88
+ echo.
89
+ echo Done!
90
+
91
+ :end
92
+ pause
93
+
94
+
95
+
roop/FaceSet.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class FaceSet:
4
+ faces = []
5
+ ref_images = []
6
+ embedding_average = 'None'
7
+ embeddings_backup = None
8
+
9
+ def __init__(self):
10
+ self.faces = []
11
+ self.ref_images = []
12
+ self.embeddings_backup = None
13
+
14
+ def AverageEmbeddings(self):
15
+ if len(self.faces) > 1 and self.embeddings_backup is None:
16
+ self.embeddings_backup = self.faces[0]['embedding']
17
+ embeddings = [face.embedding for face in self.faces]
18
+
19
+ self.faces[0]['embedding'] = np.mean(embeddings, axis=0)
20
+ # try median too?
roop/ProcessEntry.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ class ProcessEntry:
2
+ def __init__(self, filename: str, start: int, end: int, fps: float):
3
+ self.filename = filename
4
+ self.finalname = None
5
+ self.startframe = start
6
+ self.endframe = end
7
+ self.fps = fps
roop/ProcessMgr.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import psutil
5
+
6
+ from roop.ProcessOptions import ProcessOptions
7
+
8
+ from roop.face_util import get_first_face, get_all_faces, rotate_image_180, rotate_anticlockwise, rotate_clockwise, clamp_cut_values
9
+ from roop.utilities import compute_cosine_distance, get_device, str_to_class
10
+ import roop.vr_util as vr
11
+
12
+ from typing import Any, List, Callable
13
+ from roop.typing import Frame, Face
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ from threading import Thread, Lock
16
+ from queue import Queue
17
+ from tqdm import tqdm
18
+ from roop.ffmpeg_writer import FFMPEG_VideoWriter
19
+ import roop.globals
20
+
21
+
22
+ def create_queue(temp_frame_paths: List[str]) -> Queue[str]:
23
+ queue: Queue[str] = Queue()
24
+ for frame_path in temp_frame_paths:
25
+ queue.put(frame_path)
26
+ return queue
27
+
28
+
29
+ def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]:
30
+ queues = []
31
+ for _ in range(queue_per_future):
32
+ if not queue.empty():
33
+ queues.append(queue.get())
34
+ return queues
35
+
36
+
37
+ class ProcessMgr():
38
+ input_face_datas = []
39
+ target_face_datas = []
40
+
41
+ imagemask = None
42
+
43
+ processors = []
44
+ options : ProcessOptions = None
45
+
46
+ num_threads = 1
47
+ current_index = 0
48
+ processing_threads = 1
49
+ buffer_wait_time = 0.1
50
+
51
+ lock = Lock()
52
+
53
+ frames_queue = None
54
+ processed_queue = None
55
+
56
+ videowriter= None
57
+
58
+ progress_gradio = None
59
+ total_frames = 0
60
+
61
+
62
+
63
+
64
+ plugins = {
65
+ 'faceswap' : 'FaceSwapInsightFace',
66
+ 'mask_clip2seg' : 'Mask_Clip2Seg',
67
+ 'codeformer' : 'Enhance_CodeFormer',
68
+ 'gfpgan' : 'Enhance_GFPGAN',
69
+ 'dmdnet' : 'Enhance_DMDNet',
70
+ 'gpen' : 'Enhance_GPEN',
71
+ 'restoreformer++' : 'Enhance_RestoreFormerPPlus',
72
+ }
73
+
74
+ def __init__(self, progress):
75
+ if progress is not None:
76
+ self.progress_gradio = progress
77
+
78
+
79
+ def initialize(self, input_faces, target_faces, options):
80
+ self.input_face_datas = input_faces
81
+ self.target_face_datas = target_faces
82
+ self.options = options
83
+
84
+ roop.globals.g_desired_face_analysis=["landmark_3d_68", "landmark_2d_106","detection","recognition"]
85
+ if options.swap_mode == "all_female" or options.swap_mode == "all_male":
86
+ roop.globals.g_desired_face_analysis.append("genderage")
87
+
88
+ processornames = options.processors.split(",")
89
+ devicename = get_device()
90
+ if len(self.processors) < 1:
91
+ for pn in processornames:
92
+ classname = self.plugins[pn]
93
+ module = 'roop.processors.' + classname
94
+ p = str_to_class(module, classname)
95
+ if p is not None:
96
+ p.Initialize(devicename)
97
+ self.processors.append(p)
98
+ else:
99
+ print(f"Not using {module}")
100
+ else:
101
+ for i in range(len(self.processors) -1, -1, -1):
102
+ if not self.processors[i].processorname in processornames:
103
+ self.processors[i].Release()
104
+ del self.processors[i]
105
+
106
+ for i,pn in enumerate(processornames):
107
+ if i >= len(self.processors) or self.processors[i].processorname != pn:
108
+ classname = self.plugins[pn]
109
+ module = 'roop.processors.' + classname
110
+ p = str_to_class(module, classname)
111
+ if p is not None:
112
+ p.Initialize(devicename)
113
+ self.processors.insert(i, p)
114
+ else:
115
+ print(f"Not using {module}")
116
+
117
+
118
+ if isinstance(self.options.imagemask, dict) and self.options.imagemask.get("layers") and len(self.options.imagemask["layers"]) > 0:
119
+ self.options.imagemask = self.options.imagemask.get("layers")[0]
120
+ # Get rid of alpha
121
+ self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_RGBA2GRAY)
122
+ if np.any(self.options.imagemask):
123
+ mo = self.input_face_datas[0].faces[0].mask_offsets
124
+ self.options.imagemask = self.blur_area(self.options.imagemask, mo[4], mo[5])
125
+ self.options.imagemask = self.options.imagemask.astype(np.float32) / 255
126
+ self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_GRAY2RGB)
127
+ else:
128
+ self.options.imagemask = None
129
+
130
+
131
+
132
+
133
+ def run_batch(self, source_files, target_files, threads:int = 1):
134
+ progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
135
+ self.total_frames = len(source_files)
136
+ self.num_threads = threads
137
+ with tqdm(total=self.total_frames, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
138
+ with ThreadPoolExecutor(max_workers=threads) as executor:
139
+ futures = []
140
+ queue = create_queue(source_files)
141
+ queue_per_future = max(len(source_files) // threads, 1)
142
+ while not queue.empty():
143
+ future = executor.submit(self.process_frames, source_files, target_files, pick_queue(queue, queue_per_future), lambda: self.update_progress(progress))
144
+ futures.append(future)
145
+ for future in as_completed(futures):
146
+ future.result()
147
+
148
+
149
+ def process_frames(self, source_files: List[str], target_files: List[str], current_files, update: Callable[[], None]) -> None:
150
+ for f in current_files:
151
+ if not roop.globals.processing:
152
+ return
153
+
154
+ # Decode the byte array into an OpenCV image
155
+ temp_frame = cv2.imdecode(np.fromfile(f, dtype=np.uint8), cv2.IMREAD_COLOR)
156
+ if temp_frame is not None:
157
+ resimg = self.process_frame(temp_frame)
158
+ if resimg is not None:
159
+ i = source_files.index(f)
160
+ cv2.imwrite(target_files[i], resimg)
161
+ if update:
162
+ update()
163
+
164
+
165
+
166
+ def read_frames_thread(self, cap, frame_start, frame_end, num_threads):
167
+ num_frame = 0
168
+ total_num = frame_end - frame_start
169
+ if frame_start > 0:
170
+ cap.set(cv2.CAP_PROP_POS_FRAMES,frame_start)
171
+
172
+ while True and roop.globals.processing:
173
+ ret, frame = cap.read()
174
+ if not ret:
175
+ break
176
+
177
+ self.frames_queue[num_frame % num_threads].put(frame, block=True)
178
+ num_frame += 1
179
+ if num_frame == total_num:
180
+ break
181
+
182
+ for i in range(num_threads):
183
+ self.frames_queue[i].put(None)
184
+
185
+
186
+
187
+ def process_videoframes(self, threadindex, progress) -> None:
188
+ while True:
189
+ frame = self.frames_queue[threadindex].get()
190
+ if frame is None:
191
+ self.processing_threads -= 1
192
+ self.processed_queue[threadindex].put((False, None))
193
+ return
194
+ else:
195
+ resimg = self.process_frame(frame)
196
+ self.processed_queue[threadindex].put((True, resimg))
197
+ del frame
198
+ progress()
199
+
200
+
201
+ def write_frames_thread(self):
202
+ nextindex = 0
203
+ num_producers = self.num_threads
204
+
205
+ while True:
206
+ process, frame = self.processed_queue[nextindex % self.num_threads].get()
207
+ nextindex += 1
208
+ if frame is not None:
209
+ self.videowriter.write_frame(frame)
210
+ del frame
211
+ elif process == False:
212
+ num_producers -= 1
213
+ if num_producers < 1:
214
+ return
215
+
216
+
217
+
218
+ def run_batch_inmem(self, source_video, target_video, frame_start, frame_end, fps, threads:int = 1, skip_audio=False):
219
+ cap = cv2.VideoCapture(source_video)
220
+ # frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
221
+ frame_count = (frame_end - frame_start) + 1
222
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
223
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
224
+
225
+ self.total_frames = frame_count
226
+ self.num_threads = threads
227
+
228
+ self.processing_threads = self.num_threads
229
+ self.frames_queue = []
230
+ self.processed_queue = []
231
+ for _ in range(threads):
232
+ self.frames_queue.append(Queue(1))
233
+ self.processed_queue.append(Queue(1))
234
+
235
+ self.videowriter = FFMPEG_VideoWriter(target_video, (width, height), fps, codec=roop.globals.video_encoder, crf=roop.globals.video_quality, audiofile=None)
236
+
237
+ readthread = Thread(target=self.read_frames_thread, args=(cap, frame_start, frame_end, threads))
238
+ readthread.start()
239
+
240
+ writethread = Thread(target=self.write_frames_thread)
241
+ writethread.start()
242
+
243
+ progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
244
+ with tqdm(total=self.total_frames, desc='Processing', unit='frames', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
245
+ with ThreadPoolExecutor(thread_name_prefix='swap_proc', max_workers=self.num_threads) as executor:
246
+ futures = []
247
+
248
+ for threadindex in range(threads):
249
+ future = executor.submit(self.process_videoframes, threadindex, lambda: self.update_progress(progress))
250
+ futures.append(future)
251
+
252
+ for future in as_completed(futures):
253
+ future.result()
254
+ # wait for the task to complete
255
+ readthread.join()
256
+ writethread.join()
257
+ cap.release()
258
+ self.videowriter.close()
259
+ self.frames_queue.clear()
260
+ self.processed_queue.clear()
261
+
262
+
263
+
264
+
265
+ def update_progress(self, progress: Any = None) -> None:
266
+ process = psutil.Process(os.getpid())
267
+ memory_usage = process.memory_info().rss / 1024 / 1024 / 1024
268
+ progress.set_postfix({
269
+ 'memory_usage': '{:.2f}'.format(memory_usage).zfill(5) + 'GB',
270
+ 'execution_threads': self.num_threads
271
+ })
272
+ progress.update(1)
273
+ self.progress_gradio((progress.n, self.total_frames), desc='Processing', total=self.total_frames, unit='frames')
274
+
275
+
276
+ def on_no_face_action(self, frame:Frame):
277
+ if roop.globals.no_face_action == 0:
278
+ return None, frame
279
+ elif roop.globals.no_face_action == 2:
280
+ return None, None
281
+
282
+
283
+ faces = get_all_faces(frame)
284
+ if faces is not None:
285
+ return faces, frame
286
+ return None, frame
287
+
288
+ # https://github.com/deepinsight/insightface#third-party-re-implementation-of-arcface
289
+ # https://github.com/deepinsight/insightface/blob/master/alignment/coordinate_reg/image_infer.py
290
+ # https://github.com/deepinsight/insightface/issues/1350
291
+ # https://github.com/linghu8812/tensorrt_inference
292
+
293
+
294
+ def process_frame(self, frame:Frame):
295
+ use_original_frame = 0
296
+ skip_frame = 2
297
+
298
+ if len(self.input_face_datas) < 1:
299
+ return frame
300
+ temp_frame = frame.copy()
301
+ num_swapped, temp_frame = self.swap_faces(frame, temp_frame)
302
+ if num_swapped > 0:
303
+ return temp_frame
304
+ if roop.globals.no_face_action == use_original_frame:
305
+ return frame
306
+ if roop.globals.no_face_action == skip_frame:
307
+ #This only works with in-mem processing, as it simply skips the frame.
308
+ #For 'extract frames' it simply leaves the unprocessed frame unprocessed and it gets used in the final output by ffmpeg.
309
+ #If we could delete that frame here, that'd work but that might cause ffmpeg to fail unless the frames are renamed, and I don't think we have the info on what frame it actually is?????
310
+ #alternatively, it could mark all the necessary frames for deletion, delete them at the end, then rename the remaining frames that might work?
311
+ return None
312
+ else:
313
+ copyframe = frame.copy()
314
+ copyframe = rotate_image_180(copyframe)
315
+ temp_frame = copyframe.copy()
316
+ num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame)
317
+ if num_swapped == 0:
318
+ return frame
319
+ temp_frame = rotate_image_180(temp_frame)
320
+ return temp_frame
321
+
322
+
323
+ def swap_faces(self, frame, temp_frame):
324
+ num_faces_found = 0
325
+
326
+ if self.options.swap_mode == "first":
327
+ face = get_first_face(frame)
328
+
329
+ if face is None:
330
+ return num_faces_found, frame
331
+
332
+ num_faces_found += 1
333
+ temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
334
+ else:
335
+ faces = get_all_faces(frame)
336
+ if faces is None:
337
+ return num_faces_found, frame
338
+
339
+ if self.options.swap_mode == "all":
340
+ for face in faces:
341
+ num_faces_found += 1
342
+ temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
343
+ del face
344
+
345
+ elif self.options.swap_mode == "selected":
346
+ use_index = len(self.target_face_datas) == 1
347
+ for i,tf in enumerate(self.target_face_datas):
348
+ for face in faces:
349
+ if compute_cosine_distance(tf.embedding, face.embedding) <= self.options.face_distance_threshold:
350
+ if i < len(self.input_face_datas):
351
+ if use_index:
352
+ temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
353
+ else:
354
+ temp_frame = self.process_face(i, face, temp_frame)
355
+ num_faces_found += 1
356
+ if not roop.globals.vr_mode:
357
+ break
358
+ del face
359
+ elif self.options.swap_mode == "all_female" or self.options.swap_mode == "all_male":
360
+ gender = 'F' if self.options.swap_mode == "all_female" else 'M'
361
+ for face in faces:
362
+ if face.sex == gender:
363
+ num_faces_found += 1
364
+ temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
365
+ del face
366
+
367
+ if roop.globals.vr_mode and num_faces_found % 2 > 0:
368
+ # stereo image, there has to be an even number of faces
369
+ num_faces_found = 0
370
+ return num_faces_found, frame
371
+ if num_faces_found == 0:
372
+ return num_faces_found, frame
373
+
374
+ maskprocessor = next((x for x in self.processors if x.processorname == 'clip2seg'), None)
375
+
376
+ if self.options.imagemask is not None and self.options.imagemask.shape == frame.shape:
377
+ temp_frame = self.simple_blend_with_mask(temp_frame, frame, self.options.imagemask)
378
+
379
+ if maskprocessor is not None:
380
+ temp_frame = self.process_mask(maskprocessor, frame, temp_frame)
381
+ return num_faces_found, temp_frame
382
+
383
+
384
+ def rotation_action(self, original_face:Face, frame:Frame):
385
+ (height, width) = frame.shape[:2]
386
+
387
+ bounding_box_width = original_face.bbox[2] - original_face.bbox[0]
388
+ bounding_box_height = original_face.bbox[3] - original_face.bbox[1]
389
+ horizontal_face = bounding_box_width > bounding_box_height
390
+
391
+ center_x = width // 2.0
392
+ start_x = original_face.bbox[0]
393
+ end_x = original_face.bbox[2]
394
+ bbox_center_x = start_x + (bounding_box_width // 2.0)
395
+
396
+ # need to leverage the array of landmarks as decribed here:
397
+ # https://github.com/deepinsight/insightface/tree/master/alignment/coordinate_reg
398
+ # basically, we should be able to check for the relative position of eyes and nose
399
+ # then use that to determine which way the face is actually facing when in a horizontal position
400
+ # and use that to determine the correct rotation_action
401
+
402
+ forehead_x = original_face.landmark_2d_106[72][0]
403
+ chin_x = original_face.landmark_2d_106[0][0]
404
+
405
+ if horizontal_face:
406
+ if chin_x < forehead_x:
407
+ # this is someone lying down with their face like this (:
408
+ return "rotate_anticlockwise"
409
+ elif forehead_x < chin_x:
410
+ # this is someone lying down with their face like this :)
411
+ return "rotate_clockwise"
412
+ if bbox_center_x >= center_x:
413
+ # this is someone lying down with their face in the right hand side of the frame
414
+ return "rotate_anticlockwise"
415
+ if bbox_center_x < center_x:
416
+ # this is someone lying down with their face in the left hand side of the frame
417
+ return "rotate_clockwise"
418
+
419
+ return None
420
+
421
+
422
+ def auto_rotate_frame(self, original_face, frame:Frame):
423
+ target_face = original_face
424
+ original_frame = frame
425
+
426
+ rotation_action = self.rotation_action(original_face, frame)
427
+
428
+ if rotation_action == "rotate_anticlockwise":
429
+ #face is horizontal, rotating frame anti-clockwise and getting face bounding box from rotated frame
430
+ frame = rotate_anticlockwise(frame)
431
+ elif rotation_action == "rotate_clockwise":
432
+ #face is horizontal, rotating frame clockwise and getting face bounding box from rotated frame
433
+ frame = rotate_clockwise(frame)
434
+
435
+ return target_face, frame, rotation_action
436
+
437
+
438
+ def auto_unrotate_frame(self, frame:Frame, rotation_action):
439
+ if rotation_action == "rotate_anticlockwise":
440
+ return rotate_clockwise(frame)
441
+ elif rotation_action == "rotate_clockwise":
442
+ return rotate_anticlockwise(frame)
443
+
444
+ return frame
445
+
446
+
447
+
448
+ def process_face(self,face_index, target_face:Face, frame:Frame):
449
+ enhanced_frame = None
450
+ inputface = self.input_face_datas[face_index].faces[0]
451
+
452
+ rotation_action = None
453
+ if roop.globals.autorotate_faces:
454
+ # check for sideways rotation of face
455
+ rotation_action = self.rotation_action(target_face, frame)
456
+ if rotation_action is not None:
457
+ (startX, startY, endX, endY) = target_face["bbox"].astype("int")
458
+ width = endX - startX
459
+ height = endY - startY
460
+ offs = int(max(width,height) * 0.25)
461
+ rotcutframe,startX, startY, endX, endY = self.cutout(frame, startX - offs, startY - offs, endX + offs, endY + offs)
462
+ if rotation_action == "rotate_anticlockwise":
463
+ rotcutframe = rotate_anticlockwise(rotcutframe)
464
+ elif rotation_action == "rotate_clockwise":
465
+ rotcutframe = rotate_clockwise(rotcutframe)
466
+ # rotate image and re-detect face to correct wonky landmarks
467
+ rotface = get_first_face(rotcutframe)
468
+ if rotface is None:
469
+ rotation_action = None
470
+ else:
471
+ saved_frame = frame.copy()
472
+ frame = rotcutframe
473
+ target_face = rotface
474
+
475
+
476
+
477
+ # if roop.globals.vr_mode:
478
+ # bbox = target_face.bbox
479
+ # [orig_width, orig_height, _] = frame.shape
480
+
481
+ # # Convert bounding box to ints
482
+ # x1, y1, x2, y2 = map(int, bbox)
483
+
484
+ # # Determine the center of the bounding box
485
+ # x_center = (x1 + x2) / 2
486
+ # y_center = (y1 + y2) / 2
487
+
488
+ # # Normalize coordinates to range [-1, 1]
489
+ # x_center_normalized = x_center / (orig_width / 2) - 1
490
+ # y_center_normalized = y_center / (orig_width / 2) - 1
491
+
492
+ # # Convert normalized coordinates to spherical (theta, phi)
493
+ # theta = x_center_normalized * 180 # Theta ranges from -180 to 180 degrees
494
+ # phi = -y_center_normalized * 90 # Phi ranges from -90 to 90 degrees
495
+
496
+ # img = vr.GetPerspective(frame, 90, theta, phi, 1280, 1280) # Generate perspective image
497
+
498
+ fake_frame = None
499
+ for p in self.processors:
500
+ if p.type == 'swap':
501
+ fake_frame = p.Run(inputface, target_face, frame)
502
+ scale_factor = 0.0
503
+ elif p.type == 'mask':
504
+ continue
505
+ else:
506
+ enhanced_frame, scale_factor = p.Run(self.input_face_datas[face_index], target_face, fake_frame)
507
+
508
+ upscale = 512
509
+ orig_width = fake_frame.shape[1]
510
+
511
+ fake_frame = cv2.resize(fake_frame, (upscale, upscale), cv2.INTER_CUBIC)
512
+ mask_offsets = inputface.mask_offsets
513
+
514
+
515
+ if enhanced_frame is None:
516
+ scale_factor = int(upscale / orig_width)
517
+ result = self.paste_upscale(fake_frame, fake_frame, target_face.matrix, frame, scale_factor, mask_offsets)
518
+ else:
519
+ result = self.paste_upscale(fake_frame, enhanced_frame, target_face.matrix, frame, scale_factor, mask_offsets)
520
+
521
+ if rotation_action is not None:
522
+ fake_frame = self.auto_unrotate_frame(result, rotation_action)
523
+ return self.paste_simple(fake_frame, saved_frame, startX, startY)
524
+
525
+ return result
526
+
527
+
528
+
529
+
530
+ def cutout(self, frame:Frame, start_x, start_y, end_x, end_y):
531
+ if start_x < 0:
532
+ start_x = 0
533
+ if start_y < 0:
534
+ start_y = 0
535
+ if end_x > frame.shape[1]:
536
+ end_x = frame.shape[1]
537
+ if end_y > frame.shape[0]:
538
+ end_y = frame.shape[0]
539
+ return frame[start_y:end_y, start_x:end_x], start_x, start_y, end_x, end_y
540
+
541
+ def paste_simple(self, src:Frame, dest:Frame, start_x, start_y):
542
+ end_x = start_x + src.shape[1]
543
+ end_y = start_y + src.shape[0]
544
+
545
+ start_x, end_x, start_y, end_y = clamp_cut_values(start_x, end_x, start_y, end_y, dest)
546
+ dest[start_y:end_y, start_x:end_x] = src
547
+ return dest
548
+
549
+ def simple_blend_with_mask(self, image1, image2, mask):
550
+ # Blend the images
551
+ blended_image = image1.astype(np.float32) * (1.0 - mask) + image2.astype(np.float32) * mask
552
+ return blended_image.astype(np.uint8)
553
+
554
+
555
+ def paste_upscale(self, fake_face, upsk_face, M, target_img, scale_factor, mask_offsets):
556
+ M_scale = M * scale_factor
557
+ IM = cv2.invertAffineTransform(M_scale)
558
+
559
+ face_matte = np.full((target_img.shape[0],target_img.shape[1]), 255, dtype=np.uint8)
560
+ # Generate white square sized as a upsk_face
561
+ img_matte = np.zeros((upsk_face.shape[0],upsk_face.shape[1]), dtype=np.uint8)
562
+
563
+ w = img_matte.shape[1]
564
+ h = img_matte.shape[0]
565
+
566
+ top = int(mask_offsets[0] * h)
567
+ bottom = int(h - (mask_offsets[1] * h))
568
+ left = int(mask_offsets[2] * w)
569
+ right = int(w - (mask_offsets[3] * w))
570
+ img_matte[top:bottom,left:right] = 255
571
+
572
+ # Transform white square back to target_img
573
+ img_matte = cv2.warpAffine(img_matte, IM, (target_img.shape[1], target_img.shape[0]), flags=cv2.INTER_NEAREST, borderValue=0.0)
574
+ ##Blacken the edges of face_matte by 1 pixels (so the mask in not expanded on the image edges)
575
+ img_matte[:1,:] = img_matte[-1:,:] = img_matte[:,:1] = img_matte[:,-1:] = 0
576
+
577
+ img_matte = self.blur_area(img_matte, mask_offsets[4], mask_offsets[5])
578
+ #Normalize images to float values and reshape
579
+ img_matte = img_matte.astype(np.float32)/255
580
+ face_matte = face_matte.astype(np.float32)/255
581
+ img_matte = np.minimum(face_matte, img_matte)
582
+ if self.options.show_mask:
583
+ # Additional steps for green overlay
584
+ green_overlay = np.zeros_like(target_img)
585
+ green_color = [0, 255, 0] # RGB for green
586
+ for i in range(3): # Apply green color where img_matte is not zero
587
+ green_overlay[:, :, i] = np.where(img_matte > 0, green_color[i], 0) ##Transform upcaled face back to target_img
588
+ img_matte = np.reshape(img_matte, [img_matte.shape[0],img_matte.shape[1],1])
589
+ paste_face = cv2.warpAffine(upsk_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE)
590
+ if upsk_face is not fake_face:
591
+ fake_face = cv2.warpAffine(fake_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE)
592
+ paste_face = cv2.addWeighted(paste_face, self.options.blend_ratio, fake_face, 1.0 - self.options.blend_ratio, 0)
593
+
594
+ # Re-assemble image
595
+ paste_face = img_matte * paste_face
596
+ paste_face = paste_face + (1-img_matte) * target_img.astype(np.float32)
597
+ if self.options.show_mask:
598
+ # Overlay the green overlay on the final image
599
+ paste_face = cv2.addWeighted(paste_face.astype(np.uint8), 1 - 0.5, green_overlay, 0.5, 0)
600
+ return paste_face.astype(np.uint8)
601
+
602
+
603
+ def blur_area(self, img_matte, num_erosion_iterations, blur_amount):
604
+ # Detect the affine transformed white area
605
+ mask_h_inds, mask_w_inds = np.where(img_matte==255)
606
+ # Calculate the size (and diagonal size) of transformed white area width and height boundaries
607
+ mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
608
+ mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
609
+ mask_size = int(np.sqrt(mask_h*mask_w))
610
+ # Calculate the kernel size for eroding img_matte by kernel (insightface empirical guess for best size was max(mask_size//10,10))
611
+ # k = max(mask_size//12, 8)
612
+ k = max(mask_size//(blur_amount // 2) , blur_amount // 2)
613
+ kernel = np.ones((k,k),np.uint8)
614
+ img_matte = cv2.erode(img_matte,kernel,iterations = num_erosion_iterations)
615
+ #Calculate the kernel size for blurring img_matte by blur_size (insightface empirical guess for best size was max(mask_size//20, 5))
616
+ # k = max(mask_size//24, 4)
617
+ k = max(mask_size//blur_amount, blur_amount//5)
618
+ kernel_size = (k, k)
619
+ blur_size = tuple(2*i+1 for i in kernel_size)
620
+ return cv2.GaussianBlur(img_matte, blur_size, 0)
621
+
622
+
623
+ def process_mask(self, processor, frame:Frame, target:Frame):
624
+ img_mask = processor.Run(frame, self.options.masking_text)
625
+ img_mask = cv2.resize(img_mask, (target.shape[1], target.shape[0]))
626
+ img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
627
+
628
+ target = target.astype(np.float32)
629
+ result = (1-img_mask) * target
630
+ result += img_mask * frame.astype(np.float32)
631
+ return np.uint8(result)
632
+
633
+
634
+
635
+
636
+ def unload_models():
637
+ pass
638
+
639
+
640
+ def release_resources(self):
641
+ for p in self.processors:
642
+ p.Release()
643
+ self.processors.clear()
644
+
roop/ProcessOptions.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ProcessOptions:
2
+
3
+ def __init__(self,processors, face_distance, blend_ratio, swap_mode, selected_index, masking_text, imagemask, show_mask=False):
4
+ self.processors = processors
5
+ self.face_distance_threshold = face_distance
6
+ self.blend_ratio = blend_ratio
7
+ self.swap_mode = swap_mode
8
+ self.selected_index = selected_index
9
+ self.masking_text = masking_text
10
+ self.imagemask = imagemask
11
+ self.show_mask = show_mask
roop/__init__.py ADDED
File without changes
roop/capturer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from roop.typing import Frame
6
+
7
+ def get_image_frame(filename: str):
8
+ try:
9
+ return cv2.imdecode(np.fromfile(filename, dtype=np.uint8), cv2.IMREAD_COLOR)
10
+ except:
11
+ print(f"Exception reading {filename}")
12
+ return None
13
+
14
+
15
+ def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Frame]:
16
+ capture = cv2.VideoCapture(video_path)
17
+ frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
18
+ capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1))
19
+ has_frame, frame = capture.read()
20
+ capture.release()
21
+ if has_frame:
22
+ return frame
23
+ return None
24
+
25
+
26
+ def get_video_frame_total(video_path: str) -> int:
27
+ capture = cv2.VideoCapture(video_path)
28
+ video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
29
+ capture.release()
30
+ return video_frame_total
roop/core.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import sys
5
+ import shutil
6
+ # single thread doubles cuda performance - needs to be set before torch import
7
+ if any(arg.startswith('--execution-provider') for arg in sys.argv):
8
+ os.environ['OMP_NUM_THREADS'] = '1'
9
+
10
+ import warnings
11
+ from typing import List
12
+ import platform
13
+ import signal
14
+ import torch
15
+ import onnxruntime
16
+ import pathlib
17
+
18
+ from time import time
19
+
20
+ import roop.globals
21
+ import roop.metadata
22
+ import roop.utilities as util
23
+ import roop.util_ffmpeg as ffmpeg
24
+ import ui.main as main
25
+ from settings import Settings
26
+ from roop.face_util import extract_face_images
27
+ from roop.ProcessEntry import ProcessEntry
28
+ from roop.ProcessMgr import ProcessMgr
29
+ from roop.ProcessOptions import ProcessOptions
30
+ from roop.capturer import get_video_frame_total
31
+
32
+
33
+ clip_text = None
34
+
35
+ call_display_ui = None
36
+
37
+ process_mgr = None
38
+
39
+
40
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
41
+ del torch
42
+
43
+ warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
44
+ warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
45
+
46
+
47
+ def parse_args() -> None:
48
+ signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
49
+ roop.globals.headless = False
50
+ # Always enable all processors when using GUI
51
+ if len(sys.argv) > 1:
52
+ print('No CLI args supported - use Settings Tab instead')
53
+ roop.globals.frame_processors = ['face_swapper', 'face_enhancer']
54
+
55
+
56
+ def encode_execution_providers(execution_providers: List[str]) -> List[str]:
57
+ return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
58
+
59
+
60
+ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
61
+ return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
62
+ if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
63
+
64
+
65
+ def suggest_max_memory() -> int:
66
+ if platform.system().lower() == 'darwin':
67
+ return 4
68
+ return 16
69
+
70
+
71
+ def suggest_execution_providers() -> List[str]:
72
+ return encode_execution_providers(onnxruntime.get_available_providers())
73
+
74
+
75
+ def suggest_execution_threads() -> int:
76
+ if 'DmlExecutionProvider' in roop.globals.execution_providers:
77
+ return 1
78
+ if 'ROCMExecutionProvider' in roop.globals.execution_providers:
79
+ return 1
80
+ return 8
81
+
82
+
83
+ def limit_resources() -> None:
84
+ # limit memory usage
85
+ if roop.globals.max_memory:
86
+ memory = roop.globals.max_memory * 1024 ** 3
87
+ if platform.system().lower() == 'darwin':
88
+ memory = roop.globals.max_memory * 1024 ** 6
89
+ if platform.system().lower() == 'windows':
90
+ import ctypes
91
+ kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
92
+ kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
93
+ else:
94
+ import resource
95
+ resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
96
+
97
+
98
+
99
+ def release_resources() -> None:
100
+ import gc
101
+ global process_mgr
102
+
103
+ if process_mgr is not None:
104
+ process_mgr.release_resources()
105
+ process_mgr = None
106
+
107
+ gc.collect()
108
+ # if 'CUDAExecutionProvider' in roop.globals.execution_providers and torch.cuda.is_available():
109
+ # with torch.cuda.device('cuda'):
110
+ # torch.cuda.empty_cache()
111
+ # torch.cuda.ipc_collect()
112
+
113
+
114
+ def pre_check() -> bool:
115
+ if sys.version_info < (3, 9):
116
+ update_status('Python version is not supported - please upgrade to 3.9 or higher.')
117
+ return False
118
+
119
+ download_directory_path = util.resolve_relative_path('../models')
120
+ util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/inswapper_128.onnx'])
121
+ util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GFPGANv1.4.onnx'])
122
+ util.conditional_download(download_directory_path, ['https://github.com/csxmli2016/DMDNet/releases/download/v1/DMDNet.pth'])
123
+ util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GPEN-BFR-512.onnx'])
124
+ util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/restoreformer_plus_plus.onnx'])
125
+ download_directory_path = util.resolve_relative_path('../models/CLIP')
126
+ util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/rd64-uni-refined.pth'])
127
+ download_directory_path = util.resolve_relative_path('../models/CodeFormer')
128
+ util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/CodeFormerv0.1.onnx'])
129
+
130
+ if not shutil.which('ffmpeg'):
131
+ update_status('ffmpeg is not installed.')
132
+ return True
133
+
134
+ def set_display_ui(function):
135
+ global call_display_ui
136
+
137
+ call_display_ui = function
138
+
139
+
140
+ def update_status(message: str) -> None:
141
+ global call_display_ui
142
+
143
+ print(message)
144
+ if call_display_ui is not None:
145
+ call_display_ui(message)
146
+
147
+
148
+
149
+
150
+ def start() -> None:
151
+ if roop.globals.headless:
152
+ print('Headless mode currently unsupported - starting UI!')
153
+ # faces = extract_face_images(roop.globals.source_path, (False, 0))
154
+ # roop.globals.INPUT_FACES.append(faces[roop.globals.source_face_index])
155
+ # faces = extract_face_images(roop.globals.target_path, (False, util.has_image_extension(roop.globals.target_path)))
156
+ # roop.globals.TARGET_FACES.append(faces[roop.globals.target_face_index])
157
+ # if 'face_enhancer' in roop.globals.frame_processors:
158
+ # roop.globals.selected_enhancer = 'GFPGAN'
159
+
160
+ batch_process(None, False, None)
161
+
162
+
163
+ def get_processing_plugins(use_clip):
164
+ processors = "faceswap"
165
+ if use_clip:
166
+ processors += ",mask_clip2seg"
167
+
168
+ if roop.globals.selected_enhancer == 'GFPGAN':
169
+ processors += ",gfpgan"
170
+ elif roop.globals.selected_enhancer == 'Codeformer':
171
+ processors += ",codeformer"
172
+ elif roop.globals.selected_enhancer == 'DMDNet':
173
+ processors += ",dmdnet"
174
+ elif roop.globals.selected_enhancer == 'GPEN':
175
+ processors += ",gpen"
176
+ elif roop.globals.selected_enhancer == 'Restoreformer++':
177
+ processors += ",restoreformer++"
178
+ return processors
179
+
180
+
181
+ def live_swap(frame, swap_mode, use_clip, clip_text, imagemask, show_mask, selected_index = 0):
182
+ global process_mgr
183
+
184
+ if frame is None:
185
+ return frame
186
+
187
+ if process_mgr is None:
188
+ process_mgr = ProcessMgr(None)
189
+
190
+ if len(roop.globals.INPUT_FACESETS) <= selected_index:
191
+ selected_index = 0
192
+ options = ProcessOptions(get_processing_plugins(use_clip), roop.globals.distance_threshold, roop.globals.blend_ratio,
193
+ swap_mode, selected_index, clip_text,imagemask, show_mask)
194
+ process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options)
195
+ newframe = process_mgr.process_frame(frame)
196
+ if newframe is None:
197
+ return frame
198
+ return newframe
199
+
200
+
201
+ def preview_mask(frame, clip_text):
202
+ import numpy as np
203
+ global process_mgr
204
+
205
+ maskimage = np.zeros((frame.shape), np.uint8)
206
+ if process_mgr is None:
207
+ process_mgr = ProcessMgr(None)
208
+ options = ProcessOptions("mask_clip2seg", roop.globals.distance_threshold, roop.globals.blend_ratio, "None", 0, clip_text, None)
209
+ process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options)
210
+ maskprocessor = next((x for x in process_mgr.processors if x.processorname == 'clip2seg'), None)
211
+ return process_mgr.process_mask(maskprocessor, frame, maskimage)
212
+
213
+
214
+
215
+
216
+
217
+ def batch_process(files:list[ProcessEntry], use_clip, new_clip_text, use_new_method, imagemask, progress, selected_index = 0) -> None:
218
+ global clip_text, process_mgr
219
+
220
+ roop.globals.processing = True
221
+ release_resources()
222
+ limit_resources()
223
+
224
+ # limit threads for some providers
225
+ max_threads = suggest_execution_threads()
226
+ if max_threads == 1:
227
+ roop.globals.execution_threads = 1
228
+
229
+ imagefiles:list[ProcessEntry] = []
230
+ videofiles:list[ProcessEntry] = []
231
+
232
+ update_status('Sorting videos/images')
233
+
234
+
235
+ for index, f in enumerate(files):
236
+ fullname = f.filename
237
+ if util.has_image_extension(fullname):
238
+ destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'.{roop.globals.CFG.output_image_format}')
239
+ destination = util.replace_template(destination, index=index)
240
+ pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True)
241
+ f.finalname = destination
242
+ imagefiles.append(f)
243
+
244
+ elif util.is_video(fullname) or util.has_extension(fullname, ['gif']):
245
+ destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'__temp.{roop.globals.CFG.output_video_format}')
246
+ f.finalname = destination
247
+ videofiles.append(f)
248
+
249
+
250
+ if process_mgr is None:
251
+ process_mgr = ProcessMgr(progress)
252
+ mask = imagemask["layers"][0] if imagemask is not None else None
253
+ if len(roop.globals.INPUT_FACESETS) <= selected_index:
254
+ selected_index = 0
255
+ options = ProcessOptions(get_processing_plugins(use_clip), roop.globals.distance_threshold, roop.globals.blend_ratio, roop.globals.face_swap_mode, selected_index, new_clip_text, mask)
256
+ process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options)
257
+
258
+ if(len(imagefiles) > 0):
259
+ update_status('Processing image(s)')
260
+ origimages = []
261
+ fakeimages = []
262
+ for f in imagefiles:
263
+ origimages.append(f.filename)
264
+ fakeimages.append(f.finalname)
265
+
266
+ process_mgr.run_batch(origimages, fakeimages, roop.globals.execution_threads)
267
+ origimages.clear()
268
+ fakeimages.clear()
269
+
270
+ if(len(videofiles) > 0):
271
+ for index,v in enumerate(videofiles):
272
+ if not roop.globals.processing:
273
+ end_processing('Processing stopped!')
274
+ return
275
+ fps = v.fps if v.fps > 0 else util.detect_fps(v.filename)
276
+ if v.endframe == 0:
277
+ v.endframe = get_video_frame_total(v.filename)
278
+
279
+ update_status(f'Creating {os.path.basename(v.finalname)} with {fps} FPS...')
280
+ start_processing = time()
281
+ if roop.globals.keep_frames or not use_new_method:
282
+ util.create_temp(v.filename)
283
+ update_status('Extracting frames...')
284
+ ffmpeg.extract_frames(v.filename,v.startframe,v.endframe, fps)
285
+ if not roop.globals.processing:
286
+ end_processing('Processing stopped!')
287
+ return
288
+
289
+ temp_frame_paths = util.get_temp_frame_paths(v.filename)
290
+ process_mgr.run_batch(temp_frame_paths, temp_frame_paths, roop.globals.execution_threads)
291
+ if not roop.globals.processing:
292
+ end_processing('Processing stopped!')
293
+ return
294
+ if roop.globals.wait_after_extraction:
295
+ extract_path = os.path.dirname(temp_frame_paths[0])
296
+ util.open_folder(extract_path)
297
+ input("Press any key to continue...")
298
+ print("Resorting frames to create video")
299
+ util.sort_rename_frames(extract_path)
300
+
301
+ ffmpeg.create_video(v.filename, v.finalname, fps)
302
+ if not roop.globals.keep_frames:
303
+ util.delete_temp_frames(temp_frame_paths[0])
304
+ else:
305
+ if util.has_extension(v.filename, ['gif']):
306
+ skip_audio = True
307
+ else:
308
+ skip_audio = roop.globals.skip_audio
309
+ process_mgr.run_batch_inmem(v.filename, v.finalname, v.startframe, v.endframe, fps,roop.globals.execution_threads, skip_audio)
310
+
311
+ if not roop.globals.processing:
312
+ end_processing('Processing stopped!')
313
+ return
314
+
315
+ video_file_name = v.finalname
316
+ if os.path.isfile(video_file_name):
317
+ destination = ''
318
+ if util.has_extension(v.filename, ['gif']):
319
+ gifname = util.get_destfilename_from_path(v.filename, roop.globals.output_path, '.gif')
320
+ destination = util.replace_template(gifname, index=index)
321
+ pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True)
322
+
323
+ update_status('Creating final GIF')
324
+ ffmpeg.create_gif_from_video(video_file_name, destination)
325
+ if os.path.isfile(destination):
326
+ os.remove(video_file_name)
327
+ else:
328
+ skip_audio = roop.globals.skip_audio
329
+ destination = util.replace_template(video_file_name, index=index)
330
+ pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True)
331
+
332
+ if not skip_audio:
333
+ ffmpeg.restore_audio(video_file_name, v.filename, v.startframe, v.endframe, destination)
334
+ if os.path.isfile(destination):
335
+ os.remove(video_file_name)
336
+ else:
337
+ shutil.move(video_file_name, destination)
338
+ update_status(f'\nProcessing {os.path.basename(destination)} took {time() - start_processing} secs')
339
+
340
+ else:
341
+ update_status(f'Failed processing {os.path.basename(v.finalname)}!')
342
+ end_processing('Finished')
343
+
344
+
345
+ def end_processing(msg:str):
346
+ update_status(msg)
347
+ roop.globals.target_folder_path = None
348
+ release_resources()
349
+
350
+
351
+ def destroy() -> None:
352
+ if roop.globals.target_path:
353
+ util.clean_temp(roop.globals.target_path)
354
+ release_resources()
355
+ sys.exit()
356
+
357
+
358
+ def run() -> None:
359
+ parse_args()
360
+ if not pre_check():
361
+ return
362
+ roop.globals.CFG = Settings('config.yaml')
363
+ roop.globals.execution_threads = roop.globals.CFG.max_threads
364
+ roop.globals.video_encoder = roop.globals.CFG.output_video_codec
365
+ roop.globals.video_quality = roop.globals.CFG.video_quality
366
+ roop.globals.max_memory = roop.globals.CFG.memory_limit if roop.globals.CFG.memory_limit > 0 else None
367
+ main.run()
roop/face_util.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Any
3
+ import insightface
4
+
5
+ import roop.globals
6
+ from roop.typing import Frame, Face
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from skimage import transform as trans
11
+ from roop.capturer import get_video_frame
12
+ from roop.utilities import resolve_relative_path, conditional_download
13
+
14
+ FACE_ANALYSER = None
15
+ THREAD_LOCK_ANALYSER = threading.Lock()
16
+ THREAD_LOCK_SWAPPER = threading.Lock()
17
+ FACE_SWAPPER = None
18
+
19
+
20
+ def get_face_analyser() -> Any:
21
+ global FACE_ANALYSER
22
+
23
+ with THREAD_LOCK_ANALYSER:
24
+ if FACE_ANALYSER is None or roop.globals.g_current_face_analysis != roop.globals.g_desired_face_analysis:
25
+ model_path = resolve_relative_path('..')
26
+ # removed genderage
27
+ allowed_modules = roop.globals.g_desired_face_analysis
28
+ roop.globals.g_current_face_analysis = roop.globals.g_desired_face_analysis
29
+ if roop.globals.CFG.force_cpu:
30
+ print("Forcing CPU for Face Analysis")
31
+ FACE_ANALYSER = insightface.app.FaceAnalysis(
32
+ name="buffalo_l",
33
+ root=model_path, providers=["CPUExecutionProvider"],allowed_modules=allowed_modules
34
+ )
35
+ else:
36
+ FACE_ANALYSER = insightface.app.FaceAnalysis(
37
+ name="buffalo_l", root=model_path, providers=roop.globals.execution_providers,allowed_modules=allowed_modules
38
+ )
39
+ FACE_ANALYSER.prepare(
40
+ ctx_id=0,
41
+ det_size=(640, 640) if roop.globals.default_det_size else (320, 320),
42
+ )
43
+ return FACE_ANALYSER
44
+
45
+
46
+ def get_first_face(frame: Frame) -> Any:
47
+ try:
48
+ faces = get_face_analyser().get(frame)
49
+ return min(faces, key=lambda x: x.bbox[0])
50
+ # return sorted(faces, reverse=True, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[0]
51
+ except:
52
+ return None
53
+
54
+
55
+ def get_all_faces(frame: Frame) -> Any:
56
+ try:
57
+ faces = get_face_analyser().get(frame)
58
+ return sorted(faces, key=lambda x: x.bbox[0])
59
+ except:
60
+ return None
61
+
62
+
63
+ def extract_face_images(source_filename, video_info, extra_padding=-1.0):
64
+ face_data = []
65
+ source_image = None
66
+
67
+ if video_info[0]:
68
+ frame = get_video_frame(source_filename, video_info[1])
69
+ if frame is not None:
70
+ source_image = frame
71
+ else:
72
+ return face_data
73
+ else:
74
+ source_image = cv2.imdecode(np.fromfile(source_filename, dtype=np.uint8), cv2.IMREAD_COLOR)
75
+
76
+ faces = get_all_faces(source_image)
77
+ if faces is None:
78
+ return face_data
79
+
80
+ i = 0
81
+ for face in faces:
82
+ (startX, startY, endX, endY) = face["bbox"].astype("int")
83
+ if extra_padding > 0.0:
84
+ if source_image.shape[:2] == (512, 512):
85
+ i += 1
86
+ face_data.append([face, source_image])
87
+ continue
88
+
89
+ found = False
90
+ for i in range(1, 3):
91
+ (startX, startY, endX, endY) = face["bbox"].astype("int")
92
+ cutout_padding = extra_padding
93
+ # top needs extra room for detection
94
+ padding = int((endY - startY) * cutout_padding)
95
+ oldY = startY
96
+ startY -= padding
97
+
98
+ factor = 0.25 if i == 1 else 0.5
99
+ cutout_padding = factor
100
+ padding = int((endY - oldY) * cutout_padding)
101
+ endY += padding
102
+ padding = int((endX - startX) * cutout_padding)
103
+ startX -= padding
104
+ endX += padding
105
+ startX, endX, startY, endY = clamp_cut_values(
106
+ startX, endX, startY, endY, source_image
107
+ )
108
+ face_temp = source_image[startY:endY, startX:endX]
109
+ face_temp = resize_image_keep_content(face_temp)
110
+ testfaces = get_all_faces(face_temp)
111
+ if testfaces is not None and len(testfaces) > 0:
112
+ i += 1
113
+ face_data.append([testfaces[0], face_temp])
114
+ found = True
115
+ break
116
+
117
+ if not found:
118
+ print("No face found after resizing, this shouldn't happen!")
119
+ continue
120
+
121
+ face_temp = source_image[startY:endY, startX:endX]
122
+ if face_temp.size < 1:
123
+ continue
124
+
125
+ i += 1
126
+ face_data.append([face, face_temp])
127
+ return face_data
128
+
129
+
130
+ def clamp_cut_values(startX, endX, startY, endY, image):
131
+ if startX < 0:
132
+ startX = 0
133
+ if endX > image.shape[1]:
134
+ endX = image.shape[1]
135
+ if startY < 0:
136
+ startY = 0
137
+ if endY > image.shape[0]:
138
+ endY = image.shape[0]
139
+ return startX, endX, startY, endY
140
+
141
+
142
+ def get_face_swapper() -> Any:
143
+ global FACE_SWAPPER
144
+
145
+ with THREAD_LOCK_SWAPPER:
146
+ if FACE_SWAPPER is None:
147
+ model_path = resolve_relative_path("../models/inswapper_128.onnx")
148
+ FACE_SWAPPER = insightface.model_zoo.get_model(
149
+ model_path, providers=roop.globals.execution_providers
150
+ )
151
+ return FACE_SWAPPER
152
+
153
+
154
+ def pre_check() -> bool:
155
+ download_directory_path = resolve_relative_path("../models")
156
+ conditional_download(
157
+ download_directory_path,
158
+ ["https://huggingface.co/countfloyd/deepfake/resolve/main/inswapper_128.onnx"],
159
+ )
160
+ return True
161
+
162
+
163
+ def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
164
+ return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True)
165
+
166
+
167
+ def face_offset_top(face: Face, offset):
168
+ face["bbox"][1] += offset
169
+ face["bbox"][3] += offset
170
+ lm106 = face.landmark_2d_106
171
+ add = np.full_like(lm106, [0, offset])
172
+ face["landmark_2d_106"] = lm106 + add
173
+ return face
174
+
175
+
176
+ def resize_image_keep_content(image, new_width=512, new_height=512):
177
+ dim = None
178
+ (h, w) = image.shape[:2]
179
+ if h > w:
180
+ r = new_height / float(h)
181
+ dim = (int(w * r), new_height)
182
+ else:
183
+ # Calculate the ratio of the width and construct the dimensions
184
+ r = new_width / float(w)
185
+ dim = (new_width, int(h * r))
186
+ image = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
187
+ (h, w) = image.shape[:2]
188
+ if h == new_height and w == new_width:
189
+ return image
190
+ resize_img = np.zeros(shape=(new_height, new_width, 3), dtype=image.dtype)
191
+ offs = (new_width - w) if h == new_height else (new_height - h)
192
+ startoffs = int(offs // 2) if offs % 2 == 0 else int(offs // 2) + 1
193
+ offs = int(offs // 2)
194
+
195
+ if h == new_height:
196
+ resize_img[0:new_height, startoffs : new_width - offs] = image
197
+ else:
198
+ resize_img[startoffs : new_height - offs, 0:new_width] = image
199
+ return resize_img
200
+
201
+
202
+ def rotate_image_90(image, rotate=True):
203
+ if rotate:
204
+ return np.rot90(image)
205
+ else:
206
+ return np.rot90(image, 1, (1, 0))
207
+
208
+
209
+ def rotate_anticlockwise(frame):
210
+ return rotate_image_90(frame)
211
+
212
+
213
+ def rotate_clockwise(frame):
214
+ return rotate_image_90(frame, False)
215
+
216
+
217
+ def rotate_image_180(image):
218
+ return np.flip(image, 0)
219
+
220
+
221
+ # alignment code from insightface https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py
222
+
223
+ arcface_dst = np.array(
224
+ [
225
+ [38.2946, 51.6963],
226
+ [73.5318, 51.5014],
227
+ [56.0252, 71.7366],
228
+ [41.5493, 92.3655],
229
+ [70.7299, 92.2041],
230
+ ],
231
+ dtype=np.float32,
232
+ )
233
+
234
+
235
+ def estimate_norm(lmk, image_size=112, mode="arcface"):
236
+ assert lmk.shape == (5, 2)
237
+ assert image_size % 112 == 0 or image_size % 128 == 0
238
+ if image_size % 112 == 0:
239
+ ratio = float(image_size) / 112.0
240
+ diff_x = 0
241
+ else:
242
+ ratio = float(image_size) / 128.0
243
+ diff_x = 8.0 * ratio
244
+ dst = arcface_dst * ratio
245
+ dst[:, 0] += diff_x
246
+ tform = trans.SimilarityTransform()
247
+ tform.estimate(lmk, dst)
248
+ M = tform.params[0:2, :]
249
+ return M
250
+
251
+
252
+ def norm_crop(img, landmark, image_size=112, mode="arcface"):
253
+ M = estimate_norm(landmark, image_size, mode)
254
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
255
+ return warped
256
+
257
+
258
+ # aligned, M = norm_crop2(f[1], face.kps, 512)
259
+ def norm_crop2(img, landmark, image_size=112, mode="arcface"):
260
+ M = estimate_norm(landmark, image_size, mode)
261
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
262
+ return warped, M
263
+
264
+
265
+ def square_crop(im, S):
266
+ if im.shape[0] > im.shape[1]:
267
+ height = S
268
+ width = int(float(im.shape[1]) / im.shape[0] * S)
269
+ scale = float(S) / im.shape[0]
270
+ else:
271
+ width = S
272
+ height = int(float(im.shape[0]) / im.shape[1] * S)
273
+ scale = float(S) / im.shape[1]
274
+ resized_im = cv2.resize(im, (width, height))
275
+ det_im = np.zeros((S, S, 3), dtype=np.uint8)
276
+ det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im
277
+ return det_im, scale
278
+
279
+
280
+ def transform(data, center, output_size, scale, rotation):
281
+ scale_ratio = scale
282
+ rot = float(rotation) * np.pi / 180.0
283
+ # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
284
+ t1 = trans.SimilarityTransform(scale=scale_ratio)
285
+ cx = center[0] * scale_ratio
286
+ cy = center[1] * scale_ratio
287
+ t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
288
+ t3 = trans.SimilarityTransform(rotation=rot)
289
+ t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2))
290
+ t = t1 + t2 + t3 + t4
291
+ M = t.params[0:2]
292
+ cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0)
293
+ return cropped, M
294
+
295
+
296
+ def trans_points2d(pts, M):
297
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
298
+ for i in range(pts.shape[0]):
299
+ pt = pts[i]
300
+ new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
301
+ new_pt = np.dot(M, new_pt)
302
+ # print('new_pt', new_pt.shape, new_pt)
303
+ new_pts[i] = new_pt[0:2]
304
+
305
+ return new_pts
306
+
307
+
308
+ def trans_points3d(pts, M):
309
+ scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
310
+ # print(scale)
311
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
312
+ for i in range(pts.shape[0]):
313
+ pt = pts[i]
314
+ new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
315
+ new_pt = np.dot(M, new_pt)
316
+ # print('new_pt', new_pt.shape, new_pt)
317
+ new_pts[i][0:2] = new_pt[0:2]
318
+ new_pts[i][2] = pts[i][2] * scale
319
+
320
+ return new_pts
321
+
322
+
323
+ def trans_points(pts, M):
324
+ if pts.shape[1] == 2:
325
+ return trans_points2d(pts, M)
326
+ else:
327
+ return trans_points3d(pts, M)
328
+
329
+ def create_blank_image(width, height):
330
+ img = np.zeros((height, width, 4), dtype=np.uint8)
331
+ img[:] = [0,0,0,0]
332
+ return img
333
+
roop/ffmpeg_writer.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FFMPEG_Writer - write set of frames to video file
3
+
4
+ original from
5
+ https://github.com/Zulko/moviepy/blob/master/moviepy/video/io/ffmpeg_writer.py
6
+
7
+ removed unnecessary dependencies
8
+
9
+ The MIT License (MIT)
10
+
11
+ Copyright (c) 2015 Zulko
12
+ Copyright (c) 2023 Janvarev Vladislav
13
+ """
14
+
15
+ import os
16
+ import subprocess as sp
17
+
18
+ PIPE = -1
19
+ STDOUT = -2
20
+ DEVNULL = -3
21
+
22
+ FFMPEG_BINARY = "ffmpeg"
23
+
24
+ class FFMPEG_VideoWriter:
25
+ """ A class for FFMPEG-based video writing.
26
+
27
+ A class to write videos using ffmpeg. ffmpeg will write in a large
28
+ choice of formats.
29
+
30
+ Parameters
31
+ -----------
32
+
33
+ filename
34
+ Any filename like 'video.mp4' etc. but if you want to avoid
35
+ complications it is recommended to use the generic extension
36
+ '.avi' for all your videos.
37
+
38
+ size
39
+ Size (width,height) of the output video in pixels.
40
+
41
+ fps
42
+ Frames per second in the output video file.
43
+
44
+ codec
45
+ FFMPEG codec. It seems that in terms of quality the hierarchy is
46
+ 'rawvideo' = 'png' > 'mpeg4' > 'libx264'
47
+ 'png' manages the same lossless quality as 'rawvideo' but yields
48
+ smaller files. Type ``ffmpeg -codecs`` in a terminal to get a list
49
+ of accepted codecs.
50
+
51
+ Note for default 'libx264': by default the pixel format yuv420p
52
+ is used. If the video dimensions are not both even (e.g. 720x405)
53
+ another pixel format is used, and this can cause problem in some
54
+ video readers.
55
+
56
+ audiofile
57
+ Optional: The name of an audio file that will be incorporated
58
+ to the video.
59
+
60
+ preset
61
+ Sets the time that FFMPEG will take to compress the video. The slower,
62
+ the better the compression rate. Possibilities are: ultrafast,superfast,
63
+ veryfast, faster, fast, medium (default), slow, slower, veryslow,
64
+ placebo.
65
+
66
+ bitrate
67
+ Only relevant for codecs which accept a bitrate. "5000k" offers
68
+ nice results in general.
69
+
70
+ """
71
+
72
+ def __init__(self, filename, size, fps, codec="libx265", crf=14, audiofile=None,
73
+ preset="medium", bitrate=None,
74
+ logfile=None, threads=None, ffmpeg_params=None):
75
+
76
+ if logfile is None:
77
+ logfile = sp.PIPE
78
+
79
+ self.filename = filename
80
+ self.codec = codec
81
+ self.ext = self.filename.split(".")[-1]
82
+ w = size[0] - 1 if size[0] % 2 != 0 else size[0]
83
+ h = size[1] - 1 if size[1] % 2 != 0 else size[1]
84
+
85
+
86
+ # order is important
87
+ cmd = [
88
+ FFMPEG_BINARY,
89
+ '-hide_banner',
90
+ '-hwaccel', 'auto',
91
+ '-y',
92
+ '-loglevel', 'error' if logfile == sp.PIPE else 'info',
93
+ '-f', 'rawvideo',
94
+ '-vcodec', 'rawvideo',
95
+ '-s', '%dx%d' % (size[0], size[1]),
96
+ #'-pix_fmt', 'rgba' if withmask else 'rgb24',
97
+ '-pix_fmt', 'bgr24',
98
+ '-r', str(fps),
99
+ '-an', '-i', '-'
100
+ ]
101
+
102
+ if audiofile is not None:
103
+ cmd.extend([
104
+ '-i', audiofile,
105
+ '-acodec', 'copy'
106
+ ])
107
+
108
+ cmd.extend([
109
+ '-vcodec', codec,
110
+ '-crf', str(crf)
111
+ #'-preset', preset,
112
+ ])
113
+ if ffmpeg_params is not None:
114
+ cmd.extend(ffmpeg_params)
115
+ if bitrate is not None:
116
+ cmd.extend([
117
+ '-b', bitrate
118
+ ])
119
+
120
+ # scale to a resolution divisible by 2 if not even
121
+ cmd.extend(['-vf', f'scale={w}:{h}' if w != size[0] or h != size[1] else 'colorspace=bt709:iall=bt601-6-625:fast=1'])
122
+
123
+ if threads is not None:
124
+ cmd.extend(["-threads", str(threads)])
125
+
126
+ cmd.extend([
127
+ '-pix_fmt', 'yuv420p',
128
+
129
+ ])
130
+ cmd.extend([
131
+ filename
132
+ ])
133
+
134
+ test = str(cmd)
135
+ print(test)
136
+
137
+ popen_params = {"stdout": DEVNULL,
138
+ "stderr": logfile,
139
+ "stdin": sp.PIPE}
140
+
141
+ # This was added so that no extra unwanted window opens on windows
142
+ # when the child process is created
143
+ if os.name == "nt":
144
+ popen_params["creationflags"] = 0x08000000 # CREATE_NO_WINDOW
145
+
146
+ self.proc = sp.Popen(cmd, **popen_params)
147
+
148
+
149
+ def write_frame(self, img_array):
150
+ """ Writes one frame in the file."""
151
+ try:
152
+ #if PY3:
153
+ self.proc.stdin.write(img_array.tobytes())
154
+ # else:
155
+ # self.proc.stdin.write(img_array.tostring())
156
+ except IOError as err:
157
+ _, ffmpeg_error = self.proc.communicate()
158
+ error = (str(err) + ("\n\nroop unleashed error: FFMPEG encountered "
159
+ "the following error while writing file %s:"
160
+ "\n\n %s" % (self.filename, str(ffmpeg_error))))
161
+
162
+ if b"Unknown encoder" in ffmpeg_error:
163
+
164
+ error = error+("\n\nThe video export "
165
+ "failed because FFMPEG didn't find the specified "
166
+ "codec for video encoding (%s). Please install "
167
+ "this codec or change the codec when calling "
168
+ "write_videofile. For instance:\n"
169
+ " >>> clip.write_videofile('myvid.webm', codec='libvpx')")%(self.codec)
170
+
171
+ elif b"incorrect codec parameters ?" in ffmpeg_error:
172
+
173
+ error = error+("\n\nThe video export "
174
+ "failed, possibly because the codec specified for "
175
+ "the video (%s) is not compatible with the given "
176
+ "extension (%s). Please specify a valid 'codec' "
177
+ "argument in write_videofile. This would be 'libx264' "
178
+ "or 'mpeg4' for mp4, 'libtheora' for ogv, 'libvpx for webm. "
179
+ "Another possible reason is that the audio codec was not "
180
+ "compatible with the video codec. For instance the video "
181
+ "extensions 'ogv' and 'webm' only allow 'libvorbis' (default) as a"
182
+ "video codec."
183
+ )%(self.codec, self.ext)
184
+
185
+ elif b"encoder setup failed" in ffmpeg_error:
186
+
187
+ error = error+("\n\nThe video export "
188
+ "failed, possibly because the bitrate you specified "
189
+ "was too high or too low for the video codec.")
190
+
191
+ elif b"Invalid encoder type" in ffmpeg_error:
192
+
193
+ error = error + ("\n\nThe video export failed because the codec "
194
+ "or file extension you provided is not a video")
195
+
196
+
197
+ raise IOError(error)
198
+
199
+ def close(self):
200
+ if self.proc:
201
+ self.proc.stdin.close()
202
+ if self.proc.stderr is not None:
203
+ self.proc.stderr.close()
204
+ self.proc.wait()
205
+
206
+ self.proc = None
207
+
208
+ # Support the Context Manager protocol, to ensure that resources are cleaned up.
209
+
210
+ def __enter__(self):
211
+ return self
212
+
213
+ def __exit__(self, exc_type, exc_value, traceback):
214
+ self.close()
215
+
216
+
217
+
218
+
roop/filters.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ c64_palette = np.array([
5
+ [0, 0, 0],
6
+ [255, 255, 255],
7
+ [0x81, 0x33, 0x38],
8
+ [0x75, 0xce, 0xc8],
9
+ [0x8e, 0x3c, 0x97],
10
+ [0x56, 0xac, 0x4d],
11
+ [0x2e, 0x2c, 0x9b],
12
+ [0xed, 0xf1, 0x71],
13
+ [0x8e, 0x50, 0x29],
14
+ [0x55, 0x38, 0x00],
15
+ [0xc4, 0x6c, 0x71],
16
+ [0x4a, 0x4a, 0x4a],
17
+ [0x7b, 0x7b, 0x7b],
18
+ [0xa9, 0xff, 0x9f],
19
+ [0x70, 0x6d, 0xeb],
20
+ [0xb2, 0xb2, 0xb2]
21
+ ])
22
+
23
+ def fast_quantize_to_palette(image):
24
+ # Simply round the color values to the nearest color in the palette
25
+ palette = c64_palette / 255.0 # Normalize palette
26
+ img_normalized = image / 255.0 # Normalize image
27
+
28
+ # Calculate the index in the palette that is closest to each pixel in the image
29
+ indices = np.sqrt(((img_normalized[:, :, None, :] - palette[None, None, :, :]) ** 2).sum(axis=3)).argmin(axis=2)
30
+ # Map the image to the palette colors
31
+ mapped_image = palette[indices]
32
+
33
+ return (mapped_image * 255).astype(np.uint8) # Denormalize and return the image
34
+
35
+
36
+ '''
37
+ knn = None
38
+
39
+ def quantize_to_palette(image, palette):
40
+ global knn
41
+
42
+ NumColors = 16
43
+ quantized_image = None
44
+ cv2.pyrMeanShiftFiltering(image, NumColors / 4, NumColors / 2, quantized_image, 1, cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_MAX_ITER, 5, 1)
45
+
46
+ palette = c64_palette
47
+ X_query = image.reshape(-1, 3).astype(np.float32)
48
+
49
+ if(knn == None):
50
+ X_index = palette.astype(np.float32)
51
+ knn = cv2.ml.KNearest_create()
52
+ knn.train(X_index, cv2.ml.ROW_SAMPLE, np.arange(len(palette)))
53
+
54
+ ret, results, neighbours, dist = knn.findNearest(X_query, 1)
55
+
56
+ quantized_image = np.array([palette[idx] for idx in neighbours.astype(int)])
57
+ quantized_image = quantized_image.reshape(image.shape)
58
+ return quantized_image.astype(np.uint8)
59
+ '''
roop/globals.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from settings import Settings
2
+ from typing import List
3
+
4
+ source_path = None
5
+ target_path = None
6
+ output_path = None
7
+ target_folder_path = None
8
+
9
+ frame_processors: List[str] = []
10
+ keep_fps = None
11
+ keep_frames = None
12
+ autorotate_faces = None
13
+ vr_mode = None
14
+ skip_audio = None
15
+ wait_after_extraction = None
16
+ many_faces = None
17
+ use_batch = None
18
+ source_face_index = 0
19
+ target_face_index = 0
20
+ face_position = None
21
+ video_encoder = None
22
+ video_quality = None
23
+ max_memory = None
24
+ execution_providers: List[str] = []
25
+ execution_threads = None
26
+ headless = None
27
+ log_level = 'error'
28
+ selected_enhancer = None
29
+ face_swap_mode = None
30
+ blend_ratio = 0.5
31
+ distance_threshold = 0.65
32
+ default_det_size = True
33
+
34
+ no_face_action = 0
35
+
36
+ processing = False
37
+
38
+ g_current_face_analysis = None
39
+ g_desired_face_analysis = None
40
+
41
+ FACE_ENHANCER = None
42
+
43
+ INPUT_FACESETS = []
44
+ TARGET_FACES = []
45
+
46
+
47
+ IMAGE_CHAIN_PROCESSOR = None
48
+ VIDEO_CHAIN_PROCESSOR = None
49
+ BATCH_IMAGE_CHAIN_PROCESSOR = None
50
+
51
+ CFG: Settings = None
52
+
53
+
roop/metadata.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ name = 'roop unleashed'
2
+ version = '3.6.7'
roop/processors/Enhance_CodeFormer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import threading
4
+ import numpy as np
5
+ import onnxruntime
6
+ import onnx
7
+ import roop.globals
8
+
9
+ from roop.typing import Face, Frame, FaceSet
10
+ from roop.utilities import resolve_relative_path
11
+
12
+
13
+ # THREAD_LOCK = threading.Lock()
14
+
15
+
16
+ class Enhance_CodeFormer():
17
+ model_codeformer = None
18
+ devicename = None
19
+
20
+ processorname = 'codeformer'
21
+ type = 'enhance'
22
+
23
+
24
+ def Initialize(self, devicename:str):
25
+ if self.model_codeformer is None:
26
+ # replace Mac mps with cpu for the moment
27
+ devicename = devicename.replace('mps', 'cpu')
28
+ self.devicename = devicename
29
+ model_path = resolve_relative_path('../models/CodeFormer/CodeFormerv0.1.onnx')
30
+ self.model_codeformer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
31
+ self.model_inputs = self.model_codeformer.get_inputs()
32
+ model_outputs = self.model_codeformer.get_outputs()
33
+ self.io_binding = self.model_codeformer.io_binding()
34
+ self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5]))
35
+ self.io_binding.bind_output(model_outputs[0].name, self.devicename)
36
+
37
+
38
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
39
+ input_size = temp_frame.shape[1]
40
+ # preprocess
41
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
42
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
43
+ temp_frame = temp_frame.astype('float32') / 255.0
44
+ temp_frame = (temp_frame - 0.5) / 0.5
45
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
46
+
47
+ self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame.astype(np.float32))
48
+ self.model_codeformer.run_with_iobinding(self.io_binding)
49
+ ort_outs = self.io_binding.copy_outputs_to_cpu()
50
+ result = ort_outs[0][0]
51
+ del ort_outs
52
+
53
+ # post-process
54
+ result = result.transpose((1, 2, 0))
55
+
56
+ un_min = -1.0
57
+ un_max = 1.0
58
+ result = np.clip(result, un_min, un_max)
59
+ result = (result - un_min) / (un_max - un_min)
60
+
61
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
62
+ result = (result * 255.0).round()
63
+ scale_factor = int(result.shape[1] / input_size)
64
+ return result.astype(np.uint8), scale_factor
65
+
66
+
67
+ def Release(self):
68
+ del self.model_codeformer
69
+ self.model_codeformer = None
70
+ del self.io_binding
71
+ self.io_binding = None
72
+
roop/processors/Enhance_DMDNet.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.nn.utils.spectral_norm as SpectralNorm
8
+ import threading
9
+ from torchvision.ops import roi_align
10
+
11
+ from math import sqrt
12
+
13
+ from torchvision.transforms.functional import normalize
14
+
15
+ from roop.typing import Face, Frame, FaceSet
16
+
17
+
18
+ THREAD_LOCK_DMDNET = threading.Lock()
19
+
20
+
21
+ class Enhance_DMDNet():
22
+
23
+ model_dmdnet = None
24
+ torchdevice = None
25
+
26
+ processorname = 'dmdnet'
27
+ type = 'enhance'
28
+
29
+
30
+ def Initialize(self, devicename):
31
+ if self.model_dmdnet is None:
32
+ self.model_dmdnet = self.create(devicename)
33
+
34
+
35
+ # temp_frame already cropped+aligned, bbox not
36
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
37
+ input_size = temp_frame.shape[1]
38
+
39
+ result = self.enhance_face(source_faceset, temp_frame, target_face)
40
+ scale_factor = int(result.shape[1] / input_size)
41
+ return result.astype(np.uint8), scale_factor
42
+
43
+
44
+ def Release(self):
45
+ self.model_gfpgan = None
46
+
47
+
48
+ # https://stackoverflow.com/a/67174339
49
+ def landmarks106_to_68(self, pt106):
50
+ map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17,
51
+ 43,48,49,51,50,
52
+ 102,103,104,105,101,
53
+ 72,73,74,86,78,79,80,85,84,
54
+ 35,41,42,39,37,36,
55
+ 89,95,96,93,91,90,
56
+ 52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54
57
+ ]
58
+
59
+ pt68 = []
60
+ for i in range(68):
61
+ index = map106to68[i]
62
+ pt68.append(pt106[index])
63
+ return pt68
64
+
65
+
66
+
67
+
68
+ def check_bbox(self, imgs, boxes):
69
+ boxes = boxes.view(-1, 4, 4)
70
+ colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)]
71
+ i = 0
72
+ for img, box in zip(imgs, boxes):
73
+ img = (img + 1)/2 * 255
74
+ img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy()
75
+ for idx, point in enumerate(box):
76
+ cv2.rectangle(img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2)
77
+ cv2.imwrite('dmdnet_{:02d}.png'.format(i), img2)
78
+ i += 1
79
+
80
+
81
+ def trans_points2d(self, pts, M):
82
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
83
+ for i in range(pts.shape[0]):
84
+ pt = pts[i]
85
+ new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
86
+ new_pt = np.dot(M, new_pt)
87
+ new_pts[i] = new_pt[0:2]
88
+
89
+ return new_pts
90
+
91
+
92
+ def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face):
93
+ # preprocess
94
+ start_x, start_y, end_x, end_y = map(int, face['bbox'])
95
+ lm106 = face.landmark_2d_106
96
+ lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
97
+
98
+ if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512:
99
+ # scale to 512x512
100
+ scale_factor = 512 / temp_frame.shape[1]
101
+
102
+ M = face.matrix * scale_factor
103
+
104
+ lq_landmarks = self.trans_points2d(lq_landmarks, M)
105
+ temp_frame = cv2.resize(temp_frame, (512,512), interpolation = cv2.INTER_AREA)
106
+
107
+ if temp_frame.ndim == 2:
108
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
109
+ # else:
110
+ # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
111
+
112
+ lq = read_img_tensor(temp_frame)
113
+
114
+ LQLocs = get_component_location(lq_landmarks)
115
+ # self.check_bbox(lq, LQLocs.unsqueeze(0))
116
+
117
+ # specific, change 1000 to 1 to activate
118
+ if len(ref_faceset.faces) > 1:
119
+ SpecificImgs = []
120
+ SpecificLocs = []
121
+ for i,face in enumerate(ref_faceset.faces):
122
+ lm106 = face.landmark_2d_106
123
+ lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
124
+ ref_image = ref_faceset.ref_images[i]
125
+ if ref_image.shape[0] != 512 or ref_image.shape[1] != 512:
126
+ # scale to 512x512
127
+ scale_factor = 512 / ref_image.shape[1]
128
+
129
+ M = face.matrix * scale_factor
130
+
131
+ lq_landmarks = self.trans_points2d(lq_landmarks, M)
132
+ ref_image = cv2.resize(ref_image, (512,512), interpolation = cv2.INTER_AREA)
133
+
134
+ if ref_image.ndim == 2:
135
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
136
+ # else:
137
+ # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
138
+
139
+ ref_tensor = read_img_tensor(ref_image)
140
+ ref_locs = get_component_location(lq_landmarks)
141
+ # self.check_bbox(ref_tensor, ref_locs.unsqueeze(0))
142
+
143
+ SpecificImgs.append(ref_tensor)
144
+ SpecificLocs.append(ref_locs.unsqueeze(0))
145
+
146
+ SpecificImgs = torch.cat(SpecificImgs, dim=0)
147
+ SpecificLocs = torch.cat(SpecificLocs, dim=0)
148
+ # check_bbox(SpecificImgs, SpecificLocs)
149
+ SpMem256, SpMem128, SpMem64 = self.model_dmdnet.generate_specific_dictionary(sp_imgs = SpecificImgs.to(self.torchdevice), sp_locs = SpecificLocs)
150
+ SpMem256Para = {}
151
+ SpMem128Para = {}
152
+ SpMem64Para = {}
153
+ for k, v in SpMem256.items():
154
+ SpMem256Para[k] = v
155
+ for k, v in SpMem128.items():
156
+ SpMem128Para[k] = v
157
+ for k, v in SpMem64.items():
158
+ SpMem64Para[k] = v
159
+ else:
160
+ # generic
161
+ SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
162
+
163
+ with torch.no_grad():
164
+ with THREAD_LOCK_DMDNET:
165
+ try:
166
+ GenericResult, SpecificResult = self.model_dmdnet(lq = lq.to(self.torchdevice), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para)
167
+ except Exception as e:
168
+ print(f'Error {e} there may be something wrong with the detected component locations.')
169
+ return temp_frame
170
+
171
+ if SpecificResult is not None:
172
+ save_specific = SpecificResult * 0.5 + 0.5
173
+ save_specific = save_specific.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
174
+ save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0
175
+ temp_frame = save_specific.astype("uint8")
176
+ if False:
177
+ save_generic = GenericResult * 0.5 + 0.5
178
+ save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
179
+ save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
180
+ check_lq = lq * 0.5 + 0.5
181
+ check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
182
+ check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
183
+ cv2.imwrite('dmdnet_comparison.png', cv2.cvtColor(np.hstack((check_lq, save_generic, save_specific)),cv2.COLOR_RGB2BGR))
184
+ else:
185
+ save_generic = GenericResult * 0.5 + 0.5
186
+ save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
187
+ save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
188
+ temp_frame = save_generic.astype("uint8")
189
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB
190
+ return temp_frame
191
+
192
+
193
+
194
+ def create(self, devicename):
195
+ self.torchdevice = torch.device(devicename)
196
+ model_dmdnet = DMDNet().to(self.torchdevice)
197
+ weights = torch.load('./models/DMDNet.pth')
198
+ model_dmdnet.load_state_dict(weights, strict=True)
199
+
200
+ model_dmdnet.eval()
201
+ num_params = 0
202
+ for param in model_dmdnet.parameters():
203
+ num_params += param.numel()
204
+ return model_dmdnet
205
+
206
+ # print('{:>8s} : {}'.format('Using device', device))
207
+ # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
208
+
209
+
210
+
211
+ def read_img_tensor(Img=None): #rgb -1~1
212
+ Img = Img.transpose((2, 0, 1))/255.0
213
+ Img = torch.from_numpy(Img).float()
214
+ normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True)
215
+ ImgTensor = Img.unsqueeze(0)
216
+ return ImgTensor
217
+
218
+
219
+ def get_component_location(Landmarks, re_read=False):
220
+ if re_read:
221
+ ReadLandmark = []
222
+ with open(Landmarks,'r') as f:
223
+ for line in f:
224
+ tmp = [float(i) for i in line.split(' ') if i != '\n']
225
+ ReadLandmark.append(tmp)
226
+ ReadLandmark = np.array(ReadLandmark) #
227
+ Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
228
+ Map_LE_B = list(np.hstack((range(17,22), range(36,42))))
229
+ Map_RE_B = list(np.hstack((range(22,27), range(42,48))))
230
+ Map_LE = list(range(36,42))
231
+ Map_RE = list(range(42,48))
232
+ Map_NO = list(range(29,36))
233
+ Map_MO = list(range(48,68))
234
+
235
+ Landmarks[Landmarks>504]=504
236
+ Landmarks[Landmarks<8]=8
237
+
238
+ #left eye
239
+ Mean_LE = np.mean(Landmarks[Map_LE],0)
240
+ L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1])
241
+ L_LE1 = L_LE1 * 1.3
242
+ L_LE2 = L_LE1 / 1.9
243
+ L_LE_xy = L_LE1 + L_LE2
244
+ L_LE_lt = [L_LE_xy/2, L_LE1]
245
+ L_LE_rb = [L_LE_xy/2, L_LE2]
246
+ Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
247
+
248
+ #right eye
249
+ Mean_RE = np.mean(Landmarks[Map_RE],0)
250
+ L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1])
251
+ L_RE1 = L_RE1 * 1.3
252
+ L_RE2 = L_RE1 / 1.9
253
+ L_RE_xy = L_RE1 + L_RE2
254
+ L_RE_lt = [L_RE_xy/2, L_RE1]
255
+ L_RE_rb = [L_RE_xy/2, L_RE2]
256
+ Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
257
+
258
+ #nose
259
+ Mean_NO = np.mean(Landmarks[Map_NO],0)
260
+ L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25
261
+ L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
262
+ L_NO_xy = L_NO1 * 2
263
+ L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2]
264
+ L_NO_rb = [L_NO_xy/2, L_NO2]
265
+ Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
266
+
267
+ #mouth
268
+ Mean_MO = np.mean(Landmarks[Map_MO],0)
269
+ L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1
270
+ MO_O = Mean_MO - L_MO + 1
271
+ MO_T = Mean_MO + L_MO
272
+ MO_T[MO_T>510]=510
273
+ Location_MO = np.hstack((MO_O, MO_T)).astype(int)
274
+ return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0)
275
+
276
+
277
+
278
+
279
+ def calc_mean_std_4D(feat, eps=1e-5):
280
+ # eps is a small value added to the variance to avoid divide-by-zero.
281
+ size = feat.size()
282
+ assert (len(size) == 4)
283
+ N, C = size[:2]
284
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
285
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
286
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
287
+ return feat_mean, feat_std
288
+
289
+ def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
290
+ size = content_feat.size()
291
+ style_mean, style_std = calc_mean_std_4D(style_feat)
292
+ content_mean, content_std = calc_mean_std_4D(content_feat)
293
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
294
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
295
+
296
+
297
+ def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
298
+ return nn.Sequential(
299
+ SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
300
+ nn.LeakyReLU(0.2),
301
+ SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
302
+ )
303
+
304
+
305
+ class MSDilateBlock(nn.Module):
306
+ def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
307
+ super(MSDilateBlock, self).__init__()
308
+ self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
309
+ self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
310
+ self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
311
+ self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
312
+ self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
313
+ def forward(self, x):
314
+ conv1 = self.conv1(x)
315
+ conv2 = self.conv2(x)
316
+ conv3 = self.conv3(x)
317
+ conv4 = self.conv4(x)
318
+ cat = torch.cat([conv1, conv2, conv3, conv4], 1)
319
+ out = self.convi(cat) + x
320
+ return out
321
+
322
+
323
+ class AdaptiveInstanceNorm(nn.Module):
324
+ def __init__(self, in_channel):
325
+ super().__init__()
326
+ self.norm = nn.InstanceNorm2d(in_channel)
327
+
328
+ def forward(self, input, style):
329
+ style_mean, style_std = calc_mean_std_4D(style)
330
+ out = self.norm(input)
331
+ size = input.size()
332
+ out = style_std.expand(size) * out + style_mean.expand(size)
333
+ return out
334
+
335
+ class NoiseInjection(nn.Module):
336
+ def __init__(self, channel):
337
+ super().__init__()
338
+ self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
339
+ def forward(self, image, noise):
340
+ if noise is None:
341
+ b, c, h, w = image.shape
342
+ noise = image.new_empty(b, 1, h, w).normal_()
343
+ return image + self.weight * noise
344
+
345
+ class StyledUpBlock(nn.Module):
346
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False):
347
+ super().__init__()
348
+
349
+ self.noise_inject = noise_inject
350
+ if upsample:
351
+ self.conv1 = nn.Sequential(
352
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
353
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
354
+ nn.LeakyReLU(0.2),
355
+ )
356
+ else:
357
+ self.conv1 = nn.Sequential(
358
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
359
+ nn.LeakyReLU(0.2),
360
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
361
+ )
362
+ self.convup = nn.Sequential(
363
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
364
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
365
+ nn.LeakyReLU(0.2),
366
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
367
+ )
368
+ if self.noise_inject:
369
+ self.noise1 = NoiseInjection(out_channel)
370
+
371
+ self.lrelu1 = nn.LeakyReLU(0.2)
372
+
373
+ self.ScaleModel1 = nn.Sequential(
374
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
375
+ nn.LeakyReLU(0.2),
376
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
377
+ )
378
+ self.ShiftModel1 = nn.Sequential(
379
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
380
+ nn.LeakyReLU(0.2),
381
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
382
+ )
383
+
384
+ def forward(self, input, style):
385
+ out = self.conv1(input)
386
+ out = self.lrelu1(out)
387
+ Shift1 = self.ShiftModel1(style)
388
+ Scale1 = self.ScaleModel1(style)
389
+ out = out * Scale1 + Shift1
390
+ if self.noise_inject:
391
+ out = self.noise1(out, noise=None)
392
+ outup = self.convup(out)
393
+ return outup
394
+
395
+
396
+ ####################################################################
397
+ ###############Face Dictionary Generator
398
+ ####################################################################
399
+ def AttentionBlock(in_channel):
400
+ return nn.Sequential(
401
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
402
+ nn.LeakyReLU(0.2),
403
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
404
+ )
405
+
406
+ class DilateResBlock(nn.Module):
407
+ def __init__(self, dim, dilation=[5,3] ):
408
+ super(DilateResBlock, self).__init__()
409
+ self.Res = nn.Sequential(
410
+ SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])),
411
+ nn.LeakyReLU(0.2),
412
+ SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])),
413
+ )
414
+ def forward(self, x):
415
+ out = x + self.Res(x)
416
+ return out
417
+
418
+
419
+ class KeyValue(nn.Module):
420
+ def __init__(self, indim, keydim, valdim):
421
+ super(KeyValue, self).__init__()
422
+ self.Key = nn.Sequential(
423
+ SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
424
+ nn.LeakyReLU(0.2),
425
+ SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
426
+ )
427
+ self.Value = nn.Sequential(
428
+ SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
429
+ nn.LeakyReLU(0.2),
430
+ SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
431
+ )
432
+ def forward(self, x):
433
+ return self.Key(x), self.Value(x)
434
+
435
+ class MaskAttention(nn.Module):
436
+ def __init__(self, indim):
437
+ super(MaskAttention, self).__init__()
438
+ self.conv1 = nn.Sequential(
439
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
440
+ nn.LeakyReLU(0.2),
441
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
442
+ )
443
+ self.conv2 = nn.Sequential(
444
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
445
+ nn.LeakyReLU(0.2),
446
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
447
+ )
448
+ self.conv3 = nn.Sequential(
449
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
450
+ nn.LeakyReLU(0.2),
451
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
452
+ )
453
+ self.convCat = nn.Sequential(
454
+ SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
455
+ nn.LeakyReLU(0.2),
456
+ SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
457
+ )
458
+ def forward(self, x, y, z):
459
+ c1 = self.conv1(x)
460
+ c2 = self.conv2(y)
461
+ c3 = self.conv3(z)
462
+ return self.convCat(torch.cat([c1,c2,c3], dim=1))
463
+
464
+ class Query(nn.Module):
465
+ def __init__(self, indim, quedim):
466
+ super(Query, self).__init__()
467
+ self.Query = nn.Sequential(
468
+ SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
469
+ nn.LeakyReLU(0.2),
470
+ SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
471
+ )
472
+ def forward(self, x):
473
+ return self.Query(x)
474
+
475
+ def roi_align_self(input, location, target_size):
476
+ test = (target_size.item(),target_size.item())
477
+ return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],test,mode='bilinear',align_corners=False) for i in range(input.size(0))],0)
478
+
479
+ class FeatureExtractor(nn.Module):
480
+ def __init__(self, ngf = 64, key_scale = 4):#
481
+ super().__init__()
482
+
483
+ self.key_scale = 4
484
+ self.part_sizes = np.array([80,80,50,110]) #
485
+ self.feature_sizes = np.array([256,128,64]) #
486
+
487
+ self.conv1 = nn.Sequential(
488
+ SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
489
+ nn.LeakyReLU(0.2),
490
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
491
+ )
492
+ self.conv2 = nn.Sequential(
493
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
494
+ nn.LeakyReLU(0.2),
495
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1))
496
+ )
497
+ self.res1 = DilateResBlock(ngf, [5,3])
498
+ self.res2 = DilateResBlock(ngf, [5,3])
499
+
500
+
501
+ self.conv3 = nn.Sequential(
502
+ SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)),
503
+ nn.LeakyReLU(0.2),
504
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
505
+ )
506
+ self.conv4 = nn.Sequential(
507
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
508
+ nn.LeakyReLU(0.2),
509
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1))
510
+ )
511
+ self.res3 = DilateResBlock(ngf*2, [3,1])
512
+ self.res4 = DilateResBlock(ngf*2, [3,1])
513
+
514
+ self.conv5 = nn.Sequential(
515
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)),
516
+ nn.LeakyReLU(0.2),
517
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
518
+ )
519
+ self.conv6 = nn.Sequential(
520
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
521
+ nn.LeakyReLU(0.2),
522
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1))
523
+ )
524
+ self.res5 = DilateResBlock(ngf*4, [1,1])
525
+ self.res6 = DilateResBlock(ngf*4, [1,1])
526
+
527
+ self.LE_256_Q = Query(ngf, ngf // self.key_scale)
528
+ self.RE_256_Q = Query(ngf, ngf // self.key_scale)
529
+ self.MO_256_Q = Query(ngf, ngf // self.key_scale)
530
+ self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
531
+ self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
532
+ self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
533
+ self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
534
+ self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
535
+ self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
536
+
537
+
538
+ def forward(self, img, locs):
539
+ le_location = locs[:,0,:].int().cpu().numpy()
540
+ re_location = locs[:,1,:].int().cpu().numpy()
541
+ no_location = locs[:,2,:].int().cpu().numpy()
542
+ mo_location = locs[:,3,:].int().cpu().numpy()
543
+
544
+
545
+ f1_0 = self.conv1(img)
546
+ f1_1 = self.res1(f1_0)
547
+ f2_0 = self.conv2(f1_1)
548
+ f2_1 = self.res2(f2_0)
549
+
550
+ f3_0 = self.conv3(f2_1)
551
+ f3_1 = self.res3(f3_0)
552
+ f4_0 = self.conv4(f3_1)
553
+ f4_1 = self.res4(f4_0)
554
+
555
+ f5_0 = self.conv5(f4_1)
556
+ f5_1 = self.res5(f5_0)
557
+ f6_0 = self.conv6(f5_1)
558
+ f6_1 = self.res6(f6_0)
559
+
560
+
561
+ ####ROI Align
562
+ le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2)
563
+ re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2)
564
+ mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2)
565
+
566
+ le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4)
567
+ re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4)
568
+ mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4)
569
+
570
+ le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8)
571
+ re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8)
572
+ mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8)
573
+
574
+
575
+ le_256_q = self.LE_256_Q(le_part_256)
576
+ re_256_q = self.RE_256_Q(re_part_256)
577
+ mo_256_q = self.MO_256_Q(mo_part_256)
578
+
579
+ le_128_q = self.LE_128_Q(le_part_128)
580
+ re_128_q = self.RE_128_Q(re_part_128)
581
+ mo_128_q = self.MO_128_Q(mo_part_128)
582
+
583
+ le_64_q = self.LE_64_Q(le_part_64)
584
+ re_64_q = self.RE_64_Q(re_part_64)
585
+ mo_64_q = self.MO_64_Q(mo_part_64)
586
+
587
+ return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\
588
+ 'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \
589
+ 'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \
590
+ 'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \
591
+ 'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\
592
+ 'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\
593
+ 'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q}
594
+
595
+
596
+ class DMDNet(nn.Module):
597
+ def __init__(self, ngf = 64, banks_num = 128):
598
+ super().__init__()
599
+ self.part_sizes = np.array([80,80,50,110]) # size for 512
600
+ self.feature_sizes = np.array([256,128,64]) # size for 512
601
+
602
+ self.banks_num = banks_num
603
+ self.key_scale = 4
604
+
605
+ self.E_lq = FeatureExtractor(key_scale = self.key_scale)
606
+ self.E_hq = FeatureExtractor(key_scale = self.key_scale)
607
+
608
+ self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
609
+ self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
610
+ self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
611
+
612
+ self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
613
+ self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
614
+ self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
615
+
616
+ self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
617
+ self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
618
+ self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
619
+
620
+
621
+ self.LE_256_Attention = AttentionBlock(64)
622
+ self.RE_256_Attention = AttentionBlock(64)
623
+ self.MO_256_Attention = AttentionBlock(64)
624
+
625
+ self.LE_128_Attention = AttentionBlock(128)
626
+ self.RE_128_Attention = AttentionBlock(128)
627
+ self.MO_128_Attention = AttentionBlock(128)
628
+
629
+ self.LE_64_Attention = AttentionBlock(256)
630
+ self.RE_64_Attention = AttentionBlock(256)
631
+ self.MO_64_Attention = AttentionBlock(256)
632
+
633
+ self.LE_256_Mask = MaskAttention(64)
634
+ self.RE_256_Mask = MaskAttention(64)
635
+ self.MO_256_Mask = MaskAttention(64)
636
+
637
+ self.LE_128_Mask = MaskAttention(128)
638
+ self.RE_128_Mask = MaskAttention(128)
639
+ self.MO_128_Mask = MaskAttention(128)
640
+
641
+ self.LE_64_Mask = MaskAttention(256)
642
+ self.RE_64_Mask = MaskAttention(256)
643
+ self.MO_64_Mask = MaskAttention(256)
644
+
645
+ self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1])
646
+
647
+ self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) #
648
+ self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) #
649
+ self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
650
+ self.up4 = nn.Sequential(
651
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
652
+ nn.LeakyReLU(0.2),
653
+ UpResBlock(ngf),
654
+ UpResBlock(ngf),
655
+ SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
656
+ nn.Tanh()
657
+ )
658
+
659
+ # define generic memory, revise register_buffer to register_parameter for backward update
660
+ self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40))
661
+ self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40))
662
+ self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55))
663
+ self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40))
664
+ self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40))
665
+ self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55))
666
+
667
+
668
+ self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20))
669
+ self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20))
670
+ self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27))
671
+ self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20))
672
+ self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20))
673
+ self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27))
674
+
675
+ self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10))
676
+ self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10))
677
+ self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13))
678
+ self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10))
679
+ self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10))
680
+ self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13))
681
+
682
+
683
+ def readMem(self, k, v, q):
684
+ sim = F.conv2d(q, k)
685
+ score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128
686
+ sb,sn,sw,sh = score.size()
687
+ s_m = score.view(sb, -1).unsqueeze(1)#2*1*M
688
+ vb,vn,vw,vh = v.size()
689
+ v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h)
690
+ mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh)
691
+ max_inds = torch.argmax(score, dim=1).squeeze()
692
+ return mem_out, max_inds
693
+
694
+
695
+ def memorize(self, img, locs):
696
+ fs = self.E_hq(img, locs)
697
+ LE256_key, LE256_value = self.LE_256_KV(fs['le256'])
698
+ RE256_key, RE256_value = self.RE_256_KV(fs['re256'])
699
+ MO256_key, MO256_value = self.MO_256_KV(fs['mo256'])
700
+
701
+ LE128_key, LE128_value = self.LE_128_KV(fs['le128'])
702
+ RE128_key, RE128_value = self.RE_128_KV(fs['re128'])
703
+ MO128_key, MO128_value = self.MO_128_KV(fs['mo128'])
704
+
705
+ LE64_key, LE64_value = self.LE_64_KV(fs['le64'])
706
+ RE64_key, RE64_value = self.RE_64_KV(fs['re64'])
707
+ MO64_key, MO64_value = self.MO_64_KV(fs['mo64'])
708
+
709
+ Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value}
710
+ Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value}
711
+ Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value}
712
+
713
+ FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']}
714
+ FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']}
715
+ FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']}
716
+
717
+ return Mem256, Mem128, Mem64
718
+
719
+ def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
720
+ le_256_q = fs_in['le_256_q']
721
+ re_256_q = fs_in['re_256_q']
722
+ mo_256_q = fs_in['mo_256_q']
723
+
724
+ le_128_q = fs_in['le_128_q']
725
+ re_128_q = fs_in['re_128_q']
726
+ mo_128_q = fs_in['mo_128_q']
727
+
728
+ le_64_q = fs_in['le_64_q']
729
+ re_64_q = fs_in['re_64_q']
730
+ mo_64_q = fs_in['mo_64_q']
731
+
732
+
733
+ ####for 256
734
+ le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q)
735
+ re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q)
736
+ mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q)
737
+
738
+ le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q)
739
+ re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q)
740
+ mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q)
741
+
742
+ le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q)
743
+ re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q)
744
+ mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q)
745
+
746
+ if sp_256 is not None and sp_128 is not None and sp_64 is not None:
747
+ le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q)
748
+ re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q)
749
+ mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q)
750
+ le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g)
751
+ le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g
752
+ re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g)
753
+ re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g
754
+ mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g)
755
+ mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g
756
+
757
+ le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q)
758
+ re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q)
759
+ mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q)
760
+ le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g)
761
+ le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g
762
+ re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g)
763
+ re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g
764
+ mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g)
765
+ mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g
766
+
767
+ le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q)
768
+ re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q)
769
+ mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q)
770
+ le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g)
771
+ le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g
772
+ re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g)
773
+ re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g
774
+ mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g)
775
+ mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g
776
+ else:
777
+ le_256_mem = le_256_mem_g
778
+ re_256_mem = re_256_mem_g
779
+ mo_256_mem = mo_256_mem_g
780
+ le_128_mem = le_128_mem_g
781
+ re_128_mem = re_128_mem_g
782
+ mo_128_mem = mo_128_mem_g
783
+ le_64_mem = le_64_mem_g
784
+ re_64_mem = re_64_mem_g
785
+ mo_64_mem = mo_64_mem_g
786
+
787
+ le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256'])
788
+ re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256'])
789
+ mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256'])
790
+
791
+ ####for 128
792
+ le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128'])
793
+ re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128'])
794
+ mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128'])
795
+
796
+ ####for 64
797
+ le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64'])
798
+ re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64'])
799
+ mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64'])
800
+
801
+
802
+ EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm}
803
+ EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm}
804
+ EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm}
805
+ Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds}
806
+ Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds}
807
+ Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds}
808
+ return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
809
+
810
+ def reconstruct(self, fs_in, locs, memstar):
811
+ le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm']
812
+ le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm']
813
+ le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm']
814
+
815
+ le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256']
816
+ re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256']
817
+ mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256']
818
+
819
+ le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128']
820
+ re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128']
821
+ mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128']
822
+
823
+ le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64']
824
+ re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64']
825
+ mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64']
826
+
827
+
828
+ le_location = locs[:,0,:]
829
+ re_location = locs[:,1,:]
830
+ mo_location = locs[:,3,:]
831
+
832
+ # Somehow with latest Torch it doesn't like numpy wrappers anymore
833
+
834
+ # le_location = le_location.cpu().int().numpy()
835
+ # re_location = re_location.cpu().int().numpy()
836
+ # mo_location = mo_location.cpu().int().numpy()
837
+ le_location = le_location.cpu().int()
838
+ re_location = re_location.cpu().int()
839
+ mo_location = mo_location.cpu().int()
840
+
841
+ up_in_256 = fs_in['f256'].clone()# * 0
842
+ up_in_128 = fs_in['f128'].clone()# * 0
843
+ up_in_64 = fs_in['f64'].clone()# * 0
844
+
845
+ for i in range(fs_in['f256'].size(0)):
846
+ up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False)
847
+ up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False)
848
+ up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False)
849
+
850
+ up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False)
851
+ up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False)
852
+ up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False)
853
+
854
+ up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False)
855
+ up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False)
856
+ up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False)
857
+
858
+ ms_in_64 = self.MSDilate(fs_in['f64'].clone())
859
+ fea_up1 = self.up1(ms_in_64, up_in_64)
860
+ fea_up2 = self.up2(fea_up1, up_in_128) #
861
+ fea_up3 = self.up3(fea_up2, up_in_256) #
862
+ output = self.up4(fea_up3) #
863
+ return output
864
+
865
+ def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
866
+ return self.memorize(sp_imgs, sp_locs)
867
+
868
+ def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None):
869
+ try:
870
+ fs_in = self.E_lq(lq, loc) # low quality images
871
+ except Exception as e:
872
+ print(e)
873
+
874
+ GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in)
875
+ GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64])
876
+ if sp_256 is not None and sp_128 is not None and sp_64 is not None:
877
+ GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64)
878
+ GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64])
879
+ else:
880
+ GSOut = None
881
+ return GeOut, GSOut
882
+
883
+ class UpResBlock(nn.Module):
884
+ def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
885
+ super(UpResBlock, self).__init__()
886
+ self.Model = nn.Sequential(
887
+ SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
888
+ nn.LeakyReLU(0.2),
889
+ SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
890
+ )
891
+ def forward(self, x):
892
+ out = x + self.Model(x)
893
+ return out
roop/processors/Enhance_GFPGAN.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import roop.globals
6
+
7
+ from roop.typing import Face, Frame, FaceSet
8
+ from roop.utilities import resolve_relative_path
9
+
10
+
11
+ # THREAD_LOCK = threading.Lock()
12
+
13
+
14
+ class Enhance_GFPGAN():
15
+
16
+ model_gfpgan = None
17
+ name = None
18
+ devicename = None
19
+
20
+ processorname = 'gfpgan'
21
+ type = 'enhance'
22
+
23
+
24
+ def Initialize(self, devicename):
25
+ if self.model_gfpgan is None:
26
+ model_path = resolve_relative_path('../models/GFPGANv1.4.onnx')
27
+ self.model_gfpgan = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
28
+ # replace Mac mps with cpu for the moment
29
+ devicename = devicename.replace('mps', 'cpu')
30
+ self.devicename = devicename
31
+
32
+ self.name = self.model_gfpgan.get_inputs()[0].name
33
+
34
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
35
+ # preprocess
36
+ input_size = temp_frame.shape[1]
37
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
38
+
39
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
40
+ temp_frame = temp_frame.astype('float32') / 255.0
41
+ temp_frame = (temp_frame - 0.5) / 0.5
42
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
43
+
44
+ io_binding = self.model_gfpgan.io_binding()
45
+ io_binding.bind_cpu_input("input", temp_frame)
46
+ io_binding.bind_output("1288", self.devicename)
47
+ self.model_gfpgan.run_with_iobinding(io_binding)
48
+ ort_outs = io_binding.copy_outputs_to_cpu()
49
+ result = ort_outs[0][0]
50
+
51
+ # post-process
52
+ result = np.clip(result, -1, 1)
53
+ result = (result + 1) / 2
54
+ result = result.transpose(1, 2, 0) * 255.0
55
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
56
+ scale_factor = int(result.shape[1] / input_size)
57
+ return result.astype(np.uint8), scale_factor
58
+
59
+
60
+ def Release(self):
61
+ self.model_gfpgan = None
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
roop/processors/Enhance_GPEN.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import roop.globals
6
+
7
+ from roop.typing import Face, Frame, FaceSet
8
+ from roop.utilities import resolve_relative_path
9
+
10
+
11
+ class Enhance_GPEN():
12
+
13
+ model_gpen = None
14
+ name = None
15
+ devicename = None
16
+
17
+ processorname = 'gpen'
18
+ type = 'enhance'
19
+
20
+
21
+ def Initialize(self, devicename):
22
+ if self.model_gpen is None:
23
+ model_path = resolve_relative_path('../models/GPEN-BFR-512.onnx')
24
+ self.model_gpen = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
25
+ # replace Mac mps with cpu for the moment
26
+ devicename = devicename.replace('mps', 'cpu')
27
+ self.devicename = devicename
28
+
29
+ self.name = self.model_gpen.get_inputs()[0].name
30
+
31
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
32
+ # preprocess
33
+ input_size = temp_frame.shape[1]
34
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
35
+
36
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
37
+ temp_frame = temp_frame.astype('float32') / 255.0
38
+ temp_frame = (temp_frame - 0.5) / 0.5
39
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
40
+
41
+ io_binding = self.model_gpen.io_binding()
42
+ io_binding.bind_cpu_input("input", temp_frame)
43
+ io_binding.bind_output("output", self.devicename)
44
+ self.model_gpen.run_with_iobinding(io_binding)
45
+ ort_outs = io_binding.copy_outputs_to_cpu()
46
+ result = ort_outs[0][0]
47
+
48
+ # post-process
49
+ result = np.clip(result, -1, 1)
50
+ result = (result + 1) / 2
51
+ result = result.transpose(1, 2, 0) * 255.0
52
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
53
+ scale_factor = int(result.shape[1] / input_size)
54
+ return result.astype(np.uint8), scale_factor
55
+
56
+
57
+ def Release(self):
58
+ self.model_gpen = None
roop/processors/Enhance_RestoreFormerPPlus.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import roop.globals
6
+
7
+ from roop.typing import Face, Frame, FaceSet
8
+ from roop.utilities import resolve_relative_path
9
+
10
+ class Enhance_RestoreFormerPPlus():
11
+ model_restoreformerpplus = None
12
+ devicename = None
13
+ name = None
14
+
15
+ processorname = 'restoreformer++'
16
+ type = 'enhance'
17
+
18
+
19
+ def Initialize(self, devicename:str):
20
+ if self.model_restoreformerpplus is None:
21
+ # replace Mac mps with cpu for the moment
22
+ devicename = devicename.replace('mps', 'cpu')
23
+ self.devicename = devicename
24
+ model_path = resolve_relative_path('../models/restoreformer_plus_plus.onnx')
25
+ self.model_restoreformerpplus = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
26
+ self.model_inputs = self.model_restoreformerpplus.get_inputs()
27
+ model_outputs = self.model_restoreformerpplus.get_outputs()
28
+ self.io_binding = self.model_restoreformerpplus.io_binding()
29
+ self.io_binding.bind_output(model_outputs[0].name, self.devicename)
30
+
31
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
32
+ # preprocess
33
+ input_size = temp_frame.shape[1]
34
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
35
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
36
+ temp_frame = temp_frame.astype('float32') / 255.0
37
+ temp_frame = (temp_frame - 0.5) / 0.5
38
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
39
+
40
+ self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) # .astype(np.float32)
41
+ self.model_restoreformerpplus.run_with_iobinding(self.io_binding)
42
+ ort_outs = self.io_binding.copy_outputs_to_cpu()
43
+ result = ort_outs[0][0]
44
+ del ort_outs
45
+
46
+ result = np.clip(result, -1, 1)
47
+ result = (result + 1) / 2
48
+ result = result.transpose(1, 2, 0) * 255.0
49
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
50
+ scale_factor = int(result.shape[1] / input_size)
51
+ return result.astype(np.uint8), scale_factor
52
+
53
+
54
+ def Release(self):
55
+ del self.model_restoreformerpplus
56
+ self.model_restoreformerpplus = None
57
+ del self.io_binding
58
+ self.io_binding = None
59
+
roop/processors/FaceSwapInsightFace.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import roop.globals
3
+ import insightface
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from roop.typing import Face, Frame
8
+ from roop.utilities import resolve_relative_path
9
+
10
+
11
+
12
+ class FaceSwapInsightFace():
13
+ model_swap_insightface = None
14
+
15
+
16
+ processorname = 'faceswap'
17
+ type = 'swap'
18
+
19
+
20
+ def Initialize(self, devicename):
21
+ if self.model_swap_insightface is None:
22
+ model_path = resolve_relative_path('../models/inswapper_128.onnx')
23
+ self.model_swap_insightface = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
24
+
25
+
26
+ def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
27
+ img_fake, M = self.model_swap_insightface.get(temp_frame, target_face, source_face, paste_back=False)
28
+ target_face.matrix = M
29
+ return img_fake
30
+
31
+
32
+ def Release(self):
33
+ del self.model_swap_insightface
34
+ self.model_swap_insightface = None
35
+
36
+
37
+
38
+
39
+
40
+
roop/processors/Mask_Clip2Seg.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import threading
6
+ from torchvision import transforms
7
+ from clip.clipseg import CLIPDensePredT
8
+ import numpy as np
9
+
10
+ from roop.typing import Frame
11
+
12
+ THREAD_LOCK_CLIP = threading.Lock()
13
+
14
+
15
+ class Mask_Clip2Seg():
16
+
17
+ model_clip = None
18
+
19
+ processorname = 'clip2seg'
20
+ type = 'mask'
21
+
22
+
23
+ def Initialize(self, devicename):
24
+ if self.model_clip is None:
25
+ self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
26
+ self.model_clip.eval();
27
+ self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
28
+
29
+ device = torch.device(devicename)
30
+ self.model_clip.to(device)
31
+
32
+
33
+ def Run(self, img1, keywords:str) -> Frame:
34
+ if keywords is None or len(keywords) < 1 or img1 is None:
35
+ return img1
36
+
37
+ source_image_small = cv2.resize(img1, (256,256))
38
+
39
+ img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
40
+ mask_border = 1
41
+ l = 0
42
+ t = 0
43
+ r = 1
44
+ b = 1
45
+
46
+ mask_blur = 5
47
+ clip_blur = 5
48
+
49
+ img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
50
+ (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
51
+ img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
52
+ img_mask /= 255
53
+
54
+
55
+ input_image = source_image_small
56
+
57
+ transform = transforms.Compose([
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
60
+ transforms.Resize((256, 256)),
61
+ ])
62
+ img = transform(input_image).unsqueeze(0)
63
+
64
+ thresh = 0.5
65
+ prompts = keywords.split(',')
66
+ with THREAD_LOCK_CLIP:
67
+ with torch.no_grad():
68
+ preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
69
+ clip_mask = torch.sigmoid(preds[0][0])
70
+ for i in range(len(prompts)-1):
71
+ clip_mask += torch.sigmoid(preds[i+1][0])
72
+
73
+ clip_mask = clip_mask.data.cpu().numpy()
74
+ np.clip(clip_mask, 0, 1)
75
+
76
+ clip_mask[clip_mask>thresh] = 1.0
77
+ clip_mask[clip_mask<=thresh] = 0.0
78
+ kernel = np.ones((5, 5), np.float32)
79
+ clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
80
+ clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
81
+
82
+ img_mask *= clip_mask
83
+ img_mask[img_mask<0.0] = 0.0
84
+ return img_mask
85
+
86
+
87
+
88
+ def Release(self):
89
+ self.model_clip = None
90
+
roop/processors/__init__.py ADDED
File without changes
roop/processors/frame/__init__.py ADDED
File without changes
roop/processors/frame/face_swapper.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import insightface
3
+ import threading
4
+
5
+ import roop.globals
6
+ from roop.utilities import resolve_relative_path
7
+
8
+ FACE_SWAPPER = None
9
+ THREAD_LOCK = threading.Lock()
10
+ NAME = 'ROOP.FACE-SWAPPER'
11
+
12
+ DIST_THRESHOLD = 0.65
13
+
14
+
15
+ def get_face_swapper() -> Any:
16
+ global FACE_SWAPPER
17
+
18
+ with THREAD_LOCK:
19
+ if FACE_SWAPPER is None:
20
+ model_path = resolve_relative_path('../models/inswapper_128.onnx')
21
+ FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
22
+ return FACE_SWAPPER
23
+
roop/template_parser.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from datetime import datetime
3
+
4
+ template_functions = {
5
+ "timestamp": lambda data: str(int(datetime.now().timestamp())),
6
+ "i": lambda data: data.get("index", False),
7
+ "file": lambda data: data.get("file", False),
8
+ "date": lambda data: datetime.now().strftime("%Y-%m-%d"),
9
+ "time": lambda data: datetime.now().strftime("%H-%M-%S"),
10
+ }
11
+
12
+
13
+ def parse(text: str, data: dict):
14
+ pattern = r"\{([^}]+)\}"
15
+
16
+ matches = re.findall(pattern, text)
17
+
18
+ for match in matches:
19
+ replacement = template_functions[match](data)
20
+ if replacement is not False:
21
+ text = text.replace(f"{{{match}}}", replacement)
22
+
23
+ return text
roop/typing.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from insightface.app.common import Face
4
+ from roop.FaceSet import FaceSet
5
+ import numpy
6
+
7
+ Face = Face
8
+ FaceSet = FaceSet
9
+ Frame = numpy.ndarray[Any, Any]
roop/util_ffmpeg.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import subprocess
4
+ import roop.globals
5
+ import roop.utilities as util
6
+
7
+ from typing import List, Any
8
+
9
+ def run_ffmpeg(args: List[str]) -> bool:
10
+ commands = ['ffmpeg', '-hide_banner', '-hwaccel', 'auto', '-y', '-loglevel', roop.globals.log_level]
11
+ commands.extend(args)
12
+ print ("Running ffmpeg")
13
+ try:
14
+ subprocess.check_output(commands, stderr=subprocess.STDOUT)
15
+ return True
16
+ except Exception as e:
17
+ print("Running ffmpeg failed! Commandline:")
18
+ print (" ".join(commands))
19
+ return False
20
+
21
+
22
+
23
+ def cut_video(original_video: str, cut_video: str, start_frame: int, end_frame: int, reencode: bool):
24
+ fps = util.detect_fps(original_video)
25
+ start_time = start_frame / fps
26
+ num_frames = end_frame - start_frame
27
+
28
+ if reencode:
29
+ run_ffmpeg(['-ss', format(start_time, ".2f"), '-i', original_video, '-c:v', roop.globals.video_encoder, '-c:a', 'aac', '-frames:v', str(num_frames), cut_video])
30
+ else:
31
+ run_ffmpeg(['-ss', format(start_time, ".2f"), '-i', original_video, '-frames:v', str(num_frames), '-c:v' ,'copy','-c:a' ,'copy', cut_video])
32
+
33
+ def join_videos(videos: List[str], dest_filename: str, simple: bool):
34
+ if simple:
35
+ txtfilename = util.resolve_relative_path('../temp')
36
+ txtfilename = os.path.join(txtfilename, 'joinvids.txt')
37
+ with open(txtfilename, "w", encoding="utf-8") as f:
38
+ for v in videos:
39
+ v = v.replace('\\', '/')
40
+ f.write(f"file {v}\n")
41
+ commands = ['-f', 'concat', '-safe', '0', '-i', f'{txtfilename}', '-vcodec', 'copy', f'{dest_filename}']
42
+ run_ffmpeg(commands)
43
+
44
+ else:
45
+ inputs = []
46
+ filter = ''
47
+ for i,v in enumerate(videos):
48
+ inputs.append('-i')
49
+ inputs.append(v)
50
+ filter += f'[{i}:v:0][{i}:a:0]'
51
+ run_ffmpeg([" ".join(inputs), '-filter_complex', f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', '-map', '"[outv]"', '-map', '"[outa]"', dest_filename])
52
+
53
+ # filter += f'[{i}:v:0][{i}:a:0]'
54
+ # run_ffmpeg([" ".join(inputs), '-filter_complex', f'"{filter}concat=n={len(videos)}:v=1:a=1[outv][outa]"', '-map', '"[outv]"', '-map', '"[outa]"', dest_filename])
55
+
56
+
57
+
58
+ def extract_frames(target_path : str, trim_frame_start, trim_frame_end, fps : float) -> bool:
59
+ util.create_temp(target_path)
60
+ temp_directory_path = util.get_temp_directory_path(target_path)
61
+ commands = ['-i', target_path, '-q:v', '1', '-pix_fmt', 'rgb24', ]
62
+ if trim_frame_start is not None and trim_frame_end is not None:
63
+ commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ':end_frame=' + str(trim_frame_end) + ',fps=' + str(fps) ])
64
+ commands.extend(['-vsync', '0', os.path.join(temp_directory_path, '%06d.' + roop.globals.CFG.output_image_format)])
65
+ return run_ffmpeg(commands)
66
+
67
+
68
+ def create_video(target_path: str, dest_filename: str, fps: float = 24.0, temp_directory_path: str = None) -> None:
69
+ if temp_directory_path is None:
70
+ temp_directory_path = util.get_temp_directory_path(target_path)
71
+ run_ffmpeg(['-r', str(fps), '-i', os.path.join(temp_directory_path, f'%06d.{roop.globals.CFG.output_image_format}'), '-c:v', roop.globals.video_encoder, '-crf', str(roop.globals.video_quality), '-pix_fmt', 'yuv420p', '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1', '-y', dest_filename])
72
+ return dest_filename
73
+
74
+
75
+ def create_gif_from_video(video_path: str, gif_path):
76
+ from roop.capturer import get_video_frame
77
+
78
+ fps = util.detect_fps(video_path)
79
+ frame = get_video_frame(video_path)
80
+
81
+ run_ffmpeg(['-i', video_path, '-vf', f'fps={fps},scale={frame.shape[0]}:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse', '-loop', '0', gif_path])
82
+
83
+
84
+ def restore_audio(intermediate_video: str, original_video: str, trim_frame_start, trim_frame_end, final_video : str) -> None:
85
+ fps = util.detect_fps(original_video)
86
+ commands = [ '-i', intermediate_video ]
87
+ if trim_frame_start is None and trim_frame_end is None:
88
+ commands.extend([ '-c:a', 'copy' ])
89
+ else:
90
+ # if trim_frame_start is not None:
91
+ # start_time = trim_frame_start / fps
92
+ # commands.extend([ '-ss', format(start_time, ".2f")])
93
+ # else:
94
+ # commands.extend([ '-ss', '0' ])
95
+ # if trim_frame_end is not None:
96
+ # end_time = trim_frame_end / fps
97
+ # commands.extend([ '-to', format(end_time, ".2f")])
98
+ # commands.extend([ '-c:a', 'aac' ])
99
+ if trim_frame_start is not None:
100
+ start_time = trim_frame_start / fps
101
+ commands.extend([ '-ss', format(start_time, ".2f")])
102
+ else:
103
+ commands.extend([ '-ss', '0' ])
104
+ if trim_frame_end is not None:
105
+ end_time = trim_frame_end / fps
106
+ commands.extend([ '-to', format(end_time, ".2f")])
107
+ commands.extend([ '-i', original_video, "-c", "copy" ])
108
+
109
+ commands.extend([ '-map', '0:v:0', '-map', '1:a:0?', '-shortest', final_video ])
110
+ run_ffmpeg(commands)
roop/utilities.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import mimetypes
3
+ import os
4
+ import platform
5
+ import shutil
6
+ import ssl
7
+ import subprocess
8
+ import sys
9
+ import urllib
10
+ import torch
11
+ import gradio
12
+ import tempfile
13
+ import cv2
14
+ import zipfile
15
+ import traceback
16
+
17
+ from pathlib import Path
18
+ from typing import List, Any
19
+ from tqdm import tqdm
20
+ from scipy.spatial import distance
21
+
22
+ import roop.template_parser as template_parser
23
+
24
+ import roop.globals
25
+
26
+ TEMP_FILE = "temp.mp4"
27
+ TEMP_DIRECTORY = "temp"
28
+
29
+ # monkey patch ssl for mac
30
+ if platform.system().lower() == "darwin":
31
+ ssl._create_default_https_context = ssl._create_unverified_context
32
+
33
+
34
+ # https://github.com/facefusion/facefusion/blob/master/facefusion
35
+ def detect_fps(target_path: str) -> float:
36
+ fps = 24.0
37
+ cap = cv2.VideoCapture(target_path)
38
+ if cap.isOpened():
39
+ fps = cap.get(cv2.CAP_PROP_FPS)
40
+ cap.release()
41
+ return fps
42
+
43
+
44
+ # Gradio wants Images in RGB
45
+ def convert_to_gradio(image):
46
+ if image is None:
47
+ return None
48
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
49
+
50
+
51
+ def sort_filenames_ignore_path(filenames):
52
+ """Sorts a list of filenames containing a complete path by their filename,
53
+ while retaining their original path.
54
+
55
+ Args:
56
+ filenames: A list of filenames containing a complete path.
57
+
58
+ Returns:
59
+ A sorted list of filenames containing a complete path.
60
+ """
61
+ filename_path_tuples = [
62
+ (os.path.split(filename)[1], filename) for filename in filenames
63
+ ]
64
+ sorted_filename_path_tuples = sorted(filename_path_tuples, key=lambda x: x[0])
65
+ return [
66
+ filename_path_tuple[1] for filename_path_tuple in sorted_filename_path_tuples
67
+ ]
68
+
69
+
70
+ def sort_rename_frames(path: str):
71
+ filenames = os.listdir(path)
72
+ filenames.sort()
73
+ for i in range(len(filenames)):
74
+ of = os.path.join(path, filenames[i])
75
+ newidx = i + 1
76
+ new_filename = os.path.join(
77
+ path, f"{newidx:06d}." + roop.globals.CFG.output_image_format
78
+ )
79
+ os.rename(of, new_filename)
80
+
81
+
82
+ def get_temp_frame_paths(target_path: str) -> List[str]:
83
+ temp_directory_path = get_temp_directory_path(target_path)
84
+ return glob.glob(
85
+ (
86
+ os.path.join(
87
+ glob.escape(temp_directory_path),
88
+ f"*.{roop.globals.CFG.output_image_format}",
89
+ )
90
+ )
91
+ )
92
+
93
+
94
+ def get_temp_directory_path(target_path: str) -> str:
95
+ target_name, _ = os.path.splitext(os.path.basename(target_path))
96
+ target_directory_path = os.path.dirname(target_path)
97
+ return os.path.join(target_directory_path, TEMP_DIRECTORY, target_name)
98
+
99
+
100
+ def get_temp_output_path(target_path: str) -> str:
101
+ temp_directory_path = get_temp_directory_path(target_path)
102
+ return os.path.join(temp_directory_path, TEMP_FILE)
103
+
104
+
105
+ def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any:
106
+ if source_path and target_path:
107
+ source_name, _ = os.path.splitext(os.path.basename(source_path))
108
+ target_name, target_extension = os.path.splitext(os.path.basename(target_path))
109
+ if os.path.isdir(output_path):
110
+ return os.path.join(
111
+ output_path, source_name + "-" + target_name + target_extension
112
+ )
113
+ return output_path
114
+
115
+
116
+ def get_destfilename_from_path(
117
+ srcfilepath: str, destfilepath: str, extension: str
118
+ ) -> str:
119
+ fn, ext = os.path.splitext(os.path.basename(srcfilepath))
120
+ if "." in extension:
121
+ return os.path.join(destfilepath, f"{fn}{extension}")
122
+ return os.path.join(destfilepath, f"{fn}{extension}{ext}")
123
+
124
+
125
+ def replace_template(file_path: str, index: int = 0) -> str:
126
+ fn, ext = os.path.splitext(os.path.basename(file_path))
127
+
128
+ # Remove the "__temp" placeholder that was used as a temporary filename
129
+ fn = fn.replace("__temp", "")
130
+
131
+ template = roop.globals.CFG.output_template
132
+ replaced_filename = template_parser.parse(
133
+ template, {"index": str(index), "file": fn}
134
+ )
135
+
136
+ return os.path.join(roop.globals.output_path, f"{replaced_filename}{ext}")
137
+
138
+
139
+ def create_temp(target_path: str) -> None:
140
+ temp_directory_path = get_temp_directory_path(target_path)
141
+ Path(temp_directory_path).mkdir(parents=True, exist_ok=True)
142
+
143
+
144
+ def move_temp(target_path: str, output_path: str) -> None:
145
+ temp_output_path = get_temp_output_path(target_path)
146
+ if os.path.isfile(temp_output_path):
147
+ if os.path.isfile(output_path):
148
+ os.remove(output_path)
149
+ shutil.move(temp_output_path, output_path)
150
+
151
+
152
+ def clean_temp(target_path: str) -> None:
153
+ temp_directory_path = get_temp_directory_path(target_path)
154
+ parent_directory_path = os.path.dirname(temp_directory_path)
155
+ if not roop.globals.keep_frames and os.path.isdir(temp_directory_path):
156
+ shutil.rmtree(temp_directory_path)
157
+ if os.path.exists(parent_directory_path) and not os.listdir(parent_directory_path):
158
+ os.rmdir(parent_directory_path)
159
+
160
+
161
+ def delete_temp_frames(filename: str) -> None:
162
+ dir = os.path.dirname(os.path.dirname(filename))
163
+ shutil.rmtree(dir)
164
+
165
+
166
+ def has_image_extension(image_path: str) -> bool:
167
+ return image_path.lower().endswith(("png", "jpg", "jpeg", "webp"))
168
+
169
+
170
+ def has_extension(filepath: str, extensions: List[str]) -> bool:
171
+ return filepath.lower().endswith(tuple(extensions))
172
+
173
+
174
+ def is_image(image_path: str) -> bool:
175
+ if image_path and os.path.isfile(image_path):
176
+ mimetype, _ = mimetypes.guess_type(image_path)
177
+ return bool(mimetype and mimetype.startswith("image/"))
178
+ return False
179
+
180
+
181
+ def is_video(video_path: str) -> bool:
182
+ if video_path and os.path.isfile(video_path):
183
+ mimetype, _ = mimetypes.guess_type(video_path)
184
+ return bool(mimetype and mimetype.startswith("video/"))
185
+ return False
186
+
187
+
188
+ def conditional_download(download_directory_path: str, urls: List[str]) -> None:
189
+ if not os.path.exists(download_directory_path):
190
+ os.makedirs(download_directory_path)
191
+ for url in urls:
192
+ download_file_path = os.path.join(
193
+ download_directory_path, os.path.basename(url)
194
+ )
195
+ if not os.path.exists(download_file_path):
196
+ request = urllib.request.urlopen(url) # type: ignore[attr-defined]
197
+ total = int(request.headers.get("Content-Length", 0))
198
+ with tqdm(
199
+ total=total,
200
+ desc=f"Downloading {url}",
201
+ unit="B",
202
+ unit_scale=True,
203
+ unit_divisor=1024,
204
+ ) as progress:
205
+ urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined]
206
+
207
+
208
+ def get_local_files_from_folder(folder: str) -> List[str]:
209
+ if not os.path.exists(folder) or not os.path.isdir(folder):
210
+ return None
211
+ files = [
212
+ os.path.join(folder, f)
213
+ for f in os.listdir(folder)
214
+ if os.path.isfile(os.path.join(folder, f))
215
+ ]
216
+ return files
217
+
218
+
219
+ def resolve_relative_path(path: str) -> str:
220
+ return os.path.abspath(os.path.join(os.path.dirname(__file__), path))
221
+
222
+
223
+ def get_device() -> str:
224
+ if len(roop.globals.execution_providers) < 1:
225
+ roop.globals.execution_providers = ["CPUExecutionProvider"]
226
+
227
+ prov = roop.globals.execution_providers[0]
228
+ if "CoreMLExecutionProvider" in prov:
229
+ return "mps"
230
+ if "CUDAExecutionProvider" in prov or "ROCMExecutionProvider" in prov:
231
+ return "cuda"
232
+ if "OpenVINOExecutionProvider" in prov:
233
+ return "mkl"
234
+ return "cpu"
235
+
236
+
237
+ def str_to_class(module_name, class_name) -> Any:
238
+ from importlib import import_module
239
+
240
+ class_ = None
241
+ try:
242
+ module_ = import_module(module_name)
243
+ try:
244
+ class_ = getattr(module_, class_name)()
245
+ except AttributeError:
246
+ print(f"Class {class_name} does not exist")
247
+ except ImportError:
248
+ print(f"Module {module_name} does not exist")
249
+ return class_
250
+
251
+ def is_installed(name:str) -> bool:
252
+ return shutil.which(name);
253
+
254
+ # Taken from https://stackoverflow.com/a/68842705
255
+ def get_platform() -> str:
256
+ if sys.platform == "linux":
257
+ try:
258
+ proc_version = open("/proc/version").read()
259
+ if "Microsoft" in proc_version:
260
+ return "wsl"
261
+ except:
262
+ pass
263
+ return sys.platform
264
+
265
+ def open_with_default_app(filename:str):
266
+ if filename == None:
267
+ return
268
+ platform = get_platform()
269
+ if platform == "darwin":
270
+ subprocess.call(("open", filename))
271
+ elif platform in ["win64", "win32"]: os.startfile(filename.replace("/", "\\"))
272
+ elif platform == "wsl":
273
+ subprocess.call("cmd.exe /C start".split() + [filename])
274
+ else: # linux variants
275
+ subprocess.call("xdg-open", filename)
276
+
277
+
278
+ def prepare_for_batch(target_files) -> str:
279
+ print("Preparing temp files")
280
+ tempfolder = os.path.join(tempfile.gettempdir(), "rooptmp")
281
+ if os.path.exists(tempfolder):
282
+ shutil.rmtree(tempfolder)
283
+ Path(tempfolder).mkdir(parents=True, exist_ok=True)
284
+ for f in target_files:
285
+ newname = os.path.basename(f.name)
286
+ shutil.move(f.name, os.path.join(tempfolder, newname))
287
+ return tempfolder
288
+
289
+
290
+ def zip(files, zipname):
291
+ with zipfile.ZipFile(zipname, "w") as zip_file:
292
+ for f in files:
293
+ zip_file.write(f, os.path.basename(f))
294
+
295
+
296
+ def unzip(zipfilename: str, target_path: str):
297
+ with zipfile.ZipFile(zipfilename, "r") as zip_file:
298
+ zip_file.extractall(target_path)
299
+
300
+
301
+ def mkdir_with_umask(directory):
302
+ oldmask = os.umask(0)
303
+ # mode needs octal
304
+ os.makedirs(directory, mode=0o775, exist_ok=True)
305
+ os.umask(oldmask)
306
+
307
+
308
+ def open_folder(path: str):
309
+ platform = get_platform()
310
+ try:
311
+ if platform == "darwin":
312
+ subprocess.call(("open", path))
313
+ elif platform in ["win64", "win32"]:
314
+ open_with_default_app(path)
315
+ elif platform == "wsl":
316
+ subprocess.call("cmd.exe /C start".split() + [path])
317
+ else: # linux variants
318
+ subprocess.Popen(["xdg-open", path])
319
+ except Exception as e:
320
+ traceback.print_exc()
321
+ pass
322
+ # import webbrowser
323
+ # webbrowser.open(url)
324
+
325
+
326
+ def create_version_html() -> str:
327
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
328
+ versions_html = f"""
329
+ python: <span title="{sys.version}">{python_version}</span>
330
+
331
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
332
+
333
+ gradio: {gradio.__version__}
334
+ """
335
+ return versions_html
336
+
337
+
338
+ def compute_cosine_distance(emb1, emb2) -> float:
339
+ return distance.cosine(emb1, emb2)
roop/virtualcam.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import roop.globals
3
+ import ui.globals
4
+ import pyvirtualcam
5
+ import threading
6
+ import time
7
+
8
+
9
+ cam_active = False
10
+ cam_thread = None
11
+ vcam = None
12
+
13
+ def virtualcamera(streamobs, cam_num,width,height):
14
+ from roop.core import live_swap
15
+ from roop.filters import fast_quantize_to_palette
16
+
17
+ global cam_active
18
+
19
+ #time.sleep(2)
20
+ print('Starting capture')
21
+ cap = cv2.VideoCapture(cam_num, cv2.CAP_DSHOW)
22
+ if not cap.isOpened():
23
+ print("Cannot open camera")
24
+ cap.release()
25
+ del cap
26
+ return
27
+
28
+ pref_width = width
29
+ pref_height = height
30
+ pref_fps_in = 30
31
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, pref_width)
32
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, pref_height)
33
+ cap.set(cv2.CAP_PROP_FPS, pref_fps_in)
34
+ cam_active = True
35
+
36
+ # native format UYVY
37
+
38
+ cam = None
39
+ if streamobs:
40
+ print('Detecting virtual cam devices')
41
+ cam = pyvirtualcam.Camera(width=pref_width, height=pref_height, fps=pref_fps_in, fmt=pyvirtualcam.PixelFormat.BGR, print_fps=False)
42
+ if cam:
43
+ print(f'Using virtual camera: {cam.device}')
44
+ print(f'Using {cam.native_fmt}')
45
+ else:
46
+ print(f'Not streaming to virtual camera!')
47
+
48
+ while cam_active:
49
+ ret, frame = cap.read()
50
+ if not ret:
51
+ break
52
+
53
+ if len(roop.globals.INPUT_FACESETS) > 0:
54
+ frame = live_swap(frame, "all", False, None, None, False)
55
+ #frame = fast_quantize_to_palette(frame)
56
+ if cam:
57
+ cam.send(frame)
58
+ cam.sleep_until_next_frame()
59
+ ui.globals.ui_camera_frame = frame
60
+
61
+ if cam:
62
+ cam.close()
63
+ cap.release()
64
+ print('Camera stopped')
65
+
66
+
67
+
68
+ def start_virtual_cam(streamobs, cam_number, resolution):
69
+ global cam_thread, cam_active
70
+
71
+ if not cam_active:
72
+ width, height = map(int, resolution.split('x'))
73
+ cam_thread = threading.Thread(target=virtualcamera, args=[streamobs, cam_number, width, height])
74
+ cam_thread.start()
75
+
76
+
77
+
78
+ def stop_virtual_cam():
79
+ global cam_active, cam_thread
80
+
81
+ if cam_active:
82
+ cam_active = False
83
+ cam_thread.join()
84
+
85
+
roop/vr_util.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ # VR Lense Distortion
5
+ # Taken from https://github.com/g0kuvonlange/vrswap
6
+
7
+
8
+ def get_perspective(img, FOV, THETA, PHI, height, width):
9
+ #
10
+ # THETA is left/right angle, PHI is up/down angle, both in degree
11
+ #
12
+ [orig_width, orig_height, _] = img.shape
13
+ equ_h = orig_height
14
+ equ_w = orig_width
15
+ equ_cx = (equ_w - 1) / 2.0
16
+ equ_cy = (equ_h - 1) / 2.0
17
+
18
+ wFOV = FOV
19
+ hFOV = float(height) / width * wFOV
20
+
21
+ w_len = np.tan(np.radians(wFOV / 2.0))
22
+ h_len = np.tan(np.radians(hFOV / 2.0))
23
+
24
+ x_map = np.ones([height, width], np.float32)
25
+ y_map = np.tile(np.linspace(-w_len, w_len, width), [height, 1])
26
+ z_map = -np.tile(np.linspace(-h_len, h_len, height), [width, 1]).T
27
+
28
+ D = np.sqrt(x_map**2 + y_map**2 + z_map**2)
29
+ xyz = np.stack((x_map, y_map, z_map), axis=2) / np.repeat(
30
+ D[:, :, np.newaxis], 3, axis=2
31
+ )
32
+
33
+ y_axis = np.array([0.0, 1.0, 0.0], np.float32)
34
+ z_axis = np.array([0.0, 0.0, 1.0], np.float32)
35
+ [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA))
36
+ [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(-PHI))
37
+
38
+ xyz = xyz.reshape([height * width, 3]).T
39
+ xyz = np.dot(R1, xyz)
40
+ xyz = np.dot(R2, xyz).T
41
+ lat = np.arcsin(xyz[:, 2])
42
+ lon = np.arctan2(xyz[:, 1], xyz[:, 0])
43
+
44
+ lon = lon.reshape([height, width]) / np.pi * 180
45
+ lat = -lat.reshape([height, width]) / np.pi * 180
46
+
47
+ lon = lon / 180 * equ_cx + equ_cx
48
+ lat = lat / 90 * equ_cy + equ_cy
49
+
50
+ persp = cv2.remap(
51
+ img,
52
+ lon.astype(np.float32),
53
+ lat.astype(np.float32),
54
+ cv2.INTER_CUBIC,
55
+ borderMode=cv2.BORDER_WRAP,
56
+ )
57
+ return persp
ui/globals.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ui_restart_server = False
2
+
3
+ SELECTION_FACES_DATA = None
4
+ ui_SELECTED_INPUT_FACE_INDEX = 0
5
+
6
+ ui_selected_enhancer = None
7
+ ui_blend_ratio = None
8
+ ui_input_thumbs = []
9
+ ui_target_thumbs = []
10
+ ui_camera_frame = None
11
+
12
+
13
+
14
+
15
+
ui/main.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ import roop.globals
5
+ import roop.metadata
6
+ import roop.utilities as util
7
+ import ui.globals as uii
8
+
9
+ from ui.tabs.faceswap_tab import faceswap_tab
10
+ from ui.tabs.livecam_tab import livecam_tab
11
+ from ui.tabs.facemgr_tab import facemgr_tab
12
+ from ui.tabs.extras_tab import extras_tab
13
+ from ui.tabs.settings_tab import settings_tab
14
+
15
+ roop.globals.keep_fps = None
16
+ roop.globals.keep_frames = None
17
+ roop.globals.skip_audio = None
18
+ roop.globals.use_batch = None
19
+
20
+
21
+ def prepare_environment():
22
+ roop.globals.output_path = os.path.abspath(os.path.join(os.getcwd(), "output"))
23
+ os.makedirs(roop.globals.output_path, exist_ok=True)
24
+ if not roop.globals.CFG.use_os_temp_folder:
25
+ os.environ["TEMP"] = os.environ["TMP"] = os.path.abspath(os.path.join(os.getcwd(), "temp"))
26
+ os.makedirs(os.environ["TEMP"], exist_ok=True)
27
+ os.environ["GRADIO_TEMP_DIR"] = os.environ["TEMP"]
28
+
29
+
30
+ def run():
31
+ from roop.core import decode_execution_providers, set_display_ui
32
+
33
+ prepare_environment()
34
+
35
+ set_display_ui(show_msg)
36
+ roop.globals.execution_providers = decode_execution_providers([roop.globals.CFG.provider])
37
+ print(f'Using provider {roop.globals.execution_providers} - Device:{util.get_device()}')
38
+
39
+ run_server = True
40
+ uii.ui_restart_server = False
41
+ mycss = """
42
+ span {color: var(--block-info-text-color)}
43
+ #fixedheight {
44
+ max-height: 238.4px;
45
+ overflow-y: auto !important;
46
+ }
47
+ .image-container.svelte-1l6wqyv {height: 100%}
48
+
49
+ """
50
+
51
+ while run_server:
52
+ server_name = roop.globals.CFG.server_name
53
+ if server_name is None or len(server_name) < 1:
54
+ server_name = None
55
+ server_port = roop.globals.CFG.server_port
56
+ if server_port <= 0:
57
+ server_port = None
58
+ ssl_verify = False if server_name == '0.0.0.0' else True
59
+ with gr.Blocks(title=f'{roop.metadata.name} {roop.metadata.version}', theme=roop.globals.CFG.selected_theme, css=mycss) as ui:
60
+ with gr.Row(variant='compact'):
61
+ gr.Markdown(f"### [{roop.metadata.name} {roop.metadata.version}](https://github.com/C0untFloyd/roop-unleashed)")
62
+ gr.HTML(util.create_version_html(), elem_id="versions")
63
+ faceswap_tab()
64
+ livecam_tab()
65
+ facemgr_tab()
66
+ extras_tab()
67
+ settings_tab()
68
+
69
+ uii.ui_restart_server = False
70
+ try:
71
+ ui.queue().launch(inbrowser=True, server_name=server_name, server_port=server_port, share=roop.globals.CFG.server_share, ssl_verify=ssl_verify, prevent_thread_lock=True, show_error=True)
72
+ except Exception as e:
73
+ print(f'Exception {e} when launching Gradio Server!')
74
+ uii.ui_restart_server = True
75
+ run_server = False
76
+ try:
77
+ while uii.ui_restart_server == False:
78
+ time.sleep(1.0)
79
+
80
+ except (KeyboardInterrupt, OSError):
81
+ print("Keyboard interruption in main thread... closing server.")
82
+ run_server = False
83
+ ui.close()
84
+
85
+
86
+ def show_msg(msg: str):
87
+ gr.Info(msg)
88
+
ui/tabs/extras_tab.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import roop.utilities as util
4
+ import roop.util_ffmpeg as ffmpeg
5
+ import roop.globals
6
+
7
+ def extras_tab():
8
+ with gr.Tab("🎉 Extras"):
9
+ with gr.Row():
10
+ files_to_process = gr.Files(label='File(s) to process', file_count="multiple", file_types=["image", "video"])
11
+ # with gr.Row(variant='panel'):
12
+ # with gr.Accordion(label="Post process", open=False):
13
+ # with gr.Column():
14
+ # selected_post_enhancer = gr.Dropdown(["None", "Codeformer", "GFPGAN"], value="None", label="Select post-processing")
15
+ # with gr.Column():
16
+ # gr.Button("Start").click(fn=lambda: gr.Info('Not yet implemented...'))
17
+ with gr.Row(variant='panel'):
18
+ with gr.Accordion(label="Video/GIF", open=False):
19
+ with gr.Row(variant='panel'):
20
+ with gr.Column():
21
+ gr.Markdown("""
22
+ # Poor man's video editor
23
+ Re-encoding uses your configuration from the Settings Tab.
24
+ """)
25
+ with gr.Column():
26
+ cut_start_time = gr.Slider(0, 1000000, value=0, label="Start Frame", step=1.0, interactive=True)
27
+ with gr.Column():
28
+ cut_end_time = gr.Slider(1, 1000000, value=1, label="End Frame", step=1.0, interactive=True)
29
+ with gr.Column():
30
+ extras_chk_encode = gr.Checkbox(label='Re-encode videos (necessary for videos with different codecs)', value=False)
31
+ start_cut_video = gr.Button("Cut video")
32
+ start_extract_frames = gr.Button("Extract frames")
33
+ start_join_videos = gr.Button("Join videos")
34
+
35
+ with gr.Row(variant='panel'):
36
+ with gr.Column():
37
+ gr.Markdown("""
38
+ # Create video/gif from images
39
+ """)
40
+ with gr.Column():
41
+ extras_fps = gr.Slider(minimum=0, maximum=120, value=30, label="Video FPS", step=1.0, interactive=True)
42
+ extras_images_folder = gr.Textbox(show_label=False, placeholder="/content/", interactive=True)
43
+ with gr.Column():
44
+ extras_chk_creategif = gr.Checkbox(label='Create GIF from video', value=False)
45
+ extras_create_video=gr.Button("Create")
46
+ with gr.Row():
47
+ gr.Button("👀 Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path))
48
+ with gr.Row():
49
+ extra_files_output = gr.Files(label='Resulting output files', file_count="multiple")
50
+
51
+ start_cut_video.click(fn=on_cut_video, inputs=[files_to_process, cut_start_time, cut_end_time, extras_chk_encode], outputs=[extra_files_output])
52
+ start_extract_frames.click(fn=on_extras_extract_frames, inputs=[files_to_process], outputs=[extra_files_output])
53
+ start_join_videos.click(fn=on_join_videos, inputs=[files_to_process, extras_chk_encode], outputs=[extra_files_output])
54
+ extras_create_video.click(fn=on_extras_create_video, inputs=[extras_images_folder, extras_fps, extras_chk_creategif], outputs=[extra_files_output])
55
+
56
+
57
+ def on_cut_video(files, cut_start_frame, cut_end_frame, reencode):
58
+ if files is None:
59
+ return None
60
+
61
+ resultfiles = []
62
+ for tf in files:
63
+ f = tf.name
64
+ destfile = util.get_destfilename_from_path(f, roop.globals.output_path, '_cut')
65
+ ffmpeg.cut_video(f, destfile, cut_start_frame, cut_end_frame, reencode)
66
+ if os.path.isfile(destfile):
67
+ resultfiles.append(destfile)
68
+ else:
69
+ gr.Error('Cutting video failed!')
70
+ return resultfiles
71
+
72
+
73
+ def on_join_videos(files, chk_encode):
74
+ if files is None:
75
+ return None
76
+
77
+ filenames = []
78
+ for f in files:
79
+ filenames.append(f.name)
80
+ destfile = util.get_destfilename_from_path(filenames[0], roop.globals.output_path, '_join')
81
+ sorted_filenames = util.sort_filenames_ignore_path(filenames)
82
+ ffmpeg.join_videos(sorted_filenames, destfile, not chk_encode)
83
+ resultfiles = []
84
+ if os.path.isfile(destfile):
85
+ resultfiles.append(destfile)
86
+ else:
87
+ gr.Error('Joining videos failed!')
88
+ return resultfiles
89
+
90
+
91
+
92
+ def on_extras_create_video(images_path,fps, create_gif):
93
+ util.sort_rename_frames(os.path.dirname(images_path))
94
+ destfilename = os.path.join(roop.globals.output_path, "img2video." + roop.globals.CFG.output_video_format)
95
+ ffmpeg.create_video('', destfilename, fps, images_path)
96
+ resultfiles = []
97
+ if os.path.isfile(destfilename):
98
+ resultfiles.append(destfilename)
99
+ else:
100
+ return None
101
+ if create_gif:
102
+ gifname = util.get_destfilename_from_path(destfilename, './output', '.gif')
103
+ ffmpeg.create_gif_from_video(destfilename, gifname)
104
+ if os.path.isfile(destfilename):
105
+ resultfiles.append(gifname)
106
+ return resultfiles
107
+
108
+
109
+ def on_extras_extract_frames(files):
110
+ if files is None:
111
+ return None
112
+
113
+ resultfiles = []
114
+ for tf in files:
115
+ f = tf.name
116
+ resfolder = ffmpeg.extract_frames(f)
117
+ for file in os.listdir(resfolder):
118
+ outfile = os.path.join(resfolder, file)
119
+ if os.path.isfile(outfile):
120
+ resultfiles.append(outfile)
121
+ return resultfiles
122
+
ui/tabs/facemgr_tab.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import cv2
4
+ import gradio as gr
5
+ import roop.utilities as util
6
+ import roop.globals
7
+ from roop.face_util import extract_face_images
8
+
9
+ selected_face_index = -1
10
+ thumbs = []
11
+ images = []
12
+
13
+
14
+ def facemgr_tab():
15
+ with gr.Tab("👨‍👩‍👧‍👦 Face Management"):
16
+ with gr.Row():
17
+ gr.Markdown("""
18
+ # Create blending facesets
19
+ Add multiple reference images into a faceset file.
20
+ """)
21
+ with gr.Row():
22
+ fb_facesetfile = gr.Files(label='Faceset', file_count='single', file_types=['.fsz'], interactive=True)
23
+ fb_files = gr.Files(label='Input Files', file_count="multiple", file_types=["image"], interactive=True)
24
+ with gr.Row():
25
+ with gr.Column():
26
+ gr.Button("👀 Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path))
27
+ with gr.Column():
28
+ gr.Markdown(' ')
29
+ with gr.Row():
30
+ faces = gr.Gallery(label="Faces in this Faceset", allow_preview=True, preview=True, height=128, object_fit="scale-down")
31
+ with gr.Row():
32
+ fb_remove = gr.Button("Remove selected", variant='secondary')
33
+ fb_update = gr.Button("Create/Update Faceset file", variant='primary')
34
+ fb_clear = gr.Button("Clear all", variant='stop')
35
+
36
+ fb_facesetfile.change(fn=on_faceset_changed, inputs=[fb_facesetfile], outputs=[faces])
37
+ fb_files.change(fn=on_fb_files_changed, inputs=[fb_files], outputs=[faces])
38
+ fb_update.click(fn=on_update_clicked, outputs=[fb_facesetfile])
39
+ fb_remove.click(fn=on_remove_clicked, outputs=[faces])
40
+ fb_clear.click(fn=on_clear_clicked, outputs=[faces, fb_files, fb_facesetfile])
41
+ faces.select(fn=on_face_selected)
42
+
43
+ def on_faceset_changed(faceset, progress=gr.Progress()):
44
+ global thumbs, images
45
+
46
+ if faceset is None:
47
+ return thumbs
48
+
49
+ thumbs.clear()
50
+ filename = faceset.name
51
+
52
+ if filename.lower().endswith('fsz'):
53
+ progress(0, desc="Retrieving faces from Faceset File", )
54
+ unzipfolder = os.path.join(os.environ["TEMP"], 'faceset')
55
+ if os.path.isdir(unzipfolder):
56
+ shutil.rmtree(unzipfolder)
57
+ util.mkdir_with_umask(unzipfolder)
58
+ util.unzip(filename, unzipfolder)
59
+ for file in os.listdir(unzipfolder):
60
+ if file.endswith(".png"):
61
+ SELECTION_FACES_DATA = extract_face_images(os.path.join(unzipfolder,file), (False, 0), 0.5)
62
+ if len(SELECTION_FACES_DATA) < 1:
63
+ gr.Warning(f"No face detected in {file}!")
64
+ for f in SELECTION_FACES_DATA:
65
+ image = f[1]
66
+ images.append(image)
67
+ thumbs.append(util.convert_to_gradio(image))
68
+
69
+ return thumbs
70
+
71
+
72
+ def on_fb_files_changed(inputfiles, progress=gr.Progress()):
73
+ global thumbs, images
74
+
75
+ if inputfiles is None or len(inputfiles) < 1:
76
+ return thumbs
77
+
78
+ progress(0, desc="Retrieving faces from images", )
79
+ for f in inputfiles:
80
+ source_path = f.name
81
+ if util.has_image_extension(source_path):
82
+ roop.globals.source_path = source_path
83
+ SELECTION_FACES_DATA = extract_face_images(roop.globals.source_path, (False, 0), 0.5)
84
+ for f in SELECTION_FACES_DATA:
85
+ image = f[1]
86
+ images.append(image)
87
+ thumbs.append(util.convert_to_gradio(image))
88
+ return thumbs
89
+
90
+ def on_face_selected(evt: gr.SelectData):
91
+ global selected_face_index
92
+
93
+ if evt is not None:
94
+ selected_face_index = evt.index
95
+
96
+ def on_remove_clicked():
97
+ global thumbs, images, selected_face_index
98
+
99
+ if len(thumbs) > selected_face_index:
100
+ f = thumbs.pop(selected_face_index)
101
+ del f
102
+ f = images.pop(selected_face_index)
103
+ del f
104
+ return thumbs
105
+
106
+ def on_clear_clicked():
107
+ global thumbs, images
108
+
109
+ thumbs.clear()
110
+ images.clear()
111
+ return thumbs, None, None
112
+
113
+
114
+
115
+
116
+
117
+ def on_update_clicked():
118
+ if len(images) < 1:
119
+ gr.Warning(f"No faces to create faceset from!")
120
+ return None
121
+
122
+ imgnames = []
123
+ for index,img in enumerate(images):
124
+ filename = os.path.join(roop.globals.output_path, f'{index}.png')
125
+ # if img.shape[0] != 512 or img.shape[1] != 512:
126
+ # cv2.imwrite(filename, resize_image_keep_content(img, 512, 512))
127
+ # removed resizing
128
+ cv2.imwrite(filename, img)
129
+ imgnames.append(filename)
130
+
131
+ finalzip = os.path.join(roop.globals.output_path, 'faceset.fsz')
132
+ util.zip(imgnames, finalzip)
133
+ return finalzip
134
+
135
+
ui/tabs/faceswap_tab.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import pathlib
4
+ import gradio as gr
5
+ import roop.utilities as util
6
+ import roop.globals
7
+ import ui.globals
8
+ from roop.face_util import extract_face_images, create_blank_image
9
+ from roop.capturer import get_video_frame, get_video_frame_total, get_image_frame
10
+ from roop.ProcessEntry import ProcessEntry
11
+ from roop.FaceSet import FaceSet
12
+
13
+ last_image = None
14
+
15
+
16
+ IS_INPUT = True
17
+ SELECTED_FACE_INDEX = 0
18
+
19
+ SELECTED_INPUT_FACE_INDEX = 0
20
+ SELECTED_TARGET_FACE_INDEX = 0
21
+
22
+ input_faces = None
23
+ target_faces = None
24
+ face_selection = None
25
+ previewimage = None
26
+
27
+ selected_preview_index = 0
28
+
29
+ is_processing = False
30
+
31
+ list_files_process : list[ProcessEntry] = []
32
+ no_face_choices = ["Use untouched original frame","Retry rotated", "Skip Frame"]
33
+
34
+ current_video_fps = 50
35
+
36
+ manual_masking = False
37
+
38
+
39
+ def faceswap_tab():
40
+ global no_face_choices, previewimage
41
+
42
+ with gr.Tab("🎭 Face Swap"):
43
+ with gr.Row(variant='panel'):
44
+ with gr.Column(scale=2):
45
+ with gr.Row():
46
+ with gr.Column(min_width=160):
47
+ input_faces = gr.Gallery(label="Input faces", allow_preview=False, preview=False, height=128, object_fit="scale-down", columns=8)
48
+ with gr.Accordion(label="Advanced Masking", open=False):
49
+ chk_showmaskoffsets = gr.Checkbox(label="Show mask overlay in preview", value=False, interactive=True)
50
+ mask_top = gr.Slider(0, 1.0, value=0, label="Offset Face Top", step=0.01, interactive=True)
51
+ mask_bottom = gr.Slider(0, 1.0, value=0, label="Offset Face Bottom", step=0.01, interactive=True)
52
+ mask_left = gr.Slider(0, 1.0, value=0, label="Offset Face Left", step=0.01, interactive=True)
53
+ mask_right = gr.Slider(0, 1.0, value=0, label="Offset Face Right", step=0.01, interactive=True)
54
+ mask_erosion = gr.Slider(1.0, 3.0, value=1.0, label="Erosion Iterations", step=1.00, interactive=True)
55
+ mask_blur = gr.Slider(10.0, 50.0, value=20.0, label="Blur size", step=1.00, interactive=True)
56
+ bt_toggle_masking = gr.Button("Toggle manual masking", variant='secondary', size='sm')
57
+ chk_useclip = gr.Checkbox(label="Use Text Masking", value=False)
58
+ clip_text = gr.Textbox(label="List of objects to mask and restore back on fake image", value="cup,hands,hair,banana" ,elem_id='tooltip')
59
+ gr.Dropdown(["Clip2Seg"], value="Clip2Seg", label="Engine")
60
+ bt_preview_mask = gr.Button("👥 Show Mask Preview", variant='secondary')
61
+ bt_remove_selected_input_face = gr.Button("❌ Remove selected", size='sm')
62
+ bt_clear_input_faces = gr.Button("💥 Clear all", variant='stop', size='sm')
63
+ with gr.Column(min_width=160):
64
+ target_faces = gr.Gallery(label="Target faces", allow_preview=False, preview=False, height=128, object_fit="scale-down", columns=8)
65
+ bt_remove_selected_target_face = gr.Button("❌ Remove selected", size='sm')
66
+ bt_add_local = gr.Button('Add local files from', size='sm')
67
+ local_folder = gr.Textbox(show_label=False, placeholder="/content/", interactive=True)
68
+ with gr.Row(variant='panel'):
69
+ bt_srcfiles = gr.Files(label='Source File(s)', file_count="multiple", file_types=["image", ".fsz"], elem_id='filelist', height=233)
70
+ bt_destfiles = gr.Files(label='Target File(s)', file_count="multiple", file_types=["image", "video"], elem_id='filelist', height=233)
71
+ with gr.Row(variant='panel'):
72
+ gr.Markdown('')
73
+ forced_fps = gr.Slider(minimum=0, maximum=120, value=0, label="Video FPS", info='Overrides detected fps if not 0', step=1.0, interactive=True, container=True)
74
+
75
+ with gr.Column(scale=2):
76
+ previewimage = gr.Image(label="Preview Image", height=576, interactive=False, visible=True)
77
+ maskimage = gr.ImageEditor(label="Manual mask Image", sources=["clipboard"], transforms="", type="numpy",
78
+ brush=gr.Brush(color_mode="fixed", colors=["rgba(255, 255, 255, 1"]), interactive=True, visible=False)
79
+ with gr.Row(variant='panel'):
80
+ fake_preview = gr.Checkbox(label="Face swap frames", value=False)
81
+ bt_refresh_preview = gr.Button("🔄 Refresh", variant='secondary', size='sm')
82
+ bt_use_face_from_preview = gr.Button("Use Face from this Frame", variant='primary', size='sm')
83
+ with gr.Row():
84
+ preview_frame_num = gr.Slider(1, 1, value=1, label="Frame Number", info='0:00:00', step=1.0, interactive=True)
85
+ with gr.Row():
86
+ text_frame_clip = gr.Markdown('Processing frame range [0 - 0]')
87
+ set_frame_start = gr.Button("⬅ Set as Start", size='sm')
88
+ set_frame_end = gr.Button("➡ Set as End", size='sm')
89
+ with gr.Row(visible=False) as dynamic_face_selection:
90
+ with gr.Column(scale=2):
91
+ face_selection = gr.Gallery(label="Detected faces", allow_preview=False, preview=False, height=256, object_fit="cover", columns=8)
92
+ with gr.Column():
93
+ bt_faceselect = gr.Button("☑ Use selected face", size='sm')
94
+ bt_cancelfaceselect = gr.Button("Done", size='sm')
95
+ with gr.Column():
96
+ gr.Markdown(' ')
97
+
98
+ with gr.Row(variant='panel'):
99
+ with gr.Column(scale=1):
100
+ selected_face_detection = gr.Dropdown(["First found", "All female", "All male", "All faces", "Selected face"], value="First found", label="Specify face selection for swapping")
101
+ max_face_distance = gr.Slider(0.01, 1.0, value=0.65, label="Max Face Similarity Threshold")
102
+ video_swapping_method = gr.Dropdown(["Extract Frames to media","In-Memory processing"], value="In-Memory processing", label="Select video processing method", interactive=True)
103
+ no_face_action = gr.Dropdown(choices=no_face_choices, value=no_face_choices[0], label="Action on no face detected", interactive=True)
104
+ vr_mode = gr.Checkbox(label="VR Mode", value=False)
105
+ with gr.Column(scale=1):
106
+ ui.globals.ui_selected_enhancer = gr.Dropdown(["None", "Codeformer", "DMDNet", "GFPGAN", "GPEN", "Restoreformer++"], value="None", label="Select post-processing")
107
+ ui.globals.ui_blend_ratio = gr.Slider(0.0, 1.0, value=0.65, label="Original/Enhanced image blend ratio")
108
+ with gr.Group():
109
+ autorotate = gr.Checkbox(label="Auto rotate horizontal Faces", value=True)
110
+ roop.globals.skip_audio = gr.Checkbox(label="Skip audio", value=False)
111
+ roop.globals.keep_frames = gr.Checkbox(label="Keep Frames (relevant only when extracting frames)", value=False)
112
+ roop.globals.wait_after_extraction = gr.Checkbox(label="Wait for user key press before creating video ", value=False)
113
+ with gr.Row(variant='panel'):
114
+ with gr.Column():
115
+ bt_start = gr.Button("▶ Start", variant='primary')
116
+ gr.Button("👀 Open Output Folder", size='sm').click(fn=lambda: util.open_folder(roop.globals.output_path))
117
+ with gr.Column():
118
+ bt_stop = gr.Button("⏹ Stop", variant='secondary', interactive=False)
119
+ with gr.Column(scale=2):
120
+ gr.Markdown(' ')
121
+ with gr.Row(variant='panel'):
122
+ with gr.Column():
123
+ resultfiles = gr.Files(label='Processed File(s)', interactive=False)
124
+ with gr.Column():
125
+ resultimage = gr.Image(type='filepath', label='Final Image', interactive=False )
126
+ resultvideo = gr.Video(label='Final Video', interactive=False, visible=False)
127
+
128
+ previewinputs = [preview_frame_num, bt_destfiles, fake_preview, ui.globals.ui_selected_enhancer, selected_face_detection,
129
+ max_face_distance, ui.globals.ui_blend_ratio, chk_useclip, clip_text, no_face_action, vr_mode, autorotate, maskimage, chk_showmaskoffsets]
130
+ previewoutputs = [previewimage, maskimage, preview_frame_num]
131
+ input_faces.select(on_select_input_face, None, None).then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs)
132
+ bt_remove_selected_input_face.click(fn=remove_selected_input_face, outputs=[input_faces])
133
+ bt_srcfiles.change(fn=on_srcfile_changed, show_progress='full', inputs=bt_srcfiles, outputs=[dynamic_face_selection, face_selection, input_faces])
134
+
135
+ mask_top.release(fn=on_mask_top_changed, inputs=[mask_top], show_progress='hidden')
136
+ mask_bottom.release(fn=on_mask_bottom_changed, inputs=[mask_bottom], show_progress='hidden')
137
+ mask_left.release(fn=on_mask_left_changed, inputs=[mask_left], show_progress='hidden')
138
+ mask_right.release(fn=on_mask_right_changed, inputs=[mask_right], show_progress='hidden')
139
+ mask_erosion.release(fn=on_mask_erosion_changed, inputs=[mask_erosion], show_progress='hidden')
140
+ mask_blur.release(fn=on_mask_blur_changed, inputs=[mask_blur], show_progress='hidden')
141
+
142
+
143
+ target_faces.select(on_select_target_face, None, None)
144
+ bt_remove_selected_target_face.click(fn=remove_selected_target_face, outputs=[target_faces])
145
+
146
+ forced_fps.change(fn=on_fps_changed, inputs=[forced_fps], show_progress='hidden')
147
+ bt_destfiles.change(fn=on_destfiles_changed, inputs=[bt_destfiles], outputs=[preview_frame_num, text_frame_clip], show_progress='hidden').then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden')
148
+ bt_destfiles.select(fn=on_destfiles_selected, outputs=[preview_frame_num, text_frame_clip, forced_fps], show_progress='hidden').then(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden')
149
+ bt_destfiles.clear(fn=on_clear_destfiles, outputs=[target_faces])
150
+ resultfiles.select(fn=on_resultfiles_selected, inputs=[resultfiles], outputs=[resultimage, resultvideo])
151
+
152
+ face_selection.select(on_select_face, None, None)
153
+ bt_faceselect.click(fn=on_selected_face, outputs=[input_faces, target_faces, selected_face_detection])
154
+ bt_cancelfaceselect.click(fn=on_end_face_selection, outputs=[dynamic_face_selection, face_selection])
155
+
156
+ bt_clear_input_faces.click(fn=on_clear_input_faces, outputs=[input_faces])
157
+
158
+
159
+ bt_add_local.click(fn=on_add_local_folder, inputs=[local_folder], outputs=[bt_destfiles])
160
+ bt_preview_mask.click(fn=on_preview_mask, inputs=[preview_frame_num, bt_destfiles, clip_text], outputs=[previewimage])
161
+
162
+ start_event = bt_start.click(fn=start_swap,
163
+ inputs=[ui.globals.ui_selected_enhancer, selected_face_detection, roop.globals.keep_frames, roop.globals.wait_after_extraction,
164
+ roop.globals.skip_audio, max_face_distance, ui.globals.ui_blend_ratio, chk_useclip, clip_text,video_swapping_method, no_face_action, vr_mode, autorotate, maskimage],
165
+ outputs=[bt_start, bt_stop, resultfiles], show_progress='full')
166
+ after_swap_event = start_event.then(fn=on_resultfiles_finished, inputs=[resultfiles], outputs=[resultimage, resultvideo])
167
+
168
+ bt_stop.click(fn=stop_swap, cancels=[start_event, after_swap_event], outputs=[bt_start, bt_stop], queue=False)
169
+
170
+ bt_refresh_preview.click(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs)
171
+ bt_toggle_masking.click(fn=on_toggle_masking, inputs=[previewimage, maskimage], outputs=[previewimage, maskimage])
172
+ fake_preview.change(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs)
173
+ preview_frame_num.release(fn=on_preview_frame_changed, inputs=previewinputs, outputs=previewoutputs, show_progress='hidden', )
174
+ bt_use_face_from_preview.click(fn=on_use_face_from_selected, show_progress='full', inputs=[bt_destfiles, preview_frame_num], outputs=[dynamic_face_selection, face_selection, target_faces, selected_face_detection])
175
+ set_frame_start.click(fn=on_set_frame, inputs=[set_frame_start, preview_frame_num], outputs=[text_frame_clip])
176
+ set_frame_end.click(fn=on_set_frame, inputs=[set_frame_end, preview_frame_num], outputs=[text_frame_clip])
177
+
178
+
179
+
180
+ def on_mask_top_changed(mask_offset):
181
+ set_mask_offset(0, mask_offset)
182
+
183
+ def on_mask_bottom_changed(mask_offset):
184
+ set_mask_offset(1, mask_offset)
185
+
186
+ def on_mask_left_changed(mask_offset):
187
+ set_mask_offset(2, mask_offset)
188
+
189
+ def on_mask_right_changed(mask_offset):
190
+ set_mask_offset(3, mask_offset)
191
+
192
+ def on_mask_erosion_changed(mask_offset):
193
+ set_mask_offset(4, mask_offset)
194
+ def on_mask_blur_changed(mask_offset):
195
+ set_mask_offset(5, mask_offset)
196
+
197
+
198
+ def set_mask_offset(index, mask_offset):
199
+ global SELECTED_INPUT_FACE_INDEX
200
+
201
+ if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX:
202
+ offs = roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets
203
+ offs[index] = mask_offset
204
+ if offs[0] + offs[1] > 0.99:
205
+ offs[0] = 0.99
206
+ offs[1] = 0.0
207
+ if offs[2] + offs[3] > 0.99:
208
+ offs[2] = 0.99
209
+ offs[3] = 0.0
210
+ roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets = offs
211
+
212
+
213
+
214
+ def on_add_local_folder(folder):
215
+ files = util.get_local_files_from_folder(folder)
216
+ if files is None:
217
+ gr.Warning("Empty folder or folder not found!")
218
+ return files
219
+
220
+
221
+ def on_srcfile_changed(srcfiles, progress=gr.Progress()):
222
+ from roop.face_util import norm_crop2
223
+ global SELECTION_FACES_DATA, IS_INPUT, input_faces, face_selection, last_image
224
+
225
+ IS_INPUT = True
226
+
227
+ if srcfiles is None or len(srcfiles) < 1:
228
+ return gr.Column(visible=False), None, ui.globals.ui_input_thumbs
229
+
230
+ thumbs = []
231
+ for f in srcfiles:
232
+ source_path = f.name
233
+ if source_path.lower().endswith('fsz'):
234
+ progress(0, desc="Retrieving faces from Faceset File")
235
+ unzipfolder = os.path.join(os.environ["TEMP"], 'faceset')
236
+ if os.path.isdir(unzipfolder):
237
+ files = os.listdir(unzipfolder)
238
+ for file in files:
239
+ os.remove(os.path.join(unzipfolder, file))
240
+ else:
241
+ os.makedirs(unzipfolder)
242
+ util.mkdir_with_umask(unzipfolder)
243
+ util.unzip(source_path, unzipfolder)
244
+ is_first = True
245
+ face_set = FaceSet()
246
+ for file in os.listdir(unzipfolder):
247
+ if file.endswith(".png"):
248
+ filename = os.path.join(unzipfolder,file)
249
+ progress(0, desc="Extracting faceset")
250
+ SELECTION_FACES_DATA = extract_face_images(filename, (False, 0))
251
+ for f in SELECTION_FACES_DATA:
252
+ face = f[0]
253
+ face.mask_offsets = (0,0,0,0,1,20)
254
+ face_set.faces.append(face)
255
+ if is_first:
256
+ image = util.convert_to_gradio(f[1])
257
+ ui.globals.ui_input_thumbs.append(image)
258
+ is_first = False
259
+ face_set.ref_images.append(get_image_frame(filename))
260
+ if len(face_set.faces) > 0:
261
+ if len(face_set.faces) > 1:
262
+ face_set.AverageEmbeddings()
263
+ roop.globals.INPUT_FACESETS.append(face_set)
264
+
265
+ elif util.has_image_extension(source_path):
266
+ progress(0, desc="Retrieving faces from image")
267
+ roop.globals.source_path = source_path
268
+ SELECTION_FACES_DATA = extract_face_images(roop.globals.source_path, (False, 0))
269
+ progress(0.5, desc="Retrieving faces from image")
270
+ for f in SELECTION_FACES_DATA:
271
+ face_set = FaceSet()
272
+ face = f[0]
273
+ face.mask_offsets = (0,0,0,0,1,20)
274
+ face_set.faces.append(face)
275
+ image = util.convert_to_gradio(f[1])
276
+ ui.globals.ui_input_thumbs.append(image)
277
+ roop.globals.INPUT_FACESETS.append(face_set)
278
+
279
+ progress(1.0)
280
+
281
+ # old style with selecting input faces commented out
282
+ # if len(thumbs) < 1:
283
+ # return gr.Column(visible=False), None, ui.globals.ui_input_thumbs
284
+ # return gr.Column(visible=True), thumbs, gr.Gallery(visible=True)
285
+
286
+ return gr.Column(visible=False), None, ui.globals.ui_input_thumbs
287
+
288
+
289
+ def on_select_input_face(evt: gr.SelectData):
290
+ global SELECTED_INPUT_FACE_INDEX
291
+
292
+ SELECTED_INPUT_FACE_INDEX = evt.index
293
+
294
+
295
+ def remove_selected_input_face():
296
+ global SELECTED_INPUT_FACE_INDEX
297
+
298
+ if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX:
299
+ f = roop.globals.INPUT_FACESETS.pop(SELECTED_INPUT_FACE_INDEX)
300
+ del f
301
+ if len(ui.globals.ui_input_thumbs) > SELECTED_INPUT_FACE_INDEX:
302
+ f = ui.globals.ui_input_thumbs.pop(SELECTED_INPUT_FACE_INDEX)
303
+ del f
304
+
305
+ return ui.globals.ui_input_thumbs
306
+
307
+ def on_select_target_face(evt: gr.SelectData):
308
+ global SELECTED_TARGET_FACE_INDEX
309
+
310
+ SELECTED_TARGET_FACE_INDEX = evt.index
311
+
312
+ def remove_selected_target_face():
313
+ if len(roop.globals.TARGET_FACES) > SELECTED_TARGET_FACE_INDEX:
314
+ f = roop.globals.TARGET_FACES.pop(SELECTED_TARGET_FACE_INDEX)
315
+ del f
316
+ if len(ui.globals.ui_target_thumbs) > SELECTED_TARGET_FACE_INDEX:
317
+ f = ui.globals.ui_target_thumbs.pop(SELECTED_TARGET_FACE_INDEX)
318
+ del f
319
+ return ui.globals.ui_target_thumbs
320
+
321
+
322
+
323
+
324
+
325
+ def on_use_face_from_selected(files, frame_num):
326
+ global IS_INPUT, SELECTION_FACES_DATA
327
+
328
+ IS_INPUT = False
329
+ thumbs = []
330
+
331
+ roop.globals.target_path = files[selected_preview_index].name
332
+ if util.is_image(roop.globals.target_path) and not roop.globals.target_path.lower().endswith(('gif')):
333
+ SELECTION_FACES_DATA = extract_face_images(roop.globals.target_path, (False, 0))
334
+ if len(SELECTION_FACES_DATA) > 0:
335
+ for f in SELECTION_FACES_DATA:
336
+ image = util.convert_to_gradio(f[1])
337
+ thumbs.append(image)
338
+ else:
339
+ gr.Info('No faces detected!')
340
+ roop.globals.target_path = None
341
+
342
+ elif util.is_video(roop.globals.target_path) or roop.globals.target_path.lower().endswith(('gif')):
343
+ selected_frame = frame_num
344
+ SELECTION_FACES_DATA = extract_face_images(roop.globals.target_path, (True, selected_frame))
345
+ if len(SELECTION_FACES_DATA) > 0:
346
+ for f in SELECTION_FACES_DATA:
347
+ image = util.convert_to_gradio(f[1])
348
+ thumbs.append(image)
349
+ else:
350
+ gr.Info('No faces detected!')
351
+ roop.globals.target_path = None
352
+
353
+ if len(thumbs) == 1:
354
+ roop.globals.TARGET_FACES.append(SELECTION_FACES_DATA[0][0])
355
+ ui.globals.ui_target_thumbs.append(thumbs[0])
356
+ return gr.Row(visible=False), None, ui.globals.ui_target_thumbs, gr.Dropdown(value='Selected face')
357
+
358
+ return gr.Row(visible=True), thumbs, gr.Gallery(visible=True), gr.Dropdown(visible=True)
359
+
360
+
361
+
362
+ def on_select_face(evt: gr.SelectData): # SelectData is a subclass of EventData
363
+ global SELECTED_FACE_INDEX
364
+ SELECTED_FACE_INDEX = evt.index
365
+
366
+
367
+ def on_selected_face():
368
+ global IS_INPUT, SELECTED_FACE_INDEX, SELECTION_FACES_DATA
369
+
370
+ fd = SELECTION_FACES_DATA[SELECTED_FACE_INDEX]
371
+ image = util.convert_to_gradio(fd[1])
372
+ if IS_INPUT:
373
+ face_set = FaceSet()
374
+ fd[0].mask_offsets = (0,0,0,0,1,20)
375
+ face_set.faces.append(fd[0])
376
+ roop.globals.INPUT_FACESETS.append(face_set)
377
+ ui.globals.ui_input_thumbs.append(image)
378
+ return ui.globals.ui_input_thumbs, gr.Gallery(visible=True), gr.Dropdown(visible=True)
379
+ else:
380
+ roop.globals.TARGET_FACES.append(fd[0])
381
+ ui.globals.ui_target_thumbs.append(image)
382
+ return gr.Gallery(visible=True), ui.globals.ui_target_thumbs, gr.Dropdown(value='Selected face')
383
+
384
+ # bt_faceselect.click(fn=on_selected_face, outputs=[dynamic_face_selection, face_selection, input_faces, target_faces])
385
+
386
+ def on_end_face_selection():
387
+ return gr.Column(visible=False), None
388
+
389
+
390
+ def on_preview_frame_changed(frame_num, files, fake_preview, enhancer, detection, face_distance, blend_ratio,
391
+ use_clip, clip_text, no_face_action, vr_mode, auto_rotate, maskimage, show_mask):
392
+ global SELECTED_INPUT_FACE_INDEX, manual_masking, current_video_fps
393
+
394
+ from roop.core import live_swap
395
+
396
+ manual_masking = False
397
+ mask_offsets = (0,0,0,0)
398
+ if len(roop.globals.INPUT_FACESETS) > SELECTED_INPUT_FACE_INDEX:
399
+ if not hasattr(roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0], 'mask_offsets'):
400
+ roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets = mask_offsets
401
+ mask_offsets = roop.globals.INPUT_FACESETS[SELECTED_INPUT_FACE_INDEX].faces[0].mask_offsets
402
+
403
+ timeinfo = '0:00:00'
404
+ if files is None or selected_preview_index >= len(files) or frame_num is None:
405
+ return None,None, gr.Slider(info=timeinfo)
406
+
407
+ filename = files[selected_preview_index].name
408
+ # time.sleep(0.3)
409
+ if util.is_video(filename) or filename.lower().endswith('gif'):
410
+ current_frame = get_video_frame(filename, frame_num)
411
+ if current_video_fps == 0:
412
+ current_video_fps = 1
413
+ secs = (frame_num - 1) / current_video_fps
414
+ minutes = secs / 60
415
+ secs = secs % 60
416
+ hours = minutes / 60
417
+ minutes = minutes % 60
418
+ milliseconds = (secs - int(secs)) * 1000
419
+ timeinfo = f"{int(hours):0>2}:{int(minutes):0>2}:{int(secs):0>2}.{int(milliseconds):0>3}"
420
+ else:
421
+ current_frame = get_image_frame(filename)
422
+ if current_frame is None:
423
+ return None, None, gr.Slider(info=timeinfo)
424
+
425
+ layers = None
426
+ if maskimage is not None:
427
+ layers = maskimage["layers"]
428
+
429
+ if not fake_preview or len(roop.globals.INPUT_FACESETS) < 1:
430
+ return gr.Image(value=util.convert_to_gradio(current_frame), visible=True), gr.ImageEditor(visible=False), gr.Slider(info=timeinfo)
431
+
432
+ roop.globals.face_swap_mode = translate_swap_mode(detection)
433
+ roop.globals.selected_enhancer = enhancer
434
+ roop.globals.distance_threshold = face_distance
435
+ roop.globals.blend_ratio = blend_ratio
436
+ roop.globals.no_face_action = index_of_no_face_action(no_face_action)
437
+ roop.globals.vr_mode = vr_mode
438
+ roop.globals.autorotate_faces = auto_rotate
439
+
440
+ if use_clip and clip_text is None or len(clip_text) < 1:
441
+ use_clip = False
442
+
443
+ roop.globals.execution_threads = roop.globals.CFG.max_threads
444
+ mask = layers[0] if layers is not None else None
445
+ current_frame = live_swap(current_frame, roop.globals.face_swap_mode, use_clip, clip_text, maskimage, show_mask, SELECTED_INPUT_FACE_INDEX)
446
+ if current_frame is None:
447
+ return gr.Image(visible=True), None, gr.Slider(info=timeinfo)
448
+ return gr.Image(value=util.convert_to_gradio(current_frame), visible=True), gr.ImageEditor(visible=False), gr.Slider(info=timeinfo)
449
+
450
+ def on_toggle_masking(previewimage, mask):
451
+ global manual_masking
452
+
453
+ manual_masking = not manual_masking
454
+ if manual_masking:
455
+ layers = mask["layers"]
456
+ if len(layers) == 1:
457
+ layers = [create_blank_image(previewimage.shape[1],previewimage.shape[0])]
458
+ return gr.Image(visible=False), gr.ImageEditor(value={"background": previewimage, "layers": layers, "composite": None}, visible=True)
459
+ return gr.Image(visible=True), gr.ImageEditor(visible=False)
460
+
461
+ def gen_processing_text(start, end):
462
+ return f'Processing frame range [{start} - {end}]'
463
+
464
+ def on_set_frame(sender:str, frame_num):
465
+ global selected_preview_index, list_files_process
466
+
467
+ idx = selected_preview_index
468
+ if list_files_process[idx].endframe == 0:
469
+ return gen_processing_text(0,0)
470
+
471
+ start = list_files_process[idx].startframe
472
+ end = list_files_process[idx].endframe
473
+ if sender.lower().endswith('start'):
474
+ list_files_process[idx].startframe = min(frame_num, end)
475
+ else:
476
+ list_files_process[idx].endframe = max(frame_num, start)
477
+
478
+ return gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe)
479
+
480
+
481
+
482
+ def on_preview_mask(frame_num, files, clip_text):
483
+ from roop.core import preview_mask
484
+ global is_processing
485
+
486
+ if is_processing or files is None or selected_preview_index >= len(files) or clip_text is None or frame_num is None:
487
+ return None
488
+
489
+ filename = files[selected_preview_index].name
490
+ if util.is_video(filename) or filename.lower().endswith('gif'):
491
+ current_frame = get_video_frame(filename, frame_num)
492
+ else:
493
+ current_frame = get_image_frame(filename)
494
+ if current_frame is None:
495
+ return None
496
+
497
+ current_frame = preview_mask(current_frame, clip_text)
498
+ return util.convert_to_gradio(current_frame)
499
+
500
+
501
+ def on_clear_input_faces():
502
+ ui.globals.ui_input_thumbs.clear()
503
+ roop.globals.INPUT_FACESETS.clear()
504
+ return ui.globals.ui_input_thumbs
505
+
506
+ def on_clear_destfiles():
507
+ roop.globals.TARGET_FACES.clear()
508
+ ui.globals.ui_target_thumbs.clear()
509
+ return ui.globals.ui_target_thumbs
510
+
511
+
512
+ def index_of_no_face_action(dropdown_text):
513
+ global no_face_choices
514
+
515
+ return no_face_choices.index(dropdown_text)
516
+
517
+ def translate_swap_mode(dropdown_text):
518
+ if dropdown_text == "Selected face":
519
+ return "selected"
520
+ elif dropdown_text == "First found":
521
+ return "first"
522
+ elif dropdown_text == "All female":
523
+ return "all_female"
524
+ elif dropdown_text == "All male":
525
+ return "all_male"
526
+
527
+ return "all"
528
+
529
+
530
+
531
+ def start_swap( enhancer, detection, keep_frames, wait_after_extraction, skip_audio, face_distance, blend_ratio,
532
+ use_clip, clip_text, processing_method, no_face_action, vr_mode, autorotate, imagemask, progress=gr.Progress()):
533
+ from ui.main import prepare_environment
534
+ from roop.core import batch_process
535
+ global is_processing, list_files_process
536
+
537
+ if list_files_process is None or len(list_files_process) <= 0:
538
+ return gr.Button(variant="primary"), None, None
539
+
540
+ if roop.globals.CFG.clear_output:
541
+ shutil.rmtree(roop.globals.output_path)
542
+
543
+ if not util.is_installed("ffmpeg"):
544
+ msg = "ffmpeg is not installed! No video processing possible."
545
+ gr.Warning(msg)
546
+
547
+ prepare_environment()
548
+
549
+ roop.globals.selected_enhancer = enhancer
550
+ roop.globals.target_path = None
551
+ roop.globals.distance_threshold = face_distance
552
+ roop.globals.blend_ratio = blend_ratio
553
+ roop.globals.keep_frames = keep_frames
554
+ roop.globals.wait_after_extraction = wait_after_extraction
555
+ roop.globals.skip_audio = skip_audio
556
+ roop.globals.face_swap_mode = translate_swap_mode(detection)
557
+ roop.globals.no_face_action = index_of_no_face_action(no_face_action)
558
+ roop.globals.vr_mode = vr_mode
559
+ roop.globals.autorotate_faces = autorotate
560
+ if use_clip and clip_text is None or len(clip_text) < 1:
561
+ use_clip = False
562
+
563
+ if roop.globals.face_swap_mode == 'selected':
564
+ if len(roop.globals.TARGET_FACES) < 1:
565
+ gr.Error('No Target Face selected!')
566
+ return gr.Button(variant="primary"), None, None
567
+
568
+ is_processing = True
569
+ yield gr.Button(variant="secondary", interactive=False), gr.Button(variant="primary", interactive=True), None
570
+ roop.globals.execution_threads = roop.globals.CFG.max_threads
571
+ roop.globals.video_encoder = roop.globals.CFG.output_video_codec
572
+ roop.globals.video_quality = roop.globals.CFG.video_quality
573
+ roop.globals.max_memory = roop.globals.CFG.memory_limit if roop.globals.CFG.memory_limit > 0 else None
574
+
575
+ batch_process(list_files_process, use_clip, clip_text, processing_method == "In-Memory processing", imagemask, progress, SELECTED_INPUT_FACE_INDEX)
576
+ is_processing = False
577
+ outdir = pathlib.Path(roop.globals.output_path)
578
+ outfiles = [str(item) for item in outdir.rglob("*") if item.is_file()]
579
+ if len(outfiles) > 0:
580
+ yield gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),gr.Files(value=outfiles)
581
+ else:
582
+ yield gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),None
583
+
584
+
585
+ def stop_swap():
586
+ roop.globals.processing = False
587
+ gr.Info('Aborting processing - please wait for the remaining threads to be stopped')
588
+ return gr.Button(variant="primary", interactive=True),gr.Button(variant="secondary", interactive=False),None
589
+
590
+
591
+ def on_fps_changed(fps):
592
+ global selected_preview_index, list_files_process
593
+
594
+ if len(list_files_process) < 1 or list_files_process[selected_preview_index].endframe < 1:
595
+ return
596
+ list_files_process[selected_preview_index].fps = fps
597
+
598
+
599
+ def on_destfiles_changed(destfiles):
600
+ global selected_preview_index, list_files_process, current_video_fps
601
+
602
+ if destfiles is None or len(destfiles) < 1:
603
+ list_files_process.clear()
604
+ return gr.Slider(value=1, maximum=1, info='0:00:00'), ''
605
+
606
+ for f in destfiles:
607
+ list_files_process.append(ProcessEntry(f.name, 0,0, 0))
608
+
609
+ selected_preview_index = 0
610
+ idx = selected_preview_index
611
+
612
+ filename = list_files_process[idx].filename
613
+
614
+ if util.is_video(filename) or filename.lower().endswith('gif'):
615
+ total_frames = get_video_frame_total(filename)
616
+ current_video_fps = util.detect_fps(filename)
617
+ else:
618
+ total_frames = 1
619
+ list_files_process[idx].endframe = total_frames
620
+ if total_frames > 1:
621
+ return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe)
622
+ return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), ''
623
+
624
+
625
+
626
+
627
+ def on_destfiles_selected(evt: gr.SelectData):
628
+ global selected_preview_index, list_files_process, current_video_fps
629
+
630
+ if evt is not None:
631
+ selected_preview_index = evt.index
632
+ idx = selected_preview_index
633
+ filename = list_files_process[idx].filename
634
+ fps = list_files_process[idx].fps
635
+ if util.is_video(filename) or filename.lower().endswith('gif'):
636
+ total_frames = get_video_frame_total(filename)
637
+ current_video_fps = util.detect_fps(filename)
638
+ if list_files_process[idx].endframe == 0:
639
+ list_files_process[idx].endframe = total_frames
640
+ else:
641
+ total_frames = 1
642
+
643
+ if total_frames > 1:
644
+ return gr.Slider(value=list_files_process[idx].startframe, maximum=total_frames, info='0:00:00'), gen_processing_text(list_files_process[idx].startframe,list_files_process[idx].endframe), fps
645
+ return gr.Slider(value=1, maximum=total_frames, info='0:00:00'), gen_processing_text(0,0), fps
646
+
647
+
648
+
649
+ def on_resultfiles_selected(evt: gr.SelectData, files):
650
+ selected_index = evt.index
651
+ filename = files[selected_index].name
652
+ return display_output(filename)
653
+
654
+ def on_resultfiles_finished(files):
655
+ selected_index = 0
656
+ if files is None or len(files) < 1:
657
+ return None, None
658
+
659
+ filename = files[selected_index].name
660
+ return display_output(filename)
661
+
662
+
663
+ def display_output(filename):
664
+ if util.is_video(filename) and roop.globals.CFG.output_show_video:
665
+ return gr.Image(visible=False), gr.Video(visible=True, value=filename)
666
+ else:
667
+ if util.is_video(filename) or filename.lower().endswith('gif'):
668
+ current_frame = get_video_frame(filename)
669
+ else:
670
+ current_frame = get_image_frame(filename)
671
+ return gr.Image(visible=True, value=util.convert_to_gradio(current_frame)), gr.Video(visible=False)
ui/tabs/livecam_tab.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import roop.globals
3
+ import ui.globals
4
+
5
+
6
+ camera_frame = None
7
+
8
+ def livecam_tab():
9
+ with gr.Tab("🎥 Live Cam"):
10
+ with gr.Row(variant='panel'):
11
+ gr.Markdown("""
12
+ This feature will allow you to use your physical webcam and apply the selected faces to the stream.
13
+ You can also forward the stream to a virtual camera, which can be used in video calls or streaming software.<br />
14
+ Supported are: v4l2loopback (linux), OBS Virtual Camera (macOS/Windows) and unitycapture (Windows).<br />
15
+ **Please note:** to change the face or any other settings you need to stop and restart a running live cam.
16
+ """)
17
+
18
+ with gr.Row(variant='panel'):
19
+ with gr.Column():
20
+ bt_start = gr.Button("▶ Start", variant='primary')
21
+ with gr.Column():
22
+ bt_stop = gr.Button("⏹ Stop", variant='secondary', interactive=False)
23
+ with gr.Column():
24
+ camera_num = gr.Slider(0, 2, value=0, label="Camera Number", step=1.0, interactive=True)
25
+ cb_obs = gr.Checkbox(label="Forward stream to virtual camera", interactive=True)
26
+ with gr.Column():
27
+ dd_reso = gr.Dropdown(choices=["640x480","1280x720", "1920x1080"], value="1280x720", label="Fake Camera Resolution", interactive=True)
28
+
29
+ with gr.Row():
30
+ fake_cam_image = gr.Image(label='Fake Camera Output', interactive=False)
31
+
32
+ start_event = bt_start.click(fn=start_cam, inputs=[cb_obs, camera_num, dd_reso, ui.globals.ui_selected_enhancer, ui.globals.ui_blend_ratio],outputs=[bt_start, bt_stop,fake_cam_image])
33
+ bt_stop.click(fn=stop_swap, cancels=[start_event], outputs=[bt_start, bt_stop], queue=False)
34
+
35
+
36
+ def start_cam(stream_to_obs, cam, reso, enhancer, blend_ratio):
37
+ from roop.virtualcam import start_virtual_cam
38
+ from roop.utilities import convert_to_gradio
39
+
40
+ start_virtual_cam(stream_to_obs, cam, reso)
41
+ roop.globals.selected_enhancer = enhancer
42
+ roop.globals.blend_ratio = blend_ratio
43
+ while True:
44
+ yield gr.Button(interactive=False), gr.Button(interactive=True), convert_to_gradio(ui.globals.ui_camera_frame)
45
+
46
+
47
+ def stop_swap():
48
+ from roop.virtualcam import stop_virtual_cam
49
+ stop_virtual_cam()
50
+ return gr.Button(interactive=True), gr.Button(interactive=False)
51
+
52
+
53
+
54
+
ui/tabs/settings_tab.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os
3
+ import gradio as gr
4
+ import roop.globals
5
+ import ui.globals
6
+
7
+ available_themes = ["Default", "gradio/glass", "gradio/monochrome", "gradio/seafoam", "gradio/soft", "gstaff/xkcd", "freddyaboulton/dracula_revamped", "ysharma/steampunk"]
8
+ image_formats = ['jpg','png', 'webp']
9
+ video_formats = ['avi','mkv', 'mp4', 'webm']
10
+ video_codecs = ['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc']
11
+ providerlist = None
12
+
13
+ settings_controls = []
14
+
15
+ def settings_tab():
16
+ from roop.core import suggest_execution_providers
17
+ global providerlist
18
+
19
+ providerlist = suggest_execution_providers()
20
+ with gr.Tab("⚙ Settings"):
21
+ with gr.Row():
22
+ with gr.Column():
23
+ themes = gr.Dropdown(available_themes, label="Theme", info="Change needs complete restart", value=roop.globals.CFG.selected_theme)
24
+ with gr.Column():
25
+ settings_controls.append(gr.Checkbox(label="Public Server", value=roop.globals.CFG.server_share, elem_id='server_share', interactive=True))
26
+ settings_controls.append(gr.Checkbox(label='Clear output folder before each run', value=roop.globals.CFG.clear_output, elem_id='clear_output', interactive=True))
27
+ output_template = gr.Textbox(label="Filename Output Template", info="(file extension is added automatically)", lines=1, placeholder='{file}_{time}', value=roop.globals.CFG.output_template)
28
+ with gr.Column():
29
+ input_server_name = gr.Textbox(label="Server Name", lines=1, info="Leave blank to run locally", value=roop.globals.CFG.server_name)
30
+ with gr.Column():
31
+ input_server_port = gr.Number(label="Server Port", precision=0, info="Leave at 0 to use default", value=roop.globals.CFG.server_port)
32
+ with gr.Row():
33
+ with gr.Column():
34
+ settings_controls.append(gr.Dropdown(providerlist, label="Provider", value=roop.globals.CFG.provider, elem_id='provider', interactive=True))
35
+ chk_det_size = gr.Checkbox(label="Use default Det-Size", value=True, elem_id='default_det_size', interactive=True)
36
+ settings_controls.append(gr.Checkbox(label="Force CPU for Face Analyser", value=roop.globals.CFG.force_cpu, elem_id='force_cpu', interactive=True))
37
+ max_threads = gr.Slider(1, 32, value=roop.globals.CFG.max_threads, label="Max. Number of Threads", info='default: 3', step=1.0, interactive=True)
38
+ with gr.Column():
39
+ memory_limit = gr.Slider(0, 128, value=roop.globals.CFG.memory_limit, label="Max. Memory to use (Gb)", info='0 meaning no limit', step=1.0, interactive=True)
40
+ settings_controls.append(gr.Dropdown(image_formats, label="Image Output Format", info='default: png', value=roop.globals.CFG.output_image_format, elem_id='output_image_format', interactive=True))
41
+ with gr.Column():
42
+ settings_controls.append(gr.Dropdown(video_codecs, label="Video Codec", info='default: libx264', value=roop.globals.CFG.output_video_codec, elem_id='output_video_codec', interactive=True))
43
+ settings_controls.append(gr.Dropdown(video_formats, label="Video Output Format", info='default: mp4', value=roop.globals.CFG.output_video_format, elem_id='output_video_format', interactive=True))
44
+ video_quality = gr.Slider(0, 100, value=roop.globals.CFG.video_quality, label="Video Quality (crf)", info='default: 14', step=1.0, interactive=True)
45
+ with gr.Column():
46
+ with gr.Group():
47
+ settings_controls.append(gr.Checkbox(label='Use OS temp folder', value=roop.globals.CFG.use_os_temp_folder, elem_id='use_os_temp_folder', interactive=True))
48
+ settings_controls.append(gr.Checkbox(label='Show video in browser (re-encodes output)', value=roop.globals.CFG.output_show_video, elem_id='output_show_video', interactive=True))
49
+ button_apply_restart = gr.Button("Restart Server", variant='primary')
50
+ button_clean_temp = gr.Button("Clean temp folder")
51
+ button_apply_settings = gr.Button("Apply Settings")
52
+
53
+ chk_det_size.select(fn=on_option_changed)
54
+
55
+ # Settings
56
+ for s in settings_controls:
57
+ s.select(fn=on_settings_changed)
58
+ max_threads.input(fn=lambda a,b='max_threads':on_settings_changed_misc(a,b), inputs=[max_threads])
59
+ memory_limit.input(fn=lambda a,b='memory_limit':on_settings_changed_misc(a,b), inputs=[memory_limit])
60
+ video_quality.input(fn=lambda a,b='video_quality':on_settings_changed_misc(a,b), inputs=[video_quality])
61
+
62
+ # button_clean_temp.click(fn=clean_temp, outputs=[bt_srcfiles, input_faces, target_faces, bt_destfiles])
63
+ button_clean_temp.click(fn=clean_temp)
64
+ button_apply_settings.click(apply_settings, inputs=[themes, input_server_name, input_server_port, output_template])
65
+ button_apply_restart.click(restart)
66
+
67
+
68
+ def on_option_changed(evt: gr.SelectData):
69
+ attribname = evt.target.elem_id
70
+ if isinstance(evt.target, gr.Checkbox):
71
+ if hasattr(roop.globals, attribname):
72
+ setattr(roop.globals, attribname, evt.selected)
73
+ return
74
+ elif isinstance(evt.target, gr.Dropdown):
75
+ if hasattr(roop.globals, attribname):
76
+ setattr(roop.globals, attribname, evt.value)
77
+ return
78
+ raise gr.Error(f'Unhandled Setting for {evt.target}')
79
+
80
+
81
+ def on_settings_changed_misc(new_val, attribname):
82
+ if hasattr(roop.globals.CFG, attribname):
83
+ setattr(roop.globals.CFG, attribname, new_val)
84
+ else:
85
+ print("Didn't find attrib!")
86
+
87
+
88
+
89
+ def on_settings_changed(evt: gr.SelectData):
90
+ attribname = evt.target.elem_id
91
+ if isinstance(evt.target, gr.Checkbox):
92
+ if hasattr(roop.globals.CFG, attribname):
93
+ setattr(roop.globals.CFG, attribname, evt.selected)
94
+ return
95
+ elif isinstance(evt.target, gr.Dropdown):
96
+ if hasattr(roop.globals.CFG, attribname):
97
+ setattr(roop.globals.CFG, attribname, evt.value)
98
+ return
99
+
100
+ raise gr.Error(f'Unhandled Setting for {evt.target}')
101
+
102
+ def clean_temp():
103
+ from ui.main import prepare_environment
104
+
105
+ if not roop.globals.CFG.use_os_temp_folder:
106
+ shutil.rmtree(os.environ["TEMP"])
107
+ prepare_environment()
108
+
109
+ ui.globals.ui_input_thumbs.clear()
110
+ roop.globals.INPUT_FACESETS.clear()
111
+ roop.globals.TARGET_FACES.clear()
112
+ ui.globals.ui_target_thumbs = []
113
+ gr.Info('Temp Files removed')
114
+ return None,None,None,None
115
+
116
+
117
+ def apply_settings(themes, input_server_name, input_server_port, output_template):
118
+ from ui.main import show_msg
119
+
120
+ roop.globals.CFG.selected_theme = themes
121
+ roop.globals.CFG.server_name = input_server_name
122
+ roop.globals.CFG.server_port = input_server_port
123
+ roop.globals.CFG.output_template = output_template
124
+ roop.globals.CFG.save()
125
+ show_msg('Settings saved')
126
+
127
+
128
+ def restart():
129
+ ui.globals.ui_restart_server = True