Zhizhou Zhong zzgoogle commited on
Commit
b6de2a7
·
unverified ·
1 Parent(s): 5485b83

feat: windows infer & gradio (#312)

Browse files

* fix: windows infer

* docs: update readme

* docs: update readme

* feat: v1.5 gradio for windows&linux

* fix: dependencies

* feat: windows infer & gradio

---------

Co-authored-by: NeRF-Factory <zzhizhou66@gmail.com>

.gitignore CHANGED
@@ -5,11 +5,14 @@
5
  *.pyc
6
  .ipynb_checkpoints
7
  results/
8
- ./models
9
  **/__pycache__/
10
  *.py[cod]
11
  *$py.class
12
  dataset/
13
  ffmpeg*
 
 
14
  debug
15
- exp_out
 
 
5
  *.pyc
6
  .ipynb_checkpoints
7
  results/
8
+ models/
9
  **/__pycache__/
10
  *.py[cod]
11
  *$py.class
12
  dataset/
13
  ffmpeg*
14
+ ffmprobe*
15
+ ffplay*
16
  debug
17
+ exp_out
18
+ .gradio
README.md CHANGED
@@ -146,50 +146,87 @@ We also hope you note that we have not verified, maintained, or updated third-pa
146
 
147
  ## Installation
148
  To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
 
149
  ### Build environment
 
 
 
 
 
 
150
 
151
- We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  ```shell
154
  pip install -r requirements.txt
155
  ```
156
 
157
- ### mmlab packages
 
 
158
  ```bash
159
- pip install --no-cache-dir -U openmim
160
- mim install mmengine
161
- mim install "mmcv>=2.0.1"
162
- mim install "mmdet>=3.1.0"
163
- mim install "mmpose>=1.1.0"
164
  ```
165
 
166
- ### Download ffmpeg-static
167
- Download the ffmpeg-static and
168
- ```
 
 
 
 
169
  export FFMPEG_PATH=/path/to/ffmpeg
170
- ```
171
- for example:
172
- ```
173
  export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
174
  ```
 
 
 
 
175
  ### Download weights
176
- You can download weights manually as follows:
 
 
 
177
 
178
- 1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk).
179
  ```bash
180
- # !pip install -U "huggingface_hub[cli]"
181
- export HF_ENDPOINT=https://hf-mirror.com
182
- huggingface-cli download TMElyralab/MuseTalk --local-dir models/
183
  ```
184
 
 
 
 
 
 
 
 
 
 
 
185
  2. Download the weights of other components:
186
- - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
187
  - [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
188
  - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
189
- - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
190
- - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
191
  - [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
192
-
 
193
 
194
  Finally, these weights should be organized in `models` as follows:
195
  ```
@@ -207,7 +244,7 @@ Finally, these weights should be organized in `models` as follows:
207
  ├── face-parse-bisent
208
  │ ├── 79999_iter.pth
209
  │ └── resnet18-5c106cde.pth
210
- ├── sd-vae-ft-mse
211
  │ ├── config.json
212
  │ └── diffusion_pytorch_model.bin
213
  └── whisper
@@ -221,42 +258,66 @@ Finally, these weights should be organized in `models` as follows:
221
  ### Inference
222
  We provide inference scripts for both versions of MuseTalk:
223
 
224
- #### MuseTalk 1.5 (Recommended)
 
225
  ```bash
226
- # Run MuseTalk 1.5 inference
227
- sh inference.sh v1.5 normal
228
  ```
 
 
 
229
 
230
- #### MuseTalk 1.0
 
231
  ```bash
232
- # Run MuseTalk 1.0 inference
 
 
 
233
  sh inference.sh v1.0 normal
234
  ```
235
 
236
- The inference script supports both MuseTalk 1.5 and 1.0 models:
237
- - For MuseTalk 1.5: Use the command above with the V1.5 model path
238
- - For MuseTalk 1.0: Use the same script but point to the V1.0 model path
239
 
240
- The configuration file `configs/inference/test.yaml` contains the inference settings, including:
241
- - `video_path`: Path to the input video, image file, or directory of images
242
- - `audio_path`: Path to the input audio file
243
 
244
- Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
 
 
 
 
 
 
 
 
245
 
246
  #### Real-time Inference
247
- For real-time inference, use the following command:
248
  ```bash
249
- # Run real-time inference
250
- sh inference.sh v1.5 realtime # For MuseTalk 1.5
251
- # or
252
- sh inference.sh v1.0 realtime # For MuseTalk 1.0
 
253
  ```
254
 
255
- The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
256
- - `preparation`: Set to `True` for new avatar preparation
257
- - `video_path`: Path to the input video
258
- - `bbox_shift`: Adjustable parameter for mouth region control
259
- - `audio_clips`: List of audio clips for generation
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  Important notes for real-time inference:
262
  1. Set `preparation` to `True` when processing a new avatar
@@ -269,6 +330,18 @@ For faster generation without saving images, you can use:
269
  python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
270
  ```
271
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  ## Training
273
 
274
  ### Data Preparation
 
146
 
147
  ## Installation
148
  To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
149
+
150
  ### Build environment
151
+ We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
152
+
153
+ ```shell
154
+ conda create -n MuseTalk python==3.10
155
+ conda activate MuseTalk
156
+ ```
157
 
158
+ ### Install PyTorch 2.0.1
159
+ Choose one of the following installation methods:
160
+
161
+ ```shell
162
+ # Option 1: Using pip
163
+ pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
164
+
165
+ # Option 2: Using conda
166
+ conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
167
+ ```
168
+
169
+ ### Install Dependencies
170
+ Install the remaining required packages:
171
 
172
  ```shell
173
  pip install -r requirements.txt
174
  ```
175
 
176
+ ### Install MMLab Packages
177
+ Install the MMLab ecosystem packages:
178
+
179
  ```bash
180
+ pip install --no-cache-dir -U openmim
181
+ mim install mmengine
182
+ mim install "mmcv==2.0.1"
183
+ mim install "mmdet==3.1.0"
184
+ mim install "mmpose==1.1.0"
185
  ```
186
 
187
+ ### Setup FFmpeg
188
+ 1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
189
+
190
+ 2. Configure FFmpeg based on your operating system:
191
+
192
+ For Linux:
193
+ ```bash
194
  export FFMPEG_PATH=/path/to/ffmpeg
195
+ # Example:
 
 
196
  export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
197
  ```
198
+
199
+ For Windows:
200
+ Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
201
+
202
  ### Download weights
203
+ You can download weights in two ways:
204
+
205
+ #### Option 1: Using Download Scripts
206
+ We provide two scripts for automatic downloading:
207
 
208
+ For Linux:
209
  ```bash
210
+ sh ./download_weights.sh
 
 
211
  ```
212
 
213
+ For Windows:
214
+ ```batch
215
+ # Run the script
216
+ download_weights.bat
217
+ ```
218
+
219
+ #### Option 2: Manual Download
220
+ You can also download the weights manually from the following links:
221
+
222
+ 1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
223
  2. Download the weights of other components:
224
+ - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
225
  - [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
226
  - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
 
 
227
  - [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
228
+ - [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
229
+ - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
230
 
231
  Finally, these weights should be organized in `models` as follows:
232
  ```
 
244
  ├── face-parse-bisent
245
  │ ├── 79999_iter.pth
246
  │ └── resnet18-5c106cde.pth
247
+ ├── sd-vae
248
  │ ├── config.json
249
  │ └── diffusion_pytorch_model.bin
250
  └── whisper
 
258
  ### Inference
259
  We provide inference scripts for both versions of MuseTalk:
260
 
261
+ #### Prerequisites
262
+ Before running inference, please ensure ffmpeg is installed and accessible:
263
  ```bash
264
+ # Check ffmpeg installation
265
+ ffmpeg -version
266
  ```
267
+ If ffmpeg is not found, please install it first:
268
+ - Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
269
+ - Linux: `sudo apt-get install ffmpeg`
270
 
271
+ #### Normal Inference
272
+ ##### Linux Environment
273
  ```bash
274
+ # MuseTalk 1.5 (Recommended)
275
+ sh inference.sh v1.5 normal
276
+
277
+ # MuseTalk 1.0
278
  sh inference.sh v1.0 normal
279
  ```
280
 
281
+ ##### Windows Environment
 
 
282
 
283
+ Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
 
 
284
 
285
+ ```bash
286
+ # MuseTalk 1.5 (Recommended)
287
+ python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
288
+
289
+ # For MuseTalk 1.0, change:
290
+ # - models\musetalkV15 -> models\musetalk
291
+ # - unet.pth -> pytorch_model.bin
292
+ # - --version v15 -> --version v1
293
+ ```
294
 
295
  #### Real-time Inference
296
+ ##### Linux Environment
297
  ```bash
298
+ # MuseTalk 1.5 (Recommended)
299
+ sh inference.sh v1.5 realtime
300
+
301
+ # MuseTalk 1.0
302
+ sh inference.sh v1.0 realtime
303
  ```
304
 
305
+ ##### Windows Environment
306
+ ```bash
307
+ # MuseTalk 1.5 (Recommended)
308
+ python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
309
+
310
+ # For MuseTalk 1.0, change:
311
+ # - models\musetalkV15 -> models\musetalk
312
+ # - unet.pth -> pytorch_model.bin
313
+ # - --version v15 -> --version v1
314
+ ```
315
+
316
+ The configuration file `configs/inference/test.yaml` contains the inference settings, including:
317
+ - `video_path`: Path to the input video, image file, or directory of images
318
+ - `audio_path`: Path to the input audio file
319
+
320
+ Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
321
 
322
  Important notes for real-time inference:
323
  1. Set `preparation` to `True` when processing a new avatar
 
330
  python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
331
  ```
332
 
333
+ ## Gradio Demo
334
+ We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
335
+ ![para](assets/figs/gradio_2.png)
336
+ For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. ![speed](assets/figs/gradio.png)
337
+
338
+ Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
339
+
340
+ ```bash
341
+ # You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
342
+ python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
343
+ ```
344
+
345
  ## Training
346
 
347
  ### Data Preparation
app.py CHANGED
@@ -4,7 +4,6 @@ import pdb
4
  import re
5
 
6
  import gradio as gr
7
- import spaces
8
  import numpy as np
9
  import sys
10
  import subprocess
@@ -28,11 +27,101 @@ import gdown
28
  import imageio
29
  import ffmpeg
30
  from moviepy.editor import *
31
-
32
 
33
  ProjectDir = os.path.abspath(os.path.dirname(__file__))
34
  CheckpointsDir = os.path.join(ProjectDir, "models")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def print_directory_contents(path):
37
  for child in os.listdir(path):
38
  child_path = os.path.join(path, child)
@@ -40,119 +129,107 @@ def print_directory_contents(path):
40
  print(child_path)
41
 
42
  def download_model():
43
- if not os.path.exists(CheckpointsDir):
44
- os.makedirs(CheckpointsDir)
45
- print("Checkpoint Not Downloaded, start downloading...")
46
- tic = time.time()
47
- snapshot_download(
48
- repo_id="TMElyralab/MuseTalk",
49
- local_dir=CheckpointsDir,
50
- max_workers=8,
51
- local_dir_use_symlinks=True,
52
- force_download=True, resume_download=False
53
- )
54
- # weight
55
- os.makedirs(f"{CheckpointsDir}/sd-vae-ft-mse/")
56
- snapshot_download(
57
- repo_id="stabilityai/sd-vae-ft-mse",
58
- local_dir=CheckpointsDir+'/sd-vae-ft-mse',
59
- max_workers=8,
60
- local_dir_use_symlinks=True,
61
- force_download=True, resume_download=False
62
- )
63
- #dwpose
64
- os.makedirs(f"{CheckpointsDir}/dwpose/")
65
- snapshot_download(
66
- repo_id="yzd-v/DWPose",
67
- local_dir=CheckpointsDir+'/dwpose',
68
- max_workers=8,
69
- local_dir_use_symlinks=True,
70
- force_download=True, resume_download=False
71
- )
72
- #vae
73
- url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
74
- response = requests.get(url)
75
- # 确保请求成功
76
- if response.status_code == 200:
77
- # 指定文件保存的位置
78
- file_path = f"{CheckpointsDir}/whisper/tiny.pt"
79
- os.makedirs(f"{CheckpointsDir}/whisper/")
80
- # 将文件内容写入指定位置
81
- with open(file_path, "wb") as f:
82
- f.write(response.content)
83
- else:
84
- print(f"请求失败,状态码:{response.status_code}")
85
- #gdown face parse
86
- url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
87
- os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
88
- file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
89
- gdown.download(url, file_path, quiet=False)
90
- #resnet
91
- url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
92
- response = requests.get(url)
93
- # 确保请求成功
94
- if response.status_code == 200:
95
- # 指定文件保存的位置
96
- file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
97
- # 将文件内容写入指定位置
98
- with open(file_path, "wb") as f:
99
- f.write(response.content)
100
  else:
101
- print(f"请求失败,状态码:{response.status_code}")
102
-
103
-
104
- toc = time.time()
105
-
106
- print(f"download cost {toc-tic} seconds")
107
- print_directory_contents(CheckpointsDir)
108
-
109
  else:
110
- print("Already download the model.")
111
-
112
 
113
 
114
 
115
 
116
  download_model() # for huggingface deployment.
117
 
118
-
119
- from musetalk.utils.utils import get_file_type,get_video_fps,datagen
120
- from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
121
  from musetalk.utils.blending import get_image
122
- from musetalk.utils.utils import load_all_model
 
 
 
123
 
124
 
 
 
 
 
 
 
125
 
126
 
127
-
128
-
129
- @spaces.GPU(duration=600)
130
  @torch.no_grad()
131
- def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
132
- args_dict={"result_dir":'./results/output', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  args = Namespace(**args_dict)
134
 
 
 
 
 
135
  input_basename = os.path.basename(video_path).split('.')[0]
136
- audio_basename = os.path.basename(audio_path).split('.')[0]
137
  output_basename = f"{input_basename}_{audio_basename}"
138
- result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
139
- crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
140
- os.makedirs(result_img_save_path,exist_ok =True)
 
 
 
 
 
 
141
 
142
- if args.output_vid_name=="":
143
- output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
144
  else:
145
- output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
 
146
  ############################################## extract frames from source video ##############################################
147
- if get_file_type(video_path)=="video":
148
- save_dir_full = os.path.join(args.result_dir, input_basename)
149
- os.makedirs(save_dir_full,exist_ok = True)
150
- # cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
151
- # os.system(cmd)
152
- # 读取视频
153
  reader = imageio.get_reader(video_path)
154
 
155
- # 保存图片
156
  for i, im in enumerate(reader):
157
  imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
158
  input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
@@ -161,10 +238,21 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
161
  input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
162
  input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
163
  fps = args.fps
164
- #print(input_img_list)
165
  ############################################## extract audio feature ##############################################
166
- whisper_feature = audio_processor.audio2feat(audio_path)
167
- whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
 
 
 
 
 
 
 
 
 
 
 
168
  ############################################## preprocess input image ##############################################
169
  if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
170
  print("using extracted coordinates")
@@ -176,13 +264,22 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
176
  coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
177
  with open(crop_coord_save_path, 'wb') as f:
178
  pickle.dump(coord_list, f)
179
- bbox_shift_text=get_bbox_range(input_img_list, bbox_shift)
 
 
 
 
 
 
 
180
  i = 0
181
  input_latent_list = []
182
  for bbox, frame in zip(coord_list, frame_list):
183
  if bbox == coord_placeholder:
184
  continue
185
  x1, y1, x2, y2 = bbox
 
 
186
  crop_frame = frame[y1:y2, x1:x2]
187
  crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
188
  latents = vae.get_latents_for_unet(crop_frame)
@@ -192,17 +289,23 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
192
  frame_list_cycle = frame_list + frame_list[::-1]
193
  coord_list_cycle = coord_list + coord_list[::-1]
194
  input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
 
195
  ############################################## inference batch by batch ##############################################
196
  print("start inference")
197
  video_num = len(whisper_chunks)
198
  batch_size = args.batch_size
199
- gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
 
 
 
 
 
 
200
  res_frame_list = []
201
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
202
-
203
- tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
204
- audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
205
- audio_feature_batch = pe(audio_feature_batch)
206
 
207
  pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
208
  recon = vae.decode_latents(pred_latents)
@@ -215,25 +318,24 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
215
  bbox = coord_list_cycle[i%(len(coord_list_cycle))]
216
  ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
217
  x1, y1, x2, y2 = bbox
 
 
218
  try:
219
  res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
220
  except:
221
- # print(bbox)
222
  continue
223
 
224
- combine_frame = get_image(ori_frame,res_frame,bbox)
 
 
225
  cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
226
 
227
- # cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p temp.mp4"
228
- # print(cmd_img2video)
229
- # os.system(cmd_img2video)
230
- # 帧率
231
  fps = 25
232
- # 图片路径
233
- # 输出视频路径
234
  output_video = 'temp.mp4'
235
 
236
- # 读取图片
237
  def is_valid_image(file):
238
  pattern = re.compile(r'\d{8}\.png')
239
  return pattern.match(file)
@@ -247,13 +349,9 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
247
  images.append(imageio.imread(filename))
248
 
249
 
250
- # 保存视频
251
  imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
252
 
253
- # cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
254
- # print(cmd_combine_audio)
255
- # os.system(cmd_combine_audio)
256
-
257
  input_video = './temp.mp4'
258
  # Check if the input_video and audio_path exist
259
  if not os.path.exists(input_video):
@@ -261,40 +359,15 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
261
  if not os.path.exists(audio_path):
262
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
263
 
264
- # 读取视频
265
  reader = imageio.get_reader(input_video)
266
- fps = reader.get_meta_data()['fps'] # 获取原视频的帧率
267
- reader.close() # 否则在win11上会报错:PermissionError: [WinError 32] 另一个程序正在使用此文件,进程无法访问。: 'temp.mp4'
268
- # 将帧存储在列表中
269
  frames = images
270
-
271
- # 保存视频并添加音频
272
- # imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
273
-
274
- # input_video = ffmpeg.input(input_video)
275
-
276
- # input_audio = ffmpeg.input(audio_path)
277
 
278
  print(len(frames))
279
 
280
- # imageio.mimwrite(
281
- # output_video,
282
- # frames,
283
- # 'FFMPEG',
284
- # fps=25,
285
- # codec='libx264',
286
- # audio_codec='aac',
287
- # input_params=['-i', audio_path],
288
- # output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
289
- # )
290
- # writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
291
- # for im in frames:
292
- # writer.append_data(im)
293
- # writer.close()
294
-
295
-
296
-
297
-
298
  # Load the video
299
  video_clip = VideoFileClip(input_video)
300
 
@@ -315,11 +388,45 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
315
 
316
 
317
  # load model weights
318
- audio_processor,vae,unet,pe = load_all_model()
319
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
320
- timesteps = torch.tensor([0], device=device)
 
 
 
 
 
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
 
 
 
 
 
 
 
323
 
324
 
325
  def check_video(video):
@@ -340,9 +447,6 @@ def check_video(video):
340
  output_video = os.path.join('./results/input', output_file_name)
341
 
342
 
343
- # # Run the ffmpeg command to change the frame rate to 25fps
344
- # command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
345
-
346
  # read video
347
  reader = imageio.get_reader(video)
348
  fps = reader.get_meta_data()['fps'] # get fps from original video
@@ -374,34 +478,45 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p
374
 
375
  with gr.Blocks(css=css) as demo:
376
  gr.Markdown(
377
- "<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \
378
  <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
379
  </br>\
380
- Yue Zhang <sup>\*</sup>,\
381
- Minhao Liu<sup>\*</sup>,\
 
382
  Zhaokang Chen,\
383
  Bin Wu<sup>†</sup>,\
 
 
384
  Yingjie He,\
385
- Chao Zhan,\
386
- Wenjiang Zhou\
387
  (<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
388
  Lyra Lab, Tencent Music Entertainment\
389
  </h2> \
390
  <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
391
  <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
392
- <a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
393
- <a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
394
  )
395
 
396
  with gr.Row():
397
  with gr.Column():
398
- audio = gr.Audio(label="Driven Audio",type="filepath")
399
  video = gr.Video(label="Reference Video",sources=['upload'])
400
  bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
401
- bbox_shift_scale = gr.Textbox(label="BBox_shift recommend value lower bound,The corresponding bbox range is generated after the initial result is generated. \n If the result is not good, it can be adjusted according to this reference value", value="",interactive=False)
402
-
403
- btn = gr.Button("Generate")
404
- out1 = gr.Video()
 
 
 
 
 
 
 
 
 
405
 
406
  video.change(
407
  fn=check_video, inputs=[video], outputs=[video]
@@ -412,15 +527,44 @@ with gr.Blocks(css=css) as demo:
412
  audio,
413
  video,
414
  bbox_shift,
 
 
 
 
415
  ],
416
  outputs=[out1,bbox_shift_scale]
417
  )
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
- # Set the IP and port
420
- ip_address = "0.0.0.0" # Replace with your desired IP address
421
- port_number = 7860 # Replace with your desired port number
422
-
423
-
 
 
 
 
 
 
 
 
 
 
424
  demo.queue().launch(
425
- share=False , debug=True, server_name=ip_address, server_port=port_number
 
 
 
426
  )
 
4
  import re
5
 
6
  import gradio as gr
 
7
  import numpy as np
8
  import sys
9
  import subprocess
 
27
  import imageio
28
  import ffmpeg
29
  from moviepy.editor import *
30
+ from transformers import WhisperModel
31
 
32
  ProjectDir = os.path.abspath(os.path.dirname(__file__))
33
  CheckpointsDir = os.path.join(ProjectDir, "models")
34
 
35
+ @torch.no_grad()
36
+ def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
37
+ left_cheek_width=90, right_cheek_width=90):
38
+ """Debug inpainting parameters, only process the first frame"""
39
+ # Set default parameters
40
+ args_dict = {
41
+ "result_dir": './results/debug',
42
+ "fps": 25,
43
+ "batch_size": 1,
44
+ "output_vid_name": '',
45
+ "use_saved_coord": False,
46
+ "audio_padding_length_left": 2,
47
+ "audio_padding_length_right": 2,
48
+ "version": "v15",
49
+ "extra_margin": extra_margin,
50
+ "parsing_mode": parsing_mode,
51
+ "left_cheek_width": left_cheek_width,
52
+ "right_cheek_width": right_cheek_width
53
+ }
54
+ args = Namespace(**args_dict)
55
+
56
+ # Create debug directory
57
+ os.makedirs(args.result_dir, exist_ok=True)
58
+
59
+ # Read first frame
60
+ if get_file_type(video_path) == "video":
61
+ reader = imageio.get_reader(video_path)
62
+ first_frame = reader.get_data(0)
63
+ reader.close()
64
+ else:
65
+ first_frame = cv2.imread(video_path)
66
+ first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
67
+
68
+ # Save first frame
69
+ debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
70
+ cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
71
+
72
+ # Get face coordinates
73
+ coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
74
+ bbox = coord_list[0]
75
+ frame = frame_list[0]
76
+
77
+ if bbox == coord_placeholder:
78
+ return None, "No face detected, please adjust bbox_shift parameter"
79
+
80
+ # Initialize face parser
81
+ fp = FaceParsing(
82
+ left_cheek_width=args.left_cheek_width,
83
+ right_cheek_width=args.right_cheek_width
84
+ )
85
+
86
+ # Process first frame
87
+ x1, y1, x2, y2 = bbox
88
+ y2 = y2 + args.extra_margin
89
+ y2 = min(y2, frame.shape[0])
90
+ crop_frame = frame[y1:y2, x1:x2]
91
+ crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
92
+
93
+ # Generate random audio features
94
+ random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
95
+ audio_feature = pe(random_audio)
96
+
97
+ # Get latents
98
+ latents = vae.get_latents_for_unet(crop_frame)
99
+ latents = latents.to(dtype=weight_dtype)
100
+
101
+ # Generate prediction results
102
+ pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
103
+ recon = vae.decode_latents(pred_latents)
104
+
105
+ # Inpaint back to original image
106
+ res_frame = recon[0]
107
+ res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
108
+ combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
109
+
110
+ # Save results (no need to convert color space again since get_image already returns RGB format)
111
+ debug_result_path = os.path.join(args.result_dir, "debug_result.png")
112
+ cv2.imwrite(debug_result_path, combine_frame)
113
+
114
+ # Create information text
115
+ info_text = f"Parameter information:\n" + \
116
+ f"bbox_shift: {bbox_shift}\n" + \
117
+ f"extra_margin: {extra_margin}\n" + \
118
+ f"parsing_mode: {parsing_mode}\n" + \
119
+ f"left_cheek_width: {left_cheek_width}\n" + \
120
+ f"right_cheek_width: {right_cheek_width}\n" + \
121
+ f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
122
+
123
+ return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
124
+
125
  def print_directory_contents(path):
126
  for child in os.listdir(path):
127
  child_path = os.path.join(path, child)
 
129
  print(child_path)
130
 
131
  def download_model():
132
+ # 检查必需的模型文件是否存在
133
+ required_models = {
134
+ "MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
135
+ "MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
136
+ "SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
137
+ "Whisper": f"{CheckpointsDir}/whisper/config.json",
138
+ "DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
139
+ "SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
140
+ "Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
141
+ "ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
142
+ }
143
+
144
+ missing_models = []
145
+ for model_name, model_path in required_models.items():
146
+ if not os.path.exists(model_path):
147
+ missing_models.append(model_name)
148
+
149
+ if missing_models:
150
+ # 全用英文
151
+ print("The following required model files are missing:")
152
+ for model in missing_models:
153
+ print(f"- {model}")
154
+ print("\nPlease run the download script to download the missing models:")
155
+ if sys.platform == "win32":
156
+ print("Windows: Run download_weights.bat")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  else:
158
+ print("Linux/Mac: Run ./download_weights.sh")
159
+ sys.exit(1)
 
 
 
 
 
 
160
  else:
161
+ print("All required model files exist.")
 
162
 
163
 
164
 
165
 
166
  download_model() # for huggingface deployment.
167
 
 
 
 
168
  from musetalk.utils.blending import get_image
169
+ from musetalk.utils.face_parsing import FaceParsing
170
+ from musetalk.utils.audio_processor import AudioProcessor
171
+ from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
172
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
173
 
174
 
175
+ def fast_check_ffmpeg():
176
+ try:
177
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
178
+ return True
179
+ except:
180
+ return False
181
 
182
 
 
 
 
183
  @torch.no_grad()
184
+ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
185
+ left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
186
+ # Set default parameters, aligned with inference.py
187
+ args_dict = {
188
+ "result_dir": './results/output',
189
+ "fps": 25,
190
+ "batch_size": 8,
191
+ "output_vid_name": '',
192
+ "use_saved_coord": False,
193
+ "audio_padding_length_left": 2,
194
+ "audio_padding_length_right": 2,
195
+ "version": "v15", # Fixed use v15 version
196
+ "extra_margin": extra_margin,
197
+ "parsing_mode": parsing_mode,
198
+ "left_cheek_width": left_cheek_width,
199
+ "right_cheek_width": right_cheek_width
200
+ }
201
  args = Namespace(**args_dict)
202
 
203
+ # Check ffmpeg
204
+ if not fast_check_ffmpeg():
205
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
206
+
207
  input_basename = os.path.basename(video_path).split('.')[0]
208
+ audio_basename = os.path.basename(audio_path).split('.')[0]
209
  output_basename = f"{input_basename}_{audio_basename}"
210
+
211
+ # Create temporary directory
212
+ temp_dir = os.path.join(args.result_dir, f"{args.version}")
213
+ os.makedirs(temp_dir, exist_ok=True)
214
+
215
+ # Set result save path
216
+ result_img_save_path = os.path.join(temp_dir, output_basename)
217
+ crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
218
+ os.makedirs(result_img_save_path, exist_ok=True)
219
 
220
+ if args.output_vid_name == "":
221
+ output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
222
  else:
223
+ output_vid_name = os.path.join(temp_dir, args.output_vid_name)
224
+
225
  ############################################## extract frames from source video ##############################################
226
+ if get_file_type(video_path) == "video":
227
+ save_dir_full = os.path.join(temp_dir, input_basename)
228
+ os.makedirs(save_dir_full, exist_ok=True)
229
+ # Read video
 
 
230
  reader = imageio.get_reader(video_path)
231
 
232
+ # Save images
233
  for i, im in enumerate(reader):
234
  imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
235
  input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
 
238
  input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
239
  input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
240
  fps = args.fps
241
+
242
  ############################################## extract audio feature ##############################################
243
+ # Extract audio features
244
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
245
+ whisper_chunks = audio_processor.get_whisper_chunk(
246
+ whisper_input_features,
247
+ device,
248
+ weight_dtype,
249
+ whisper,
250
+ librosa_length,
251
+ fps=fps,
252
+ audio_padding_length_left=args.audio_padding_length_left,
253
+ audio_padding_length_right=args.audio_padding_length_right,
254
+ )
255
+
256
  ############################################## preprocess input image ##############################################
257
  if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
258
  print("using extracted coordinates")
 
264
  coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
265
  with open(crop_coord_save_path, 'wb') as f:
266
  pickle.dump(coord_list, f)
267
+ bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
268
+
269
+ # Initialize face parser
270
+ fp = FaceParsing(
271
+ left_cheek_width=args.left_cheek_width,
272
+ right_cheek_width=args.right_cheek_width
273
+ )
274
+
275
  i = 0
276
  input_latent_list = []
277
  for bbox, frame in zip(coord_list, frame_list):
278
  if bbox == coord_placeholder:
279
  continue
280
  x1, y1, x2, y2 = bbox
281
+ y2 = y2 + args.extra_margin
282
+ y2 = min(y2, frame.shape[0])
283
  crop_frame = frame[y1:y2, x1:x2]
284
  crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
285
  latents = vae.get_latents_for_unet(crop_frame)
 
289
  frame_list_cycle = frame_list + frame_list[::-1]
290
  coord_list_cycle = coord_list + coord_list[::-1]
291
  input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
292
+
293
  ############################################## inference batch by batch ##############################################
294
  print("start inference")
295
  video_num = len(whisper_chunks)
296
  batch_size = args.batch_size
297
+ gen = datagen(
298
+ whisper_chunks=whisper_chunks,
299
+ vae_encode_latents=input_latent_list_cycle,
300
+ batch_size=batch_size,
301
+ delay_frame=0,
302
+ device=device,
303
+ )
304
  res_frame_list = []
305
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
306
+ audio_feature_batch = pe(whisper_batch)
307
+ # Ensure latent_batch is consistent with model weight type
308
+ latent_batch = latent_batch.to(dtype=weight_dtype)
 
309
 
310
  pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
311
  recon = vae.decode_latents(pred_latents)
 
318
  bbox = coord_list_cycle[i%(len(coord_list_cycle))]
319
  ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
320
  x1, y1, x2, y2 = bbox
321
+ y2 = y2 + args.extra_margin
322
+ y2 = min(y2, frame.shape[0])
323
  try:
324
  res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
325
  except:
 
326
  continue
327
 
328
+ # Use v15 version blending
329
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
330
+
331
  cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
332
 
333
+ # Frame rate
 
 
 
334
  fps = 25
335
+ # Output video path
 
336
  output_video = 'temp.mp4'
337
 
338
+ # Read images
339
  def is_valid_image(file):
340
  pattern = re.compile(r'\d{8}\.png')
341
  return pattern.match(file)
 
349
  images.append(imageio.imread(filename))
350
 
351
 
352
+ # Save video
353
  imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
354
 
 
 
 
 
355
  input_video = './temp.mp4'
356
  # Check if the input_video and audio_path exist
357
  if not os.path.exists(input_video):
 
359
  if not os.path.exists(audio_path):
360
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
361
 
362
+ # Read video
363
  reader = imageio.get_reader(input_video)
364
+ fps = reader.get_meta_data()['fps'] # Get original video frame rate
365
+ reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
366
+ # Store frames in list
367
  frames = images
 
 
 
 
 
 
 
368
 
369
  print(len(frames))
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  # Load the video
372
  video_clip = VideoFileClip(input_video)
373
 
 
388
 
389
 
390
  # load model weights
 
391
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
392
+ vae, unet, pe = load_all_model(
393
+ unet_model_path="./models/musetalkV15/unet.pth",
394
+ vae_type="sd-vae",
395
+ unet_config="./models/musetalkV15/musetalk.json",
396
+ device=device
397
+ )
398
 
399
+ # Parse command line arguments
400
+ parser = argparse.ArgumentParser()
401
+ parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
402
+ parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
403
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
404
+ parser.add_argument("--share", action="store_true", help="Create a public link")
405
+ parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
406
+ args = parser.parse_args()
407
+
408
+ # Set data type
409
+ if args.use_float16:
410
+ # Convert models to half precision for better performance
411
+ pe = pe.half()
412
+ vae.vae = vae.vae.half()
413
+ unet.model = unet.model.half()
414
+ weight_dtype = torch.float16
415
+ else:
416
+ weight_dtype = torch.float32
417
+
418
+ # Move models to specified device
419
+ pe = pe.to(device)
420
+ vae.vae = vae.vae.to(device)
421
+ unet.model = unet.model.to(device)
422
 
423
+ timesteps = torch.tensor([0], device=device)
424
+
425
+ # Initialize audio processor and Whisper model
426
+ audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
427
+ whisper = WhisperModel.from_pretrained("./models/whisper")
428
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
429
+ whisper.requires_grad_(False)
430
 
431
 
432
  def check_video(video):
 
447
  output_video = os.path.join('./results/input', output_file_name)
448
 
449
 
 
 
 
450
  # read video
451
  reader = imageio.get_reader(video)
452
  fps = reader.get_meta_data()['fps'] # get fps from original video
 
478
 
479
  with gr.Blocks(css=css) as demo:
480
  gr.Markdown(
481
+ """<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
482
  <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
483
  </br>\
484
+ Yue Zhang <sup>*</sup>,\
485
+ Zhizhou Zhong <sup>*</sup>,\
486
+ Minhao Liu<sup>*</sup>,\
487
  Zhaokang Chen,\
488
  Bin Wu<sup>†</sup>,\
489
+ Yubin Zeng,\
490
+ Chao Zhang,\
491
  Yingjie He,\
492
+ Junxin Huang,\
493
+ Wenjiang Zhou <br>\
494
  (<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
495
  Lyra Lab, Tencent Music Entertainment\
496
  </h2> \
497
  <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
498
  <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
499
+ <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
 
500
  )
501
 
502
  with gr.Row():
503
  with gr.Column():
504
+ audio = gr.Audio(label="Drving Audio",type="filepath")
505
  video = gr.Video(label="Reference Video",sources=['upload'])
506
  bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
507
+ extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
508
+ parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
509
+ left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
510
+ right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
511
+ bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
512
+
513
+ with gr.Row():
514
+ debug_btn = gr.Button("1. Test Inpainting ")
515
+ btn = gr.Button("2. Generate")
516
+ with gr.Column():
517
+ debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
518
+ debug_info = gr.Textbox(label="Parameter Information", lines=5)
519
+ out1 = gr.Video()
520
 
521
  video.change(
522
  fn=check_video, inputs=[video], outputs=[video]
 
527
  audio,
528
  video,
529
  bbox_shift,
530
+ extra_margin,
531
+ parsing_mode,
532
+ left_cheek_width,
533
+ right_cheek_width
534
  ],
535
  outputs=[out1,bbox_shift_scale]
536
  )
537
+ debug_btn.click(
538
+ fn=debug_inpainting,
539
+ inputs=[
540
+ video,
541
+ bbox_shift,
542
+ extra_margin,
543
+ parsing_mode,
544
+ left_cheek_width,
545
+ right_cheek_width
546
+ ],
547
+ outputs=[debug_image, debug_info]
548
+ )
549
 
550
+ # Check ffmpeg and add to PATH
551
+ if not fast_check_ffmpeg():
552
+ print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
553
+ # According to operating system, choose path separator
554
+ path_separator = ';' if sys.platform == 'win32' else ':'
555
+ os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
556
+ if not fast_check_ffmpeg():
557
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
558
+
559
+ # Solve asynchronous IO issues on Windows
560
+ if sys.platform == 'win32':
561
+ import asyncio
562
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
563
+
564
+ # Start Gradio application
565
  demo.queue().launch(
566
+ share=args.share,
567
+ debug=True,
568
+ server_name=args.ip,
569
+ server_port=args.port
570
  )
assets/figs/gradio.png ADDED

Git LFS Details

  • SHA256: 75202ee94d490eb2340a1e5461f39ca7dbdaa4405d09a0486ce0e267fcd07201
  • Pointer size: 130 Bytes
  • Size of remote file: 13.9 kB
assets/figs/gradio_2.png ADDED

Git LFS Details

  • SHA256: 83ea8834de20562b85d6e1def453dfb608a200c5db3eb26ff8e4604bd86cc3be
  • Pointer size: 130 Bytes
  • Size of remote file: 75.2 kB
download_weights.bat ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal
3
+
4
+ :: Set the checkpoints directory
5
+ set CheckpointsDir=models
6
+
7
+ :: Create necessary directories
8
+ mkdir %CheckpointsDir%\musetalk
9
+ mkdir %CheckpointsDir%\musetalkV15
10
+ mkdir %CheckpointsDir%\syncnet
11
+ mkdir %CheckpointsDir%\dwpose
12
+ mkdir %CheckpointsDir%\face-parse-bisent
13
+ mkdir %CheckpointsDir%\sd-vae-ft-mse
14
+ mkdir %CheckpointsDir%\whisper
15
+
16
+ :: Install required packages
17
+ pip install -U "huggingface_hub[cli]"
18
+ pip install gdown
19
+
20
+ :: Set HuggingFace endpoint
21
+ set HF_ENDPOINT=https://hf-mirror.com
22
+
23
+ :: Download MuseTalk weights
24
+ huggingface-cli download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
25
+
26
+ :: Download SD VAE weights
27
+ huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
28
+
29
+ :: Download Whisper weights
30
+ huggingface-cli download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
31
+
32
+ :: Download DWPose weights
33
+ huggingface-cli download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
34
+
35
+ :: Download SyncNet weights
36
+ huggingface-cli download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
37
+
38
+ :: Download Face Parse Bisent weights (using gdown)
39
+ gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O %CheckpointsDir%\face-parse-bisent\79999_iter.pth
40
+
41
+ :: Download ResNet weights
42
+ curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o %CheckpointsDir%\face-parse-bisent\resnet18-5c106cde.pth
43
+
44
+ echo All weights have been downloaded successfully!
45
+ endlocal
download_weights.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set the checkpoints directory
4
+ CheckpointsDir="models"
5
+
6
+ # Create necessary directories
7
+ mkdir -p $CheckpointsDir/{musetalk,musetalkV15,syncnet,dwpose,face-parse-bisent,sd-vae-ft-mse,whisper}
8
+
9
+ # Install required packages
10
+ pip install -U "huggingface_hub[cli]"
11
+ pip install gdown
12
+
13
+ # Set HuggingFace endpoint
14
+ export HF_ENDPOINT=https://hf-mirror.com
15
+
16
+ # Download MuseTalk weights
17
+ huggingface-cli download TMElyralab/MuseTalk --local-dir $CheckpointsDir
18
+
19
+ # Download SD VAE weights
20
+ huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir $CheckpointsDir/sd-vae --include "config.json" "diffusion_pytorch_model.bin"
21
+
22
+ # Download Whisper weights
23
+ huggingface-cli download openai/whisper-tiny --local-dir $CheckpointsDir/whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
24
+
25
+ # Download DWPose weights
26
+ huggingface-cli download yzd-v/DWPose --local-dir $CheckpointsDir/dwpose --include "dw-ll_ucoco_384.pth"
27
+
28
+ # Download SyncNet weights
29
+ huggingface-cli download ByteDance/LatentSync --local-dir $CheckpointsDir/syncnet --include "latentsync_syncnet.pt"
30
+
31
+ # Download Face Parse Bisent weights (using gdown)
32
+ gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
33
+
34
+ # Download ResNet weights
35
+ curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
36
+
37
+ echo "All weights have been downloaded successfully!"
musetalk/utils/audio_processor.py CHANGED
@@ -49,8 +49,9 @@ class AudioProcessor:
49
  whisper_feature = []
50
  # Process multiple 30s mel input features
51
  for input_feature in whisper_input_features:
52
- audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
53
- audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
 
54
  whisper_feature.append(audio_feats)
55
 
56
  whisper_feature = torch.cat(whisper_feature, dim=1)
 
49
  whisper_feature = []
50
  # Process multiple 30s mel input features
51
  for input_feature in whisper_input_features:
52
+ input_feature = input_feature.to(device).to(weight_dtype)
53
+ audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
54
+ audio_feats = torch.stack(audio_feats, dim=2)
55
  whisper_feature.append(audio_feats)
56
 
57
  whisper_feature = torch.cat(whisper_feature, dim=1)
musetalk/utils/utils.py CHANGED
@@ -8,26 +8,18 @@ from einops import rearrange
8
  import shutil
9
  import os.path as osp
10
 
11
- ffmpeg_path = os.getenv('FFMPEG_PATH')
12
- if ffmpeg_path is None:
13
- print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
14
- elif ffmpeg_path not in os.getenv('PATH'):
15
- print("add ffmpeg to path")
16
- os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
17
-
18
-
19
  from musetalk.models.vae import VAE
20
  from musetalk.models.unet import UNet,PositionalEncoding
21
 
22
 
23
  def load_all_model(
24
- unet_model_path="./models/musetalk/pytorch_model.bin",
25
- vae_type="sd-vae-ft-mse",
26
- unet_config="./models/musetalk/musetalk.json",
27
  device=None,
28
  ):
29
  vae = VAE(
30
- model_path = f"./models/{vae_type}/",
31
  )
32
  print(f"load unet model from {unet_model_path}")
33
  unet = UNet(
 
8
  import shutil
9
  import os.path as osp
10
 
 
 
 
 
 
 
 
 
11
  from musetalk.models.vae import VAE
12
  from musetalk.models.unet import UNet,PositionalEncoding
13
 
14
 
15
  def load_all_model(
16
+ unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
17
+ vae_type="sd-vae",
18
+ unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
19
  device=None,
20
  ):
21
  vae = VAE(
22
+ model_path = os.path.join("models", vae_type),
23
  )
24
  print(f"load unet model from {unet_model_path}")
25
  unet = UNet(
requirements.txt CHANGED
@@ -1,15 +1,15 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu118
2
- torch==2.0.1
3
- torchvision==0.15.2
4
- torchaudio==2.0.2
5
- diffusers==0.27.2
6
  accelerate==0.28.0
 
7
  tensorflow==2.12.0
8
  tensorboard==2.12.0
9
  opencv-python==4.9.0.80
10
  soundfile==0.12.1
11
  transformers==4.39.2
12
- huggingface_hub==0.25.0
 
 
 
13
 
14
  gdown
15
  requests
@@ -17,6 +17,4 @@ imageio[ffmpeg]
17
 
18
  omegaconf
19
  ffmpeg-python
20
- gradio
21
- spaces
22
  moviepy
 
1
+ diffusers==0.30.2
 
 
 
 
2
  accelerate==0.28.0
3
+ numpy==1.23.5
4
  tensorflow==2.12.0
5
  tensorboard==2.12.0
6
  opencv-python==4.9.0.80
7
  soundfile==0.12.1
8
  transformers==4.39.2
9
+ huggingface_hub==0.30.2
10
+ librosa==0.11.0
11
+ einops==0.8.1
12
+ gradio==5.24.0
13
 
14
  gdown
15
  requests
 
17
 
18
  omegaconf
19
  ffmpeg-python
 
 
20
  moviepy
scripts/inference.py CHANGED
@@ -8,9 +8,11 @@ import shutil
8
  import pickle
9
  import argparse
10
  import numpy as np
 
11
  from tqdm import tqdm
12
  from omegaconf import OmegaConf
13
  from transformers import WhisperModel
 
14
 
15
  from musetalk.utils.blending import get_image
16
  from musetalk.utils.face_parsing import FaceParsing
@@ -18,16 +20,26 @@ from musetalk.utils.audio_processor import AudioProcessor
18
  from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
19
  from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
20
 
 
 
 
 
 
 
 
21
  @torch.no_grad()
22
  def main(args):
23
  # Configure ffmpeg path
24
- if args.ffmpeg_path not in os.getenv('PATH'):
25
  print("Adding ffmpeg to PATH")
26
- os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
 
 
 
 
27
 
28
  # Set computing device
29
  device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
30
-
31
  # Load model weights
32
  vae, unet, pe = load_all_model(
33
  unet_model_path=args.unet_model_path,
 
8
  import pickle
9
  import argparse
10
  import numpy as np
11
+ import subprocess
12
  from tqdm import tqdm
13
  from omegaconf import OmegaConf
14
  from transformers import WhisperModel
15
+ import sys
16
 
17
  from musetalk.utils.blending import get_image
18
  from musetalk.utils.face_parsing import FaceParsing
 
20
  from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
21
  from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
22
 
23
+ def fast_check_ffmpeg():
24
+ try:
25
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
26
+ return True
27
+ except:
28
+ return False
29
+
30
  @torch.no_grad()
31
  def main(args):
32
  # Configure ffmpeg path
33
+ if not fast_check_ffmpeg():
34
  print("Adding ffmpeg to PATH")
35
+ # Choose path separator based on operating system
36
+ path_separator = ';' if sys.platform == 'win32' else ':'
37
+ os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
38
+ if not fast_check_ffmpeg():
39
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
40
 
41
  # Set computing device
42
  device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
 
43
  # Load model weights
44
  vae, unet, pe = load_all_model(
45
  unet_model_path=args.unet_model_path,
scripts/preprocess.py CHANGED
@@ -12,11 +12,23 @@ from mmpose.structures import merge_data_samples
12
  import torch
13
  import numpy as np
14
  from tqdm import tqdm
 
 
 
 
 
 
 
 
15
 
16
  ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
17
- if ffmpeg_path not in os.getenv('PATH'):
18
- print("add ffmpeg to path")
19
- os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
 
 
 
 
20
 
21
  class AnalyzeFace:
22
  def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
 
12
  import torch
13
  import numpy as np
14
  from tqdm import tqdm
15
+ import sys
16
+
17
+ def fast_check_ffmpeg():
18
+ try:
19
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
20
+ return True
21
+ except:
22
+ return False
23
 
24
  ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
25
+ if not fast_check_ffmpeg():
26
+ print("Adding ffmpeg to PATH")
27
+ # Choose path separator based on operating system
28
+ path_separator = ';' if sys.platform == 'win32' else ':'
29
+ os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
30
+ if not fast_check_ffmpeg():
31
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
32
 
33
  class AnalyzeFace:
34
  def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
scripts/realtime_inference.py CHANGED
@@ -23,6 +23,15 @@ import shutil
23
  import threading
24
  import queue
25
  import time
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
@@ -318,7 +327,7 @@ if __name__ == "__main__":
318
  parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
319
  parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
320
  parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
321
- parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
322
  parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
323
  parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
324
  parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
@@ -332,6 +341,15 @@ if __name__ == "__main__":
332
 
333
  args = parser.parse_args()
334
 
 
 
 
 
 
 
 
 
 
335
  # Set computing device
336
  device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
337
 
 
23
  import threading
24
  import queue
25
  import time
26
+ import subprocess
27
+
28
+
29
+ def fast_check_ffmpeg():
30
+ try:
31
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
32
+ return True
33
+ except:
34
+ return False
35
 
36
 
37
  def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
 
327
  parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
328
  parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
329
  parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
330
+ parser.add_argument("--batch_size", type=int, default=20, help="Batch size for inference")
331
  parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
332
  parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
333
  parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
 
341
 
342
  args = parser.parse_args()
343
 
344
+ # Configure ffmpeg path
345
+ if not fast_check_ffmpeg():
346
+ print("Adding ffmpeg to PATH")
347
+ # Choose path separator based on operating system
348
+ path_separator = ';' if sys.platform == 'win32' else ':'
349
+ os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
350
+ if not fast_check_ffmpeg():
351
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
352
+
353
  # Set computing device
354
  device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
355
 
test_ffmpeg.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ def test_ffmpeg(ffmpeg_path):
6
+ print(f"Testing ffmpeg path: {ffmpeg_path}")
7
+
8
+ # Choose path separator based on operating system
9
+ path_separator = ';' if sys.platform == 'win32' else ':'
10
+
11
+ # Add ffmpeg path to environment variable
12
+ os.environ["PATH"] = f"{ffmpeg_path}{path_separator}{os.environ['PATH']}"
13
+
14
+ try:
15
+ # Try to run ffmpeg
16
+ result = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
17
+ print("FFmpeg test successful!")
18
+ print("FFmpeg version information:")
19
+ print(result.stdout)
20
+ return True
21
+ except Exception as e:
22
+ print("FFmpeg test failed!")
23
+ print(f"Error message: {str(e)}")
24
+ return False
25
+
26
+ if __name__ == "__main__":
27
+ # Default ffmpeg path, can be modified as needed
28
+ default_path = r"ffmpeg-master-latest-win64-gpl-shared\bin"
29
+
30
+ # Use command line argument if provided, otherwise use default path
31
+ ffmpeg_path = sys.argv[1] if len(sys.argv) > 1 else default_path
32
+
33
+ test_ffmpeg(ffmpeg_path)