Spaces:
Running
Running
fix: resolve NameError, align model paths to models/, remove @spaces.GPU
Browse files
app.py
CHANGED
|
@@ -1,294 +1,283 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
-
import pdb
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
import spaces
|
| 7 |
-
import numpy as np
|
| 8 |
import sys
|
| 9 |
import subprocess
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
import requests
|
| 13 |
-
|
| 14 |
-
import argparse
|
| 15 |
-
import os
|
| 16 |
-
from omegaconf import OmegaConf
|
| 17 |
import numpy as np
|
| 18 |
import cv2
|
| 19 |
import torch
|
| 20 |
-
import glob
|
| 21 |
-
import pickle
|
| 22 |
from tqdm import tqdm
|
| 23 |
-
import copy
|
| 24 |
from argparse import Namespace
|
| 25 |
-
import
|
| 26 |
-
import
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
def download_model():
|
| 30 |
-
if not
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
snapshot_download(
|
| 35 |
repo_id="TMElyralab/MuseTalk",
|
| 36 |
-
local_dir=
|
| 37 |
max_workers=8,
|
| 38 |
local_dir_use_symlinks=True,
|
|
|
|
| 39 |
)
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
snapshot_download(
|
| 42 |
repo_id="stabilityai/sd-vae-ft-mse",
|
| 43 |
-
local_dir=
|
| 44 |
max_workers=8,
|
| 45 |
local_dir_use_symlinks=True,
|
|
|
|
| 46 |
)
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
snapshot_download(
|
| 49 |
repo_id="yzd-v/DWPose",
|
| 50 |
-
local_dir=
|
| 51 |
max_workers=8,
|
| 52 |
local_dir_use_symlinks=True,
|
|
|
|
| 53 |
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
response = requests.get(url)
|
| 57 |
-
# 确保请求成功
|
| 58 |
-
if response.status_code == 200:
|
| 59 |
-
# 指定文件保存的位置
|
| 60 |
-
file_path = f"{CheckpointsDir}/whisper/tiny.pt"
|
| 61 |
-
os.makedirs(f"{CheckpointsDir}/whisper/")
|
| 62 |
-
# 将文件内容写入指定位置
|
| 63 |
-
with open(file_path, "wb") as f:
|
| 64 |
-
f.write(response.content)
|
| 65 |
-
else:
|
| 66 |
-
print(f"请求失败,状态码:{response.status_code}")
|
| 67 |
-
#gdown face parse
|
| 68 |
-
url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
|
| 69 |
-
os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
|
| 70 |
-
file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
|
| 71 |
-
gdown.download(url, output, quiet=False)
|
| 72 |
-
#resnet
|
| 73 |
-
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
| 74 |
-
response = requests.get(url)
|
| 75 |
-
# 确保请求成功
|
| 76 |
-
if response.status_code == 200:
|
| 77 |
-
# 指定文件保存的位置
|
| 78 |
-
file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
| 79 |
-
# 将文件内容写入指定位置
|
| 80 |
-
with open(file_path, "wb") as f:
|
| 81 |
-
f.write(response.content)
|
| 82 |
-
else:
|
| 83 |
-
print(f"请求失败,状态码:{response.status_code}")
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
toc = time.time()
|
| 87 |
-
|
| 88 |
-
print(f"download cost {toc-tic} seconds")
|
| 89 |
-
else:
|
| 90 |
-
print("Already download the model.")
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
download_model() # for huggingface deployment.
|
| 95 |
|
|
|
|
| 96 |
|
| 97 |
-
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
| 98 |
-
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
| 99 |
from musetalk.utils.blending import get_image
|
| 100 |
from musetalk.utils.utils import load_all_model
|
| 101 |
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
|
| 107 |
|
| 108 |
-
@spaces.GPU(duration=600)
|
| 109 |
@torch.no_grad()
|
| 110 |
-
def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
|
| 111 |
-
args_dict=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
args = Namespace(**args_dict)
|
| 113 |
|
| 114 |
-
input_basename = os.path.basename(video_path).split(
|
| 115 |
-
audio_basename
|
| 116 |
output_basename = f"{input_basename}_{audio_basename}"
|
| 117 |
-
result_img_save_path = os.path.join(args.result_dir, output_basename)
|
| 118 |
-
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl")
|
| 119 |
-
os.makedirs(result_img_save_path,
|
| 120 |
|
| 121 |
-
if args.output_vid_name=="":
|
| 122 |
-
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
| 123 |
else:
|
| 124 |
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
| 125 |
-
|
| 126 |
-
if get_file_type(video_path)=="video":
|
| 127 |
save_dir_full = os.path.join(args.result_dir, input_basename)
|
| 128 |
-
os.makedirs(save_dir_full,
|
| 129 |
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
| 130 |
os.system(cmd)
|
| 131 |
-
input_img_list = sorted(glob.glob(os.path.join(save_dir_full,
|
| 132 |
fps = get_video_fps(video_path)
|
| 133 |
-
else:
|
| 134 |
-
input_img_list = glob.glob(os.path.join(video_path,
|
| 135 |
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
| 136 |
fps = args.fps
|
| 137 |
-
|
| 138 |
-
############################################## extract audio feature ##############################################
|
| 139 |
whisper_feature = audio_processor.audio2feat(audio_path)
|
| 140 |
-
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
| 141 |
-
|
| 142 |
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
| 143 |
-
print("
|
| 144 |
-
with open(crop_coord_save_path,
|
| 145 |
coord_list = pickle.load(f)
|
| 146 |
frame_list = read_imgs(input_img_list)
|
| 147 |
else:
|
| 148 |
-
print("
|
| 149 |
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
| 150 |
-
with open(crop_coord_save_path,
|
| 151 |
pickle.dump(coord_list, f)
|
| 152 |
-
|
| 153 |
-
i = 0
|
| 154 |
input_latent_list = []
|
| 155 |
for bbox, frame in zip(coord_list, frame_list):
|
| 156 |
if bbox == coord_placeholder:
|
| 157 |
continue
|
| 158 |
x1, y1, x2, y2 = bbox
|
| 159 |
crop_frame = frame[y1:y2, x1:x2]
|
| 160 |
-
crop_frame = cv2.resize(crop_frame,(256,256),
|
| 161 |
latents = vae.get_latents_for_unet(crop_frame)
|
| 162 |
input_latent_list.append(latents)
|
| 163 |
|
| 164 |
-
# to smooth the first and the last frame
|
| 165 |
frame_list_cycle = frame_list + frame_list[::-1]
|
| 166 |
coord_list_cycle = coord_list + coord_list[::-1]
|
| 167 |
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 168 |
-
|
| 169 |
-
print("
|
| 170 |
video_num = len(whisper_chunks)
|
| 171 |
batch_size = args.batch_size
|
| 172 |
-
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
| 173 |
res_frame_list = []
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
| 176 |
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
| 177 |
-
audio_feature_batch = torch.stack(tensor_list).to(unet.device)
|
| 178 |
audio_feature_batch = pe(audio_feature_batch)
|
| 179 |
-
|
| 180 |
-
|
|
|
|
| 181 |
recon = vae.decode_latents(pred_latents)
|
| 182 |
for res_frame in recon:
|
| 183 |
res_frame_list.append(res_frame)
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
print("pad talking image to original video")
|
| 187 |
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
| 188 |
-
bbox = coord_list_cycle[i%
|
| 189 |
-
ori_frame = copy.deepcopy(frame_list_cycle[i%
|
| 190 |
x1, y1, x2, y2 = bbox
|
| 191 |
try:
|
| 192 |
-
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
| 193 |
-
except:
|
| 194 |
-
# print(bbox)
|
| 195 |
continue
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
os.system(cmd_img2video)
|
| 203 |
|
| 204 |
cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
|
| 205 |
-
print(cmd_combine_audio)
|
| 206 |
os.system(cmd_combine_audio)
|
| 207 |
|
| 208 |
-
os.
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
return output_vid_name
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
# load model weights
|
| 216 |
-
audio_processor,vae,unet,pe = load_all_model()
|
| 217 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 218 |
-
timesteps = torch.tensor([0], device=device)
|
| 219 |
-
|
| 220 |
|
|
|
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
def check_video(video):
|
| 224 |
-
# Define the output video file name
|
| 225 |
dir_path, file_name = os.path.split(video)
|
|
|
|
| 226 |
if file_name.startswith("outputxxx_"):
|
| 227 |
return video
|
| 228 |
-
# Add the output prefix to the file name
|
| 229 |
-
output_file_name = "outputxxx_" + file_name
|
| 230 |
|
| 231 |
-
|
| 232 |
output_video = os.path.join(dir_path, output_file_name)
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
# Run the ffmpeg command to change the frame rate to 25fps
|
| 236 |
-
command = f"ffmpeg -i {video} -r 25 {output_video} -y"
|
| 237 |
subprocess.run(command, shell=True, check=True)
|
| 238 |
-
return output_video
|
| 239 |
-
|
| 240 |
|
|
|
|
| 241 |
|
| 242 |
|
| 243 |
css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
|
| 244 |
|
| 245 |
with gr.Blocks(css=css) as demo:
|
| 246 |
gr.Markdown(
|
| 247 |
-
"<
|
| 248 |
-
|
| 249 |
-
</br>\
|
| 250 |
-
Yue Zhang <sup>\*</sup>,\
|
| 251 |
-
Minhao Liu<sup>\*</sup>,\
|
| 252 |
-
Zhaokang Chen,\
|
| 253 |
-
Bin Wu<sup>†</sup>,\
|
| 254 |
-
Yingjie He,\
|
| 255 |
-
Chao Zhan,\
|
| 256 |
-
Wenjiang Zhou\
|
| 257 |
-
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
| 258 |
-
Lyra Lab, Tencent Music Entertainment\
|
| 259 |
-
</h2> \
|
| 260 |
-
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
| 261 |
-
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
| 262 |
-
<a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
|
| 263 |
-
<a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
|
| 264 |
)
|
| 265 |
|
| 266 |
with gr.Row():
|
| 267 |
with gr.Column():
|
| 268 |
-
audio = gr.Audio(label="Driven Audio",type="filepath")
|
| 269 |
video = gr.Video(label="Reference Video")
|
| 270 |
-
bbox_shift = gr.Number(label="
|
| 271 |
btn = gr.Button("Generate")
|
| 272 |
out1 = gr.Video()
|
| 273 |
-
|
| 274 |
-
video.change(
|
| 275 |
-
fn=check_video, inputs=[video], outputs=[video]
|
| 276 |
-
)
|
| 277 |
btn.click(
|
| 278 |
fn=inference,
|
| 279 |
-
inputs=[
|
| 280 |
-
audio,
|
| 281 |
-
video,
|
| 282 |
-
bbox_shift,
|
| 283 |
-
],
|
| 284 |
outputs=out1,
|
| 285 |
)
|
| 286 |
|
| 287 |
-
# Set the IP and port
|
| 288 |
-
ip_address = "0.0.0.0" # Replace with your desired IP address
|
| 289 |
-
port_number = 7860 # Replace with your desired port number
|
| 290 |
-
|
| 291 |
-
|
| 292 |
demo.queue().launch(
|
| 293 |
-
share=False
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import sys
|
| 4 |
import subprocess
|
| 5 |
+
import glob
|
| 6 |
+
import copy
|
| 7 |
+
import pickle
|
| 8 |
+
import shutil
|
| 9 |
|
| 10 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import cv2
|
| 13 |
import torch
|
|
|
|
|
|
|
| 14 |
from tqdm import tqdm
|
|
|
|
| 15 |
from argparse import Namespace
|
| 16 |
+
from huggingface_hub import snapshot_download
|
| 17 |
+
import requests
|
| 18 |
+
|
| 19 |
+
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
+
ModelsDir = os.path.join(ProjectDir, "models")
|
| 21 |
|
| 22 |
|
| 23 |
def download_model():
|
| 24 |
+
"""Download model weights if not already present (entrypoint.sh handles this in Docker)."""
|
| 25 |
+
required_files = [
|
| 26 |
+
os.path.join(ModelsDir, "musetalkV15", "unet.pth"),
|
| 27 |
+
os.path.join(ModelsDir, "sd-vae", "diffusion_pytorch_model.safetensors"),
|
| 28 |
+
os.path.join(ModelsDir, "whisper", "config.json"),
|
| 29 |
+
os.path.join(ModelsDir, "dwpose", "dw-ll_ucoco_384.pth"),
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
all_present = all(os.path.exists(f) for f in required_files)
|
| 33 |
+
|
| 34 |
+
if all_present:
|
| 35 |
+
print("All model files present — skipping download.")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
print("Some model files missing, attempting download...")
|
| 39 |
+
tic = time.time()
|
| 40 |
+
|
| 41 |
+
os.makedirs(ModelsDir, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
snapshot_download(
|
| 45 |
repo_id="TMElyralab/MuseTalk",
|
| 46 |
+
local_dir=ModelsDir,
|
| 47 |
max_workers=8,
|
| 48 |
local_dir_use_symlinks=True,
|
| 49 |
+
allow_patterns=["musetalk/*", "musetalkV15/*"],
|
| 50 |
)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Warning: MuseTalk model download failed: {e}")
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
snapshot_download(
|
| 56 |
repo_id="stabilityai/sd-vae-ft-mse",
|
| 57 |
+
local_dir=os.path.join(ModelsDir, "sd-vae"),
|
| 58 |
max_workers=8,
|
| 59 |
local_dir_use_symlinks=True,
|
| 60 |
+
allow_patterns=["config.json", "diffusion_pytorch_model.*"],
|
| 61 |
)
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"Warning: SD VAE download failed: {e}")
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
snapshot_download(
|
| 67 |
+
repo_id="openai/whisper-tiny",
|
| 68 |
+
local_dir=os.path.join(ModelsDir, "whisper"),
|
| 69 |
+
max_workers=8,
|
| 70 |
+
local_dir_use_symlinks=True,
|
| 71 |
+
allow_patterns=["config.json", "pytorch_model.bin", "preprocessor_config.json"],
|
| 72 |
+
)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Warning: Whisper download failed: {e}")
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
snapshot_download(
|
| 78 |
repo_id="yzd-v/DWPose",
|
| 79 |
+
local_dir=os.path.join(ModelsDir, "dwpose"),
|
| 80 |
max_workers=8,
|
| 81 |
local_dir_use_symlinks=True,
|
| 82 |
+
allow_patterns=["dw-ll_ucoco_384.pth"],
|
| 83 |
)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Warning: DWPose download failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
face_parse_dir = os.path.join(ModelsDir, "face-parse-bisent")
|
| 88 |
+
os.makedirs(face_parse_dir, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
face_parse_path = os.path.join(face_parse_dir, "79999_iter.pth")
|
| 91 |
+
if not os.path.exists(face_parse_path):
|
| 92 |
+
try:
|
| 93 |
+
import gdown
|
| 94 |
+
gdown.download(
|
| 95 |
+
id="154JgKpzCPW82qINcVieuPH3fZ2e0P812",
|
| 96 |
+
output=face_parse_path,
|
| 97 |
+
quiet=False,
|
| 98 |
+
)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"Warning: Face parse download failed: {e}")
|
| 101 |
+
|
| 102 |
+
resnet_path = os.path.join(face_parse_dir, "resnet18-5c106cde.pth")
|
| 103 |
+
if not os.path.exists(resnet_path):
|
| 104 |
+
try:
|
| 105 |
+
response = requests.get("https://download.pytorch.org/models/resnet18-5c106cde.pth")
|
| 106 |
+
if response.status_code == 200:
|
| 107 |
+
with open(resnet_path, "wb") as f:
|
| 108 |
+
f.write(response.content)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"Warning: ResNet download failed: {e}")
|
| 111 |
|
| 112 |
+
toc = time.time()
|
| 113 |
+
print(f"Download completed in {toc - tic:.1f}s")
|
| 114 |
|
|
|
|
| 115 |
|
| 116 |
+
download_model()
|
| 117 |
|
| 118 |
+
from musetalk.utils.utils import get_file_type, get_video_fps, datagen
|
| 119 |
+
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 120 |
from musetalk.utils.blending import get_image
|
| 121 |
from musetalk.utils.utils import load_all_model
|
| 122 |
|
| 123 |
|
| 124 |
+
audio_processor, vae, unet, pe = load_all_model()
|
| 125 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 126 |
+
timesteps = torch.tensor([0], device=device)
|
| 127 |
|
| 128 |
|
|
|
|
| 129 |
@torch.no_grad()
|
| 130 |
+
def inference(audio_path, video_path, bbox_shift, progress=gr.Progress(track_tqdm=True)):
|
| 131 |
+
args_dict = {
|
| 132 |
+
"result_dir": "./results",
|
| 133 |
+
"fps": 25,
|
| 134 |
+
"batch_size": 8,
|
| 135 |
+
"output_vid_name": "",
|
| 136 |
+
"use_saved_coord": False,
|
| 137 |
+
}
|
| 138 |
args = Namespace(**args_dict)
|
| 139 |
|
| 140 |
+
input_basename = os.path.basename(video_path).split(".")[0]
|
| 141 |
+
audio_basename = os.path.basename(audio_path).split(".")[0]
|
| 142 |
output_basename = f"{input_basename}_{audio_basename}"
|
| 143 |
+
result_img_save_path = os.path.join(args.result_dir, output_basename)
|
| 144 |
+
crop_coord_save_path = os.path.join(result_img_save_path, input_basename + ".pkl")
|
| 145 |
+
os.makedirs(result_img_save_path, exist_ok=True)
|
| 146 |
|
| 147 |
+
if args.output_vid_name == "":
|
| 148 |
+
output_vid_name = os.path.join(args.result_dir, output_basename + ".mp4")
|
| 149 |
else:
|
| 150 |
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
| 151 |
+
|
| 152 |
+
if get_file_type(video_path) == "video":
|
| 153 |
save_dir_full = os.path.join(args.result_dir, input_basename)
|
| 154 |
+
os.makedirs(save_dir_full, exist_ok=True)
|
| 155 |
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
| 156 |
os.system(cmd)
|
| 157 |
+
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]")))
|
| 158 |
fps = get_video_fps(video_path)
|
| 159 |
+
else:
|
| 160 |
+
input_img_list = glob.glob(os.path.join(video_path, "*.[jpJP][pnPN]*[gG]"))
|
| 161 |
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
| 162 |
fps = args.fps
|
| 163 |
+
|
|
|
|
| 164 |
whisper_feature = audio_processor.audio2feat(audio_path)
|
| 165 |
+
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature, fps=fps)
|
| 166 |
+
|
| 167 |
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
| 168 |
+
print("Using extracted coordinates")
|
| 169 |
+
with open(crop_coord_save_path, "rb") as f:
|
| 170 |
coord_list = pickle.load(f)
|
| 171 |
frame_list = read_imgs(input_img_list)
|
| 172 |
else:
|
| 173 |
+
print("Extracting landmarks...")
|
| 174 |
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
| 175 |
+
with open(crop_coord_save_path, "wb") as f:
|
| 176 |
pickle.dump(coord_list, f)
|
| 177 |
+
|
|
|
|
| 178 |
input_latent_list = []
|
| 179 |
for bbox, frame in zip(coord_list, frame_list):
|
| 180 |
if bbox == coord_placeholder:
|
| 181 |
continue
|
| 182 |
x1, y1, x2, y2 = bbox
|
| 183 |
crop_frame = frame[y1:y2, x1:x2]
|
| 184 |
+
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
| 185 |
latents = vae.get_latents_for_unet(crop_frame)
|
| 186 |
input_latent_list.append(latents)
|
| 187 |
|
|
|
|
| 188 |
frame_list_cycle = frame_list + frame_list[::-1]
|
| 189 |
coord_list_cycle = coord_list + coord_list[::-1]
|
| 190 |
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 191 |
+
|
| 192 |
+
print("Starting inference...")
|
| 193 |
video_num = len(whisper_chunks)
|
| 194 |
batch_size = args.batch_size
|
| 195 |
+
gen = datagen(whisper_chunks, input_latent_list_cycle, batch_size)
|
| 196 |
res_frame_list = []
|
| 197 |
+
|
| 198 |
+
for i, (whisper_batch, latent_batch) in enumerate(
|
| 199 |
+
tqdm(gen, total=int(np.ceil(float(video_num) / batch_size)))
|
| 200 |
+
):
|
| 201 |
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
| 202 |
+
audio_feature_batch = torch.stack(tensor_list).to(unet.device)
|
| 203 |
audio_feature_batch = pe(audio_feature_batch)
|
| 204 |
+
pred_latents = unet.model(
|
| 205 |
+
latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
|
| 206 |
+
).sample
|
| 207 |
recon = vae.decode_latents(pred_latents)
|
| 208 |
for res_frame in recon:
|
| 209 |
res_frame_list.append(res_frame)
|
| 210 |
+
|
| 211 |
+
print("Compositing frames...")
|
|
|
|
| 212 |
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
| 213 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 214 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 215 |
x1, y1, x2, y2 = bbox
|
| 216 |
try:
|
| 217 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 218 |
+
except Exception:
|
|
|
|
| 219 |
continue
|
| 220 |
+
combine_frame = get_image(ori_frame, res_frame, bbox)
|
| 221 |
+
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 222 |
+
|
| 223 |
+
cmd_img2video = (
|
| 224 |
+
f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png "
|
| 225 |
+
f"-vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p "
|
| 226 |
+
f"-crf 18 temp.mp4"
|
| 227 |
+
)
|
| 228 |
os.system(cmd_img2video)
|
| 229 |
|
| 230 |
cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
|
|
|
|
| 231 |
os.system(cmd_combine_audio)
|
| 232 |
|
| 233 |
+
if os.path.exists("temp.mp4"):
|
| 234 |
+
os.remove("temp.mp4")
|
| 235 |
+
shutil.rmtree(result_img_save_path, ignore_errors=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
print(f"Result saved to {output_vid_name}")
|
| 238 |
+
return output_vid_name
|
| 239 |
|
| 240 |
|
| 241 |
def check_video(video):
|
|
|
|
| 242 |
dir_path, file_name = os.path.split(video)
|
| 243 |
+
|
| 244 |
if file_name.startswith("outputxxx_"):
|
| 245 |
return video
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
output_file_name = "outputxxx_" + file_name
|
| 248 |
output_video = os.path.join(dir_path, output_file_name)
|
| 249 |
+
command = f"ffmpeg -i {video} -r 25 {output_video} -y"
|
|
|
|
|
|
|
|
|
|
| 250 |
subprocess.run(command, shell=True, check=True)
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
return output_video
|
| 253 |
|
| 254 |
|
| 255 |
css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
|
| 256 |
|
| 257 |
with gr.Blocks(css=css) as demo:
|
| 258 |
gr.Markdown(
|
| 259 |
+
"<b>MuseTalk: Real-Time High Quality Lip Synchronization "
|
| 260 |
+
"with Latent Space Inpainting</b>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
)
|
| 262 |
|
| 263 |
with gr.Row():
|
| 264 |
with gr.Column():
|
| 265 |
+
audio = gr.Audio(label="Driven Audio", type="filepath")
|
| 266 |
video = gr.Video(label="Reference Video")
|
| 267 |
+
bbox_shift = gr.Number(label="BBox shift [-9, 9]", value=-1)
|
| 268 |
btn = gr.Button("Generate")
|
| 269 |
out1 = gr.Video()
|
| 270 |
+
|
| 271 |
+
video.change(fn=check_video, inputs=[video], outputs=[video])
|
|
|
|
|
|
|
| 272 |
btn.click(
|
| 273 |
fn=inference,
|
| 274 |
+
inputs=[audio, video, bbox_shift],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
outputs=out1,
|
| 276 |
)
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
demo.queue().launch(
|
| 279 |
+
share=False,
|
| 280 |
+
debug=True,
|
| 281 |
+
server_name="0.0.0.0",
|
| 282 |
+
server_port=7860,
|
| 283 |
)
|