Spaces:
Runtime error
Runtime error
musetalk model
Browse files- app.py +6 -3
- download_musetalk_models.py +31 -0
- lipsync_processing.py +32 -7
- musetalk.py +219 -0
- musetalk.py.bak +219 -0
- musetalk_integration/__init__.py +0 -0
- musetalk_integration/models/__init__.py +0 -0
- musetalk_integration/models/unet.py +51 -0
- musetalk_integration/models/vae.py +148 -0
- musetalk_integration/utils/__init__.py +0 -0
- musetalk_integration/utils/audio_processor.py +102 -0
- musetalk_integration/utils/blending.py +136 -0
- musetalk_integration/utils/dwpose/__init__.py +0 -0
- musetalk_integration/utils/dwpose/default_runtime.py +54 -0
- musetalk_integration/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py +257 -0
- musetalk_integration/utils/face_detection/README.md +1 -0
- musetalk_integration/utils/face_detection/__init__.py +7 -0
- musetalk_integration/utils/face_detection/api.py +240 -0
- musetalk_integration/utils/face_detection/detection/__init__.py +1 -0
- musetalk_integration/utils/face_detection/detection/core.py +130 -0
- musetalk_integration/utils/face_detection/detection/sfd/__init__.py +1 -0
- musetalk_integration/utils/face_detection/detection/sfd/bbox.py +129 -0
- musetalk_integration/utils/face_detection/detection/sfd/detect.py +114 -0
- musetalk_integration/utils/face_detection/detection/sfd/net_s3fd.py +129 -0
- musetalk_integration/utils/face_detection/detection/sfd/sfd_detector.py +59 -0
- musetalk_integration/utils/face_detection/models.py +261 -0
- musetalk_integration/utils/face_detection/utils.py +313 -0
- musetalk_integration/utils/face_parsing/__init__.py +117 -0
- musetalk_integration/utils/face_parsing/model.py +283 -0
- musetalk_integration/utils/face_parsing/resnet.py +109 -0
- musetalk_integration/whisper/__init__.py +116 -0
- musetalk_integration/whisper/__main__.py +4 -0
- musetalk_integration/whisper/audio.py +125 -0
- musetalk_integration/whisper/decoding.py +729 -0
- musetalk_integration/whisper/model.py +290 -0
- musetalk_integration/whisper/tokenizer.py +331 -0
- musetalk_integration/whisper/transcribe.py +207 -0
- musetalk_integration/whisper/utils.py +87 -0
- processing.py +15 -5
- requirements.txt +9 -0
app.py
CHANGED
|
@@ -88,8 +88,11 @@ with gr.Blocks(css=css) as demo:
|
|
| 88 |
audio_input = gr.Audio(
|
| 89 |
label="Target Audio (English only)", type="filepath"
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
label="
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
lipsync_only_btn = gr.Button("👄 Lipsync", variant="primary", size="lg")
|
| 95 |
|
|
@@ -113,7 +116,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 113 |
|
| 114 |
lipsync_only_btn.click(
|
| 115 |
fn=lipsync_with_audio_target,
|
| 116 |
-
inputs=[video_input, audio_input, session_state,
|
| 117 |
outputs=[
|
| 118 |
final_video,
|
| 119 |
video_normalized_output,
|
|
|
|
| 88 |
audio_input = gr.Audio(
|
| 89 |
label="Target Audio (English only)", type="filepath"
|
| 90 |
)
|
| 91 |
+
model_radio = gr.Radio(
|
| 92 |
+
label="Lipsync Model",
|
| 93 |
+
choices=["LatentSync v1.6", "MuseTalk v1.5"],
|
| 94 |
+
value="LatentSync v1.6",
|
| 95 |
+
interactive=True,
|
| 96 |
)
|
| 97 |
lipsync_only_btn = gr.Button("👄 Lipsync", variant="primary", size="lg")
|
| 98 |
|
|
|
|
| 116 |
|
| 117 |
lipsync_only_btn.click(
|
| 118 |
fn=lipsync_with_audio_target,
|
| 119 |
+
inputs=[video_input, audio_input, session_state, model_radio],
|
| 120 |
outputs=[
|
| 121 |
final_video,
|
| 122 |
video_normalized_output,
|
download_musetalk_models.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download MuseTalk V1.5 models"""
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import snapshot_download
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
print("Downloading MuseTalk V1.5 models...")
|
| 7 |
+
|
| 8 |
+
os.makedirs("checkpoints/musetalkV15", exist_ok=True)
|
| 9 |
+
snapshot_download(
|
| 10 |
+
repo_id="TMElyralab/MuseTalk",
|
| 11 |
+
local_dir="./checkpoints/musetalkV15",
|
| 12 |
+
allow_patterns=["*.pth", "*.json", "*.pt"],
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
print("✓ MuseTalk V1.5 models downloaded to checkpoints/musetalkV15/")
|
| 16 |
+
|
| 17 |
+
print("Downloading SD-VAE-FT-MSE model...")
|
| 18 |
+
os.makedirs("checkpoints/sd-vae-ft-mse", exist_ok=True)
|
| 19 |
+
snapshot_download(
|
| 20 |
+
repo_id="stabilityai/sd-vae-ft-mse", local_dir="./checkpoints/sd-vae-ft-mse"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
print("✓ SD-VAE-FT-MSE downloaded to checkpoints/sd-vae-ft-mse/")
|
| 24 |
+
|
| 25 |
+
print("Downloading Whisper-Tiny model...")
|
| 26 |
+
os.makedirs("checkpoints/whisper-tiny", exist_ok=True)
|
| 27 |
+
snapshot_download(repo_id="openai/whisper-tiny", local_dir="./checkpoints/whisper-tiny")
|
| 28 |
+
|
| 29 |
+
print("✓ Whisper-Tiny downloaded to checkpoints/whisper-tiny/")
|
| 30 |
+
|
| 31 |
+
print("\nAll MuseTalk models downloaded successfully!")
|
lipsync_processing.py
CHANGED
|
@@ -45,7 +45,10 @@ def get_video_info(video_path: str) -> dict:
|
|
| 45 |
|
| 46 |
|
| 47 |
def apply_lipsync_to_video(
|
| 48 |
-
video_path: str,
|
|
|
|
|
|
|
|
|
|
| 49 |
) -> tuple:
|
| 50 |
"""Apply lipsync to video using clean 16k audio
|
| 51 |
|
|
@@ -53,17 +56,32 @@ def apply_lipsync_to_video(
|
|
| 53 |
video_path: Path to input video
|
| 54 |
audio_16k_path: Path to 16kHz audio
|
| 55 |
output_dir: Directory to save output
|
| 56 |
-
|
| 57 |
|
| 58 |
Returns:
|
| 59 |
Tuple of (lipsynced_video_path, video_info)
|
| 60 |
"""
|
| 61 |
try:
|
| 62 |
lipsynced_video = os.path.join(output_dir, "output_with_lipsync.mp4")
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
video_info = get_video_info(lipsynced_video)
|
| 69 |
print(
|
|
@@ -78,11 +96,18 @@ def apply_lipsync_to_video(
|
|
| 78 |
)
|
| 79 |
if "face not detected" in str(e).lower():
|
| 80 |
raise RuntimeError(
|
| 81 |
-
"Face detection failed in
|
| 82 |
)
|
| 83 |
print(f"Runtime Error in lipsync processing: {e}")
|
| 84 |
traceback.print_exc()
|
| 85 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
except Exception as e:
|
| 87 |
print(f"Error in apply_lipsync_to_video: {e}")
|
| 88 |
traceback.print_exc()
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def apply_lipsync_to_video(
|
| 48 |
+
video_path: str,
|
| 49 |
+
audio_16k_path: str,
|
| 50 |
+
output_dir: str,
|
| 51 |
+
model_type: str = "LatentSync v1.6",
|
| 52 |
) -> tuple:
|
| 53 |
"""Apply lipsync to video using clean 16k audio
|
| 54 |
|
|
|
|
| 56 |
video_path: Path to input video
|
| 57 |
audio_16k_path: Path to 16kHz audio
|
| 58 |
output_dir: Directory to save output
|
| 59 |
+
model_type: Model type for lipsync ("LatentSync v1.6" or "MuseTalk v1.5")
|
| 60 |
|
| 61 |
Returns:
|
| 62 |
Tuple of (lipsynced_video_path, video_info)
|
| 63 |
"""
|
| 64 |
try:
|
| 65 |
lipsynced_video = os.path.join(output_dir, "output_with_lipsync.mp4")
|
| 66 |
+
|
| 67 |
+
if model_type == "LatentSync v1.6":
|
| 68 |
+
crop_size = 512
|
| 69 |
+
print(
|
| 70 |
+
f"Using LatentSync v1.6: video={video_path}, audio={audio_16k_path}, crop_size={crop_size}"
|
| 71 |
+
)
|
| 72 |
+
apply_lipsync(video_path, audio_16k_path, lipsynced_video, crop_size)
|
| 73 |
+
|
| 74 |
+
elif model_type == "MuseTalk v1.5":
|
| 75 |
+
crop_size = 256
|
| 76 |
+
print(
|
| 77 |
+
f"Using MuseTalk v1.5: video={video_path}, audio={audio_16k_path}, crop_size={crop_size}"
|
| 78 |
+
)
|
| 79 |
+
from musetalk import apply_musetalk_lipsync
|
| 80 |
+
|
| 81 |
+
apply_musetalk_lipsync(video_path, audio_16k_path, lipsynced_video)
|
| 82 |
+
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 85 |
|
| 86 |
video_info = get_video_info(lipsynced_video)
|
| 87 |
print(
|
|
|
|
| 96 |
)
|
| 97 |
if "face not detected" in str(e).lower():
|
| 98 |
raise RuntimeError(
|
| 99 |
+
"Face detection failed in lipsync pipeline. Please upload a video with a clear, visible face."
|
| 100 |
)
|
| 101 |
print(f"Runtime Error in lipsync processing: {e}")
|
| 102 |
traceback.print_exc()
|
| 103 |
raise
|
| 104 |
+
except Exception:
|
| 105 |
+
raise
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"Error in apply_lipsync_to_video: {e}")
|
| 108 |
+
traceback.print_exc()
|
| 109 |
+
raise
|
| 110 |
+
|
| 111 |
except Exception as e:
|
| 112 |
print(f"Error in apply_lipsync_to_video: {e}")
|
| 113 |
traceback.print_exc()
|
musetalk.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MuseTalk V1.5 integration module"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
import copy
|
| 8 |
+
import math
|
| 9 |
+
import subprocess
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from glob import glob
|
| 12 |
+
|
| 13 |
+
from transformers import WhisperModel
|
| 14 |
+
|
| 15 |
+
from musetalk_integration.models.unet import UNet, PositionalEncoding
|
| 16 |
+
from musetalk_integration.models.vae import VAE
|
| 17 |
+
from musetalk_integration.utils.audio_processor import AudioProcessor
|
| 18 |
+
from musetalk_integration.utils.face_parsing import FaceParsing
|
| 19 |
+
from musetalk_integration.utils.blending import get_image
|
| 20 |
+
from musetalk_integration.utils.preprocessing import (
|
| 21 |
+
get_landmark_and_bbox,
|
| 22 |
+
read_imgs,
|
| 23 |
+
coord_placeholder,
|
| 24 |
+
)
|
| 25 |
+
from musetalk_integration.utils.utils import datagen, get_video_fps
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_musetalk_models():
|
| 32 |
+
"""Load MuseTalk V1.5 models"""
|
| 33 |
+
print("Loading MuseTalk V1.5 models...")
|
| 34 |
+
|
| 35 |
+
vae = VAE(model_path="./checkpoints/sd-vae-ft-mse")
|
| 36 |
+
print("✓ VAE loaded")
|
| 37 |
+
|
| 38 |
+
unet = UNet(
|
| 39 |
+
unet_config="./checkpoints/musetalkV15/musetalk.json",
|
| 40 |
+
model_path="./checkpoints/musetalkV15/unet.pth",
|
| 41 |
+
device=device,
|
| 42 |
+
)
|
| 43 |
+
print("✓ UNet loaded")
|
| 44 |
+
|
| 45 |
+
pe = PositionalEncoding(d_model=384)
|
| 46 |
+
print("✓ Positional encoding loaded")
|
| 47 |
+
|
| 48 |
+
audio_processor = AudioProcessor(
|
| 49 |
+
feature_extractor_path="./checkpoints/whisper-tiny"
|
| 50 |
+
)
|
| 51 |
+
print("✓ Audio processor loaded")
|
| 52 |
+
|
| 53 |
+
whisper = WhisperModel.from_pretrained("./checkpoints/whisper-tiny")
|
| 54 |
+
whisper = whisper.to(device=device, dtype=torch.float16).eval()
|
| 55 |
+
whisper.requires_grad_(False)
|
| 56 |
+
print("✓ Whisper model loaded")
|
| 57 |
+
|
| 58 |
+
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
|
| 59 |
+
print("✓ Face parser loaded")
|
| 60 |
+
|
| 61 |
+
timesteps = torch.tensor([0], device=device)
|
| 62 |
+
|
| 63 |
+
return vae, unet, pe, audio_processor, whisper, fp, timesteps
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
vae, unet, pe, audio_processor, whisper, fp, timesteps = load_musetalk_models()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def apply_musetalk_lipsync(
|
| 71 |
+
video_path: str, audio_path: str, video_out_path: str, progress=None
|
| 72 |
+
):
|
| 73 |
+
"""Apply MuseTalk V1.5 lipsync
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
video_path: Path to input video
|
| 77 |
+
audio_path: Path to input audio
|
| 78 |
+
video_out_path: Path to output video
|
| 79 |
+
progress: Progress object
|
| 80 |
+
"""
|
| 81 |
+
print(f"\n{'=' * 60}")
|
| 82 |
+
print(f"MUSETALK V1.5 START")
|
| 83 |
+
print(f"Video: {video_path}")
|
| 84 |
+
print(f"Audio: {audio_path}")
|
| 85 |
+
print(f"Output: {video_out_path}")
|
| 86 |
+
print(f"{'=' * 60}\n")
|
| 87 |
+
|
| 88 |
+
output_dir = os.path.dirname(video_out_path)
|
| 89 |
+
|
| 90 |
+
# 1. Extract frames
|
| 91 |
+
input_basename = os.path.basename(video_path).split(".")[0]
|
| 92 |
+
save_dir_full = os.path.join(output_dir, f"{input_basename}_frames")
|
| 93 |
+
os.makedirs(save_dir_full, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
| 96 |
+
os.system(cmd)
|
| 97 |
+
|
| 98 |
+
input_img_list = sorted(glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]")))
|
| 99 |
+
fps = get_video_fps(video_path)
|
| 100 |
+
print(f"Extracted {len(input_img_list)} frames at {fps} fps")
|
| 101 |
+
|
| 102 |
+
# 2. Extract audio features
|
| 103 |
+
print("Extracting audio features...")
|
| 104 |
+
whisper_input_features, librosa_length = audio_processor.get_audio_feature(
|
| 105 |
+
audio_path
|
| 106 |
+
)
|
| 107 |
+
whisper_chunks = audio_processor.get_whisper_chunk(
|
| 108 |
+
whisper_input_features,
|
| 109 |
+
device,
|
| 110 |
+
torch.float16,
|
| 111 |
+
whisper,
|
| 112 |
+
librosa_length,
|
| 113 |
+
fps=fps,
|
| 114 |
+
audio_padding_length_left=2,
|
| 115 |
+
audio_padding_length_right=2,
|
| 116 |
+
)
|
| 117 |
+
print(f"Generated {len(whisper_chunks)} audio chunks")
|
| 118 |
+
|
| 119 |
+
# 3. Detect face landmarks
|
| 120 |
+
print("Extracting landmarks...")
|
| 121 |
+
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0)
|
| 122 |
+
print(f"Detected {len(coord_list)} face landmarks")
|
| 123 |
+
|
| 124 |
+
# 4. VAE encode
|
| 125 |
+
print("Encoding frames to latents...")
|
| 126 |
+
input_latent_list = []
|
| 127 |
+
for bbox, frame in zip(coord_list, frame_list):
|
| 128 |
+
if bbox == coord_placeholder:
|
| 129 |
+
continue
|
| 130 |
+
x1, y1, x2, y2 = bbox
|
| 131 |
+
y2 = y2 + 10
|
| 132 |
+
y2 = min(y2, frame.shape[0])
|
| 133 |
+
crop_frame = frame[y1:y2, x1:x2]
|
| 134 |
+
crop_frame = cv2.resize(
|
| 135 |
+
crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4
|
| 136 |
+
)
|
| 137 |
+
latents = vae.get_latents_for_unet(crop_frame)
|
| 138 |
+
input_latent_list.append(latents)
|
| 139 |
+
print(f"Encoded {len(input_latent_list)} frames")
|
| 140 |
+
|
| 141 |
+
# 5. Cycle frames for smoothing
|
| 142 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 143 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 144 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 145 |
+
|
| 146 |
+
# 6. Batch inference
|
| 147 |
+
print("Starting inference...")
|
| 148 |
+
batch_size = 8
|
| 149 |
+
gen = datagen(
|
| 150 |
+
whisper_chunks=whisper_chunks,
|
| 151 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 152 |
+
batch_size=batch_size,
|
| 153 |
+
delay_frame=0,
|
| 154 |
+
device=device,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
res_frame_list = []
|
| 158 |
+
for whisper_batch, latent_batch in tqdm(
|
| 159 |
+
gen, total=int(math.ceil(len(whisper_chunks) / batch_size))
|
| 160 |
+
):
|
| 161 |
+
audio_feature_batch = pe(whisper_batch)
|
| 162 |
+
latent_batch = latent_batch.to(dtype=torch.float16)
|
| 163 |
+
|
| 164 |
+
pred_latents = unet.model(
|
| 165 |
+
latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
|
| 166 |
+
).sample
|
| 167 |
+
recon = vae.decode_latents(pred_latents)
|
| 168 |
+
for res_frame in recon:
|
| 169 |
+
res_frame_list.append(res_frame)
|
| 170 |
+
print(f"Generated {len(res_frame_list)} frames")
|
| 171 |
+
|
| 172 |
+
# 7. Blend back to original video
|
| 173 |
+
print("Blending...")
|
| 174 |
+
output_frames_dir = os.path.join(output_dir, f"{input_basename}_output")
|
| 175 |
+
os.makedirs(output_frames_dir, exist_ok=True)
|
| 176 |
+
|
| 177 |
+
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
| 178 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 179 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 180 |
+
x1, y1, x2, y2 = bbox
|
| 181 |
+
y2 = y2 + 10
|
| 182 |
+
y2 = min(y2, ori_frame.shape[0])
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"Warning: Could not resize frame {i}: {e}")
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# MuseTalk v1.5 blending with jaw mode (default)
|
| 191 |
+
combine_frame = get_image(
|
| 192 |
+
ori_frame, res_frame, [x1, y1, x2, y2], mode="jaw", fp=fp
|
| 193 |
+
)
|
| 194 |
+
cv2.imwrite(f"{output_frames_dir}/{i:08d}.png", combine_frame)
|
| 195 |
+
|
| 196 |
+
# 8. Encode to video
|
| 197 |
+
print("Encoding video...")
|
| 198 |
+
cmd = f"ffmpeg -y -v warning -r {fps} -f image2 -i {output_frames_dir}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {video_out_path}"
|
| 199 |
+
os.system(cmd)
|
| 200 |
+
|
| 201 |
+
# 9. Add audio
|
| 202 |
+
print("Adding audio...")
|
| 203 |
+
cmd = f"ffmpeg -y -v warning -i {audio_path} -i {video_out_path} -c:v copy -c:a aac {video_out_path.replace('.mp4', '_final.mp4')}"
|
| 204 |
+
os.system(cmd)
|
| 205 |
+
os.replace(video_out_path.replace(".mp4", "_final.mp4"), video_out_path)
|
| 206 |
+
|
| 207 |
+
# 10. Cleanup
|
| 208 |
+
print("Cleaning up...")
|
| 209 |
+
import shutil
|
| 210 |
+
|
| 211 |
+
shutil.rmtree(save_dir_full)
|
| 212 |
+
shutil.rmtree(output_frames_dir)
|
| 213 |
+
|
| 214 |
+
print(f"\n{'=' * 60}")
|
| 215 |
+
print(f"MUSETALK V1.5 SUCCESS")
|
| 216 |
+
print(f"Output: {video_out_path}")
|
| 217 |
+
print(f"{'=' * 60}\n")
|
| 218 |
+
|
| 219 |
+
return video_out_path
|
musetalk.py.bak
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MuseTalk V1.5 integration module"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
import copy
|
| 8 |
+
import math
|
| 9 |
+
import subprocess
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from glob import glob
|
| 12 |
+
|
| 13 |
+
from transformers import WhisperModel
|
| 14 |
+
|
| 15 |
+
from musetalk_integration.models.unet import UNet, PositionalEncoding
|
| 16 |
+
from musetalk_integration.models.vae import VAE
|
| 17 |
+
from musetalk_integration.utils.audio_processor import AudioProcessor
|
| 18 |
+
from musetalk_integration.utils.face_parsing import FaceParsing
|
| 19 |
+
from musetalk_integration.utils.blending import get_image
|
| 20 |
+
from musetalk_integration.utils.preprocessing import (
|
| 21 |
+
get_landmark_and_bbox,
|
| 22 |
+
read_imgs,
|
| 23 |
+
coord_placeholder,
|
| 24 |
+
)
|
| 25 |
+
from musetalk_integration.utils.utils import datagen, get_video_fps
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_musetalk_models():
|
| 32 |
+
"""Load MuseTalk V1.5 models"""
|
| 33 |
+
print("Loading MuseTalk V1.5 models...")
|
| 34 |
+
|
| 35 |
+
vae = VAE(model_path="./checkpoints/sd-vae-ft-mse")
|
| 36 |
+
print("✓ VAE loaded")
|
| 37 |
+
|
| 38 |
+
unet = UNet(
|
| 39 |
+
unet_config="./checkpoints/musetalkV15/musetalk.json",
|
| 40 |
+
model_path="./checkpoints/musetalkV15/unet.pth",
|
| 41 |
+
device=device,
|
| 42 |
+
)
|
| 43 |
+
print("✓ UNet loaded")
|
| 44 |
+
|
| 45 |
+
pe = PositionalEncoding(d_model=384)
|
| 46 |
+
print("✓ Positional encoding loaded")
|
| 47 |
+
|
| 48 |
+
audio_processor = AudioProcessor(
|
| 49 |
+
feature_extractor_path="./checkpoints/whisper-tiny"
|
| 50 |
+
)
|
| 51 |
+
print("✓ Audio processor loaded")
|
| 52 |
+
|
| 53 |
+
whisper = WhisperModel.from_pretrained("./checkpoints/whisper-tiny")
|
| 54 |
+
whisper = whisper.to(device=device, dtype=torch.float16).eval()
|
| 55 |
+
whisper.requires_grad_(False)
|
| 56 |
+
print("✓ Whisper model loaded")
|
| 57 |
+
|
| 58 |
+
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
|
| 59 |
+
print("✓ Face parser loaded")
|
| 60 |
+
|
| 61 |
+
timesteps = torch.tensor([0], device=device)
|
| 62 |
+
|
| 63 |
+
return vae, unet, pe, audio_processor, whisper, fp, timesteps
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
vae, unet, pe, audio_processor, whisper, fp, timesteps = load_musetalk_models()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def apply_musetalk_lipsync(
|
| 71 |
+
video_path: str, audio_path: str, video_out_path: str, progress=None
|
| 72 |
+
):
|
| 73 |
+
"""Apply MuseTalk V1.5 lipsync
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
video_path: Path to input video
|
| 77 |
+
audio_path: Path to input audio
|
| 78 |
+
video_out_path: Path to output video
|
| 79 |
+
progress: Progress object
|
| 80 |
+
"""
|
| 81 |
+
print(f"\n{'=' * 60}")
|
| 82 |
+
print(f"MUSETALK V1.5 START")
|
| 83 |
+
print(f"Video: {video_path}")
|
| 84 |
+
print(f"Audio: {audio_path}")
|
| 85 |
+
print(f"Output: {video_out_path}")
|
| 86 |
+
print(f"{'=' * 60}\n")
|
| 87 |
+
|
| 88 |
+
output_dir = os.path.dirname(video_out_path)
|
| 89 |
+
|
| 90 |
+
# 1. Extract frames
|
| 91 |
+
input_basename = os.path.basename(video_path).split(".")[0]
|
| 92 |
+
save_dir_full = os.path.join(output_dir, f"{input_basename}_frames")
|
| 93 |
+
os.makedirs(save_dir_full, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
| 96 |
+
os.system(cmd)
|
| 97 |
+
|
| 98 |
+
input_img_list = sorted(glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]")))
|
| 99 |
+
fps = get_video_fps(video_path)
|
| 100 |
+
print(f"Extracted {len(input_img_list)} frames at {fps} fps")
|
| 101 |
+
|
| 102 |
+
# 2. Extract audio features
|
| 103 |
+
print("Extracting audio features...")
|
| 104 |
+
whisper_input_features, librosa_length = audio_processor.get_audio_feature(
|
| 105 |
+
audio_path
|
| 106 |
+
)
|
| 107 |
+
whisper_chunks = audio_processor.get_whisper_chunk(
|
| 108 |
+
whisper_input_features,
|
| 109 |
+
device,
|
| 110 |
+
torch.float16,
|
| 111 |
+
whisper,
|
| 112 |
+
librosa_length,
|
| 113 |
+
fps=fps,
|
| 114 |
+
audio_padding_length_left=2,
|
| 115 |
+
audio_padding_length_right=2,
|
| 116 |
+
)
|
| 117 |
+
print(f"Generated {len(whisper_chunks)} audio chunks")
|
| 118 |
+
|
| 119 |
+
# 3. Detect face landmarks
|
| 120 |
+
print("Extracting landmarks...")
|
| 121 |
+
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0)
|
| 122 |
+
print(f"Detected {len(coord_list)} face landmarks")
|
| 123 |
+
|
| 124 |
+
# 4. VAE encode
|
| 125 |
+
print("Encoding frames to latents...")
|
| 126 |
+
input_latent_list = []
|
| 127 |
+
for bbox, frame in zip(coord_list, frame_list):
|
| 128 |
+
if bbox == coord_placeholder:
|
| 129 |
+
continue
|
| 130 |
+
x1, y1, x2, y2 = bbox
|
| 131 |
+
y2 = y2 + 10
|
| 132 |
+
y2 = min(y2, frame.shape[0])
|
| 133 |
+
crop_frame = frame[y1:y2, x1:x2]
|
| 134 |
+
crop_frame = cv2.resize(
|
| 135 |
+
crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4
|
| 136 |
+
)
|
| 137 |
+
latents = vae.get_latents_for_unet(crop_frame)
|
| 138 |
+
input_latent_list.append(latents)
|
| 139 |
+
print(f"Encoded {len(input_latent_list)} frames")
|
| 140 |
+
|
| 141 |
+
# 5. Cycle frames for smoothing
|
| 142 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 143 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 144 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 145 |
+
|
| 146 |
+
# 6. Batch inference
|
| 147 |
+
print("Starting inference...")
|
| 148 |
+
batch_size = 8
|
| 149 |
+
gen = datagen(
|
| 150 |
+
whisper_chunks=whisper_chunks,
|
| 151 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 152 |
+
batch_size=batch_size,
|
| 153 |
+
delay_frame=0,
|
| 154 |
+
device=device,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
res_frame_list = []
|
| 158 |
+
for whisper_batch, latent_batch in tqdm(
|
| 159 |
+
gen, total=int(math.ceil(len(whisper_chunks) / batch_size))
|
| 160 |
+
):
|
| 161 |
+
audio_feature_batch = pe(whisper_batch)
|
| 162 |
+
latent_batch = latent_batch.to(dtype=torch.float16)
|
| 163 |
+
|
| 164 |
+
pred_latents = unet.model(
|
| 165 |
+
latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
|
| 166 |
+
).sample
|
| 167 |
+
recon = vae.decode_latents(pred_latents)
|
| 168 |
+
for res_frame in recon:
|
| 169 |
+
res_frame_list.append(res_frame)
|
| 170 |
+
print(f"Generated {len(res_frame_list)} frames")
|
| 171 |
+
|
| 172 |
+
# 7. Blend back to original video
|
| 173 |
+
print("Blending...")
|
| 174 |
+
output_frames_dir = os.path.join(output_dir, f"{input_basename}_output")
|
| 175 |
+
os.makedirs(output_frames_dir, exist_ok=True)
|
| 176 |
+
|
| 177 |
+
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
| 178 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 179 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 180 |
+
x1, y1, x2, y2 = bbox
|
| 181 |
+
y2 = y2 + 10
|
| 182 |
+
y2 = min(y2, ori_frame.shape[0])
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"Warning: Could not resize frame {i}: {e}")
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# MuseTalk v1.5 blending with jaw mode (default)
|
| 191 |
+
combine_frame = get_image(
|
| 192 |
+
ori_frame, res_frame, [x1, y1, x2, y2], mode="jaw", fp=fp
|
| 193 |
+
)
|
| 194 |
+
cv2.imwrite(f"{output_frames_dir}/{i:08d}.png", combine_frame)
|
| 195 |
+
|
| 196 |
+
# 8. Encode to video
|
| 197 |
+
print("Encoding video...")
|
| 198 |
+
cmd = f"ffmpeg -y -v warning -r {fps} -f image2 -i {output_frames_dir}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {video_out_path}"
|
| 199 |
+
os.system(cmd)
|
| 200 |
+
|
| 201 |
+
# 9. Add audio
|
| 202 |
+
print("Adding audio...")
|
| 203 |
+
cmd = f"ffmpeg -y -v warning -i {audio_path} -i {video_out_path} -c:v copy -c:a aac {video_out_path.replace('.mp4', '_final.mp4')}"
|
| 204 |
+
os.system(cmd)
|
| 205 |
+
os.replace(video_out_path.replace(".mp4", "_final.mp4"), video_out_path)
|
| 206 |
+
|
| 207 |
+
# 10. Cleanup
|
| 208 |
+
print("Cleaning up...")
|
| 209 |
+
import shutil
|
| 210 |
+
|
| 211 |
+
shutil.rmtree(save_dir_full)
|
| 212 |
+
shutil.rmtree(output_frames_dir)
|
| 213 |
+
|
| 214 |
+
print(f"\n{'=' * 60}")
|
| 215 |
+
print(f"MUSETALK V1.5 SUCCESS")
|
| 216 |
+
print(f"Output: {video_out_path}")
|
| 217 |
+
print(f"{'=' * 60}\n")
|
| 218 |
+
|
| 219 |
+
return video_out_path
|
musetalk_integration/__init__.py
ADDED
|
File without changes
|
musetalk_integration/models/__init__.py
ADDED
|
File without changes
|
musetalk_integration/models/unet.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from diffusers import UNet2DConditionModel
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
class PositionalEncoding(nn.Module):
|
| 13 |
+
def __init__(self, d_model=384, max_len=5000):
|
| 14 |
+
super(PositionalEncoding, self).__init__()
|
| 15 |
+
pe = torch.zeros(max_len, d_model)
|
| 16 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 17 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 18 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 19 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 20 |
+
pe = pe.unsqueeze(0)
|
| 21 |
+
self.register_buffer('pe', pe)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
b, seq_len, d_model = x.size()
|
| 25 |
+
pe = self.pe[:, :seq_len, :]
|
| 26 |
+
x = x + pe.to(x.device)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
class UNet():
|
| 30 |
+
def __init__(self,
|
| 31 |
+
unet_config,
|
| 32 |
+
model_path,
|
| 33 |
+
use_float16=False,
|
| 34 |
+
device=None
|
| 35 |
+
):
|
| 36 |
+
with open(unet_config, 'r') as f:
|
| 37 |
+
unet_config = json.load(f)
|
| 38 |
+
self.model = UNet2DConditionModel(**unet_config)
|
| 39 |
+
self.pe = PositionalEncoding(d_model=384)
|
| 40 |
+
if device != None:
|
| 41 |
+
self.device = device
|
| 42 |
+
else:
|
| 43 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
| 45 |
+
self.model.load_state_dict(weights)
|
| 46 |
+
if use_float16:
|
| 47 |
+
self.model = self.model.half()
|
| 48 |
+
self.model.to(self.device)
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
unet = UNet()
|
musetalk_integration/models/vae.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import AutoencoderKL
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
class VAE():
|
| 11 |
+
"""
|
| 12 |
+
VAE (Variational Autoencoder) class for image processing.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
|
| 16 |
+
"""
|
| 17 |
+
Initialize the VAE instance.
|
| 18 |
+
|
| 19 |
+
:param model_path: Path to the trained model.
|
| 20 |
+
:param resized_img: The size to which images are resized.
|
| 21 |
+
:param use_float16: Whether to use float16 precision.
|
| 22 |
+
"""
|
| 23 |
+
self.model_path = model_path
|
| 24 |
+
self.vae = AutoencoderKL.from_pretrained(self.model_path)
|
| 25 |
+
|
| 26 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
self.vae.to(self.device)
|
| 28 |
+
|
| 29 |
+
if use_float16:
|
| 30 |
+
self.vae = self.vae.half()
|
| 31 |
+
self._use_float16 = True
|
| 32 |
+
else:
|
| 33 |
+
self._use_float16 = False
|
| 34 |
+
|
| 35 |
+
self.scaling_factor = self.vae.config.scaling_factor
|
| 36 |
+
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 37 |
+
self._resized_img = resized_img
|
| 38 |
+
self._mask_tensor = self.get_mask_tensor()
|
| 39 |
+
|
| 40 |
+
def get_mask_tensor(self):
|
| 41 |
+
"""
|
| 42 |
+
Creates a mask tensor for image processing.
|
| 43 |
+
:return: A mask tensor.
|
| 44 |
+
"""
|
| 45 |
+
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
|
| 46 |
+
mask_tensor[:self._resized_img//2,:] = 1
|
| 47 |
+
mask_tensor[mask_tensor< 0.5] = 0
|
| 48 |
+
mask_tensor[mask_tensor>= 0.5] = 1
|
| 49 |
+
return mask_tensor
|
| 50 |
+
|
| 51 |
+
def preprocess_img(self,img_name,half_mask=False):
|
| 52 |
+
"""
|
| 53 |
+
Preprocess an image for the VAE.
|
| 54 |
+
|
| 55 |
+
:param img_name: The image file path or a list of image file paths.
|
| 56 |
+
:param half_mask: Whether to apply a half mask to the image.
|
| 57 |
+
:return: A preprocessed image tensor.
|
| 58 |
+
"""
|
| 59 |
+
window = []
|
| 60 |
+
if isinstance(img_name, str):
|
| 61 |
+
window_fnames = [img_name]
|
| 62 |
+
for fname in window_fnames:
|
| 63 |
+
img = cv2.imread(fname)
|
| 64 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 65 |
+
img = cv2.resize(img, (self._resized_img, self._resized_img),
|
| 66 |
+
interpolation=cv2.INTER_LANCZOS4)
|
| 67 |
+
window.append(img)
|
| 68 |
+
else:
|
| 69 |
+
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
| 70 |
+
window.append(img)
|
| 71 |
+
|
| 72 |
+
x = np.asarray(window) / 255.
|
| 73 |
+
x = np.transpose(x, (3, 0, 1, 2))
|
| 74 |
+
x = torch.squeeze(torch.FloatTensor(x))
|
| 75 |
+
if half_mask:
|
| 76 |
+
x = x * (self._mask_tensor>0.5)
|
| 77 |
+
x = self.transform(x)
|
| 78 |
+
|
| 79 |
+
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
| 80 |
+
x = x.to(self.vae.device)
|
| 81 |
+
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def encode_latents(self,image):
|
| 85 |
+
"""
|
| 86 |
+
Encode an image into latent variables.
|
| 87 |
+
|
| 88 |
+
:param image: The image tensor to encode.
|
| 89 |
+
:return: The encoded latent variables.
|
| 90 |
+
"""
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
|
| 93 |
+
init_latents = self.scaling_factor * init_latent_dist.sample()
|
| 94 |
+
return init_latents
|
| 95 |
+
|
| 96 |
+
def decode_latents(self, latents):
|
| 97 |
+
"""
|
| 98 |
+
Decode latent variables back into an image.
|
| 99 |
+
:param latents: The latent variables to decode.
|
| 100 |
+
:return: A NumPy array representing the decoded image.
|
| 101 |
+
"""
|
| 102 |
+
latents = (1/ self.scaling_factor) * latents
|
| 103 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 104 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 105 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
| 106 |
+
image = (image * 255).round().astype("uint8")
|
| 107 |
+
image = image[...,::-1] # RGB to BGR
|
| 108 |
+
return image
|
| 109 |
+
|
| 110 |
+
def get_latents_for_unet(self,img):
|
| 111 |
+
"""
|
| 112 |
+
Prepare latent variables for a U-Net model.
|
| 113 |
+
:param img: The image to process.
|
| 114 |
+
:return: A concatenated tensor of latents for U-Net input.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
| 118 |
+
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
| 119 |
+
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
| 120 |
+
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
| 121 |
+
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
| 122 |
+
return latent_model_input
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
vae_mode_path = "./models/sd-vae-ft-mse/"
|
| 126 |
+
vae = VAE(model_path = vae_mode_path,use_float16=False)
|
| 127 |
+
img_path = "./results/sun001_crop/00000.png"
|
| 128 |
+
|
| 129 |
+
crop_imgs_path = "./results/sun001_crop/"
|
| 130 |
+
latents_out_path = "./results/latents/"
|
| 131 |
+
if not os.path.exists(latents_out_path):
|
| 132 |
+
os.mkdir(latents_out_path)
|
| 133 |
+
|
| 134 |
+
files = os.listdir(crop_imgs_path)
|
| 135 |
+
files.sort()
|
| 136 |
+
files = [file for file in files if file.split(".")[-1] == "png"]
|
| 137 |
+
|
| 138 |
+
for file in files:
|
| 139 |
+
index = file.split(".")[0]
|
| 140 |
+
img_path = crop_imgs_path + file
|
| 141 |
+
latents = vae.get_latents_for_unet(img_path)
|
| 142 |
+
print(img_path,"latents",latents.size())
|
| 143 |
+
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
|
| 144 |
+
#reload_tensor = torch.load('tensor.pt')
|
| 145 |
+
#print(reload_tensor.size())
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
musetalk_integration/utils/__init__.py
ADDED
|
File without changes
|
musetalk_integration/utils/audio_processor.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import librosa
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from transformers import AutoFeatureExtractor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AudioProcessor:
|
| 12 |
+
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
| 13 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
| 14 |
+
|
| 15 |
+
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
| 16 |
+
if not os.path.exists(wav_path):
|
| 17 |
+
return None
|
| 18 |
+
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
|
| 19 |
+
assert sampling_rate == 16000
|
| 20 |
+
# Split audio into 30s segments
|
| 21 |
+
segment_length = 30 * sampling_rate
|
| 22 |
+
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
|
| 23 |
+
|
| 24 |
+
features = []
|
| 25 |
+
for segment in segments:
|
| 26 |
+
audio_feature = self.feature_extractor(
|
| 27 |
+
segment,
|
| 28 |
+
return_tensors="pt",
|
| 29 |
+
sampling_rate=sampling_rate
|
| 30 |
+
).input_features
|
| 31 |
+
if weight_dtype is not None:
|
| 32 |
+
audio_feature = audio_feature.to(dtype=weight_dtype)
|
| 33 |
+
features.append(audio_feature)
|
| 34 |
+
|
| 35 |
+
return features, len(librosa_output)
|
| 36 |
+
|
| 37 |
+
def get_whisper_chunk(
|
| 38 |
+
self,
|
| 39 |
+
whisper_input_features,
|
| 40 |
+
device,
|
| 41 |
+
weight_dtype,
|
| 42 |
+
whisper,
|
| 43 |
+
librosa_length,
|
| 44 |
+
fps=25,
|
| 45 |
+
audio_padding_length_left=2,
|
| 46 |
+
audio_padding_length_right=2,
|
| 47 |
+
):
|
| 48 |
+
audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
|
| 49 |
+
whisper_feature = []
|
| 50 |
+
# Process multiple 30s mel input features
|
| 51 |
+
for input_feature in whisper_input_features:
|
| 52 |
+
input_feature = input_feature.to(device).to(weight_dtype)
|
| 53 |
+
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
|
| 54 |
+
audio_feats = torch.stack(audio_feats, dim=2)
|
| 55 |
+
whisper_feature.append(audio_feats)
|
| 56 |
+
|
| 57 |
+
whisper_feature = torch.cat(whisper_feature, dim=1)
|
| 58 |
+
# Trim the last segment to remove padding
|
| 59 |
+
sr = 16000
|
| 60 |
+
audio_fps = 50
|
| 61 |
+
fps = int(fps)
|
| 62 |
+
whisper_idx_multiplier = audio_fps / fps
|
| 63 |
+
num_frames = math.floor((librosa_length / sr) * fps)
|
| 64 |
+
actual_length = math.floor((librosa_length / sr) * audio_fps)
|
| 65 |
+
whisper_feature = whisper_feature[:,:actual_length,...]
|
| 66 |
+
|
| 67 |
+
# Calculate padding amount
|
| 68 |
+
padding_nums = math.ceil(whisper_idx_multiplier)
|
| 69 |
+
# Add padding at start and end
|
| 70 |
+
whisper_feature = torch.cat([
|
| 71 |
+
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
|
| 72 |
+
whisper_feature,
|
| 73 |
+
# Add extra padding to prevent out of bounds
|
| 74 |
+
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
|
| 75 |
+
], 1)
|
| 76 |
+
|
| 77 |
+
audio_prompts = []
|
| 78 |
+
for frame_index in range(num_frames):
|
| 79 |
+
try:
|
| 80 |
+
audio_index = math.floor(frame_index * whisper_idx_multiplier)
|
| 81 |
+
audio_clip = whisper_feature[:, audio_index: audio_index + audio_feature_length_per_frame]
|
| 82 |
+
assert audio_clip.shape[1] == audio_feature_length_per_frame
|
| 83 |
+
audio_prompts.append(audio_clip)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Error occurred: {e}")
|
| 86 |
+
print(f"whisper_feature.shape: {whisper_feature.shape}")
|
| 87 |
+
print(f"audio_clip.shape: {audio_clip.shape}")
|
| 88 |
+
print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
|
| 89 |
+
print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
|
| 90 |
+
exit()
|
| 91 |
+
|
| 92 |
+
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
|
| 93 |
+
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
|
| 94 |
+
return audio_prompts
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
audio_processor = AudioProcessor()
|
| 98 |
+
wav_path = "./2.wav"
|
| 99 |
+
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
| 100 |
+
print("Audio Feature shape:", audio_feature.shape)
|
| 101 |
+
print("librosa_feature_length:", librosa_feature_length)
|
| 102 |
+
|
musetalk_integration/utils/blending.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_crop_box(box, expand):
|
| 8 |
+
x, y, x1, y1 = box
|
| 9 |
+
x_c, y_c = (x+x1)//2, (y+y1)//2
|
| 10 |
+
w, h = x1-x, y1-y
|
| 11 |
+
s = int(max(w, h)//2*expand)
|
| 12 |
+
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
| 13 |
+
return crop_box, s
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def face_seg(image, mode="raw", fp=None):
|
| 17 |
+
"""
|
| 18 |
+
对图像进行面部解析,生成面部区域的掩码。
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
image (PIL.Image): 输入图像。
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
PIL.Image: 面部区域的掩码图像。
|
| 25 |
+
"""
|
| 26 |
+
seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
|
| 27 |
+
if seg_image is None:
|
| 28 |
+
print("error, no person_segment") # 如果没有检测到面部,返回错误
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
|
| 32 |
+
return seg_image
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
|
| 36 |
+
"""
|
| 37 |
+
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
image (numpy.ndarray): 原始图像(身体部分)。
|
| 41 |
+
face (numpy.ndarray): 裁剪的面部图像。
|
| 42 |
+
face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
|
| 43 |
+
upper_boundary_ratio (float): 用于控制面部区域的保留比例。
|
| 44 |
+
expand (float): 扩展因子,用于放大裁剪框。
|
| 45 |
+
mode: 融合mask构建方式
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
numpy.ndarray: 处理后的图像。
|
| 49 |
+
"""
|
| 50 |
+
# 将 numpy 数组转换为 PIL 图像
|
| 51 |
+
body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
|
| 52 |
+
face = Image.fromarray(face[:, :, ::-1]) # 面部图像
|
| 53 |
+
|
| 54 |
+
x, y, x1, y1 = face_box # 获取面部边界框的坐标
|
| 55 |
+
crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
|
| 56 |
+
x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
|
| 57 |
+
face_position = (x, y) # 面部在原始图像中的位置
|
| 58 |
+
|
| 59 |
+
# 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
|
| 60 |
+
face_large = body.crop(crop_box)
|
| 61 |
+
|
| 62 |
+
ori_shape = face_large.size # 裁剪后图像的原始尺寸
|
| 63 |
+
|
| 64 |
+
# 对裁剪后的面部区域进行面部解析,生成掩码
|
| 65 |
+
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
| 66 |
+
|
| 67 |
+
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
|
| 68 |
+
|
| 69 |
+
mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
|
| 70 |
+
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# 保留面部区域的上半部分(用于控制说话区域)
|
| 74 |
+
width, height = mask_image.size
|
| 75 |
+
top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
|
| 76 |
+
modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
|
| 77 |
+
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# 对掩码进行高斯模糊,使边缘更平滑
|
| 81 |
+
blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
|
| 82 |
+
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
|
| 83 |
+
#mask_array = np.array(modified_mask_image)
|
| 84 |
+
mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
|
| 85 |
+
|
| 86 |
+
# 将裁剪的面部图像粘贴回扩展后的面部区域
|
| 87 |
+
face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
| 88 |
+
|
| 89 |
+
body.paste(face_large, crop_box[:2], mask_image)
|
| 90 |
+
|
| 91 |
+
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
|
| 92 |
+
|
| 93 |
+
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_image_blending(image, face, face_box, mask_array, crop_box):
|
| 97 |
+
body = Image.fromarray(image[:,:,::-1])
|
| 98 |
+
face = Image.fromarray(face[:,:,::-1])
|
| 99 |
+
|
| 100 |
+
x, y, x1, y1 = face_box
|
| 101 |
+
x_s, y_s, x_e, y_e = crop_box
|
| 102 |
+
face_large = body.crop(crop_box)
|
| 103 |
+
|
| 104 |
+
mask_image = Image.fromarray(mask_array)
|
| 105 |
+
mask_image = mask_image.convert("L")
|
| 106 |
+
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
| 107 |
+
body.paste(face_large, crop_box[:2], mask_image)
|
| 108 |
+
body = np.array(body)
|
| 109 |
+
return body[:,:,::-1]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
|
| 113 |
+
body = Image.fromarray(image[:,:,::-1])
|
| 114 |
+
|
| 115 |
+
x, y, x1, y1 = face_box
|
| 116 |
+
#print(x1-x,y1-y)
|
| 117 |
+
crop_box, s = get_crop_box(face_box, expand)
|
| 118 |
+
x_s, y_s, x_e, y_e = crop_box
|
| 119 |
+
|
| 120 |
+
face_large = body.crop(crop_box)
|
| 121 |
+
ori_shape = face_large.size
|
| 122 |
+
|
| 123 |
+
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
| 124 |
+
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
| 125 |
+
mask_image = Image.new('L', ori_shape, 0)
|
| 126 |
+
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
| 127 |
+
|
| 128 |
+
# keep upper_boundary_ratio of talking area
|
| 129 |
+
width, height = mask_image.size
|
| 130 |
+
top_boundary = int(height * upper_boundary_ratio)
|
| 131 |
+
modified_mask_image = Image.new('L', ori_shape, 0)
|
| 132 |
+
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
| 133 |
+
|
| 134 |
+
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
| 135 |
+
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
| 136 |
+
return mask_array, crop_box
|
musetalk_integration/utils/dwpose/__init__.py
ADDED
|
File without changes
|
musetalk_integration/utils/dwpose/default_runtime.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_scope = 'mmpose'
|
| 2 |
+
|
| 3 |
+
# hooks
|
| 4 |
+
default_hooks = dict(
|
| 5 |
+
timer=dict(type='IterTimerHook'),
|
| 6 |
+
logger=dict(type='LoggerHook', interval=50),
|
| 7 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
| 8 |
+
checkpoint=dict(type='CheckpointHook', interval=10),
|
| 9 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
| 10 |
+
visualization=dict(type='PoseVisualizationHook', enable=False),
|
| 11 |
+
badcase=dict(
|
| 12 |
+
type='BadCaseAnalysisHook',
|
| 13 |
+
enable=False,
|
| 14 |
+
out_dir='badcase',
|
| 15 |
+
metric_type='loss',
|
| 16 |
+
badcase_thr=5))
|
| 17 |
+
|
| 18 |
+
# custom hooks
|
| 19 |
+
custom_hooks = [
|
| 20 |
+
# Synchronize model buffers such as running_mean and running_var in BN
|
| 21 |
+
# at the end of each epoch
|
| 22 |
+
dict(type='SyncBuffersHook')
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# multi-processing backend
|
| 26 |
+
env_cfg = dict(
|
| 27 |
+
cudnn_benchmark=False,
|
| 28 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
| 29 |
+
dist_cfg=dict(backend='nccl'),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# visualizer
|
| 33 |
+
vis_backends = [
|
| 34 |
+
dict(type='LocalVisBackend'),
|
| 35 |
+
# dict(type='TensorboardVisBackend'),
|
| 36 |
+
# dict(type='WandbVisBackend'),
|
| 37 |
+
]
|
| 38 |
+
visualizer = dict(
|
| 39 |
+
type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
| 40 |
+
|
| 41 |
+
# logger
|
| 42 |
+
log_processor = dict(
|
| 43 |
+
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
|
| 44 |
+
log_level = 'INFO'
|
| 45 |
+
load_from = None
|
| 46 |
+
resume = False
|
| 47 |
+
|
| 48 |
+
# file I/O backend
|
| 49 |
+
backend_args = dict(backend='local')
|
| 50 |
+
|
| 51 |
+
# training/validation/testing progress
|
| 52 |
+
train_cfg = dict(by_epoch=True)
|
| 53 |
+
val_cfg = dict()
|
| 54 |
+
test_cfg = dict()
|
musetalk_integration/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#_base_ = ['../../../_base_/default_runtime.py']
|
| 2 |
+
_base_ = ['default_runtime.py']
|
| 3 |
+
|
| 4 |
+
# runtime
|
| 5 |
+
max_epochs = 270
|
| 6 |
+
stage2_num_epochs = 30
|
| 7 |
+
base_lr = 4e-3
|
| 8 |
+
train_batch_size = 32
|
| 9 |
+
val_batch_size = 32
|
| 10 |
+
|
| 11 |
+
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
| 12 |
+
randomness = dict(seed=21)
|
| 13 |
+
|
| 14 |
+
# optimizer
|
| 15 |
+
optim_wrapper = dict(
|
| 16 |
+
type='OptimWrapper',
|
| 17 |
+
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
| 18 |
+
paramwise_cfg=dict(
|
| 19 |
+
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
| 20 |
+
|
| 21 |
+
# learning rate
|
| 22 |
+
param_scheduler = [
|
| 23 |
+
dict(
|
| 24 |
+
type='LinearLR',
|
| 25 |
+
start_factor=1.0e-5,
|
| 26 |
+
by_epoch=False,
|
| 27 |
+
begin=0,
|
| 28 |
+
end=1000),
|
| 29 |
+
dict(
|
| 30 |
+
# use cosine lr from 150 to 300 epoch
|
| 31 |
+
type='CosineAnnealingLR',
|
| 32 |
+
eta_min=base_lr * 0.05,
|
| 33 |
+
begin=max_epochs // 2,
|
| 34 |
+
end=max_epochs,
|
| 35 |
+
T_max=max_epochs // 2,
|
| 36 |
+
by_epoch=True,
|
| 37 |
+
convert_to_iter_based=True),
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# automatically scaling LR based on the actual training batch size
|
| 41 |
+
auto_scale_lr = dict(base_batch_size=512)
|
| 42 |
+
|
| 43 |
+
# codec settings
|
| 44 |
+
codec = dict(
|
| 45 |
+
type='SimCCLabel',
|
| 46 |
+
input_size=(288, 384),
|
| 47 |
+
sigma=(6., 6.93),
|
| 48 |
+
simcc_split_ratio=2.0,
|
| 49 |
+
normalize=False,
|
| 50 |
+
use_dark=False)
|
| 51 |
+
|
| 52 |
+
# model settings
|
| 53 |
+
model = dict(
|
| 54 |
+
type='TopdownPoseEstimator',
|
| 55 |
+
data_preprocessor=dict(
|
| 56 |
+
type='PoseDataPreprocessor',
|
| 57 |
+
mean=[123.675, 116.28, 103.53],
|
| 58 |
+
std=[58.395, 57.12, 57.375],
|
| 59 |
+
bgr_to_rgb=True),
|
| 60 |
+
backbone=dict(
|
| 61 |
+
_scope_='mmdet',
|
| 62 |
+
type='CSPNeXt',
|
| 63 |
+
arch='P5',
|
| 64 |
+
expand_ratio=0.5,
|
| 65 |
+
deepen_factor=1.,
|
| 66 |
+
widen_factor=1.,
|
| 67 |
+
out_indices=(4, ),
|
| 68 |
+
channel_attention=True,
|
| 69 |
+
norm_cfg=dict(type='SyncBN'),
|
| 70 |
+
act_cfg=dict(type='SiLU'),
|
| 71 |
+
init_cfg=dict(
|
| 72 |
+
type='Pretrained',
|
| 73 |
+
prefix='backbone.',
|
| 74 |
+
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
|
| 75 |
+
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
|
| 76 |
+
)),
|
| 77 |
+
head=dict(
|
| 78 |
+
type='RTMCCHead',
|
| 79 |
+
in_channels=1024,
|
| 80 |
+
out_channels=133,
|
| 81 |
+
input_size=codec['input_size'],
|
| 82 |
+
in_featuremap_size=(9, 12),
|
| 83 |
+
simcc_split_ratio=codec['simcc_split_ratio'],
|
| 84 |
+
final_layer_kernel_size=7,
|
| 85 |
+
gau_cfg=dict(
|
| 86 |
+
hidden_dims=256,
|
| 87 |
+
s=128,
|
| 88 |
+
expansion_factor=2,
|
| 89 |
+
dropout_rate=0.,
|
| 90 |
+
drop_path=0.,
|
| 91 |
+
act_fn='SiLU',
|
| 92 |
+
use_rel_bias=False,
|
| 93 |
+
pos_enc=False),
|
| 94 |
+
loss=dict(
|
| 95 |
+
type='KLDiscretLoss',
|
| 96 |
+
use_target_weight=True,
|
| 97 |
+
beta=10.,
|
| 98 |
+
label_softmax=True),
|
| 99 |
+
decoder=codec),
|
| 100 |
+
test_cfg=dict(flip_test=True, ))
|
| 101 |
+
|
| 102 |
+
# base dataset settings
|
| 103 |
+
dataset_type = 'UBody2dDataset'
|
| 104 |
+
data_mode = 'topdown'
|
| 105 |
+
data_root = 'data/UBody/'
|
| 106 |
+
|
| 107 |
+
backend_args = dict(backend='local')
|
| 108 |
+
|
| 109 |
+
scenes = [
|
| 110 |
+
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
|
| 111 |
+
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
|
| 112 |
+
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
train_datasets = [
|
| 116 |
+
dict(
|
| 117 |
+
type='CocoWholeBodyDataset',
|
| 118 |
+
data_root='data/coco/',
|
| 119 |
+
data_mode=data_mode,
|
| 120 |
+
ann_file='annotations/coco_wholebody_train_v1.0.json',
|
| 121 |
+
data_prefix=dict(img='train2017/'),
|
| 122 |
+
pipeline=[])
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
for scene in scenes:
|
| 126 |
+
train_dataset = dict(
|
| 127 |
+
type=dataset_type,
|
| 128 |
+
data_root=data_root,
|
| 129 |
+
data_mode=data_mode,
|
| 130 |
+
ann_file=f'annotations/{scene}/train_annotations.json',
|
| 131 |
+
data_prefix=dict(img='images/'),
|
| 132 |
+
pipeline=[],
|
| 133 |
+
sample_interval=10)
|
| 134 |
+
train_datasets.append(train_dataset)
|
| 135 |
+
|
| 136 |
+
# pipelines
|
| 137 |
+
train_pipeline = [
|
| 138 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 139 |
+
dict(type='GetBBoxCenterScale'),
|
| 140 |
+
dict(type='RandomFlip', direction='horizontal'),
|
| 141 |
+
dict(type='RandomHalfBody'),
|
| 142 |
+
dict(
|
| 143 |
+
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
|
| 144 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 145 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
| 146 |
+
dict(
|
| 147 |
+
type='Albumentation',
|
| 148 |
+
transforms=[
|
| 149 |
+
dict(type='Blur', p=0.1),
|
| 150 |
+
dict(type='MedianBlur', p=0.1),
|
| 151 |
+
dict(
|
| 152 |
+
type='CoarseDropout',
|
| 153 |
+
max_holes=1,
|
| 154 |
+
max_height=0.4,
|
| 155 |
+
max_width=0.4,
|
| 156 |
+
min_holes=1,
|
| 157 |
+
min_height=0.2,
|
| 158 |
+
min_width=0.2,
|
| 159 |
+
p=1.0),
|
| 160 |
+
]),
|
| 161 |
+
dict(type='GenerateTarget', encoder=codec),
|
| 162 |
+
dict(type='PackPoseInputs')
|
| 163 |
+
]
|
| 164 |
+
val_pipeline = [
|
| 165 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 166 |
+
dict(type='GetBBoxCenterScale'),
|
| 167 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 168 |
+
dict(type='PackPoseInputs')
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
train_pipeline_stage2 = [
|
| 172 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 173 |
+
dict(type='GetBBoxCenterScale'),
|
| 174 |
+
dict(type='RandomFlip', direction='horizontal'),
|
| 175 |
+
dict(type='RandomHalfBody'),
|
| 176 |
+
dict(
|
| 177 |
+
type='RandomBBoxTransform',
|
| 178 |
+
shift_factor=0.,
|
| 179 |
+
scale_factor=[0.5, 1.5],
|
| 180 |
+
rotate_factor=90),
|
| 181 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 182 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
| 183 |
+
dict(
|
| 184 |
+
type='Albumentation',
|
| 185 |
+
transforms=[
|
| 186 |
+
dict(type='Blur', p=0.1),
|
| 187 |
+
dict(type='MedianBlur', p=0.1),
|
| 188 |
+
dict(
|
| 189 |
+
type='CoarseDropout',
|
| 190 |
+
max_holes=1,
|
| 191 |
+
max_height=0.4,
|
| 192 |
+
max_width=0.4,
|
| 193 |
+
min_holes=1,
|
| 194 |
+
min_height=0.2,
|
| 195 |
+
min_width=0.2,
|
| 196 |
+
p=0.5),
|
| 197 |
+
]),
|
| 198 |
+
dict(type='GenerateTarget', encoder=codec),
|
| 199 |
+
dict(type='PackPoseInputs')
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
# data loaders
|
| 203 |
+
train_dataloader = dict(
|
| 204 |
+
batch_size=train_batch_size,
|
| 205 |
+
num_workers=10,
|
| 206 |
+
persistent_workers=True,
|
| 207 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
| 208 |
+
dataset=dict(
|
| 209 |
+
type='CombinedDataset',
|
| 210 |
+
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
|
| 211 |
+
datasets=train_datasets,
|
| 212 |
+
pipeline=train_pipeline,
|
| 213 |
+
test_mode=False,
|
| 214 |
+
))
|
| 215 |
+
|
| 216 |
+
val_dataloader = dict(
|
| 217 |
+
batch_size=val_batch_size,
|
| 218 |
+
num_workers=10,
|
| 219 |
+
persistent_workers=True,
|
| 220 |
+
drop_last=False,
|
| 221 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
| 222 |
+
dataset=dict(
|
| 223 |
+
type='CocoWholeBodyDataset',
|
| 224 |
+
data_root=data_root,
|
| 225 |
+
data_mode=data_mode,
|
| 226 |
+
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
|
| 227 |
+
bbox_file='data/coco/person_detection_results/'
|
| 228 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
| 229 |
+
data_prefix=dict(img='coco/val2017/'),
|
| 230 |
+
test_mode=True,
|
| 231 |
+
pipeline=val_pipeline,
|
| 232 |
+
))
|
| 233 |
+
test_dataloader = val_dataloader
|
| 234 |
+
|
| 235 |
+
# hooks
|
| 236 |
+
default_hooks = dict(
|
| 237 |
+
checkpoint=dict(
|
| 238 |
+
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
|
| 239 |
+
|
| 240 |
+
custom_hooks = [
|
| 241 |
+
dict(
|
| 242 |
+
type='EMAHook',
|
| 243 |
+
ema_type='ExpMomentumEMA',
|
| 244 |
+
momentum=0.0002,
|
| 245 |
+
update_buffers=True,
|
| 246 |
+
priority=49),
|
| 247 |
+
dict(
|
| 248 |
+
type='mmdet.PipelineSwitchHook',
|
| 249 |
+
switch_epoch=max_epochs - stage2_num_epochs,
|
| 250 |
+
switch_pipeline=train_pipeline_stage2)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
# evaluators
|
| 254 |
+
val_evaluator = dict(
|
| 255 |
+
type='CocoWholeBodyMetric',
|
| 256 |
+
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
|
| 257 |
+
test_evaluator = val_evaluator
|
musetalk_integration/utils/face_detection/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
musetalk_integration/utils/face_detection/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
__author__ = """Adrian Bulat"""
|
| 4 |
+
__email__ = 'adrian.bulat@nottingham.ac.uk'
|
| 5 |
+
__version__ = '1.0.1'
|
| 6 |
+
|
| 7 |
+
from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face
|
musetalk_integration/utils/face_detection/api.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.model_zoo import load_url
|
| 5 |
+
from enum import Enum
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
try:
|
| 9 |
+
import urllib.request as request_file
|
| 10 |
+
except BaseException:
|
| 11 |
+
import urllib as request_file
|
| 12 |
+
|
| 13 |
+
from .models import FAN, ResNetDepth
|
| 14 |
+
from .utils import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LandmarksType(Enum):
|
| 18 |
+
"""Enum class defining the type of landmarks to detect.
|
| 19 |
+
|
| 20 |
+
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
| 21 |
+
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
| 22 |
+
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
_2D = 1
|
| 26 |
+
_2halfD = 2
|
| 27 |
+
_3D = 3
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class NetworkSize(Enum):
|
| 31 |
+
# TINY = 1
|
| 32 |
+
# SMALL = 2
|
| 33 |
+
# MEDIUM = 3
|
| 34 |
+
LARGE = 4
|
| 35 |
+
|
| 36 |
+
def __new__(cls, value):
|
| 37 |
+
member = object.__new__(cls)
|
| 38 |
+
member._value_ = value
|
| 39 |
+
return member
|
| 40 |
+
|
| 41 |
+
def __int__(self):
|
| 42 |
+
return self.value
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class FaceAlignment:
|
| 47 |
+
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
| 48 |
+
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
| 49 |
+
self.device = device
|
| 50 |
+
self.flip_input = flip_input
|
| 51 |
+
self.landmarks_type = landmarks_type
|
| 52 |
+
self.verbose = verbose
|
| 53 |
+
|
| 54 |
+
network_size = int(network_size)
|
| 55 |
+
|
| 56 |
+
if 'cuda' in device:
|
| 57 |
+
torch.backends.cudnn.benchmark = True
|
| 58 |
+
# torch.backends.cuda.matmul.allow_tf32 = False
|
| 59 |
+
# torch.backends.cudnn.benchmark = True
|
| 60 |
+
# torch.backends.cudnn.deterministic = False
|
| 61 |
+
# torch.backends.cudnn.allow_tf32 = True
|
| 62 |
+
print('cuda start')
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Get the face detector
|
| 66 |
+
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
| 67 |
+
globals(), locals(), [face_detector], 0)
|
| 68 |
+
|
| 69 |
+
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
| 70 |
+
|
| 71 |
+
def get_detections_for_batch(self, images):
|
| 72 |
+
images = images[..., ::-1]
|
| 73 |
+
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
| 74 |
+
results = []
|
| 75 |
+
|
| 76 |
+
for i, d in enumerate(detected_faces):
|
| 77 |
+
if len(d) == 0:
|
| 78 |
+
results.append(None)
|
| 79 |
+
continue
|
| 80 |
+
d = d[0]
|
| 81 |
+
d = np.clip(d, 0, None)
|
| 82 |
+
|
| 83 |
+
x1, y1, x2, y2 = map(int, d[:-1])
|
| 84 |
+
results.append((x1, y1, x2, y2))
|
| 85 |
+
|
| 86 |
+
return results
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class YOLOv8_face:
|
| 90 |
+
def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
|
| 91 |
+
self.conf_threshold = conf_thres
|
| 92 |
+
self.iou_threshold = iou_thres
|
| 93 |
+
self.class_names = ['face']
|
| 94 |
+
self.num_classes = len(self.class_names)
|
| 95 |
+
# Initialize model
|
| 96 |
+
self.net = cv2.dnn.readNet(path)
|
| 97 |
+
self.input_height = 640
|
| 98 |
+
self.input_width = 640
|
| 99 |
+
self.reg_max = 16
|
| 100 |
+
|
| 101 |
+
self.project = np.arange(self.reg_max)
|
| 102 |
+
self.strides = (8, 16, 32)
|
| 103 |
+
self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
|
| 104 |
+
self.anchors = self.make_anchors(self.feats_hw)
|
| 105 |
+
|
| 106 |
+
def make_anchors(self, feats_hw, grid_cell_offset=0.5):
|
| 107 |
+
"""Generate anchors from features."""
|
| 108 |
+
anchor_points = {}
|
| 109 |
+
for i, stride in enumerate(self.strides):
|
| 110 |
+
h,w = feats_hw[i]
|
| 111 |
+
x = np.arange(0, w) + grid_cell_offset # shift x
|
| 112 |
+
y = np.arange(0, h) + grid_cell_offset # shift y
|
| 113 |
+
sx, sy = np.meshgrid(x, y)
|
| 114 |
+
# sy, sx = np.meshgrid(y, x)
|
| 115 |
+
anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
|
| 116 |
+
return anchor_points
|
| 117 |
+
|
| 118 |
+
def softmax(self, x, axis=1):
|
| 119 |
+
x_exp = np.exp(x)
|
| 120 |
+
# 如果是列向量,则axis=0
|
| 121 |
+
x_sum = np.sum(x_exp, axis=axis, keepdims=True)
|
| 122 |
+
s = x_exp / x_sum
|
| 123 |
+
return s
|
| 124 |
+
|
| 125 |
+
def resize_image(self, srcimg, keep_ratio=True):
|
| 126 |
+
top, left, newh, neww = 0, 0, self.input_width, self.input_height
|
| 127 |
+
if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
|
| 128 |
+
hw_scale = srcimg.shape[0] / srcimg.shape[1]
|
| 129 |
+
if hw_scale > 1:
|
| 130 |
+
newh, neww = self.input_height, int(self.input_width / hw_scale)
|
| 131 |
+
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
| 132 |
+
left = int((self.input_width - neww) * 0.5)
|
| 133 |
+
img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
|
| 134 |
+
value=(0, 0, 0)) # add border
|
| 135 |
+
else:
|
| 136 |
+
newh, neww = int(self.input_height * hw_scale), self.input_width
|
| 137 |
+
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
| 138 |
+
top = int((self.input_height - newh) * 0.5)
|
| 139 |
+
img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
|
| 140 |
+
value=(0, 0, 0))
|
| 141 |
+
else:
|
| 142 |
+
img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
|
| 143 |
+
return img, newh, neww, top, left
|
| 144 |
+
|
| 145 |
+
def detect(self, srcimg):
|
| 146 |
+
input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
|
| 147 |
+
scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
|
| 148 |
+
input_img = input_img.astype(np.float32) / 255.0
|
| 149 |
+
|
| 150 |
+
blob = cv2.dnn.blobFromImage(input_img)
|
| 151 |
+
self.net.setInput(blob)
|
| 152 |
+
outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
|
| 153 |
+
# if isinstance(outputs, tuple):
|
| 154 |
+
# outputs = list(outputs)
|
| 155 |
+
# if float(cv2.__version__[:3])>=4.7:
|
| 156 |
+
# outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要
|
| 157 |
+
# Perform inference on the image
|
| 158 |
+
det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
|
| 159 |
+
return det_bboxes, det_conf, det_classid, landmarks
|
| 160 |
+
|
| 161 |
+
def post_process(self, preds, scale_h, scale_w, padh, padw):
|
| 162 |
+
bboxes, scores, landmarks = [], [], []
|
| 163 |
+
for i, pred in enumerate(preds):
|
| 164 |
+
stride = int(self.input_height/pred.shape[2])
|
| 165 |
+
pred = pred.transpose((0, 2, 3, 1))
|
| 166 |
+
|
| 167 |
+
box = pred[..., :self.reg_max * 4]
|
| 168 |
+
cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
|
| 169 |
+
kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
|
| 170 |
+
|
| 171 |
+
# tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
|
| 172 |
+
tmp = box.reshape(-1, 4, self.reg_max)
|
| 173 |
+
bbox_pred = self.softmax(tmp, axis=-1)
|
| 174 |
+
bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
|
| 175 |
+
|
| 176 |
+
bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
|
| 177 |
+
kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
|
| 178 |
+
kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
|
| 179 |
+
kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
|
| 180 |
+
|
| 181 |
+
bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
|
| 182 |
+
bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
|
| 183 |
+
kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
|
| 184 |
+
kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
|
| 185 |
+
|
| 186 |
+
bboxes.append(bbox)
|
| 187 |
+
scores.append(cls)
|
| 188 |
+
landmarks.append(kpts)
|
| 189 |
+
|
| 190 |
+
bboxes = np.concatenate(bboxes, axis=0)
|
| 191 |
+
scores = np.concatenate(scores, axis=0)
|
| 192 |
+
landmarks = np.concatenate(landmarks, axis=0)
|
| 193 |
+
|
| 194 |
+
bboxes_wh = bboxes.copy()
|
| 195 |
+
bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
|
| 196 |
+
classIds = np.argmax(scores, axis=1)
|
| 197 |
+
confidences = np.max(scores, axis=1) ####max_class_confidence
|
| 198 |
+
|
| 199 |
+
mask = confidences>self.conf_threshold
|
| 200 |
+
bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
|
| 201 |
+
confidences = confidences[mask]
|
| 202 |
+
classIds = classIds[mask]
|
| 203 |
+
landmarks = landmarks[mask]
|
| 204 |
+
|
| 205 |
+
indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
|
| 206 |
+
self.iou_threshold).flatten()
|
| 207 |
+
if len(indices) > 0:
|
| 208 |
+
mlvl_bboxes = bboxes_wh[indices]
|
| 209 |
+
confidences = confidences[indices]
|
| 210 |
+
classIds = classIds[indices]
|
| 211 |
+
landmarks = landmarks[indices]
|
| 212 |
+
return mlvl_bboxes, confidences, classIds, landmarks
|
| 213 |
+
else:
|
| 214 |
+
print('nothing detect')
|
| 215 |
+
return np.array([]), np.array([]), np.array([]), np.array([])
|
| 216 |
+
|
| 217 |
+
def distance2bbox(self, points, distance, max_shape=None):
|
| 218 |
+
x1 = points[:, 0] - distance[:, 0]
|
| 219 |
+
y1 = points[:, 1] - distance[:, 1]
|
| 220 |
+
x2 = points[:, 0] + distance[:, 2]
|
| 221 |
+
y2 = points[:, 1] + distance[:, 3]
|
| 222 |
+
if max_shape is not None:
|
| 223 |
+
x1 = np.clip(x1, 0, max_shape[1])
|
| 224 |
+
y1 = np.clip(y1, 0, max_shape[0])
|
| 225 |
+
x2 = np.clip(x2, 0, max_shape[1])
|
| 226 |
+
y2 = np.clip(y2, 0, max_shape[0])
|
| 227 |
+
return np.stack([x1, y1, x2, y2], axis=-1)
|
| 228 |
+
|
| 229 |
+
def draw_detections(self, image, boxes, scores, kpts):
|
| 230 |
+
for box, score, kp in zip(boxes, scores, kpts):
|
| 231 |
+
x, y, w, h = box.astype(int)
|
| 232 |
+
# Draw rectangle
|
| 233 |
+
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
|
| 234 |
+
cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
|
| 235 |
+
for i in range(5):
|
| 236 |
+
cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
|
| 237 |
+
# cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
|
| 238 |
+
return image
|
| 239 |
+
|
| 240 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
musetalk_integration/utils/face_detection/detection/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .core import FaceDetector
|
musetalk_integration/utils/face_detection/detection/core.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import glob
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FaceDetector(object):
|
| 10 |
+
"""An abstract class representing a face detector.
|
| 11 |
+
|
| 12 |
+
Any other face detection implementation must subclass it. All subclasses
|
| 13 |
+
must implement ``detect_from_image``, that return a list of detected
|
| 14 |
+
bounding boxes. Optionally, for speed considerations detect from path is
|
| 15 |
+
recommended.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, device, verbose):
|
| 19 |
+
self.device = device
|
| 20 |
+
self.verbose = verbose
|
| 21 |
+
|
| 22 |
+
if verbose:
|
| 23 |
+
if 'cpu' in device:
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
logger.warning("Detection running on CPU, this may be potentially slow.")
|
| 26 |
+
|
| 27 |
+
if 'cpu' not in device and 'cuda' not in device:
|
| 28 |
+
if verbose:
|
| 29 |
+
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
| 30 |
+
raise ValueError
|
| 31 |
+
|
| 32 |
+
def detect_from_image(self, tensor_or_path):
|
| 33 |
+
"""Detects faces in a given image.
|
| 34 |
+
|
| 35 |
+
This function detects the faces present in a provided BGR(usually)
|
| 36 |
+
image. The input can be either the image itself or the path to it.
|
| 37 |
+
|
| 38 |
+
Arguments:
|
| 39 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
| 40 |
+
to an image or the image itself.
|
| 41 |
+
|
| 42 |
+
Example::
|
| 43 |
+
|
| 44 |
+
>>> path_to_image = 'data/image_01.jpg'
|
| 45 |
+
... detected_faces = detect_from_image(path_to_image)
|
| 46 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
| 47 |
+
>>> image = cv2.imread(path_to_image)
|
| 48 |
+
... detected_faces = detect_from_image(image)
|
| 49 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
| 55 |
+
"""Detects faces from all the images present in a given directory.
|
| 56 |
+
|
| 57 |
+
Arguments:
|
| 58 |
+
path {string} -- a string containing a path that points to the folder containing the images
|
| 59 |
+
|
| 60 |
+
Keyword Arguments:
|
| 61 |
+
extensions {list} -- list of string containing the extensions to be
|
| 62 |
+
consider in the following format: ``.extension_name`` (default:
|
| 63 |
+
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
| 64 |
+
folder recursively (default: {False}) show_progress_bar {bool} --
|
| 65 |
+
display a progressbar (default: {True})
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
>>> directory = 'data'
|
| 69 |
+
... detected_faces = detect_from_directory(directory)
|
| 70 |
+
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
if self.verbose:
|
| 74 |
+
logger = logging.getLogger(__name__)
|
| 75 |
+
|
| 76 |
+
if len(extensions) == 0:
|
| 77 |
+
if self.verbose:
|
| 78 |
+
logger.error("Expected at list one extension, but none was received.")
|
| 79 |
+
raise ValueError
|
| 80 |
+
|
| 81 |
+
if self.verbose:
|
| 82 |
+
logger.info("Constructing the list of images.")
|
| 83 |
+
additional_pattern = '/**/*' if recursive else '/*'
|
| 84 |
+
files = []
|
| 85 |
+
for extension in extensions:
|
| 86 |
+
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
| 87 |
+
|
| 88 |
+
if self.verbose:
|
| 89 |
+
logger.info("Finished searching for images. %s images found", len(files))
|
| 90 |
+
logger.info("Preparing to run the detection.")
|
| 91 |
+
|
| 92 |
+
predictions = {}
|
| 93 |
+
for image_path in tqdm(files, disable=not show_progress_bar):
|
| 94 |
+
if self.verbose:
|
| 95 |
+
logger.info("Running the face detector on image: %s", image_path)
|
| 96 |
+
predictions[image_path] = self.detect_from_image(image_path)
|
| 97 |
+
|
| 98 |
+
if self.verbose:
|
| 99 |
+
logger.info("The detector was successfully run on all %s images", len(files))
|
| 100 |
+
|
| 101 |
+
return predictions
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def reference_scale(self):
|
| 105 |
+
raise NotImplementedError
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def reference_x_shift(self):
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def reference_y_shift(self):
|
| 113 |
+
raise NotImplementedError
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
| 117 |
+
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
| 118 |
+
|
| 119 |
+
Arguments:
|
| 120 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
| 121 |
+
"""
|
| 122 |
+
if isinstance(tensor_or_path, str):
|
| 123 |
+
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
| 124 |
+
elif torch.is_tensor(tensor_or_path):
|
| 125 |
+
# Call cpu in case its coming from cuda
|
| 126 |
+
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
| 127 |
+
elif isinstance(tensor_or_path, np.ndarray):
|
| 128 |
+
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
| 129 |
+
else:
|
| 130 |
+
raise TypeError
|
musetalk_integration/utils/face_detection/detection/sfd/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sfd_detector import SFDDetector as FaceDetector
|
musetalk_integration/utils/face_detection/detection/sfd/bbox.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import cv2
|
| 5 |
+
import random
|
| 6 |
+
import datetime
|
| 7 |
+
import time
|
| 8 |
+
import math
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from iou import IOU
|
| 15 |
+
except BaseException:
|
| 16 |
+
# IOU cython speedup 10x
|
| 17 |
+
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
| 18 |
+
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
| 19 |
+
sb = abs((bx2 - bx1) * (by2 - by1))
|
| 20 |
+
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
| 21 |
+
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
| 22 |
+
w = x2 - x1
|
| 23 |
+
h = y2 - y1
|
| 24 |
+
if w < 0 or h < 0:
|
| 25 |
+
return 0.0
|
| 26 |
+
else:
|
| 27 |
+
return 1.0 * w * h / (sa + sb - w * h)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
| 31 |
+
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
| 32 |
+
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
| 33 |
+
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
| 34 |
+
return dx, dy, dw, dh
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
| 38 |
+
xc, yc = dx * aww + axc, dy * ahh + ayc
|
| 39 |
+
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
| 40 |
+
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
| 41 |
+
return x1, y1, x2, y2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def nms(dets, thresh):
|
| 45 |
+
if 0 == len(dets):
|
| 46 |
+
return []
|
| 47 |
+
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
| 48 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 49 |
+
order = scores.argsort()[::-1]
|
| 50 |
+
|
| 51 |
+
keep = []
|
| 52 |
+
while order.size > 0:
|
| 53 |
+
i = order[0]
|
| 54 |
+
keep.append(i)
|
| 55 |
+
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
| 56 |
+
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
| 57 |
+
|
| 58 |
+
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
| 59 |
+
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
| 60 |
+
|
| 61 |
+
inds = np.where(ovr <= thresh)[0]
|
| 62 |
+
order = order[inds + 1]
|
| 63 |
+
|
| 64 |
+
return keep
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def encode(matched, priors, variances):
|
| 68 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 69 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 70 |
+
Args:
|
| 71 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
| 72 |
+
Shape: [num_priors, 4].
|
| 73 |
+
priors: (tensor) Prior boxes in center-offset form
|
| 74 |
+
Shape: [num_priors,4].
|
| 75 |
+
variances: (list[float]) Variances of priorboxes
|
| 76 |
+
Return:
|
| 77 |
+
encoded boxes (tensor), Shape: [num_priors, 4]
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# dist b/t match center and prior's center
|
| 81 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
| 82 |
+
# encode variance
|
| 83 |
+
g_cxcy /= (variances[0] * priors[:, 2:])
|
| 84 |
+
# match wh / prior wh
|
| 85 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
| 86 |
+
g_wh = torch.log(g_wh) / variances[1]
|
| 87 |
+
# return target for smooth_l1_loss
|
| 88 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def decode(loc, priors, variances):
|
| 92 |
+
"""Decode locations from predictions using priors to undo
|
| 93 |
+
the encoding we did for offset regression at train time.
|
| 94 |
+
Args:
|
| 95 |
+
loc (tensor): location predictions for loc layers,
|
| 96 |
+
Shape: [num_priors,4]
|
| 97 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 98 |
+
Shape: [num_priors,4].
|
| 99 |
+
variances: (list[float]) Variances of priorboxes
|
| 100 |
+
Return:
|
| 101 |
+
decoded bounding box predictions
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
boxes = torch.cat((
|
| 105 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
| 106 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
| 107 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 108 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 109 |
+
return boxes
|
| 110 |
+
|
| 111 |
+
def batch_decode(loc, priors, variances):
|
| 112 |
+
"""Decode locations from predictions using priors to undo
|
| 113 |
+
the encoding we did for offset regression at train time.
|
| 114 |
+
Args:
|
| 115 |
+
loc (tensor): location predictions for loc layers,
|
| 116 |
+
Shape: [num_priors,4]
|
| 117 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 118 |
+
Shape: [num_priors,4].
|
| 119 |
+
variances: (list[float]) Variances of priorboxes
|
| 120 |
+
Return:
|
| 121 |
+
decoded bounding box predictions
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
boxes = torch.cat((
|
| 125 |
+
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
| 126 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
| 127 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
| 128 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
| 129 |
+
return boxes
|
musetalk_integration/utils/face_detection/detection/sfd/detect.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import cv2
|
| 7 |
+
import random
|
| 8 |
+
import datetime
|
| 9 |
+
import math
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import scipy.io as sio
|
| 14 |
+
import zipfile
|
| 15 |
+
from .net_s3fd import s3fd
|
| 16 |
+
from .bbox import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def detect(net, img, device):
|
| 20 |
+
img = img - np.array([104, 117, 123])
|
| 21 |
+
img = img.transpose(2, 0, 1)
|
| 22 |
+
img = img.reshape((1,) + img.shape)
|
| 23 |
+
|
| 24 |
+
if 'cuda' in device:
|
| 25 |
+
torch.backends.cudnn.benchmark = True
|
| 26 |
+
|
| 27 |
+
img = torch.from_numpy(img).float().to(device)
|
| 28 |
+
BB, CC, HH, WW = img.size()
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
olist = net(img)
|
| 31 |
+
|
| 32 |
+
bboxlist = []
|
| 33 |
+
for i in range(len(olist) // 2):
|
| 34 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
| 35 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
| 36 |
+
for i in range(len(olist) // 2):
|
| 37 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
| 38 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
| 39 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
| 40 |
+
anchor = stride * 4
|
| 41 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
| 42 |
+
for Iindex, hindex, windex in poss:
|
| 43 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
| 44 |
+
score = ocls[0, 1, hindex, windex]
|
| 45 |
+
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
| 46 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
| 47 |
+
variances = [0.1, 0.2]
|
| 48 |
+
box = decode(loc, priors, variances)
|
| 49 |
+
x1, y1, x2, y2 = box[0] * 1.0
|
| 50 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
| 51 |
+
bboxlist.append([x1, y1, x2, y2, score])
|
| 52 |
+
bboxlist = np.array(bboxlist)
|
| 53 |
+
if 0 == len(bboxlist):
|
| 54 |
+
bboxlist = np.zeros((1, 5))
|
| 55 |
+
|
| 56 |
+
return bboxlist
|
| 57 |
+
|
| 58 |
+
def batch_detect(net, imgs, device):
|
| 59 |
+
imgs = imgs - np.array([104, 117, 123])
|
| 60 |
+
imgs = imgs.transpose(0, 3, 1, 2)
|
| 61 |
+
|
| 62 |
+
if 'cuda' in device:
|
| 63 |
+
torch.backends.cudnn.benchmark = True
|
| 64 |
+
|
| 65 |
+
imgs = torch.from_numpy(imgs).float().to(device)
|
| 66 |
+
BB, CC, HH, WW = imgs.size()
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
olist = net(imgs)
|
| 69 |
+
# print(olist)
|
| 70 |
+
|
| 71 |
+
bboxlist = []
|
| 72 |
+
for i in range(len(olist) // 2):
|
| 73 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
| 74 |
+
|
| 75 |
+
olist = [oelem.cpu() for oelem in olist]
|
| 76 |
+
for i in range(len(olist) // 2):
|
| 77 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
| 78 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
| 79 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
| 80 |
+
anchor = stride * 4
|
| 81 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
| 82 |
+
for Iindex, hindex, windex in poss:
|
| 83 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
| 84 |
+
score = ocls[:, 1, hindex, windex]
|
| 85 |
+
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
| 86 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
| 87 |
+
variances = [0.1, 0.2]
|
| 88 |
+
box = batch_decode(loc, priors, variances)
|
| 89 |
+
box = box[:, 0] * 1.0
|
| 90 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
| 91 |
+
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
| 92 |
+
bboxlist = np.array(bboxlist)
|
| 93 |
+
if 0 == len(bboxlist):
|
| 94 |
+
bboxlist = np.zeros((1, BB, 5))
|
| 95 |
+
|
| 96 |
+
return bboxlist
|
| 97 |
+
|
| 98 |
+
def flip_detect(net, img, device):
|
| 99 |
+
img = cv2.flip(img, 1)
|
| 100 |
+
b = detect(net, img, device)
|
| 101 |
+
|
| 102 |
+
bboxlist = np.zeros(b.shape)
|
| 103 |
+
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
| 104 |
+
bboxlist[:, 1] = b[:, 1]
|
| 105 |
+
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
| 106 |
+
bboxlist[:, 3] = b[:, 3]
|
| 107 |
+
bboxlist[:, 4] = b[:, 4]
|
| 108 |
+
return bboxlist
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def pts_to_bb(pts):
|
| 112 |
+
min_x, min_y = np.min(pts, axis=0)
|
| 113 |
+
max_x, max_y = np.max(pts, axis=0)
|
| 114 |
+
return np.array([min_x, min_y, max_x, max_y])
|
musetalk_integration/utils/face_detection/detection/sfd/net_s3fd.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class L2Norm(nn.Module):
|
| 7 |
+
def __init__(self, n_channels, scale=1.0):
|
| 8 |
+
super(L2Norm, self).__init__()
|
| 9 |
+
self.n_channels = n_channels
|
| 10 |
+
self.scale = scale
|
| 11 |
+
self.eps = 1e-10
|
| 12 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
| 13 |
+
self.weight.data *= 0.0
|
| 14 |
+
self.weight.data += self.scale
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
| 18 |
+
x = x / norm * self.weight.view(1, -1, 1, 1)
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class s3fd(nn.Module):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super(s3fd, self).__init__()
|
| 25 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
| 26 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 27 |
+
|
| 28 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
| 29 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
| 30 |
+
|
| 31 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
| 32 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 33 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 34 |
+
|
| 35 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
| 36 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 37 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 38 |
+
|
| 39 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 40 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 41 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 42 |
+
|
| 43 |
+
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
| 44 |
+
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
| 45 |
+
|
| 46 |
+
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
| 47 |
+
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
| 48 |
+
|
| 49 |
+
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
| 50 |
+
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
| 51 |
+
|
| 52 |
+
self.conv3_3_norm = L2Norm(256, scale=10)
|
| 53 |
+
self.conv4_3_norm = L2Norm(512, scale=8)
|
| 54 |
+
self.conv5_3_norm = L2Norm(512, scale=5)
|
| 55 |
+
|
| 56 |
+
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 57 |
+
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 58 |
+
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 59 |
+
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 60 |
+
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 61 |
+
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 62 |
+
|
| 63 |
+
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
| 64 |
+
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
| 65 |
+
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 66 |
+
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 67 |
+
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
| 68 |
+
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
h = F.relu(self.conv1_1(x))
|
| 72 |
+
h = F.relu(self.conv1_2(h))
|
| 73 |
+
h = F.max_pool2d(h, 2, 2)
|
| 74 |
+
|
| 75 |
+
h = F.relu(self.conv2_1(h))
|
| 76 |
+
h = F.relu(self.conv2_2(h))
|
| 77 |
+
h = F.max_pool2d(h, 2, 2)
|
| 78 |
+
|
| 79 |
+
h = F.relu(self.conv3_1(h))
|
| 80 |
+
h = F.relu(self.conv3_2(h))
|
| 81 |
+
h = F.relu(self.conv3_3(h))
|
| 82 |
+
f3_3 = h
|
| 83 |
+
h = F.max_pool2d(h, 2, 2)
|
| 84 |
+
|
| 85 |
+
h = F.relu(self.conv4_1(h))
|
| 86 |
+
h = F.relu(self.conv4_2(h))
|
| 87 |
+
h = F.relu(self.conv4_3(h))
|
| 88 |
+
f4_3 = h
|
| 89 |
+
h = F.max_pool2d(h, 2, 2)
|
| 90 |
+
|
| 91 |
+
h = F.relu(self.conv5_1(h))
|
| 92 |
+
h = F.relu(self.conv5_2(h))
|
| 93 |
+
h = F.relu(self.conv5_3(h))
|
| 94 |
+
f5_3 = h
|
| 95 |
+
h = F.max_pool2d(h, 2, 2)
|
| 96 |
+
|
| 97 |
+
h = F.relu(self.fc6(h))
|
| 98 |
+
h = F.relu(self.fc7(h))
|
| 99 |
+
ffc7 = h
|
| 100 |
+
h = F.relu(self.conv6_1(h))
|
| 101 |
+
h = F.relu(self.conv6_2(h))
|
| 102 |
+
f6_2 = h
|
| 103 |
+
h = F.relu(self.conv7_1(h))
|
| 104 |
+
h = F.relu(self.conv7_2(h))
|
| 105 |
+
f7_2 = h
|
| 106 |
+
|
| 107 |
+
f3_3 = self.conv3_3_norm(f3_3)
|
| 108 |
+
f4_3 = self.conv4_3_norm(f4_3)
|
| 109 |
+
f5_3 = self.conv5_3_norm(f5_3)
|
| 110 |
+
|
| 111 |
+
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
| 112 |
+
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
| 113 |
+
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
| 114 |
+
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
| 115 |
+
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
| 116 |
+
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
| 117 |
+
cls4 = self.fc7_mbox_conf(ffc7)
|
| 118 |
+
reg4 = self.fc7_mbox_loc(ffc7)
|
| 119 |
+
cls5 = self.conv6_2_mbox_conf(f6_2)
|
| 120 |
+
reg5 = self.conv6_2_mbox_loc(f6_2)
|
| 121 |
+
cls6 = self.conv7_2_mbox_conf(f7_2)
|
| 122 |
+
reg6 = self.conv7_2_mbox_loc(f7_2)
|
| 123 |
+
|
| 124 |
+
# max-out background label
|
| 125 |
+
chunk = torch.chunk(cls1, 4, 1)
|
| 126 |
+
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
| 127 |
+
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
| 128 |
+
|
| 129 |
+
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
musetalk_integration/utils/face_detection/detection/sfd/sfd_detector.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
from torch.utils.model_zoo import load_url
|
| 4 |
+
|
| 5 |
+
from ..core import FaceDetector
|
| 6 |
+
|
| 7 |
+
from .net_s3fd import s3fd
|
| 8 |
+
from .bbox import *
|
| 9 |
+
from .detect import *
|
| 10 |
+
|
| 11 |
+
models_urls = {
|
| 12 |
+
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SFDDetector(FaceDetector):
|
| 17 |
+
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
| 18 |
+
super(SFDDetector, self).__init__(device, verbose)
|
| 19 |
+
|
| 20 |
+
# Initialise the face detector
|
| 21 |
+
if not os.path.isfile(path_to_detector):
|
| 22 |
+
model_weights = load_url(models_urls['s3fd'])
|
| 23 |
+
else:
|
| 24 |
+
model_weights = torch.load(path_to_detector)
|
| 25 |
+
|
| 26 |
+
self.face_detector = s3fd()
|
| 27 |
+
self.face_detector.load_state_dict(model_weights)
|
| 28 |
+
self.face_detector.to(device)
|
| 29 |
+
self.face_detector.eval()
|
| 30 |
+
|
| 31 |
+
def detect_from_image(self, tensor_or_path):
|
| 32 |
+
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
| 33 |
+
|
| 34 |
+
bboxlist = detect(self.face_detector, image, device=self.device)
|
| 35 |
+
keep = nms(bboxlist, 0.3)
|
| 36 |
+
bboxlist = bboxlist[keep, :]
|
| 37 |
+
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
| 38 |
+
|
| 39 |
+
return bboxlist
|
| 40 |
+
|
| 41 |
+
def detect_from_batch(self, images):
|
| 42 |
+
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
| 43 |
+
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
| 44 |
+
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
| 45 |
+
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
| 46 |
+
|
| 47 |
+
return bboxlists
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def reference_scale(self):
|
| 51 |
+
return 195
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def reference_x_shift(self):
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def reference_y_shift(self):
|
| 59 |
+
return 0
|
musetalk_integration/utils/face_detection/models.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
| 8 |
+
"3x3 convolution with padding"
|
| 9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
| 10 |
+
stride=strd, padding=padding, bias=bias)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConvBlock(nn.Module):
|
| 14 |
+
def __init__(self, in_planes, out_planes):
|
| 15 |
+
super(ConvBlock, self).__init__()
|
| 16 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
| 17 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
| 18 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
| 19 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
| 20 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
| 21 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
| 22 |
+
|
| 23 |
+
if in_planes != out_planes:
|
| 24 |
+
self.downsample = nn.Sequential(
|
| 25 |
+
nn.BatchNorm2d(in_planes),
|
| 26 |
+
nn.ReLU(True),
|
| 27 |
+
nn.Conv2d(in_planes, out_planes,
|
| 28 |
+
kernel_size=1, stride=1, bias=False),
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
self.downsample = None
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
residual = x
|
| 35 |
+
|
| 36 |
+
out1 = self.bn1(x)
|
| 37 |
+
out1 = F.relu(out1, True)
|
| 38 |
+
out1 = self.conv1(out1)
|
| 39 |
+
|
| 40 |
+
out2 = self.bn2(out1)
|
| 41 |
+
out2 = F.relu(out2, True)
|
| 42 |
+
out2 = self.conv2(out2)
|
| 43 |
+
|
| 44 |
+
out3 = self.bn3(out2)
|
| 45 |
+
out3 = F.relu(out3, True)
|
| 46 |
+
out3 = self.conv3(out3)
|
| 47 |
+
|
| 48 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
residual = self.downsample(residual)
|
| 52 |
+
|
| 53 |
+
out3 += residual
|
| 54 |
+
|
| 55 |
+
return out3
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Bottleneck(nn.Module):
|
| 59 |
+
|
| 60 |
+
expansion = 4
|
| 61 |
+
|
| 62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 63 |
+
super(Bottleneck, self).__init__()
|
| 64 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 65 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 66 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 67 |
+
padding=1, bias=False)
|
| 68 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 69 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 70 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 71 |
+
self.relu = nn.ReLU(inplace=True)
|
| 72 |
+
self.downsample = downsample
|
| 73 |
+
self.stride = stride
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
residual = x
|
| 77 |
+
|
| 78 |
+
out = self.conv1(x)
|
| 79 |
+
out = self.bn1(out)
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
out = self.conv2(out)
|
| 83 |
+
out = self.bn2(out)
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
out = self.conv3(out)
|
| 87 |
+
out = self.bn3(out)
|
| 88 |
+
|
| 89 |
+
if self.downsample is not None:
|
| 90 |
+
residual = self.downsample(x)
|
| 91 |
+
|
| 92 |
+
out += residual
|
| 93 |
+
out = self.relu(out)
|
| 94 |
+
|
| 95 |
+
return out
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class HourGlass(nn.Module):
|
| 99 |
+
def __init__(self, num_modules, depth, num_features):
|
| 100 |
+
super(HourGlass, self).__init__()
|
| 101 |
+
self.num_modules = num_modules
|
| 102 |
+
self.depth = depth
|
| 103 |
+
self.features = num_features
|
| 104 |
+
|
| 105 |
+
self._generate_network(self.depth)
|
| 106 |
+
|
| 107 |
+
def _generate_network(self, level):
|
| 108 |
+
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
| 109 |
+
|
| 110 |
+
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
| 111 |
+
|
| 112 |
+
if level > 1:
|
| 113 |
+
self._generate_network(level - 1)
|
| 114 |
+
else:
|
| 115 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
| 116 |
+
|
| 117 |
+
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
| 118 |
+
|
| 119 |
+
def _forward(self, level, inp):
|
| 120 |
+
# Upper branch
|
| 121 |
+
up1 = inp
|
| 122 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
| 123 |
+
|
| 124 |
+
# Lower branch
|
| 125 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
| 126 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
| 127 |
+
|
| 128 |
+
if level > 1:
|
| 129 |
+
low2 = self._forward(level - 1, low1)
|
| 130 |
+
else:
|
| 131 |
+
low2 = low1
|
| 132 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
| 133 |
+
|
| 134 |
+
low3 = low2
|
| 135 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
| 136 |
+
|
| 137 |
+
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
| 138 |
+
|
| 139 |
+
return up1 + up2
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
return self._forward(self.depth, x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class FAN(nn.Module):
|
| 146 |
+
|
| 147 |
+
def __init__(self, num_modules=1):
|
| 148 |
+
super(FAN, self).__init__()
|
| 149 |
+
self.num_modules = num_modules
|
| 150 |
+
|
| 151 |
+
# Base part
|
| 152 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 153 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 154 |
+
self.conv2 = ConvBlock(64, 128)
|
| 155 |
+
self.conv3 = ConvBlock(128, 128)
|
| 156 |
+
self.conv4 = ConvBlock(128, 256)
|
| 157 |
+
|
| 158 |
+
# Stacking part
|
| 159 |
+
for hg_module in range(self.num_modules):
|
| 160 |
+
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
| 161 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
| 162 |
+
self.add_module('conv_last' + str(hg_module),
|
| 163 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
| 164 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
| 165 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
| 166 |
+
68, kernel_size=1, stride=1, padding=0))
|
| 167 |
+
|
| 168 |
+
if hg_module < self.num_modules - 1:
|
| 169 |
+
self.add_module(
|
| 170 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
| 171 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
| 172 |
+
256, kernel_size=1, stride=1, padding=0))
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
| 176 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
| 177 |
+
x = self.conv3(x)
|
| 178 |
+
x = self.conv4(x)
|
| 179 |
+
|
| 180 |
+
previous = x
|
| 181 |
+
|
| 182 |
+
outputs = []
|
| 183 |
+
for i in range(self.num_modules):
|
| 184 |
+
hg = self._modules['m' + str(i)](previous)
|
| 185 |
+
|
| 186 |
+
ll = hg
|
| 187 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
| 188 |
+
|
| 189 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
| 190 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
| 191 |
+
|
| 192 |
+
# Predict heatmaps
|
| 193 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
| 194 |
+
outputs.append(tmp_out)
|
| 195 |
+
|
| 196 |
+
if i < self.num_modules - 1:
|
| 197 |
+
ll = self._modules['bl' + str(i)](ll)
|
| 198 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
| 199 |
+
previous = previous + ll + tmp_out_
|
| 200 |
+
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class ResNetDepth(nn.Module):
|
| 205 |
+
|
| 206 |
+
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
| 207 |
+
self.inplanes = 64
|
| 208 |
+
super(ResNetDepth, self).__init__()
|
| 209 |
+
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
| 210 |
+
bias=False)
|
| 211 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 212 |
+
self.relu = nn.ReLU(inplace=True)
|
| 213 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 214 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 215 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 216 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 217 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 218 |
+
self.avgpool = nn.AvgPool2d(7)
|
| 219 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 220 |
+
|
| 221 |
+
for m in self.modules():
|
| 222 |
+
if isinstance(m, nn.Conv2d):
|
| 223 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 224 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 225 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 226 |
+
m.weight.data.fill_(1)
|
| 227 |
+
m.bias.data.zero_()
|
| 228 |
+
|
| 229 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 230 |
+
downsample = None
|
| 231 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 232 |
+
downsample = nn.Sequential(
|
| 233 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 234 |
+
kernel_size=1, stride=stride, bias=False),
|
| 235 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
layers = []
|
| 239 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 240 |
+
self.inplanes = planes * block.expansion
|
| 241 |
+
for i in range(1, blocks):
|
| 242 |
+
layers.append(block(self.inplanes, planes))
|
| 243 |
+
|
| 244 |
+
return nn.Sequential(*layers)
|
| 245 |
+
|
| 246 |
+
def forward(self, x):
|
| 247 |
+
x = self.conv1(x)
|
| 248 |
+
x = self.bn1(x)
|
| 249 |
+
x = self.relu(x)
|
| 250 |
+
x = self.maxpool(x)
|
| 251 |
+
|
| 252 |
+
x = self.layer1(x)
|
| 253 |
+
x = self.layer2(x)
|
| 254 |
+
x = self.layer3(x)
|
| 255 |
+
x = self.layer4(x)
|
| 256 |
+
|
| 257 |
+
x = self.avgpool(x)
|
| 258 |
+
x = x.view(x.size(0), -1)
|
| 259 |
+
x = self.fc(x)
|
| 260 |
+
|
| 261 |
+
return x
|
musetalk_integration/utils/face_detection/utils.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _gaussian(
|
| 12 |
+
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
| 13 |
+
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
| 14 |
+
mean_vert=0.5):
|
| 15 |
+
# handle some defaults
|
| 16 |
+
if width is None:
|
| 17 |
+
width = size
|
| 18 |
+
if height is None:
|
| 19 |
+
height = size
|
| 20 |
+
if sigma_horz is None:
|
| 21 |
+
sigma_horz = sigma
|
| 22 |
+
if sigma_vert is None:
|
| 23 |
+
sigma_vert = sigma
|
| 24 |
+
center_x = mean_horz * width + 0.5
|
| 25 |
+
center_y = mean_vert * height + 0.5
|
| 26 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
| 27 |
+
# generate kernel
|
| 28 |
+
for i in range(height):
|
| 29 |
+
for j in range(width):
|
| 30 |
+
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
| 31 |
+
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
| 32 |
+
if normalize:
|
| 33 |
+
gauss = gauss / np.sum(gauss)
|
| 34 |
+
return gauss
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def draw_gaussian(image, point, sigma):
|
| 38 |
+
# Check if the gaussian is inside
|
| 39 |
+
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
| 40 |
+
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
| 41 |
+
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
| 42 |
+
return image
|
| 43 |
+
size = 6 * sigma + 1
|
| 44 |
+
g = _gaussian(size)
|
| 45 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
| 46 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
| 47 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
| 48 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
| 49 |
+
assert (g_x[0] > 0 and g_y[1] > 0)
|
| 50 |
+
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
| 51 |
+
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
| 52 |
+
image[image > 1] = 1
|
| 53 |
+
return image
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def transform(point, center, scale, resolution, invert=False):
|
| 57 |
+
"""Generate and affine transformation matrix.
|
| 58 |
+
|
| 59 |
+
Given a set of points, a center, a scale and a targer resolution, the
|
| 60 |
+
function generates and affine transformation matrix. If invert is ``True``
|
| 61 |
+
it will produce the inverse transformation.
|
| 62 |
+
|
| 63 |
+
Arguments:
|
| 64 |
+
point {torch.tensor} -- the input 2D point
|
| 65 |
+
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
| 66 |
+
scale {float} -- the scale of the face/object
|
| 67 |
+
resolution {float} -- the output resolution
|
| 68 |
+
|
| 69 |
+
Keyword Arguments:
|
| 70 |
+
invert {bool} -- define wherever the function should produce the direct or the
|
| 71 |
+
inverse transformation matrix (default: {False})
|
| 72 |
+
"""
|
| 73 |
+
_pt = torch.ones(3)
|
| 74 |
+
_pt[0] = point[0]
|
| 75 |
+
_pt[1] = point[1]
|
| 76 |
+
|
| 77 |
+
h = 200.0 * scale
|
| 78 |
+
t = torch.eye(3)
|
| 79 |
+
t[0, 0] = resolution / h
|
| 80 |
+
t[1, 1] = resolution / h
|
| 81 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
| 82 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
| 83 |
+
|
| 84 |
+
if invert:
|
| 85 |
+
t = torch.inverse(t)
|
| 86 |
+
|
| 87 |
+
new_point = (torch.matmul(t, _pt))[0:2]
|
| 88 |
+
|
| 89 |
+
return new_point.int()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def crop(image, center, scale, resolution=256.0):
|
| 93 |
+
"""Center crops an image or set of heatmaps
|
| 94 |
+
|
| 95 |
+
Arguments:
|
| 96 |
+
image {numpy.array} -- an rgb image
|
| 97 |
+
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
| 98 |
+
scale {float} -- scale of the face
|
| 99 |
+
|
| 100 |
+
Keyword Arguments:
|
| 101 |
+
resolution {float} -- the size of the output cropped image (default: {256.0})
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
[type] -- [description]
|
| 105 |
+
""" # Crop around the center point
|
| 106 |
+
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
| 107 |
+
ul = transform([1, 1], center, scale, resolution, True)
|
| 108 |
+
br = transform([resolution, resolution], center, scale, resolution, True)
|
| 109 |
+
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
| 110 |
+
if image.ndim > 2:
|
| 111 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
| 112 |
+
image.shape[2]], dtype=np.int32)
|
| 113 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
| 114 |
+
else:
|
| 115 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
| 116 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
| 117 |
+
ht = image.shape[0]
|
| 118 |
+
wd = image.shape[1]
|
| 119 |
+
newX = np.array(
|
| 120 |
+
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
| 121 |
+
newY = np.array(
|
| 122 |
+
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
| 123 |
+
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
| 124 |
+
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
| 125 |
+
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
| 126 |
+
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
| 127 |
+
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
| 128 |
+
interpolation=cv2.INTER_LINEAR)
|
| 129 |
+
return newImg
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_preds_fromhm(hm, center=None, scale=None):
|
| 133 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
| 134 |
+
and the scale is provided the function will return the points also in
|
| 135 |
+
the original coordinate frame.
|
| 136 |
+
|
| 137 |
+
Arguments:
|
| 138 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
| 139 |
+
|
| 140 |
+
Keyword Arguments:
|
| 141 |
+
center {torch.tensor} -- the center of the bounding box (default: {None})
|
| 142 |
+
scale {float} -- face scale (default: {None})
|
| 143 |
+
"""
|
| 144 |
+
max, idx = torch.max(
|
| 145 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
| 146 |
+
idx += 1
|
| 147 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
| 148 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
| 149 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
| 150 |
+
|
| 151 |
+
for i in range(preds.size(0)):
|
| 152 |
+
for j in range(preds.size(1)):
|
| 153 |
+
hm_ = hm[i, j, :]
|
| 154 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
| 155 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
| 156 |
+
diff = torch.FloatTensor(
|
| 157 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
| 158 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
| 159 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
| 160 |
+
|
| 161 |
+
preds.add_(-.5)
|
| 162 |
+
|
| 163 |
+
preds_orig = torch.zeros(preds.size())
|
| 164 |
+
if center is not None and scale is not None:
|
| 165 |
+
for i in range(hm.size(0)):
|
| 166 |
+
for j in range(hm.size(1)):
|
| 167 |
+
preds_orig[i, j] = transform(
|
| 168 |
+
preds[i, j], center, scale, hm.size(2), True)
|
| 169 |
+
|
| 170 |
+
return preds, preds_orig
|
| 171 |
+
|
| 172 |
+
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
| 173 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
| 174 |
+
and the scales is provided the function will return the points also in
|
| 175 |
+
the original coordinate frame.
|
| 176 |
+
|
| 177 |
+
Arguments:
|
| 178 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
| 179 |
+
|
| 180 |
+
Keyword Arguments:
|
| 181 |
+
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
| 182 |
+
scales {float} -- face scales (default: {None})
|
| 183 |
+
"""
|
| 184 |
+
max, idx = torch.max(
|
| 185 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
| 186 |
+
idx += 1
|
| 187 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
| 188 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
| 189 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
| 190 |
+
|
| 191 |
+
for i in range(preds.size(0)):
|
| 192 |
+
for j in range(preds.size(1)):
|
| 193 |
+
hm_ = hm[i, j, :]
|
| 194 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
| 195 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
| 196 |
+
diff = torch.FloatTensor(
|
| 197 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
| 198 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
| 199 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
| 200 |
+
|
| 201 |
+
preds.add_(-.5)
|
| 202 |
+
|
| 203 |
+
preds_orig = torch.zeros(preds.size())
|
| 204 |
+
if centers is not None and scales is not None:
|
| 205 |
+
for i in range(hm.size(0)):
|
| 206 |
+
for j in range(hm.size(1)):
|
| 207 |
+
preds_orig[i, j] = transform(
|
| 208 |
+
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
| 209 |
+
|
| 210 |
+
return preds, preds_orig
|
| 211 |
+
|
| 212 |
+
def shuffle_lr(parts, pairs=None):
|
| 213 |
+
"""Shuffle the points left-right according to the axis of symmetry
|
| 214 |
+
of the object.
|
| 215 |
+
|
| 216 |
+
Arguments:
|
| 217 |
+
parts {torch.tensor} -- a 3D or 4D object containing the
|
| 218 |
+
heatmaps.
|
| 219 |
+
|
| 220 |
+
Keyword Arguments:
|
| 221 |
+
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
| 222 |
+
"""
|
| 223 |
+
if pairs is None:
|
| 224 |
+
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
| 225 |
+
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
| 226 |
+
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
| 227 |
+
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
| 228 |
+
62, 61, 60, 67, 66, 65]
|
| 229 |
+
if parts.ndimension() == 3:
|
| 230 |
+
parts = parts[pairs, ...]
|
| 231 |
+
else:
|
| 232 |
+
parts = parts[:, pairs, ...]
|
| 233 |
+
|
| 234 |
+
return parts
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def flip(tensor, is_label=False):
|
| 238 |
+
"""Flip an image or a set of heatmaps left-right
|
| 239 |
+
|
| 240 |
+
Arguments:
|
| 241 |
+
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
| 242 |
+
|
| 243 |
+
Keyword Arguments:
|
| 244 |
+
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
| 245 |
+
"""
|
| 246 |
+
if not torch.is_tensor(tensor):
|
| 247 |
+
tensor = torch.from_numpy(tensor)
|
| 248 |
+
|
| 249 |
+
if is_label:
|
| 250 |
+
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
| 251 |
+
else:
|
| 252 |
+
tensor = tensor.flip(tensor.ndimension() - 1)
|
| 253 |
+
|
| 254 |
+
return tensor
|
| 255 |
+
|
| 256 |
+
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def appdata_dir(appname=None, roaming=False):
|
| 260 |
+
""" appdata_dir(appname=None, roaming=False)
|
| 261 |
+
|
| 262 |
+
Get the path to the application directory, where applications are allowed
|
| 263 |
+
to write user specific files (e.g. configurations). For non-user specific
|
| 264 |
+
data, consider using common_appdata_dir().
|
| 265 |
+
If appname is given, a subdir is appended (and created if necessary).
|
| 266 |
+
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
# Define default user directory
|
| 270 |
+
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
| 271 |
+
if userDir is None:
|
| 272 |
+
userDir = os.path.expanduser('~')
|
| 273 |
+
if not os.path.isdir(userDir): # pragma: no cover
|
| 274 |
+
userDir = '/var/tmp' # issue #54
|
| 275 |
+
|
| 276 |
+
# Get system app data dir
|
| 277 |
+
path = None
|
| 278 |
+
if sys.platform.startswith('win'):
|
| 279 |
+
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
| 280 |
+
path = (path2 or path1) if roaming else (path1 or path2)
|
| 281 |
+
elif sys.platform.startswith('darwin'):
|
| 282 |
+
path = os.path.join(userDir, 'Library', 'Application Support')
|
| 283 |
+
# On Linux and as fallback
|
| 284 |
+
if not (path and os.path.isdir(path)):
|
| 285 |
+
path = userDir
|
| 286 |
+
|
| 287 |
+
# Maybe we should store things local to the executable (in case of a
|
| 288 |
+
# portable distro or a frozen application that wants to be portable)
|
| 289 |
+
prefix = sys.prefix
|
| 290 |
+
if getattr(sys, 'frozen', None):
|
| 291 |
+
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
| 292 |
+
for reldir in ('settings', '../settings'):
|
| 293 |
+
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
| 294 |
+
if os.path.isdir(localpath): # pragma: no cover
|
| 295 |
+
try:
|
| 296 |
+
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
| 297 |
+
os.remove(os.path.join(localpath, 'test.write'))
|
| 298 |
+
except IOError:
|
| 299 |
+
pass # We cannot write in this directory
|
| 300 |
+
else:
|
| 301 |
+
path = localpath
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
# Get path specific for this app
|
| 305 |
+
if appname:
|
| 306 |
+
if path == userDir:
|
| 307 |
+
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
| 308 |
+
path = os.path.join(path, appname)
|
| 309 |
+
if not os.path.isdir(path): # pragma: no cover
|
| 310 |
+
os.mkdir(path)
|
| 311 |
+
|
| 312 |
+
# Done
|
| 313 |
+
return path
|
musetalk_integration/utils/face_parsing/__init__.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from .model import BiSeNet
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
|
| 10 |
+
class FaceParsing():
|
| 11 |
+
def __init__(self, left_cheek_width=80, right_cheek_width=80):
|
| 12 |
+
self.net = self.model_init()
|
| 13 |
+
self.preprocess = self.image_preprocess()
|
| 14 |
+
# Ensure all size parameters are integers
|
| 15 |
+
cone_height = 21
|
| 16 |
+
tail_height = 12
|
| 17 |
+
total_size = cone_height + tail_height
|
| 18 |
+
|
| 19 |
+
# Create kernel with explicit integer dimensions
|
| 20 |
+
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
|
| 21 |
+
center_x = total_size // 2 # Ensure center coordinates are integers
|
| 22 |
+
|
| 23 |
+
# Cone part
|
| 24 |
+
for row in range(cone_height):
|
| 25 |
+
if row < cone_height//2:
|
| 26 |
+
continue
|
| 27 |
+
width = int(2 * (row - cone_height//2) + 1)
|
| 28 |
+
start = int(center_x - (width // 2))
|
| 29 |
+
end = int(center_x + (width // 2) + 1)
|
| 30 |
+
kernel[row, start:end] = 1
|
| 31 |
+
|
| 32 |
+
# Vertical extension part
|
| 33 |
+
if cone_height > 0:
|
| 34 |
+
base_width = int(kernel[cone_height-1].sum())
|
| 35 |
+
else:
|
| 36 |
+
base_width = 1
|
| 37 |
+
|
| 38 |
+
for row in range(cone_height, total_size):
|
| 39 |
+
start = max(0, int(center_x - (base_width//2)))
|
| 40 |
+
end = min(total_size, int(center_x + (base_width//2) + 1))
|
| 41 |
+
kernel[row, start:end] = 1
|
| 42 |
+
self.kernel = kernel
|
| 43 |
+
|
| 44 |
+
# Modify cheek erosion kernel to be flatter ellipse
|
| 45 |
+
self.cheek_kernel = cv2.getStructuringElement(
|
| 46 |
+
cv2.MORPH_ELLIPSE, (35, 3))
|
| 47 |
+
|
| 48 |
+
# Add cheek area mask (protect chin area)
|
| 49 |
+
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
|
| 50 |
+
|
| 51 |
+
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
|
| 52 |
+
"""Create cheek area mask (1/4 area on both sides)"""
|
| 53 |
+
mask = np.zeros((512, 512), dtype=np.uint8)
|
| 54 |
+
center = 512 // 2
|
| 55 |
+
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
|
| 56 |
+
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
|
| 57 |
+
return mask
|
| 58 |
+
|
| 59 |
+
def model_init(self,
|
| 60 |
+
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
| 61 |
+
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
| 62 |
+
net = BiSeNet(resnet_path)
|
| 63 |
+
if torch.cuda.is_available():
|
| 64 |
+
net.cuda()
|
| 65 |
+
net.load_state_dict(torch.load(model_pth))
|
| 66 |
+
else:
|
| 67 |
+
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
| 68 |
+
net.eval()
|
| 69 |
+
return net
|
| 70 |
+
|
| 71 |
+
def image_preprocess(self):
|
| 72 |
+
return transforms.Compose([
|
| 73 |
+
transforms.ToTensor(),
|
| 74 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
def __call__(self, image, size=(512, 512), mode="raw"):
|
| 78 |
+
if isinstance(image, str):
|
| 79 |
+
image = Image.open(image)
|
| 80 |
+
|
| 81 |
+
width, height = image.size
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
image = image.resize(size, Image.BILINEAR)
|
| 84 |
+
img = self.preprocess(image)
|
| 85 |
+
if torch.cuda.is_available():
|
| 86 |
+
img = torch.unsqueeze(img, 0).cuda()
|
| 87 |
+
else:
|
| 88 |
+
img = torch.unsqueeze(img, 0)
|
| 89 |
+
out = self.net(img)[0]
|
| 90 |
+
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
| 91 |
+
|
| 92 |
+
# Add 14:neck, remove 10:nose and 7:8:9
|
| 93 |
+
if mode == "neck":
|
| 94 |
+
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
|
| 95 |
+
parsing[np.where(parsing!=255)] = 0
|
| 96 |
+
elif mode == "jaw":
|
| 97 |
+
face_region = np.isin(parsing, [1])*255
|
| 98 |
+
face_region = face_region.astype(np.uint8)
|
| 99 |
+
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
|
| 100 |
+
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
|
| 101 |
+
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
|
| 102 |
+
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
|
| 103 |
+
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
|
| 104 |
+
parsing[np.isin(parsing, [11, 12, 13])] = 255
|
| 105 |
+
parsing[np.where(parsing!=255)] = 0
|
| 106 |
+
else:
|
| 107 |
+
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
|
| 108 |
+
parsing[np.where(parsing!=255)] = 0
|
| 109 |
+
|
| 110 |
+
parsing = Image.fromarray(parsing.astype(np.uint8))
|
| 111 |
+
return parsing
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
fp = FaceParsing()
|
| 115 |
+
segmap = fp('154_small.png')
|
| 116 |
+
segmap.save('res.png')
|
| 117 |
+
|
musetalk_integration/utils/face_parsing/model.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchvision
|
| 9 |
+
|
| 10 |
+
from .resnet import Resnet18
|
| 11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ConvBNReLU(nn.Module):
|
| 15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
| 16 |
+
super(ConvBNReLU, self).__init__()
|
| 17 |
+
self.conv = nn.Conv2d(in_chan,
|
| 18 |
+
out_chan,
|
| 19 |
+
kernel_size = ks,
|
| 20 |
+
stride = stride,
|
| 21 |
+
padding = padding,
|
| 22 |
+
bias = False)
|
| 23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
| 24 |
+
self.init_weight()
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
x = self.conv(x)
|
| 28 |
+
x = F.relu(self.bn(x))
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
def init_weight(self):
|
| 32 |
+
for ly in self.children():
|
| 33 |
+
if isinstance(ly, nn.Conv2d):
|
| 34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 36 |
+
|
| 37 |
+
class BiSeNetOutput(nn.Module):
|
| 38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
| 39 |
+
super(BiSeNetOutput, self).__init__()
|
| 40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
| 41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
| 42 |
+
self.init_weight()
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.conv(x)
|
| 46 |
+
x = self.conv_out(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
def init_weight(self):
|
| 50 |
+
for ly in self.children():
|
| 51 |
+
if isinstance(ly, nn.Conv2d):
|
| 52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 54 |
+
|
| 55 |
+
def get_params(self):
|
| 56 |
+
wd_params, nowd_params = [], []
|
| 57 |
+
for name, module in self.named_modules():
|
| 58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
| 59 |
+
wd_params.append(module.weight)
|
| 60 |
+
if not module.bias is None:
|
| 61 |
+
nowd_params.append(module.bias)
|
| 62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 63 |
+
nowd_params += list(module.parameters())
|
| 64 |
+
return wd_params, nowd_params
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AttentionRefinementModule(nn.Module):
|
| 68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
| 69 |
+
super(AttentionRefinementModule, self).__init__()
|
| 70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
| 71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
| 72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
| 73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
| 74 |
+
self.init_weight()
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
feat = self.conv(x)
|
| 78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
| 79 |
+
atten = self.conv_atten(atten)
|
| 80 |
+
atten = self.bn_atten(atten)
|
| 81 |
+
atten = self.sigmoid_atten(atten)
|
| 82 |
+
out = torch.mul(feat, atten)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
def init_weight(self):
|
| 86 |
+
for ly in self.children():
|
| 87 |
+
if isinstance(ly, nn.Conv2d):
|
| 88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ContextPath(nn.Module):
|
| 93 |
+
def __init__(self, resnet_path, *args, **kwargs):
|
| 94 |
+
super(ContextPath, self).__init__()
|
| 95 |
+
self.resnet = Resnet18(resnet_path)
|
| 96 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
| 97 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
| 98 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
| 99 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
| 100 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
| 101 |
+
|
| 102 |
+
self.init_weight()
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
H0, W0 = x.size()[2:]
|
| 106 |
+
feat8, feat16, feat32 = self.resnet(x)
|
| 107 |
+
H8, W8 = feat8.size()[2:]
|
| 108 |
+
H16, W16 = feat16.size()[2:]
|
| 109 |
+
H32, W32 = feat32.size()[2:]
|
| 110 |
+
|
| 111 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
| 112 |
+
avg = self.conv_avg(avg)
|
| 113 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
| 114 |
+
|
| 115 |
+
feat32_arm = self.arm32(feat32)
|
| 116 |
+
feat32_sum = feat32_arm + avg_up
|
| 117 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
| 118 |
+
feat32_up = self.conv_head32(feat32_up)
|
| 119 |
+
|
| 120 |
+
feat16_arm = self.arm16(feat16)
|
| 121 |
+
feat16_sum = feat16_arm + feat32_up
|
| 122 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
| 123 |
+
feat16_up = self.conv_head16(feat16_up)
|
| 124 |
+
|
| 125 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
| 126 |
+
|
| 127 |
+
def init_weight(self):
|
| 128 |
+
for ly in self.children():
|
| 129 |
+
if isinstance(ly, nn.Conv2d):
|
| 130 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 131 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 132 |
+
|
| 133 |
+
def get_params(self):
|
| 134 |
+
wd_params, nowd_params = [], []
|
| 135 |
+
for name, module in self.named_modules():
|
| 136 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 137 |
+
wd_params.append(module.weight)
|
| 138 |
+
if not module.bias is None:
|
| 139 |
+
nowd_params.append(module.bias)
|
| 140 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 141 |
+
nowd_params += list(module.parameters())
|
| 142 |
+
return wd_params, nowd_params
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
| 146 |
+
class SpatialPath(nn.Module):
|
| 147 |
+
def __init__(self, *args, **kwargs):
|
| 148 |
+
super(SpatialPath, self).__init__()
|
| 149 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
| 150 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
| 151 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
| 152 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
| 153 |
+
self.init_weight()
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
feat = self.conv1(x)
|
| 157 |
+
feat = self.conv2(feat)
|
| 158 |
+
feat = self.conv3(feat)
|
| 159 |
+
feat = self.conv_out(feat)
|
| 160 |
+
return feat
|
| 161 |
+
|
| 162 |
+
def init_weight(self):
|
| 163 |
+
for ly in self.children():
|
| 164 |
+
if isinstance(ly, nn.Conv2d):
|
| 165 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 166 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 167 |
+
|
| 168 |
+
def get_params(self):
|
| 169 |
+
wd_params, nowd_params = [], []
|
| 170 |
+
for name, module in self.named_modules():
|
| 171 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
| 172 |
+
wd_params.append(module.weight)
|
| 173 |
+
if not module.bias is None:
|
| 174 |
+
nowd_params.append(module.bias)
|
| 175 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 176 |
+
nowd_params += list(module.parameters())
|
| 177 |
+
return wd_params, nowd_params
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class FeatureFusionModule(nn.Module):
|
| 181 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
| 182 |
+
super(FeatureFusionModule, self).__init__()
|
| 183 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
| 184 |
+
self.conv1 = nn.Conv2d(out_chan,
|
| 185 |
+
out_chan//4,
|
| 186 |
+
kernel_size = 1,
|
| 187 |
+
stride = 1,
|
| 188 |
+
padding = 0,
|
| 189 |
+
bias = False)
|
| 190 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
| 191 |
+
out_chan,
|
| 192 |
+
kernel_size = 1,
|
| 193 |
+
stride = 1,
|
| 194 |
+
padding = 0,
|
| 195 |
+
bias = False)
|
| 196 |
+
self.relu = nn.ReLU(inplace=True)
|
| 197 |
+
self.sigmoid = nn.Sigmoid()
|
| 198 |
+
self.init_weight()
|
| 199 |
+
|
| 200 |
+
def forward(self, fsp, fcp):
|
| 201 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
| 202 |
+
feat = self.convblk(fcat)
|
| 203 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
| 204 |
+
atten = self.conv1(atten)
|
| 205 |
+
atten = self.relu(atten)
|
| 206 |
+
atten = self.conv2(atten)
|
| 207 |
+
atten = self.sigmoid(atten)
|
| 208 |
+
feat_atten = torch.mul(feat, atten)
|
| 209 |
+
feat_out = feat_atten + feat
|
| 210 |
+
return feat_out
|
| 211 |
+
|
| 212 |
+
def init_weight(self):
|
| 213 |
+
for ly in self.children():
|
| 214 |
+
if isinstance(ly, nn.Conv2d):
|
| 215 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 216 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 217 |
+
|
| 218 |
+
def get_params(self):
|
| 219 |
+
wd_params, nowd_params = [], []
|
| 220 |
+
for name, module in self.named_modules():
|
| 221 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
| 222 |
+
wd_params.append(module.weight)
|
| 223 |
+
if not module.bias is None:
|
| 224 |
+
nowd_params.append(module.bias)
|
| 225 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 226 |
+
nowd_params += list(module.parameters())
|
| 227 |
+
return wd_params, nowd_params
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class BiSeNet(nn.Module):
|
| 231 |
+
def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
|
| 232 |
+
super(BiSeNet, self).__init__()
|
| 233 |
+
self.cp = ContextPath(resnet_path)
|
| 234 |
+
## here self.sp is deleted
|
| 235 |
+
self.ffm = FeatureFusionModule(256, 256)
|
| 236 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
| 237 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
| 238 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
| 239 |
+
self.init_weight()
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
H, W = x.size()[2:]
|
| 243 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
| 244 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
| 245 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
| 246 |
+
|
| 247 |
+
feat_out = self.conv_out(feat_fuse)
|
| 248 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
| 249 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
| 250 |
+
|
| 251 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
| 252 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
| 253 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
| 254 |
+
return feat_out, feat_out16, feat_out32
|
| 255 |
+
|
| 256 |
+
def init_weight(self):
|
| 257 |
+
for ly in self.children():
|
| 258 |
+
if isinstance(ly, nn.Conv2d):
|
| 259 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 260 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 261 |
+
|
| 262 |
+
def get_params(self):
|
| 263 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
| 264 |
+
for name, child in self.named_children():
|
| 265 |
+
child_wd_params, child_nowd_params = child.get_params()
|
| 266 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
| 267 |
+
lr_mul_wd_params += child_wd_params
|
| 268 |
+
lr_mul_nowd_params += child_nowd_params
|
| 269 |
+
else:
|
| 270 |
+
wd_params += child_wd_params
|
| 271 |
+
nowd_params += child_nowd_params
|
| 272 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if __name__ == "__main__":
|
| 276 |
+
net = BiSeNet(19)
|
| 277 |
+
net.cuda()
|
| 278 |
+
net.eval()
|
| 279 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
| 280 |
+
out, out16, out32 = net(in_ten)
|
| 281 |
+
print(out.shape)
|
| 282 |
+
|
| 283 |
+
net.get_params()
|
musetalk_integration/utils/face_parsing/resnet.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.model_zoo as modelzoo
|
| 8 |
+
|
| 9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
| 10 |
+
|
| 11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 15 |
+
"""3x3 convolution with padding"""
|
| 16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 17 |
+
padding=1, bias=False)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BasicBlock(nn.Module):
|
| 21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
| 22 |
+
super(BasicBlock, self).__init__()
|
| 23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
| 24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
| 25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
| 26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
| 27 |
+
self.relu = nn.ReLU(inplace=True)
|
| 28 |
+
self.downsample = None
|
| 29 |
+
if in_chan != out_chan or stride != 1:
|
| 30 |
+
self.downsample = nn.Sequential(
|
| 31 |
+
nn.Conv2d(in_chan, out_chan,
|
| 32 |
+
kernel_size=1, stride=stride, bias=False),
|
| 33 |
+
nn.BatchNorm2d(out_chan),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
residual = self.conv1(x)
|
| 38 |
+
residual = F.relu(self.bn1(residual))
|
| 39 |
+
residual = self.conv2(residual)
|
| 40 |
+
residual = self.bn2(residual)
|
| 41 |
+
|
| 42 |
+
shortcut = x
|
| 43 |
+
if self.downsample is not None:
|
| 44 |
+
shortcut = self.downsample(x)
|
| 45 |
+
|
| 46 |
+
out = shortcut + residual
|
| 47 |
+
out = self.relu(out)
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
| 52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
| 53 |
+
for i in range(bnum-1):
|
| 54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
| 55 |
+
return nn.Sequential(*layers)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Resnet18(nn.Module):
|
| 59 |
+
def __init__(self, model_path):
|
| 60 |
+
super(Resnet18, self).__init__()
|
| 61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
| 62 |
+
bias=False)
|
| 63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
| 66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
| 67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
| 68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
| 69 |
+
self.init_weight(model_path)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.conv1(x)
|
| 73 |
+
x = F.relu(self.bn1(x))
|
| 74 |
+
x = self.maxpool(x)
|
| 75 |
+
|
| 76 |
+
x = self.layer1(x)
|
| 77 |
+
feat8 = self.layer2(x) # 1/8
|
| 78 |
+
feat16 = self.layer3(feat8) # 1/16
|
| 79 |
+
feat32 = self.layer4(feat16) # 1/32
|
| 80 |
+
return feat8, feat16, feat32
|
| 81 |
+
|
| 82 |
+
def init_weight(self, model_path):
|
| 83 |
+
state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
|
| 84 |
+
self_state_dict = self.state_dict()
|
| 85 |
+
for k, v in state_dict.items():
|
| 86 |
+
if 'fc' in k: continue
|
| 87 |
+
self_state_dict.update({k: v})
|
| 88 |
+
self.load_state_dict(self_state_dict)
|
| 89 |
+
|
| 90 |
+
def get_params(self):
|
| 91 |
+
wd_params, nowd_params = [], []
|
| 92 |
+
for name, module in self.named_modules():
|
| 93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 94 |
+
wd_params.append(module.weight)
|
| 95 |
+
if not module.bias is None:
|
| 96 |
+
nowd_params.append(module.bias)
|
| 97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 98 |
+
nowd_params += list(module.parameters())
|
| 99 |
+
return wd_params, nowd_params
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
net = Resnet18()
|
| 104 |
+
x = torch.randn(16, 3, 224, 224)
|
| 105 |
+
out = net(x)
|
| 106 |
+
print(out[0].size())
|
| 107 |
+
print(out[1].size())
|
| 108 |
+
print(out[2].size())
|
| 109 |
+
net.get_params()
|
musetalk_integration/whisper/__init__.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import urllib
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
| 12 |
+
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
| 13 |
+
from .model import Whisper, ModelDimensions
|
| 14 |
+
from .transcribe import transcribe
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
_MODELS = {
|
| 18 |
+
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
| 19 |
+
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
| 20 |
+
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
| 21 |
+
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
| 22 |
+
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
| 23 |
+
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
| 24 |
+
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
| 25 |
+
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
| 26 |
+
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
| 27 |
+
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
| 28 |
+
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
| 29 |
+
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
| 34 |
+
os.makedirs(root, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
expected_sha256 = url.split("/")[-2]
|
| 37 |
+
download_target = os.path.join(root, os.path.basename(url))
|
| 38 |
+
|
| 39 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 40 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 41 |
+
|
| 42 |
+
if os.path.isfile(download_target):
|
| 43 |
+
model_bytes = open(download_target, "rb").read()
|
| 44 |
+
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
| 45 |
+
return model_bytes if in_memory else download_target
|
| 46 |
+
else:
|
| 47 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 48 |
+
|
| 49 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 50 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
| 51 |
+
while True:
|
| 52 |
+
buffer = source.read(8192)
|
| 53 |
+
if not buffer:
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
output.write(buffer)
|
| 57 |
+
loop.update(len(buffer))
|
| 58 |
+
|
| 59 |
+
model_bytes = open(download_target, "rb").read()
|
| 60 |
+
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
| 61 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
| 62 |
+
|
| 63 |
+
return model_bytes if in_memory else download_target
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def available_models() -> List[str]:
|
| 67 |
+
"""Returns the names of available models"""
|
| 68 |
+
return list(_MODELS.keys())
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
|
| 72 |
+
"""
|
| 73 |
+
Load a Whisper ASR model
|
| 74 |
+
|
| 75 |
+
Parameters
|
| 76 |
+
----------
|
| 77 |
+
name : str
|
| 78 |
+
one of the official model names listed by `whisper.available_models()`, or
|
| 79 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
| 80 |
+
device : Union[str, torch.device]
|
| 81 |
+
the PyTorch device to put the model into
|
| 82 |
+
download_root: str
|
| 83 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
| 84 |
+
in_memory: bool
|
| 85 |
+
whether to preload the model weights into host memory
|
| 86 |
+
|
| 87 |
+
Returns
|
| 88 |
+
-------
|
| 89 |
+
model : Whisper
|
| 90 |
+
The Whisper ASR model instance
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
if device is None:
|
| 94 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 95 |
+
if download_root is None:
|
| 96 |
+
download_root = os.getenv(
|
| 97 |
+
"XDG_CACHE_HOME",
|
| 98 |
+
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if name in _MODELS:
|
| 102 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
| 103 |
+
elif os.path.isfile(name):
|
| 104 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
| 105 |
+
else:
|
| 106 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 107 |
+
|
| 108 |
+
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
| 109 |
+
checkpoint = torch.load(fp, map_location=device)
|
| 110 |
+
del checkpoint_file
|
| 111 |
+
|
| 112 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
| 113 |
+
model = Whisper(dims)
|
| 114 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 115 |
+
|
| 116 |
+
return model.to(device)
|
musetalk_integration/whisper/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transcribe import cli
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
cli()
|
musetalk_integration/whisper/audio.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import ffmpeg
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .utils import exact_div
|
| 11 |
+
|
| 12 |
+
# hard-coded audio hyperparameters
|
| 13 |
+
SAMPLE_RATE = 16000
|
| 14 |
+
N_FFT = 400
|
| 15 |
+
N_MELS = 80
|
| 16 |
+
HOP_LENGTH = 160
|
| 17 |
+
CHUNK_LENGTH = 30
|
| 18 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
| 19 |
+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
| 23 |
+
"""
|
| 24 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
file: str
|
| 29 |
+
The audio file to open
|
| 30 |
+
|
| 31 |
+
sr: int
|
| 32 |
+
The sample rate to resample the audio if necessary
|
| 33 |
+
|
| 34 |
+
Returns
|
| 35 |
+
-------
|
| 36 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
| 40 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
| 41 |
+
out, _ = (
|
| 42 |
+
ffmpeg.input(file, threads=0)
|
| 43 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
| 44 |
+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
| 45 |
+
)
|
| 46 |
+
except ffmpeg.Error as e:
|
| 47 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
| 48 |
+
|
| 49 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
| 53 |
+
"""
|
| 54 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
| 55 |
+
"""
|
| 56 |
+
if torch.is_tensor(array):
|
| 57 |
+
if array.shape[axis] > length:
|
| 58 |
+
array = array.index_select(dim=axis, index=torch.arange(length))
|
| 59 |
+
|
| 60 |
+
if array.shape[axis] < length:
|
| 61 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 62 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 63 |
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
| 64 |
+
else:
|
| 65 |
+
if array.shape[axis] > length:
|
| 66 |
+
array = array.take(indices=range(length), axis=axis)
|
| 67 |
+
|
| 68 |
+
if array.shape[axis] < length:
|
| 69 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 70 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 71 |
+
array = np.pad(array, pad_widths)
|
| 72 |
+
|
| 73 |
+
return array
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@lru_cache(maxsize=None)
|
| 77 |
+
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 80 |
+
Allows decoupling librosa dependency; saved using:
|
| 81 |
+
|
| 82 |
+
np.savez_compressed(
|
| 83 |
+
"mel_filters.npz",
|
| 84 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 85 |
+
)
|
| 86 |
+
"""
|
| 87 |
+
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
| 88 |
+
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
| 89 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
| 93 |
+
"""
|
| 94 |
+
Compute the log-Mel spectrogram of
|
| 95 |
+
|
| 96 |
+
Parameters
|
| 97 |
+
----------
|
| 98 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 99 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 100 |
+
|
| 101 |
+
n_mels: int
|
| 102 |
+
The number of Mel-frequency filters, only 80 is supported
|
| 103 |
+
|
| 104 |
+
Returns
|
| 105 |
+
-------
|
| 106 |
+
torch.Tensor, shape = (80, n_frames)
|
| 107 |
+
A Tensor that contains the Mel spectrogram
|
| 108 |
+
"""
|
| 109 |
+
if not torch.is_tensor(audio):
|
| 110 |
+
if isinstance(audio, str):
|
| 111 |
+
audio = load_audio(audio)
|
| 112 |
+
audio = torch.from_numpy(audio)
|
| 113 |
+
|
| 114 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
| 115 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 116 |
+
|
| 117 |
+
magnitudes = stft[:, :-1].abs() ** 2
|
| 118 |
+
|
| 119 |
+
filters = mel_filters(audio.device, n_mels)
|
| 120 |
+
mel_spec = filters @ magnitudes
|
| 121 |
+
|
| 122 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 123 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 124 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 125 |
+
return log_spec
|
musetalk_integration/whisper/decoding.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.distributions import Categorical
|
| 9 |
+
|
| 10 |
+
from .audio import CHUNK_LENGTH
|
| 11 |
+
from .tokenizer import Tokenizer, get_tokenizer
|
| 12 |
+
from .utils import compression_ratio
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from .model import Whisper
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
| 20 |
+
"""
|
| 21 |
+
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
| 22 |
+
of the most probable language tokens and the probability distribution over all language tokens.
|
| 23 |
+
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
language_tokens : Tensor, shape = (n_audio,)
|
| 28 |
+
ids of the most probable language tokens, which appears after the startoftranscript token.
|
| 29 |
+
language_probs : List[Dict[str, float]], length = n_audio
|
| 30 |
+
list of dictionaries containing the probability distribution over all languages.
|
| 31 |
+
"""
|
| 32 |
+
if tokenizer is None:
|
| 33 |
+
tokenizer = get_tokenizer(model.is_multilingual)
|
| 34 |
+
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
| 35 |
+
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
| 36 |
+
|
| 37 |
+
single = mel.ndim == 2
|
| 38 |
+
if single:
|
| 39 |
+
mel = mel.unsqueeze(0)
|
| 40 |
+
|
| 41 |
+
# skip encoder forward pass if already-encoded audio features were given
|
| 42 |
+
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
| 43 |
+
mel = model.encoder(mel)
|
| 44 |
+
|
| 45 |
+
# forward pass using a single token, startoftranscript
|
| 46 |
+
n_audio = mel.shape[0]
|
| 47 |
+
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
| 48 |
+
logits = model.logits(x, mel)[:, 0]
|
| 49 |
+
|
| 50 |
+
# collect detected languages; suppress all non-language tokens
|
| 51 |
+
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
| 52 |
+
mask[list(tokenizer.all_language_tokens)] = False
|
| 53 |
+
logits[:, mask] = -np.inf
|
| 54 |
+
language_tokens = logits.argmax(dim=-1)
|
| 55 |
+
language_token_probs = logits.softmax(dim=-1).cpu()
|
| 56 |
+
language_probs = [
|
| 57 |
+
{
|
| 58 |
+
c: language_token_probs[i, j].item()
|
| 59 |
+
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
| 60 |
+
}
|
| 61 |
+
for i in range(n_audio)
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
if single:
|
| 65 |
+
language_tokens = language_tokens[0]
|
| 66 |
+
language_probs = language_probs[0]
|
| 67 |
+
|
| 68 |
+
return language_tokens, language_probs
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass(frozen=True)
|
| 72 |
+
class DecodingOptions:
|
| 73 |
+
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
| 74 |
+
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
| 75 |
+
|
| 76 |
+
# sampling-related options
|
| 77 |
+
temperature: float = 0.0
|
| 78 |
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
| 79 |
+
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
| 80 |
+
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
| 81 |
+
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
| 82 |
+
|
| 83 |
+
# options for ranking generations (either beams or best-of-N samples)
|
| 84 |
+
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
| 85 |
+
|
| 86 |
+
# prompt, prefix, and token suppression
|
| 87 |
+
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
| 88 |
+
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
| 89 |
+
suppress_blank: bool = True # this will suppress blank outputs
|
| 90 |
+
|
| 91 |
+
# list of tokens ids (or comma-separated token ids) to suppress
|
| 92 |
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
| 93 |
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
| 94 |
+
|
| 95 |
+
# timestamp sampling options
|
| 96 |
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
| 97 |
+
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
| 98 |
+
|
| 99 |
+
# implementation details
|
| 100 |
+
fp16: bool = True # use fp16 for most of the calculation
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass(frozen=True)
|
| 104 |
+
class DecodingResult:
|
| 105 |
+
audio_features: Tensor
|
| 106 |
+
language: str
|
| 107 |
+
encoder_embeddings: np.ndarray
|
| 108 |
+
decoder_embeddings: np.ndarray
|
| 109 |
+
language_probs: Optional[Dict[str, float]] = None
|
| 110 |
+
tokens: List[int] = field(default_factory=list)
|
| 111 |
+
text: str = ""
|
| 112 |
+
avg_logprob: float = np.nan
|
| 113 |
+
no_speech_prob: float = np.nan
|
| 114 |
+
temperature: float = np.nan
|
| 115 |
+
compression_ratio: float = np.nan
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Inference:
|
| 119 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
| 120 |
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
| 121 |
+
raise NotImplementedError
|
| 122 |
+
|
| 123 |
+
def rearrange_kv_cache(self, source_indices) -> None:
|
| 124 |
+
"""Update the key-value cache according to the updated beams"""
|
| 125 |
+
raise NotImplementedError
|
| 126 |
+
|
| 127 |
+
def cleanup_caching(self) -> None:
|
| 128 |
+
"""Clean up any resources or hooks after decoding is finished"""
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class PyTorchInference(Inference):
|
| 133 |
+
def __init__(self, model: "Whisper", initial_token_length: int):
|
| 134 |
+
self.model: "Whisper" = model
|
| 135 |
+
self.initial_token_length = initial_token_length
|
| 136 |
+
self.kv_cache = {}
|
| 137 |
+
self.hooks = []
|
| 138 |
+
|
| 139 |
+
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
|
| 140 |
+
if not self.kv_cache:
|
| 141 |
+
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
| 142 |
+
|
| 143 |
+
if tokens.shape[-1] > self.initial_token_length:
|
| 144 |
+
# only need to use the last token except in the first forward pass
|
| 145 |
+
tokens = tokens[:, -1:]
|
| 146 |
+
|
| 147 |
+
return_val = self.model.decoder(tokens, audio_features,
|
| 148 |
+
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
|
| 149 |
+
return return_val
|
| 150 |
+
|
| 151 |
+
def cleanup_caching(self):
|
| 152 |
+
for hook in self.hooks:
|
| 153 |
+
hook.remove()
|
| 154 |
+
|
| 155 |
+
self.kv_cache = {}
|
| 156 |
+
self.hooks = []
|
| 157 |
+
|
| 158 |
+
def rearrange_kv_cache(self, source_indices):
|
| 159 |
+
for module, tensor in self.kv_cache.items():
|
| 160 |
+
# update the key/value cache to contain the selected sequences
|
| 161 |
+
self.kv_cache[module] = tensor[source_indices].detach()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class SequenceRanker:
|
| 165 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
| 166 |
+
"""
|
| 167 |
+
Given a list of groups of samples and their cumulative log probabilities,
|
| 168 |
+
return the indices of the samples in each group to select as the final result
|
| 169 |
+
"""
|
| 170 |
+
raise NotImplementedError
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class MaximumLikelihoodRanker(SequenceRanker):
|
| 174 |
+
"""
|
| 175 |
+
Select the sample with the highest log probabilities, penalized using either
|
| 176 |
+
a simple length normalization or Google NMT paper's length penalty
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, length_penalty: Optional[float]):
|
| 180 |
+
self.length_penalty = length_penalty
|
| 181 |
+
|
| 182 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
| 183 |
+
def scores(logprobs, lengths):
|
| 184 |
+
result = []
|
| 185 |
+
for logprob, length in zip(logprobs, lengths):
|
| 186 |
+
if self.length_penalty is None:
|
| 187 |
+
penalty = length
|
| 188 |
+
else:
|
| 189 |
+
# from the Google NMT paper
|
| 190 |
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
| 191 |
+
result.append(logprob / penalty)
|
| 192 |
+
return result
|
| 193 |
+
|
| 194 |
+
# get the sequence with the highest score
|
| 195 |
+
lengths = [[len(t) for t in s] for s in tokens]
|
| 196 |
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class TokenDecoder:
|
| 200 |
+
def reset(self):
|
| 201 |
+
"""Initialize any stateful variables for decoding a new sequence"""
|
| 202 |
+
|
| 203 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
| 204 |
+
"""Specify how to select the next token, based on the current trace and logits
|
| 205 |
+
|
| 206 |
+
Parameters
|
| 207 |
+
----------
|
| 208 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
| 209 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
| 210 |
+
|
| 211 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
| 212 |
+
per-token logits of the probability distribution at the current step
|
| 213 |
+
|
| 214 |
+
sum_logprobs : Tensor, shape = (n_batch)
|
| 215 |
+
cumulative log probabilities for each sequence
|
| 216 |
+
|
| 217 |
+
Returns
|
| 218 |
+
-------
|
| 219 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
| 220 |
+
the tokens, appended with the selected next token
|
| 221 |
+
|
| 222 |
+
completed : bool
|
| 223 |
+
True if all sequences has reached the end of text
|
| 224 |
+
|
| 225 |
+
"""
|
| 226 |
+
raise NotImplementedError
|
| 227 |
+
|
| 228 |
+
def finalize(
|
| 229 |
+
self, tokens: Tensor, sum_logprobs: Tensor
|
| 230 |
+
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
| 231 |
+
"""Finalize search and return the final candidate sequences
|
| 232 |
+
|
| 233 |
+
Parameters
|
| 234 |
+
----------
|
| 235 |
+
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
| 236 |
+
all tokens in the context so far, including the prefix and sot_sequence
|
| 237 |
+
|
| 238 |
+
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
| 239 |
+
cumulative log probabilities for each sequence
|
| 240 |
+
|
| 241 |
+
Returns
|
| 242 |
+
-------
|
| 243 |
+
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
| 244 |
+
sequence of Tensors containing candidate token sequences, for each audio input
|
| 245 |
+
|
| 246 |
+
sum_logprobs : List[List[float]], length = n_audio
|
| 247 |
+
sequence of cumulative log probabilities corresponding to the above
|
| 248 |
+
|
| 249 |
+
"""
|
| 250 |
+
raise NotImplementedError
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class GreedyDecoder(TokenDecoder):
|
| 254 |
+
def __init__(self, temperature: float, eot: int):
|
| 255 |
+
self.temperature = temperature
|
| 256 |
+
self.eot = eot
|
| 257 |
+
|
| 258 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
| 259 |
+
temperature = self.temperature
|
| 260 |
+
if temperature == 0:
|
| 261 |
+
next_tokens = logits.argmax(dim=-1)
|
| 262 |
+
else:
|
| 263 |
+
next_tokens = Categorical(logits=logits / temperature).sample()
|
| 264 |
+
|
| 265 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 266 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
| 267 |
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
| 268 |
+
|
| 269 |
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
| 270 |
+
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
| 271 |
+
|
| 272 |
+
completed = (tokens[:, -1] == self.eot).all()
|
| 273 |
+
return tokens, completed
|
| 274 |
+
|
| 275 |
+
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
| 276 |
+
# make sure each sequence has at least one EOT token at the end
|
| 277 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
| 278 |
+
return tokens, sum_logprobs.tolist()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class BeamSearchDecoder(TokenDecoder):
|
| 282 |
+
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
| 283 |
+
self.beam_size = beam_size
|
| 284 |
+
self.eot = eot
|
| 285 |
+
self.inference = inference
|
| 286 |
+
self.patience = patience or 1.0
|
| 287 |
+
self.max_candidates: int = round(beam_size * self.patience)
|
| 288 |
+
self.finished_sequences = None
|
| 289 |
+
|
| 290 |
+
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
| 291 |
+
|
| 292 |
+
def reset(self):
|
| 293 |
+
self.finished_sequences = None
|
| 294 |
+
|
| 295 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
| 296 |
+
if tokens.shape[0] % self.beam_size != 0:
|
| 297 |
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
| 298 |
+
|
| 299 |
+
n_audio = tokens.shape[0] // self.beam_size
|
| 300 |
+
if self.finished_sequences is None: # for the first update
|
| 301 |
+
self.finished_sequences = [{} for _ in range(n_audio)]
|
| 302 |
+
|
| 303 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 304 |
+
next_tokens, source_indices, finished_sequences = [], [], []
|
| 305 |
+
for i in range(n_audio):
|
| 306 |
+
scores, sources, finished = {}, {}, {}
|
| 307 |
+
|
| 308 |
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
| 309 |
+
for j in range(self.beam_size):
|
| 310 |
+
idx = i * self.beam_size + j
|
| 311 |
+
prefix = tokens[idx].tolist()
|
| 312 |
+
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
| 313 |
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
| 314 |
+
sequence = tuple(prefix + [token.item()])
|
| 315 |
+
scores[sequence] = new_logprob
|
| 316 |
+
sources[sequence] = idx
|
| 317 |
+
|
| 318 |
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
| 319 |
+
saved = 0
|
| 320 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
| 321 |
+
if sequence[-1] == self.eot:
|
| 322 |
+
finished[sequence] = scores[sequence]
|
| 323 |
+
else:
|
| 324 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
| 325 |
+
next_tokens.append(sequence)
|
| 326 |
+
source_indices.append(sources[sequence])
|
| 327 |
+
|
| 328 |
+
saved += 1
|
| 329 |
+
if saved == self.beam_size:
|
| 330 |
+
break
|
| 331 |
+
|
| 332 |
+
finished_sequences.append(finished)
|
| 333 |
+
|
| 334 |
+
tokens = torch.tensor(next_tokens, device=tokens.device)
|
| 335 |
+
self.inference.rearrange_kv_cache(source_indices)
|
| 336 |
+
|
| 337 |
+
# add newly finished sequences to self.finished_sequences
|
| 338 |
+
assert len(self.finished_sequences) == len(finished_sequences)
|
| 339 |
+
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
| 340 |
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
| 341 |
+
if len(previously_finished) >= self.max_candidates:
|
| 342 |
+
break # the candidate list is full
|
| 343 |
+
previously_finished[seq] = newly_finished[seq]
|
| 344 |
+
|
| 345 |
+
# mark as completed if all audio has enough number of samples
|
| 346 |
+
completed = all(
|
| 347 |
+
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
| 348 |
+
)
|
| 349 |
+
return tokens, completed
|
| 350 |
+
|
| 351 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
| 352 |
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
| 353 |
+
sum_logprobs = sum_logprobs.cpu()
|
| 354 |
+
for i, sequences in enumerate(self.finished_sequences):
|
| 355 |
+
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
| 356 |
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
| 357 |
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
| 358 |
+
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
| 359 |
+
if len(sequences) >= self.beam_size:
|
| 360 |
+
break
|
| 361 |
+
|
| 362 |
+
tokens: List[List[Tensor]] = [
|
| 363 |
+
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
| 364 |
+
]
|
| 365 |
+
sum_logprobs: List[List[float]] = [
|
| 366 |
+
list(sequences.values()) for sequences in self.finished_sequences
|
| 367 |
+
]
|
| 368 |
+
return tokens, sum_logprobs
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class LogitFilter:
|
| 372 |
+
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
| 373 |
+
"""Apply any filtering or masking to logits in-place
|
| 374 |
+
|
| 375 |
+
Parameters
|
| 376 |
+
----------
|
| 377 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
| 378 |
+
per-token logits of the probability distribution at the current step
|
| 379 |
+
|
| 380 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
| 381 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
| 382 |
+
|
| 383 |
+
"""
|
| 384 |
+
raise NotImplementedError
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class SuppressBlank(LogitFilter):
|
| 388 |
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
| 389 |
+
self.tokenizer = tokenizer
|
| 390 |
+
self.sample_begin = sample_begin
|
| 391 |
+
|
| 392 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 393 |
+
if tokens.shape[1] == self.sample_begin:
|
| 394 |
+
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class SuppressTokens(LogitFilter):
|
| 398 |
+
def __init__(self, suppress_tokens: Sequence[int]):
|
| 399 |
+
self.suppress_tokens = list(suppress_tokens)
|
| 400 |
+
|
| 401 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 402 |
+
logits[:, self.suppress_tokens] = -np.inf
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class ApplyTimestampRules(LogitFilter):
|
| 406 |
+
def __init__(
|
| 407 |
+
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
| 408 |
+
):
|
| 409 |
+
self.tokenizer = tokenizer
|
| 410 |
+
self.sample_begin = sample_begin
|
| 411 |
+
self.max_initial_timestamp_index = max_initial_timestamp_index
|
| 412 |
+
|
| 413 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 414 |
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
| 415 |
+
if self.tokenizer.no_timestamps is not None:
|
| 416 |
+
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
| 417 |
+
|
| 418 |
+
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
| 419 |
+
for k in range(tokens.shape[0]):
|
| 420 |
+
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
| 421 |
+
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
| 422 |
+
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
| 423 |
+
|
| 424 |
+
if last_was_timestamp:
|
| 425 |
+
if penultimate_was_timestamp: # has to be non-timestamp
|
| 426 |
+
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
| 427 |
+
else: # cannot be normal text tokens
|
| 428 |
+
logits[k, : self.tokenizer.eot] = -np.inf
|
| 429 |
+
|
| 430 |
+
# apply the `max_initial_timestamp` option
|
| 431 |
+
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
| 432 |
+
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
| 433 |
+
logits[:, last_allowed + 1 :] = -np.inf
|
| 434 |
+
|
| 435 |
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
| 436 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 437 |
+
for k in range(tokens.shape[0]):
|
| 438 |
+
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
| 439 |
+
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
| 440 |
+
if timestamp_logprob > max_text_token_logprob:
|
| 441 |
+
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class DecodingTask:
|
| 445 |
+
inference: Inference
|
| 446 |
+
sequence_ranker: SequenceRanker
|
| 447 |
+
decoder: TokenDecoder
|
| 448 |
+
logit_filters: List[LogitFilter]
|
| 449 |
+
|
| 450 |
+
def __init__(self, model: "Whisper", options: DecodingOptions):
|
| 451 |
+
self.model = model
|
| 452 |
+
|
| 453 |
+
language = options.language or "en"
|
| 454 |
+
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
| 455 |
+
self.tokenizer: Tokenizer = tokenizer
|
| 456 |
+
self.options: DecodingOptions = self._verify_options(options)
|
| 457 |
+
|
| 458 |
+
self.n_group: int = options.beam_size or options.best_of or 1
|
| 459 |
+
self.n_ctx: int = model.dims.n_text_ctx
|
| 460 |
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
| 461 |
+
|
| 462 |
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
| 463 |
+
if self.options.without_timestamps:
|
| 464 |
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
| 465 |
+
|
| 466 |
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
| 467 |
+
self.sample_begin: int = len(self.initial_tokens)
|
| 468 |
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
| 469 |
+
|
| 470 |
+
# inference: implements the forward pass through the decoder, including kv caching
|
| 471 |
+
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
| 472 |
+
|
| 473 |
+
# sequence ranker: implements how to rank a group of sampled sequences
|
| 474 |
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
| 475 |
+
|
| 476 |
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
| 477 |
+
if options.beam_size is not None:
|
| 478 |
+
self.decoder = BeamSearchDecoder(
|
| 479 |
+
options.beam_size, tokenizer.eot, self.inference, options.patience
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
| 483 |
+
|
| 484 |
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
| 485 |
+
self.logit_filters = []
|
| 486 |
+
if self.options.suppress_blank:
|
| 487 |
+
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
| 488 |
+
if self.options.suppress_tokens:
|
| 489 |
+
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
| 490 |
+
if not options.without_timestamps:
|
| 491 |
+
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
| 492 |
+
max_initial_timestamp_index = None
|
| 493 |
+
if options.max_initial_timestamp:
|
| 494 |
+
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
| 495 |
+
self.logit_filters.append(
|
| 496 |
+
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
| 500 |
+
if options.beam_size is not None and options.best_of is not None:
|
| 501 |
+
raise ValueError("beam_size and best_of can't be given together")
|
| 502 |
+
if options.temperature == 0:
|
| 503 |
+
if options.best_of is not None:
|
| 504 |
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
| 505 |
+
if options.patience is not None and options.beam_size is None:
|
| 506 |
+
raise ValueError("patience requires beam_size to be given")
|
| 507 |
+
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
| 508 |
+
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
| 509 |
+
|
| 510 |
+
return options
|
| 511 |
+
|
| 512 |
+
def _get_initial_tokens(self) -> Tuple[int]:
|
| 513 |
+
tokens = list(self.sot_sequence)
|
| 514 |
+
prefix = self.options.prefix
|
| 515 |
+
prompt = self.options.prompt
|
| 516 |
+
|
| 517 |
+
if prefix:
|
| 518 |
+
prefix_tokens = (
|
| 519 |
+
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
| 520 |
+
)
|
| 521 |
+
if self.sample_len is not None:
|
| 522 |
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
| 523 |
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
| 524 |
+
tokens = tokens + prefix_tokens
|
| 525 |
+
|
| 526 |
+
if prompt:
|
| 527 |
+
prompt_tokens = (
|
| 528 |
+
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
| 529 |
+
)
|
| 530 |
+
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
| 531 |
+
|
| 532 |
+
return tuple(tokens)
|
| 533 |
+
|
| 534 |
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
| 535 |
+
suppress_tokens = self.options.suppress_tokens
|
| 536 |
+
|
| 537 |
+
if isinstance(suppress_tokens, str):
|
| 538 |
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
| 539 |
+
|
| 540 |
+
if -1 in suppress_tokens:
|
| 541 |
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
| 542 |
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
| 543 |
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
| 544 |
+
suppress_tokens = [] # interpret empty string as an empty list
|
| 545 |
+
else:
|
| 546 |
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
| 547 |
+
|
| 548 |
+
suppress_tokens.extend(
|
| 549 |
+
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
| 550 |
+
)
|
| 551 |
+
if self.tokenizer.no_speech is not None:
|
| 552 |
+
# no-speech probability is collected separately
|
| 553 |
+
suppress_tokens.append(self.tokenizer.no_speech)
|
| 554 |
+
|
| 555 |
+
return tuple(sorted(set(suppress_tokens)))
|
| 556 |
+
|
| 557 |
+
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
|
| 558 |
+
if self.options.fp16:
|
| 559 |
+
mel = mel.half()
|
| 560 |
+
|
| 561 |
+
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
| 562 |
+
# encoded audio features are given; skip audio encoding
|
| 563 |
+
audio_features = mel
|
| 564 |
+
else:
|
| 565 |
+
result = self.model.encoder(mel, include_embeddings)
|
| 566 |
+
if include_embeddings:
|
| 567 |
+
audio_features, embeddings = result
|
| 568 |
+
else:
|
| 569 |
+
audio_features = result
|
| 570 |
+
|
| 571 |
+
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
| 572 |
+
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
| 573 |
+
|
| 574 |
+
if include_embeddings:
|
| 575 |
+
return audio_features, embeddings
|
| 576 |
+
else:
|
| 577 |
+
return audio_features
|
| 578 |
+
|
| 579 |
+
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
| 580 |
+
languages = [self.options.language] * audio_features.shape[0]
|
| 581 |
+
lang_probs = None
|
| 582 |
+
|
| 583 |
+
if self.options.language is None or self.options.task == "lang_id":
|
| 584 |
+
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
| 585 |
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
| 586 |
+
if self.options.language is None:
|
| 587 |
+
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
| 588 |
+
|
| 589 |
+
return languages, lang_probs
|
| 590 |
+
|
| 591 |
+
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
| 592 |
+
assert audio_features.shape[0] == tokens.shape[0]
|
| 593 |
+
n_batch = tokens.shape[0]
|
| 594 |
+
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
| 595 |
+
no_speech_probs = [np.nan] * n_batch
|
| 596 |
+
|
| 597 |
+
try:
|
| 598 |
+
embeddings = []
|
| 599 |
+
for i in range(self.sample_len):
|
| 600 |
+
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
|
| 601 |
+
|
| 602 |
+
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
| 603 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
| 604 |
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
| 605 |
+
|
| 606 |
+
# now we need to consider the logits at the last token only
|
| 607 |
+
logits = logits[:, -1]
|
| 608 |
+
token_embeddings = token_embeddings[:, :, -1]
|
| 609 |
+
|
| 610 |
+
# Append embeddings together
|
| 611 |
+
embeddings.append(token_embeddings)
|
| 612 |
+
|
| 613 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
| 614 |
+
for logit_filter in self.logit_filters:
|
| 615 |
+
logit_filter.apply(logits, tokens)
|
| 616 |
+
|
| 617 |
+
# expand the tokens tensor with the selected next tokens
|
| 618 |
+
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
| 619 |
+
|
| 620 |
+
if completed or tokens.shape[-1] > self.n_ctx:
|
| 621 |
+
break
|
| 622 |
+
finally:
|
| 623 |
+
if completed:
|
| 624 |
+
embeddings = embeddings[:-1]
|
| 625 |
+
embeddings = np.stack(embeddings, 2)
|
| 626 |
+
self.inference.cleanup_caching()
|
| 627 |
+
|
| 628 |
+
return tokens, sum_logprobs, no_speech_probs, embeddings
|
| 629 |
+
|
| 630 |
+
@torch.no_grad()
|
| 631 |
+
def run(self, mel: Tensor) -> List[DecodingResult]:
|
| 632 |
+
self.decoder.reset()
|
| 633 |
+
tokenizer: Tokenizer = self.tokenizer
|
| 634 |
+
n_audio: int = mel.shape[0]
|
| 635 |
+
|
| 636 |
+
# encoder forward pass
|
| 637 |
+
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
|
| 638 |
+
audio_features, encoder_embeddings = forward_pass
|
| 639 |
+
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
| 640 |
+
|
| 641 |
+
# detect language if requested, overwriting the language token
|
| 642 |
+
languages, language_probs = self._detect_language(audio_features, tokens)
|
| 643 |
+
if self.options.task == "lang_id":
|
| 644 |
+
return [
|
| 645 |
+
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
| 646 |
+
for features, language, probs in zip(audio_features, languages, language_probs)
|
| 647 |
+
]
|
| 648 |
+
|
| 649 |
+
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
| 650 |
+
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
| 651 |
+
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
| 652 |
+
|
| 653 |
+
# call the main sampling loop
|
| 654 |
+
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
|
| 655 |
+
|
| 656 |
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
| 657 |
+
audio_features = audio_features[:: self.n_group]
|
| 658 |
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
| 659 |
+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
| 660 |
+
|
| 661 |
+
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
| 662 |
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
| 663 |
+
|
| 664 |
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
| 665 |
+
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
| 666 |
+
tokens: List[List[Tensor]] = [
|
| 667 |
+
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
| 668 |
+
]
|
| 669 |
+
|
| 670 |
+
# select the top-ranked sample in each group
|
| 671 |
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
| 672 |
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
| 673 |
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
| 674 |
+
|
| 675 |
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
| 676 |
+
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
| 677 |
+
|
| 678 |
+
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
| 679 |
+
if len(set(map(len, fields))) != 1:
|
| 680 |
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
| 681 |
+
|
| 682 |
+
return [
|
| 683 |
+
DecodingResult(
|
| 684 |
+
audio_features=features,
|
| 685 |
+
language=language,
|
| 686 |
+
tokens=tokens,
|
| 687 |
+
text=text,
|
| 688 |
+
avg_logprob=avg_logprob,
|
| 689 |
+
no_speech_prob=no_speech_prob,
|
| 690 |
+
temperature=self.options.temperature,
|
| 691 |
+
compression_ratio=compression_ratio(text),
|
| 692 |
+
encoder_embeddings=encoder_embeddings,
|
| 693 |
+
decoder_embeddings=decoder_embeddings
|
| 694 |
+
)
|
| 695 |
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
| 696 |
+
]
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
@torch.no_grad()
|
| 700 |
+
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
| 701 |
+
"""
|
| 702 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
| 703 |
+
|
| 704 |
+
Parameters
|
| 705 |
+
----------
|
| 706 |
+
model: Whisper
|
| 707 |
+
the Whisper model instance
|
| 708 |
+
|
| 709 |
+
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
| 710 |
+
A tensor containing the Mel spectrogram(s)
|
| 711 |
+
|
| 712 |
+
options: DecodingOptions
|
| 713 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
| 714 |
+
|
| 715 |
+
Returns
|
| 716 |
+
-------
|
| 717 |
+
result: Union[DecodingResult, List[DecodingResult]]
|
| 718 |
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
| 719 |
+
"""
|
| 720 |
+
single = mel.ndim == 2
|
| 721 |
+
if single:
|
| 722 |
+
mel = mel.unsqueeze(0)
|
| 723 |
+
|
| 724 |
+
result = DecodingTask(model, options).run(mel)
|
| 725 |
+
|
| 726 |
+
if single:
|
| 727 |
+
result = result[0]
|
| 728 |
+
|
| 729 |
+
return result
|
musetalk_integration/whisper/model.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from typing import Iterable, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .transcribe import transcribe as transcribe_function
|
| 12 |
+
from .decoding import detect_language as detect_language_function, decode as decode_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ModelDimensions:
|
| 17 |
+
n_mels: int
|
| 18 |
+
n_audio_ctx: int
|
| 19 |
+
n_audio_state: int
|
| 20 |
+
n_audio_head: int
|
| 21 |
+
n_audio_layer: int
|
| 22 |
+
n_vocab: int
|
| 23 |
+
n_text_ctx: int
|
| 24 |
+
n_text_state: int
|
| 25 |
+
n_text_head: int
|
| 26 |
+
n_text_layer: int
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LayerNorm(nn.LayerNorm):
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
return super().forward(x.float()).type(x.dtype)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Linear(nn.Linear):
|
| 35 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 36 |
+
return F.linear(
|
| 37 |
+
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Conv1d(nn.Conv1d):
|
| 42 |
+
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
| 43 |
+
return super()._conv_forward(
|
| 44 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 49 |
+
"""Returns sinusoids for positional embedding"""
|
| 50 |
+
assert channels % 2 == 0
|
| 51 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 52 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 53 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 54 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MultiHeadAttention(nn.Module):
|
| 58 |
+
def __init__(self, n_state: int, n_head: int):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.n_head = n_head
|
| 61 |
+
self.query = Linear(n_state, n_state)
|
| 62 |
+
self.key = Linear(n_state, n_state, bias=False)
|
| 63 |
+
self.value = Linear(n_state, n_state)
|
| 64 |
+
self.out = Linear(n_state, n_state)
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
x: Tensor,
|
| 69 |
+
xa: Optional[Tensor] = None,
|
| 70 |
+
mask: Optional[Tensor] = None,
|
| 71 |
+
kv_cache: Optional[dict] = None,
|
| 72 |
+
):
|
| 73 |
+
q = self.query(x)
|
| 74 |
+
|
| 75 |
+
if kv_cache is None or xa is None:
|
| 76 |
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
| 77 |
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
| 78 |
+
k = self.key(x if xa is None else xa)
|
| 79 |
+
v = self.value(x if xa is None else xa)
|
| 80 |
+
else:
|
| 81 |
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
| 82 |
+
k = kv_cache.get(self.key, self.key(xa))
|
| 83 |
+
v = kv_cache.get(self.value, self.value(xa))
|
| 84 |
+
|
| 85 |
+
wv = self.qkv_attention(q, k, v, mask)
|
| 86 |
+
return self.out(wv)
|
| 87 |
+
|
| 88 |
+
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
| 89 |
+
n_batch, n_ctx, n_state = q.shape
|
| 90 |
+
scale = (n_state // self.n_head) ** -0.25
|
| 91 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
| 92 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
| 93 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 94 |
+
|
| 95 |
+
qk = q @ k
|
| 96 |
+
if mask is not None:
|
| 97 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
| 98 |
+
|
| 99 |
+
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
| 100 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ResidualAttentionBlock(nn.Module):
|
| 104 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
| 108 |
+
self.attn_ln = LayerNorm(n_state)
|
| 109 |
+
|
| 110 |
+
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
| 111 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
| 112 |
+
|
| 113 |
+
n_mlp = n_state * 4
|
| 114 |
+
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
| 115 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
x: Tensor,
|
| 120 |
+
xa: Optional[Tensor] = None,
|
| 121 |
+
mask: Optional[Tensor] = None,
|
| 122 |
+
kv_cache: Optional[dict] = None,
|
| 123 |
+
):
|
| 124 |
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
| 125 |
+
if self.cross_attn:
|
| 126 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
| 127 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class AudioEncoder(nn.Module):
|
| 132 |
+
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
| 135 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
| 136 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
| 137 |
+
|
| 138 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
| 139 |
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
| 140 |
+
)
|
| 141 |
+
self.ln_post = LayerNorm(n_state)
|
| 142 |
+
|
| 143 |
+
def forward(self, x: Tensor, include_embeddings: bool = False):
|
| 144 |
+
"""
|
| 145 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
| 146 |
+
the mel spectrogram of the audio
|
| 147 |
+
include_embeddings: bool
|
| 148 |
+
whether to include intermediate steps in the output
|
| 149 |
+
"""
|
| 150 |
+
x = F.gelu(self.conv1(x))
|
| 151 |
+
x = F.gelu(self.conv2(x))
|
| 152 |
+
x = x.permute(0, 2, 1)
|
| 153 |
+
|
| 154 |
+
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
| 155 |
+
x = (x + self.positional_embedding).to(x.dtype)
|
| 156 |
+
|
| 157 |
+
if include_embeddings:
|
| 158 |
+
embeddings = [x.cpu().detach().numpy()]
|
| 159 |
+
|
| 160 |
+
for block in self.blocks:
|
| 161 |
+
x = block(x)
|
| 162 |
+
if include_embeddings:
|
| 163 |
+
embeddings.append(x.cpu().detach().numpy())
|
| 164 |
+
|
| 165 |
+
x = self.ln_post(x)
|
| 166 |
+
|
| 167 |
+
if include_embeddings:
|
| 168 |
+
embeddings = np.stack(embeddings, axis=1)
|
| 169 |
+
return x, embeddings
|
| 170 |
+
else:
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class TextDecoder(nn.Module):
|
| 175 |
+
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
| 176 |
+
super().__init__()
|
| 177 |
+
|
| 178 |
+
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
| 179 |
+
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
| 180 |
+
|
| 181 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
| 182 |
+
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
| 183 |
+
)
|
| 184 |
+
self.ln = LayerNorm(n_state)
|
| 185 |
+
|
| 186 |
+
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
| 187 |
+
self.register_buffer("mask", mask, persistent=False)
|
| 188 |
+
|
| 189 |
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
|
| 190 |
+
"""
|
| 191 |
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
| 192 |
+
the text tokens
|
| 193 |
+
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
| 194 |
+
the encoded audio features to be attended on
|
| 195 |
+
include_embeddings : bool
|
| 196 |
+
Whether to include intermediate values in the output to this function
|
| 197 |
+
"""
|
| 198 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
| 199 |
+
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
| 200 |
+
x = x.to(xa.dtype)
|
| 201 |
+
|
| 202 |
+
if include_embeddings:
|
| 203 |
+
embeddings = [x.cpu().detach().numpy()]
|
| 204 |
+
|
| 205 |
+
for block in self.blocks:
|
| 206 |
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
| 207 |
+
if include_embeddings:
|
| 208 |
+
embeddings.append(x.cpu().detach().numpy())
|
| 209 |
+
|
| 210 |
+
x = self.ln(x)
|
| 211 |
+
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
| 212 |
+
|
| 213 |
+
if include_embeddings:
|
| 214 |
+
embeddings = np.stack(embeddings, axis=1)
|
| 215 |
+
return logits, embeddings
|
| 216 |
+
else:
|
| 217 |
+
return logits
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class Whisper(nn.Module):
|
| 221 |
+
def __init__(self, dims: ModelDimensions):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.dims = dims
|
| 224 |
+
self.encoder = AudioEncoder(
|
| 225 |
+
self.dims.n_mels,
|
| 226 |
+
self.dims.n_audio_ctx,
|
| 227 |
+
self.dims.n_audio_state,
|
| 228 |
+
self.dims.n_audio_head,
|
| 229 |
+
self.dims.n_audio_layer,
|
| 230 |
+
)
|
| 231 |
+
self.decoder = TextDecoder(
|
| 232 |
+
self.dims.n_vocab,
|
| 233 |
+
self.dims.n_text_ctx,
|
| 234 |
+
self.dims.n_text_state,
|
| 235 |
+
self.dims.n_text_head,
|
| 236 |
+
self.dims.n_text_layer,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def embed_audio(self, mel: torch.Tensor):
|
| 240 |
+
return self.encoder.forward(mel)
|
| 241 |
+
|
| 242 |
+
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
| 243 |
+
return self.decoder.forward(tokens, audio_features)
|
| 244 |
+
|
| 245 |
+
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 246 |
+
return self.decoder(tokens, self.encoder(mel))
|
| 247 |
+
|
| 248 |
+
@property
|
| 249 |
+
def device(self):
|
| 250 |
+
return next(self.parameters()).device
|
| 251 |
+
|
| 252 |
+
@property
|
| 253 |
+
def is_multilingual(self):
|
| 254 |
+
return self.dims.n_vocab == 51865
|
| 255 |
+
|
| 256 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
| 257 |
+
"""
|
| 258 |
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
| 259 |
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
| 260 |
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
| 261 |
+
intermediate tensors to be reused during later calculations.
|
| 262 |
+
|
| 263 |
+
Returns
|
| 264 |
+
-------
|
| 265 |
+
cache : Dict[nn.Module, torch.Tensor]
|
| 266 |
+
A dictionary object mapping the key/value projection modules to its cache
|
| 267 |
+
hooks : List[RemovableHandle]
|
| 268 |
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
| 269 |
+
"""
|
| 270 |
+
cache = {**cache} if cache is not None else {}
|
| 271 |
+
hooks = []
|
| 272 |
+
|
| 273 |
+
def save_to_cache(module, _, output):
|
| 274 |
+
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
| 275 |
+
cache[module] = output # save as-is, for the first token or cross attention
|
| 276 |
+
else:
|
| 277 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
| 278 |
+
return cache[module]
|
| 279 |
+
|
| 280 |
+
def install_hooks(layer: nn.Module):
|
| 281 |
+
if isinstance(layer, MultiHeadAttention):
|
| 282 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
| 283 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
| 284 |
+
|
| 285 |
+
self.decoder.apply(install_hooks)
|
| 286 |
+
return cache, hooks
|
| 287 |
+
|
| 288 |
+
detect_language = detect_language_function
|
| 289 |
+
transcribe = transcribe_function
|
| 290 |
+
decode = decode_function
|
musetalk_integration/whisper/tokenizer.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import GPT2TokenizerFast
|
| 9 |
+
|
| 10 |
+
LANGUAGES = {
|
| 11 |
+
"en": "english",
|
| 12 |
+
"zh": "chinese",
|
| 13 |
+
"de": "german",
|
| 14 |
+
"es": "spanish",
|
| 15 |
+
"ru": "russian",
|
| 16 |
+
"ko": "korean",
|
| 17 |
+
"fr": "french",
|
| 18 |
+
"ja": "japanese",
|
| 19 |
+
"pt": "portuguese",
|
| 20 |
+
"tr": "turkish",
|
| 21 |
+
"pl": "polish",
|
| 22 |
+
"ca": "catalan",
|
| 23 |
+
"nl": "dutch",
|
| 24 |
+
"ar": "arabic",
|
| 25 |
+
"sv": "swedish",
|
| 26 |
+
"it": "italian",
|
| 27 |
+
"id": "indonesian",
|
| 28 |
+
"hi": "hindi",
|
| 29 |
+
"fi": "finnish",
|
| 30 |
+
"vi": "vietnamese",
|
| 31 |
+
"iw": "hebrew",
|
| 32 |
+
"uk": "ukrainian",
|
| 33 |
+
"el": "greek",
|
| 34 |
+
"ms": "malay",
|
| 35 |
+
"cs": "czech",
|
| 36 |
+
"ro": "romanian",
|
| 37 |
+
"da": "danish",
|
| 38 |
+
"hu": "hungarian",
|
| 39 |
+
"ta": "tamil",
|
| 40 |
+
"no": "norwegian",
|
| 41 |
+
"th": "thai",
|
| 42 |
+
"ur": "urdu",
|
| 43 |
+
"hr": "croatian",
|
| 44 |
+
"bg": "bulgarian",
|
| 45 |
+
"lt": "lithuanian",
|
| 46 |
+
"la": "latin",
|
| 47 |
+
"mi": "maori",
|
| 48 |
+
"ml": "malayalam",
|
| 49 |
+
"cy": "welsh",
|
| 50 |
+
"sk": "slovak",
|
| 51 |
+
"te": "telugu",
|
| 52 |
+
"fa": "persian",
|
| 53 |
+
"lv": "latvian",
|
| 54 |
+
"bn": "bengali",
|
| 55 |
+
"sr": "serbian",
|
| 56 |
+
"az": "azerbaijani",
|
| 57 |
+
"sl": "slovenian",
|
| 58 |
+
"kn": "kannada",
|
| 59 |
+
"et": "estonian",
|
| 60 |
+
"mk": "macedonian",
|
| 61 |
+
"br": "breton",
|
| 62 |
+
"eu": "basque",
|
| 63 |
+
"is": "icelandic",
|
| 64 |
+
"hy": "armenian",
|
| 65 |
+
"ne": "nepali",
|
| 66 |
+
"mn": "mongolian",
|
| 67 |
+
"bs": "bosnian",
|
| 68 |
+
"kk": "kazakh",
|
| 69 |
+
"sq": "albanian",
|
| 70 |
+
"sw": "swahili",
|
| 71 |
+
"gl": "galician",
|
| 72 |
+
"mr": "marathi",
|
| 73 |
+
"pa": "punjabi",
|
| 74 |
+
"si": "sinhala",
|
| 75 |
+
"km": "khmer",
|
| 76 |
+
"sn": "shona",
|
| 77 |
+
"yo": "yoruba",
|
| 78 |
+
"so": "somali",
|
| 79 |
+
"af": "afrikaans",
|
| 80 |
+
"oc": "occitan",
|
| 81 |
+
"ka": "georgian",
|
| 82 |
+
"be": "belarusian",
|
| 83 |
+
"tg": "tajik",
|
| 84 |
+
"sd": "sindhi",
|
| 85 |
+
"gu": "gujarati",
|
| 86 |
+
"am": "amharic",
|
| 87 |
+
"yi": "yiddish",
|
| 88 |
+
"lo": "lao",
|
| 89 |
+
"uz": "uzbek",
|
| 90 |
+
"fo": "faroese",
|
| 91 |
+
"ht": "haitian creole",
|
| 92 |
+
"ps": "pashto",
|
| 93 |
+
"tk": "turkmen",
|
| 94 |
+
"nn": "nynorsk",
|
| 95 |
+
"mt": "maltese",
|
| 96 |
+
"sa": "sanskrit",
|
| 97 |
+
"lb": "luxembourgish",
|
| 98 |
+
"my": "myanmar",
|
| 99 |
+
"bo": "tibetan",
|
| 100 |
+
"tl": "tagalog",
|
| 101 |
+
"mg": "malagasy",
|
| 102 |
+
"as": "assamese",
|
| 103 |
+
"tt": "tatar",
|
| 104 |
+
"haw": "hawaiian",
|
| 105 |
+
"ln": "lingala",
|
| 106 |
+
"ha": "hausa",
|
| 107 |
+
"ba": "bashkir",
|
| 108 |
+
"jw": "javanese",
|
| 109 |
+
"su": "sundanese",
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# language code lookup by name, with a few language aliases
|
| 113 |
+
TO_LANGUAGE_CODE = {
|
| 114 |
+
**{language: code for code, language in LANGUAGES.items()},
|
| 115 |
+
"burmese": "my",
|
| 116 |
+
"valencian": "ca",
|
| 117 |
+
"flemish": "nl",
|
| 118 |
+
"haitian": "ht",
|
| 119 |
+
"letzeburgesch": "lb",
|
| 120 |
+
"pushto": "ps",
|
| 121 |
+
"panjabi": "pa",
|
| 122 |
+
"moldavian": "ro",
|
| 123 |
+
"moldovan": "ro",
|
| 124 |
+
"sinhalese": "si",
|
| 125 |
+
"castilian": "es",
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass(frozen=True)
|
| 130 |
+
class Tokenizer:
|
| 131 |
+
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
| 132 |
+
|
| 133 |
+
tokenizer: "GPT2TokenizerFast"
|
| 134 |
+
language: Optional[str]
|
| 135 |
+
sot_sequence: Tuple[int]
|
| 136 |
+
|
| 137 |
+
def encode(self, text, **kwargs):
|
| 138 |
+
return self.tokenizer.encode(text, **kwargs)
|
| 139 |
+
|
| 140 |
+
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
| 141 |
+
return self.tokenizer.decode(token_ids, **kwargs)
|
| 142 |
+
|
| 143 |
+
def decode_with_timestamps(self, tokens) -> str:
|
| 144 |
+
"""
|
| 145 |
+
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
| 146 |
+
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
| 147 |
+
"""
|
| 148 |
+
outputs = [[]]
|
| 149 |
+
for token in tokens:
|
| 150 |
+
if token >= self.timestamp_begin:
|
| 151 |
+
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
| 152 |
+
outputs.append(timestamp)
|
| 153 |
+
outputs.append([])
|
| 154 |
+
else:
|
| 155 |
+
outputs[-1].append(token)
|
| 156 |
+
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
| 157 |
+
return "".join(outputs)
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
@lru_cache()
|
| 161 |
+
def eot(self) -> int:
|
| 162 |
+
return self.tokenizer.eos_token_id
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
@lru_cache()
|
| 166 |
+
def sot(self) -> int:
|
| 167 |
+
return self._get_single_token_id("<|startoftranscript|>")
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
@lru_cache()
|
| 171 |
+
def sot_lm(self) -> int:
|
| 172 |
+
return self._get_single_token_id("<|startoflm|>")
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
@lru_cache()
|
| 176 |
+
def sot_prev(self) -> int:
|
| 177 |
+
return self._get_single_token_id("<|startofprev|>")
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
@lru_cache()
|
| 181 |
+
def no_speech(self) -> int:
|
| 182 |
+
return self._get_single_token_id("<|nospeech|>")
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
@lru_cache()
|
| 186 |
+
def no_timestamps(self) -> int:
|
| 187 |
+
return self._get_single_token_id("<|notimestamps|>")
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
@lru_cache()
|
| 191 |
+
def timestamp_begin(self) -> int:
|
| 192 |
+
return self.tokenizer.all_special_ids[-1] + 1
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
@lru_cache()
|
| 196 |
+
def language_token(self) -> int:
|
| 197 |
+
"""Returns the token id corresponding to the value of the `language` field"""
|
| 198 |
+
if self.language is None:
|
| 199 |
+
raise ValueError(f"This tokenizer does not have language token configured")
|
| 200 |
+
|
| 201 |
+
additional_tokens = dict(
|
| 202 |
+
zip(
|
| 203 |
+
self.tokenizer.additional_special_tokens,
|
| 204 |
+
self.tokenizer.additional_special_tokens_ids,
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
candidate = f"<|{self.language}|>"
|
| 208 |
+
if candidate in additional_tokens:
|
| 209 |
+
return additional_tokens[candidate]
|
| 210 |
+
|
| 211 |
+
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
@lru_cache()
|
| 215 |
+
def all_language_tokens(self) -> Tuple[int]:
|
| 216 |
+
result = []
|
| 217 |
+
for token, token_id in zip(
|
| 218 |
+
self.tokenizer.additional_special_tokens,
|
| 219 |
+
self.tokenizer.additional_special_tokens_ids,
|
| 220 |
+
):
|
| 221 |
+
if token.strip("<|>") in LANGUAGES:
|
| 222 |
+
result.append(token_id)
|
| 223 |
+
return tuple(result)
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
@lru_cache()
|
| 227 |
+
def all_language_codes(self) -> Tuple[str]:
|
| 228 |
+
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
@lru_cache()
|
| 232 |
+
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
| 233 |
+
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
@lru_cache()
|
| 237 |
+
def non_speech_tokens(self) -> Tuple[int]:
|
| 238 |
+
"""
|
| 239 |
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
| 240 |
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
| 241 |
+
|
| 242 |
+
- ♪♪♪
|
| 243 |
+
- ( SPEAKING FOREIGN LANGUAGE )
|
| 244 |
+
- [DAVID] Hey there,
|
| 245 |
+
|
| 246 |
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
| 247 |
+
"""
|
| 248 |
+
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
| 249 |
+
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
| 250 |
+
|
| 251 |
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
| 252 |
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
| 253 |
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
| 254 |
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
| 255 |
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
| 256 |
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
| 257 |
+
|
| 258 |
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
| 259 |
+
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
| 260 |
+
for symbol in symbols + list(miscellaneous):
|
| 261 |
+
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
| 262 |
+
if len(tokens) == 1 or symbol in miscellaneous:
|
| 263 |
+
result.add(tokens[0])
|
| 264 |
+
|
| 265 |
+
return tuple(sorted(result))
|
| 266 |
+
|
| 267 |
+
def _get_single_token_id(self, text) -> int:
|
| 268 |
+
tokens = self.tokenizer.encode(text)
|
| 269 |
+
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
| 270 |
+
return tokens[0]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@lru_cache(maxsize=None)
|
| 274 |
+
def build_tokenizer(name: str = "gpt2"):
|
| 275 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 276 |
+
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
| 277 |
+
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
| 278 |
+
|
| 279 |
+
specials = [
|
| 280 |
+
"<|startoftranscript|>",
|
| 281 |
+
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
| 282 |
+
"<|translate|>",
|
| 283 |
+
"<|transcribe|>",
|
| 284 |
+
"<|startoflm|>",
|
| 285 |
+
"<|startofprev|>",
|
| 286 |
+
"<|nospeech|>",
|
| 287 |
+
"<|notimestamps|>",
|
| 288 |
+
]
|
| 289 |
+
|
| 290 |
+
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
| 291 |
+
return tokenizer
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@lru_cache(maxsize=None)
|
| 295 |
+
def get_tokenizer(
|
| 296 |
+
multilingual: bool,
|
| 297 |
+
*,
|
| 298 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
| 299 |
+
language: Optional[str] = None,
|
| 300 |
+
) -> Tokenizer:
|
| 301 |
+
if language is not None:
|
| 302 |
+
language = language.lower()
|
| 303 |
+
if language not in LANGUAGES:
|
| 304 |
+
if language in TO_LANGUAGE_CODE:
|
| 305 |
+
language = TO_LANGUAGE_CODE[language]
|
| 306 |
+
else:
|
| 307 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 308 |
+
|
| 309 |
+
if multilingual:
|
| 310 |
+
tokenizer_name = "multilingual"
|
| 311 |
+
task = task or "transcribe"
|
| 312 |
+
language = language or "en"
|
| 313 |
+
else:
|
| 314 |
+
tokenizer_name = "gpt2"
|
| 315 |
+
task = None
|
| 316 |
+
language = None
|
| 317 |
+
|
| 318 |
+
tokenizer = build_tokenizer(name=tokenizer_name)
|
| 319 |
+
all_special_ids: List[int] = tokenizer.all_special_ids
|
| 320 |
+
sot: int = all_special_ids[1]
|
| 321 |
+
translate: int = all_special_ids[-6]
|
| 322 |
+
transcribe: int = all_special_ids[-5]
|
| 323 |
+
|
| 324 |
+
langs = tuple(LANGUAGES.keys())
|
| 325 |
+
sot_sequence = [sot]
|
| 326 |
+
if language is not None:
|
| 327 |
+
sot_sequence.append(sot + 1 + langs.index(language))
|
| 328 |
+
if task is not None:
|
| 329 |
+
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
| 330 |
+
|
| 331 |
+
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
musetalk_integration/whisper/transcribe.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
|
| 10 |
+
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
| 11 |
+
from .decoding import DecodingOptions, DecodingResult
|
| 12 |
+
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
| 13 |
+
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from .model import Whisper
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def transcribe(
|
| 20 |
+
model: "Whisper",
|
| 21 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 22 |
+
*,
|
| 23 |
+
verbose: Optional[bool] = None,
|
| 24 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
| 25 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
| 26 |
+
logprob_threshold: Optional[float] = -1.0,
|
| 27 |
+
no_speech_threshold: Optional[float] = 0.6,
|
| 28 |
+
condition_on_previous_text: bool = True,
|
| 29 |
+
force_extraction: bool = False,
|
| 30 |
+
**decode_options,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Transcribe an audio file using Whisper
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
model: Whisper
|
| 38 |
+
The Whisper model instance
|
| 39 |
+
|
| 40 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
| 41 |
+
The path to the audio file to open, or the audio waveform
|
| 42 |
+
|
| 43 |
+
verbose: bool
|
| 44 |
+
Whether to display the text being decoded to the console. If True, displays all the details,
|
| 45 |
+
If False, displays minimal details. If None, does not display anything
|
| 46 |
+
|
| 47 |
+
temperature: Union[float, Tuple[float, ...]]
|
| 48 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
| 49 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
| 50 |
+
|
| 51 |
+
compression_ratio_threshold: float
|
| 52 |
+
If the gzip compression ratio is above this value, treat as failed
|
| 53 |
+
|
| 54 |
+
logprob_threshold: float
|
| 55 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
| 56 |
+
|
| 57 |
+
no_speech_threshold: float
|
| 58 |
+
If the no_speech probability is higher than this value AND the average log probability
|
| 59 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
| 60 |
+
|
| 61 |
+
condition_on_previous_text: bool
|
| 62 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
| 63 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
| 64 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
| 65 |
+
|
| 66 |
+
decode_options: dict
|
| 67 |
+
Keyword arguments to construct `DecodingOptions` instances
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
-------
|
| 71 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
| 72 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
| 73 |
+
"""
|
| 74 |
+
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
| 75 |
+
if model.device == torch.device("cpu"):
|
| 76 |
+
if torch.cuda.is_available():
|
| 77 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
| 78 |
+
if dtype == torch.float16:
|
| 79 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
| 80 |
+
dtype = torch.float32
|
| 81 |
+
|
| 82 |
+
if dtype == torch.float32:
|
| 83 |
+
decode_options["fp16"] = False
|
| 84 |
+
|
| 85 |
+
mel = log_mel_spectrogram(audio)
|
| 86 |
+
|
| 87 |
+
all_segments = []
|
| 88 |
+
def add_segment(
|
| 89 |
+
*, start: float, end: float, encoder_embeddings
|
| 90 |
+
):
|
| 91 |
+
|
| 92 |
+
all_segments.append(
|
| 93 |
+
{
|
| 94 |
+
"start": start,
|
| 95 |
+
"end": end,
|
| 96 |
+
"encoder_embeddings":encoder_embeddings,
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
| 100 |
+
num_frames = mel.shape[-1]
|
| 101 |
+
seek = 0
|
| 102 |
+
previous_seek_value = seek
|
| 103 |
+
sample_skip = 3000 #
|
| 104 |
+
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
| 105 |
+
while seek < num_frames:
|
| 106 |
+
# seek是开始的帧数
|
| 107 |
+
end_seek = min(seek + sample_skip, num_frames)
|
| 108 |
+
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
|
| 109 |
+
|
| 110 |
+
single = segment.ndim == 2
|
| 111 |
+
if single:
|
| 112 |
+
segment = segment.unsqueeze(0)
|
| 113 |
+
if dtype == torch.float16:
|
| 114 |
+
segment = segment.half()
|
| 115 |
+
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
|
| 116 |
+
|
| 117 |
+
encoder_embeddings = embeddings
|
| 118 |
+
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
|
| 119 |
+
add_segment(
|
| 120 |
+
start=seek,
|
| 121 |
+
end=end_seek,
|
| 122 |
+
#text_tokens=tokens,
|
| 123 |
+
#result=result,
|
| 124 |
+
encoder_embeddings=encoder_embeddings,
|
| 125 |
+
)
|
| 126 |
+
seek+=sample_skip
|
| 127 |
+
|
| 128 |
+
return dict(segments=all_segments)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def cli():
|
| 132 |
+
from . import available_models
|
| 133 |
+
|
| 134 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 135 |
+
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
| 136 |
+
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
| 137 |
+
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
| 138 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
| 139 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
| 140 |
+
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
| 141 |
+
|
| 142 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
| 143 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
| 144 |
+
|
| 145 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
| 146 |
+
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
| 147 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
| 148 |
+
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
| 149 |
+
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
| 150 |
+
|
| 151 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
| 152 |
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
| 153 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
| 154 |
+
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
| 155 |
+
|
| 156 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
| 157 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
| 158 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
| 159 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
| 160 |
+
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
| 161 |
+
|
| 162 |
+
args = parser.parse_args().__dict__
|
| 163 |
+
model_name: str = args.pop("model")
|
| 164 |
+
model_dir: str = args.pop("model_dir")
|
| 165 |
+
output_dir: str = args.pop("output_dir")
|
| 166 |
+
device: str = args.pop("device")
|
| 167 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 168 |
+
|
| 169 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
| 170 |
+
if args["language"] is not None:
|
| 171 |
+
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
| 172 |
+
args["language"] = "en"
|
| 173 |
+
|
| 174 |
+
temperature = args.pop("temperature")
|
| 175 |
+
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
| 176 |
+
if temperature_increment_on_fallback is not None:
|
| 177 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
| 178 |
+
else:
|
| 179 |
+
temperature = [temperature]
|
| 180 |
+
|
| 181 |
+
threads = args.pop("threads")
|
| 182 |
+
if threads > 0:
|
| 183 |
+
torch.set_num_threads(threads)
|
| 184 |
+
|
| 185 |
+
from . import load_model
|
| 186 |
+
model = load_model(model_name, device=device, download_root=model_dir)
|
| 187 |
+
|
| 188 |
+
for audio_path in args.pop("audio"):
|
| 189 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
| 190 |
+
|
| 191 |
+
audio_basename = os.path.basename(audio_path)
|
| 192 |
+
|
| 193 |
+
# save TXT
|
| 194 |
+
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
| 195 |
+
write_txt(result["segments"], file=txt)
|
| 196 |
+
|
| 197 |
+
# save VTT
|
| 198 |
+
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
| 199 |
+
write_vtt(result["segments"], file=vtt)
|
| 200 |
+
|
| 201 |
+
# save SRT
|
| 202 |
+
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
| 203 |
+
write_srt(result["segments"], file=srt)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if __name__ == '__main__':
|
| 207 |
+
cli()
|
musetalk_integration/whisper/utils.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import zlib
|
| 2 |
+
from typing import Iterator, TextIO
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def exact_div(x, y):
|
| 6 |
+
assert x % y == 0
|
| 7 |
+
return x // y
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def str2bool(string):
|
| 11 |
+
str2val = {"True": True, "False": False}
|
| 12 |
+
if string in str2val:
|
| 13 |
+
return str2val[string]
|
| 14 |
+
else:
|
| 15 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def optional_int(string):
|
| 19 |
+
return None if string == "None" else int(string)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def optional_float(string):
|
| 23 |
+
return None if string == "None" else float(string)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def compression_ratio(text) -> float:
|
| 27 |
+
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
| 31 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
| 32 |
+
milliseconds = round(seconds * 1000.0)
|
| 33 |
+
|
| 34 |
+
hours = milliseconds // 3_600_000
|
| 35 |
+
milliseconds -= hours * 3_600_000
|
| 36 |
+
|
| 37 |
+
minutes = milliseconds // 60_000
|
| 38 |
+
milliseconds -= minutes * 60_000
|
| 39 |
+
|
| 40 |
+
seconds = milliseconds // 1_000
|
| 41 |
+
milliseconds -= seconds * 1_000
|
| 42 |
+
|
| 43 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
| 44 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
| 48 |
+
for segment in transcript:
|
| 49 |
+
print(segment['text'].strip(), file=file, flush=True)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
| 53 |
+
print("WEBVTT\n", file=file)
|
| 54 |
+
for segment in transcript:
|
| 55 |
+
print(
|
| 56 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
| 57 |
+
f"{segment['text'].strip().replace('-->', '->')}\n",
|
| 58 |
+
file=file,
|
| 59 |
+
flush=True,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def write_srt(transcript: Iterator[dict], file: TextIO):
|
| 64 |
+
"""
|
| 65 |
+
Write a transcript to a file in SRT format.
|
| 66 |
+
|
| 67 |
+
Example usage:
|
| 68 |
+
from pathlib import Path
|
| 69 |
+
from whisper.utils import write_srt
|
| 70 |
+
|
| 71 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
| 72 |
+
|
| 73 |
+
# save SRT
|
| 74 |
+
audio_basename = Path(audio_path).stem
|
| 75 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
| 76 |
+
write_srt(result["segments"], file=srt)
|
| 77 |
+
"""
|
| 78 |
+
for i, segment in enumerate(transcript, start=1):
|
| 79 |
+
# write srt lines
|
| 80 |
+
print(
|
| 81 |
+
f"{i}\n"
|
| 82 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
| 83 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
| 84 |
+
f"{segment['text'].strip().replace('-->', '->')}\n",
|
| 85 |
+
file=file,
|
| 86 |
+
flush=True,
|
| 87 |
+
)
|
processing.py
CHANGED
|
@@ -320,7 +320,7 @@ def process_lipsync_with_audio_target_new(
|
|
| 320 |
video_file,
|
| 321 |
audio_file,
|
| 322 |
session_id=None,
|
| 323 |
-
|
| 324 |
progress=gr.Progress(track_tqdm=True),
|
| 325 |
):
|
| 326 |
"""Workflow mới: Chuẩn hóa YouTube rồi lipsync
|
|
@@ -337,7 +337,7 @@ def process_lipsync_with_audio_target_new(
|
|
| 337 |
video_file: Path to video source
|
| 338 |
audio_file: Path to audio target (English only)
|
| 339 |
session_id: Session identifier
|
| 340 |
-
|
| 341 |
progress: Progress tracking object
|
| 342 |
|
| 343 |
Returns:
|
|
@@ -352,6 +352,16 @@ def process_lipsync_with_audio_target_new(
|
|
| 352 |
|
| 353 |
output_dir = setup_output_dir(session_id)
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
logger.info(f"Memory at start: {get_memory_usage()}")
|
| 356 |
|
| 357 |
audio_duration = get_audio_duration(audio_path)
|
|
@@ -417,7 +427,7 @@ def process_lipsync_with_audio_target_new(
|
|
| 417 |
with timer("Applying lipsync"):
|
| 418 |
try:
|
| 419 |
lipsynced_video, lipsynced_info = apply_lipsync_to_video(
|
| 420 |
-
video_normalized, audio_16k, output_dir,
|
| 421 |
)
|
| 422 |
logger.info(
|
| 423 |
f"Lipsynced video: {lipsynced_video}, size: {lipsynced_info['width']}x{lipsynced_info['height']}"
|
|
@@ -461,7 +471,7 @@ def lipsync_with_audio_target(
|
|
| 461 |
video_file,
|
| 462 |
audio_file,
|
| 463 |
session_id=None,
|
| 464 |
-
|
| 465 |
progress=gr.Progress(track_tqdm=True),
|
| 466 |
):
|
| 467 |
"""Wrapper for Gradio: Lipsync video source with audio target (English only)
|
|
@@ -474,5 +484,5 @@ def lipsync_with_audio_target(
|
|
| 474 |
if audio_file is None:
|
| 475 |
raise gr.Error("Please upload a target audio.")
|
| 476 |
return process_lipsync_with_audio_target_new(
|
| 477 |
-
video_file, audio_file, session_id,
|
| 478 |
)
|
|
|
|
| 320 |
video_file,
|
| 321 |
audio_file,
|
| 322 |
session_id=None,
|
| 323 |
+
model_type="latentsync",
|
| 324 |
progress=gr.Progress(track_tqdm=True),
|
| 325 |
):
|
| 326 |
"""Workflow mới: Chuẩn hóa YouTube rồi lipsync
|
|
|
|
| 337 |
video_file: Path to video source
|
| 338 |
audio_file: Path to audio target (English only)
|
| 339 |
session_id: Session identifier
|
| 340 |
+
model_type: Model type for lipsync ("latentsync" or "musetalk")
|
| 341 |
progress: Progress tracking object
|
| 342 |
|
| 343 |
Returns:
|
|
|
|
| 352 |
|
| 353 |
output_dir = setup_output_dir(session_id)
|
| 354 |
|
| 355 |
+
# Mapping model_type to crop_size
|
| 356 |
+
if model_type == "LatentSync v1.6":
|
| 357 |
+
crop_size = 512
|
| 358 |
+
logger.info("Using LatentSync v1.6 with crop_size=512")
|
| 359 |
+
elif model_type == "MuseTalk v1.5":
|
| 360 |
+
crop_size = 256
|
| 361 |
+
logger.info("Using MuseTalk v1.5 with crop_size=256")
|
| 362 |
+
else:
|
| 363 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 364 |
+
|
| 365 |
logger.info(f"Memory at start: {get_memory_usage()}")
|
| 366 |
|
| 367 |
audio_duration = get_audio_duration(audio_path)
|
|
|
|
| 427 |
with timer("Applying lipsync"):
|
| 428 |
try:
|
| 429 |
lipsynced_video, lipsynced_info = apply_lipsync_to_video(
|
| 430 |
+
video_normalized, audio_16k, output_dir, model_type
|
| 431 |
)
|
| 432 |
logger.info(
|
| 433 |
f"Lipsynced video: {lipsynced_video}, size: {lipsynced_info['width']}x{lipsynced_info['height']}"
|
|
|
|
| 471 |
video_file,
|
| 472 |
audio_file,
|
| 473 |
session_id=None,
|
| 474 |
+
model_type="LatentSync v1.6",
|
| 475 |
progress=gr.Progress(track_tqdm=True),
|
| 476 |
):
|
| 477 |
"""Wrapper for Gradio: Lipsync video source with audio target (English only)
|
|
|
|
| 484 |
if audio_file is None:
|
| 485 |
raise gr.Error("Please upload a target audio.")
|
| 486 |
return process_lipsync_with_audio_target_new(
|
| 487 |
+
video_file, audio_file, session_id, model_type, progress
|
| 488 |
)
|
requirements.txt
CHANGED
|
@@ -45,3 +45,12 @@ psutil
|
|
| 45 |
# Gradio & Spaces
|
| 46 |
gradio==5.24.0
|
| 47 |
spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# Gradio & Spaces
|
| 46 |
gradio==5.24.0
|
| 47 |
spaces
|
| 48 |
+
|
| 49 |
+
# MuseTalk Dependencies
|
| 50 |
+
mmengine>=0.10.0
|
| 51 |
+
mmcv>=2.0.1
|
| 52 |
+
mmdet>=3.1.0
|
| 53 |
+
mmpose>=1.1.0
|
| 54 |
+
openmim>=0.3.0
|
| 55 |
+
moviepy>=1.0.3
|
| 56 |
+
gdown>=5.1.0
|