import os import time import sys import subprocess import glob import copy import pickle import shutil import tempfile import gradio as gr import gradio_client.utils as _gc_utils _orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type def _patched_json_schema_to_python_type(schema, defs=None): if not isinstance(schema, dict): return "Any" return _orig_json_schema_to_python_type(schema, defs) _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type import numpy as np import cv2 import torch from tqdm import tqdm from argparse import Namespace from huggingface_hub import snapshot_download import requests ProjectDir = os.path.abspath(os.path.dirname(__file__)) ModelsDir = os.path.join(ProjectDir, "models") def download_model(): """Download model weights if not already present (entrypoint.sh handles this in Docker).""" required_files = [ os.path.join(ModelsDir, "musetalkV15", "unet.pth"), os.path.join(ModelsDir, "sd-vae", "diffusion_pytorch_model.safetensors"), os.path.join(ModelsDir, "whisper", "config.json"), os.path.join(ModelsDir, "dwpose", "dw-ll_ucoco_384.pth"), ] all_present = all(os.path.exists(f) for f in required_files) if all_present: print("All model files present — skipping download.") return print("Some model files missing, attempting download...") tic = time.time() os.makedirs(ModelsDir, exist_ok=True) try: snapshot_download( repo_id="TMElyralab/MuseTalk", local_dir=ModelsDir, max_workers=8, local_dir_use_symlinks=True, allow_patterns=["musetalk/*", "musetalkV15/*"], ) except Exception as e: print(f"Warning: MuseTalk model download failed: {e}") try: snapshot_download( repo_id="stabilityai/sd-vae-ft-mse", local_dir=os.path.join(ModelsDir, "sd-vae"), max_workers=8, local_dir_use_symlinks=True, allow_patterns=["config.json", "diffusion_pytorch_model.*"], ) except Exception as e: print(f"Warning: SD VAE download failed: {e}") try: snapshot_download( repo_id="openai/whisper-tiny", local_dir=os.path.join(ModelsDir, "whisper"), max_workers=8, local_dir_use_symlinks=True, allow_patterns=["config.json", "pytorch_model.bin", "preprocessor_config.json"], ) except Exception as e: print(f"Warning: Whisper download failed: {e}") try: snapshot_download( repo_id="yzd-v/DWPose", local_dir=os.path.join(ModelsDir, "dwpose"), max_workers=8, local_dir_use_symlinks=True, allow_patterns=["dw-ll_ucoco_384.pth"], ) except Exception as e: print(f"Warning: DWPose download failed: {e}") face_parse_dir = os.path.join(ModelsDir, "face-parse-bisent") os.makedirs(face_parse_dir, exist_ok=True) face_parse_path = os.path.join(face_parse_dir, "79999_iter.pth") if not os.path.exists(face_parse_path): try: import gdown gdown.download( id="154JgKpzCPW82qINcVieuPH3fZ2e0P812", output=face_parse_path, quiet=False, ) except Exception as e: print(f"Warning: Face parse download failed: {e}") resnet_path = os.path.join(face_parse_dir, "resnet18-5c106cde.pth") if not os.path.exists(resnet_path): try: response = requests.get("https://download.pytorch.org/models/resnet18-5c106cde.pth") if response.status_code == 200: with open(resnet_path, "wb") as f: f.write(response.content) except Exception as e: print(f"Warning: ResNet download failed: {e}") toc = time.time() print(f"Download completed in {toc - tic:.1f}s") download_model() from transformers import WhisperModel from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder from musetalk.utils.audio_processor import AudioProcessor print("Loading models...") def get_device(): if torch.cuda.is_available(): try: torch.cuda.get_device_name(0) return torch.device("cuda:0") except RuntimeError: print("CUDA reported available but device 0 is invalid, falling back to CPU") return torch.device("cpu") device = get_device() weight_dtype = torch.float16 if device.type == "cuda" else torch.float32 vae, unet, pe = load_all_model( unet_model_path="./models/musetalkV15/unet.pth", vae_type="sd-vae", unet_config="./models/musetalkV15/musetalk.json", device=device, ) if weight_dtype == torch.float16: pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() pe = pe.to(device) vae.vae = vae.vae.to(device) unet.model = unet.model.to(device) timesteps = torch.tensor([0], device=device) audio_processor = AudioProcessor(feature_extractor_path="./models/whisper") whisper = WhisperModel.from_pretrained("./models/whisper") whisper = whisper.to(device=device, dtype=weight_dtype).eval() whisper.requires_grad_(False) print(f"Models loaded on {device} ({weight_dtype}).") FFMPEG_VCODEC = "mpeg4" for codec in ["libx264", "libopenh264", "mpeg4"]: r = subprocess.run( f"ffmpeg -f lavfi -i nullsrc=s=2x2:d=0.1 -vcodec {codec} -f null -", shell=True, capture_output=True, ) if r.returncode == 0: FFMPEG_VCODEC = codec break ffmpeg_version = subprocess.run("ffmpeg -version", shell=True, capture_output=True, text=True) print(f"ffmpeg version: {ffmpeg_version.stdout.split(chr(10))[0]}") print(f"Using video codec: {FFMPEG_VCODEC}") @torch.no_grad() def inference(audio_path, video_path, bbox_shift, progress=gr.Progress(track_tqdm=True)): result_dir = os.path.join(tempfile.gettempdir(), f"musetalk_{time.time_ns()}") os.makedirs(result_dir, exist_ok=True) args_dict = { "result_dir": result_dir, "fps": 25, "batch_size": 8, "output_vid_name": "", "use_saved_coord": False, } args = Namespace(**args_dict) input_basename = os.path.basename(video_path).split(".")[0] audio_basename = os.path.basename(audio_path).split(".")[0] output_basename = f"{input_basename}_{audio_basename}" result_img_save_path = os.path.join(result_dir, output_basename) crop_coord_save_path = os.path.join(result_img_save_path, input_basename + ".pkl") os.makedirs(result_img_save_path, exist_ok=True) if args.output_vid_name == "": output_vid_name = os.path.join(result_dir, output_basename + ".mp4") else: output_vid_name = os.path.join(result_dir, args.output_vid_name) if get_file_type(video_path) == "video": save_dir_full = os.path.join(args.result_dir, input_basename) os.makedirs(save_dir_full, exist_ok=True) cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" os.system(cmd) input_img_list = sorted(glob.glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]"))) fps = get_video_fps(video_path) else: input_img_list = glob.glob(os.path.join(video_path, "*.[jpJP][pnPN]*[gG]")) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) fps = args.fps whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) whisper_chunks = audio_processor.get_whisper_chunk( whisper_input_features, device, weight_dtype, whisper, librosa_length, fps, ) if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("Using extracted coordinates") with open(crop_coord_save_path, "rb") as f: coord_list = pickle.load(f) frame_list = read_imgs(input_img_list) else: print("Extracting landmarks...") coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, "wb") as f: pickle.dump(coord_list, f) input_latent_list = [] for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] print("Starting inference...") video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks, input_latent_list_cycle, batch_size, device=device) res_frame_list = [] for i, (whisper_batch, latent_batch) in enumerate( tqdm(gen, total=int(np.ceil(float(video_num) / batch_size))) ): audio_feature_batch = whisper_batch.to(device=unet.device, dtype=weight_dtype) audio_feature_batch = pe(audio_feature_batch) pred_latents = unet.model( latent_batch, timesteps, encoder_hidden_states=audio_feature_batch ).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) print("Compositing frames...") for i, res_frame in enumerate(tqdm(res_frame_list)): bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) x1, y1, x2, y2 = bbox try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) except Exception: continue ori_frame[y1:y2, x1:x2] = res_frame cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", ori_frame) frame_count = len(glob.glob(os.path.join(result_img_save_path, "*.png"))) print(f"Composited {frame_count} frames in {result_img_save_path}") if frame_count == 0: raise gr.Error("No frames were composited - check face detection / bbox") temp_vid = os.path.join(result_dir, "temp.mp4") codec_opts = f"-vcodec {FFMPEG_VCODEC} -pix_fmt yuv420p" if FFMPEG_VCODEC == "libx264": codec_opts += " -crf 18" cmd_img2video = ( f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png " f"{codec_opts} {temp_vid}" ) r1 = subprocess.run(cmd_img2video, capture_output=True, text=True, shell=True) print(f"ffmpeg img2video exit={r1.returncode}") if r1.returncode != 0: print(f"ffmpeg img2video stderr: {r1.stderr[:500]}") raise gr.Error(f"ffmpeg img2video failed: {r1.stderr[:200]}") cmd_combine = ( f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_vid_name}" ) r2 = subprocess.run(cmd_combine, capture_output=True, text=True, shell=True) print(f"ffmpeg combine exit={r2.returncode}") if r2.returncode != 0: print(f"ffmpeg combine stderr: {r2.stderr[:500]}") raise gr.Error(f"ffmpeg combine failed: {r2.stderr[:200]}") if os.path.exists(temp_vid): os.remove(temp_vid) shutil.rmtree(result_img_save_path, ignore_errors=True) exists = os.path.isfile(output_vid_name) size_kb = os.path.getsize(output_vid_name) // 1024 if exists else 0 print(f"Result saved to {output_vid_name} (exists={exists}, size={size_kb}KB)") return output_vid_name def check_video(video): if video is None: return None dir_path, file_name = os.path.split(video) if file_name.startswith("outputxxx_"): return video base_name, _ext = os.path.splitext(file_name) output_file_name = f"outputxxx_{base_name}.mp4" output_video = os.path.join(dir_path, output_file_name) command = ( f"ffmpeg -i {video} -r 25 -c:v {FFMPEG_VCODEC} " f"-pix_fmt yuv420p -an {output_video} -y" ) subprocess.run(command, shell=True, check=True) return output_video css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}""" with gr.Blocks(css=css) as demo: gr.Markdown( "MuseTalk: Real-Time High Quality Lip Synchronization " "with Latent Space Inpainting" ) with gr.Row(): with gr.Column(): audio = gr.Audio(label="Driven Audio", type="filepath") video = gr.Video(label="Reference Video", format="mp4") bbox_shift = gr.Number(label="BBox shift [-9, 9]", value=-1) btn = gr.Button("Generate") out1 = gr.Video() video.change(fn=check_video, inputs=[video], outputs=[video]) btn.click( fn=inference, inputs=[audio, video, bbox_shift], outputs=out1, ) print(f"GRADIO_TEMP_DIR={os.environ.get('GRADIO_TEMP_DIR', 'NOT SET')}") print(f"tempfile.gettempdir()={tempfile.gettempdir()}") print(f"Results will be saved in /tmp/musetalk_* dirs") demo.queue() from fastapi import FastAPI, Query from fastapi.responses import FileResponse, JSONResponse app = FastAPI() @app.get("/api/download") async def download_result(path: str = Query(...)): if not path.startswith("/tmp/musetalk_"): return JSONResponse({"error": "forbidden path"}, status_code=403) if not os.path.isfile(path): existing = glob.glob("/tmp/musetalk_*") files_in_dirs = [] for d in existing: if os.path.isdir(d): for f in os.listdir(d): files_in_dirs.append(os.path.join(d, f)) return JSONResponse({ "error": "file not found", "requested": path, "musetalk_dirs": existing[:10], "files_in_dirs": files_in_dirs[:20], }, status_code=404) return FileResponse(path, media_type="video/mp4", filename=os.path.basename(path)) app = gr.mount_gradio_app(app, demo, path="/") print("Custom /api/download endpoint registered") import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)