Apex-X commited on
Commit
81c1343
·
verified ·
1 Parent(s): 8fe0d35

Update roop/core.py

Browse files
Files changed (1) hide show
  1. roop/core.py +120 -145
roop/core.py CHANGED
@@ -1,38 +1,26 @@
1
  #!/usr/bin/env python3
 
2
  import os
3
  import sys
 
 
 
 
 
4
  import warnings
5
  from typing import List
6
  import platform
7
  import signal
8
  import shutil
9
  import argparse
10
- import torch
11
  import onnxruntime
12
  import tensorflow
13
-
14
  import roop.globals
15
  import roop.metadata
16
  import roop.ui as ui
17
  from roop.predictor import predict_image, predict_video
18
  from roop.processors.frame.core import get_frame_processors_modules
19
- from roop.utilities import (
20
- has_image_extension, is_image, is_video,
21
- detect_fps, create_video, extract_frames,
22
- get_temp_frame_paths, restore_audio,
23
- create_temp, move_temp, clean_temp,
24
- normalize_output_path
25
- )
26
-
27
- # single thread doubles cuda performance - needs to be set before torch import
28
- if any(arg.startswith('--execution-provider') for arg in sys.argv):
29
- os.environ['OMP_NUM_THREADS'] = '1'
30
-
31
- # reduce tensorflow log level
32
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
33
-
34
- if 'ROCMExecutionProvider' in roop.globals.execution_providers:
35
- del torch
36
 
37
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
38
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
@@ -45,13 +33,18 @@ def parse_args() -> None:
45
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
46
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
47
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
48
- program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
49
- program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
50
- program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
51
- program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
52
- program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
53
- program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
54
- program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
 
 
 
 
 
55
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
56
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
57
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
@@ -61,14 +54,19 @@ def parse_args() -> None:
61
  roop.globals.source_path = args.source_path
62
  roop.globals.target_path = args.target_path
63
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
 
64
  roop.globals.frame_processors = args.frame_processor
65
- roop.globals.headless = args.source_path or args.target_path or args.output_path
66
  roop.globals.keep_fps = args.keep_fps
67
- roop.globals.keep_audio = args.keep_audio
68
  roop.globals.keep_frames = args.keep_frames
 
69
  roop.globals.many_faces = args.many_faces
70
- roop.globals.video_encoder = args.video_encoder
71
- roop.globals.video_quality = args.video_quality
 
 
 
 
 
72
  roop.globals.max_memory = args.max_memory
73
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
74
  roop.globals.execution_threads = args.execution_threads
@@ -79,16 +77,8 @@ def encode_execution_providers(execution_providers: List[str]) -> List[str]:
79
 
80
 
81
  def decode_execution_providers(execution_providers: List[str]) -> List[str]:
82
- return [provider for provider, encoded_execution_provider in zip(
83
- onnxruntime.get_available_providers(),
84
- encode_execution_providers(onnxruntime.get_available_providers())
85
- ) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
86
-
87
-
88
- def suggest_max_memory() -> int:
89
- if platform.system().lower() == 'darwin':
90
- return 4
91
- return 16
92
 
93
 
94
  def suggest_execution_providers() -> List[str]:
@@ -96,47 +86,30 @@ def suggest_execution_providers() -> List[str]:
96
 
97
 
98
  def suggest_execution_threads() -> int:
99
- if 'DmlExecutionProvider' in roop.globals.execution_providers:
100
- return 1
101
- if 'ROCMExecutionProvider' in roop.globals.execution_providers:
102
- return 1
103
- return 8
104
 
105
 
106
  def limit_resources() -> None:
107
- try:
108
- gpus = tensorflow.config.experimental.list_physical_devices('GPU')
109
- for gpu in gpus:
110
- tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
111
- tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
112
- ])
113
- except Exception as e:
114
- update_status(f'Failed to limit TensorFlow GPU memory: {e}')
115
-
116
  if roop.globals.max_memory:
117
- memory = roop.globals.max_memory * 1024 ** 3 # GB to bytes
118
- system_name = platform.system().lower()
119
- if system_name == 'darwin':
120
- # Fixed multiplication from 1024**6 to 1024**3 for GB to bytes
121
- memory = roop.globals.max_memory * 1024 ** 3
122
- if system_name == 'windows':
123
- try:
124
- import ctypes
125
- kernel32 = ctypes.windll.kernel32
126
- kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
127
- except Exception as e:
128
- update_status(f'Failed to limit memory on Windows: {e}')
129
  else:
130
- try:
131
- import resource
132
- resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
133
- except Exception as e:
134
- update_status(f'Failed to limit memory on POSIX system: {e}')
135
-
136
-
137
- def release_resources() -> None:
138
- if 'CUDAExecutionProvider' in roop.globals.execution_providers:
139
- torch.cuda.empty_cache()
140
 
141
 
142
  def pre_check() -> bool:
@@ -156,90 +129,92 @@ def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
156
 
157
 
158
  def start() -> None:
159
- try:
160
- for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
161
- if not frame_processor.pre_start():
162
- return
163
- if has_image_extension(roop.globals.target_path):
164
- if predict_image(roop.globals.target_path):
165
- destroy()
166
- shutil.copy2(roop.globals.target_path, roop.globals.output_path)
167
- for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
168
- update_status('Progressing...', frame_processor.NAME)
169
- frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
170
- frame_processor.post_process()
171
- release_resources()
172
- if is_image(roop.globals.target_path):
173
- update_status('Processing to image succeed!')
174
- else:
175
- update_status('Processing to image failed!')
176
  return
177
-
178
- if predict_video(roop.globals.target_path):
 
179
  destroy()
180
- update_status('Creating temp resources...')
181
- create_temp(roop.globals.target_path)
182
-
183
- update_status('Extracting frames...')
184
- if roop.globals.temp_frame_quality is None:
185
- roop.globals.temp_frame_quality = 75 # Default value if not set
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  extract_frames(roop.globals.target_path)
187
-
188
- temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
 
189
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
190
  update_status('Progressing...', frame_processor.NAME)
191
  frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
192
  frame_processor.post_process()
193
- release_resources()
194
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  if roop.globals.keep_fps:
196
- update_status('Detecting fps...')
197
- fps = detect_fps(roop.globals.target_path)
198
- update_status(f'Creating video with {fps} fps...')
199
- create_video(roop.globals.target_path, fps)
200
- else:
201
- update_status('Creating video with 30.0 fps...')
202
- create_video(roop.globals.target_path)
203
-
204
- if roop.globals.keep_audio:
205
- if roop.globals.keep_fps:
206
- update_status('Restoring audio...')
207
- else:
208
- update_status('Restoring audio might cause issues as fps are not kept...')
209
- restore_audio(roop.globals.target_path, roop.globals.output_path)
210
  else:
211
- move_temp(roop.globals.target_path, roop.globals.output_path)
212
-
213
- clean_temp(roop.globals.target_path)
214
- if is_video(roop.globals.target_path):
215
- update_status('Processing to video succeed!')
216
- else:
217
- update_status('Processing to video failed!')
218
- except Exception as e:
219
- update_status(f'Error during processing: {e}')
220
- destroy()
221
 
222
 
223
  def destroy() -> None:
224
  if roop.globals.target_path:
225
  clean_temp(roop.globals.target_path)
226
- quit()
227
 
228
 
229
  def run() -> None:
230
- try:
231
- parse_args()
232
- if not pre_check():
 
 
233
  return
234
- for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
235
- if not frame_processor.pre_check():
236
- return
237
- limit_resources()
238
- if roop.globals.headless:
239
- start()
240
- else:
241
- window = ui.init(start, destroy)
242
- window.mainloop()
243
- except Exception as e:
244
- update_status(f'Fatal error: {e}')
245
- destroy()
 
1
  #!/usr/bin/env python3
2
+
3
  import os
4
  import sys
5
+ # single thread doubles cuda performance - needs to be set before torch import
6
+ if any(arg.startswith('--execution-provider') for arg in sys.argv):
7
+ os.environ['OMP_NUM_THREADS'] = '1'
8
+ # reduce tensorflow log level
9
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
10
  import warnings
11
  from typing import List
12
  import platform
13
  import signal
14
  import shutil
15
  import argparse
 
16
  import onnxruntime
17
  import tensorflow
 
18
  import roop.globals
19
  import roop.metadata
20
  import roop.ui as ui
21
  from roop.predictor import predict_image, predict_video
22
  from roop.processors.frame.core import get_frame_processors_modules
23
+ from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
26
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
 
33
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
34
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
35
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
36
+ program.add_argument('--keep-fps', help='keep target fps', dest='keep_fps', action='store_true')
37
+ program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true')
38
+ program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true')
39
+ program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true')
40
+ program.add_argument('--reference-face-position', help='position of the reference face', dest='reference_face_position', type=int, default=0)
41
+ program.add_argument('--reference-frame-number', help='number of the reference frame', dest='reference_frame_number', type=int, default=0)
42
+ program.add_argument('--similar-face-distance', help='face distance used for recognition', dest='similar_face_distance', type=float, default=0.85)
43
+ program.add_argument('--temp-frame-format', help='image format used for frame extraction', dest='temp_frame_format', default='png', choices=['jpg', 'png'])
44
+ program.add_argument('--temp-frame-quality', help='image quality used for frame extraction', dest='temp_frame_quality', type=int, default=0, choices=range(101), metavar='[0-100]')
45
+ program.add_argument('--output-video-encoder', help='encoder used for the output video', dest='output_video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc'])
46
+ program.add_argument('--output-video-quality', help='quality used for the output video', dest='output_video_quality', type=int, default=35, choices=range(101), metavar='[0-100]')
47
+ program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int)
48
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
49
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
50
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
 
54
  roop.globals.source_path = args.source_path
55
  roop.globals.target_path = args.target_path
56
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
57
+ roop.globals.headless = roop.globals.source_path is not None and roop.globals.target_path is not None and roop.globals.output_path is not None
58
  roop.globals.frame_processors = args.frame_processor
 
59
  roop.globals.keep_fps = args.keep_fps
 
60
  roop.globals.keep_frames = args.keep_frames
61
+ roop.globals.skip_audio = args.skip_audio
62
  roop.globals.many_faces = args.many_faces
63
+ roop.globals.reference_face_position = args.reference_face_position
64
+ roop.globals.reference_frame_number = args.reference_frame_number
65
+ roop.globals.similar_face_distance = args.similar_face_distance
66
+ roop.globals.temp_frame_format = args.temp_frame_format
67
+ roop.globals.temp_frame_quality = args.temp_frame_quality
68
+ roop.globals.output_video_encoder = args.output_video_encoder
69
+ roop.globals.output_video_quality = args.output_video_quality
70
  roop.globals.max_memory = args.max_memory
71
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
72
  roop.globals.execution_threads = args.execution_threads
 
77
 
78
 
79
  def decode_execution_providers(execution_providers: List[str]) -> List[str]:
80
+ return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
81
+ if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
 
 
 
 
 
 
 
 
82
 
83
 
84
  def suggest_execution_providers() -> List[str]:
 
86
 
87
 
88
  def suggest_execution_threads() -> int:
89
+ if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
90
+ return 8
91
+ return 1
 
 
92
 
93
 
94
  def limit_resources() -> None:
95
+ # prevent tensorflow memory leak
96
+ gpus = tensorflow.config.experimental.list_physical_devices('GPU')
97
+ for gpu in gpus:
98
+ tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
99
+ tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
100
+ ])
101
+ # limit memory usage
 
 
102
  if roop.globals.max_memory:
103
+ memory = roop.globals.max_memory * 1024 ** 3
104
+ if platform.system().lower() == 'darwin':
105
+ memory = roop.globals.max_memory * 1024 ** 6
106
+ if platform.system().lower() == 'windows':
107
+ import ctypes
108
+ kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
109
+ kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
 
 
 
 
 
110
  else:
111
+ import resource
112
+ resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
 
 
 
 
 
 
 
 
113
 
114
 
115
  def pre_check() -> bool:
 
129
 
130
 
131
  def start() -> None:
132
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
133
+ if not frame_processor.pre_start():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  return
135
+ # process image to image
136
+ if has_image_extension(roop.globals.target_path):
137
+ if predict_image(roop.globals.target_path):
138
  destroy()
139
+ shutil.copy2(roop.globals.target_path, roop.globals.output_path)
140
+ # process frame
141
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
142
+ update_status('Progressing...', frame_processor.NAME)
143
+ frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
144
+ frame_processor.post_process()
145
+ # validate image
146
+ if is_image(roop.globals.target_path):
147
+ update_status('Processing to image succeed!')
148
+ else:
149
+ update_status('Processing to image failed!')
150
+ return
151
+ # process image to videos
152
+ if predict_video(roop.globals.target_path):
153
+ destroy()
154
+ update_status('Creating temporary resources...')
155
+ create_temp(roop.globals.target_path)
156
+ # extract frames
157
+ if roop.globals.keep_fps:
158
+ fps = detect_fps(roop.globals.target_path)
159
+ update_status(f'Extracting frames with {fps} FPS...')
160
+ extract_frames(roop.globals.target_path, fps)
161
+ else:
162
+ update_status('Extracting frames with 30 FPS...')
163
  extract_frames(roop.globals.target_path)
164
+ # process frame
165
+ temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
166
+ if temp_frame_paths:
167
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
168
  update_status('Progressing...', frame_processor.NAME)
169
  frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
170
  frame_processor.post_process()
171
+ else:
172
+ update_status('Frames not found...')
173
+ return
174
+ # create video
175
+ if roop.globals.keep_fps:
176
+ fps = detect_fps(roop.globals.target_path)
177
+ update_status(f'Creating video with {fps} FPS...')
178
+ create_video(roop.globals.target_path, fps)
179
+ else:
180
+ update_status('Creating video with 30 FPS...')
181
+ create_video(roop.globals.target_path)
182
+ # handle audio
183
+ if roop.globals.skip_audio:
184
+ move_temp(roop.globals.target_path, roop.globals.output_path)
185
+ update_status('Skipping audio...')
186
+ else:
187
  if roop.globals.keep_fps:
188
+ update_status('Restoring audio...')
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  else:
190
+ update_status('Restoring audio might cause issues as fps are not kept...')
191
+ restore_audio(roop.globals.target_path, roop.globals.output_path)
192
+ # clean temp
193
+ update_status('Cleaning temporary resources...')
194
+ clean_temp(roop.globals.target_path)
195
+ # validate video
196
+ if is_video(roop.globals.target_path):
197
+ update_status('Processing to video succeed!')
198
+ else:
199
+ update_status('Processing to video failed!')
200
 
201
 
202
  def destroy() -> None:
203
  if roop.globals.target_path:
204
  clean_temp(roop.globals.target_path)
205
+ sys.exit()
206
 
207
 
208
  def run() -> None:
209
+ parse_args()
210
+ if not pre_check():
211
+ return
212
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
213
+ if not frame_processor.pre_check():
214
  return
215
+ limit_resources()
216
+ if roop.globals.headless:
217
+ start()
218
+ else:
219
+ window = ui.init(start, destroy)
220
+ window.mainloop()