root commited on
Commit
2cfffe0
·
1 Parent(s): aab5d4e
Files changed (1) hide show
  1. roop/utilities.py +51 -249
roop/utilities.py CHANGED
@@ -5,90 +5,64 @@ import platform
5
  import shutil
6
  import ssl
7
  import subprocess
8
- import sys
9
  import urllib
10
- import torch
11
- import gradio
12
- import tempfile
13
- import cv2
14
- import zipfile
15
- import traceback
16
-
17
  from pathlib import Path
18
  from typing import List, Any
19
  from tqdm import tqdm
20
- from scipy.spatial import distance
21
-
22
- import roop.template_parser as template_parser
23
 
24
  import roop.globals
25
 
26
- TEMP_FILE = "temp.mp4"
27
- TEMP_DIRECTORY = "temp"
28
 
29
  # monkey patch ssl for mac
30
- if platform.system().lower() == "darwin":
31
  ssl._create_default_https_context = ssl._create_unverified_context
32
 
33
 
34
- # https://github.com/facefusion/facefusion/blob/master/facefusion
 
 
 
 
 
 
 
 
 
 
35
  def detect_fps(target_path: str) -> float:
36
- fps = 24.0
37
- cap = cv2.VideoCapture(target_path)
38
- if cap.isOpened():
39
- fps = cap.get(cv2.CAP_PROP_FPS)
40
- cap.release()
41
- return fps
42
-
43
-
44
- # Gradio wants Images in RGB
45
- def convert_to_gradio(image):
46
- if image is None:
47
- return None
48
- return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
49
-
50
-
51
- def sort_filenames_ignore_path(filenames):
52
- """Sorts a list of filenames containing a complete path by their filename,
53
- while retaining their original path.
54
-
55
- Args:
56
- filenames: A list of filenames containing a complete path.
57
-
58
- Returns:
59
- A sorted list of filenames containing a complete path.
60
- """
61
- filename_path_tuples = [
62
- (os.path.split(filename)[1], filename) for filename in filenames
63
- ]
64
- sorted_filename_path_tuples = sorted(filename_path_tuples, key=lambda x: x[0])
65
- return [
66
- filename_path_tuple[1] for filename_path_tuple in sorted_filename_path_tuples
67
- ]
68
-
69
-
70
- def sort_rename_frames(path: str):
71
- filenames = os.listdir(path)
72
- filenames.sort()
73
- for i in range(len(filenames)):
74
- of = os.path.join(path, filenames[i])
75
- newidx = i + 1
76
- new_filename = os.path.join(
77
- path, f"{newidx:06d}." + roop.globals.CFG.output_image_format
78
- )
79
- os.rename(of, new_filename)
80
 
81
 
82
  def get_temp_frame_paths(target_path: str) -> List[str]:
83
  temp_directory_path = get_temp_directory_path(target_path)
84
- return glob.glob(
85
- (
86
- os.path.join(
87
- glob.escape(temp_directory_path),
88
- f"*.{roop.globals.CFG.output_image_format}",
89
- )
90
- )
91
- )
92
 
93
 
94
  def get_temp_directory_path(target_path: str) -> str:
@@ -107,35 +81,10 @@ def normalize_output_path(source_path: str, target_path: str, output_path: str)
107
  source_name, _ = os.path.splitext(os.path.basename(source_path))
108
  target_name, target_extension = os.path.splitext(os.path.basename(target_path))
109
  if os.path.isdir(output_path):
110
- return os.path.join(
111
- output_path, source_name + "-" + target_name + target_extension
112
- )
113
  return output_path
114
 
115
 
116
- def get_destfilename_from_path(
117
- srcfilepath: str, destfilepath: str, extension: str
118
- ) -> str:
119
- fn, ext = os.path.splitext(os.path.basename(srcfilepath))
120
- if "." in extension:
121
- return os.path.join(destfilepath, f"{fn}{extension}")
122
- return os.path.join(destfilepath, f"{fn}{extension}{ext}")
123
-
124
-
125
- def replace_template(file_path: str, index: int = 0) -> str:
126
- fn, ext = os.path.splitext(os.path.basename(file_path))
127
-
128
- # Remove the "__temp" placeholder that was used as a temporary filename
129
- fn = fn.replace("__temp", "")
130
-
131
- template = roop.globals.CFG.output_template
132
- replaced_filename = template_parser.parse(
133
- template, {"index": str(index), "file": fn}
134
- )
135
-
136
- return os.path.join(roop.globals.output_path, f"{replaced_filename}{ext}")
137
-
138
-
139
  def create_temp(target_path: str) -> None:
140
  temp_directory_path = get_temp_directory_path(target_path)
141
  Path(temp_directory_path).mkdir(parents=True, exist_ok=True)
@@ -158,30 +107,21 @@ def clean_temp(target_path: str) -> None:
158
  os.rmdir(parent_directory_path)
159
 
160
 
161
- def delete_temp_frames(filename: str) -> None:
162
- dir = os.path.dirname(os.path.dirname(filename))
163
- shutil.rmtree(dir)
164
-
165
-
166
  def has_image_extension(image_path: str) -> bool:
167
- return image_path.lower().endswith(("png", "jpg", "jpeg", "webp"))
168
-
169
-
170
- def has_extension(filepath: str, extensions: List[str]) -> bool:
171
- return filepath.lower().endswith(tuple(extensions))
172
 
173
 
174
  def is_image(image_path: str) -> bool:
175
  if image_path and os.path.isfile(image_path):
176
  mimetype, _ = mimetypes.guess_type(image_path)
177
- return bool(mimetype and mimetype.startswith("image/"))
178
  return False
179
 
180
 
181
  def is_video(video_path: str) -> bool:
182
  if video_path and os.path.isfile(video_path):
183
  mimetype, _ = mimetypes.guess_type(video_path)
184
- return bool(mimetype and mimetype.startswith("video/"))
185
  return False
186
 
187
 
@@ -189,151 +129,13 @@ def conditional_download(download_directory_path: str, urls: List[str]) -> None:
189
  if not os.path.exists(download_directory_path):
190
  os.makedirs(download_directory_path)
191
  for url in urls:
192
- download_file_path = os.path.join(
193
- download_directory_path, os.path.basename(url)
194
- )
195
  if not os.path.exists(download_file_path):
196
- request = urllib.request.urlopen(url) # type: ignore[attr-defined]
197
- total = int(request.headers.get("Content-Length", 0))
198
- with tqdm(
199
- total=total,
200
- desc=f"Downloading {url}",
201
- unit="B",
202
- unit_scale=True,
203
- unit_divisor=1024,
204
- ) as progress:
205
- urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined]
206
-
207
-
208
- def get_local_files_from_folder(folder: str) -> List[str]:
209
- if not os.path.exists(folder) or not os.path.isdir(folder):
210
- return None
211
- files = [
212
- os.path.join(folder, f)
213
- for f in os.listdir(folder)
214
- if os.path.isfile(os.path.join(folder, f))
215
- ]
216
- return files
217
 
218
 
219
  def resolve_relative_path(path: str) -> str:
220
- return os.path.abspath(os.path.join(os.path.dirname(__file__), path))
221
-
222
-
223
- def get_device() -> str:
224
- if len(roop.globals.execution_providers) < 1:
225
- roop.globals.execution_providers = ["CPUExecutionProvider"]
226
-
227
- prov = roop.globals.execution_providers[0]
228
- if "CoreMLExecutionProvider" in prov:
229
- return "mps"
230
- if "CUDAExecutionProvider" in prov or "ROCMExecutionProvider" in prov:
231
- return "cuda"
232
- if "OpenVINOExecutionProvider" in prov:
233
- return "mkl"
234
- return "cpu"
235
-
236
-
237
- def str_to_class(module_name, class_name) -> Any:
238
- from importlib import import_module
239
-
240
- class_ = None
241
- try:
242
- module_ = import_module(module_name)
243
- try:
244
- class_ = getattr(module_, class_name)()
245
- except AttributeError:
246
- print(f"Class {class_name} does not exist")
247
- except ImportError:
248
- print(f"Module {module_name} does not exist")
249
- return class_
250
-
251
- def is_installed(name:str) -> bool:
252
- return shutil.which(name);
253
-
254
- # Taken from https://stackoverflow.com/a/68842705
255
- def get_platform() -> str:
256
- if sys.platform == "linux":
257
- try:
258
- proc_version = open("/proc/version").read()
259
- if "Microsoft" in proc_version:
260
- return "wsl"
261
- except:
262
- pass
263
- return sys.platform
264
-
265
- def open_with_default_app(filename:str):
266
- if filename == None:
267
- return
268
- platform = get_platform()
269
- if platform == "darwin":
270
- subprocess.call(("open", filename))
271
- elif platform in ["win64", "win32"]: os.startfile(filename.replace("/", "\\"))
272
- elif platform == "wsl":
273
- subprocess.call("cmd.exe /C start".split() + [filename])
274
- else: # linux variants
275
- subprocess.call("xdg-open", filename)
276
-
277
-
278
- def prepare_for_batch(target_files) -> str:
279
- print("Preparing temp files")
280
- tempfolder = os.path.join(tempfile.gettempdir(), "rooptmp")
281
- if os.path.exists(tempfolder):
282
- shutil.rmtree(tempfolder)
283
- Path(tempfolder).mkdir(parents=True, exist_ok=True)
284
- for f in target_files:
285
- newname = os.path.basename(f.name)
286
- shutil.move(f.name, os.path.join(tempfolder, newname))
287
- return tempfolder
288
-
289
-
290
- def zip(files, zipname):
291
- with zipfile.ZipFile(zipname, "w") as zip_file:
292
- for f in files:
293
- zip_file.write(f, os.path.basename(f))
294
-
295
-
296
- def unzip(zipfilename: str, target_path: str):
297
- with zipfile.ZipFile(zipfilename, "r") as zip_file:
298
- zip_file.extractall(target_path)
299
-
300
-
301
- def mkdir_with_umask(directory):
302
- oldmask = os.umask(0)
303
- # mode needs octal
304
- os.makedirs(directory, mode=0o775, exist_ok=True)
305
- os.umask(oldmask)
306
-
307
-
308
- def open_folder(path: str):
309
- platform = get_platform()
310
- try:
311
- if platform == "darwin":
312
- subprocess.call(("open", path))
313
- elif platform in ["win64", "win32"]:
314
- open_with_default_app(path)
315
- elif platform == "wsl":
316
- subprocess.call("cmd.exe /C start".split() + [path])
317
- else: # linux variants
318
- subprocess.Popen(["xdg-open", path])
319
- except Exception as e:
320
- traceback.print_exc()
321
- pass
322
- # import webbrowser
323
- # webbrowser.open(url)
324
-
325
-
326
- def create_version_html() -> str:
327
- python_version = ".".join([str(x) for x in sys.version_info[0:3]])
328
- versions_html = f"""
329
- python: <span title="{sys.version}">{python_version}</span>
330
-
331
- torch: {getattr(torch, '__long_version__',torch.__version__)}
332
-
333
- gradio: {gradio.__version__}
334
- """
335
- return versions_html
336
-
337
-
338
- def compute_cosine_distance(emb1, emb2) -> float:
339
- return distance.cosine(emb1, emb2)
 
5
  import shutil
6
  import ssl
7
  import subprocess
 
8
  import urllib
 
 
 
 
 
 
 
9
  from pathlib import Path
10
  from typing import List, Any
11
  from tqdm import tqdm
 
 
 
12
 
13
  import roop.globals
14
 
15
+ TEMP_FILE = 'temp.mp4'
16
+ TEMP_DIRECTORY = 'temp'
17
 
18
  # monkey patch ssl for mac
19
+ if platform.system().lower() == 'darwin':
20
  ssl._create_default_https_context = ssl._create_unverified_context
21
 
22
 
23
+ def run_ffmpeg(args: List[str]) -> bool:
24
+ commands = ['ffmpeg', '-hide_banner', '-hwaccel', 'auto', '-loglevel', roop.globals.log_level]
25
+ commands.extend(args)
26
+ try:
27
+ subprocess.check_output(commands, stderr=subprocess.STDOUT)
28
+ return True
29
+ except Exception:
30
+ pass
31
+ return False
32
+
33
+
34
  def detect_fps(target_path: str) -> float:
35
+ command = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=r_frame_rate', '-of', 'default=noprint_wrappers=1:nokey=1', target_path]
36
+ output = subprocess.check_output(command).decode().strip().split('/')
37
+ try:
38
+ numerator, denominator = map(int, output)
39
+ return numerator / denominator
40
+ except Exception:
41
+ pass
42
+ return 30.0
43
+
44
+
45
+ def extract_frames(target_path: str) -> None:
46
+ temp_directory_path = get_temp_directory_path(target_path)
47
+ run_ffmpeg(['-i', target_path, '-pix_fmt', 'rgb24', os.path.join(temp_directory_path, '%04d.png')])
48
+
49
+
50
+ def create_video(target_path: str, fps: float = 30.0) -> None:
51
+ temp_output_path = get_temp_output_path(target_path)
52
+ temp_directory_path = get_temp_directory_path(target_path)
53
+ run_ffmpeg(['-r', str(fps), '-i', os.path.join(temp_directory_path, '%04d.png'), '-c:v', roop.globals.video_encoder, '-crf', str(roop.globals.video_quality), '-pix_fmt', 'yuv420p', '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1', '-y', temp_output_path])
54
+
55
+
56
+ def restore_audio(target_path: str, output_path: str) -> None:
57
+ temp_output_path = get_temp_output_path(target_path)
58
+ done = run_ffmpeg(['-i', temp_output_path, '-i', target_path, '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0', '-y', output_path])
59
+ if not done:
60
+ move_temp(target_path, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def get_temp_frame_paths(target_path: str) -> List[str]:
64
  temp_directory_path = get_temp_directory_path(target_path)
65
+ return glob.glob((os.path.join(glob.escape(temp_directory_path), '*.png')))
 
 
 
 
 
 
 
66
 
67
 
68
  def get_temp_directory_path(target_path: str) -> str:
 
81
  source_name, _ = os.path.splitext(os.path.basename(source_path))
82
  target_name, target_extension = os.path.splitext(os.path.basename(target_path))
83
  if os.path.isdir(output_path):
84
+ return os.path.join(output_path, source_name + '-' + target_name + target_extension)
 
 
85
  return output_path
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def create_temp(target_path: str) -> None:
89
  temp_directory_path = get_temp_directory_path(target_path)
90
  Path(temp_directory_path).mkdir(parents=True, exist_ok=True)
 
107
  os.rmdir(parent_directory_path)
108
 
109
 
 
 
 
 
 
110
  def has_image_extension(image_path: str) -> bool:
111
+ return image_path.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))
 
 
 
 
112
 
113
 
114
  def is_image(image_path: str) -> bool:
115
  if image_path and os.path.isfile(image_path):
116
  mimetype, _ = mimetypes.guess_type(image_path)
117
+ return bool(mimetype and mimetype.startswith('image/'))
118
  return False
119
 
120
 
121
  def is_video(video_path: str) -> bool:
122
  if video_path and os.path.isfile(video_path):
123
  mimetype, _ = mimetypes.guess_type(video_path)
124
+ return bool(mimetype and mimetype.startswith('video/'))
125
  return False
126
 
127
 
 
129
  if not os.path.exists(download_directory_path):
130
  os.makedirs(download_directory_path)
131
  for url in urls:
132
+ download_file_path = os.path.join(download_directory_path, os.path.basename(url))
 
 
133
  if not os.path.exists(download_file_path):
134
+ request = urllib.request.urlopen(url) # type: ignore[attr-defined]
135
+ total = int(request.headers.get('Content-Length', 0))
136
+ with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress:
137
+ urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
  def resolve_relative_path(path: str) -> str:
141
+ return os.path.abspath(os.path.join(os.path.dirname(__file__), path))