feat: windows infer & gradio (#312)
Browse files* fix: windows infer
* docs: update readme
* docs: update readme
* feat: v1.5 gradio for windows&linux
* fix: dependencies
* feat: windows infer & gradio
---------
Co-authored-by: NeRF-Factory <zzhizhou66@gmail.com>
- .gitignore +5 -2
- README.md +118 -45
- app.py +311 -167
- assets/figs/gradio.png +3 -0
- assets/figs/gradio_2.png +3 -0
- download_weights.bat +45 -0
- download_weights.sh +37 -0
- musetalk/utils/audio_processor.py +3 -2
- musetalk/utils/utils.py +4 -12
- requirements.txt +6 -8
- scripts/inference.py +15 -3
- scripts/preprocess.py +15 -3
- scripts/realtime_inference.py +19 -1
- test_ffmpeg.py +33 -0
.gitignore
CHANGED
|
@@ -5,11 +5,14 @@
|
|
| 5 |
*.pyc
|
| 6 |
.ipynb_checkpoints
|
| 7 |
results/
|
| 8 |
-
|
| 9 |
**/__pycache__/
|
| 10 |
*.py[cod]
|
| 11 |
*$py.class
|
| 12 |
dataset/
|
| 13 |
ffmpeg*
|
|
|
|
|
|
|
| 14 |
debug
|
| 15 |
-
exp_out
|
|
|
|
|
|
| 5 |
*.pyc
|
| 6 |
.ipynb_checkpoints
|
| 7 |
results/
|
| 8 |
+
models/
|
| 9 |
**/__pycache__/
|
| 10 |
*.py[cod]
|
| 11 |
*$py.class
|
| 12 |
dataset/
|
| 13 |
ffmpeg*
|
| 14 |
+
ffmprobe*
|
| 15 |
+
ffplay*
|
| 16 |
debug
|
| 17 |
+
exp_out
|
| 18 |
+
.gradio
|
README.md
CHANGED
|
@@ -146,50 +146,87 @@ We also hope you note that we have not verified, maintained, or updated third-pa
|
|
| 146 |
|
| 147 |
## Installation
|
| 148 |
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
|
|
|
| 149 |
### Build environment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
```shell
|
| 154 |
pip install -r requirements.txt
|
| 155 |
```
|
| 156 |
|
| 157 |
-
###
|
|
|
|
|
|
|
| 158 |
```bash
|
| 159 |
-
pip install --no-cache-dir -U openmim
|
| 160 |
-
mim install mmengine
|
| 161 |
-
mim install "mmcv
|
| 162 |
-
mim install "mmdet
|
| 163 |
-
mim install "mmpose
|
| 164 |
```
|
| 165 |
|
| 166 |
-
###
|
| 167 |
-
Download the ffmpeg-static
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
export FFMPEG_PATH=/path/to/ffmpeg
|
| 170 |
-
|
| 171 |
-
for example:
|
| 172 |
-
```
|
| 173 |
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
| 174 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
### Download weights
|
| 176 |
-
You can download weights
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
```bash
|
| 180 |
-
|
| 181 |
-
export HF_ENDPOINT=https://hf-mirror.com
|
| 182 |
-
huggingface-cli download TMElyralab/MuseTalk --local-dir models/
|
| 183 |
```
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
2. Download the weights of other components:
|
| 186 |
-
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
|
| 187 |
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
| 188 |
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
| 189 |
-
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
|
| 190 |
-
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
| 191 |
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
| 192 |
-
|
|
|
|
| 193 |
|
| 194 |
Finally, these weights should be organized in `models` as follows:
|
| 195 |
```
|
|
@@ -207,7 +244,7 @@ Finally, these weights should be organized in `models` as follows:
|
|
| 207 |
├── face-parse-bisent
|
| 208 |
│ ├── 79999_iter.pth
|
| 209 |
│ └── resnet18-5c106cde.pth
|
| 210 |
-
├── sd-vae
|
| 211 |
│ ├── config.json
|
| 212 |
│ └── diffusion_pytorch_model.bin
|
| 213 |
└── whisper
|
|
@@ -221,42 +258,66 @@ Finally, these weights should be organized in `models` as follows:
|
|
| 221 |
### Inference
|
| 222 |
We provide inference scripts for both versions of MuseTalk:
|
| 223 |
|
| 224 |
-
####
|
|
|
|
| 225 |
```bash
|
| 226 |
-
#
|
| 227 |
-
|
| 228 |
```
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
####
|
|
|
|
| 231 |
```bash
|
| 232 |
-
#
|
|
|
|
|
|
|
|
|
|
| 233 |
sh inference.sh v1.0 normal
|
| 234 |
```
|
| 235 |
|
| 236 |
-
|
| 237 |
-
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
| 238 |
-
- For MuseTalk 1.0: Use the same script but point to the V1.0 model path
|
| 239 |
|
| 240 |
-
|
| 241 |
-
- `video_path`: Path to the input video, image file, or directory of images
|
| 242 |
-
- `audio_path`: Path to the input audio file
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
#### Real-time Inference
|
| 247 |
-
|
| 248 |
```bash
|
| 249 |
-
#
|
| 250 |
-
sh inference.sh v1.5 realtime
|
| 251 |
-
|
| 252 |
-
|
|
|
|
| 253 |
```
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
Important notes for real-time inference:
|
| 262 |
1. Set `preparation` to `True` when processing a new avatar
|
|
@@ -269,6 +330,18 @@ For faster generation without saving images, you can use:
|
|
| 269 |
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
| 270 |
```
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
## Training
|
| 273 |
|
| 274 |
### Data Preparation
|
|
|
|
| 146 |
|
| 147 |
## Installation
|
| 148 |
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
| 149 |
+
|
| 150 |
### Build environment
|
| 151 |
+
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
|
| 152 |
+
|
| 153 |
+
```shell
|
| 154 |
+
conda create -n MuseTalk python==3.10
|
| 155 |
+
conda activate MuseTalk
|
| 156 |
+
```
|
| 157 |
|
| 158 |
+
### Install PyTorch 2.0.1
|
| 159 |
+
Choose one of the following installation methods:
|
| 160 |
+
|
| 161 |
+
```shell
|
| 162 |
+
# Option 1: Using pip
|
| 163 |
+
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
| 164 |
+
|
| 165 |
+
# Option 2: Using conda
|
| 166 |
+
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### Install Dependencies
|
| 170 |
+
Install the remaining required packages:
|
| 171 |
|
| 172 |
```shell
|
| 173 |
pip install -r requirements.txt
|
| 174 |
```
|
| 175 |
|
| 176 |
+
### Install MMLab Packages
|
| 177 |
+
Install the MMLab ecosystem packages:
|
| 178 |
+
|
| 179 |
```bash
|
| 180 |
+
pip install --no-cache-dir -U openmim
|
| 181 |
+
mim install mmengine
|
| 182 |
+
mim install "mmcv==2.0.1"
|
| 183 |
+
mim install "mmdet==3.1.0"
|
| 184 |
+
mim install "mmpose==1.1.0"
|
| 185 |
```
|
| 186 |
|
| 187 |
+
### Setup FFmpeg
|
| 188 |
+
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
|
| 189 |
+
|
| 190 |
+
2. Configure FFmpeg based on your operating system:
|
| 191 |
+
|
| 192 |
+
For Linux:
|
| 193 |
+
```bash
|
| 194 |
export FFMPEG_PATH=/path/to/ffmpeg
|
| 195 |
+
# Example:
|
|
|
|
|
|
|
| 196 |
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
| 197 |
```
|
| 198 |
+
|
| 199 |
+
For Windows:
|
| 200 |
+
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
|
| 201 |
+
|
| 202 |
### Download weights
|
| 203 |
+
You can download weights in two ways:
|
| 204 |
+
|
| 205 |
+
#### Option 1: Using Download Scripts
|
| 206 |
+
We provide two scripts for automatic downloading:
|
| 207 |
|
| 208 |
+
For Linux:
|
| 209 |
```bash
|
| 210 |
+
sh ./download_weights.sh
|
|
|
|
|
|
|
| 211 |
```
|
| 212 |
|
| 213 |
+
For Windows:
|
| 214 |
+
```batch
|
| 215 |
+
# Run the script
|
| 216 |
+
download_weights.bat
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
#### Option 2: Manual Download
|
| 220 |
+
You can also download the weights manually from the following links:
|
| 221 |
+
|
| 222 |
+
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
|
| 223 |
2. Download the weights of other components:
|
| 224 |
+
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
|
| 225 |
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
| 226 |
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
|
|
|
|
|
|
| 227 |
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
| 228 |
+
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
|
| 229 |
+
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
| 230 |
|
| 231 |
Finally, these weights should be organized in `models` as follows:
|
| 232 |
```
|
|
|
|
| 244 |
├── face-parse-bisent
|
| 245 |
│ ├── 79999_iter.pth
|
| 246 |
│ └── resnet18-5c106cde.pth
|
| 247 |
+
├── sd-vae
|
| 248 |
│ ├── config.json
|
| 249 |
│ └── diffusion_pytorch_model.bin
|
| 250 |
└── whisper
|
|
|
|
| 258 |
### Inference
|
| 259 |
We provide inference scripts for both versions of MuseTalk:
|
| 260 |
|
| 261 |
+
#### Prerequisites
|
| 262 |
+
Before running inference, please ensure ffmpeg is installed and accessible:
|
| 263 |
```bash
|
| 264 |
+
# Check ffmpeg installation
|
| 265 |
+
ffmpeg -version
|
| 266 |
```
|
| 267 |
+
If ffmpeg is not found, please install it first:
|
| 268 |
+
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
|
| 269 |
+
- Linux: `sudo apt-get install ffmpeg`
|
| 270 |
|
| 271 |
+
#### Normal Inference
|
| 272 |
+
##### Linux Environment
|
| 273 |
```bash
|
| 274 |
+
# MuseTalk 1.5 (Recommended)
|
| 275 |
+
sh inference.sh v1.5 normal
|
| 276 |
+
|
| 277 |
+
# MuseTalk 1.0
|
| 278 |
sh inference.sh v1.0 normal
|
| 279 |
```
|
| 280 |
|
| 281 |
+
##### Windows Environment
|
|
|
|
|
|
|
| 282 |
|
| 283 |
+
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
```bash
|
| 286 |
+
# MuseTalk 1.5 (Recommended)
|
| 287 |
+
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
| 288 |
+
|
| 289 |
+
# For MuseTalk 1.0, change:
|
| 290 |
+
# - models\musetalkV15 -> models\musetalk
|
| 291 |
+
# - unet.pth -> pytorch_model.bin
|
| 292 |
+
# - --version v15 -> --version v1
|
| 293 |
+
```
|
| 294 |
|
| 295 |
#### Real-time Inference
|
| 296 |
+
##### Linux Environment
|
| 297 |
```bash
|
| 298 |
+
# MuseTalk 1.5 (Recommended)
|
| 299 |
+
sh inference.sh v1.5 realtime
|
| 300 |
+
|
| 301 |
+
# MuseTalk 1.0
|
| 302 |
+
sh inference.sh v1.0 realtime
|
| 303 |
```
|
| 304 |
|
| 305 |
+
##### Windows Environment
|
| 306 |
+
```bash
|
| 307 |
+
# MuseTalk 1.5 (Recommended)
|
| 308 |
+
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
| 309 |
+
|
| 310 |
+
# For MuseTalk 1.0, change:
|
| 311 |
+
# - models\musetalkV15 -> models\musetalk
|
| 312 |
+
# - unet.pth -> pytorch_model.bin
|
| 313 |
+
# - --version v15 -> --version v1
|
| 314 |
+
```
|
| 315 |
+
|
| 316 |
+
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
| 317 |
+
- `video_path`: Path to the input video, image file, or directory of images
|
| 318 |
+
- `audio_path`: Path to the input audio file
|
| 319 |
+
|
| 320 |
+
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
| 321 |
|
| 322 |
Important notes for real-time inference:
|
| 323 |
1. Set `preparation` to `True` when processing a new avatar
|
|
|
|
| 330 |
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
| 331 |
```
|
| 332 |
|
| 333 |
+
## Gradio Demo
|
| 334 |
+
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
|
| 335 |
+

|
| 336 |
+
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. 
|
| 337 |
+
|
| 338 |
+
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
|
| 339 |
+
|
| 340 |
+
```bash
|
| 341 |
+
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
|
| 342 |
+
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
## Training
|
| 346 |
|
| 347 |
### Data Preparation
|
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import pdb
|
|
| 4 |
import re
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
-
import spaces
|
| 8 |
import numpy as np
|
| 9 |
import sys
|
| 10 |
import subprocess
|
|
@@ -28,11 +27,101 @@ import gdown
|
|
| 28 |
import imageio
|
| 29 |
import ffmpeg
|
| 30 |
from moviepy.editor import *
|
| 31 |
-
|
| 32 |
|
| 33 |
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
| 34 |
CheckpointsDir = os.path.join(ProjectDir, "models")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def print_directory_contents(path):
|
| 37 |
for child in os.listdir(path):
|
| 38 |
child_path = os.path.join(path, child)
|
|
@@ -40,119 +129,107 @@ def print_directory_contents(path):
|
|
| 40 |
print(child_path)
|
| 41 |
|
| 42 |
def download_model():
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
max_workers=8,
|
| 69 |
-
local_dir_use_symlinks=True,
|
| 70 |
-
force_download=True, resume_download=False
|
| 71 |
-
)
|
| 72 |
-
#vae
|
| 73 |
-
url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
|
| 74 |
-
response = requests.get(url)
|
| 75 |
-
# 确保请求成功
|
| 76 |
-
if response.status_code == 200:
|
| 77 |
-
# 指定文件保存的位置
|
| 78 |
-
file_path = f"{CheckpointsDir}/whisper/tiny.pt"
|
| 79 |
-
os.makedirs(f"{CheckpointsDir}/whisper/")
|
| 80 |
-
# 将文件内容写入指定位置
|
| 81 |
-
with open(file_path, "wb") as f:
|
| 82 |
-
f.write(response.content)
|
| 83 |
-
else:
|
| 84 |
-
print(f"请求失败,状态码:{response.status_code}")
|
| 85 |
-
#gdown face parse
|
| 86 |
-
url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
|
| 87 |
-
os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
|
| 88 |
-
file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
|
| 89 |
-
gdown.download(url, file_path, quiet=False)
|
| 90 |
-
#resnet
|
| 91 |
-
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
| 92 |
-
response = requests.get(url)
|
| 93 |
-
# 确保请求成功
|
| 94 |
-
if response.status_code == 200:
|
| 95 |
-
# 指定文件保存的位置
|
| 96 |
-
file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
| 97 |
-
# 将文件内容写入指定位置
|
| 98 |
-
with open(file_path, "wb") as f:
|
| 99 |
-
f.write(response.content)
|
| 100 |
else:
|
| 101 |
-
print(
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
toc = time.time()
|
| 105 |
-
|
| 106 |
-
print(f"download cost {toc-tic} seconds")
|
| 107 |
-
print_directory_contents(CheckpointsDir)
|
| 108 |
-
|
| 109 |
else:
|
| 110 |
-
print("
|
| 111 |
-
|
| 112 |
|
| 113 |
|
| 114 |
|
| 115 |
|
| 116 |
download_model() # for huggingface deployment.
|
| 117 |
|
| 118 |
-
|
| 119 |
-
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
| 120 |
-
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
|
| 121 |
from musetalk.utils.blending import get_image
|
| 122 |
-
from musetalk.utils.
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
@spaces.GPU(duration=600)
|
| 130 |
@torch.no_grad()
|
| 131 |
-
def inference(audio_path,video_path,bbox_shift,
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
args = Namespace(**args_dict)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
input_basename = os.path.basename(video_path).split('.')[0]
|
| 136 |
-
audio_basename
|
| 137 |
output_basename = f"{input_basename}_{audio_basename}"
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
os.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
if args.output_vid_name=="":
|
| 143 |
-
output_vid_name = os.path.join(
|
| 144 |
else:
|
| 145 |
-
output_vid_name = os.path.join(
|
|
|
|
| 146 |
############################################## extract frames from source video ##############################################
|
| 147 |
-
if get_file_type(video_path)=="video":
|
| 148 |
-
save_dir_full = os.path.join(
|
| 149 |
-
os.makedirs(save_dir_full,
|
| 150 |
-
#
|
| 151 |
-
# os.system(cmd)
|
| 152 |
-
# 读取视频
|
| 153 |
reader = imageio.get_reader(video_path)
|
| 154 |
|
| 155 |
-
#
|
| 156 |
for i, im in enumerate(reader):
|
| 157 |
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
| 158 |
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
|
@@ -161,10 +238,21 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|
| 161 |
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
| 162 |
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
| 163 |
fps = args.fps
|
| 164 |
-
|
| 165 |
############################################## extract audio feature ##############################################
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
############################################## preprocess input image ##############################################
|
| 169 |
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
| 170 |
print("using extracted coordinates")
|
|
@@ -176,13 +264,22 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|
| 176 |
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
| 177 |
with open(crop_coord_save_path, 'wb') as f:
|
| 178 |
pickle.dump(coord_list, f)
|
| 179 |
-
bbox_shift_text=get_bbox_range(input_img_list, bbox_shift)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
i = 0
|
| 181 |
input_latent_list = []
|
| 182 |
for bbox, frame in zip(coord_list, frame_list):
|
| 183 |
if bbox == coord_placeholder:
|
| 184 |
continue
|
| 185 |
x1, y1, x2, y2 = bbox
|
|
|
|
|
|
|
| 186 |
crop_frame = frame[y1:y2, x1:x2]
|
| 187 |
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
| 188 |
latents = vae.get_latents_for_unet(crop_frame)
|
|
@@ -192,17 +289,23 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|
| 192 |
frame_list_cycle = frame_list + frame_list[::-1]
|
| 193 |
coord_list_cycle = coord_list + coord_list[::-1]
|
| 194 |
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
|
|
|
| 195 |
############################################## inference batch by batch ##############################################
|
| 196 |
print("start inference")
|
| 197 |
video_num = len(whisper_chunks)
|
| 198 |
batch_size = args.batch_size
|
| 199 |
-
gen = datagen(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
res_frame_list = []
|
| 201 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
audio_feature_batch = pe(audio_feature_batch)
|
| 206 |
|
| 207 |
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
| 208 |
recon = vae.decode_latents(pred_latents)
|
|
@@ -215,25 +318,24 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|
| 215 |
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
| 216 |
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
| 217 |
x1, y1, x2, y2 = bbox
|
|
|
|
|
|
|
| 218 |
try:
|
| 219 |
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
| 220 |
except:
|
| 221 |
-
# print(bbox)
|
| 222 |
continue
|
| 223 |
|
| 224 |
-
|
|
|
|
|
|
|
| 225 |
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
| 226 |
|
| 227 |
-
#
|
| 228 |
-
# print(cmd_img2video)
|
| 229 |
-
# os.system(cmd_img2video)
|
| 230 |
-
# 帧率
|
| 231 |
fps = 25
|
| 232 |
-
#
|
| 233 |
-
# 输出视频路径
|
| 234 |
output_video = 'temp.mp4'
|
| 235 |
|
| 236 |
-
#
|
| 237 |
def is_valid_image(file):
|
| 238 |
pattern = re.compile(r'\d{8}\.png')
|
| 239 |
return pattern.match(file)
|
|
@@ -247,13 +349,9 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|
| 247 |
images.append(imageio.imread(filename))
|
| 248 |
|
| 249 |
|
| 250 |
-
#
|
| 251 |
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
| 252 |
|
| 253 |
-
# cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
|
| 254 |
-
# print(cmd_combine_audio)
|
| 255 |
-
# os.system(cmd_combine_audio)
|
| 256 |
-
|
| 257 |
input_video = './temp.mp4'
|
| 258 |
# Check if the input_video and audio_path exist
|
| 259 |
if not os.path.exists(input_video):
|
|
@@ -261,40 +359,15 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|
| 261 |
if not os.path.exists(audio_path):
|
| 262 |
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
| 263 |
|
| 264 |
-
#
|
| 265 |
reader = imageio.get_reader(input_video)
|
| 266 |
-
fps = reader.get_meta_data()['fps'] #
|
| 267 |
-
reader.close() #
|
| 268 |
-
#
|
| 269 |
frames = images
|
| 270 |
-
|
| 271 |
-
# 保存视频并添加音频
|
| 272 |
-
# imageio.mimwrite(output_vid_name, frames, 'FFMPEG', fps=fps, codec='libx264', audio_codec='aac', input_params=['-i', audio_path])
|
| 273 |
-
|
| 274 |
-
# input_video = ffmpeg.input(input_video)
|
| 275 |
-
|
| 276 |
-
# input_audio = ffmpeg.input(audio_path)
|
| 277 |
|
| 278 |
print(len(frames))
|
| 279 |
|
| 280 |
-
# imageio.mimwrite(
|
| 281 |
-
# output_video,
|
| 282 |
-
# frames,
|
| 283 |
-
# 'FFMPEG',
|
| 284 |
-
# fps=25,
|
| 285 |
-
# codec='libx264',
|
| 286 |
-
# audio_codec='aac',
|
| 287 |
-
# input_params=['-i', audio_path],
|
| 288 |
-
# output_params=['-y'], # Add the '-y' flag to overwrite the output file if it exists
|
| 289 |
-
# )
|
| 290 |
-
# writer = imageio.get_writer(output_vid_name, fps = 25, codec='libx264', quality=10, pixelformat='yuvj444p')
|
| 291 |
-
# for im in frames:
|
| 292 |
-
# writer.append_data(im)
|
| 293 |
-
# writer.close()
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
# Load the video
|
| 299 |
video_clip = VideoFileClip(input_video)
|
| 300 |
|
|
@@ -315,11 +388,45 @@ def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=T
|
|
| 315 |
|
| 316 |
|
| 317 |
# load model weights
|
| 318 |
-
audio_processor,vae,unet,pe = load_all_model()
|
| 319 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
|
| 325 |
def check_video(video):
|
|
@@ -340,9 +447,6 @@ def check_video(video):
|
|
| 340 |
output_video = os.path.join('./results/input', output_file_name)
|
| 341 |
|
| 342 |
|
| 343 |
-
# # Run the ffmpeg command to change the frame rate to 25fps
|
| 344 |
-
# command = f"ffmpeg -i {video} -r 25 -vcodec libx264 -vtag hvc1 -pix_fmt yuv420p crf 18 {output_video} -y"
|
| 345 |
-
|
| 346 |
# read video
|
| 347 |
reader = imageio.get_reader(video)
|
| 348 |
fps = reader.get_meta_data()['fps'] # get fps from original video
|
|
@@ -374,34 +478,45 @@ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024p
|
|
| 374 |
|
| 375 |
with gr.Blocks(css=css) as demo:
|
| 376 |
gr.Markdown(
|
| 377 |
-
"<div align='center'> <h1>MuseTalk: Real-Time High
|
| 378 |
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
| 379 |
</br>\
|
| 380 |
-
Yue Zhang <sup>
|
| 381 |
-
|
|
|
|
| 382 |
Zhaokang Chen,\
|
| 383 |
Bin Wu<sup>†</sup>,\
|
|
|
|
|
|
|
| 384 |
Yingjie He,\
|
| 385 |
-
|
| 386 |
-
Wenjiang Zhou\
|
| 387 |
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
| 388 |
Lyra Lab, Tencent Music Entertainment\
|
| 389 |
</h2> \
|
| 390 |
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
| 391 |
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
| 392 |
-
<a style='font-size:18px;color: #000000' href=''> [Technical report
|
| 393 |
-
<a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
|
| 394 |
)
|
| 395 |
|
| 396 |
with gr.Row():
|
| 397 |
with gr.Column():
|
| 398 |
-
audio = gr.Audio(label="
|
| 399 |
video = gr.Video(label="Reference Video",sources=['upload'])
|
| 400 |
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
video.change(
|
| 407 |
fn=check_video, inputs=[video], outputs=[video]
|
|
@@ -412,15 +527,44 @@ with gr.Blocks(css=css) as demo:
|
|
| 412 |
audio,
|
| 413 |
video,
|
| 414 |
bbox_shift,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
],
|
| 416 |
outputs=[out1,bbox_shift_scale]
|
| 417 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
-
#
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
demo.queue().launch(
|
| 425 |
-
share=
|
|
|
|
|
|
|
|
|
|
| 426 |
)
|
|
|
|
| 4 |
import re
|
| 5 |
|
| 6 |
import gradio as gr
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import sys
|
| 9 |
import subprocess
|
|
|
|
| 27 |
import imageio
|
| 28 |
import ffmpeg
|
| 29 |
from moviepy.editor import *
|
| 30 |
+
from transformers import WhisperModel
|
| 31 |
|
| 32 |
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
| 33 |
CheckpointsDir = os.path.join(ProjectDir, "models")
|
| 34 |
|
| 35 |
+
@torch.no_grad()
|
| 36 |
+
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
| 37 |
+
left_cheek_width=90, right_cheek_width=90):
|
| 38 |
+
"""Debug inpainting parameters, only process the first frame"""
|
| 39 |
+
# Set default parameters
|
| 40 |
+
args_dict = {
|
| 41 |
+
"result_dir": './results/debug',
|
| 42 |
+
"fps": 25,
|
| 43 |
+
"batch_size": 1,
|
| 44 |
+
"output_vid_name": '',
|
| 45 |
+
"use_saved_coord": False,
|
| 46 |
+
"audio_padding_length_left": 2,
|
| 47 |
+
"audio_padding_length_right": 2,
|
| 48 |
+
"version": "v15",
|
| 49 |
+
"extra_margin": extra_margin,
|
| 50 |
+
"parsing_mode": parsing_mode,
|
| 51 |
+
"left_cheek_width": left_cheek_width,
|
| 52 |
+
"right_cheek_width": right_cheek_width
|
| 53 |
+
}
|
| 54 |
+
args = Namespace(**args_dict)
|
| 55 |
+
|
| 56 |
+
# Create debug directory
|
| 57 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
# Read first frame
|
| 60 |
+
if get_file_type(video_path) == "video":
|
| 61 |
+
reader = imageio.get_reader(video_path)
|
| 62 |
+
first_frame = reader.get_data(0)
|
| 63 |
+
reader.close()
|
| 64 |
+
else:
|
| 65 |
+
first_frame = cv2.imread(video_path)
|
| 66 |
+
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 67 |
+
|
| 68 |
+
# Save first frame
|
| 69 |
+
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
|
| 70 |
+
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
|
| 71 |
+
|
| 72 |
+
# Get face coordinates
|
| 73 |
+
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
|
| 74 |
+
bbox = coord_list[0]
|
| 75 |
+
frame = frame_list[0]
|
| 76 |
+
|
| 77 |
+
if bbox == coord_placeholder:
|
| 78 |
+
return None, "No face detected, please adjust bbox_shift parameter"
|
| 79 |
+
|
| 80 |
+
# Initialize face parser
|
| 81 |
+
fp = FaceParsing(
|
| 82 |
+
left_cheek_width=args.left_cheek_width,
|
| 83 |
+
right_cheek_width=args.right_cheek_width
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Process first frame
|
| 87 |
+
x1, y1, x2, y2 = bbox
|
| 88 |
+
y2 = y2 + args.extra_margin
|
| 89 |
+
y2 = min(y2, frame.shape[0])
|
| 90 |
+
crop_frame = frame[y1:y2, x1:x2]
|
| 91 |
+
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
| 92 |
+
|
| 93 |
+
# Generate random audio features
|
| 94 |
+
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
|
| 95 |
+
audio_feature = pe(random_audio)
|
| 96 |
+
|
| 97 |
+
# Get latents
|
| 98 |
+
latents = vae.get_latents_for_unet(crop_frame)
|
| 99 |
+
latents = latents.to(dtype=weight_dtype)
|
| 100 |
+
|
| 101 |
+
# Generate prediction results
|
| 102 |
+
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
|
| 103 |
+
recon = vae.decode_latents(pred_latents)
|
| 104 |
+
|
| 105 |
+
# Inpaint back to original image
|
| 106 |
+
res_frame = recon[0]
|
| 107 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
| 108 |
+
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
| 109 |
+
|
| 110 |
+
# Save results (no need to convert color space again since get_image already returns RGB format)
|
| 111 |
+
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
|
| 112 |
+
cv2.imwrite(debug_result_path, combine_frame)
|
| 113 |
+
|
| 114 |
+
# Create information text
|
| 115 |
+
info_text = f"Parameter information:\n" + \
|
| 116 |
+
f"bbox_shift: {bbox_shift}\n" + \
|
| 117 |
+
f"extra_margin: {extra_margin}\n" + \
|
| 118 |
+
f"parsing_mode: {parsing_mode}\n" + \
|
| 119 |
+
f"left_cheek_width: {left_cheek_width}\n" + \
|
| 120 |
+
f"right_cheek_width: {right_cheek_width}\n" + \
|
| 121 |
+
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
|
| 122 |
+
|
| 123 |
+
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
|
| 124 |
+
|
| 125 |
def print_directory_contents(path):
|
| 126 |
for child in os.listdir(path):
|
| 127 |
child_path = os.path.join(path, child)
|
|
|
|
| 129 |
print(child_path)
|
| 130 |
|
| 131 |
def download_model():
|
| 132 |
+
# 检查必需的模型文件是否存在
|
| 133 |
+
required_models = {
|
| 134 |
+
"MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
|
| 135 |
+
"MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
|
| 136 |
+
"SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
|
| 137 |
+
"Whisper": f"{CheckpointsDir}/whisper/config.json",
|
| 138 |
+
"DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
|
| 139 |
+
"SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
|
| 140 |
+
"Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
|
| 141 |
+
"ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
missing_models = []
|
| 145 |
+
for model_name, model_path in required_models.items():
|
| 146 |
+
if not os.path.exists(model_path):
|
| 147 |
+
missing_models.append(model_name)
|
| 148 |
+
|
| 149 |
+
if missing_models:
|
| 150 |
+
# 全用英文
|
| 151 |
+
print("The following required model files are missing:")
|
| 152 |
+
for model in missing_models:
|
| 153 |
+
print(f"- {model}")
|
| 154 |
+
print("\nPlease run the download script to download the missing models:")
|
| 155 |
+
if sys.platform == "win32":
|
| 156 |
+
print("Windows: Run download_weights.bat")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
else:
|
| 158 |
+
print("Linux/Mac: Run ./download_weights.sh")
|
| 159 |
+
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
else:
|
| 161 |
+
print("All required model files exist.")
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
|
| 165 |
|
| 166 |
download_model() # for huggingface deployment.
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
from musetalk.utils.blending import get_image
|
| 169 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 170 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 171 |
+
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
| 172 |
+
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
|
| 173 |
|
| 174 |
|
| 175 |
+
def fast_check_ffmpeg():
|
| 176 |
+
try:
|
| 177 |
+
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
| 178 |
+
return True
|
| 179 |
+
except:
|
| 180 |
+
return False
|
| 181 |
|
| 182 |
|
|
|
|
|
|
|
|
|
|
| 183 |
@torch.no_grad()
|
| 184 |
+
def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
| 185 |
+
left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
|
| 186 |
+
# Set default parameters, aligned with inference.py
|
| 187 |
+
args_dict = {
|
| 188 |
+
"result_dir": './results/output',
|
| 189 |
+
"fps": 25,
|
| 190 |
+
"batch_size": 8,
|
| 191 |
+
"output_vid_name": '',
|
| 192 |
+
"use_saved_coord": False,
|
| 193 |
+
"audio_padding_length_left": 2,
|
| 194 |
+
"audio_padding_length_right": 2,
|
| 195 |
+
"version": "v15", # Fixed use v15 version
|
| 196 |
+
"extra_margin": extra_margin,
|
| 197 |
+
"parsing_mode": parsing_mode,
|
| 198 |
+
"left_cheek_width": left_cheek_width,
|
| 199 |
+
"right_cheek_width": right_cheek_width
|
| 200 |
+
}
|
| 201 |
args = Namespace(**args_dict)
|
| 202 |
|
| 203 |
+
# Check ffmpeg
|
| 204 |
+
if not fast_check_ffmpeg():
|
| 205 |
+
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
| 206 |
+
|
| 207 |
input_basename = os.path.basename(video_path).split('.')[0]
|
| 208 |
+
audio_basename = os.path.basename(audio_path).split('.')[0]
|
| 209 |
output_basename = f"{input_basename}_{audio_basename}"
|
| 210 |
+
|
| 211 |
+
# Create temporary directory
|
| 212 |
+
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
| 213 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 214 |
+
|
| 215 |
+
# Set result save path
|
| 216 |
+
result_img_save_path = os.path.join(temp_dir, output_basename)
|
| 217 |
+
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
| 218 |
+
os.makedirs(result_img_save_path, exist_ok=True)
|
| 219 |
|
| 220 |
+
if args.output_vid_name == "":
|
| 221 |
+
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
|
| 222 |
else:
|
| 223 |
+
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
| 224 |
+
|
| 225 |
############################################## extract frames from source video ##############################################
|
| 226 |
+
if get_file_type(video_path) == "video":
|
| 227 |
+
save_dir_full = os.path.join(temp_dir, input_basename)
|
| 228 |
+
os.makedirs(save_dir_full, exist_ok=True)
|
| 229 |
+
# Read video
|
|
|
|
|
|
|
| 230 |
reader = imageio.get_reader(video_path)
|
| 231 |
|
| 232 |
+
# Save images
|
| 233 |
for i, im in enumerate(reader):
|
| 234 |
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
| 235 |
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
|
|
|
| 238 |
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
| 239 |
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
| 240 |
fps = args.fps
|
| 241 |
+
|
| 242 |
############################################## extract audio feature ##############################################
|
| 243 |
+
# Extract audio features
|
| 244 |
+
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
| 245 |
+
whisper_chunks = audio_processor.get_whisper_chunk(
|
| 246 |
+
whisper_input_features,
|
| 247 |
+
device,
|
| 248 |
+
weight_dtype,
|
| 249 |
+
whisper,
|
| 250 |
+
librosa_length,
|
| 251 |
+
fps=fps,
|
| 252 |
+
audio_padding_length_left=args.audio_padding_length_left,
|
| 253 |
+
audio_padding_length_right=args.audio_padding_length_right,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
############################################## preprocess input image ##############################################
|
| 257 |
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
| 258 |
print("using extracted coordinates")
|
|
|
|
| 264 |
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
| 265 |
with open(crop_coord_save_path, 'wb') as f:
|
| 266 |
pickle.dump(coord_list, f)
|
| 267 |
+
bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
|
| 268 |
+
|
| 269 |
+
# Initialize face parser
|
| 270 |
+
fp = FaceParsing(
|
| 271 |
+
left_cheek_width=args.left_cheek_width,
|
| 272 |
+
right_cheek_width=args.right_cheek_width
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
i = 0
|
| 276 |
input_latent_list = []
|
| 277 |
for bbox, frame in zip(coord_list, frame_list):
|
| 278 |
if bbox == coord_placeholder:
|
| 279 |
continue
|
| 280 |
x1, y1, x2, y2 = bbox
|
| 281 |
+
y2 = y2 + args.extra_margin
|
| 282 |
+
y2 = min(y2, frame.shape[0])
|
| 283 |
crop_frame = frame[y1:y2, x1:x2]
|
| 284 |
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
| 285 |
latents = vae.get_latents_for_unet(crop_frame)
|
|
|
|
| 289 |
frame_list_cycle = frame_list + frame_list[::-1]
|
| 290 |
coord_list_cycle = coord_list + coord_list[::-1]
|
| 291 |
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 292 |
+
|
| 293 |
############################################## inference batch by batch ##############################################
|
| 294 |
print("start inference")
|
| 295 |
video_num = len(whisper_chunks)
|
| 296 |
batch_size = args.batch_size
|
| 297 |
+
gen = datagen(
|
| 298 |
+
whisper_chunks=whisper_chunks,
|
| 299 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 300 |
+
batch_size=batch_size,
|
| 301 |
+
delay_frame=0,
|
| 302 |
+
device=device,
|
| 303 |
+
)
|
| 304 |
res_frame_list = []
|
| 305 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
| 306 |
+
audio_feature_batch = pe(whisper_batch)
|
| 307 |
+
# Ensure latent_batch is consistent with model weight type
|
| 308 |
+
latent_batch = latent_batch.to(dtype=weight_dtype)
|
|
|
|
| 309 |
|
| 310 |
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
| 311 |
recon = vae.decode_latents(pred_latents)
|
|
|
|
| 318 |
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
| 319 |
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
| 320 |
x1, y1, x2, y2 = bbox
|
| 321 |
+
y2 = y2 + args.extra_margin
|
| 322 |
+
y2 = min(y2, frame.shape[0])
|
| 323 |
try:
|
| 324 |
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
| 325 |
except:
|
|
|
|
| 326 |
continue
|
| 327 |
|
| 328 |
+
# Use v15 version blending
|
| 329 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
| 330 |
+
|
| 331 |
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
| 332 |
|
| 333 |
+
# Frame rate
|
|
|
|
|
|
|
|
|
|
| 334 |
fps = 25
|
| 335 |
+
# Output video path
|
|
|
|
| 336 |
output_video = 'temp.mp4'
|
| 337 |
|
| 338 |
+
# Read images
|
| 339 |
def is_valid_image(file):
|
| 340 |
pattern = re.compile(r'\d{8}\.png')
|
| 341 |
return pattern.match(file)
|
|
|
|
| 349 |
images.append(imageio.imread(filename))
|
| 350 |
|
| 351 |
|
| 352 |
+
# Save video
|
| 353 |
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
input_video = './temp.mp4'
|
| 356 |
# Check if the input_video and audio_path exist
|
| 357 |
if not os.path.exists(input_video):
|
|
|
|
| 359 |
if not os.path.exists(audio_path):
|
| 360 |
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
| 361 |
|
| 362 |
+
# Read video
|
| 363 |
reader = imageio.get_reader(input_video)
|
| 364 |
+
fps = reader.get_meta_data()['fps'] # Get original video frame rate
|
| 365 |
+
reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
|
| 366 |
+
# Store frames in list
|
| 367 |
frames = images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
print(len(frames))
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
# Load the video
|
| 372 |
video_clip = VideoFileClip(input_video)
|
| 373 |
|
|
|
|
| 388 |
|
| 389 |
|
| 390 |
# load model weights
|
|
|
|
| 391 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 392 |
+
vae, unet, pe = load_all_model(
|
| 393 |
+
unet_model_path="./models/musetalkV15/unet.pth",
|
| 394 |
+
vae_type="sd-vae",
|
| 395 |
+
unet_config="./models/musetalkV15/musetalk.json",
|
| 396 |
+
device=device
|
| 397 |
+
)
|
| 398 |
|
| 399 |
+
# Parse command line arguments
|
| 400 |
+
parser = argparse.ArgumentParser()
|
| 401 |
+
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
|
| 402 |
+
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
|
| 403 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
| 404 |
+
parser.add_argument("--share", action="store_true", help="Create a public link")
|
| 405 |
+
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
| 406 |
+
args = parser.parse_args()
|
| 407 |
+
|
| 408 |
+
# Set data type
|
| 409 |
+
if args.use_float16:
|
| 410 |
+
# Convert models to half precision for better performance
|
| 411 |
+
pe = pe.half()
|
| 412 |
+
vae.vae = vae.vae.half()
|
| 413 |
+
unet.model = unet.model.half()
|
| 414 |
+
weight_dtype = torch.float16
|
| 415 |
+
else:
|
| 416 |
+
weight_dtype = torch.float32
|
| 417 |
+
|
| 418 |
+
# Move models to specified device
|
| 419 |
+
pe = pe.to(device)
|
| 420 |
+
vae.vae = vae.vae.to(device)
|
| 421 |
+
unet.model = unet.model.to(device)
|
| 422 |
|
| 423 |
+
timesteps = torch.tensor([0], device=device)
|
| 424 |
+
|
| 425 |
+
# Initialize audio processor and Whisper model
|
| 426 |
+
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
|
| 427 |
+
whisper = WhisperModel.from_pretrained("./models/whisper")
|
| 428 |
+
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
| 429 |
+
whisper.requires_grad_(False)
|
| 430 |
|
| 431 |
|
| 432 |
def check_video(video):
|
|
|
|
| 447 |
output_video = os.path.join('./results/input', output_file_name)
|
| 448 |
|
| 449 |
|
|
|
|
|
|
|
|
|
|
| 450 |
# read video
|
| 451 |
reader = imageio.get_reader(video)
|
| 452 |
fps = reader.get_meta_data()['fps'] # get fps from original video
|
|
|
|
| 478 |
|
| 479 |
with gr.Blocks(css=css) as demo:
|
| 480 |
gr.Markdown(
|
| 481 |
+
"""<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
|
| 482 |
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
| 483 |
</br>\
|
| 484 |
+
Yue Zhang <sup>*</sup>,\
|
| 485 |
+
Zhizhou Zhong <sup>*</sup>,\
|
| 486 |
+
Minhao Liu<sup>*</sup>,\
|
| 487 |
Zhaokang Chen,\
|
| 488 |
Bin Wu<sup>†</sup>,\
|
| 489 |
+
Yubin Zeng,\
|
| 490 |
+
Chao Zhang,\
|
| 491 |
Yingjie He,\
|
| 492 |
+
Junxin Huang,\
|
| 493 |
+
Wenjiang Zhou <br>\
|
| 494 |
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
| 495 |
Lyra Lab, Tencent Music Entertainment\
|
| 496 |
</h2> \
|
| 497 |
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
| 498 |
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
| 499 |
+
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
|
|
|
|
| 500 |
)
|
| 501 |
|
| 502 |
with gr.Row():
|
| 503 |
with gr.Column():
|
| 504 |
+
audio = gr.Audio(label="Drving Audio",type="filepath")
|
| 505 |
video = gr.Video(label="Reference Video",sources=['upload'])
|
| 506 |
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
| 507 |
+
extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
|
| 508 |
+
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
|
| 509 |
+
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
| 510 |
+
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
| 511 |
+
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
|
| 512 |
+
|
| 513 |
+
with gr.Row():
|
| 514 |
+
debug_btn = gr.Button("1. Test Inpainting ")
|
| 515 |
+
btn = gr.Button("2. Generate")
|
| 516 |
+
with gr.Column():
|
| 517 |
+
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
|
| 518 |
+
debug_info = gr.Textbox(label="Parameter Information", lines=5)
|
| 519 |
+
out1 = gr.Video()
|
| 520 |
|
| 521 |
video.change(
|
| 522 |
fn=check_video, inputs=[video], outputs=[video]
|
|
|
|
| 527 |
audio,
|
| 528 |
video,
|
| 529 |
bbox_shift,
|
| 530 |
+
extra_margin,
|
| 531 |
+
parsing_mode,
|
| 532 |
+
left_cheek_width,
|
| 533 |
+
right_cheek_width
|
| 534 |
],
|
| 535 |
outputs=[out1,bbox_shift_scale]
|
| 536 |
)
|
| 537 |
+
debug_btn.click(
|
| 538 |
+
fn=debug_inpainting,
|
| 539 |
+
inputs=[
|
| 540 |
+
video,
|
| 541 |
+
bbox_shift,
|
| 542 |
+
extra_margin,
|
| 543 |
+
parsing_mode,
|
| 544 |
+
left_cheek_width,
|
| 545 |
+
right_cheek_width
|
| 546 |
+
],
|
| 547 |
+
outputs=[debug_image, debug_info]
|
| 548 |
+
)
|
| 549 |
|
| 550 |
+
# Check ffmpeg and add to PATH
|
| 551 |
+
if not fast_check_ffmpeg():
|
| 552 |
+
print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
|
| 553 |
+
# According to operating system, choose path separator
|
| 554 |
+
path_separator = ';' if sys.platform == 'win32' else ':'
|
| 555 |
+
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
| 556 |
+
if not fast_check_ffmpeg():
|
| 557 |
+
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
| 558 |
+
|
| 559 |
+
# Solve asynchronous IO issues on Windows
|
| 560 |
+
if sys.platform == 'win32':
|
| 561 |
+
import asyncio
|
| 562 |
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
| 563 |
+
|
| 564 |
+
# Start Gradio application
|
| 565 |
demo.queue().launch(
|
| 566 |
+
share=args.share,
|
| 567 |
+
debug=True,
|
| 568 |
+
server_name=args.ip,
|
| 569 |
+
server_port=args.port
|
| 570 |
)
|
assets/figs/gradio.png
ADDED
|
Git LFS Details
|
assets/figs/gradio_2.png
ADDED
|
Git LFS Details
|
download_weights.bat
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
setlocal
|
| 3 |
+
|
| 4 |
+
:: Set the checkpoints directory
|
| 5 |
+
set CheckpointsDir=models
|
| 6 |
+
|
| 7 |
+
:: Create necessary directories
|
| 8 |
+
mkdir %CheckpointsDir%\musetalk
|
| 9 |
+
mkdir %CheckpointsDir%\musetalkV15
|
| 10 |
+
mkdir %CheckpointsDir%\syncnet
|
| 11 |
+
mkdir %CheckpointsDir%\dwpose
|
| 12 |
+
mkdir %CheckpointsDir%\face-parse-bisent
|
| 13 |
+
mkdir %CheckpointsDir%\sd-vae-ft-mse
|
| 14 |
+
mkdir %CheckpointsDir%\whisper
|
| 15 |
+
|
| 16 |
+
:: Install required packages
|
| 17 |
+
pip install -U "huggingface_hub[cli]"
|
| 18 |
+
pip install gdown
|
| 19 |
+
|
| 20 |
+
:: Set HuggingFace endpoint
|
| 21 |
+
set HF_ENDPOINT=https://hf-mirror.com
|
| 22 |
+
|
| 23 |
+
:: Download MuseTalk weights
|
| 24 |
+
huggingface-cli download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
|
| 25 |
+
|
| 26 |
+
:: Download SD VAE weights
|
| 27 |
+
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
| 28 |
+
|
| 29 |
+
:: Download Whisper weights
|
| 30 |
+
huggingface-cli download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
| 31 |
+
|
| 32 |
+
:: Download DWPose weights
|
| 33 |
+
huggingface-cli download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
|
| 34 |
+
|
| 35 |
+
:: Download SyncNet weights
|
| 36 |
+
huggingface-cli download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
|
| 37 |
+
|
| 38 |
+
:: Download Face Parse Bisent weights (using gdown)
|
| 39 |
+
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O %CheckpointsDir%\face-parse-bisent\79999_iter.pth
|
| 40 |
+
|
| 41 |
+
:: Download ResNet weights
|
| 42 |
+
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o %CheckpointsDir%\face-parse-bisent\resnet18-5c106cde.pth
|
| 43 |
+
|
| 44 |
+
echo All weights have been downloaded successfully!
|
| 45 |
+
endlocal
|
download_weights.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Set the checkpoints directory
|
| 4 |
+
CheckpointsDir="models"
|
| 5 |
+
|
| 6 |
+
# Create necessary directories
|
| 7 |
+
mkdir -p $CheckpointsDir/{musetalk,musetalkV15,syncnet,dwpose,face-parse-bisent,sd-vae-ft-mse,whisper}
|
| 8 |
+
|
| 9 |
+
# Install required packages
|
| 10 |
+
pip install -U "huggingface_hub[cli]"
|
| 11 |
+
pip install gdown
|
| 12 |
+
|
| 13 |
+
# Set HuggingFace endpoint
|
| 14 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 15 |
+
|
| 16 |
+
# Download MuseTalk weights
|
| 17 |
+
huggingface-cli download TMElyralab/MuseTalk --local-dir $CheckpointsDir
|
| 18 |
+
|
| 19 |
+
# Download SD VAE weights
|
| 20 |
+
huggingface-cli download stabilityai/sd-vae-ft-mse --local-dir $CheckpointsDir/sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
| 21 |
+
|
| 22 |
+
# Download Whisper weights
|
| 23 |
+
huggingface-cli download openai/whisper-tiny --local-dir $CheckpointsDir/whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
| 24 |
+
|
| 25 |
+
# Download DWPose weights
|
| 26 |
+
huggingface-cli download yzd-v/DWPose --local-dir $CheckpointsDir/dwpose --include "dw-ll_ucoco_384.pth"
|
| 27 |
+
|
| 28 |
+
# Download SyncNet weights
|
| 29 |
+
huggingface-cli download ByteDance/LatentSync --local-dir $CheckpointsDir/syncnet --include "latentsync_syncnet.pt"
|
| 30 |
+
|
| 31 |
+
# Download Face Parse Bisent weights (using gdown)
|
| 32 |
+
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
|
| 33 |
+
|
| 34 |
+
# Download ResNet weights
|
| 35 |
+
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
|
| 36 |
+
|
| 37 |
+
echo "All weights have been downloaded successfully!"
|
musetalk/utils/audio_processor.py
CHANGED
|
@@ -49,8 +49,9 @@ class AudioProcessor:
|
|
| 49 |
whisper_feature = []
|
| 50 |
# Process multiple 30s mel input features
|
| 51 |
for input_feature in whisper_input_features:
|
| 52 |
-
|
| 53 |
-
audio_feats =
|
|
|
|
| 54 |
whisper_feature.append(audio_feats)
|
| 55 |
|
| 56 |
whisper_feature = torch.cat(whisper_feature, dim=1)
|
|
|
|
| 49 |
whisper_feature = []
|
| 50 |
# Process multiple 30s mel input features
|
| 51 |
for input_feature in whisper_input_features:
|
| 52 |
+
input_feature = input_feature.to(device).to(weight_dtype)
|
| 53 |
+
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
|
| 54 |
+
audio_feats = torch.stack(audio_feats, dim=2)
|
| 55 |
whisper_feature.append(audio_feats)
|
| 56 |
|
| 57 |
whisper_feature = torch.cat(whisper_feature, dim=1)
|
musetalk/utils/utils.py
CHANGED
|
@@ -8,26 +8,18 @@ from einops import rearrange
|
|
| 8 |
import shutil
|
| 9 |
import os.path as osp
|
| 10 |
|
| 11 |
-
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
| 12 |
-
if ffmpeg_path is None:
|
| 13 |
-
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
|
| 14 |
-
elif ffmpeg_path not in os.getenv('PATH'):
|
| 15 |
-
print("add ffmpeg to path")
|
| 16 |
-
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
| 17 |
-
|
| 18 |
-
|
| 19 |
from musetalk.models.vae import VAE
|
| 20 |
from musetalk.models.unet import UNet,PositionalEncoding
|
| 21 |
|
| 22 |
|
| 23 |
def load_all_model(
|
| 24 |
-
unet_model_path=
|
| 25 |
-
vae_type="sd-vae
|
| 26 |
-
unet_config=
|
| 27 |
device=None,
|
| 28 |
):
|
| 29 |
vae = VAE(
|
| 30 |
-
model_path =
|
| 31 |
)
|
| 32 |
print(f"load unet model from {unet_model_path}")
|
| 33 |
unet = UNet(
|
|
|
|
| 8 |
import shutil
|
| 9 |
import os.path as osp
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from musetalk.models.vae import VAE
|
| 12 |
from musetalk.models.unet import UNet,PositionalEncoding
|
| 13 |
|
| 14 |
|
| 15 |
def load_all_model(
|
| 16 |
+
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
|
| 17 |
+
vae_type="sd-vae",
|
| 18 |
+
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
|
| 19 |
device=None,
|
| 20 |
):
|
| 21 |
vae = VAE(
|
| 22 |
+
model_path = os.path.join("models", vae_type),
|
| 23 |
)
|
| 24 |
print(f"load unet model from {unet_model_path}")
|
| 25 |
unet = UNet(
|
requirements.txt
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
|
| 2 |
-
torch==2.0.1
|
| 3 |
-
torchvision==0.15.2
|
| 4 |
-
torchaudio==2.0.2
|
| 5 |
-
diffusers==0.27.2
|
| 6 |
accelerate==0.28.0
|
|
|
|
| 7 |
tensorflow==2.12.0
|
| 8 |
tensorboard==2.12.0
|
| 9 |
opencv-python==4.9.0.80
|
| 10 |
soundfile==0.12.1
|
| 11 |
transformers==4.39.2
|
| 12 |
-
huggingface_hub==0.
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
gdown
|
| 15 |
requests
|
|
@@ -17,6 +17,4 @@ imageio[ffmpeg]
|
|
| 17 |
|
| 18 |
omegaconf
|
| 19 |
ffmpeg-python
|
| 20 |
-
gradio
|
| 21 |
-
spaces
|
| 22 |
moviepy
|
|
|
|
| 1 |
+
diffusers==0.30.2
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
accelerate==0.28.0
|
| 3 |
+
numpy==1.23.5
|
| 4 |
tensorflow==2.12.0
|
| 5 |
tensorboard==2.12.0
|
| 6 |
opencv-python==4.9.0.80
|
| 7 |
soundfile==0.12.1
|
| 8 |
transformers==4.39.2
|
| 9 |
+
huggingface_hub==0.30.2
|
| 10 |
+
librosa==0.11.0
|
| 11 |
+
einops==0.8.1
|
| 12 |
+
gradio==5.24.0
|
| 13 |
|
| 14 |
gdown
|
| 15 |
requests
|
|
|
|
| 17 |
|
| 18 |
omegaconf
|
| 19 |
ffmpeg-python
|
|
|
|
|
|
|
| 20 |
moviepy
|
scripts/inference.py
CHANGED
|
@@ -8,9 +8,11 @@ import shutil
|
|
| 8 |
import pickle
|
| 9 |
import argparse
|
| 10 |
import numpy as np
|
|
|
|
| 11 |
from tqdm import tqdm
|
| 12 |
from omegaconf import OmegaConf
|
| 13 |
from transformers import WhisperModel
|
|
|
|
| 14 |
|
| 15 |
from musetalk.utils.blending import get_image
|
| 16 |
from musetalk.utils.face_parsing import FaceParsing
|
|
@@ -18,16 +20,26 @@ from musetalk.utils.audio_processor import AudioProcessor
|
|
| 18 |
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
| 19 |
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
@torch.no_grad()
|
| 22 |
def main(args):
|
| 23 |
# Configure ffmpeg path
|
| 24 |
-
if
|
| 25 |
print("Adding ffmpeg to PATH")
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Set computing device
|
| 29 |
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 30 |
-
|
| 31 |
# Load model weights
|
| 32 |
vae, unet, pe = load_all_model(
|
| 33 |
unet_model_path=args.unet_model_path,
|
|
|
|
| 8 |
import pickle
|
| 9 |
import argparse
|
| 10 |
import numpy as np
|
| 11 |
+
import subprocess
|
| 12 |
from tqdm import tqdm
|
| 13 |
from omegaconf import OmegaConf
|
| 14 |
from transformers import WhisperModel
|
| 15 |
+
import sys
|
| 16 |
|
| 17 |
from musetalk.utils.blending import get_image
|
| 18 |
from musetalk.utils.face_parsing import FaceParsing
|
|
|
|
| 20 |
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
| 21 |
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 22 |
|
| 23 |
+
def fast_check_ffmpeg():
|
| 24 |
+
try:
|
| 25 |
+
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
| 26 |
+
return True
|
| 27 |
+
except:
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
@torch.no_grad()
|
| 31 |
def main(args):
|
| 32 |
# Configure ffmpeg path
|
| 33 |
+
if not fast_check_ffmpeg():
|
| 34 |
print("Adding ffmpeg to PATH")
|
| 35 |
+
# Choose path separator based on operating system
|
| 36 |
+
path_separator = ';' if sys.platform == 'win32' else ':'
|
| 37 |
+
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
| 38 |
+
if not fast_check_ffmpeg():
|
| 39 |
+
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
| 40 |
|
| 41 |
# Set computing device
|
| 42 |
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 43 |
# Load model weights
|
| 44 |
vae, unet, pe = load_all_model(
|
| 45 |
unet_model_path=args.unet_model_path,
|
scripts/preprocess.py
CHANGED
|
@@ -12,11 +12,23 @@ from mmpose.structures import merge_data_samples
|
|
| 12 |
import torch
|
| 13 |
import numpy as np
|
| 14 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
|
| 17 |
-
if
|
| 18 |
-
print("
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
class AnalyzeFace:
|
| 22 |
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
|
|
|
|
| 12 |
import torch
|
| 13 |
import numpy as np
|
| 14 |
from tqdm import tqdm
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
def fast_check_ffmpeg():
|
| 18 |
+
try:
|
| 19 |
+
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
| 20 |
+
return True
|
| 21 |
+
except:
|
| 22 |
+
return False
|
| 23 |
|
| 24 |
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
|
| 25 |
+
if not fast_check_ffmpeg():
|
| 26 |
+
print("Adding ffmpeg to PATH")
|
| 27 |
+
# Choose path separator based on operating system
|
| 28 |
+
path_separator = ';' if sys.platform == 'win32' else ':'
|
| 29 |
+
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
| 30 |
+
if not fast_check_ffmpeg():
|
| 31 |
+
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
| 32 |
|
| 33 |
class AnalyzeFace:
|
| 34 |
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
|
scripts/realtime_inference.py
CHANGED
|
@@ -23,6 +23,15 @@ import shutil
|
|
| 23 |
import threading
|
| 24 |
import queue
|
| 25 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
|
@@ -318,7 +327,7 @@ if __name__ == "__main__":
|
|
| 318 |
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
| 319 |
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
| 320 |
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
| 321 |
-
parser.add_argument("--batch_size", type=int, default=
|
| 322 |
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
| 323 |
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
| 324 |
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
|
@@ -332,6 +341,15 @@ if __name__ == "__main__":
|
|
| 332 |
|
| 333 |
args = parser.parse_args()
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
# Set computing device
|
| 336 |
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 337 |
|
|
|
|
| 23 |
import threading
|
| 24 |
import queue
|
| 25 |
import time
|
| 26 |
+
import subprocess
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def fast_check_ffmpeg():
|
| 30 |
+
try:
|
| 31 |
+
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
| 32 |
+
return True
|
| 33 |
+
except:
|
| 34 |
+
return False
|
| 35 |
|
| 36 |
|
| 37 |
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
|
|
|
| 327 |
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
| 328 |
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
| 329 |
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
| 330 |
+
parser.add_argument("--batch_size", type=int, default=20, help="Batch size for inference")
|
| 331 |
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
| 332 |
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
| 333 |
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
|
|
|
| 341 |
|
| 342 |
args = parser.parse_args()
|
| 343 |
|
| 344 |
+
# Configure ffmpeg path
|
| 345 |
+
if not fast_check_ffmpeg():
|
| 346 |
+
print("Adding ffmpeg to PATH")
|
| 347 |
+
# Choose path separator based on operating system
|
| 348 |
+
path_separator = ';' if sys.platform == 'win32' else ':'
|
| 349 |
+
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
| 350 |
+
if not fast_check_ffmpeg():
|
| 351 |
+
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
| 352 |
+
|
| 353 |
# Set computing device
|
| 354 |
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 355 |
|
test_ffmpeg.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
def test_ffmpeg(ffmpeg_path):
|
| 6 |
+
print(f"Testing ffmpeg path: {ffmpeg_path}")
|
| 7 |
+
|
| 8 |
+
# Choose path separator based on operating system
|
| 9 |
+
path_separator = ';' if sys.platform == 'win32' else ':'
|
| 10 |
+
|
| 11 |
+
# Add ffmpeg path to environment variable
|
| 12 |
+
os.environ["PATH"] = f"{ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
# Try to run ffmpeg
|
| 16 |
+
result = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
|
| 17 |
+
print("FFmpeg test successful!")
|
| 18 |
+
print("FFmpeg version information:")
|
| 19 |
+
print(result.stdout)
|
| 20 |
+
return True
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print("FFmpeg test failed!")
|
| 23 |
+
print(f"Error message: {str(e)}")
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
# Default ffmpeg path, can be modified as needed
|
| 28 |
+
default_path = r"ffmpeg-master-latest-win64-gpl-shared\bin"
|
| 29 |
+
|
| 30 |
+
# Use command line argument if provided, otherwise use default path
|
| 31 |
+
ffmpeg_path = sys.argv[1] if len(sys.argv) > 1 else default_path
|
| 32 |
+
|
| 33 |
+
test_ffmpeg(ffmpeg_path)
|