Spaces:
Configuration error
Configuration error
| """ | |
| MuseTalk HTTP API Server v2 | |
| Optimized for repeated use of the same avatar. | |
| """ | |
| import os | |
| import cv2 | |
| import copy | |
| import torch | |
| import glob | |
| import shutil | |
| import pickle | |
| import numpy as np | |
| import subprocess | |
| import tempfile | |
| import hashlib | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| from transformers import WhisperModel | |
| import uvicorn | |
| # MuseTalk imports | |
| from musetalk.utils.blending import get_image | |
| from musetalk.utils.face_parsing import FaceParsing | |
| from musetalk.utils.audio_processor import AudioProcessor | |
| from musetalk.utils.utils import get_file_type, datagen, load_all_model | |
| from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder | |
| class MuseTalkServerV2: | |
| """Server optimized for pre-processed avatars.""" | |
| def __init__(self): | |
| self.device = None | |
| self.vae = None | |
| self.unet = None | |
| self.pe = None | |
| self.whisper = None | |
| self.audio_processor = None | |
| self.fp = None | |
| self.timesteps = None | |
| self.weight_dtype = None | |
| self.is_loaded = False | |
| # Avatar cache (in-memory) | |
| self.loaded_avatars = {} | |
| self.avatar_dir = Path("./avatars") | |
| # Config | |
| self.fps = 25 | |
| self.batch_size = 8 | |
| self.use_float16 = True | |
| self.version = "v15" | |
| self.extra_margin = 10 | |
| self.parsing_mode = "jaw" | |
| self.left_cheek_width = 90 | |
| self.right_cheek_width = 90 | |
| self.audio_padding_left = 2 | |
| self.audio_padding_right = 2 | |
| def load_models( | |
| self, | |
| gpu_id: int = 0, | |
| unet_model_path: str = "./models/musetalkV15/unet.pth", | |
| unet_config: str = "./models/musetalk/config.json", | |
| vae_type: str = "sd-vae", | |
| whisper_dir: str = "./models/whisper", | |
| use_float16: bool = True, | |
| version: str = "v15" | |
| ): | |
| if self.is_loaded: | |
| print("Models already loaded!") | |
| return | |
| print("=" * 50) | |
| print("Loading MuseTalk models into GPU memory...") | |
| print("=" * 50) | |
| start_time = time.time() | |
| self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| print("Loading VAE, UNet, PE...") | |
| self.vae, self.unet, self.pe = load_all_model( | |
| unet_model_path=unet_model_path, | |
| vae_type=vae_type, | |
| unet_config=unet_config, | |
| device=self.device | |
| ) | |
| self.timesteps = torch.tensor([0], device=self.device) | |
| self.use_float16 = use_float16 | |
| if use_float16: | |
| print("Converting to float16...") | |
| self.pe = self.pe.half() | |
| self.vae.vae = self.vae.vae.half() | |
| self.unet.model = self.unet.model.half() | |
| self.pe = self.pe.to(self.device) | |
| self.vae.vae = self.vae.vae.to(self.device) | |
| self.unet.model = self.unet.model.to(self.device) | |
| print("Loading Whisper model...") | |
| self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) | |
| self.weight_dtype = self.unet.model.dtype | |
| self.whisper = WhisperModel.from_pretrained(whisper_dir) | |
| self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() | |
| self.whisper.requires_grad_(False) | |
| self.version = version | |
| if version == "v15": | |
| self.fp = FaceParsing( | |
| left_cheek_width=self.left_cheek_width, | |
| right_cheek_width=self.right_cheek_width | |
| ) | |
| else: | |
| self.fp = FaceParsing() | |
| self.is_loaded = True | |
| print(f"Models loaded in {time.time() - start_time:.2f}s") | |
| print("=" * 50) | |
| def load_avatar(self, avatar_name: str) -> dict: | |
| """Load a preprocessed avatar into memory.""" | |
| if avatar_name in self.loaded_avatars: | |
| return self.loaded_avatars[avatar_name] | |
| avatar_path = self.avatar_dir / avatar_name | |
| if not avatar_path.exists(): | |
| raise FileNotFoundError(f"Avatar not found: {avatar_name}") | |
| print(f"Loading avatar '{avatar_name}' into memory...") | |
| t0 = time.time() | |
| avatar_data = {} | |
| # Load metadata | |
| with open(avatar_path / "metadata.pkl", 'rb') as f: | |
| avatar_data['metadata'] = pickle.load(f) | |
| # Load coords | |
| with open(avatar_path / "coords.pkl", 'rb') as f: | |
| avatar_data['coord_list'] = pickle.load(f) | |
| # Load frames | |
| with open(avatar_path / "frames.pkl", 'rb') as f: | |
| avatar_data['frame_list'] = pickle.load(f) | |
| # Load latents and convert to GPU tensors | |
| with open(avatar_path / "latents.pkl", 'rb') as f: | |
| latents_np = pickle.load(f) | |
| avatar_data['latent_list'] = [ | |
| torch.from_numpy(l).to(self.device) for l in latents_np | |
| ] | |
| # Load crop info | |
| with open(avatar_path / "crop_info.pkl", 'rb') as f: | |
| avatar_data['crop_info'] = pickle.load(f) | |
| # Load parsing data (optional) | |
| parsing_path = avatar_path / "parsing.pkl" | |
| if parsing_path.exists(): | |
| with open(parsing_path, 'rb') as f: | |
| avatar_data['parsing_data'] = pickle.load(f) | |
| self.loaded_avatars[avatar_name] = avatar_data | |
| print(f"Avatar loaded in {time.time() - t0:.2f}s") | |
| return avatar_data | |
| def unload_avatar(self, avatar_name: str): | |
| """Unload avatar from memory.""" | |
| if avatar_name in self.loaded_avatars: | |
| del self.loaded_avatars[avatar_name] | |
| torch.cuda.empty_cache() | |
| def generate_with_avatar( | |
| self, | |
| avatar_name: str, | |
| audio_path: str, | |
| output_path: str, | |
| fps: Optional[int] = None | |
| ) -> dict: | |
| """Generate video using pre-processed avatar. Much faster!""" | |
| if not self.is_loaded: | |
| raise RuntimeError("Models not loaded!") | |
| fps = fps or self.fps | |
| timings = {} | |
| total_start = time.time() | |
| # Load avatar (cached in memory) | |
| t0 = time.time() | |
| avatar = self.load_avatar(avatar_name) | |
| timings["avatar_load"] = time.time() - t0 | |
| coord_list = avatar['coord_list'] | |
| frame_list = avatar['frame_list'] | |
| input_latent_list = avatar['latent_list'] | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| # 1. Extract audio features (only audio-dependent step that's heavy) | |
| t0 = time.time() | |
| whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) | |
| whisper_chunks = self.audio_processor.get_whisper_chunk( | |
| whisper_input_features, | |
| self.device, | |
| self.weight_dtype, | |
| self.whisper, | |
| librosa_length, | |
| fps=fps, | |
| audio_padding_length_left=self.audio_padding_left, | |
| audio_padding_length_right=self.audio_padding_right, | |
| ) | |
| timings["whisper_features"] = time.time() - t0 | |
| # 2. Prepare cycled lists | |
| frame_list_cycle = frame_list + frame_list[::-1] | |
| coord_list_cycle = coord_list + coord_list[::-1] | |
| input_latent_list_cycle = input_latent_list + input_latent_list[::-1] | |
| # 3. UNet inference | |
| t0 = time.time() | |
| gen = datagen( | |
| whisper_chunks=whisper_chunks, | |
| vae_encode_latents=input_latent_list_cycle, | |
| batch_size=self.batch_size, | |
| delay_frame=0, | |
| device=self.device, | |
| ) | |
| res_frame_list = [] | |
| for whisper_batch, latent_batch in gen: | |
| audio_feature_batch = self.pe(whisper_batch) | |
| latent_batch = latent_batch.to(dtype=self.unet.model.dtype) | |
| pred_latents = self.unet.model( | |
| latent_batch, self.timesteps, | |
| encoder_hidden_states=audio_feature_batch | |
| ).sample | |
| recon = self.vae.decode_latents(pred_latents) | |
| for res_frame in recon: | |
| res_frame_list.append(res_frame) | |
| timings["unet_inference"] = time.time() - t0 | |
| # 4. Face blending | |
| t0 = time.time() | |
| result_img_path = os.path.join(temp_dir, "results") | |
| os.makedirs(result_img_path, exist_ok=True) | |
| for i, res_frame in enumerate(res_frame_list): | |
| bbox = coord_list_cycle[i % len(coord_list_cycle)] | |
| ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) | |
| x1, y1, x2, y2 = bbox | |
| if self.version == "v15": | |
| y2 = y2 + self.extra_margin | |
| y2 = min(y2, ori_frame.shape[0]) | |
| try: | |
| res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) | |
| except: | |
| continue | |
| if self.version == "v15": | |
| combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], | |
| mode=self.parsing_mode, fp=self.fp) | |
| else: | |
| combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) | |
| cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) | |
| timings["face_blending"] = time.time() - t0 | |
| # 5. Encode video | |
| t0 = time.time() | |
| temp_vid = os.path.join(temp_dir, "temp.mp4") | |
| cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" | |
| os.system(cmd_img2video) | |
| cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}" | |
| os.system(cmd_combine) | |
| timings["video_encoding"] = time.time() - t0 | |
| finally: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| timings["total"] = time.time() - total_start | |
| timings["frames_generated"] = len(res_frame_list) | |
| return timings | |
| # Global server instance | |
| server = MuseTalkServerV2() | |
| # FastAPI app | |
| app = FastAPI( | |
| title="MuseTalk API v2", | |
| description="Optimized API for repeated avatar usage", | |
| version="2.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| server.load_models() | |
| async def health_check(): | |
| return { | |
| "status": "ok" if server.is_loaded else "loading", | |
| "models_loaded": server.is_loaded, | |
| "device": str(server.device) if server.device else None, | |
| "loaded_avatars": list(server.loaded_avatars.keys()) | |
| } | |
| async def list_avatars(): | |
| """List all available preprocessed avatars.""" | |
| avatars = [] | |
| for p in server.avatar_dir.iterdir(): | |
| if p.is_dir() and (p / "metadata.pkl").exists(): | |
| with open(p / "metadata.pkl", 'rb') as f: | |
| metadata = pickle.load(f) | |
| metadata['loaded'] = p.name in server.loaded_avatars | |
| avatars.append(metadata) | |
| return {"avatars": avatars} | |
| async def load_avatar(avatar_name: str): | |
| """Pre-load an avatar into GPU memory.""" | |
| try: | |
| server.load_avatar(avatar_name) | |
| return {"status": "loaded", "avatar_name": avatar_name} | |
| except FileNotFoundError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| async def unload_avatar(avatar_name: str): | |
| """Unload an avatar from memory.""" | |
| server.unload_avatar(avatar_name) | |
| return {"status": "unloaded", "avatar_name": avatar_name} | |
| class GenerateWithAvatarRequest(BaseModel): | |
| avatar_name: str | |
| audio_path: str | |
| output_path: str | |
| fps: Optional[int] = 25 | |
| async def generate_with_avatar(request: GenerateWithAvatarRequest): | |
| """Generate video using pre-processed avatar. FAST!""" | |
| if not server.is_loaded: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| if not os.path.exists(request.audio_path): | |
| raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") | |
| try: | |
| timings = server.generate_with_avatar( | |
| avatar_name=request.avatar_name, | |
| audio_path=request.audio_path, | |
| output_path=request.output_path, | |
| fps=request.fps | |
| ) | |
| return { | |
| "status": "success", | |
| "output_path": request.output_path, | |
| "timings": timings | |
| } | |
| except FileNotFoundError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_with_avatar_upload( | |
| avatar_name: str = Form(...), | |
| audio: UploadFile = File(...), | |
| fps: int = Form(25) | |
| ): | |
| """Generate video from uploaded audio using pre-processed avatar.""" | |
| if not server.is_loaded: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| audio_path = os.path.join(temp_dir, audio.filename) | |
| output_path = os.path.join(temp_dir, "output.mp4") | |
| with open(audio_path, "wb") as f: | |
| f.write(await audio.read()) | |
| timings = server.generate_with_avatar( | |
| avatar_name=avatar_name, | |
| audio_path=audio_path, | |
| output_path=output_path, | |
| fps=fps | |
| ) | |
| return FileResponse( | |
| output_path, | |
| media_type="video/mp4", | |
| filename="result.mp4", | |
| headers={"X-Timings": str(timings)} | |
| ) | |
| except Exception as e: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=8000) | |
| args = parser.parse_args() | |
| uvicorn.run(app, host=args.host, port=args.port) | |