MuseTalk / app.py
gonefishin1's picture
Auto-detect available ffmpeg codec (libx264/libopenh264/mpeg4 fallback)
4c2f9e8 verified
import os
import time
import sys
import subprocess
import glob
import copy
import pickle
import shutil
import tempfile
import gradio as gr
import gradio_client.utils as _gc_utils
_orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type
def _patched_json_schema_to_python_type(schema, defs=None):
if not isinstance(schema, dict):
return "Any"
return _orig_json_schema_to_python_type(schema, defs)
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
import numpy as np
import cv2
import torch
from tqdm import tqdm
from argparse import Namespace
from huggingface_hub import snapshot_download
import requests
ProjectDir = os.path.abspath(os.path.dirname(__file__))
ModelsDir = os.path.join(ProjectDir, "models")
def download_model():
"""Download model weights if not already present (entrypoint.sh handles this in Docker)."""
required_files = [
os.path.join(ModelsDir, "musetalkV15", "unet.pth"),
os.path.join(ModelsDir, "sd-vae", "diffusion_pytorch_model.safetensors"),
os.path.join(ModelsDir, "whisper", "config.json"),
os.path.join(ModelsDir, "dwpose", "dw-ll_ucoco_384.pth"),
]
all_present = all(os.path.exists(f) for f in required_files)
if all_present:
print("All model files present — skipping download.")
return
print("Some model files missing, attempting download...")
tic = time.time()
os.makedirs(ModelsDir, exist_ok=True)
try:
snapshot_download(
repo_id="TMElyralab/MuseTalk",
local_dir=ModelsDir,
max_workers=8,
local_dir_use_symlinks=True,
allow_patterns=["musetalk/*", "musetalkV15/*"],
)
except Exception as e:
print(f"Warning: MuseTalk model download failed: {e}")
try:
snapshot_download(
repo_id="stabilityai/sd-vae-ft-mse",
local_dir=os.path.join(ModelsDir, "sd-vae"),
max_workers=8,
local_dir_use_symlinks=True,
allow_patterns=["config.json", "diffusion_pytorch_model.*"],
)
except Exception as e:
print(f"Warning: SD VAE download failed: {e}")
try:
snapshot_download(
repo_id="openai/whisper-tiny",
local_dir=os.path.join(ModelsDir, "whisper"),
max_workers=8,
local_dir_use_symlinks=True,
allow_patterns=["config.json", "pytorch_model.bin", "preprocessor_config.json"],
)
except Exception as e:
print(f"Warning: Whisper download failed: {e}")
try:
snapshot_download(
repo_id="yzd-v/DWPose",
local_dir=os.path.join(ModelsDir, "dwpose"),
max_workers=8,
local_dir_use_symlinks=True,
allow_patterns=["dw-ll_ucoco_384.pth"],
)
except Exception as e:
print(f"Warning: DWPose download failed: {e}")
face_parse_dir = os.path.join(ModelsDir, "face-parse-bisent")
os.makedirs(face_parse_dir, exist_ok=True)
face_parse_path = os.path.join(face_parse_dir, "79999_iter.pth")
if not os.path.exists(face_parse_path):
try:
import gdown
gdown.download(
id="154JgKpzCPW82qINcVieuPH3fZ2e0P812",
output=face_parse_path,
quiet=False,
)
except Exception as e:
print(f"Warning: Face parse download failed: {e}")
resnet_path = os.path.join(face_parse_dir, "resnet18-5c106cde.pth")
if not os.path.exists(resnet_path):
try:
response = requests.get("https://download.pytorch.org/models/resnet18-5c106cde.pth")
if response.status_code == 200:
with open(resnet_path, "wb") as f:
f.write(response.content)
except Exception as e:
print(f"Warning: ResNet download failed: {e}")
toc = time.time()
print(f"Download completed in {toc - tic:.1f}s")
download_model()
from transformers import WhisperModel
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
from musetalk.utils.audio_processor import AudioProcessor
print("Loading models...")
def get_device():
if torch.cuda.is_available():
try:
torch.cuda.get_device_name(0)
return torch.device("cuda:0")
except RuntimeError:
print("CUDA reported available but device 0 is invalid, falling back to CPU")
return torch.device("cpu")
device = get_device()
weight_dtype = torch.float16 if device.type == "cuda" else torch.float32
vae, unet, pe = load_all_model(
unet_model_path="./models/musetalkV15/unet.pth",
vae_type="sd-vae",
unet_config="./models/musetalkV15/musetalk.json",
device=device,
)
if weight_dtype == torch.float16:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
timesteps = torch.tensor([0], device=device)
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
whisper = WhisperModel.from_pretrained("./models/whisper")
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
print(f"Models loaded on {device} ({weight_dtype}).")
FFMPEG_VCODEC = "mpeg4"
for codec in ["libx264", "libopenh264", "mpeg4"]:
r = subprocess.run(
f"ffmpeg -f lavfi -i nullsrc=s=2x2:d=0.1 -vcodec {codec} -f null -",
shell=True, capture_output=True,
)
if r.returncode == 0:
FFMPEG_VCODEC = codec
break
ffmpeg_version = subprocess.run("ffmpeg -version", shell=True, capture_output=True, text=True)
print(f"ffmpeg version: {ffmpeg_version.stdout.split(chr(10))[0]}")
print(f"Using video codec: {FFMPEG_VCODEC}")
@torch.no_grad()
def inference(audio_path, video_path, bbox_shift, progress=gr.Progress(track_tqdm=True)):
result_dir = os.path.join(tempfile.gettempdir(), f"musetalk_{time.time_ns()}")
os.makedirs(result_dir, exist_ok=True)
args_dict = {
"result_dir": result_dir,
"fps": 25,
"batch_size": 8,
"output_vid_name": "",
"use_saved_coord": False,
}
args = Namespace(**args_dict)
input_basename = os.path.basename(video_path).split(".")[0]
audio_basename = os.path.basename(audio_path).split(".")[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(result_dir, output_basename)
crop_coord_save_path = os.path.join(result_img_save_path, input_basename + ".pkl")
os.makedirs(result_img_save_path, exist_ok=True)
if args.output_vid_name == "":
output_vid_name = os.path.join(result_dir, output_basename + ".mp4")
else:
output_vid_name = os.path.join(result_dir, args.output_vid_name)
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]")))
fps = get_video_fps(video_path)
else:
input_img_list = glob.glob(os.path.join(video_path, "*.[jpJP][pnPN]*[gG]"))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps,
)
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("Using extracted coordinates")
with open(crop_coord_save_path, "rb") as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("Extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, "wb") as f:
pickle.dump(coord_list, f)
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
print("Starting inference...")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks, input_latent_list_cycle, batch_size, device=device)
res_frame_list = []
for i, (whisper_batch, latent_batch) in enumerate(
tqdm(gen, total=int(np.ceil(float(video_num) / batch_size)))
):
audio_feature_batch = whisper_batch.to(device=unet.device, dtype=weight_dtype)
audio_feature_batch = pe(audio_feature_batch)
pred_latents = unet.model(
latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
print("Compositing frames...")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i % len(coord_list_cycle)]
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except Exception:
continue
ori_frame[y1:y2, x1:x2] = res_frame
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", ori_frame)
frame_count = len(glob.glob(os.path.join(result_img_save_path, "*.png")))
print(f"Composited {frame_count} frames in {result_img_save_path}")
if frame_count == 0:
raise gr.Error("No frames were composited - check face detection / bbox")
temp_vid = os.path.join(result_dir, "temp.mp4")
codec_opts = f"-vcodec {FFMPEG_VCODEC} -pix_fmt yuv420p"
if FFMPEG_VCODEC == "libx264":
codec_opts += " -crf 18"
cmd_img2video = (
f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png "
f"{codec_opts} {temp_vid}"
)
r1 = subprocess.run(cmd_img2video, capture_output=True, text=True, shell=True)
print(f"ffmpeg img2video exit={r1.returncode}")
if r1.returncode != 0:
print(f"ffmpeg img2video stderr: {r1.stderr[:500]}")
raise gr.Error(f"ffmpeg img2video failed: {r1.stderr[:200]}")
cmd_combine = (
f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_vid_name}"
)
r2 = subprocess.run(cmd_combine, capture_output=True, text=True, shell=True)
print(f"ffmpeg combine exit={r2.returncode}")
if r2.returncode != 0:
print(f"ffmpeg combine stderr: {r2.stderr[:500]}")
raise gr.Error(f"ffmpeg combine failed: {r2.stderr[:200]}")
if os.path.exists(temp_vid):
os.remove(temp_vid)
shutil.rmtree(result_img_save_path, ignore_errors=True)
exists = os.path.isfile(output_vid_name)
size_kb = os.path.getsize(output_vid_name) // 1024 if exists else 0
print(f"Result saved to {output_vid_name} (exists={exists}, size={size_kb}KB)")
return output_vid_name
def check_video(video):
if video is None:
return None
dir_path, file_name = os.path.split(video)
if file_name.startswith("outputxxx_"):
return video
base_name, _ext = os.path.splitext(file_name)
output_file_name = f"outputxxx_{base_name}.mp4"
output_video = os.path.join(dir_path, output_file_name)
command = (
f"ffmpeg -i {video} -r 25 -c:v {FFMPEG_VCODEC} "
f"-pix_fmt yuv420p -an {output_video} -y"
)
subprocess.run(command, shell=True, check=True)
return output_video
css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"<b>MuseTalk: Real-Time High Quality Lip Synchronization "
"with Latent Space Inpainting</b>"
)
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Driven Audio", type="filepath")
video = gr.Video(label="Reference Video", format="mp4")
bbox_shift = gr.Number(label="BBox shift [-9, 9]", value=-1)
btn = gr.Button("Generate")
out1 = gr.Video()
video.change(fn=check_video, inputs=[video], outputs=[video])
btn.click(
fn=inference,
inputs=[audio, video, bbox_shift],
outputs=out1,
)
print(f"GRADIO_TEMP_DIR={os.environ.get('GRADIO_TEMP_DIR', 'NOT SET')}")
print(f"tempfile.gettempdir()={tempfile.gettempdir()}")
print(f"Results will be saved in /tmp/musetalk_* dirs")
demo.queue()
from fastapi import FastAPI, Query
from fastapi.responses import FileResponse, JSONResponse
app = FastAPI()
@app.get("/api/download")
async def download_result(path: str = Query(...)):
if not path.startswith("/tmp/musetalk_"):
return JSONResponse({"error": "forbidden path"}, status_code=403)
if not os.path.isfile(path):
existing = glob.glob("/tmp/musetalk_*")
files_in_dirs = []
for d in existing:
if os.path.isdir(d):
for f in os.listdir(d):
files_in_dirs.append(os.path.join(d, f))
return JSONResponse({
"error": "file not found",
"requested": path,
"musetalk_dirs": existing[:10],
"files_in_dirs": files_in_dirs[:20],
}, status_code=404)
return FileResponse(path, media_type="video/mp4", filename=os.path.basename(path))
app = gr.mount_gradio_app(app, demo, path="/")
print("Custom /api/download endpoint registered")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)