Spaces:
Configuration error
Configuration error
| import torch | |
| import subprocess | |
| import json | |
| import os | |
| import dlib | |
| import gdown | |
| import pickle | |
| import re | |
| from models import Wav2Lip | |
| from base64 import b64encode | |
| from urllib.parse import urlparse | |
| from torch.hub import download_url_to_file, get_dir | |
| from IPython.display import HTML, display | |
| device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
| def get_video_details(filename): | |
| cmd = [ | |
| "ffprobe", | |
| "-v", | |
| "error", | |
| "-show_format", | |
| "-show_streams", | |
| "-of", | |
| "json", | |
| filename, | |
| ] | |
| result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
| info = json.loads(result.stdout) | |
| # Get video stream | |
| video_stream = next( | |
| stream for stream in info["streams"] if stream["codec_type"] == "video" | |
| ) | |
| # Get resolution | |
| width = int(video_stream["width"]) | |
| height = int(video_stream["height"]) | |
| resolution = width * height | |
| # Get fps | |
| fps = eval(video_stream["avg_frame_rate"]) | |
| # Get length | |
| length = float(info["format"]["duration"]) | |
| return width, height, fps, length | |
| def show_video(file_path): | |
| """Function to display video in Colab""" | |
| mp4 = open(file_path, "rb").read() | |
| data_url = "data:video/mp4;base64," + b64encode(mp4).decode() | |
| width, _, _, _ = get_video_details(file_path) | |
| display( | |
| HTML( | |
| """ | |
| <video controls width=%d> | |
| <source src="%s" type="video/mp4"> | |
| </video> | |
| """ | |
| % (min(width, 1280), data_url) | |
| ) | |
| ) | |
| def format_time(seconds): | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| seconds = int(seconds % 60) | |
| if hours > 0: | |
| return f"{hours}h {minutes}m {seconds}s" | |
| elif minutes > 0: | |
| return f"{minutes}m {seconds}s" | |
| else: | |
| return f"{seconds}s" | |
| def _load(checkpoint_path): | |
| if device != "cpu": | |
| checkpoint = torch.load(checkpoint_path) | |
| else: | |
| checkpoint = torch.load( | |
| checkpoint_path, map_location=lambda storage, loc: storage | |
| ) | |
| return checkpoint | |
| def load_model(path): | |
| # If results file exists, load it and return | |
| working_directory = os.getcwd() | |
| folder, filename_with_extension = os.path.split(path) | |
| filename, file_type = os.path.splitext(filename_with_extension) | |
| results_file = os.path.join(folder, filename + ".pk1") | |
| if os.path.exists(results_file): | |
| with open(results_file, "rb") as f: | |
| return pickle.load(f) | |
| model = Wav2Lip() | |
| print("Loading {}".format(path)) | |
| checkpoint = _load(path) | |
| s = checkpoint["state_dict"] | |
| new_s = {} | |
| for k, v in s.items(): | |
| new_s[k.replace("module.", "")] = v | |
| model.load_state_dict(new_s) | |
| model = model.to(device) | |
| # Save results to file | |
| with open(results_file, "wb") as f: | |
| pickle.dump(model.eval(), f) | |
| # os.remove(path) | |
| return model.eval() | |
| def get_input_length(filename): | |
| result = subprocess.run( | |
| [ | |
| "ffprobe", | |
| "-v", | |
| "error", | |
| "-show_entries", | |
| "format=duration", | |
| "-of", | |
| "default=noprint_wrappers=1:nokey=1", | |
| filename, | |
| ], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| return float(result.stdout) | |
| def is_url(string): | |
| url_regex = re.compile(r"^(https?|ftp)://[^\s/$.?#].[^\s]*$") | |
| return bool(url_regex.match(string)) | |
| def load_predictor(): | |
| checkpoint = os.path.join( | |
| "checkpoints", "shape_predictor_68_face_landmarks_GTX.dat" | |
| ) | |
| predictor = dlib.shape_predictor(checkpoint) | |
| mouth_detector = dlib.get_frontal_face_detector() | |
| # Serialize the variables | |
| with open(os.path.join("checkpoints", "predictor.pkl"), "wb") as f: | |
| pickle.dump(predictor, f) | |
| with open(os.path.join("checkpoints", "mouth_detector.pkl"), "wb") as f: | |
| pickle.dump(mouth_detector, f) | |
| # delete the .dat file as it is no longer needed | |
| # os.remove(output) | |
| def load_file_from_url(url, model_dir=None, progress=True, file_name=None): | |
| """Load file form http url, will download models if necessary. | |
| Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py | |
| Args: | |
| url (str): URL to be downloaded. | |
| model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. | |
| Default: None. | |
| progress (bool): Whether to show the download progress. Default: True. | |
| file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. | |
| Returns: | |
| str: The path to the downloaded file. | |
| """ | |
| if model_dir is None: # use the pytorch hub_dir | |
| hub_dir = get_dir() | |
| model_dir = os.path.join(hub_dir, "checkpoints") | |
| os.makedirs(model_dir, exist_ok=True) | |
| parts = urlparse(url) | |
| filename = os.path.basename(parts.path) | |
| if file_name is not None: | |
| filename = file_name | |
| cached_file = os.path.abspath(os.path.join(model_dir, filename)) | |
| if not os.path.exists(cached_file): | |
| print(f'Downloading: "{url}" to {cached_file}\n') | |
| download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) | |
| return cached_file | |
| def g_colab(): | |
| try: | |
| import google.colab | |
| return True | |
| except ImportError: | |
| return False | |