Apex-X commited on
Commit
9d84e0e
·
verified ·
1 Parent(s): 3d2113a

Update roop/utilities.py

Browse files
Files changed (1) hide show
  1. roop/utilities.py +14 -22
roop/utilities.py CHANGED
@@ -7,13 +7,13 @@ import ssl
7
  import subprocess
8
  import urllib
9
  from pathlib import Path
10
- from typing import List, Optional
11
  from tqdm import tqdm
12
 
13
  import roop.globals
14
 
 
15
  TEMP_DIRECTORY = 'temp'
16
- TEMP_VIDEO_FILE = 'temp.mp4'
17
 
18
  # monkey patch ssl for mac
19
  if platform.system().lower() == 'darwin':
@@ -21,7 +21,7 @@ if platform.system().lower() == 'darwin':
21
 
22
 
23
  def run_ffmpeg(args: List[str]) -> bool:
24
- commands = ['ffmpeg', '-hide_banner', '-loglevel', roop.globals.log_level]
25
  commands.extend(args)
26
  try:
27
  subprocess.check_output(commands, stderr=subprocess.STDOUT)
@@ -39,26 +39,18 @@ def detect_fps(target_path: str) -> float:
39
  return numerator / denominator
40
  except Exception:
41
  pass
42
- return 30
43
 
44
 
45
- def extract_frames(target_path: str, fps: float = 30) -> bool:
46
  temp_directory_path = get_temp_directory_path(target_path)
47
- temp_frame_quality = roop.globals.temp_frame_quality * 31 // 100
48
- return run_ffmpeg(['-hwaccel', 'auto', '-i', target_path, '-q:v', str(temp_frame_quality), '-pix_fmt', 'rgb24', '-vf', 'fps=' + str(fps), os.path.join(temp_directory_path, '%04d.' + roop.globals.temp_frame_format)])
49
 
50
 
51
- def create_video(target_path: str, fps: float = 30) -> bool:
52
  temp_output_path = get_temp_output_path(target_path)
53
  temp_directory_path = get_temp_directory_path(target_path)
54
- output_video_quality = (roop.globals.output_video_quality + 1) * 51 // 100
55
- commands = ['-hwaccel', 'auto', '-r', str(fps), '-i', os.path.join(temp_directory_path, '%04d.' + roop.globals.temp_frame_format), '-c:v', roop.globals.output_video_encoder]
56
- if roop.globals.output_video_encoder in ['libx264', 'libx265', 'libvpx']:
57
- commands.extend(['-crf', str(output_video_quality)])
58
- if roop.globals.output_video_encoder in ['h264_nvenc', 'hevc_nvenc']:
59
- commands.extend(['-cq', str(output_video_quality)])
60
- commands.extend(['-pix_fmt', 'yuv420p', '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1', '-y', temp_output_path])
61
- return run_ffmpeg(commands)
62
 
63
 
64
  def restore_audio(target_path: str, output_path: str) -> None:
@@ -70,7 +62,7 @@ def restore_audio(target_path: str, output_path: str) -> None:
70
 
71
  def get_temp_frame_paths(target_path: str) -> List[str]:
72
  temp_directory_path = get_temp_directory_path(target_path)
73
- return glob.glob((os.path.join(glob.escape(temp_directory_path), '*.' + roop.globals.temp_frame_format)))
74
 
75
 
76
  def get_temp_directory_path(target_path: str) -> str:
@@ -81,11 +73,11 @@ def get_temp_directory_path(target_path: str) -> str:
81
 
82
  def get_temp_output_path(target_path: str) -> str:
83
  temp_directory_path = get_temp_directory_path(target_path)
84
- return os.path.join(temp_directory_path, TEMP_VIDEO_FILE)
85
 
86
 
87
- def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Optional[str]:
88
- if source_path and target_path and output_path:
89
  source_name, _ = os.path.splitext(os.path.basename(source_path))
90
  target_name, target_extension = os.path.splitext(os.path.basename(target_path))
91
  if os.path.isdir(output_path):
@@ -139,10 +131,10 @@ def conditional_download(download_directory_path: str, urls: List[str]) -> None:
139
  for url in urls:
140
  download_file_path = os.path.join(download_directory_path, os.path.basename(url))
141
  if not os.path.exists(download_file_path):
142
- request = urllib.request.urlopen(url) # type: ignore[attr-defined]
143
  total = int(request.headers.get('Content-Length', 0))
144
  with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress:
145
- urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined]
146
 
147
 
148
  def resolve_relative_path(path: str) -> str:
 
7
  import subprocess
8
  import urllib
9
  from pathlib import Path
10
+ from typing import List, Any
11
  from tqdm import tqdm
12
 
13
  import roop.globals
14
 
15
+ TEMP_FILE = 'temp.mp4'
16
  TEMP_DIRECTORY = 'temp'
 
17
 
18
  # monkey patch ssl for mac
19
  if platform.system().lower() == 'darwin':
 
21
 
22
 
23
  def run_ffmpeg(args: List[str]) -> bool:
24
+ commands = ['ffmpeg', '-hide_banner', '-hwaccel', 'auto', '-loglevel', roop.globals.log_level]
25
  commands.extend(args)
26
  try:
27
  subprocess.check_output(commands, stderr=subprocess.STDOUT)
 
39
  return numerator / denominator
40
  except Exception:
41
  pass
42
+ return 30.0
43
 
44
 
45
+ def extract_frames(target_path: str) -> None:
46
  temp_directory_path = get_temp_directory_path(target_path)
47
+ run_ffmpeg(['-i', target_path, '-pix_fmt', 'rgb24', os.path.join(temp_directory_path, '%04d.png')])
 
48
 
49
 
50
+ def create_video(target_path: str, fps: float = 30.0) -> None:
51
  temp_output_path = get_temp_output_path(target_path)
52
  temp_directory_path = get_temp_directory_path(target_path)
53
+ run_ffmpeg(['-r', str(fps), '-i', os.path.join(temp_directory_path, '%04d.png'), '-c:v', roop.globals.video_encoder, '-crf', str(roop.globals.video_quality), '-pix_fmt', 'yuv420p', '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1', '-y', temp_output_path])
 
 
 
 
 
 
 
54
 
55
 
56
  def restore_audio(target_path: str, output_path: str) -> None:
 
62
 
63
  def get_temp_frame_paths(target_path: str) -> List[str]:
64
  temp_directory_path = get_temp_directory_path(target_path)
65
+ return glob.glob((os.path.join(glob.escape(temp_directory_path), '*.png')))
66
 
67
 
68
  def get_temp_directory_path(target_path: str) -> str:
 
73
 
74
  def get_temp_output_path(target_path: str) -> str:
75
  temp_directory_path = get_temp_directory_path(target_path)
76
+ return os.path.join(temp_directory_path, TEMP_FILE)
77
 
78
 
79
+ def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any:
80
+ if source_path and target_path:
81
  source_name, _ = os.path.splitext(os.path.basename(source_path))
82
  target_name, target_extension = os.path.splitext(os.path.basename(target_path))
83
  if os.path.isdir(output_path):
 
131
  for url in urls:
132
  download_file_path = os.path.join(download_directory_path, os.path.basename(url))
133
  if not os.path.exists(download_file_path):
134
+ request = urllib.request.urlopen(url) # type: ignore[attr-defined]
135
  total = int(request.headers.get('Content-Length', 0))
136
  with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress:
137
+ urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined]
138
 
139
 
140
  def resolve_relative_path(path: str) -> str: