naicoi commited on
Commit
759d98d
·
1 Parent(s): 64a2ea3

musetalk model

Browse files
Files changed (40) hide show
  1. app.py +6 -3
  2. download_musetalk_models.py +31 -0
  3. lipsync_processing.py +32 -7
  4. musetalk.py +219 -0
  5. musetalk.py.bak +219 -0
  6. musetalk_integration/__init__.py +0 -0
  7. musetalk_integration/models/__init__.py +0 -0
  8. musetalk_integration/models/unet.py +51 -0
  9. musetalk_integration/models/vae.py +148 -0
  10. musetalk_integration/utils/__init__.py +0 -0
  11. musetalk_integration/utils/audio_processor.py +102 -0
  12. musetalk_integration/utils/blending.py +136 -0
  13. musetalk_integration/utils/dwpose/__init__.py +0 -0
  14. musetalk_integration/utils/dwpose/default_runtime.py +54 -0
  15. musetalk_integration/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py +257 -0
  16. musetalk_integration/utils/face_detection/README.md +1 -0
  17. musetalk_integration/utils/face_detection/__init__.py +7 -0
  18. musetalk_integration/utils/face_detection/api.py +240 -0
  19. musetalk_integration/utils/face_detection/detection/__init__.py +1 -0
  20. musetalk_integration/utils/face_detection/detection/core.py +130 -0
  21. musetalk_integration/utils/face_detection/detection/sfd/__init__.py +1 -0
  22. musetalk_integration/utils/face_detection/detection/sfd/bbox.py +129 -0
  23. musetalk_integration/utils/face_detection/detection/sfd/detect.py +114 -0
  24. musetalk_integration/utils/face_detection/detection/sfd/net_s3fd.py +129 -0
  25. musetalk_integration/utils/face_detection/detection/sfd/sfd_detector.py +59 -0
  26. musetalk_integration/utils/face_detection/models.py +261 -0
  27. musetalk_integration/utils/face_detection/utils.py +313 -0
  28. musetalk_integration/utils/face_parsing/__init__.py +117 -0
  29. musetalk_integration/utils/face_parsing/model.py +283 -0
  30. musetalk_integration/utils/face_parsing/resnet.py +109 -0
  31. musetalk_integration/whisper/__init__.py +116 -0
  32. musetalk_integration/whisper/__main__.py +4 -0
  33. musetalk_integration/whisper/audio.py +125 -0
  34. musetalk_integration/whisper/decoding.py +729 -0
  35. musetalk_integration/whisper/model.py +290 -0
  36. musetalk_integration/whisper/tokenizer.py +331 -0
  37. musetalk_integration/whisper/transcribe.py +207 -0
  38. musetalk_integration/whisper/utils.py +87 -0
  39. processing.py +15 -5
  40. requirements.txt +9 -0
app.py CHANGED
@@ -88,8 +88,11 @@ with gr.Blocks(css=css) as demo:
88
  audio_input = gr.Audio(
89
  label="Target Audio (English only)", type="filepath"
90
  )
91
- crop_size_radio = gr.Radio(
92
- label="Crop Size", choices=[256, 512], value=512, interactive=True
 
 
 
93
  )
94
  lipsync_only_btn = gr.Button("👄 Lipsync", variant="primary", size="lg")
95
 
@@ -113,7 +116,7 @@ with gr.Blocks(css=css) as demo:
113
 
114
  lipsync_only_btn.click(
115
  fn=lipsync_with_audio_target,
116
- inputs=[video_input, audio_input, session_state, crop_size_radio],
117
  outputs=[
118
  final_video,
119
  video_normalized_output,
 
88
  audio_input = gr.Audio(
89
  label="Target Audio (English only)", type="filepath"
90
  )
91
+ model_radio = gr.Radio(
92
+ label="Lipsync Model",
93
+ choices=["LatentSync v1.6", "MuseTalk v1.5"],
94
+ value="LatentSync v1.6",
95
+ interactive=True,
96
  )
97
  lipsync_only_btn = gr.Button("👄 Lipsync", variant="primary", size="lg")
98
 
 
116
 
117
  lipsync_only_btn.click(
118
  fn=lipsync_with_audio_target,
119
+ inputs=[video_input, audio_input, session_state, model_radio],
120
  outputs=[
121
  final_video,
122
  video_normalized_output,
download_musetalk_models.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download MuseTalk V1.5 models"""
2
+
3
+ from huggingface_hub import snapshot_download
4
+ import os
5
+
6
+ print("Downloading MuseTalk V1.5 models...")
7
+
8
+ os.makedirs("checkpoints/musetalkV15", exist_ok=True)
9
+ snapshot_download(
10
+ repo_id="TMElyralab/MuseTalk",
11
+ local_dir="./checkpoints/musetalkV15",
12
+ allow_patterns=["*.pth", "*.json", "*.pt"],
13
+ )
14
+
15
+ print("✓ MuseTalk V1.5 models downloaded to checkpoints/musetalkV15/")
16
+
17
+ print("Downloading SD-VAE-FT-MSE model...")
18
+ os.makedirs("checkpoints/sd-vae-ft-mse", exist_ok=True)
19
+ snapshot_download(
20
+ repo_id="stabilityai/sd-vae-ft-mse", local_dir="./checkpoints/sd-vae-ft-mse"
21
+ )
22
+
23
+ print("✓ SD-VAE-FT-MSE downloaded to checkpoints/sd-vae-ft-mse/")
24
+
25
+ print("Downloading Whisper-Tiny model...")
26
+ os.makedirs("checkpoints/whisper-tiny", exist_ok=True)
27
+ snapshot_download(repo_id="openai/whisper-tiny", local_dir="./checkpoints/whisper-tiny")
28
+
29
+ print("✓ Whisper-Tiny downloaded to checkpoints/whisper-tiny/")
30
+
31
+ print("\nAll MuseTalk models downloaded successfully!")
lipsync_processing.py CHANGED
@@ -45,7 +45,10 @@ def get_video_info(video_path: str) -> dict:
45
 
46
 
47
  def apply_lipsync_to_video(
48
- video_path: str, audio_16k_path: str, output_dir: str, crop_size: int = 256
 
 
 
49
  ) -> tuple:
50
  """Apply lipsync to video using clean 16k audio
51
 
@@ -53,17 +56,32 @@ def apply_lipsync_to_video(
53
  video_path: Path to input video
54
  audio_16k_path: Path to 16kHz audio
55
  output_dir: Directory to save output
56
- crop_size: Size of crop region for lipsync (default: 256)
57
 
58
  Returns:
59
  Tuple of (lipsynced_video_path, video_info)
60
  """
61
  try:
62
  lipsynced_video = os.path.join(output_dir, "output_with_lipsync.mp4")
63
- print(
64
- f"Lipsync params: video={video_path}, audio={audio_16k_path}, crop_size={crop_size}"
65
- )
66
- apply_lipsync(video_path, audio_16k_path, lipsynced_video, crop_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  video_info = get_video_info(lipsynced_video)
69
  print(
@@ -78,11 +96,18 @@ def apply_lipsync_to_video(
78
  )
79
  if "face not detected" in str(e).lower():
80
  raise RuntimeError(
81
- "Face detection failed in LatentSync pipeline. Please upload a video with a clear, visible face."
82
  )
83
  print(f"Runtime Error in lipsync processing: {e}")
84
  traceback.print_exc()
85
  raise
 
 
 
 
 
 
 
86
  except Exception as e:
87
  print(f"Error in apply_lipsync_to_video: {e}")
88
  traceback.print_exc()
 
45
 
46
 
47
  def apply_lipsync_to_video(
48
+ video_path: str,
49
+ audio_16k_path: str,
50
+ output_dir: str,
51
+ model_type: str = "LatentSync v1.6",
52
  ) -> tuple:
53
  """Apply lipsync to video using clean 16k audio
54
 
 
56
  video_path: Path to input video
57
  audio_16k_path: Path to 16kHz audio
58
  output_dir: Directory to save output
59
+ model_type: Model type for lipsync ("LatentSync v1.6" or "MuseTalk v1.5")
60
 
61
  Returns:
62
  Tuple of (lipsynced_video_path, video_info)
63
  """
64
  try:
65
  lipsynced_video = os.path.join(output_dir, "output_with_lipsync.mp4")
66
+
67
+ if model_type == "LatentSync v1.6":
68
+ crop_size = 512
69
+ print(
70
+ f"Using LatentSync v1.6: video={video_path}, audio={audio_16k_path}, crop_size={crop_size}"
71
+ )
72
+ apply_lipsync(video_path, audio_16k_path, lipsynced_video, crop_size)
73
+
74
+ elif model_type == "MuseTalk v1.5":
75
+ crop_size = 256
76
+ print(
77
+ f"Using MuseTalk v1.5: video={video_path}, audio={audio_16k_path}, crop_size={crop_size}"
78
+ )
79
+ from musetalk import apply_musetalk_lipsync
80
+
81
+ apply_musetalk_lipsync(video_path, audio_16k_path, lipsynced_video)
82
+
83
+ else:
84
+ raise ValueError(f"Unknown model_type: {model_type}")
85
 
86
  video_info = get_video_info(lipsynced_video)
87
  print(
 
96
  )
97
  if "face not detected" in str(e).lower():
98
  raise RuntimeError(
99
+ "Face detection failed in lipsync pipeline. Please upload a video with a clear, visible face."
100
  )
101
  print(f"Runtime Error in lipsync processing: {e}")
102
  traceback.print_exc()
103
  raise
104
+ except Exception:
105
+ raise
106
+ except Exception as e:
107
+ print(f"Error in apply_lipsync_to_video: {e}")
108
+ traceback.print_exc()
109
+ raise
110
+
111
  except Exception as e:
112
  print(f"Error in apply_lipsync_to_video: {e}")
113
  traceback.print_exc()
musetalk.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MuseTalk V1.5 integration module"""
2
+
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ import copy
8
+ import math
9
+ import subprocess
10
+ from tqdm import tqdm
11
+ from glob import glob
12
+
13
+ from transformers import WhisperModel
14
+
15
+ from musetalk_integration.models.unet import UNet, PositionalEncoding
16
+ from musetalk_integration.models.vae import VAE
17
+ from musetalk_integration.utils.audio_processor import AudioProcessor
18
+ from musetalk_integration.utils.face_parsing import FaceParsing
19
+ from musetalk_integration.utils.blending import get_image
20
+ from musetalk_integration.utils.preprocessing import (
21
+ get_landmark_and_bbox,
22
+ read_imgs,
23
+ coord_placeholder,
24
+ )
25
+ from musetalk_integration.utils.utils import datagen, get_video_fps
26
+
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+ def load_musetalk_models():
32
+ """Load MuseTalk V1.5 models"""
33
+ print("Loading MuseTalk V1.5 models...")
34
+
35
+ vae = VAE(model_path="./checkpoints/sd-vae-ft-mse")
36
+ print("✓ VAE loaded")
37
+
38
+ unet = UNet(
39
+ unet_config="./checkpoints/musetalkV15/musetalk.json",
40
+ model_path="./checkpoints/musetalkV15/unet.pth",
41
+ device=device,
42
+ )
43
+ print("✓ UNet loaded")
44
+
45
+ pe = PositionalEncoding(d_model=384)
46
+ print("✓ Positional encoding loaded")
47
+
48
+ audio_processor = AudioProcessor(
49
+ feature_extractor_path="./checkpoints/whisper-tiny"
50
+ )
51
+ print("✓ Audio processor loaded")
52
+
53
+ whisper = WhisperModel.from_pretrained("./checkpoints/whisper-tiny")
54
+ whisper = whisper.to(device=device, dtype=torch.float16).eval()
55
+ whisper.requires_grad_(False)
56
+ print("✓ Whisper model loaded")
57
+
58
+ fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
59
+ print("✓ Face parser loaded")
60
+
61
+ timesteps = torch.tensor([0], device=device)
62
+
63
+ return vae, unet, pe, audio_processor, whisper, fp, timesteps
64
+
65
+
66
+ vae, unet, pe, audio_processor, whisper, fp, timesteps = load_musetalk_models()
67
+
68
+
69
+ @torch.no_grad()
70
+ def apply_musetalk_lipsync(
71
+ video_path: str, audio_path: str, video_out_path: str, progress=None
72
+ ):
73
+ """Apply MuseTalk V1.5 lipsync
74
+
75
+ Args:
76
+ video_path: Path to input video
77
+ audio_path: Path to input audio
78
+ video_out_path: Path to output video
79
+ progress: Progress object
80
+ """
81
+ print(f"\n{'=' * 60}")
82
+ print(f"MUSETALK V1.5 START")
83
+ print(f"Video: {video_path}")
84
+ print(f"Audio: {audio_path}")
85
+ print(f"Output: {video_out_path}")
86
+ print(f"{'=' * 60}\n")
87
+
88
+ output_dir = os.path.dirname(video_out_path)
89
+
90
+ # 1. Extract frames
91
+ input_basename = os.path.basename(video_path).split(".")[0]
92
+ save_dir_full = os.path.join(output_dir, f"{input_basename}_frames")
93
+ os.makedirs(save_dir_full, exist_ok=True)
94
+
95
+ cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
96
+ os.system(cmd)
97
+
98
+ input_img_list = sorted(glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]")))
99
+ fps = get_video_fps(video_path)
100
+ print(f"Extracted {len(input_img_list)} frames at {fps} fps")
101
+
102
+ # 2. Extract audio features
103
+ print("Extracting audio features...")
104
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(
105
+ audio_path
106
+ )
107
+ whisper_chunks = audio_processor.get_whisper_chunk(
108
+ whisper_input_features,
109
+ device,
110
+ torch.float16,
111
+ whisper,
112
+ librosa_length,
113
+ fps=fps,
114
+ audio_padding_length_left=2,
115
+ audio_padding_length_right=2,
116
+ )
117
+ print(f"Generated {len(whisper_chunks)} audio chunks")
118
+
119
+ # 3. Detect face landmarks
120
+ print("Extracting landmarks...")
121
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0)
122
+ print(f"Detected {len(coord_list)} face landmarks")
123
+
124
+ # 4. VAE encode
125
+ print("Encoding frames to latents...")
126
+ input_latent_list = []
127
+ for bbox, frame in zip(coord_list, frame_list):
128
+ if bbox == coord_placeholder:
129
+ continue
130
+ x1, y1, x2, y2 = bbox
131
+ y2 = y2 + 10
132
+ y2 = min(y2, frame.shape[0])
133
+ crop_frame = frame[y1:y2, x1:x2]
134
+ crop_frame = cv2.resize(
135
+ crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4
136
+ )
137
+ latents = vae.get_latents_for_unet(crop_frame)
138
+ input_latent_list.append(latents)
139
+ print(f"Encoded {len(input_latent_list)} frames")
140
+
141
+ # 5. Cycle frames for smoothing
142
+ frame_list_cycle = frame_list + frame_list[::-1]
143
+ coord_list_cycle = coord_list + coord_list[::-1]
144
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
145
+
146
+ # 6. Batch inference
147
+ print("Starting inference...")
148
+ batch_size = 8
149
+ gen = datagen(
150
+ whisper_chunks=whisper_chunks,
151
+ vae_encode_latents=input_latent_list_cycle,
152
+ batch_size=batch_size,
153
+ delay_frame=0,
154
+ device=device,
155
+ )
156
+
157
+ res_frame_list = []
158
+ for whisper_batch, latent_batch in tqdm(
159
+ gen, total=int(math.ceil(len(whisper_chunks) / batch_size))
160
+ ):
161
+ audio_feature_batch = pe(whisper_batch)
162
+ latent_batch = latent_batch.to(dtype=torch.float16)
163
+
164
+ pred_latents = unet.model(
165
+ latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
166
+ ).sample
167
+ recon = vae.decode_latents(pred_latents)
168
+ for res_frame in recon:
169
+ res_frame_list.append(res_frame)
170
+ print(f"Generated {len(res_frame_list)} frames")
171
+
172
+ # 7. Blend back to original video
173
+ print("Blending...")
174
+ output_frames_dir = os.path.join(output_dir, f"{input_basename}_output")
175
+ os.makedirs(output_frames_dir, exist_ok=True)
176
+
177
+ for i, res_frame in enumerate(tqdm(res_frame_list)):
178
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
179
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
180
+ x1, y1, x2, y2 = bbox
181
+ y2 = y2 + 10
182
+ y2 = min(y2, ori_frame.shape[0])
183
+
184
+ try:
185
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
186
+ except Exception as e:
187
+ print(f"Warning: Could not resize frame {i}: {e}")
188
+ continue
189
+
190
+ # MuseTalk v1.5 blending with jaw mode (default)
191
+ combine_frame = get_image(
192
+ ori_frame, res_frame, [x1, y1, x2, y2], mode="jaw", fp=fp
193
+ )
194
+ cv2.imwrite(f"{output_frames_dir}/{i:08d}.png", combine_frame)
195
+
196
+ # 8. Encode to video
197
+ print("Encoding video...")
198
+ cmd = f"ffmpeg -y -v warning -r {fps} -f image2 -i {output_frames_dir}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {video_out_path}"
199
+ os.system(cmd)
200
+
201
+ # 9. Add audio
202
+ print("Adding audio...")
203
+ cmd = f"ffmpeg -y -v warning -i {audio_path} -i {video_out_path} -c:v copy -c:a aac {video_out_path.replace('.mp4', '_final.mp4')}"
204
+ os.system(cmd)
205
+ os.replace(video_out_path.replace(".mp4", "_final.mp4"), video_out_path)
206
+
207
+ # 10. Cleanup
208
+ print("Cleaning up...")
209
+ import shutil
210
+
211
+ shutil.rmtree(save_dir_full)
212
+ shutil.rmtree(output_frames_dir)
213
+
214
+ print(f"\n{'=' * 60}")
215
+ print(f"MUSETALK V1.5 SUCCESS")
216
+ print(f"Output: {video_out_path}")
217
+ print(f"{'=' * 60}\n")
218
+
219
+ return video_out_path
musetalk.py.bak ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MuseTalk V1.5 integration module"""
2
+
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ import copy
8
+ import math
9
+ import subprocess
10
+ from tqdm import tqdm
11
+ from glob import glob
12
+
13
+ from transformers import WhisperModel
14
+
15
+ from musetalk_integration.models.unet import UNet, PositionalEncoding
16
+ from musetalk_integration.models.vae import VAE
17
+ from musetalk_integration.utils.audio_processor import AudioProcessor
18
+ from musetalk_integration.utils.face_parsing import FaceParsing
19
+ from musetalk_integration.utils.blending import get_image
20
+ from musetalk_integration.utils.preprocessing import (
21
+ get_landmark_and_bbox,
22
+ read_imgs,
23
+ coord_placeholder,
24
+ )
25
+ from musetalk_integration.utils.utils import datagen, get_video_fps
26
+
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+ def load_musetalk_models():
32
+ """Load MuseTalk V1.5 models"""
33
+ print("Loading MuseTalk V1.5 models...")
34
+
35
+ vae = VAE(model_path="./checkpoints/sd-vae-ft-mse")
36
+ print("✓ VAE loaded")
37
+
38
+ unet = UNet(
39
+ unet_config="./checkpoints/musetalkV15/musetalk.json",
40
+ model_path="./checkpoints/musetalkV15/unet.pth",
41
+ device=device,
42
+ )
43
+ print("✓ UNet loaded")
44
+
45
+ pe = PositionalEncoding(d_model=384)
46
+ print("✓ Positional encoding loaded")
47
+
48
+ audio_processor = AudioProcessor(
49
+ feature_extractor_path="./checkpoints/whisper-tiny"
50
+ )
51
+ print("✓ Audio processor loaded")
52
+
53
+ whisper = WhisperModel.from_pretrained("./checkpoints/whisper-tiny")
54
+ whisper = whisper.to(device=device, dtype=torch.float16).eval()
55
+ whisper.requires_grad_(False)
56
+ print("✓ Whisper model loaded")
57
+
58
+ fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
59
+ print("✓ Face parser loaded")
60
+
61
+ timesteps = torch.tensor([0], device=device)
62
+
63
+ return vae, unet, pe, audio_processor, whisper, fp, timesteps
64
+
65
+
66
+ vae, unet, pe, audio_processor, whisper, fp, timesteps = load_musetalk_models()
67
+
68
+
69
+ @torch.no_grad()
70
+ def apply_musetalk_lipsync(
71
+ video_path: str, audio_path: str, video_out_path: str, progress=None
72
+ ):
73
+ """Apply MuseTalk V1.5 lipsync
74
+
75
+ Args:
76
+ video_path: Path to input video
77
+ audio_path: Path to input audio
78
+ video_out_path: Path to output video
79
+ progress: Progress object
80
+ """
81
+ print(f"\n{'=' * 60}")
82
+ print(f"MUSETALK V1.5 START")
83
+ print(f"Video: {video_path}")
84
+ print(f"Audio: {audio_path}")
85
+ print(f"Output: {video_out_path}")
86
+ print(f"{'=' * 60}\n")
87
+
88
+ output_dir = os.path.dirname(video_out_path)
89
+
90
+ # 1. Extract frames
91
+ input_basename = os.path.basename(video_path).split(".")[0]
92
+ save_dir_full = os.path.join(output_dir, f"{input_basename}_frames")
93
+ os.makedirs(save_dir_full, exist_ok=True)
94
+
95
+ cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
96
+ os.system(cmd)
97
+
98
+ input_img_list = sorted(glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]")))
99
+ fps = get_video_fps(video_path)
100
+ print(f"Extracted {len(input_img_list)} frames at {fps} fps")
101
+
102
+ # 2. Extract audio features
103
+ print("Extracting audio features...")
104
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(
105
+ audio_path
106
+ )
107
+ whisper_chunks = audio_processor.get_whisper_chunk(
108
+ whisper_input_features,
109
+ device,
110
+ torch.float16,
111
+ whisper,
112
+ librosa_length,
113
+ fps=fps,
114
+ audio_padding_length_left=2,
115
+ audio_padding_length_right=2,
116
+ )
117
+ print(f"Generated {len(whisper_chunks)} audio chunks")
118
+
119
+ # 3. Detect face landmarks
120
+ print("Extracting landmarks...")
121
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0)
122
+ print(f"Detected {len(coord_list)} face landmarks")
123
+
124
+ # 4. VAE encode
125
+ print("Encoding frames to latents...")
126
+ input_latent_list = []
127
+ for bbox, frame in zip(coord_list, frame_list):
128
+ if bbox == coord_placeholder:
129
+ continue
130
+ x1, y1, x2, y2 = bbox
131
+ y2 = y2 + 10
132
+ y2 = min(y2, frame.shape[0])
133
+ crop_frame = frame[y1:y2, x1:x2]
134
+ crop_frame = cv2.resize(
135
+ crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4
136
+ )
137
+ latents = vae.get_latents_for_unet(crop_frame)
138
+ input_latent_list.append(latents)
139
+ print(f"Encoded {len(input_latent_list)} frames")
140
+
141
+ # 5. Cycle frames for smoothing
142
+ frame_list_cycle = frame_list + frame_list[::-1]
143
+ coord_list_cycle = coord_list + coord_list[::-1]
144
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
145
+
146
+ # 6. Batch inference
147
+ print("Starting inference...")
148
+ batch_size = 8
149
+ gen = datagen(
150
+ whisper_chunks=whisper_chunks,
151
+ vae_encode_latents=input_latent_list_cycle,
152
+ batch_size=batch_size,
153
+ delay_frame=0,
154
+ device=device,
155
+ )
156
+
157
+ res_frame_list = []
158
+ for whisper_batch, latent_batch in tqdm(
159
+ gen, total=int(math.ceil(len(whisper_chunks) / batch_size))
160
+ ):
161
+ audio_feature_batch = pe(whisper_batch)
162
+ latent_batch = latent_batch.to(dtype=torch.float16)
163
+
164
+ pred_latents = unet.model(
165
+ latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
166
+ ).sample
167
+ recon = vae.decode_latents(pred_latents)
168
+ for res_frame in recon:
169
+ res_frame_list.append(res_frame)
170
+ print(f"Generated {len(res_frame_list)} frames")
171
+
172
+ # 7. Blend back to original video
173
+ print("Blending...")
174
+ output_frames_dir = os.path.join(output_dir, f"{input_basename}_output")
175
+ os.makedirs(output_frames_dir, exist_ok=True)
176
+
177
+ for i, res_frame in enumerate(tqdm(res_frame_list)):
178
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
179
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
180
+ x1, y1, x2, y2 = bbox
181
+ y2 = y2 + 10
182
+ y2 = min(y2, ori_frame.shape[0])
183
+
184
+ try:
185
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
186
+ except Exception as e:
187
+ print(f"Warning: Could not resize frame {i}: {e}")
188
+ continue
189
+
190
+ # MuseTalk v1.5 blending with jaw mode (default)
191
+ combine_frame = get_image(
192
+ ori_frame, res_frame, [x1, y1, x2, y2], mode="jaw", fp=fp
193
+ )
194
+ cv2.imwrite(f"{output_frames_dir}/{i:08d}.png", combine_frame)
195
+
196
+ # 8. Encode to video
197
+ print("Encoding video...")
198
+ cmd = f"ffmpeg -y -v warning -r {fps} -f image2 -i {output_frames_dir}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {video_out_path}"
199
+ os.system(cmd)
200
+
201
+ # 9. Add audio
202
+ print("Adding audio...")
203
+ cmd = f"ffmpeg -y -v warning -i {audio_path} -i {video_out_path} -c:v copy -c:a aac {video_out_path.replace('.mp4', '_final.mp4')}"
204
+ os.system(cmd)
205
+ os.replace(video_out_path.replace(".mp4", "_final.mp4"), video_out_path)
206
+
207
+ # 10. Cleanup
208
+ print("Cleaning up...")
209
+ import shutil
210
+
211
+ shutil.rmtree(save_dir_full)
212
+ shutil.rmtree(output_frames_dir)
213
+
214
+ print(f"\n{'=' * 60}")
215
+ print(f"MUSETALK V1.5 SUCCESS")
216
+ print(f"Output: {video_out_path}")
217
+ print(f"{'=' * 60}\n")
218
+
219
+ return video_out_path
musetalk_integration/__init__.py ADDED
File without changes
musetalk_integration/models/__init__.py ADDED
File without changes
musetalk_integration/models/unet.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import json
5
+
6
+ from diffusers import UNet2DConditionModel
7
+ import sys
8
+ import time
9
+ import numpy as np
10
+ import os
11
+
12
+ class PositionalEncoding(nn.Module):
13
+ def __init__(self, d_model=384, max_len=5000):
14
+ super(PositionalEncoding, self).__init__()
15
+ pe = torch.zeros(max_len, d_model)
16
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
17
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+ pe = pe.unsqueeze(0)
21
+ self.register_buffer('pe', pe)
22
+
23
+ def forward(self, x):
24
+ b, seq_len, d_model = x.size()
25
+ pe = self.pe[:, :seq_len, :]
26
+ x = x + pe.to(x.device)
27
+ return x
28
+
29
+ class UNet():
30
+ def __init__(self,
31
+ unet_config,
32
+ model_path,
33
+ use_float16=False,
34
+ device=None
35
+ ):
36
+ with open(unet_config, 'r') as f:
37
+ unet_config = json.load(f)
38
+ self.model = UNet2DConditionModel(**unet_config)
39
+ self.pe = PositionalEncoding(d_model=384)
40
+ if device != None:
41
+ self.device = device
42
+ else:
43
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
45
+ self.model.load_state_dict(weights)
46
+ if use_float16:
47
+ self.model = self.model.half()
48
+ self.model.to(self.device)
49
+
50
+ if __name__ == "__main__":
51
+ unet = UNet()
musetalk_integration/models/vae.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderKL
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ import os
9
+
10
+ class VAE():
11
+ """
12
+ VAE (Variational Autoencoder) class for image processing.
13
+ """
14
+
15
+ def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
16
+ """
17
+ Initialize the VAE instance.
18
+
19
+ :param model_path: Path to the trained model.
20
+ :param resized_img: The size to which images are resized.
21
+ :param use_float16: Whether to use float16 precision.
22
+ """
23
+ self.model_path = model_path
24
+ self.vae = AutoencoderKL.from_pretrained(self.model_path)
25
+
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.vae.to(self.device)
28
+
29
+ if use_float16:
30
+ self.vae = self.vae.half()
31
+ self._use_float16 = True
32
+ else:
33
+ self._use_float16 = False
34
+
35
+ self.scaling_factor = self.vae.config.scaling_factor
36
+ self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
37
+ self._resized_img = resized_img
38
+ self._mask_tensor = self.get_mask_tensor()
39
+
40
+ def get_mask_tensor(self):
41
+ """
42
+ Creates a mask tensor for image processing.
43
+ :return: A mask tensor.
44
+ """
45
+ mask_tensor = torch.zeros((self._resized_img,self._resized_img))
46
+ mask_tensor[:self._resized_img//2,:] = 1
47
+ mask_tensor[mask_tensor< 0.5] = 0
48
+ mask_tensor[mask_tensor>= 0.5] = 1
49
+ return mask_tensor
50
+
51
+ def preprocess_img(self,img_name,half_mask=False):
52
+ """
53
+ Preprocess an image for the VAE.
54
+
55
+ :param img_name: The image file path or a list of image file paths.
56
+ :param half_mask: Whether to apply a half mask to the image.
57
+ :return: A preprocessed image tensor.
58
+ """
59
+ window = []
60
+ if isinstance(img_name, str):
61
+ window_fnames = [img_name]
62
+ for fname in window_fnames:
63
+ img = cv2.imread(fname)
64
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
+ img = cv2.resize(img, (self._resized_img, self._resized_img),
66
+ interpolation=cv2.INTER_LANCZOS4)
67
+ window.append(img)
68
+ else:
69
+ img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
70
+ window.append(img)
71
+
72
+ x = np.asarray(window) / 255.
73
+ x = np.transpose(x, (3, 0, 1, 2))
74
+ x = torch.squeeze(torch.FloatTensor(x))
75
+ if half_mask:
76
+ x = x * (self._mask_tensor>0.5)
77
+ x = self.transform(x)
78
+
79
+ x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
80
+ x = x.to(self.vae.device)
81
+
82
+ return x
83
+
84
+ def encode_latents(self,image):
85
+ """
86
+ Encode an image into latent variables.
87
+
88
+ :param image: The image tensor to encode.
89
+ :return: The encoded latent variables.
90
+ """
91
+ with torch.no_grad():
92
+ init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
93
+ init_latents = self.scaling_factor * init_latent_dist.sample()
94
+ return init_latents
95
+
96
+ def decode_latents(self, latents):
97
+ """
98
+ Decode latent variables back into an image.
99
+ :param latents: The latent variables to decode.
100
+ :return: A NumPy array representing the decoded image.
101
+ """
102
+ latents = (1/ self.scaling_factor) * latents
103
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
104
+ image = (image / 2 + 0.5).clamp(0, 1)
105
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
106
+ image = (image * 255).round().astype("uint8")
107
+ image = image[...,::-1] # RGB to BGR
108
+ return image
109
+
110
+ def get_latents_for_unet(self,img):
111
+ """
112
+ Prepare latent variables for a U-Net model.
113
+ :param img: The image to process.
114
+ :return: A concatenated tensor of latents for U-Net input.
115
+ """
116
+
117
+ ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
118
+ masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
119
+ ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
120
+ ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
121
+ latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
122
+ return latent_model_input
123
+
124
+ if __name__ == "__main__":
125
+ vae_mode_path = "./models/sd-vae-ft-mse/"
126
+ vae = VAE(model_path = vae_mode_path,use_float16=False)
127
+ img_path = "./results/sun001_crop/00000.png"
128
+
129
+ crop_imgs_path = "./results/sun001_crop/"
130
+ latents_out_path = "./results/latents/"
131
+ if not os.path.exists(latents_out_path):
132
+ os.mkdir(latents_out_path)
133
+
134
+ files = os.listdir(crop_imgs_path)
135
+ files.sort()
136
+ files = [file for file in files if file.split(".")[-1] == "png"]
137
+
138
+ for file in files:
139
+ index = file.split(".")[0]
140
+ img_path = crop_imgs_path + file
141
+ latents = vae.get_latents_for_unet(img_path)
142
+ print(img_path,"latents",latents.size())
143
+ #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
144
+ #reload_tensor = torch.load('tensor.pt')
145
+ #print(reload_tensor.size())
146
+
147
+
148
+
musetalk_integration/utils/__init__.py ADDED
File without changes
musetalk_integration/utils/audio_processor.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from transformers import AutoFeatureExtractor
9
+
10
+
11
+ class AudioProcessor:
12
+ def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
13
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
14
+
15
+ def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
16
+ if not os.path.exists(wav_path):
17
+ return None
18
+ librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
19
+ assert sampling_rate == 16000
20
+ # Split audio into 30s segments
21
+ segment_length = 30 * sampling_rate
22
+ segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
23
+
24
+ features = []
25
+ for segment in segments:
26
+ audio_feature = self.feature_extractor(
27
+ segment,
28
+ return_tensors="pt",
29
+ sampling_rate=sampling_rate
30
+ ).input_features
31
+ if weight_dtype is not None:
32
+ audio_feature = audio_feature.to(dtype=weight_dtype)
33
+ features.append(audio_feature)
34
+
35
+ return features, len(librosa_output)
36
+
37
+ def get_whisper_chunk(
38
+ self,
39
+ whisper_input_features,
40
+ device,
41
+ weight_dtype,
42
+ whisper,
43
+ librosa_length,
44
+ fps=25,
45
+ audio_padding_length_left=2,
46
+ audio_padding_length_right=2,
47
+ ):
48
+ audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
49
+ whisper_feature = []
50
+ # Process multiple 30s mel input features
51
+ for input_feature in whisper_input_features:
52
+ input_feature = input_feature.to(device).to(weight_dtype)
53
+ audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
54
+ audio_feats = torch.stack(audio_feats, dim=2)
55
+ whisper_feature.append(audio_feats)
56
+
57
+ whisper_feature = torch.cat(whisper_feature, dim=1)
58
+ # Trim the last segment to remove padding
59
+ sr = 16000
60
+ audio_fps = 50
61
+ fps = int(fps)
62
+ whisper_idx_multiplier = audio_fps / fps
63
+ num_frames = math.floor((librosa_length / sr) * fps)
64
+ actual_length = math.floor((librosa_length / sr) * audio_fps)
65
+ whisper_feature = whisper_feature[:,:actual_length,...]
66
+
67
+ # Calculate padding amount
68
+ padding_nums = math.ceil(whisper_idx_multiplier)
69
+ # Add padding at start and end
70
+ whisper_feature = torch.cat([
71
+ torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
72
+ whisper_feature,
73
+ # Add extra padding to prevent out of bounds
74
+ torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
75
+ ], 1)
76
+
77
+ audio_prompts = []
78
+ for frame_index in range(num_frames):
79
+ try:
80
+ audio_index = math.floor(frame_index * whisper_idx_multiplier)
81
+ audio_clip = whisper_feature[:, audio_index: audio_index + audio_feature_length_per_frame]
82
+ assert audio_clip.shape[1] == audio_feature_length_per_frame
83
+ audio_prompts.append(audio_clip)
84
+ except Exception as e:
85
+ print(f"Error occurred: {e}")
86
+ print(f"whisper_feature.shape: {whisper_feature.shape}")
87
+ print(f"audio_clip.shape: {audio_clip.shape}")
88
+ print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
89
+ print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
90
+ exit()
91
+
92
+ audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
93
+ audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
94
+ return audio_prompts
95
+
96
+ if __name__ == "__main__":
97
+ audio_processor = AudioProcessor()
98
+ wav_path = "./2.wav"
99
+ audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
100
+ print("Audio Feature shape:", audio_feature.shape)
101
+ print("librosa_feature_length:", librosa_feature_length)
102
+
musetalk_integration/utils/blending.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import cv2
4
+ import copy
5
+
6
+
7
+ def get_crop_box(box, expand):
8
+ x, y, x1, y1 = box
9
+ x_c, y_c = (x+x1)//2, (y+y1)//2
10
+ w, h = x1-x, y1-y
11
+ s = int(max(w, h)//2*expand)
12
+ crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
13
+ return crop_box, s
14
+
15
+
16
+ def face_seg(image, mode="raw", fp=None):
17
+ """
18
+ 对图像进行面部解析,生成面部区域的掩码。
19
+
20
+ Args:
21
+ image (PIL.Image): 输入图像。
22
+
23
+ Returns:
24
+ PIL.Image: 面部区域的掩码图像。
25
+ """
26
+ seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
27
+ if seg_image is None:
28
+ print("error, no person_segment") # 如果没有检测到面部,返回错误
29
+ return None
30
+
31
+ seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
32
+ return seg_image
33
+
34
+
35
+ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
36
+ """
37
+ 将裁剪的面部图像粘贴回原始图像,并进行一些处理。
38
+
39
+ Args:
40
+ image (numpy.ndarray): 原始图像(身体部分)。
41
+ face (numpy.ndarray): 裁剪的面部图像。
42
+ face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
43
+ upper_boundary_ratio (float): 用于控制面部区域的保留比例。
44
+ expand (float): 扩展因子,用于放大裁剪框。
45
+ mode: 融合mask构建方式
46
+
47
+ Returns:
48
+ numpy.ndarray: 处理后的图像。
49
+ """
50
+ # 将 numpy 数组转换为 PIL 图像
51
+ body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
52
+ face = Image.fromarray(face[:, :, ::-1]) # 面部图像
53
+
54
+ x, y, x1, y1 = face_box # 获取面部边界框的坐标
55
+ crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
56
+ x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
57
+ face_position = (x, y) # 面部在原始图像中的位置
58
+
59
+ # 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
60
+ face_large = body.crop(crop_box)
61
+
62
+ ori_shape = face_large.size # 裁剪后图像的原始尺寸
63
+
64
+ # 对裁剪后的面部区域进行面部解析,生成掩码
65
+ mask_image = face_seg(face_large, mode=mode, fp=fp)
66
+
67
+ mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
68
+
69
+ mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
70
+ mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
71
+
72
+
73
+ # 保留面部区域的上半部分(用于控制说话区域)
74
+ width, height = mask_image.size
75
+ top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
76
+ modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
77
+ modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
78
+
79
+
80
+ # 对掩码进行高斯模糊,使边缘更平滑
81
+ blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
82
+ mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
83
+ #mask_array = np.array(modified_mask_image)
84
+ mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
85
+
86
+ # 将裁剪的面部图像粘贴回扩展后的面部区域
87
+ face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
88
+
89
+ body.paste(face_large, crop_box[:2], mask_image)
90
+
91
+ body = np.array(body) # 将 PIL 图像转换回 numpy 数组
92
+
93
+ return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
94
+
95
+
96
+ def get_image_blending(image, face, face_box, mask_array, crop_box):
97
+ body = Image.fromarray(image[:,:,::-1])
98
+ face = Image.fromarray(face[:,:,::-1])
99
+
100
+ x, y, x1, y1 = face_box
101
+ x_s, y_s, x_e, y_e = crop_box
102
+ face_large = body.crop(crop_box)
103
+
104
+ mask_image = Image.fromarray(mask_array)
105
+ mask_image = mask_image.convert("L")
106
+ face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
107
+ body.paste(face_large, crop_box[:2], mask_image)
108
+ body = np.array(body)
109
+ return body[:,:,::-1]
110
+
111
+
112
+ def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
113
+ body = Image.fromarray(image[:,:,::-1])
114
+
115
+ x, y, x1, y1 = face_box
116
+ #print(x1-x,y1-y)
117
+ crop_box, s = get_crop_box(face_box, expand)
118
+ x_s, y_s, x_e, y_e = crop_box
119
+
120
+ face_large = body.crop(crop_box)
121
+ ori_shape = face_large.size
122
+
123
+ mask_image = face_seg(face_large, mode=mode, fp=fp)
124
+ mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
125
+ mask_image = Image.new('L', ori_shape, 0)
126
+ mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
127
+
128
+ # keep upper_boundary_ratio of talking area
129
+ width, height = mask_image.size
130
+ top_boundary = int(height * upper_boundary_ratio)
131
+ modified_mask_image = Image.new('L', ori_shape, 0)
132
+ modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
133
+
134
+ blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
135
+ mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
136
+ return mask_array, crop_box
musetalk_integration/utils/dwpose/__init__.py ADDED
File without changes
musetalk_integration/utils/dwpose/default_runtime.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_scope = 'mmpose'
2
+
3
+ # hooks
4
+ default_hooks = dict(
5
+ timer=dict(type='IterTimerHook'),
6
+ logger=dict(type='LoggerHook', interval=50),
7
+ param_scheduler=dict(type='ParamSchedulerHook'),
8
+ checkpoint=dict(type='CheckpointHook', interval=10),
9
+ sampler_seed=dict(type='DistSamplerSeedHook'),
10
+ visualization=dict(type='PoseVisualizationHook', enable=False),
11
+ badcase=dict(
12
+ type='BadCaseAnalysisHook',
13
+ enable=False,
14
+ out_dir='badcase',
15
+ metric_type='loss',
16
+ badcase_thr=5))
17
+
18
+ # custom hooks
19
+ custom_hooks = [
20
+ # Synchronize model buffers such as running_mean and running_var in BN
21
+ # at the end of each epoch
22
+ dict(type='SyncBuffersHook')
23
+ ]
24
+
25
+ # multi-processing backend
26
+ env_cfg = dict(
27
+ cudnn_benchmark=False,
28
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
29
+ dist_cfg=dict(backend='nccl'),
30
+ )
31
+
32
+ # visualizer
33
+ vis_backends = [
34
+ dict(type='LocalVisBackend'),
35
+ # dict(type='TensorboardVisBackend'),
36
+ # dict(type='WandbVisBackend'),
37
+ ]
38
+ visualizer = dict(
39
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
40
+
41
+ # logger
42
+ log_processor = dict(
43
+ type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
44
+ log_level = 'INFO'
45
+ load_from = None
46
+ resume = False
47
+
48
+ # file I/O backend
49
+ backend_args = dict(backend='local')
50
+
51
+ # training/validation/testing progress
52
+ train_cfg = dict(by_epoch=True)
53
+ val_cfg = dict()
54
+ test_cfg = dict()
musetalk_integration/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #_base_ = ['../../../_base_/default_runtime.py']
2
+ _base_ = ['default_runtime.py']
3
+
4
+ # runtime
5
+ max_epochs = 270
6
+ stage2_num_epochs = 30
7
+ base_lr = 4e-3
8
+ train_batch_size = 32
9
+ val_batch_size = 32
10
+
11
+ train_cfg = dict(max_epochs=max_epochs, val_interval=10)
12
+ randomness = dict(seed=21)
13
+
14
+ # optimizer
15
+ optim_wrapper = dict(
16
+ type='OptimWrapper',
17
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
18
+ paramwise_cfg=dict(
19
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
20
+
21
+ # learning rate
22
+ param_scheduler = [
23
+ dict(
24
+ type='LinearLR',
25
+ start_factor=1.0e-5,
26
+ by_epoch=False,
27
+ begin=0,
28
+ end=1000),
29
+ dict(
30
+ # use cosine lr from 150 to 300 epoch
31
+ type='CosineAnnealingLR',
32
+ eta_min=base_lr * 0.05,
33
+ begin=max_epochs // 2,
34
+ end=max_epochs,
35
+ T_max=max_epochs // 2,
36
+ by_epoch=True,
37
+ convert_to_iter_based=True),
38
+ ]
39
+
40
+ # automatically scaling LR based on the actual training batch size
41
+ auto_scale_lr = dict(base_batch_size=512)
42
+
43
+ # codec settings
44
+ codec = dict(
45
+ type='SimCCLabel',
46
+ input_size=(288, 384),
47
+ sigma=(6., 6.93),
48
+ simcc_split_ratio=2.0,
49
+ normalize=False,
50
+ use_dark=False)
51
+
52
+ # model settings
53
+ model = dict(
54
+ type='TopdownPoseEstimator',
55
+ data_preprocessor=dict(
56
+ type='PoseDataPreprocessor',
57
+ mean=[123.675, 116.28, 103.53],
58
+ std=[58.395, 57.12, 57.375],
59
+ bgr_to_rgb=True),
60
+ backbone=dict(
61
+ _scope_='mmdet',
62
+ type='CSPNeXt',
63
+ arch='P5',
64
+ expand_ratio=0.5,
65
+ deepen_factor=1.,
66
+ widen_factor=1.,
67
+ out_indices=(4, ),
68
+ channel_attention=True,
69
+ norm_cfg=dict(type='SyncBN'),
70
+ act_cfg=dict(type='SiLU'),
71
+ init_cfg=dict(
72
+ type='Pretrained',
73
+ prefix='backbone.',
74
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
75
+ 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
76
+ )),
77
+ head=dict(
78
+ type='RTMCCHead',
79
+ in_channels=1024,
80
+ out_channels=133,
81
+ input_size=codec['input_size'],
82
+ in_featuremap_size=(9, 12),
83
+ simcc_split_ratio=codec['simcc_split_ratio'],
84
+ final_layer_kernel_size=7,
85
+ gau_cfg=dict(
86
+ hidden_dims=256,
87
+ s=128,
88
+ expansion_factor=2,
89
+ dropout_rate=0.,
90
+ drop_path=0.,
91
+ act_fn='SiLU',
92
+ use_rel_bias=False,
93
+ pos_enc=False),
94
+ loss=dict(
95
+ type='KLDiscretLoss',
96
+ use_target_weight=True,
97
+ beta=10.,
98
+ label_softmax=True),
99
+ decoder=codec),
100
+ test_cfg=dict(flip_test=True, ))
101
+
102
+ # base dataset settings
103
+ dataset_type = 'UBody2dDataset'
104
+ data_mode = 'topdown'
105
+ data_root = 'data/UBody/'
106
+
107
+ backend_args = dict(backend='local')
108
+
109
+ scenes = [
110
+ 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
111
+ 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
112
+ 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
113
+ ]
114
+
115
+ train_datasets = [
116
+ dict(
117
+ type='CocoWholeBodyDataset',
118
+ data_root='data/coco/',
119
+ data_mode=data_mode,
120
+ ann_file='annotations/coco_wholebody_train_v1.0.json',
121
+ data_prefix=dict(img='train2017/'),
122
+ pipeline=[])
123
+ ]
124
+
125
+ for scene in scenes:
126
+ train_dataset = dict(
127
+ type=dataset_type,
128
+ data_root=data_root,
129
+ data_mode=data_mode,
130
+ ann_file=f'annotations/{scene}/train_annotations.json',
131
+ data_prefix=dict(img='images/'),
132
+ pipeline=[],
133
+ sample_interval=10)
134
+ train_datasets.append(train_dataset)
135
+
136
+ # pipelines
137
+ train_pipeline = [
138
+ dict(type='LoadImage', backend_args=backend_args),
139
+ dict(type='GetBBoxCenterScale'),
140
+ dict(type='RandomFlip', direction='horizontal'),
141
+ dict(type='RandomHalfBody'),
142
+ dict(
143
+ type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
144
+ dict(type='TopdownAffine', input_size=codec['input_size']),
145
+ dict(type='mmdet.YOLOXHSVRandomAug'),
146
+ dict(
147
+ type='Albumentation',
148
+ transforms=[
149
+ dict(type='Blur', p=0.1),
150
+ dict(type='MedianBlur', p=0.1),
151
+ dict(
152
+ type='CoarseDropout',
153
+ max_holes=1,
154
+ max_height=0.4,
155
+ max_width=0.4,
156
+ min_holes=1,
157
+ min_height=0.2,
158
+ min_width=0.2,
159
+ p=1.0),
160
+ ]),
161
+ dict(type='GenerateTarget', encoder=codec),
162
+ dict(type='PackPoseInputs')
163
+ ]
164
+ val_pipeline = [
165
+ dict(type='LoadImage', backend_args=backend_args),
166
+ dict(type='GetBBoxCenterScale'),
167
+ dict(type='TopdownAffine', input_size=codec['input_size']),
168
+ dict(type='PackPoseInputs')
169
+ ]
170
+
171
+ train_pipeline_stage2 = [
172
+ dict(type='LoadImage', backend_args=backend_args),
173
+ dict(type='GetBBoxCenterScale'),
174
+ dict(type='RandomFlip', direction='horizontal'),
175
+ dict(type='RandomHalfBody'),
176
+ dict(
177
+ type='RandomBBoxTransform',
178
+ shift_factor=0.,
179
+ scale_factor=[0.5, 1.5],
180
+ rotate_factor=90),
181
+ dict(type='TopdownAffine', input_size=codec['input_size']),
182
+ dict(type='mmdet.YOLOXHSVRandomAug'),
183
+ dict(
184
+ type='Albumentation',
185
+ transforms=[
186
+ dict(type='Blur', p=0.1),
187
+ dict(type='MedianBlur', p=0.1),
188
+ dict(
189
+ type='CoarseDropout',
190
+ max_holes=1,
191
+ max_height=0.4,
192
+ max_width=0.4,
193
+ min_holes=1,
194
+ min_height=0.2,
195
+ min_width=0.2,
196
+ p=0.5),
197
+ ]),
198
+ dict(type='GenerateTarget', encoder=codec),
199
+ dict(type='PackPoseInputs')
200
+ ]
201
+
202
+ # data loaders
203
+ train_dataloader = dict(
204
+ batch_size=train_batch_size,
205
+ num_workers=10,
206
+ persistent_workers=True,
207
+ sampler=dict(type='DefaultSampler', shuffle=True),
208
+ dataset=dict(
209
+ type='CombinedDataset',
210
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
211
+ datasets=train_datasets,
212
+ pipeline=train_pipeline,
213
+ test_mode=False,
214
+ ))
215
+
216
+ val_dataloader = dict(
217
+ batch_size=val_batch_size,
218
+ num_workers=10,
219
+ persistent_workers=True,
220
+ drop_last=False,
221
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222
+ dataset=dict(
223
+ type='CocoWholeBodyDataset',
224
+ data_root=data_root,
225
+ data_mode=data_mode,
226
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
227
+ bbox_file='data/coco/person_detection_results/'
228
+ 'COCO_val2017_detections_AP_H_56_person.json',
229
+ data_prefix=dict(img='coco/val2017/'),
230
+ test_mode=True,
231
+ pipeline=val_pipeline,
232
+ ))
233
+ test_dataloader = val_dataloader
234
+
235
+ # hooks
236
+ default_hooks = dict(
237
+ checkpoint=dict(
238
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239
+
240
+ custom_hooks = [
241
+ dict(
242
+ type='EMAHook',
243
+ ema_type='ExpMomentumEMA',
244
+ momentum=0.0002,
245
+ update_buffers=True,
246
+ priority=49),
247
+ dict(
248
+ type='mmdet.PipelineSwitchHook',
249
+ switch_epoch=max_epochs - stage2_num_epochs,
250
+ switch_pipeline=train_pipeline_stage2)
251
+ ]
252
+
253
+ # evaluators
254
+ val_evaluator = dict(
255
+ type='CocoWholeBodyMetric',
256
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
257
+ test_evaluator = val_evaluator
musetalk_integration/utils/face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
musetalk_integration/utils/face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = 'adrian.bulat@nottingham.ac.uk'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face
musetalk_integration/utils/face_detection/api.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ try:
9
+ import urllib.request as request_file
10
+ except BaseException:
11
+ import urllib as request_file
12
+
13
+ from .models import FAN, ResNetDepth
14
+ from .utils import *
15
+
16
+
17
+ class LandmarksType(Enum):
18
+ """Enum class defining the type of landmarks to detect.
19
+
20
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
22
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23
+
24
+ """
25
+ _2D = 1
26
+ _2halfD = 2
27
+ _3D = 3
28
+
29
+
30
+ class NetworkSize(Enum):
31
+ # TINY = 1
32
+ # SMALL = 2
33
+ # MEDIUM = 3
34
+ LARGE = 4
35
+
36
+ def __new__(cls, value):
37
+ member = object.__new__(cls)
38
+ member._value_ = value
39
+ return member
40
+
41
+ def __int__(self):
42
+ return self.value
43
+
44
+
45
+
46
+ class FaceAlignment:
47
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49
+ self.device = device
50
+ self.flip_input = flip_input
51
+ self.landmarks_type = landmarks_type
52
+ self.verbose = verbose
53
+
54
+ network_size = int(network_size)
55
+
56
+ if 'cuda' in device:
57
+ torch.backends.cudnn.benchmark = True
58
+ # torch.backends.cuda.matmul.allow_tf32 = False
59
+ # torch.backends.cudnn.benchmark = True
60
+ # torch.backends.cudnn.deterministic = False
61
+ # torch.backends.cudnn.allow_tf32 = True
62
+ print('cuda start')
63
+
64
+
65
+ # Get the face detector
66
+ face_detector_module = __import__('face_detection.detection.' + face_detector,
67
+ globals(), locals(), [face_detector], 0)
68
+
69
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
70
+
71
+ def get_detections_for_batch(self, images):
72
+ images = images[..., ::-1]
73
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
74
+ results = []
75
+
76
+ for i, d in enumerate(detected_faces):
77
+ if len(d) == 0:
78
+ results.append(None)
79
+ continue
80
+ d = d[0]
81
+ d = np.clip(d, 0, None)
82
+
83
+ x1, y1, x2, y2 = map(int, d[:-1])
84
+ results.append((x1, y1, x2, y2))
85
+
86
+ return results
87
+
88
+
89
+ class YOLOv8_face:
90
+ def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
91
+ self.conf_threshold = conf_thres
92
+ self.iou_threshold = iou_thres
93
+ self.class_names = ['face']
94
+ self.num_classes = len(self.class_names)
95
+ # Initialize model
96
+ self.net = cv2.dnn.readNet(path)
97
+ self.input_height = 640
98
+ self.input_width = 640
99
+ self.reg_max = 16
100
+
101
+ self.project = np.arange(self.reg_max)
102
+ self.strides = (8, 16, 32)
103
+ self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
104
+ self.anchors = self.make_anchors(self.feats_hw)
105
+
106
+ def make_anchors(self, feats_hw, grid_cell_offset=0.5):
107
+ """Generate anchors from features."""
108
+ anchor_points = {}
109
+ for i, stride in enumerate(self.strides):
110
+ h,w = feats_hw[i]
111
+ x = np.arange(0, w) + grid_cell_offset # shift x
112
+ y = np.arange(0, h) + grid_cell_offset # shift y
113
+ sx, sy = np.meshgrid(x, y)
114
+ # sy, sx = np.meshgrid(y, x)
115
+ anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
116
+ return anchor_points
117
+
118
+ def softmax(self, x, axis=1):
119
+ x_exp = np.exp(x)
120
+ # 如果是列向量,则axis=0
121
+ x_sum = np.sum(x_exp, axis=axis, keepdims=True)
122
+ s = x_exp / x_sum
123
+ return s
124
+
125
+ def resize_image(self, srcimg, keep_ratio=True):
126
+ top, left, newh, neww = 0, 0, self.input_width, self.input_height
127
+ if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
128
+ hw_scale = srcimg.shape[0] / srcimg.shape[1]
129
+ if hw_scale > 1:
130
+ newh, neww = self.input_height, int(self.input_width / hw_scale)
131
+ img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
132
+ left = int((self.input_width - neww) * 0.5)
133
+ img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
134
+ value=(0, 0, 0)) # add border
135
+ else:
136
+ newh, neww = int(self.input_height * hw_scale), self.input_width
137
+ img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
138
+ top = int((self.input_height - newh) * 0.5)
139
+ img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
140
+ value=(0, 0, 0))
141
+ else:
142
+ img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
143
+ return img, newh, neww, top, left
144
+
145
+ def detect(self, srcimg):
146
+ input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
147
+ scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
148
+ input_img = input_img.astype(np.float32) / 255.0
149
+
150
+ blob = cv2.dnn.blobFromImage(input_img)
151
+ self.net.setInput(blob)
152
+ outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
153
+ # if isinstance(outputs, tuple):
154
+ # outputs = list(outputs)
155
+ # if float(cv2.__version__[:3])>=4.7:
156
+ # outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要
157
+ # Perform inference on the image
158
+ det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
159
+ return det_bboxes, det_conf, det_classid, landmarks
160
+
161
+ def post_process(self, preds, scale_h, scale_w, padh, padw):
162
+ bboxes, scores, landmarks = [], [], []
163
+ for i, pred in enumerate(preds):
164
+ stride = int(self.input_height/pred.shape[2])
165
+ pred = pred.transpose((0, 2, 3, 1))
166
+
167
+ box = pred[..., :self.reg_max * 4]
168
+ cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
169
+ kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
170
+
171
+ # tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
172
+ tmp = box.reshape(-1, 4, self.reg_max)
173
+ bbox_pred = self.softmax(tmp, axis=-1)
174
+ bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
175
+
176
+ bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
177
+ kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
178
+ kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
179
+ kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
180
+
181
+ bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
182
+ bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
183
+ kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
184
+ kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
185
+
186
+ bboxes.append(bbox)
187
+ scores.append(cls)
188
+ landmarks.append(kpts)
189
+
190
+ bboxes = np.concatenate(bboxes, axis=0)
191
+ scores = np.concatenate(scores, axis=0)
192
+ landmarks = np.concatenate(landmarks, axis=0)
193
+
194
+ bboxes_wh = bboxes.copy()
195
+ bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
196
+ classIds = np.argmax(scores, axis=1)
197
+ confidences = np.max(scores, axis=1) ####max_class_confidence
198
+
199
+ mask = confidences>self.conf_threshold
200
+ bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
201
+ confidences = confidences[mask]
202
+ classIds = classIds[mask]
203
+ landmarks = landmarks[mask]
204
+
205
+ indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
206
+ self.iou_threshold).flatten()
207
+ if len(indices) > 0:
208
+ mlvl_bboxes = bboxes_wh[indices]
209
+ confidences = confidences[indices]
210
+ classIds = classIds[indices]
211
+ landmarks = landmarks[indices]
212
+ return mlvl_bboxes, confidences, classIds, landmarks
213
+ else:
214
+ print('nothing detect')
215
+ return np.array([]), np.array([]), np.array([]), np.array([])
216
+
217
+ def distance2bbox(self, points, distance, max_shape=None):
218
+ x1 = points[:, 0] - distance[:, 0]
219
+ y1 = points[:, 1] - distance[:, 1]
220
+ x2 = points[:, 0] + distance[:, 2]
221
+ y2 = points[:, 1] + distance[:, 3]
222
+ if max_shape is not None:
223
+ x1 = np.clip(x1, 0, max_shape[1])
224
+ y1 = np.clip(y1, 0, max_shape[0])
225
+ x2 = np.clip(x2, 0, max_shape[1])
226
+ y2 = np.clip(y2, 0, max_shape[0])
227
+ return np.stack([x1, y1, x2, y2], axis=-1)
228
+
229
+ def draw_detections(self, image, boxes, scores, kpts):
230
+ for box, score, kp in zip(boxes, scores, kpts):
231
+ x, y, w, h = box.astype(int)
232
+ # Draw rectangle
233
+ cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
234
+ cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
235
+ for i in range(5):
236
+ cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
237
+ # cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
238
+ return image
239
+
240
+ ROOT = os.path.dirname(os.path.abspath(__file__))
musetalk_integration/utils/face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
musetalk_integration/utils/face_detection/detection/core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+
27
+ if 'cpu' not in device and 'cuda' not in device:
28
+ if verbose:
29
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30
+ raise ValueError
31
+
32
+ def detect_from_image(self, tensor_or_path):
33
+ """Detects faces in a given image.
34
+
35
+ This function detects the faces present in a provided BGR(usually)
36
+ image. The input can be either the image itself or the path to it.
37
+
38
+ Arguments:
39
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40
+ to an image or the image itself.
41
+
42
+ Example::
43
+
44
+ >>> path_to_image = 'data/image_01.jpg'
45
+ ... detected_faces = detect_from_image(path_to_image)
46
+ [A list of bounding boxes (x1, y1, x2, y2)]
47
+ >>> image = cv2.imread(path_to_image)
48
+ ... detected_faces = detect_from_image(image)
49
+ [A list of bounding boxes (x1, y1, x2, y2)]
50
+
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55
+ """Detects faces from all the images present in a given directory.
56
+
57
+ Arguments:
58
+ path {string} -- a string containing a path that points to the folder containing the images
59
+
60
+ Keyword Arguments:
61
+ extensions {list} -- list of string containing the extensions to be
62
+ consider in the following format: ``.extension_name`` (default:
63
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64
+ folder recursively (default: {False}) show_progress_bar {bool} --
65
+ display a progressbar (default: {True})
66
+
67
+ Example:
68
+ >>> directory = 'data'
69
+ ... detected_faces = detect_from_directory(directory)
70
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71
+
72
+ """
73
+ if self.verbose:
74
+ logger = logging.getLogger(__name__)
75
+
76
+ if len(extensions) == 0:
77
+ if self.verbose:
78
+ logger.error("Expected at list one extension, but none was received.")
79
+ raise ValueError
80
+
81
+ if self.verbose:
82
+ logger.info("Constructing the list of images.")
83
+ additional_pattern = '/**/*' if recursive else '/*'
84
+ files = []
85
+ for extension in extensions:
86
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87
+
88
+ if self.verbose:
89
+ logger.info("Finished searching for images. %s images found", len(files))
90
+ logger.info("Preparing to run the detection.")
91
+
92
+ predictions = {}
93
+ for image_path in tqdm(files, disable=not show_progress_bar):
94
+ if self.verbose:
95
+ logger.info("Running the face detector on image: %s", image_path)
96
+ predictions[image_path] = self.detect_from_image(image_path)
97
+
98
+ if self.verbose:
99
+ logger.info("The detector was successfully run on all %s images", len(files))
100
+
101
+ return predictions
102
+
103
+ @property
104
+ def reference_scale(self):
105
+ raise NotImplementedError
106
+
107
+ @property
108
+ def reference_x_shift(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_y_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @staticmethod
116
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118
+
119
+ Arguments:
120
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121
+ """
122
+ if isinstance(tensor_or_path, str):
123
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124
+ elif torch.is_tensor(tensor_or_path):
125
+ # Call cpu in case its coming from cuda
126
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127
+ elif isinstance(tensor_or_path, np.ndarray):
128
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129
+ else:
130
+ raise TypeError
musetalk_integration/utils/face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
musetalk_integration/utils/face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
musetalk_integration/utils/face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+ # print(olist)
70
+
71
+ bboxlist = []
72
+ for i in range(len(olist) // 2):
73
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
74
+
75
+ olist = [oelem.cpu() for oelem in olist]
76
+ for i in range(len(olist) // 2):
77
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
78
+ FB, FC, FH, FW = ocls.size() # feature map size
79
+ stride = 2**(i + 2) # 4,8,16,32,64,128
80
+ anchor = stride * 4
81
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
82
+ for Iindex, hindex, windex in poss:
83
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
84
+ score = ocls[:, 1, hindex, windex]
85
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
86
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
87
+ variances = [0.1, 0.2]
88
+ box = batch_decode(loc, priors, variances)
89
+ box = box[:, 0] * 1.0
90
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
91
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
92
+ bboxlist = np.array(bboxlist)
93
+ if 0 == len(bboxlist):
94
+ bboxlist = np.zeros((1, BB, 5))
95
+
96
+ return bboxlist
97
+
98
+ def flip_detect(net, img, device):
99
+ img = cv2.flip(img, 1)
100
+ b = detect(net, img, device)
101
+
102
+ bboxlist = np.zeros(b.shape)
103
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
104
+ bboxlist[:, 1] = b[:, 1]
105
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
106
+ bboxlist[:, 3] = b[:, 3]
107
+ bboxlist[:, 4] = b[:, 4]
108
+ return bboxlist
109
+
110
+
111
+ def pts_to_bb(pts):
112
+ min_x, min_y = np.min(pts, axis=0)
113
+ max_x, max_y = np.max(pts, axis=0)
114
+ return np.array([min_x, min_y, max_x, max_y])
musetalk_integration/utils/face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
musetalk_integration/utils/face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ if not os.path.isfile(path_to_detector):
22
+ model_weights = load_url(models_urls['s3fd'])
23
+ else:
24
+ model_weights = torch.load(path_to_detector)
25
+
26
+ self.face_detector = s3fd()
27
+ self.face_detector.load_state_dict(model_weights)
28
+ self.face_detector.to(device)
29
+ self.face_detector.eval()
30
+
31
+ def detect_from_image(self, tensor_or_path):
32
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
33
+
34
+ bboxlist = detect(self.face_detector, image, device=self.device)
35
+ keep = nms(bboxlist, 0.3)
36
+ bboxlist = bboxlist[keep, :]
37
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38
+
39
+ return bboxlist
40
+
41
+ def detect_from_batch(self, images):
42
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
43
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46
+
47
+ return bboxlists
48
+
49
+ @property
50
+ def reference_scale(self):
51
+ return 195
52
+
53
+ @property
54
+ def reference_x_shift(self):
55
+ return 0
56
+
57
+ @property
58
+ def reference_y_shift(self):
59
+ return 0
musetalk_integration/utils/face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
musetalk_integration/utils/face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
musetalk_integration/utils/face_parsing/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from .model import BiSeNet
8
+ import torchvision.transforms as transforms
9
+
10
+ class FaceParsing():
11
+ def __init__(self, left_cheek_width=80, right_cheek_width=80):
12
+ self.net = self.model_init()
13
+ self.preprocess = self.image_preprocess()
14
+ # Ensure all size parameters are integers
15
+ cone_height = 21
16
+ tail_height = 12
17
+ total_size = cone_height + tail_height
18
+
19
+ # Create kernel with explicit integer dimensions
20
+ kernel = np.zeros((total_size, total_size), dtype=np.uint8)
21
+ center_x = total_size // 2 # Ensure center coordinates are integers
22
+
23
+ # Cone part
24
+ for row in range(cone_height):
25
+ if row < cone_height//2:
26
+ continue
27
+ width = int(2 * (row - cone_height//2) + 1)
28
+ start = int(center_x - (width // 2))
29
+ end = int(center_x + (width // 2) + 1)
30
+ kernel[row, start:end] = 1
31
+
32
+ # Vertical extension part
33
+ if cone_height > 0:
34
+ base_width = int(kernel[cone_height-1].sum())
35
+ else:
36
+ base_width = 1
37
+
38
+ for row in range(cone_height, total_size):
39
+ start = max(0, int(center_x - (base_width//2)))
40
+ end = min(total_size, int(center_x + (base_width//2) + 1))
41
+ kernel[row, start:end] = 1
42
+ self.kernel = kernel
43
+
44
+ # Modify cheek erosion kernel to be flatter ellipse
45
+ self.cheek_kernel = cv2.getStructuringElement(
46
+ cv2.MORPH_ELLIPSE, (35, 3))
47
+
48
+ # Add cheek area mask (protect chin area)
49
+ self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
50
+
51
+ def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
52
+ """Create cheek area mask (1/4 area on both sides)"""
53
+ mask = np.zeros((512, 512), dtype=np.uint8)
54
+ center = 512 // 2
55
+ cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
56
+ cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
57
+ return mask
58
+
59
+ def model_init(self,
60
+ resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
61
+ model_pth='./models/face-parse-bisent/79999_iter.pth'):
62
+ net = BiSeNet(resnet_path)
63
+ if torch.cuda.is_available():
64
+ net.cuda()
65
+ net.load_state_dict(torch.load(model_pth))
66
+ else:
67
+ net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
68
+ net.eval()
69
+ return net
70
+
71
+ def image_preprocess(self):
72
+ return transforms.Compose([
73
+ transforms.ToTensor(),
74
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
75
+ ])
76
+
77
+ def __call__(self, image, size=(512, 512), mode="raw"):
78
+ if isinstance(image, str):
79
+ image = Image.open(image)
80
+
81
+ width, height = image.size
82
+ with torch.no_grad():
83
+ image = image.resize(size, Image.BILINEAR)
84
+ img = self.preprocess(image)
85
+ if torch.cuda.is_available():
86
+ img = torch.unsqueeze(img, 0).cuda()
87
+ else:
88
+ img = torch.unsqueeze(img, 0)
89
+ out = self.net(img)[0]
90
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
91
+
92
+ # Add 14:neck, remove 10:nose and 7:8:9
93
+ if mode == "neck":
94
+ parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
95
+ parsing[np.where(parsing!=255)] = 0
96
+ elif mode == "jaw":
97
+ face_region = np.isin(parsing, [1])*255
98
+ face_region = face_region.astype(np.uint8)
99
+ original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
100
+ eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
101
+ face_region = cv2.bitwise_and(eroded, self.cheek_mask)
102
+ face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
103
+ parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
104
+ parsing[np.isin(parsing, [11, 12, 13])] = 255
105
+ parsing[np.where(parsing!=255)] = 0
106
+ else:
107
+ parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
108
+ parsing[np.where(parsing!=255)] = 0
109
+
110
+ parsing = Image.fromarray(parsing.astype(np.uint8))
111
+ return parsing
112
+
113
+ if __name__ == "__main__":
114
+ fp = FaceParsing()
115
+ segmap = fp('154_small.png')
116
+ segmap.save('res.png')
117
+
musetalk_integration/utils/face_parsing/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from .resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, resnet_path, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18(resnet_path)
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath(resnet_path)
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
musetalk_integration/utils/face_parsing/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self, model_path):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight(model_path)
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self, model_path):
83
+ state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
musetalk_integration/whisper/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import io
3
+ import os
4
+ import urllib
5
+ import warnings
6
+ from typing import List, Optional, Union
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12
+ from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13
+ from .model import Whisper, ModelDimensions
14
+ from .transcribe import transcribe
15
+
16
+
17
+ _MODELS = {
18
+ "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
19
+ "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
20
+ "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
21
+ "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
22
+ "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
23
+ "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
24
+ "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
25
+ "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
26
+ "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
27
+ "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
28
+ "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
29
+ "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
30
+ }
31
+
32
+
33
+ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
34
+ os.makedirs(root, exist_ok=True)
35
+
36
+ expected_sha256 = url.split("/")[-2]
37
+ download_target = os.path.join(root, os.path.basename(url))
38
+
39
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
40
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
41
+
42
+ if os.path.isfile(download_target):
43
+ model_bytes = open(download_target, "rb").read()
44
+ if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
45
+ return model_bytes if in_memory else download_target
46
+ else:
47
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
48
+
49
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
50
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
51
+ while True:
52
+ buffer = source.read(8192)
53
+ if not buffer:
54
+ break
55
+
56
+ output.write(buffer)
57
+ loop.update(len(buffer))
58
+
59
+ model_bytes = open(download_target, "rb").read()
60
+ if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
61
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
62
+
63
+ return model_bytes if in_memory else download_target
64
+
65
+
66
+ def available_models() -> List[str]:
67
+ """Returns the names of available models"""
68
+ return list(_MODELS.keys())
69
+
70
+
71
+ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
72
+ """
73
+ Load a Whisper ASR model
74
+
75
+ Parameters
76
+ ----------
77
+ name : str
78
+ one of the official model names listed by `whisper.available_models()`, or
79
+ path to a model checkpoint containing the model dimensions and the model state_dict.
80
+ device : Union[str, torch.device]
81
+ the PyTorch device to put the model into
82
+ download_root: str
83
+ path to download the model files; by default, it uses "~/.cache/whisper"
84
+ in_memory: bool
85
+ whether to preload the model weights into host memory
86
+
87
+ Returns
88
+ -------
89
+ model : Whisper
90
+ The Whisper ASR model instance
91
+ """
92
+
93
+ if device is None:
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+ if download_root is None:
96
+ download_root = os.getenv(
97
+ "XDG_CACHE_HOME",
98
+ os.path.join(os.path.expanduser("~"), ".cache", "whisper")
99
+ )
100
+
101
+ if name in _MODELS:
102
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
103
+ elif os.path.isfile(name):
104
+ checkpoint_file = open(name, "rb").read() if in_memory else name
105
+ else:
106
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
107
+
108
+ with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
109
+ checkpoint = torch.load(fp, map_location=device)
110
+ del checkpoint_file
111
+
112
+ dims = ModelDimensions(**checkpoint["dims"])
113
+ model = Whisper(dims)
114
+ model.load_state_dict(checkpoint["model_state_dict"])
115
+
116
+ return model.to(device)
musetalk_integration/whisper/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .transcribe import cli
2
+
3
+
4
+ cli()
musetalk_integration/whisper/audio.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Union
4
+
5
+ import ffmpeg
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from .utils import exact_div
11
+
12
+ # hard-coded audio hyperparameters
13
+ SAMPLE_RATE = 16000
14
+ N_FFT = 400
15
+ N_MELS = 80
16
+ HOP_LENGTH = 160
17
+ CHUNK_LENGTH = 30
18
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
19
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
20
+
21
+
22
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
23
+ """
24
+ Open an audio file and read as mono waveform, resampling as necessary
25
+
26
+ Parameters
27
+ ----------
28
+ file: str
29
+ The audio file to open
30
+
31
+ sr: int
32
+ The sample rate to resample the audio if necessary
33
+
34
+ Returns
35
+ -------
36
+ A NumPy array containing the audio waveform, in float32 dtype.
37
+ """
38
+ try:
39
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
40
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
41
+ out, _ = (
42
+ ffmpeg.input(file, threads=0)
43
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
44
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
45
+ )
46
+ except ffmpeg.Error as e:
47
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
48
+
49
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
50
+
51
+
52
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
53
+ """
54
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
55
+ """
56
+ if torch.is_tensor(array):
57
+ if array.shape[axis] > length:
58
+ array = array.index_select(dim=axis, index=torch.arange(length))
59
+
60
+ if array.shape[axis] < length:
61
+ pad_widths = [(0, 0)] * array.ndim
62
+ pad_widths[axis] = (0, length - array.shape[axis])
63
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
64
+ else:
65
+ if array.shape[axis] > length:
66
+ array = array.take(indices=range(length), axis=axis)
67
+
68
+ if array.shape[axis] < length:
69
+ pad_widths = [(0, 0)] * array.ndim
70
+ pad_widths[axis] = (0, length - array.shape[axis])
71
+ array = np.pad(array, pad_widths)
72
+
73
+ return array
74
+
75
+
76
+ @lru_cache(maxsize=None)
77
+ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
78
+ """
79
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
80
+ Allows decoupling librosa dependency; saved using:
81
+
82
+ np.savez_compressed(
83
+ "mel_filters.npz",
84
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
85
+ )
86
+ """
87
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
88
+ with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
89
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
90
+
91
+
92
+ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
93
+ """
94
+ Compute the log-Mel spectrogram of
95
+
96
+ Parameters
97
+ ----------
98
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
99
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
100
+
101
+ n_mels: int
102
+ The number of Mel-frequency filters, only 80 is supported
103
+
104
+ Returns
105
+ -------
106
+ torch.Tensor, shape = (80, n_frames)
107
+ A Tensor that contains the Mel spectrogram
108
+ """
109
+ if not torch.is_tensor(audio):
110
+ if isinstance(audio, str):
111
+ audio = load_audio(audio)
112
+ audio = torch.from_numpy(audio)
113
+
114
+ window = torch.hann_window(N_FFT).to(audio.device)
115
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
116
+
117
+ magnitudes = stft[:, :-1].abs() ** 2
118
+
119
+ filters = mel_filters(audio.device, n_mels)
120
+ mel_spec = filters @ magnitudes
121
+
122
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
123
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
124
+ log_spec = (log_spec + 4.0) / 4.0
125
+ return log_spec
musetalk_integration/whisper/decoding.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical
9
+
10
+ from .audio import CHUNK_LENGTH
11
+ from .tokenizer import Tokenizer, get_tokenizer
12
+ from .utils import compression_ratio
13
+
14
+ if TYPE_CHECKING:
15
+ from .model import Whisper
16
+
17
+
18
+ @torch.no_grad()
19
+ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
20
+ """
21
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
22
+ of the most probable language tokens and the probability distribution over all language tokens.
23
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
24
+
25
+ Returns
26
+ -------
27
+ language_tokens : Tensor, shape = (n_audio,)
28
+ ids of the most probable language tokens, which appears after the startoftranscript token.
29
+ language_probs : List[Dict[str, float]], length = n_audio
30
+ list of dictionaries containing the probability distribution over all languages.
31
+ """
32
+ if tokenizer is None:
33
+ tokenizer = get_tokenizer(model.is_multilingual)
34
+ if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
35
+ raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
36
+
37
+ single = mel.ndim == 2
38
+ if single:
39
+ mel = mel.unsqueeze(0)
40
+
41
+ # skip encoder forward pass if already-encoded audio features were given
42
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
43
+ mel = model.encoder(mel)
44
+
45
+ # forward pass using a single token, startoftranscript
46
+ n_audio = mel.shape[0]
47
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
48
+ logits = model.logits(x, mel)[:, 0]
49
+
50
+ # collect detected languages; suppress all non-language tokens
51
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
52
+ mask[list(tokenizer.all_language_tokens)] = False
53
+ logits[:, mask] = -np.inf
54
+ language_tokens = logits.argmax(dim=-1)
55
+ language_token_probs = logits.softmax(dim=-1).cpu()
56
+ language_probs = [
57
+ {
58
+ c: language_token_probs[i, j].item()
59
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
60
+ }
61
+ for i in range(n_audio)
62
+ ]
63
+
64
+ if single:
65
+ language_tokens = language_tokens[0]
66
+ language_probs = language_probs[0]
67
+
68
+ return language_tokens, language_probs
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class DecodingOptions:
73
+ task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
74
+ language: Optional[str] = None # language that the audio is in; uses detected language if None
75
+
76
+ # sampling-related options
77
+ temperature: float = 0.0
78
+ sample_len: Optional[int] = None # maximum number of tokens to sample
79
+ best_of: Optional[int] = None # number of independent samples to collect, when t > 0
80
+ beam_size: Optional[int] = None # number of beams in beam search, when t == 0
81
+ patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
82
+
83
+ # options for ranking generations (either beams or best-of-N samples)
84
+ length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
85
+
86
+ # prompt, prefix, and token suppression
87
+ prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
88
+ prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
89
+ suppress_blank: bool = True # this will suppress blank outputs
90
+
91
+ # list of tokens ids (or comma-separated token ids) to suppress
92
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
93
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
94
+
95
+ # timestamp sampling options
96
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
97
+ max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
98
+
99
+ # implementation details
100
+ fp16: bool = True # use fp16 for most of the calculation
101
+
102
+
103
+ @dataclass(frozen=True)
104
+ class DecodingResult:
105
+ audio_features: Tensor
106
+ language: str
107
+ encoder_embeddings: np.ndarray
108
+ decoder_embeddings: np.ndarray
109
+ language_probs: Optional[Dict[str, float]] = None
110
+ tokens: List[int] = field(default_factory=list)
111
+ text: str = ""
112
+ avg_logprob: float = np.nan
113
+ no_speech_prob: float = np.nan
114
+ temperature: float = np.nan
115
+ compression_ratio: float = np.nan
116
+
117
+
118
+ class Inference:
119
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
120
+ """Perform a forward pass on the decoder and return per-token logits"""
121
+ raise NotImplementedError
122
+
123
+ def rearrange_kv_cache(self, source_indices) -> None:
124
+ """Update the key-value cache according to the updated beams"""
125
+ raise NotImplementedError
126
+
127
+ def cleanup_caching(self) -> None:
128
+ """Clean up any resources or hooks after decoding is finished"""
129
+ pass
130
+
131
+
132
+ class PyTorchInference(Inference):
133
+ def __init__(self, model: "Whisper", initial_token_length: int):
134
+ self.model: "Whisper" = model
135
+ self.initial_token_length = initial_token_length
136
+ self.kv_cache = {}
137
+ self.hooks = []
138
+
139
+ def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
140
+ if not self.kv_cache:
141
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
142
+
143
+ if tokens.shape[-1] > self.initial_token_length:
144
+ # only need to use the last token except in the first forward pass
145
+ tokens = tokens[:, -1:]
146
+
147
+ return_val = self.model.decoder(tokens, audio_features,
148
+ kv_cache=self.kv_cache, include_embeddings=include_embeddings)
149
+ return return_val
150
+
151
+ def cleanup_caching(self):
152
+ for hook in self.hooks:
153
+ hook.remove()
154
+
155
+ self.kv_cache = {}
156
+ self.hooks = []
157
+
158
+ def rearrange_kv_cache(self, source_indices):
159
+ for module, tensor in self.kv_cache.items():
160
+ # update the key/value cache to contain the selected sequences
161
+ self.kv_cache[module] = tensor[source_indices].detach()
162
+
163
+
164
+ class SequenceRanker:
165
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
166
+ """
167
+ Given a list of groups of samples and their cumulative log probabilities,
168
+ return the indices of the samples in each group to select as the final result
169
+ """
170
+ raise NotImplementedError
171
+
172
+
173
+ class MaximumLikelihoodRanker(SequenceRanker):
174
+ """
175
+ Select the sample with the highest log probabilities, penalized using either
176
+ a simple length normalization or Google NMT paper's length penalty
177
+ """
178
+
179
+ def __init__(self, length_penalty: Optional[float]):
180
+ self.length_penalty = length_penalty
181
+
182
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
183
+ def scores(logprobs, lengths):
184
+ result = []
185
+ for logprob, length in zip(logprobs, lengths):
186
+ if self.length_penalty is None:
187
+ penalty = length
188
+ else:
189
+ # from the Google NMT paper
190
+ penalty = ((5 + length) / 6) ** self.length_penalty
191
+ result.append(logprob / penalty)
192
+ return result
193
+
194
+ # get the sequence with the highest score
195
+ lengths = [[len(t) for t in s] for s in tokens]
196
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
197
+
198
+
199
+ class TokenDecoder:
200
+ def reset(self):
201
+ """Initialize any stateful variables for decoding a new sequence"""
202
+
203
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
204
+ """Specify how to select the next token, based on the current trace and logits
205
+
206
+ Parameters
207
+ ----------
208
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
209
+ all tokens in the context so far, including the prefix and sot_sequence tokens
210
+
211
+ logits : Tensor, shape = (n_batch, vocab_size)
212
+ per-token logits of the probability distribution at the current step
213
+
214
+ sum_logprobs : Tensor, shape = (n_batch)
215
+ cumulative log probabilities for each sequence
216
+
217
+ Returns
218
+ -------
219
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
220
+ the tokens, appended with the selected next token
221
+
222
+ completed : bool
223
+ True if all sequences has reached the end of text
224
+
225
+ """
226
+ raise NotImplementedError
227
+
228
+ def finalize(
229
+ self, tokens: Tensor, sum_logprobs: Tensor
230
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
231
+ """Finalize search and return the final candidate sequences
232
+
233
+ Parameters
234
+ ----------
235
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
236
+ all tokens in the context so far, including the prefix and sot_sequence
237
+
238
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
239
+ cumulative log probabilities for each sequence
240
+
241
+ Returns
242
+ -------
243
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
244
+ sequence of Tensors containing candidate token sequences, for each audio input
245
+
246
+ sum_logprobs : List[List[float]], length = n_audio
247
+ sequence of cumulative log probabilities corresponding to the above
248
+
249
+ """
250
+ raise NotImplementedError
251
+
252
+
253
+ class GreedyDecoder(TokenDecoder):
254
+ def __init__(self, temperature: float, eot: int):
255
+ self.temperature = temperature
256
+ self.eot = eot
257
+
258
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
259
+ temperature = self.temperature
260
+ if temperature == 0:
261
+ next_tokens = logits.argmax(dim=-1)
262
+ else:
263
+ next_tokens = Categorical(logits=logits / temperature).sample()
264
+
265
+ logprobs = F.log_softmax(logits.float(), dim=-1)
266
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
267
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
268
+
269
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
270
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
271
+
272
+ completed = (tokens[:, -1] == self.eot).all()
273
+ return tokens, completed
274
+
275
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
276
+ # make sure each sequence has at least one EOT token at the end
277
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
278
+ return tokens, sum_logprobs.tolist()
279
+
280
+
281
+ class BeamSearchDecoder(TokenDecoder):
282
+ def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
283
+ self.beam_size = beam_size
284
+ self.eot = eot
285
+ self.inference = inference
286
+ self.patience = patience or 1.0
287
+ self.max_candidates: int = round(beam_size * self.patience)
288
+ self.finished_sequences = None
289
+
290
+ assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
291
+
292
+ def reset(self):
293
+ self.finished_sequences = None
294
+
295
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
296
+ if tokens.shape[0] % self.beam_size != 0:
297
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
298
+
299
+ n_audio = tokens.shape[0] // self.beam_size
300
+ if self.finished_sequences is None: # for the first update
301
+ self.finished_sequences = [{} for _ in range(n_audio)]
302
+
303
+ logprobs = F.log_softmax(logits.float(), dim=-1)
304
+ next_tokens, source_indices, finished_sequences = [], [], []
305
+ for i in range(n_audio):
306
+ scores, sources, finished = {}, {}, {}
307
+
308
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
309
+ for j in range(self.beam_size):
310
+ idx = i * self.beam_size + j
311
+ prefix = tokens[idx].tolist()
312
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
313
+ new_logprob = (sum_logprobs[idx] + logprob).item()
314
+ sequence = tuple(prefix + [token.item()])
315
+ scores[sequence] = new_logprob
316
+ sources[sequence] = idx
317
+
318
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
319
+ saved = 0
320
+ for sequence in sorted(scores, key=scores.get, reverse=True):
321
+ if sequence[-1] == self.eot:
322
+ finished[sequence] = scores[sequence]
323
+ else:
324
+ sum_logprobs[len(next_tokens)] = scores[sequence]
325
+ next_tokens.append(sequence)
326
+ source_indices.append(sources[sequence])
327
+
328
+ saved += 1
329
+ if saved == self.beam_size:
330
+ break
331
+
332
+ finished_sequences.append(finished)
333
+
334
+ tokens = torch.tensor(next_tokens, device=tokens.device)
335
+ self.inference.rearrange_kv_cache(source_indices)
336
+
337
+ # add newly finished sequences to self.finished_sequences
338
+ assert len(self.finished_sequences) == len(finished_sequences)
339
+ for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
340
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
341
+ if len(previously_finished) >= self.max_candidates:
342
+ break # the candidate list is full
343
+ previously_finished[seq] = newly_finished[seq]
344
+
345
+ # mark as completed if all audio has enough number of samples
346
+ completed = all(
347
+ len(sequences) >= self.max_candidates for sequences in self.finished_sequences
348
+ )
349
+ return tokens, completed
350
+
351
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
352
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
353
+ sum_logprobs = sum_logprobs.cpu()
354
+ for i, sequences in enumerate(self.finished_sequences):
355
+ if len(sequences) < self.beam_size: # when not enough sequences are finished
356
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
357
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
358
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
359
+ if len(sequences) >= self.beam_size:
360
+ break
361
+
362
+ tokens: List[List[Tensor]] = [
363
+ [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
364
+ ]
365
+ sum_logprobs: List[List[float]] = [
366
+ list(sequences.values()) for sequences in self.finished_sequences
367
+ ]
368
+ return tokens, sum_logprobs
369
+
370
+
371
+ class LogitFilter:
372
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
373
+ """Apply any filtering or masking to logits in-place
374
+
375
+ Parameters
376
+ ----------
377
+ logits : Tensor, shape = (n_batch, vocab_size)
378
+ per-token logits of the probability distribution at the current step
379
+
380
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
381
+ all tokens in the context so far, including the prefix and sot_sequence tokens
382
+
383
+ """
384
+ raise NotImplementedError
385
+
386
+
387
+ class SuppressBlank(LogitFilter):
388
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
389
+ self.tokenizer = tokenizer
390
+ self.sample_begin = sample_begin
391
+
392
+ def apply(self, logits: Tensor, tokens: Tensor):
393
+ if tokens.shape[1] == self.sample_begin:
394
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
395
+
396
+
397
+ class SuppressTokens(LogitFilter):
398
+ def __init__(self, suppress_tokens: Sequence[int]):
399
+ self.suppress_tokens = list(suppress_tokens)
400
+
401
+ def apply(self, logits: Tensor, tokens: Tensor):
402
+ logits[:, self.suppress_tokens] = -np.inf
403
+
404
+
405
+ class ApplyTimestampRules(LogitFilter):
406
+ def __init__(
407
+ self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
408
+ ):
409
+ self.tokenizer = tokenizer
410
+ self.sample_begin = sample_begin
411
+ self.max_initial_timestamp_index = max_initial_timestamp_index
412
+
413
+ def apply(self, logits: Tensor, tokens: Tensor):
414
+ # suppress <|notimestamps|> which is handled by without_timestamps
415
+ if self.tokenizer.no_timestamps is not None:
416
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
417
+
418
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
419
+ for k in range(tokens.shape[0]):
420
+ seq = [t for t in tokens[k, self.sample_begin :].tolist()]
421
+ last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
422
+ penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
423
+
424
+ if last_was_timestamp:
425
+ if penultimate_was_timestamp: # has to be non-timestamp
426
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
427
+ else: # cannot be normal text tokens
428
+ logits[k, : self.tokenizer.eot] = -np.inf
429
+
430
+ # apply the `max_initial_timestamp` option
431
+ if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
432
+ last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
433
+ logits[:, last_allowed + 1 :] = -np.inf
434
+
435
+ # if sum of probability over timestamps is above any other token, sample timestamp
436
+ logprobs = F.log_softmax(logits.float(), dim=-1)
437
+ for k in range(tokens.shape[0]):
438
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
439
+ max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
440
+ if timestamp_logprob > max_text_token_logprob:
441
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
442
+
443
+
444
+ class DecodingTask:
445
+ inference: Inference
446
+ sequence_ranker: SequenceRanker
447
+ decoder: TokenDecoder
448
+ logit_filters: List[LogitFilter]
449
+
450
+ def __init__(self, model: "Whisper", options: DecodingOptions):
451
+ self.model = model
452
+
453
+ language = options.language or "en"
454
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
455
+ self.tokenizer: Tokenizer = tokenizer
456
+ self.options: DecodingOptions = self._verify_options(options)
457
+
458
+ self.n_group: int = options.beam_size or options.best_of or 1
459
+ self.n_ctx: int = model.dims.n_text_ctx
460
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
461
+
462
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
463
+ if self.options.without_timestamps:
464
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
465
+
466
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
467
+ self.sample_begin: int = len(self.initial_tokens)
468
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
469
+
470
+ # inference: implements the forward pass through the decoder, including kv caching
471
+ self.inference = PyTorchInference(model, len(self.initial_tokens))
472
+
473
+ # sequence ranker: implements how to rank a group of sampled sequences
474
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
475
+
476
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
477
+ if options.beam_size is not None:
478
+ self.decoder = BeamSearchDecoder(
479
+ options.beam_size, tokenizer.eot, self.inference, options.patience
480
+ )
481
+ else:
482
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
483
+
484
+ # logit filters: applies various rules to suppress or penalize certain tokens
485
+ self.logit_filters = []
486
+ if self.options.suppress_blank:
487
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
488
+ if self.options.suppress_tokens:
489
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
490
+ if not options.without_timestamps:
491
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
492
+ max_initial_timestamp_index = None
493
+ if options.max_initial_timestamp:
494
+ max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
495
+ self.logit_filters.append(
496
+ ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
497
+ )
498
+
499
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
500
+ if options.beam_size is not None and options.best_of is not None:
501
+ raise ValueError("beam_size and best_of can't be given together")
502
+ if options.temperature == 0:
503
+ if options.best_of is not None:
504
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
505
+ if options.patience is not None and options.beam_size is None:
506
+ raise ValueError("patience requires beam_size to be given")
507
+ if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
508
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
509
+
510
+ return options
511
+
512
+ def _get_initial_tokens(self) -> Tuple[int]:
513
+ tokens = list(self.sot_sequence)
514
+ prefix = self.options.prefix
515
+ prompt = self.options.prompt
516
+
517
+ if prefix:
518
+ prefix_tokens = (
519
+ self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
520
+ )
521
+ if self.sample_len is not None:
522
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
523
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
524
+ tokens = tokens + prefix_tokens
525
+
526
+ if prompt:
527
+ prompt_tokens = (
528
+ self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
529
+ )
530
+ tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
531
+
532
+ return tuple(tokens)
533
+
534
+ def _get_suppress_tokens(self) -> Tuple[int]:
535
+ suppress_tokens = self.options.suppress_tokens
536
+
537
+ if isinstance(suppress_tokens, str):
538
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
539
+
540
+ if -1 in suppress_tokens:
541
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
542
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
543
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
544
+ suppress_tokens = [] # interpret empty string as an empty list
545
+ else:
546
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
547
+
548
+ suppress_tokens.extend(
549
+ [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
550
+ )
551
+ if self.tokenizer.no_speech is not None:
552
+ # no-speech probability is collected separately
553
+ suppress_tokens.append(self.tokenizer.no_speech)
554
+
555
+ return tuple(sorted(set(suppress_tokens)))
556
+
557
+ def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
558
+ if self.options.fp16:
559
+ mel = mel.half()
560
+
561
+ if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
562
+ # encoded audio features are given; skip audio encoding
563
+ audio_features = mel
564
+ else:
565
+ result = self.model.encoder(mel, include_embeddings)
566
+ if include_embeddings:
567
+ audio_features, embeddings = result
568
+ else:
569
+ audio_features = result
570
+
571
+ if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
572
+ return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
573
+
574
+ if include_embeddings:
575
+ return audio_features, embeddings
576
+ else:
577
+ return audio_features
578
+
579
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
580
+ languages = [self.options.language] * audio_features.shape[0]
581
+ lang_probs = None
582
+
583
+ if self.options.language is None or self.options.task == "lang_id":
584
+ lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
585
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
586
+ if self.options.language is None:
587
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
588
+
589
+ return languages, lang_probs
590
+
591
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
592
+ assert audio_features.shape[0] == tokens.shape[0]
593
+ n_batch = tokens.shape[0]
594
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
595
+ no_speech_probs = [np.nan] * n_batch
596
+
597
+ try:
598
+ embeddings = []
599
+ for i in range(self.sample_len):
600
+ logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
601
+
602
+ if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
603
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
604
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
605
+
606
+ # now we need to consider the logits at the last token only
607
+ logits = logits[:, -1]
608
+ token_embeddings = token_embeddings[:, :, -1]
609
+
610
+ # Append embeddings together
611
+ embeddings.append(token_embeddings)
612
+
613
+ # apply the logit filters, e.g. for suppressing or applying penalty to
614
+ for logit_filter in self.logit_filters:
615
+ logit_filter.apply(logits, tokens)
616
+
617
+ # expand the tokens tensor with the selected next tokens
618
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
619
+
620
+ if completed or tokens.shape[-1] > self.n_ctx:
621
+ break
622
+ finally:
623
+ if completed:
624
+ embeddings = embeddings[:-1]
625
+ embeddings = np.stack(embeddings, 2)
626
+ self.inference.cleanup_caching()
627
+
628
+ return tokens, sum_logprobs, no_speech_probs, embeddings
629
+
630
+ @torch.no_grad()
631
+ def run(self, mel: Tensor) -> List[DecodingResult]:
632
+ self.decoder.reset()
633
+ tokenizer: Tokenizer = self.tokenizer
634
+ n_audio: int = mel.shape[0]
635
+
636
+ # encoder forward pass
637
+ forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
638
+ audio_features, encoder_embeddings = forward_pass
639
+ tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
640
+
641
+ # detect language if requested, overwriting the language token
642
+ languages, language_probs = self._detect_language(audio_features, tokens)
643
+ if self.options.task == "lang_id":
644
+ return [
645
+ DecodingResult(audio_features=features, language=language, language_probs=probs)
646
+ for features, language, probs in zip(audio_features, languages, language_probs)
647
+ ]
648
+
649
+ # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
650
+ audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
651
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
652
+
653
+ # call the main sampling loop
654
+ tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
655
+
656
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
657
+ audio_features = audio_features[:: self.n_group]
658
+ no_speech_probs = no_speech_probs[:: self.n_group]
659
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
660
+
661
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
662
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
663
+
664
+ # get the final candidates for each group, and slice between the first sampled token and EOT
665
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
666
+ tokens: List[List[Tensor]] = [
667
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
668
+ ]
669
+
670
+ # select the top-ranked sample in each group
671
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
672
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
673
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
674
+
675
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
676
+ avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
677
+
678
+ fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
679
+ if len(set(map(len, fields))) != 1:
680
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
681
+
682
+ return [
683
+ DecodingResult(
684
+ audio_features=features,
685
+ language=language,
686
+ tokens=tokens,
687
+ text=text,
688
+ avg_logprob=avg_logprob,
689
+ no_speech_prob=no_speech_prob,
690
+ temperature=self.options.temperature,
691
+ compression_ratio=compression_ratio(text),
692
+ encoder_embeddings=encoder_embeddings,
693
+ decoder_embeddings=decoder_embeddings
694
+ )
695
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
696
+ ]
697
+
698
+
699
+ @torch.no_grad()
700
+ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
701
+ """
702
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
703
+
704
+ Parameters
705
+ ----------
706
+ model: Whisper
707
+ the Whisper model instance
708
+
709
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
710
+ A tensor containing the Mel spectrogram(s)
711
+
712
+ options: DecodingOptions
713
+ A dataclass that contains all necessary options for decoding 30-second segments
714
+
715
+ Returns
716
+ -------
717
+ result: Union[DecodingResult, List[DecodingResult]]
718
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
719
+ """
720
+ single = mel.ndim == 2
721
+ if single:
722
+ mel = mel.unsqueeze(0)
723
+
724
+ result = DecodingTask(model, options).run(mel)
725
+
726
+ if single:
727
+ result = result[0]
728
+
729
+ return result
musetalk_integration/whisper/model.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+ from typing import Iterable, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch import nn
10
+
11
+ from .transcribe import transcribe as transcribe_function
12
+ from .decoding import detect_language as detect_language_function, decode as decode_function
13
+
14
+
15
+ @dataclass
16
+ class ModelDimensions:
17
+ n_mels: int
18
+ n_audio_ctx: int
19
+ n_audio_state: int
20
+ n_audio_head: int
21
+ n_audio_layer: int
22
+ n_vocab: int
23
+ n_text_ctx: int
24
+ n_text_state: int
25
+ n_text_head: int
26
+ n_text_layer: int
27
+
28
+
29
+ class LayerNorm(nn.LayerNorm):
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ return super().forward(x.float()).type(x.dtype)
32
+
33
+
34
+ class Linear(nn.Linear):
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ return F.linear(
37
+ x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
38
+ )
39
+
40
+
41
+ class Conv1d(nn.Conv1d):
42
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
43
+ return super()._conv_forward(
44
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
45
+ )
46
+
47
+
48
+ def sinusoids(length, channels, max_timescale=10000):
49
+ """Returns sinusoids for positional embedding"""
50
+ assert channels % 2 == 0
51
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
52
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
53
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
54
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
55
+
56
+
57
+ class MultiHeadAttention(nn.Module):
58
+ def __init__(self, n_state: int, n_head: int):
59
+ super().__init__()
60
+ self.n_head = n_head
61
+ self.query = Linear(n_state, n_state)
62
+ self.key = Linear(n_state, n_state, bias=False)
63
+ self.value = Linear(n_state, n_state)
64
+ self.out = Linear(n_state, n_state)
65
+
66
+ def forward(
67
+ self,
68
+ x: Tensor,
69
+ xa: Optional[Tensor] = None,
70
+ mask: Optional[Tensor] = None,
71
+ kv_cache: Optional[dict] = None,
72
+ ):
73
+ q = self.query(x)
74
+
75
+ if kv_cache is None or xa is None:
76
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
77
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
78
+ k = self.key(x if xa is None else xa)
79
+ v = self.value(x if xa is None else xa)
80
+ else:
81
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
82
+ k = kv_cache.get(self.key, self.key(xa))
83
+ v = kv_cache.get(self.value, self.value(xa))
84
+
85
+ wv = self.qkv_attention(q, k, v, mask)
86
+ return self.out(wv)
87
+
88
+ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
89
+ n_batch, n_ctx, n_state = q.shape
90
+ scale = (n_state // self.n_head) ** -0.25
91
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
92
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
93
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
94
+
95
+ qk = q @ k
96
+ if mask is not None:
97
+ qk = qk + mask[:n_ctx, :n_ctx]
98
+
99
+ w = F.softmax(qk.float(), dim=-1).to(q.dtype)
100
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
101
+
102
+
103
+ class ResidualAttentionBlock(nn.Module):
104
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
105
+ super().__init__()
106
+
107
+ self.attn = MultiHeadAttention(n_state, n_head)
108
+ self.attn_ln = LayerNorm(n_state)
109
+
110
+ self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
111
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
112
+
113
+ n_mlp = n_state * 4
114
+ self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
115
+ self.mlp_ln = LayerNorm(n_state)
116
+
117
+ def forward(
118
+ self,
119
+ x: Tensor,
120
+ xa: Optional[Tensor] = None,
121
+ mask: Optional[Tensor] = None,
122
+ kv_cache: Optional[dict] = None,
123
+ ):
124
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
125
+ if self.cross_attn:
126
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
127
+ x = x + self.mlp(self.mlp_ln(x))
128
+ return x
129
+
130
+
131
+ class AudioEncoder(nn.Module):
132
+ def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
133
+ super().__init__()
134
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
135
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
136
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
137
+
138
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
139
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
140
+ )
141
+ self.ln_post = LayerNorm(n_state)
142
+
143
+ def forward(self, x: Tensor, include_embeddings: bool = False):
144
+ """
145
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
146
+ the mel spectrogram of the audio
147
+ include_embeddings: bool
148
+ whether to include intermediate steps in the output
149
+ """
150
+ x = F.gelu(self.conv1(x))
151
+ x = F.gelu(self.conv2(x))
152
+ x = x.permute(0, 2, 1)
153
+
154
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
155
+ x = (x + self.positional_embedding).to(x.dtype)
156
+
157
+ if include_embeddings:
158
+ embeddings = [x.cpu().detach().numpy()]
159
+
160
+ for block in self.blocks:
161
+ x = block(x)
162
+ if include_embeddings:
163
+ embeddings.append(x.cpu().detach().numpy())
164
+
165
+ x = self.ln_post(x)
166
+
167
+ if include_embeddings:
168
+ embeddings = np.stack(embeddings, axis=1)
169
+ return x, embeddings
170
+ else:
171
+ return x
172
+
173
+
174
+ class TextDecoder(nn.Module):
175
+ def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
176
+ super().__init__()
177
+
178
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
179
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
180
+
181
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
182
+ [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
183
+ )
184
+ self.ln = LayerNorm(n_state)
185
+
186
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
187
+ self.register_buffer("mask", mask, persistent=False)
188
+
189
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
190
+ """
191
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
192
+ the text tokens
193
+ xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
194
+ the encoded audio features to be attended on
195
+ include_embeddings : bool
196
+ Whether to include intermediate values in the output to this function
197
+ """
198
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
199
+ x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
200
+ x = x.to(xa.dtype)
201
+
202
+ if include_embeddings:
203
+ embeddings = [x.cpu().detach().numpy()]
204
+
205
+ for block in self.blocks:
206
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
207
+ if include_embeddings:
208
+ embeddings.append(x.cpu().detach().numpy())
209
+
210
+ x = self.ln(x)
211
+ logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
212
+
213
+ if include_embeddings:
214
+ embeddings = np.stack(embeddings, axis=1)
215
+ return logits, embeddings
216
+ else:
217
+ return logits
218
+
219
+
220
+ class Whisper(nn.Module):
221
+ def __init__(self, dims: ModelDimensions):
222
+ super().__init__()
223
+ self.dims = dims
224
+ self.encoder = AudioEncoder(
225
+ self.dims.n_mels,
226
+ self.dims.n_audio_ctx,
227
+ self.dims.n_audio_state,
228
+ self.dims.n_audio_head,
229
+ self.dims.n_audio_layer,
230
+ )
231
+ self.decoder = TextDecoder(
232
+ self.dims.n_vocab,
233
+ self.dims.n_text_ctx,
234
+ self.dims.n_text_state,
235
+ self.dims.n_text_head,
236
+ self.dims.n_text_layer,
237
+ )
238
+
239
+ def embed_audio(self, mel: torch.Tensor):
240
+ return self.encoder.forward(mel)
241
+
242
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
243
+ return self.decoder.forward(tokens, audio_features)
244
+
245
+ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
246
+ return self.decoder(tokens, self.encoder(mel))
247
+
248
+ @property
249
+ def device(self):
250
+ return next(self.parameters()).device
251
+
252
+ @property
253
+ def is_multilingual(self):
254
+ return self.dims.n_vocab == 51865
255
+
256
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
257
+ """
258
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
259
+ tensors calculated for the previous positions. This method returns a dictionary that stores
260
+ all caches, and the necessary hooks for the key and value projection modules that save the
261
+ intermediate tensors to be reused during later calculations.
262
+
263
+ Returns
264
+ -------
265
+ cache : Dict[nn.Module, torch.Tensor]
266
+ A dictionary object mapping the key/value projection modules to its cache
267
+ hooks : List[RemovableHandle]
268
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
269
+ """
270
+ cache = {**cache} if cache is not None else {}
271
+ hooks = []
272
+
273
+ def save_to_cache(module, _, output):
274
+ if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
275
+ cache[module] = output # save as-is, for the first token or cross attention
276
+ else:
277
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
278
+ return cache[module]
279
+
280
+ def install_hooks(layer: nn.Module):
281
+ if isinstance(layer, MultiHeadAttention):
282
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
283
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
284
+
285
+ self.decoder.apply(install_hooks)
286
+ return cache, hooks
287
+
288
+ detect_language = detect_language_function
289
+ transcribe = transcribe_function
290
+ decode = decode_function
musetalk_integration/whisper/tokenizer.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from transformers import GPT2TokenizerFast
9
+
10
+ LANGUAGES = {
11
+ "en": "english",
12
+ "zh": "chinese",
13
+ "de": "german",
14
+ "es": "spanish",
15
+ "ru": "russian",
16
+ "ko": "korean",
17
+ "fr": "french",
18
+ "ja": "japanese",
19
+ "pt": "portuguese",
20
+ "tr": "turkish",
21
+ "pl": "polish",
22
+ "ca": "catalan",
23
+ "nl": "dutch",
24
+ "ar": "arabic",
25
+ "sv": "swedish",
26
+ "it": "italian",
27
+ "id": "indonesian",
28
+ "hi": "hindi",
29
+ "fi": "finnish",
30
+ "vi": "vietnamese",
31
+ "iw": "hebrew",
32
+ "uk": "ukrainian",
33
+ "el": "greek",
34
+ "ms": "malay",
35
+ "cs": "czech",
36
+ "ro": "romanian",
37
+ "da": "danish",
38
+ "hu": "hungarian",
39
+ "ta": "tamil",
40
+ "no": "norwegian",
41
+ "th": "thai",
42
+ "ur": "urdu",
43
+ "hr": "croatian",
44
+ "bg": "bulgarian",
45
+ "lt": "lithuanian",
46
+ "la": "latin",
47
+ "mi": "maori",
48
+ "ml": "malayalam",
49
+ "cy": "welsh",
50
+ "sk": "slovak",
51
+ "te": "telugu",
52
+ "fa": "persian",
53
+ "lv": "latvian",
54
+ "bn": "bengali",
55
+ "sr": "serbian",
56
+ "az": "azerbaijani",
57
+ "sl": "slovenian",
58
+ "kn": "kannada",
59
+ "et": "estonian",
60
+ "mk": "macedonian",
61
+ "br": "breton",
62
+ "eu": "basque",
63
+ "is": "icelandic",
64
+ "hy": "armenian",
65
+ "ne": "nepali",
66
+ "mn": "mongolian",
67
+ "bs": "bosnian",
68
+ "kk": "kazakh",
69
+ "sq": "albanian",
70
+ "sw": "swahili",
71
+ "gl": "galician",
72
+ "mr": "marathi",
73
+ "pa": "punjabi",
74
+ "si": "sinhala",
75
+ "km": "khmer",
76
+ "sn": "shona",
77
+ "yo": "yoruba",
78
+ "so": "somali",
79
+ "af": "afrikaans",
80
+ "oc": "occitan",
81
+ "ka": "georgian",
82
+ "be": "belarusian",
83
+ "tg": "tajik",
84
+ "sd": "sindhi",
85
+ "gu": "gujarati",
86
+ "am": "amharic",
87
+ "yi": "yiddish",
88
+ "lo": "lao",
89
+ "uz": "uzbek",
90
+ "fo": "faroese",
91
+ "ht": "haitian creole",
92
+ "ps": "pashto",
93
+ "tk": "turkmen",
94
+ "nn": "nynorsk",
95
+ "mt": "maltese",
96
+ "sa": "sanskrit",
97
+ "lb": "luxembourgish",
98
+ "my": "myanmar",
99
+ "bo": "tibetan",
100
+ "tl": "tagalog",
101
+ "mg": "malagasy",
102
+ "as": "assamese",
103
+ "tt": "tatar",
104
+ "haw": "hawaiian",
105
+ "ln": "lingala",
106
+ "ha": "hausa",
107
+ "ba": "bashkir",
108
+ "jw": "javanese",
109
+ "su": "sundanese",
110
+ }
111
+
112
+ # language code lookup by name, with a few language aliases
113
+ TO_LANGUAGE_CODE = {
114
+ **{language: code for code, language in LANGUAGES.items()},
115
+ "burmese": "my",
116
+ "valencian": "ca",
117
+ "flemish": "nl",
118
+ "haitian": "ht",
119
+ "letzeburgesch": "lb",
120
+ "pushto": "ps",
121
+ "panjabi": "pa",
122
+ "moldavian": "ro",
123
+ "moldovan": "ro",
124
+ "sinhalese": "si",
125
+ "castilian": "es",
126
+ }
127
+
128
+
129
+ @dataclass(frozen=True)
130
+ class Tokenizer:
131
+ """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
132
+
133
+ tokenizer: "GPT2TokenizerFast"
134
+ language: Optional[str]
135
+ sot_sequence: Tuple[int]
136
+
137
+ def encode(self, text, **kwargs):
138
+ return self.tokenizer.encode(text, **kwargs)
139
+
140
+ def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
141
+ return self.tokenizer.decode(token_ids, **kwargs)
142
+
143
+ def decode_with_timestamps(self, tokens) -> str:
144
+ """
145
+ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
146
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
147
+ """
148
+ outputs = [[]]
149
+ for token in tokens:
150
+ if token >= self.timestamp_begin:
151
+ timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
152
+ outputs.append(timestamp)
153
+ outputs.append([])
154
+ else:
155
+ outputs[-1].append(token)
156
+ outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
157
+ return "".join(outputs)
158
+
159
+ @property
160
+ @lru_cache()
161
+ def eot(self) -> int:
162
+ return self.tokenizer.eos_token_id
163
+
164
+ @property
165
+ @lru_cache()
166
+ def sot(self) -> int:
167
+ return self._get_single_token_id("<|startoftranscript|>")
168
+
169
+ @property
170
+ @lru_cache()
171
+ def sot_lm(self) -> int:
172
+ return self._get_single_token_id("<|startoflm|>")
173
+
174
+ @property
175
+ @lru_cache()
176
+ def sot_prev(self) -> int:
177
+ return self._get_single_token_id("<|startofprev|>")
178
+
179
+ @property
180
+ @lru_cache()
181
+ def no_speech(self) -> int:
182
+ return self._get_single_token_id("<|nospeech|>")
183
+
184
+ @property
185
+ @lru_cache()
186
+ def no_timestamps(self) -> int:
187
+ return self._get_single_token_id("<|notimestamps|>")
188
+
189
+ @property
190
+ @lru_cache()
191
+ def timestamp_begin(self) -> int:
192
+ return self.tokenizer.all_special_ids[-1] + 1
193
+
194
+ @property
195
+ @lru_cache()
196
+ def language_token(self) -> int:
197
+ """Returns the token id corresponding to the value of the `language` field"""
198
+ if self.language is None:
199
+ raise ValueError(f"This tokenizer does not have language token configured")
200
+
201
+ additional_tokens = dict(
202
+ zip(
203
+ self.tokenizer.additional_special_tokens,
204
+ self.tokenizer.additional_special_tokens_ids,
205
+ )
206
+ )
207
+ candidate = f"<|{self.language}|>"
208
+ if candidate in additional_tokens:
209
+ return additional_tokens[candidate]
210
+
211
+ raise KeyError(f"Language {self.language} not found in tokenizer.")
212
+
213
+ @property
214
+ @lru_cache()
215
+ def all_language_tokens(self) -> Tuple[int]:
216
+ result = []
217
+ for token, token_id in zip(
218
+ self.tokenizer.additional_special_tokens,
219
+ self.tokenizer.additional_special_tokens_ids,
220
+ ):
221
+ if token.strip("<|>") in LANGUAGES:
222
+ result.append(token_id)
223
+ return tuple(result)
224
+
225
+ @property
226
+ @lru_cache()
227
+ def all_language_codes(self) -> Tuple[str]:
228
+ return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
229
+
230
+ @property
231
+ @lru_cache()
232
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
233
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
234
+
235
+ @property
236
+ @lru_cache()
237
+ def non_speech_tokens(self) -> Tuple[int]:
238
+ """
239
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
240
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
241
+
242
+ - ♪♪♪
243
+ - ( SPEAKING FOREIGN LANGUAGE )
244
+ - [DAVID] Hey there,
245
+
246
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
247
+ """
248
+ symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
249
+ symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
250
+
251
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
252
+ # In case they're multiple tokens, suppress the first token, which is safe because:
253
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
254
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
255
+ miscellaneous = set("♩♪♫♬♭♮♯")
256
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
257
+
258
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
259
+ result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
260
+ for symbol in symbols + list(miscellaneous):
261
+ for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
262
+ if len(tokens) == 1 or symbol in miscellaneous:
263
+ result.add(tokens[0])
264
+
265
+ return tuple(sorted(result))
266
+
267
+ def _get_single_token_id(self, text) -> int:
268
+ tokens = self.tokenizer.encode(text)
269
+ assert len(tokens) == 1, f"{text} is not encoded as a single token"
270
+ return tokens[0]
271
+
272
+
273
+ @lru_cache(maxsize=None)
274
+ def build_tokenizer(name: str = "gpt2"):
275
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
276
+ path = os.path.join(os.path.dirname(__file__), "assets", name)
277
+ tokenizer = GPT2TokenizerFast.from_pretrained(path)
278
+
279
+ specials = [
280
+ "<|startoftranscript|>",
281
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
282
+ "<|translate|>",
283
+ "<|transcribe|>",
284
+ "<|startoflm|>",
285
+ "<|startofprev|>",
286
+ "<|nospeech|>",
287
+ "<|notimestamps|>",
288
+ ]
289
+
290
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
291
+ return tokenizer
292
+
293
+
294
+ @lru_cache(maxsize=None)
295
+ def get_tokenizer(
296
+ multilingual: bool,
297
+ *,
298
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
299
+ language: Optional[str] = None,
300
+ ) -> Tokenizer:
301
+ if language is not None:
302
+ language = language.lower()
303
+ if language not in LANGUAGES:
304
+ if language in TO_LANGUAGE_CODE:
305
+ language = TO_LANGUAGE_CODE[language]
306
+ else:
307
+ raise ValueError(f"Unsupported language: {language}")
308
+
309
+ if multilingual:
310
+ tokenizer_name = "multilingual"
311
+ task = task or "transcribe"
312
+ language = language or "en"
313
+ else:
314
+ tokenizer_name = "gpt2"
315
+ task = None
316
+ language = None
317
+
318
+ tokenizer = build_tokenizer(name=tokenizer_name)
319
+ all_special_ids: List[int] = tokenizer.all_special_ids
320
+ sot: int = all_special_ids[1]
321
+ translate: int = all_special_ids[-6]
322
+ transcribe: int = all_special_ids[-5]
323
+
324
+ langs = tuple(LANGUAGES.keys())
325
+ sot_sequence = [sot]
326
+ if language is not None:
327
+ sot_sequence.append(sot + 1 + langs.index(language))
328
+ if task is not None:
329
+ sot_sequence.append(transcribe if task == "transcribe" else translate)
330
+
331
+ return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
musetalk_integration/whisper/transcribe.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ from typing import List, Optional, Tuple, Union, TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+
10
+ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
11
+ from .decoding import DecodingOptions, DecodingResult
12
+ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
13
+ from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
14
+
15
+ if TYPE_CHECKING:
16
+ from .model import Whisper
17
+
18
+
19
+ def transcribe(
20
+ model: "Whisper",
21
+ audio: Union[str, np.ndarray, torch.Tensor],
22
+ *,
23
+ verbose: Optional[bool] = None,
24
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
25
+ compression_ratio_threshold: Optional[float] = 2.4,
26
+ logprob_threshold: Optional[float] = -1.0,
27
+ no_speech_threshold: Optional[float] = 0.6,
28
+ condition_on_previous_text: bool = True,
29
+ force_extraction: bool = False,
30
+ **decode_options,
31
+ ):
32
+ """
33
+ Transcribe an audio file using Whisper
34
+
35
+ Parameters
36
+ ----------
37
+ model: Whisper
38
+ The Whisper model instance
39
+
40
+ audio: Union[str, np.ndarray, torch.Tensor]
41
+ The path to the audio file to open, or the audio waveform
42
+
43
+ verbose: bool
44
+ Whether to display the text being decoded to the console. If True, displays all the details,
45
+ If False, displays minimal details. If None, does not display anything
46
+
47
+ temperature: Union[float, Tuple[float, ...]]
48
+ Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
49
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
50
+
51
+ compression_ratio_threshold: float
52
+ If the gzip compression ratio is above this value, treat as failed
53
+
54
+ logprob_threshold: float
55
+ If the average log probability over sampled tokens is below this value, treat as failed
56
+
57
+ no_speech_threshold: float
58
+ If the no_speech probability is higher than this value AND the average log probability
59
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
60
+
61
+ condition_on_previous_text: bool
62
+ if True, the previous output of the model is provided as a prompt for the next window;
63
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
64
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
65
+
66
+ decode_options: dict
67
+ Keyword arguments to construct `DecodingOptions` instances
68
+
69
+ Returns
70
+ -------
71
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
72
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
73
+ """
74
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
75
+ if model.device == torch.device("cpu"):
76
+ if torch.cuda.is_available():
77
+ warnings.warn("Performing inference on CPU when CUDA is available")
78
+ if dtype == torch.float16:
79
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
80
+ dtype = torch.float32
81
+
82
+ if dtype == torch.float32:
83
+ decode_options["fp16"] = False
84
+
85
+ mel = log_mel_spectrogram(audio)
86
+
87
+ all_segments = []
88
+ def add_segment(
89
+ *, start: float, end: float, encoder_embeddings
90
+ ):
91
+
92
+ all_segments.append(
93
+ {
94
+ "start": start,
95
+ "end": end,
96
+ "encoder_embeddings":encoder_embeddings,
97
+ }
98
+ )
99
+ # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
100
+ num_frames = mel.shape[-1]
101
+ seek = 0
102
+ previous_seek_value = seek
103
+ sample_skip = 3000 #
104
+ with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
105
+ while seek < num_frames:
106
+ # seek是开始的帧数
107
+ end_seek = min(seek + sample_skip, num_frames)
108
+ segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
109
+
110
+ single = segment.ndim == 2
111
+ if single:
112
+ segment = segment.unsqueeze(0)
113
+ if dtype == torch.float16:
114
+ segment = segment.half()
115
+ audio_features, embeddings = model.encoder(segment, include_embeddings = True)
116
+
117
+ encoder_embeddings = embeddings
118
+ #print(f"encoder_embeddings shape {encoder_embeddings.shape}")
119
+ add_segment(
120
+ start=seek,
121
+ end=end_seek,
122
+ #text_tokens=tokens,
123
+ #result=result,
124
+ encoder_embeddings=encoder_embeddings,
125
+ )
126
+ seek+=sample_skip
127
+
128
+ return dict(segments=all_segments)
129
+
130
+
131
+ def cli():
132
+ from . import available_models
133
+
134
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
135
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
136
+ parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
137
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
138
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
139
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
140
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
141
+
142
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
143
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
144
+
145
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
146
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
147
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
148
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
149
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
150
+
151
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
152
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
153
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
154
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
155
+
156
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
157
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
158
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
159
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
160
+ parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
161
+
162
+ args = parser.parse_args().__dict__
163
+ model_name: str = args.pop("model")
164
+ model_dir: str = args.pop("model_dir")
165
+ output_dir: str = args.pop("output_dir")
166
+ device: str = args.pop("device")
167
+ os.makedirs(output_dir, exist_ok=True)
168
+
169
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
170
+ if args["language"] is not None:
171
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
172
+ args["language"] = "en"
173
+
174
+ temperature = args.pop("temperature")
175
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
176
+ if temperature_increment_on_fallback is not None:
177
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
178
+ else:
179
+ temperature = [temperature]
180
+
181
+ threads = args.pop("threads")
182
+ if threads > 0:
183
+ torch.set_num_threads(threads)
184
+
185
+ from . import load_model
186
+ model = load_model(model_name, device=device, download_root=model_dir)
187
+
188
+ for audio_path in args.pop("audio"):
189
+ result = transcribe(model, audio_path, temperature=temperature, **args)
190
+
191
+ audio_basename = os.path.basename(audio_path)
192
+
193
+ # save TXT
194
+ with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
195
+ write_txt(result["segments"], file=txt)
196
+
197
+ # save VTT
198
+ with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
199
+ write_vtt(result["segments"], file=vtt)
200
+
201
+ # save SRT
202
+ with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
203
+ write_srt(result["segments"], file=srt)
204
+
205
+
206
+ if __name__ == '__main__':
207
+ cli()
musetalk_integration/whisper/utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zlib
2
+ from typing import Iterator, TextIO
3
+
4
+
5
+ def exact_div(x, y):
6
+ assert x % y == 0
7
+ return x // y
8
+
9
+
10
+ def str2bool(string):
11
+ str2val = {"True": True, "False": False}
12
+ if string in str2val:
13
+ return str2val[string]
14
+ else:
15
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
16
+
17
+
18
+ def optional_int(string):
19
+ return None if string == "None" else int(string)
20
+
21
+
22
+ def optional_float(string):
23
+ return None if string == "None" else float(string)
24
+
25
+
26
+ def compression_ratio(text) -> float:
27
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
28
+
29
+
30
+ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
31
+ assert seconds >= 0, "non-negative timestamp expected"
32
+ milliseconds = round(seconds * 1000.0)
33
+
34
+ hours = milliseconds // 3_600_000
35
+ milliseconds -= hours * 3_600_000
36
+
37
+ minutes = milliseconds // 60_000
38
+ milliseconds -= minutes * 60_000
39
+
40
+ seconds = milliseconds // 1_000
41
+ milliseconds -= seconds * 1_000
42
+
43
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
44
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
45
+
46
+
47
+ def write_txt(transcript: Iterator[dict], file: TextIO):
48
+ for segment in transcript:
49
+ print(segment['text'].strip(), file=file, flush=True)
50
+
51
+
52
+ def write_vtt(transcript: Iterator[dict], file: TextIO):
53
+ print("WEBVTT\n", file=file)
54
+ for segment in transcript:
55
+ print(
56
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
57
+ f"{segment['text'].strip().replace('-->', '->')}\n",
58
+ file=file,
59
+ flush=True,
60
+ )
61
+
62
+
63
+ def write_srt(transcript: Iterator[dict], file: TextIO):
64
+ """
65
+ Write a transcript to a file in SRT format.
66
+
67
+ Example usage:
68
+ from pathlib import Path
69
+ from whisper.utils import write_srt
70
+
71
+ result = transcribe(model, audio_path, temperature=temperature, **args)
72
+
73
+ # save SRT
74
+ audio_basename = Path(audio_path).stem
75
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
76
+ write_srt(result["segments"], file=srt)
77
+ """
78
+ for i, segment in enumerate(transcript, start=1):
79
+ # write srt lines
80
+ print(
81
+ f"{i}\n"
82
+ f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
83
+ f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
84
+ f"{segment['text'].strip().replace('-->', '->')}\n",
85
+ file=file,
86
+ flush=True,
87
+ )
processing.py CHANGED
@@ -320,7 +320,7 @@ def process_lipsync_with_audio_target_new(
320
  video_file,
321
  audio_file,
322
  session_id=None,
323
- crop_size=256,
324
  progress=gr.Progress(track_tqdm=True),
325
  ):
326
  """Workflow mới: Chuẩn hóa YouTube rồi lipsync
@@ -337,7 +337,7 @@ def process_lipsync_with_audio_target_new(
337
  video_file: Path to video source
338
  audio_file: Path to audio target (English only)
339
  session_id: Session identifier
340
- crop_size: Size cho lipsync (256 or 512)
341
  progress: Progress tracking object
342
 
343
  Returns:
@@ -352,6 +352,16 @@ def process_lipsync_with_audio_target_new(
352
 
353
  output_dir = setup_output_dir(session_id)
354
 
 
 
 
 
 
 
 
 
 
 
355
  logger.info(f"Memory at start: {get_memory_usage()}")
356
 
357
  audio_duration = get_audio_duration(audio_path)
@@ -417,7 +427,7 @@ def process_lipsync_with_audio_target_new(
417
  with timer("Applying lipsync"):
418
  try:
419
  lipsynced_video, lipsynced_info = apply_lipsync_to_video(
420
- video_normalized, audio_16k, output_dir, crop_size
421
  )
422
  logger.info(
423
  f"Lipsynced video: {lipsynced_video}, size: {lipsynced_info['width']}x{lipsynced_info['height']}"
@@ -461,7 +471,7 @@ def lipsync_with_audio_target(
461
  video_file,
462
  audio_file,
463
  session_id=None,
464
- crop_size=256,
465
  progress=gr.Progress(track_tqdm=True),
466
  ):
467
  """Wrapper for Gradio: Lipsync video source with audio target (English only)
@@ -474,5 +484,5 @@ def lipsync_with_audio_target(
474
  if audio_file is None:
475
  raise gr.Error("Please upload a target audio.")
476
  return process_lipsync_with_audio_target_new(
477
- video_file, audio_file, session_id, crop_size, progress
478
  )
 
320
  video_file,
321
  audio_file,
322
  session_id=None,
323
+ model_type="latentsync",
324
  progress=gr.Progress(track_tqdm=True),
325
  ):
326
  """Workflow mới: Chuẩn hóa YouTube rồi lipsync
 
337
  video_file: Path to video source
338
  audio_file: Path to audio target (English only)
339
  session_id: Session identifier
340
+ model_type: Model type for lipsync ("latentsync" or "musetalk")
341
  progress: Progress tracking object
342
 
343
  Returns:
 
352
 
353
  output_dir = setup_output_dir(session_id)
354
 
355
+ # Mapping model_type to crop_size
356
+ if model_type == "LatentSync v1.6":
357
+ crop_size = 512
358
+ logger.info("Using LatentSync v1.6 with crop_size=512")
359
+ elif model_type == "MuseTalk v1.5":
360
+ crop_size = 256
361
+ logger.info("Using MuseTalk v1.5 with crop_size=256")
362
+ else:
363
+ raise ValueError(f"Unknown model_type: {model_type}")
364
+
365
  logger.info(f"Memory at start: {get_memory_usage()}")
366
 
367
  audio_duration = get_audio_duration(audio_path)
 
427
  with timer("Applying lipsync"):
428
  try:
429
  lipsynced_video, lipsynced_info = apply_lipsync_to_video(
430
+ video_normalized, audio_16k, output_dir, model_type
431
  )
432
  logger.info(
433
  f"Lipsynced video: {lipsynced_video}, size: {lipsynced_info['width']}x{lipsynced_info['height']}"
 
471
  video_file,
472
  audio_file,
473
  session_id=None,
474
+ model_type="LatentSync v1.6",
475
  progress=gr.Progress(track_tqdm=True),
476
  ):
477
  """Wrapper for Gradio: Lipsync video source with audio target (English only)
 
484
  if audio_file is None:
485
  raise gr.Error("Please upload a target audio.")
486
  return process_lipsync_with_audio_target_new(
487
+ video_file, audio_file, session_id, model_type, progress
488
  )
requirements.txt CHANGED
@@ -45,3 +45,12 @@ psutil
45
  # Gradio & Spaces
46
  gradio==5.24.0
47
  spaces
 
 
 
 
 
 
 
 
 
 
45
  # Gradio & Spaces
46
  gradio==5.24.0
47
  spaces
48
+
49
+ # MuseTalk Dependencies
50
+ mmengine>=0.10.0
51
+ mmcv>=2.0.1
52
+ mmdet>=3.1.0
53
+ mmpose>=1.1.0
54
+ openmim>=0.3.0
55
+ moviepy>=1.0.3
56
+ gdown>=5.1.0