Spaces:
Configuration error
Configuration error
| """ | |
| MuseTalk HTTP API Server | |
| Keeps models loaded in GPU memory for fast inference. | |
| """ | |
| import os | |
| import cv2 | |
| import copy | |
| import torch | |
| import glob | |
| import shutil | |
| import pickle | |
| import numpy as np | |
| import subprocess | |
| import tempfile | |
| import hashlib | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| from transformers import WhisperModel | |
| import uvicorn | |
| # MuseTalk imports | |
| from musetalk.utils.blending import get_image | |
| from musetalk.utils.face_parsing import FaceParsing | |
| from musetalk.utils.audio_processor import AudioProcessor | |
| from musetalk.utils.utils import get_file_type, datagen, load_all_model | |
| from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder | |
| class MuseTalkServer: | |
| """Singleton server that keeps models loaded in GPU memory.""" | |
| def __init__(self): | |
| self.device = None | |
| self.vae = None | |
| self.unet = None | |
| self.pe = None | |
| self.whisper = None | |
| self.audio_processor = None | |
| self.fp = None | |
| self.timesteps = None | |
| self.weight_dtype = None | |
| self.is_loaded = False | |
| # Cache directories | |
| self.cache_dir = Path("./cache") | |
| self.cache_dir.mkdir(exist_ok=True) | |
| self.landmarks_cache = self.cache_dir / "landmarks" | |
| self.latents_cache = self.cache_dir / "latents" | |
| self.whisper_cache = self.cache_dir / "whisper_features" | |
| self.landmarks_cache.mkdir(exist_ok=True) | |
| self.latents_cache.mkdir(exist_ok=True) | |
| self.whisper_cache.mkdir(exist_ok=True) | |
| # Config | |
| self.fps = 25 | |
| self.batch_size = 8 | |
| self.use_float16 = True | |
| self.version = "v15" | |
| self.extra_margin = 10 | |
| self.parsing_mode = "jaw" | |
| self.left_cheek_width = 90 | |
| self.right_cheek_width = 90 | |
| self.audio_padding_left = 2 | |
| self.audio_padding_right = 2 | |
| def load_models( | |
| self, | |
| gpu_id: int = 0, | |
| unet_model_path: str = "./models/musetalkV15/unet.pth", | |
| unet_config: str = "./models/musetalk/config.json", | |
| vae_type: str = "sd-vae", | |
| whisper_dir: str = "./models/whisper", | |
| use_float16: bool = True, | |
| version: str = "v15" | |
| ): | |
| """Load all models into GPU memory.""" | |
| if self.is_loaded: | |
| print("Models already loaded!") | |
| return | |
| print("=" * 50) | |
| print("Loading MuseTalk models into GPU memory...") | |
| print("=" * 50) | |
| start_time = time.time() | |
| # Set device | |
| self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| # Load model weights | |
| print("Loading VAE, UNet, PE...") | |
| self.vae, self.unet, self.pe = load_all_model( | |
| unet_model_path=unet_model_path, | |
| vae_type=vae_type, | |
| unet_config=unet_config, | |
| device=self.device | |
| ) | |
| self.timesteps = torch.tensor([0], device=self.device) | |
| # Convert to float16 if enabled | |
| self.use_float16 = use_float16 | |
| if use_float16: | |
| print("Converting to float16...") | |
| self.pe = self.pe.half() | |
| self.vae.vae = self.vae.vae.half() | |
| self.unet.model = self.unet.model.half() | |
| # Move to device | |
| self.pe = self.pe.to(self.device) | |
| self.vae.vae = self.vae.vae.to(self.device) | |
| self.unet.model = self.unet.model.to(self.device) | |
| # Initialize audio processor and Whisper | |
| print("Loading Whisper model...") | |
| self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) | |
| self.weight_dtype = self.unet.model.dtype | |
| self.whisper = WhisperModel.from_pretrained(whisper_dir) | |
| self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() | |
| self.whisper.requires_grad_(False) | |
| # Initialize face parser | |
| self.version = version | |
| if version == "v15": | |
| self.fp = FaceParsing( | |
| left_cheek_width=self.left_cheek_width, | |
| right_cheek_width=self.right_cheek_width | |
| ) | |
| else: | |
| self.fp = FaceParsing() | |
| self.is_loaded = True | |
| load_time = time.time() - start_time | |
| print(f"Models loaded in {load_time:.2f}s") | |
| print("=" * 50) | |
| print("Server ready for inference!") | |
| print("=" * 50) | |
| def _get_file_hash(self, file_path: str) -> str: | |
| """Get MD5 hash of a file for caching.""" | |
| hash_md5 = hashlib.md5() | |
| with open(file_path, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest()[:16] | |
| def _get_cached_landmarks(self, video_hash: str, bbox_shift: int): | |
| """Get cached landmarks if available.""" | |
| # Disabled due to tensor comparison issues | |
| return None | |
| def _save_landmarks_cache(self, video_hash: str, bbox_shift: int, coord_list, frame_list): | |
| """Save landmarks to cache.""" | |
| cache_file = self.landmarks_cache / f"{video_hash}_shift{bbox_shift}.pkl" | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump((coord_list, frame_list), f) | |
| def _get_cached_latents(self, video_hash: str): | |
| """Get cached VAE latents if available.""" | |
| # Disabled due to tensor comparison issues | |
| return None | |
| def _save_latents_cache(self, video_hash: str, latent_list): | |
| """Save VAE latents to cache.""" | |
| cache_file = self.latents_cache / f"{video_hash}.pkl" | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(latent_list, f) | |
| def _get_cached_whisper(self, audio_hash: str): | |
| """Get cached Whisper features if available.""" | |
| # Disabled due to tensor comparison issues | |
| return None | |
| def _save_whisper_cache(self, audio_hash: str, whisper_data): | |
| """Save Whisper features to cache.""" | |
| cache_file = self.whisper_cache / f"{audio_hash}.pkl" | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(whisper_data, f) | |
| def generate( | |
| self, | |
| video_path: str, | |
| audio_path: str, | |
| output_path: str, | |
| fps: Optional[int] = None, | |
| use_cache: bool = True | |
| ) -> dict: | |
| """ | |
| Generate lip-synced video. | |
| Returns dict with timing info. | |
| """ | |
| if not self.is_loaded: | |
| raise RuntimeError("Models not loaded! Call load_models() first.") | |
| fps = fps or self.fps | |
| timings = {"total": 0} | |
| total_start = time.time() | |
| # Get file hashes for caching | |
| video_hash = self._get_file_hash(video_path) | |
| audio_hash = self._get_file_hash(audio_path) | |
| # Create temp directory | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| # 1. Extract frames | |
| t0 = time.time() | |
| input_basename = Path(video_path).stem | |
| save_dir_full = os.path.join(temp_dir, "frames") | |
| os.makedirs(save_dir_full, exist_ok=True) | |
| if get_file_type(video_path) == "video": | |
| cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={fps} -start_number 0 {save_dir_full}/%08d.png" | |
| os.system(cmd) | |
| input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) | |
| elif get_file_type(video_path) == "image": | |
| input_img_list = [video_path] | |
| else: | |
| raise ValueError(f"Unsupported video type: {video_path}") | |
| timings["frame_extraction"] = time.time() - t0 | |
| # 2. Extract audio features (with caching) | |
| t0 = time.time() | |
| cached_whisper = self._get_cached_whisper(audio_hash) if use_cache else None | |
| if cached_whisper: | |
| whisper_chunks = cached_whisper | |
| timings["whisper_source"] = "cache" | |
| else: | |
| whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) | |
| whisper_chunks = self.audio_processor.get_whisper_chunk( | |
| whisper_input_features, | |
| self.device, | |
| self.weight_dtype, | |
| self.whisper, | |
| librosa_length, | |
| fps=fps, | |
| audio_padding_length_left=self.audio_padding_left, | |
| audio_padding_length_right=self.audio_padding_right, | |
| ) | |
| if use_cache: | |
| self._save_whisper_cache(audio_hash, whisper_chunks) | |
| timings["whisper_source"] = "computed" | |
| timings["whisper_features"] = time.time() - t0 | |
| # 3. Get landmarks (with caching) | |
| t0 = time.time() | |
| bbox_shift = 0 if self.version == "v15" else 0 | |
| cache_key = f"{video_hash}_{fps}" | |
| cached_landmarks = self._get_cached_landmarks(cache_key, bbox_shift) if use_cache else None | |
| if cached_landmarks: | |
| coord_list, frame_list = cached_landmarks | |
| timings["landmarks_source"] = "cache" | |
| else: | |
| coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) | |
| if use_cache: | |
| self._save_landmarks_cache(cache_key, bbox_shift, coord_list, frame_list) | |
| timings["landmarks_source"] = "computed" | |
| timings["landmarks"] = time.time() - t0 | |
| # 4. Compute VAE latents (with caching) | |
| t0 = time.time() | |
| latent_cache_key = f"{video_hash}_{fps}_{self.version}" | |
| cached_latents = self._get_cached_latents(latent_cache_key) if use_cache else None | |
| if cached_latents: | |
| input_latent_list = cached_latents | |
| timings["latents_source"] = "cache" | |
| else: | |
| input_latent_list = [] | |
| for bbox, frame in zip(coord_list, frame_list): | |
| if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder): | |
| continue | |
| x1, y1, x2, y2 = bbox | |
| if self.version == "v15": | |
| y2 = y2 + self.extra_margin | |
| y2 = min(y2, frame.shape[0]) | |
| crop_frame = frame[y1:y2, x1:x2] | |
| crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) | |
| latents = self.vae.get_latents_for_unet(crop_frame) | |
| input_latent_list.append(latents) | |
| if use_cache: | |
| self._save_latents_cache(latent_cache_key, input_latent_list) | |
| timings["latents_source"] = "computed" | |
| timings["vae_encoding"] = time.time() - t0 | |
| # 5. Prepare cycled lists | |
| frame_list_cycle = frame_list + frame_list[::-1] | |
| coord_list_cycle = coord_list + coord_list[::-1] | |
| input_latent_list_cycle = input_latent_list + input_latent_list[::-1] | |
| # 6. UNet inference | |
| t0 = time.time() | |
| video_num = len(whisper_chunks) | |
| gen = datagen( | |
| whisper_chunks=whisper_chunks, | |
| vae_encode_latents=input_latent_list_cycle, | |
| batch_size=self.batch_size, | |
| delay_frame=0, | |
| device=self.device, | |
| ) | |
| res_frame_list = [] | |
| for whisper_batch, latent_batch in gen: | |
| audio_feature_batch = self.pe(whisper_batch) | |
| latent_batch = latent_batch.to(dtype=self.unet.model.dtype) | |
| pred_latents = self.unet.model( | |
| latent_batch, self.timesteps, | |
| encoder_hidden_states=audio_feature_batch | |
| ).sample | |
| recon = self.vae.decode_latents(pred_latents) | |
| for res_frame in recon: | |
| res_frame_list.append(res_frame) | |
| timings["unet_inference"] = time.time() - t0 | |
| # 7. Face blending | |
| t0 = time.time() | |
| result_img_path = os.path.join(temp_dir, "results") | |
| os.makedirs(result_img_path, exist_ok=True) | |
| for i, res_frame in enumerate(res_frame_list): | |
| bbox = coord_list_cycle[i % len(coord_list_cycle)] | |
| ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) | |
| x1, y1, x2, y2 = bbox | |
| if self.version == "v15": | |
| y2 = y2 + self.extra_margin | |
| y2 = min(y2, ori_frame.shape[0]) | |
| try: | |
| res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) | |
| except: | |
| continue | |
| if self.version == "v15": | |
| combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], | |
| mode=self.parsing_mode, fp=self.fp) | |
| else: | |
| combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) | |
| cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) | |
| timings["face_blending"] = time.time() - t0 | |
| # 8. Encode video | |
| t0 = time.time() | |
| temp_vid = os.path.join(temp_dir, "temp.mp4") | |
| cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" | |
| os.system(cmd_img2video) | |
| cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}" | |
| os.system(cmd_combine) | |
| timings["video_encoding"] = time.time() - t0 | |
| finally: | |
| # Cleanup | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| timings["total"] = time.time() - total_start | |
| timings["frames_generated"] = len(res_frame_list) | |
| return timings | |
| # Global server instance | |
| server = MuseTalkServer() | |
| # FastAPI app | |
| app = FastAPI( | |
| title="MuseTalk API", | |
| description="HTTP API for MuseTalk lip-sync generation", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """Load models on server startup.""" | |
| server.load_models() | |
| async def health_check(): | |
| """Check if server is ready.""" | |
| return { | |
| "status": "ok" if server.is_loaded else "loading", | |
| "models_loaded": server.is_loaded, | |
| "device": str(server.device) if server.device else None | |
| } | |
| async def cache_stats(): | |
| """Get cache statistics.""" | |
| landmarks_count = len(list(server.landmarks_cache.glob("*.pkl"))) | |
| latents_count = len(list(server.latents_cache.glob("*.pkl"))) | |
| whisper_count = len(list(server.whisper_cache.glob("*.pkl"))) | |
| return { | |
| "landmarks_cached": landmarks_count, | |
| "latents_cached": latents_count, | |
| "whisper_features_cached": whisper_count | |
| } | |
| async def clear_cache(): | |
| """Clear all caches.""" | |
| for cache_dir in [server.landmarks_cache, server.latents_cache, server.whisper_cache]: | |
| for f in cache_dir.glob("*.pkl"): | |
| f.unlink() | |
| return {"status": "cleared"} | |
| class GenerateRequest(BaseModel): | |
| video_path: str | |
| audio_path: str | |
| output_path: str | |
| fps: Optional[int] = 25 | |
| use_cache: bool = True | |
| async def generate_from_paths(request: GenerateRequest): | |
| """ | |
| Generate lip-synced video from file paths. | |
| Use this when files are already on the server. | |
| """ | |
| if not server.is_loaded: | |
| raise HTTPException(status_code=503, detail="Models not loaded yet") | |
| if not os.path.exists(request.video_path): | |
| raise HTTPException(status_code=404, detail=f"Video not found: {request.video_path}") | |
| if not os.path.exists(request.audio_path): | |
| raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") | |
| try: | |
| timings = server.generate( | |
| video_path=request.video_path, | |
| audio_path=request.audio_path, | |
| output_path=request.output_path, | |
| fps=request.fps, | |
| use_cache=request.use_cache | |
| ) | |
| return { | |
| "status": "success", | |
| "output_path": request.output_path, | |
| "timings": timings | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_from_upload( | |
| video: UploadFile = File(...), | |
| audio: UploadFile = File(...), | |
| fps: int = Form(25), | |
| use_cache: bool = Form(True) | |
| ): | |
| """ | |
| Generate lip-synced video from uploaded files. | |
| Returns the generated video file. | |
| """ | |
| if not server.is_loaded: | |
| raise HTTPException(status_code=503, detail="Models not loaded yet") | |
| # Save uploaded files | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| video_path = os.path.join(temp_dir, video.filename) | |
| audio_path = os.path.join(temp_dir, audio.filename) | |
| output_path = os.path.join(temp_dir, "output.mp4") | |
| with open(video_path, "wb") as f: | |
| f.write(await video.read()) | |
| with open(audio_path, "wb") as f: | |
| f.write(await audio.read()) | |
| timings = server.generate( | |
| video_path=video_path, | |
| audio_path=audio_path, | |
| output_path=output_path, | |
| fps=fps, | |
| use_cache=use_cache | |
| ) | |
| # Return the video file | |
| return FileResponse( | |
| output_path, | |
| media_type="video/mp4", | |
| filename="result.mp4", | |
| headers={"X-Timings": str(timings)} | |
| ) | |
| except Exception as e: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="MuseTalk API Server") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind") | |
| parser.add_argument("--port", type=int, default=8000, help="Port to bind") | |
| parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID") | |
| parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth") | |
| parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json") | |
| parser.add_argument("--whisper_dir", type=str, default="./models/whisper") | |
| parser.add_argument("--no_float16", action="store_true", help="Disable float16") | |
| args = parser.parse_args() | |
| # Pre-configure server | |
| server.load_models( | |
| gpu_id=args.gpu_id, | |
| unet_model_path=args.unet_model_path, | |
| unet_config=args.unet_config, | |
| whisper_dir=args.whisper_dir, | |
| use_float16=not args.no_float16 | |
| ) | |
| # Start server | |
| uvicorn.run(app, host=args.host, port=args.port) | |