Spaces:
Running
Running
| import os | |
| import time | |
| import sys | |
| import subprocess | |
| import glob | |
| import copy | |
| import pickle | |
| import shutil | |
| import tempfile | |
| import gradio as gr | |
| import gradio_client.utils as _gc_utils | |
| _orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type | |
| def _patched_json_schema_to_python_type(schema, defs=None): | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| return _orig_json_schema_to_python_type(schema, defs) | |
| _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from tqdm import tqdm | |
| from argparse import Namespace | |
| from huggingface_hub import snapshot_download | |
| import requests | |
| ProjectDir = os.path.abspath(os.path.dirname(__file__)) | |
| ModelsDir = os.path.join(ProjectDir, "models") | |
| def download_model(): | |
| """Download model weights if not already present (entrypoint.sh handles this in Docker).""" | |
| required_files = [ | |
| os.path.join(ModelsDir, "musetalkV15", "unet.pth"), | |
| os.path.join(ModelsDir, "sd-vae", "diffusion_pytorch_model.safetensors"), | |
| os.path.join(ModelsDir, "whisper", "config.json"), | |
| os.path.join(ModelsDir, "dwpose", "dw-ll_ucoco_384.pth"), | |
| ] | |
| all_present = all(os.path.exists(f) for f in required_files) | |
| if all_present: | |
| print("All model files present — skipping download.") | |
| return | |
| print("Some model files missing, attempting download...") | |
| tic = time.time() | |
| os.makedirs(ModelsDir, exist_ok=True) | |
| try: | |
| snapshot_download( | |
| repo_id="TMElyralab/MuseTalk", | |
| local_dir=ModelsDir, | |
| max_workers=8, | |
| local_dir_use_symlinks=True, | |
| allow_patterns=["musetalk/*", "musetalkV15/*"], | |
| ) | |
| except Exception as e: | |
| print(f"Warning: MuseTalk model download failed: {e}") | |
| try: | |
| snapshot_download( | |
| repo_id="stabilityai/sd-vae-ft-mse", | |
| local_dir=os.path.join(ModelsDir, "sd-vae"), | |
| max_workers=8, | |
| local_dir_use_symlinks=True, | |
| allow_patterns=["config.json", "diffusion_pytorch_model.*"], | |
| ) | |
| except Exception as e: | |
| print(f"Warning: SD VAE download failed: {e}") | |
| try: | |
| snapshot_download( | |
| repo_id="openai/whisper-tiny", | |
| local_dir=os.path.join(ModelsDir, "whisper"), | |
| max_workers=8, | |
| local_dir_use_symlinks=True, | |
| allow_patterns=["config.json", "pytorch_model.bin", "preprocessor_config.json"], | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Whisper download failed: {e}") | |
| try: | |
| snapshot_download( | |
| repo_id="yzd-v/DWPose", | |
| local_dir=os.path.join(ModelsDir, "dwpose"), | |
| max_workers=8, | |
| local_dir_use_symlinks=True, | |
| allow_patterns=["dw-ll_ucoco_384.pth"], | |
| ) | |
| except Exception as e: | |
| print(f"Warning: DWPose download failed: {e}") | |
| face_parse_dir = os.path.join(ModelsDir, "face-parse-bisent") | |
| os.makedirs(face_parse_dir, exist_ok=True) | |
| face_parse_path = os.path.join(face_parse_dir, "79999_iter.pth") | |
| if not os.path.exists(face_parse_path): | |
| try: | |
| import gdown | |
| gdown.download( | |
| id="154JgKpzCPW82qINcVieuPH3fZ2e0P812", | |
| output=face_parse_path, | |
| quiet=False, | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Face parse download failed: {e}") | |
| resnet_path = os.path.join(face_parse_dir, "resnet18-5c106cde.pth") | |
| if not os.path.exists(resnet_path): | |
| try: | |
| response = requests.get("https://download.pytorch.org/models/resnet18-5c106cde.pth") | |
| if response.status_code == 200: | |
| with open(resnet_path, "wb") as f: | |
| f.write(response.content) | |
| except Exception as e: | |
| print(f"Warning: ResNet download failed: {e}") | |
| toc = time.time() | |
| print(f"Download completed in {toc - tic:.1f}s") | |
| download_model() | |
| from transformers import WhisperModel | |
| from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model | |
| from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder | |
| from musetalk.utils.audio_processor import AudioProcessor | |
| print("Loading models...") | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| try: | |
| torch.cuda.get_device_name(0) | |
| return torch.device("cuda:0") | |
| except RuntimeError: | |
| print("CUDA reported available but device 0 is invalid, falling back to CPU") | |
| return torch.device("cpu") | |
| device = get_device() | |
| weight_dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| vae, unet, pe = load_all_model( | |
| unet_model_path="./models/musetalkV15/unet.pth", | |
| vae_type="sd-vae", | |
| unet_config="./models/musetalkV15/musetalk.json", | |
| device=device, | |
| ) | |
| if weight_dtype == torch.float16: | |
| pe = pe.half() | |
| vae.vae = vae.vae.half() | |
| unet.model = unet.model.half() | |
| pe = pe.to(device) | |
| vae.vae = vae.vae.to(device) | |
| unet.model = unet.model.to(device) | |
| timesteps = torch.tensor([0], device=device) | |
| audio_processor = AudioProcessor(feature_extractor_path="./models/whisper") | |
| whisper = WhisperModel.from_pretrained("./models/whisper") | |
| whisper = whisper.to(device=device, dtype=weight_dtype).eval() | |
| whisper.requires_grad_(False) | |
| print(f"Models loaded on {device} ({weight_dtype}).") | |
| FFMPEG_VCODEC = "mpeg4" | |
| for codec in ["libx264", "libopenh264", "mpeg4"]: | |
| r = subprocess.run( | |
| f"ffmpeg -f lavfi -i nullsrc=s=2x2:d=0.1 -vcodec {codec} -f null -", | |
| shell=True, capture_output=True, | |
| ) | |
| if r.returncode == 0: | |
| FFMPEG_VCODEC = codec | |
| break | |
| ffmpeg_version = subprocess.run("ffmpeg -version", shell=True, capture_output=True, text=True) | |
| print(f"ffmpeg version: {ffmpeg_version.stdout.split(chr(10))[0]}") | |
| print(f"Using video codec: {FFMPEG_VCODEC}") | |
| def inference(audio_path, video_path, bbox_shift, progress=gr.Progress(track_tqdm=True)): | |
| result_dir = os.path.join(tempfile.gettempdir(), f"musetalk_{time.time_ns()}") | |
| os.makedirs(result_dir, exist_ok=True) | |
| args_dict = { | |
| "result_dir": result_dir, | |
| "fps": 25, | |
| "batch_size": 8, | |
| "output_vid_name": "", | |
| "use_saved_coord": False, | |
| } | |
| args = Namespace(**args_dict) | |
| input_basename = os.path.basename(video_path).split(".")[0] | |
| audio_basename = os.path.basename(audio_path).split(".")[0] | |
| output_basename = f"{input_basename}_{audio_basename}" | |
| result_img_save_path = os.path.join(result_dir, output_basename) | |
| crop_coord_save_path = os.path.join(result_img_save_path, input_basename + ".pkl") | |
| os.makedirs(result_img_save_path, exist_ok=True) | |
| if args.output_vid_name == "": | |
| output_vid_name = os.path.join(result_dir, output_basename + ".mp4") | |
| else: | |
| output_vid_name = os.path.join(result_dir, args.output_vid_name) | |
| if get_file_type(video_path) == "video": | |
| save_dir_full = os.path.join(args.result_dir, input_basename) | |
| os.makedirs(save_dir_full, exist_ok=True) | |
| cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" | |
| os.system(cmd) | |
| input_img_list = sorted(glob.glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]"))) | |
| fps = get_video_fps(video_path) | |
| else: | |
| input_img_list = glob.glob(os.path.join(video_path, "*.[jpJP][pnPN]*[gG]")) | |
| input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) | |
| fps = args.fps | |
| whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) | |
| whisper_chunks = audio_processor.get_whisper_chunk( | |
| whisper_input_features, | |
| device, | |
| weight_dtype, | |
| whisper, | |
| librosa_length, | |
| fps, | |
| ) | |
| if os.path.exists(crop_coord_save_path) and args.use_saved_coord: | |
| print("Using extracted coordinates") | |
| with open(crop_coord_save_path, "rb") as f: | |
| coord_list = pickle.load(f) | |
| frame_list = read_imgs(input_img_list) | |
| else: | |
| print("Extracting landmarks...") | |
| coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) | |
| with open(crop_coord_save_path, "wb") as f: | |
| pickle.dump(coord_list, f) | |
| input_latent_list = [] | |
| for bbox, frame in zip(coord_list, frame_list): | |
| if bbox == coord_placeholder: | |
| continue | |
| x1, y1, x2, y2 = bbox | |
| crop_frame = frame[y1:y2, x1:x2] | |
| crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) | |
| latents = vae.get_latents_for_unet(crop_frame) | |
| input_latent_list.append(latents) | |
| frame_list_cycle = frame_list + frame_list[::-1] | |
| coord_list_cycle = coord_list + coord_list[::-1] | |
| input_latent_list_cycle = input_latent_list + input_latent_list[::-1] | |
| print("Starting inference...") | |
| video_num = len(whisper_chunks) | |
| batch_size = args.batch_size | |
| gen = datagen(whisper_chunks, input_latent_list_cycle, batch_size, device=device) | |
| res_frame_list = [] | |
| for i, (whisper_batch, latent_batch) in enumerate( | |
| tqdm(gen, total=int(np.ceil(float(video_num) / batch_size))) | |
| ): | |
| audio_feature_batch = whisper_batch.to(device=unet.device, dtype=weight_dtype) | |
| audio_feature_batch = pe(audio_feature_batch) | |
| pred_latents = unet.model( | |
| latent_batch, timesteps, encoder_hidden_states=audio_feature_batch | |
| ).sample | |
| recon = vae.decode_latents(pred_latents) | |
| for res_frame in recon: | |
| res_frame_list.append(res_frame) | |
| print("Compositing frames...") | |
| for i, res_frame in enumerate(tqdm(res_frame_list)): | |
| bbox = coord_list_cycle[i % len(coord_list_cycle)] | |
| ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) | |
| x1, y1, x2, y2 = bbox | |
| try: | |
| res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) | |
| except Exception: | |
| continue | |
| ori_frame[y1:y2, x1:x2] = res_frame | |
| cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", ori_frame) | |
| frame_count = len(glob.glob(os.path.join(result_img_save_path, "*.png"))) | |
| print(f"Composited {frame_count} frames in {result_img_save_path}") | |
| if frame_count == 0: | |
| raise gr.Error("No frames were composited - check face detection / bbox") | |
| temp_vid = os.path.join(result_dir, "temp.mp4") | |
| codec_opts = f"-vcodec {FFMPEG_VCODEC} -pix_fmt yuv420p" | |
| if FFMPEG_VCODEC == "libx264": | |
| codec_opts += " -crf 18" | |
| cmd_img2video = ( | |
| f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png " | |
| f"{codec_opts} {temp_vid}" | |
| ) | |
| r1 = subprocess.run(cmd_img2video, capture_output=True, text=True, shell=True) | |
| print(f"ffmpeg img2video exit={r1.returncode}") | |
| if r1.returncode != 0: | |
| print(f"ffmpeg img2video stderr: {r1.stderr[:500]}") | |
| raise gr.Error(f"ffmpeg img2video failed: {r1.stderr[:200]}") | |
| cmd_combine = ( | |
| f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_vid_name}" | |
| ) | |
| r2 = subprocess.run(cmd_combine, capture_output=True, text=True, shell=True) | |
| print(f"ffmpeg combine exit={r2.returncode}") | |
| if r2.returncode != 0: | |
| print(f"ffmpeg combine stderr: {r2.stderr[:500]}") | |
| raise gr.Error(f"ffmpeg combine failed: {r2.stderr[:200]}") | |
| if os.path.exists(temp_vid): | |
| os.remove(temp_vid) | |
| shutil.rmtree(result_img_save_path, ignore_errors=True) | |
| exists = os.path.isfile(output_vid_name) | |
| size_kb = os.path.getsize(output_vid_name) // 1024 if exists else 0 | |
| print(f"Result saved to {output_vid_name} (exists={exists}, size={size_kb}KB)") | |
| return output_vid_name | |
| def check_video(video): | |
| if video is None: | |
| return None | |
| dir_path, file_name = os.path.split(video) | |
| if file_name.startswith("outputxxx_"): | |
| return video | |
| base_name, _ext = os.path.splitext(file_name) | |
| output_file_name = f"outputxxx_{base_name}.mp4" | |
| output_video = os.path.join(dir_path, output_file_name) | |
| command = ( | |
| f"ffmpeg -i {video} -r 25 -c:v {FFMPEG_VCODEC} " | |
| f"-pix_fmt yuv420p -an {output_video} -y" | |
| ) | |
| subprocess.run(command, shell=True, check=True) | |
| return output_video | |
| css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}""" | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| "<b>MuseTalk: Real-Time High Quality Lip Synchronization " | |
| "with Latent Space Inpainting</b>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio = gr.Audio(label="Driven Audio", type="filepath") | |
| video = gr.Video(label="Reference Video", format="mp4") | |
| bbox_shift = gr.Number(label="BBox shift [-9, 9]", value=-1) | |
| btn = gr.Button("Generate") | |
| out1 = gr.Video() | |
| video.change(fn=check_video, inputs=[video], outputs=[video]) | |
| btn.click( | |
| fn=inference, | |
| inputs=[audio, video, bbox_shift], | |
| outputs=out1, | |
| ) | |
| print(f"GRADIO_TEMP_DIR={os.environ.get('GRADIO_TEMP_DIR', 'NOT SET')}") | |
| print(f"tempfile.gettempdir()={tempfile.gettempdir()}") | |
| print(f"Results will be saved in /tmp/musetalk_* dirs") | |
| demo.queue() | |
| from fastapi import FastAPI, Query | |
| from fastapi.responses import FileResponse, JSONResponse | |
| app = FastAPI() | |
| async def download_result(path: str = Query(...)): | |
| if not path.startswith("/tmp/musetalk_"): | |
| return JSONResponse({"error": "forbidden path"}, status_code=403) | |
| if not os.path.isfile(path): | |
| existing = glob.glob("/tmp/musetalk_*") | |
| files_in_dirs = [] | |
| for d in existing: | |
| if os.path.isdir(d): | |
| for f in os.listdir(d): | |
| files_in_dirs.append(os.path.join(d, f)) | |
| return JSONResponse({ | |
| "error": "file not found", | |
| "requested": path, | |
| "musetalk_dirs": existing[:10], | |
| "files_in_dirs": files_in_dirs[:20], | |
| }, status_code=404) | |
| return FileResponse(path, media_type="video/mp4", filename=os.path.basename(path)) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| print("Custom /api/download endpoint registered") | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |