Apex-X commited on
Commit
46cef6f
·
verified ·
1 Parent(s): ca333e3

Update roop/core.py

Browse files
Files changed (1) hide show
  1. roop/core.py +58 -52
roop/core.py CHANGED
@@ -1,4 +1,5 @@
1
  #!/usr/bin/env python3
 
2
  import os
3
  import sys
4
  # single thread doubles cuda performance - needs to be set before torch import
@@ -12,10 +13,8 @@ import platform
12
  import signal
13
  import shutil
14
  import argparse
15
- import torch
16
  import onnxruntime
17
  import tensorflow
18
-
19
  import roop.globals
20
  import roop.metadata
21
  import roop.ui as ui
@@ -23,9 +22,6 @@ from roop.predictor import predict_image, predict_video
23
  from roop.processors.frame.core import get_frame_processors_modules
24
  from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
25
 
26
- if 'ROCMExecutionProvider' in roop.globals.execution_providers:
27
- del torch
28
-
29
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
30
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
31
 
@@ -37,13 +33,18 @@ def parse_args() -> None:
37
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
38
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
39
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
40
- program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
41
- program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
42
- program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
43
- program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
44
- program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
45
- program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
46
- program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
 
 
 
 
 
47
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
48
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
49
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
@@ -53,14 +54,19 @@ def parse_args() -> None:
53
  roop.globals.source_path = args.source_path
54
  roop.globals.target_path = args.target_path
55
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
 
56
  roop.globals.frame_processors = args.frame_processor
57
- roop.globals.headless = args.source_path or args.target_path or args.output_path
58
  roop.globals.keep_fps = args.keep_fps
59
- roop.globals.keep_audio = args.keep_audio
60
  roop.globals.keep_frames = args.keep_frames
 
61
  roop.globals.many_faces = args.many_faces
62
- roop.globals.video_encoder = args.video_encoder
63
- roop.globals.video_quality = args.video_quality
 
 
 
 
 
64
  roop.globals.max_memory = args.max_memory
65
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
66
  roop.globals.execution_threads = args.execution_threads
@@ -75,22 +81,14 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
75
  if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
76
 
77
 
78
- def suggest_max_memory() -> int:
79
- if platform.system().lower() == 'darwin':
80
- return 4
81
- return 16
82
-
83
-
84
  def suggest_execution_providers() -> List[str]:
85
  return encode_execution_providers(onnxruntime.get_available_providers())
86
 
87
 
88
  def suggest_execution_threads() -> int:
89
- if 'DmlExecutionProvider' in roop.globals.execution_providers:
90
- return 1
91
- if 'ROCMExecutionProvider' in roop.globals.execution_providers:
92
- return 1
93
- return 8
94
 
95
 
96
  def limit_resources() -> None:
@@ -107,18 +105,13 @@ def limit_resources() -> None:
107
  memory = roop.globals.max_memory * 1024 ** 6
108
  if platform.system().lower() == 'windows':
109
  import ctypes
110
- kernel32 = ctypes.windll.kernel32
111
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
112
  else:
113
  import resource
114
  resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
115
 
116
 
117
- def release_resources() -> None:
118
- if 'CUDAExecutionProvider' in roop.globals.execution_providers:
119
- torch.cuda.empty_cache()
120
-
121
-
122
  def pre_check() -> bool:
123
  if sys.version_info < (3, 9):
124
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
@@ -144,11 +137,12 @@ def start() -> None:
144
  if predict_image(roop.globals.target_path):
145
  destroy()
146
  shutil.copy2(roop.globals.target_path, roop.globals.output_path)
 
147
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
148
  update_status('Progressing...', frame_processor.NAME)
149
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
150
  frame_processor.post_process()
151
- release_resources()
152
  if is_image(roop.globals.target_path):
153
  update_status('Processing to image succeed!')
154
  else:
@@ -157,36 +151,48 @@ def start() -> None:
157
  # process image to videos
158
  if predict_video(roop.globals.target_path):
159
  destroy()
160
- update_status('Creating temp resources...')
161
  create_temp(roop.globals.target_path)
162
- update_status('Extracting frames...')
163
- extract_frames(roop.globals.target_path)
 
 
 
 
 
 
 
164
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
165
- for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
166
- update_status('Progressing...', frame_processor.NAME)
167
- frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
168
- frame_processor.post_process()
169
- release_resources()
170
- # handles fps
 
 
 
171
  if roop.globals.keep_fps:
172
- update_status('Detecting fps...')
173
  fps = detect_fps(roop.globals.target_path)
174
- update_status(f'Creating video with {fps} fps...')
175
  create_video(roop.globals.target_path, fps)
176
  else:
177
- update_status('Creating video with 30.0 fps...')
178
  create_video(roop.globals.target_path)
179
  # handle audio
180
- if roop.globals.keep_audio:
 
 
 
181
  if roop.globals.keep_fps:
182
  update_status('Restoring audio...')
183
  else:
184
  update_status('Restoring audio might cause issues as fps are not kept...')
185
  restore_audio(roop.globals.target_path, roop.globals.output_path)
186
- else:
187
- move_temp(roop.globals.target_path, roop.globals.output_path)
188
- # clean and validate
189
  clean_temp(roop.globals.target_path)
 
190
  if is_video(roop.globals.target_path):
191
  update_status('Processing to video succeed!')
192
  else:
@@ -196,7 +202,7 @@ def start() -> None:
196
  def destroy() -> None:
197
  if roop.globals.target_path:
198
  clean_temp(roop.globals.target_path)
199
- quit()
200
 
201
 
202
  def run() -> None:
@@ -211,4 +217,4 @@ def run() -> None:
211
  start()
212
  else:
213
  window = ui.init(start, destroy)
214
- window.mainloop()
 
1
  #!/usr/bin/env python3
2
+
3
  import os
4
  import sys
5
  # single thread doubles cuda performance - needs to be set before torch import
 
13
  import signal
14
  import shutil
15
  import argparse
 
16
  import onnxruntime
17
  import tensorflow
 
18
  import roop.globals
19
  import roop.metadata
20
  import roop.ui as ui
 
22
  from roop.processors.frame.core import get_frame_processors_modules
23
  from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
24
 
 
 
 
25
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
26
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
27
 
 
33
  program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
34
  program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
35
  program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
36
+ program.add_argument('--keep-fps', help='keep target fps', dest='keep_fps', action='store_true')
37
+ program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true')
38
+ program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true')
39
+ program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true')
40
+ program.add_argument('--reference-face-position', help='position of the reference face', dest='reference_face_position', type=int, default=0)
41
+ program.add_argument('--reference-frame-number', help='number of the reference frame', dest='reference_frame_number', type=int, default=0)
42
+ program.add_argument('--similar-face-distance', help='face distance used for recognition', dest='similar_face_distance', type=float, default=0.85)
43
+ program.add_argument('--temp-frame-format', help='image format used for frame extraction', dest='temp_frame_format', default='png', choices=['jpg', 'png'])
44
+ program.add_argument('--temp-frame-quality', help='image quality used for frame extraction', dest='temp_frame_quality', type=int, default=0, choices=range(101), metavar='[0-100]')
45
+ program.add_argument('--output-video-encoder', help='encoder used for the output video', dest='output_video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc'])
46
+ program.add_argument('--output-video-quality', help='quality used for the output video', dest='output_video_quality', type=int, default=35, choices=range(101), metavar='[0-100]')
47
+ program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int)
48
  program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
49
  program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
50
  program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
 
54
  roop.globals.source_path = args.source_path
55
  roop.globals.target_path = args.target_path
56
  roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path)
57
+ roop.globals.headless = roop.globals.source_path is not None and roop.globals.target_path is not None and roop.globals.output_path is not None
58
  roop.globals.frame_processors = args.frame_processor
 
59
  roop.globals.keep_fps = args.keep_fps
 
60
  roop.globals.keep_frames = args.keep_frames
61
+ roop.globals.skip_audio = args.skip_audio
62
  roop.globals.many_faces = args.many_faces
63
+ roop.globals.reference_face_position = args.reference_face_position
64
+ roop.globals.reference_frame_number = args.reference_frame_number
65
+ roop.globals.similar_face_distance = args.similar_face_distance
66
+ roop.globals.temp_frame_format = args.temp_frame_format
67
+ roop.globals.temp_frame_quality = args.temp_frame_quality
68
+ roop.globals.output_video_encoder = args.output_video_encoder
69
+ roop.globals.output_video_quality = args.output_video_quality
70
  roop.globals.max_memory = args.max_memory
71
  roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
72
  roop.globals.execution_threads = args.execution_threads
 
81
  if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
82
 
83
 
 
 
 
 
 
 
84
  def suggest_execution_providers() -> List[str]:
85
  return encode_execution_providers(onnxruntime.get_available_providers())
86
 
87
 
88
  def suggest_execution_threads() -> int:
89
+ if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
90
+ return 8
91
+ return 1
 
 
92
 
93
 
94
  def limit_resources() -> None:
 
105
  memory = roop.globals.max_memory * 1024 ** 6
106
  if platform.system().lower() == 'windows':
107
  import ctypes
108
+ kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
109
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
110
  else:
111
  import resource
112
  resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
113
 
114
 
 
 
 
 
 
115
  def pre_check() -> bool:
116
  if sys.version_info < (3, 9):
117
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
 
137
  if predict_image(roop.globals.target_path):
138
  destroy()
139
  shutil.copy2(roop.globals.target_path, roop.globals.output_path)
140
+ # process frame
141
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
142
  update_status('Progressing...', frame_processor.NAME)
143
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
144
  frame_processor.post_process()
145
+ # validate image
146
  if is_image(roop.globals.target_path):
147
  update_status('Processing to image succeed!')
148
  else:
 
151
  # process image to videos
152
  if predict_video(roop.globals.target_path):
153
  destroy()
154
+ update_status('Creating temporary resources...')
155
  create_temp(roop.globals.target_path)
156
+ # extract frames
157
+ if roop.globals.keep_fps:
158
+ fps = detect_fps(roop.globals.target_path)
159
+ update_status(f'Extracting frames with {fps} FPS...')
160
+ extract_frames(roop.globals.target_path, fps)
161
+ else:
162
+ update_status('Extracting frames with 30 FPS...')
163
+ extract_frames(roop.globals.target_path)
164
+ # process frame
165
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
166
+ if temp_frame_paths:
167
+ for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
168
+ update_status('Progressing...', frame_processor.NAME)
169
+ frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
170
+ frame_processor.post_process()
171
+ else:
172
+ update_status('Frames not found...')
173
+ return
174
+ # create video
175
  if roop.globals.keep_fps:
 
176
  fps = detect_fps(roop.globals.target_path)
177
+ update_status(f'Creating video with {fps} FPS...')
178
  create_video(roop.globals.target_path, fps)
179
  else:
180
+ update_status('Creating video with 30 FPS...')
181
  create_video(roop.globals.target_path)
182
  # handle audio
183
+ if roop.globals.skip_audio:
184
+ move_temp(roop.globals.target_path, roop.globals.output_path)
185
+ update_status('Skipping audio...')
186
+ else:
187
  if roop.globals.keep_fps:
188
  update_status('Restoring audio...')
189
  else:
190
  update_status('Restoring audio might cause issues as fps are not kept...')
191
  restore_audio(roop.globals.target_path, roop.globals.output_path)
192
+ # clean temp
193
+ update_status('Cleaning temporary resources...')
 
194
  clean_temp(roop.globals.target_path)
195
+ # validate video
196
  if is_video(roop.globals.target_path):
197
  update_status('Processing to video succeed!')
198
  else:
 
202
  def destroy() -> None:
203
  if roop.globals.target_path:
204
  clean_temp(roop.globals.target_path)
205
+ sys.exit()
206
 
207
 
208
  def run() -> None:
 
217
  start()
218
  else:
219
  window = ui.init(start, destroy)
220
+ window.mainloop()