Spaces:
Runtime error
Runtime error
| import platform | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, TypedDict, Union | |
| import cv2 | |
| import numpy as np | |
| import yt_dlp | |
| from loguru import logger | |
| from tqdm import tqdm | |
| from nota_wav2lip.util import FFMPEG_LOGGING_MODE | |
| class LabelInfo(TypedDict): | |
| text: str | |
| conf: int | |
| url: str | |
| bbox_xywhn: Dict[int, Tuple[float, float, float, float]] | |
| def frame_to_time(frame_id: int, fps=25) -> str: | |
| seconds = frame_id / fps | |
| hours = int(seconds // 3600) | |
| seconds -= 3600 * hours | |
| minutes = int(seconds // 60) | |
| seconds -= 60 * minutes | |
| seconds_int = int(seconds) | |
| seconds_milli = int((seconds - int(seconds)) * 1e3) | |
| return f"{hours:02d}:{minutes:02d}:{seconds_int:02d}.{seconds_milli:03d}" # HH:MM:SS.mmm | |
| def save_audio_file(input_path, start_frame_id, to_frame_id, output_path=None): | |
| input_path = Path(input_path) | |
| output_path = output_path if output_path is not None else input_path.with_suffix('.wav') | |
| ss = frame_to_time(start_frame_id) | |
| to = frame_to_time(to_frame_id) | |
| subprocess.call( | |
| f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {input_path} -vn -acodec pcm_s16le -ss {ss} -to {to} -ar 16000 -ac 1 {output_path}", | |
| shell=platform.system() != 'Windows' | |
| ) | |
| def merge_video_audio(video_path, audio_path, output_path): | |
| subprocess.call( | |
| f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -i {audio_path} -strict experimental {output_path}", | |
| shell=platform.system() != 'Windows' | |
| ) | |
| def parse_lrs3_label(label_path) -> LabelInfo: | |
| label_text = Path(label_path).read_text() | |
| label_splitted = label_text.split('\n') | |
| # Label validation | |
| assert label_splitted[0].startswith("Text:") | |
| assert label_splitted[1].startswith("Conf:") | |
| assert label_splitted[2].startswith("Ref:") | |
| assert label_splitted[4].startswith("FRAME") | |
| label_info = LabelInfo(bbox_xywhn={}) | |
| label_info['text'] = label_splitted[0][len("Text: "):].strip() | |
| label_info['conf'] = int(label_splitted[1][len("Conf: "):]) | |
| label_info['url'] = label_splitted[2][len("Ref: "):].strip() | |
| for label_line in label_splitted[5:]: | |
| bbox_splitted = [x.strip() for x in label_line.split('\t')] | |
| if len(bbox_splitted) != 5: | |
| continue | |
| frame_index = int(bbox_splitted[0]) | |
| bbox_xywhn = tuple(map(float, bbox_splitted[1:])) | |
| label_info['bbox_xywhn'][frame_index] = bbox_xywhn | |
| return label_info | |
| def _get_cropped_bbox(bbox_info_xywhn, original_width, original_height): | |
| bbox_info = bbox_info_xywhn | |
| x = bbox_info[0] * original_width | |
| y = bbox_info[1] * original_height | |
| w = bbox_info[2] * original_width | |
| h = bbox_info[3] * original_height | |
| x_min = max(0, int(x - 0.5 * w)) | |
| y_min = max(0, int(y)) | |
| x_max = min(original_width, int(x + 1.5 * w)) | |
| y_max = min(original_height, int(y + 1.5 * h)) | |
| cropped_width = x_max - x_min | |
| cropped_height = y_max - y_min | |
| if cropped_height > cropped_width: | |
| offset = cropped_height - cropped_width | |
| offset_low = min(x_min, offset // 2) | |
| offset_high = min(offset - offset_low, original_width - x_max) | |
| x_min -= offset_low | |
| x_max += offset_high | |
| else: | |
| offset = cropped_width - cropped_height | |
| offset_low = min(y_min, offset // 2) | |
| offset_high = min(offset - offset_low, original_width - y_max) | |
| y_min -= offset_low | |
| y_max += offset_high | |
| return x_min, y_min, x_max, y_max | |
| def _get_smoothened_boxes(bbox_dict, bbox_smoothen_window): | |
| boxes = [np.array(bbox_dict[frame_id]) for frame_id in sorted(bbox_dict)] | |
| for i in range(len(boxes)): | |
| window = boxes[len(boxes) - bbox_smoothen_window:] if i + bbox_smoothen_window > len(boxes) else boxes[i:i + bbox_smoothen_window] | |
| boxes[i] = np.mean(window, axis=0) | |
| for idx, frame_id in enumerate(sorted(bbox_dict)): | |
| bbox_dict[frame_id] = (np.rint(boxes[idx])).astype(int).tolist() | |
| return bbox_dict | |
| def download_video_from_youtube(youtube_ref, output_path): | |
| ydl_url = f"https://www.youtube.com/watch?v={youtube_ref}" | |
| ydl_opts = { | |
| 'format': 'bestvideo[ext=mp4][height<=720]+bestaudio[ext=m4a]/best[ext=mp4][height<=720]', | |
| 'outtmpl': str(output_path), | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([ydl_url]) | |
| def resample_video(input_path, output_path): | |
| subprocess.call( | |
| f"ffmpeg {FFMPEG_LOGGING_MODE['INFO']} -y -i {input_path} -r 25 -preset veryfast {output_path}", | |
| shell=platform.system() != 'Windows' | |
| ) | |
| def _get_smoothen_xyxy_bbox( | |
| label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]], | |
| original_width: int, | |
| original_height: int, | |
| bbox_smoothen_window: int = 5 | |
| ) -> Dict[int, Tuple[float, float, float, float]]: | |
| label_bbox_xyxy: Dict[int, Tuple[float, float, float, float]] = {} | |
| for frame_id in sorted(label_bbox_xywhn): | |
| frame_bbox_xywhn = label_bbox_xywhn[frame_id] | |
| bbox_xyxy = _get_cropped_bbox(frame_bbox_xywhn, original_width, original_height) | |
| label_bbox_xyxy[frame_id] = bbox_xyxy | |
| label_bbox_xyxy = _get_smoothened_boxes(label_bbox_xyxy, bbox_smoothen_window=bbox_smoothen_window) | |
| return label_bbox_xyxy | |
| def get_start_end_frame_id( | |
| label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]], | |
| ) -> Tuple[int, int]: | |
| frame_ids = list(label_bbox_xywhn.keys()) | |
| start_frame_id = min(frame_ids) | |
| to_frame_id = max(frame_ids) | |
| return start_frame_id, to_frame_id | |
| def crop_video_with_bbox( | |
| input_path, | |
| label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]], | |
| start_frame_id, | |
| to_frame_id, | |
| output_path, | |
| bbox_smoothen_window = 5, | |
| frame_width = 224, | |
| frame_height = 224, | |
| fps = 25, | |
| interpolation = cv2.INTER_CUBIC, | |
| ): | |
| def frame_generator(cap): | |
| if not cap.isOpened(): | |
| raise IOError("Error: Could not open video.") | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| yield frame | |
| cap.release() | |
| cap = cv2.VideoCapture(str(input_path)) | |
| original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| label_bbox_xyxy = _get_smoothen_xyxy_bbox(label_bbox_xywhn, original_width, original_height, bbox_smoothen_window=bbox_smoothen_window) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height)) | |
| for frame_id, frame in tqdm(enumerate(frame_generator(cap))): | |
| if start_frame_id <= frame_id <= to_frame_id: | |
| x_min, y_min, x_max, y_max = label_bbox_xyxy[frame_id] | |
| frame_cropped = frame[y_min:y_max, x_min:x_max] | |
| frame_cropped = cv2.resize(frame_cropped, (frame_width, frame_height), interpolation=interpolation) | |
| out.write(frame_cropped) | |
| out.release() | |
| def get_cropped_face_from_lrs3_label( | |
| label_text_path: Union[Path, str], | |
| video_root_dir: Union[Path, str], | |
| bbox_smoothen_window: int = 5, | |
| frame_width: int = 224, | |
| frame_height: int = 224, | |
| fps: int = 25, | |
| interpolation = cv2.INTER_CUBIC, | |
| ignore_cache: bool = False, | |
| ): | |
| label_text_path = Path(label_text_path) | |
| label_info = parse_lrs3_label(label_text_path) | |
| start_frame_id, to_frame_id = get_start_end_frame_id(label_info['bbox_xywhn']) | |
| video_root_dir = Path(video_root_dir) | |
| video_cache_dir = video_root_dir / ".cache" | |
| video_cache_dir.mkdir(parents=True, exist_ok=True) | |
| output_video: Path = video_cache_dir / f"{label_info['url']}.mp4" | |
| output_resampled_video: Path = output_video.with_name(f"{output_video.stem}-25fps.mp4") | |
| output_cropped_audio: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.wav") | |
| output_cropped_video: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.mp4") | |
| output_cropped_with_audio: Path = video_root_dir / output_video.with_name(f"{output_video.stem}-{label_text_path.stem}.mp4").name | |
| if not output_video.exists() or ignore_cache: | |
| youtube_ref = label_info['url'] | |
| logger.info(f"Download Youtube video(https://www.youtube.com/watch?v={youtube_ref}) ... will be saved at {output_video}") | |
| download_video_from_youtube(youtube_ref, output_path=output_video) | |
| if not output_resampled_video.exists() or ignore_cache: | |
| logger.info(f"Resampling video to 25 FPS ... will be saved at {output_resampled_video}") | |
| resample_video(input_path=output_video, output_path=output_resampled_video) | |
| if not output_cropped_audio.exists() or ignore_cache: | |
| logger.info(f"Cut audio file with the given timestamps ... will be saved at {output_cropped_audio}") | |
| save_audio_file( | |
| output_resampled_video, | |
| start_frame_id=start_frame_id, | |
| to_frame_id=to_frame_id, | |
| output_path=output_cropped_audio | |
| ) | |
| logger.info(f"Naive crop the face region with the given frame labels ... will be saved at {output_cropped_video}") | |
| crop_video_with_bbox( | |
| output_resampled_video, | |
| label_info['bbox_xywhn'], | |
| start_frame_id, | |
| to_frame_id, | |
| output_path=output_cropped_video, | |
| bbox_smoothen_window=bbox_smoothen_window, | |
| frame_width=frame_width, | |
| frame_height=frame_height, | |
| fps=fps, | |
| interpolation=interpolation | |
| ) | |
| if not output_cropped_with_audio.exists() or ignore_cache: | |
| logger.info(f"Merge an audio track with the cropped face sequence ... will be saved at {output_cropped_with_audio}") | |
| merge_video_audio(output_cropped_video, output_cropped_audio, output_cropped_with_audio) | |