Spaces:
Configuration error
Configuration error
| """ | |
| MuseTalk HTTP API Server v3 (Fixed) | |
| Optimized with: | |
| 1. Sequential face blending (parallel had overhead) | |
| 2. NVENC hardware video encoding | |
| 3. Batch audio processing | |
| """ | |
| import os | |
| import cv2 | |
| import copy | |
| import torch | |
| import glob | |
| import shutil | |
| import pickle | |
| import numpy as np | |
| import subprocess | |
| import tempfile | |
| import hashlib | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, List | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from tqdm import tqdm | |
| from transformers import WhisperModel | |
| import uvicorn | |
| # MuseTalk imports | |
| from musetalk.utils.blending import get_image | |
| from musetalk.utils.face_parsing import FaceParsing | |
| from musetalk.utils.audio_processor import AudioProcessor | |
| from musetalk.utils.utils import get_file_type, datagen, load_all_model | |
| from musetalk.utils.preprocessing import coord_placeholder | |
| class MuseTalkServerV3: | |
| def __init__(self): | |
| self.device = None | |
| self.vae = None | |
| self.unet = None | |
| self.pe = None | |
| self.whisper = None | |
| self.audio_processor = None | |
| self.fp = None | |
| self.timesteps = None | |
| self.weight_dtype = None | |
| self.is_loaded = False | |
| self.loaded_avatars = {} | |
| self.avatar_dir = Path("./avatars") | |
| self.fps = 25 | |
| self.batch_size = 8 | |
| self.use_float16 = True | |
| self.version = "v15" | |
| self.extra_margin = 10 | |
| self.parsing_mode = "jaw" | |
| self.left_cheek_width = 90 | |
| self.right_cheek_width = 90 | |
| self.audio_padding_left = 2 | |
| self.audio_padding_right = 2 | |
| # NVENC | |
| self.use_nvenc = True | |
| self.nvenc_preset = "p4" | |
| self.crf = 23 | |
| def load_models(self, gpu_id: int = 0): | |
| if self.is_loaded: | |
| print("Models already loaded!") | |
| return | |
| print("=" * 50) | |
| print("Loading MuseTalk models (v3 Optimized)...") | |
| print("=" * 50) | |
| start_time = time.time() | |
| self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") | |
| self.vae, self.unet, self.pe = load_all_model( | |
| unet_model_path="./models/musetalkV15/unet.pth", | |
| vae_type="sd-vae", | |
| unet_config="./models/musetalk/config.json", | |
| device=self.device | |
| ) | |
| self.timesteps = torch.tensor([0], device=self.device) | |
| self.pe = self.pe.half().to(self.device) | |
| self.vae.vae = self.vae.vae.half().to(self.device) | |
| self.unet.model = self.unet.model.half().to(self.device) | |
| self.audio_processor = AudioProcessor(feature_extractor_path="./models/whisper") | |
| self.weight_dtype = self.unet.model.dtype | |
| self.whisper = WhisperModel.from_pretrained("./models/whisper") | |
| self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() | |
| self.whisper.requires_grad_(False) | |
| self.fp = FaceParsing( | |
| left_cheek_width=self.left_cheek_width, | |
| right_cheek_width=self.right_cheek_width | |
| ) | |
| self.is_loaded = True | |
| print(f"Models loaded in {time.time() - start_time:.2f}s") | |
| def load_avatar(self, avatar_name: str) -> dict: | |
| if avatar_name in self.loaded_avatars: | |
| return self.loaded_avatars[avatar_name] | |
| avatar_path = self.avatar_dir / avatar_name | |
| if not avatar_path.exists(): | |
| raise FileNotFoundError(f"Avatar not found: {avatar_name}") | |
| avatar_data = {} | |
| with open(avatar_path / "metadata.pkl", 'rb') as f: | |
| avatar_data['metadata'] = pickle.load(f) | |
| with open(avatar_path / "coords.pkl", 'rb') as f: | |
| avatar_data['coord_list'] = pickle.load(f) | |
| with open(avatar_path / "frames.pkl", 'rb') as f: | |
| avatar_data['frame_list'] = pickle.load(f) | |
| with open(avatar_path / "latents.pkl", 'rb') as f: | |
| latents_np = pickle.load(f) | |
| avatar_data['latent_list'] = [torch.from_numpy(l).to(self.device) for l in latents_np] | |
| self.loaded_avatars[avatar_name] = avatar_data | |
| return avatar_data | |
| def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float: | |
| t0 = time.time() | |
| temp_vid = output_path.replace('.mp4', '_temp.mp4') | |
| if self.use_nvenc: | |
| cmd = ( | |
| f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " | |
| f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} -pix_fmt yuv420p {temp_vid}" | |
| ) | |
| else: | |
| cmd = ( | |
| f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " | |
| f"-vcodec libx264 -crf 18 -pix_fmt yuv420p {temp_vid}" | |
| ) | |
| os.system(cmd) | |
| os.system(f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}") | |
| os.remove(temp_vid) if os.path.exists(temp_vid) else None | |
| return time.time() - t0 | |
| def generate_with_avatar(self, avatar_name: str, audio_path: str, output_path: str, fps: int = 25) -> dict: | |
| if not self.is_loaded: | |
| raise RuntimeError("Models not loaded!") | |
| timings = {} | |
| total_start = time.time() | |
| t0 = time.time() | |
| avatar = self.load_avatar(avatar_name) | |
| timings["avatar_load"] = time.time() - t0 | |
| coord_list = avatar['coord_list'] | |
| frame_list = avatar['frame_list'] | |
| input_latent_list = avatar['latent_list'] | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| # Whisper | |
| t0 = time.time() | |
| whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) | |
| whisper_chunks = self.audio_processor.get_whisper_chunk( | |
| whisper_input_features, self.device, self.weight_dtype, self.whisper, | |
| librosa_length, fps=fps, | |
| audio_padding_length_left=self.audio_padding_left, | |
| audio_padding_length_right=self.audio_padding_right, | |
| ) | |
| timings["whisper_features"] = time.time() - t0 | |
| # Cycle lists | |
| frame_list_cycle = frame_list + frame_list[::-1] | |
| coord_list_cycle = coord_list + coord_list[::-1] | |
| input_latent_list_cycle = input_latent_list + input_latent_list[::-1] | |
| # UNet | |
| t0 = time.time() | |
| gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, | |
| batch_size=self.batch_size, delay_frame=0, device=self.device) | |
| res_frame_list = [] | |
| for whisper_batch, latent_batch in gen: | |
| audio_feature_batch = self.pe(whisper_batch) | |
| latent_batch = latent_batch.to(dtype=self.unet.model.dtype) | |
| pred_latents = self.unet.model(latent_batch, self.timesteps, | |
| encoder_hidden_states=audio_feature_batch).sample | |
| recon = self.vae.decode_latents(pred_latents) | |
| res_frame_list.extend(recon) | |
| timings["unet_inference"] = time.time() - t0 | |
| # Face blending (sequential - faster than parallel due to FP overhead) | |
| t0 = time.time() | |
| result_img_path = os.path.join(temp_dir, "results") | |
| os.makedirs(result_img_path, exist_ok=True) | |
| for i, res_frame in enumerate(res_frame_list): | |
| bbox = coord_list_cycle[i % len(coord_list_cycle)] | |
| ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) | |
| x1, y1, x2, y2 = bbox | |
| y2 = min(y2 + self.extra_margin, ori_frame.shape[0]) | |
| try: | |
| res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) | |
| combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], | |
| mode=self.parsing_mode, fp=self.fp) | |
| cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) | |
| except: | |
| continue | |
| timings["face_blending"] = time.time() - t0 | |
| # NVENC encoding | |
| timings["video_encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps) | |
| finally: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| timings["total"] = time.time() - total_start | |
| timings["frames_generated"] = len(res_frame_list) | |
| return timings | |
| def generate_batch(self, avatar_name: str, audio_paths: List[str], output_dir: str, fps: int = 25) -> dict: | |
| if not self.is_loaded: | |
| raise RuntimeError("Models not loaded!") | |
| batch_timings = {"videos": [], "total": 0} | |
| total_start = time.time() | |
| t0 = time.time() | |
| avatar = self.load_avatar(avatar_name) | |
| batch_timings["avatar_load"] = time.time() - t0 | |
| coord_list = avatar['coord_list'] | |
| frame_list = avatar['frame_list'] | |
| input_latent_list = avatar['latent_list'] | |
| frame_list_cycle = frame_list + frame_list[::-1] | |
| coord_list_cycle = coord_list + coord_list[::-1] | |
| input_latent_list_cycle = input_latent_list + input_latent_list[::-1] | |
| os.makedirs(output_dir, exist_ok=True) | |
| for idx, audio_path in enumerate(audio_paths): | |
| video_start = time.time() | |
| timings = {} | |
| output_path = os.path.join(output_dir, f"{Path(audio_path).stem}.mp4") | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| t0 = time.time() | |
| whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) | |
| whisper_chunks = self.audio_processor.get_whisper_chunk( | |
| whisper_input_features, self.device, self.weight_dtype, self.whisper, | |
| librosa_length, fps=fps, | |
| audio_padding_length_left=self.audio_padding_left, | |
| audio_padding_length_right=self.audio_padding_right, | |
| ) | |
| timings["whisper"] = time.time() - t0 | |
| t0 = time.time() | |
| gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, | |
| batch_size=self.batch_size, delay_frame=0, device=self.device) | |
| res_frame_list = [] | |
| for whisper_batch, latent_batch in gen: | |
| audio_feature_batch = self.pe(whisper_batch) | |
| latent_batch = latent_batch.to(dtype=self.unet.model.dtype) | |
| pred_latents = self.unet.model(latent_batch, self.timesteps, | |
| encoder_hidden_states=audio_feature_batch).sample | |
| res_frame_list.extend(self.vae.decode_latents(pred_latents)) | |
| timings["unet"] = time.time() - t0 | |
| t0 = time.time() | |
| result_img_path = os.path.join(temp_dir, "results") | |
| os.makedirs(result_img_path, exist_ok=True) | |
| for i, res_frame in enumerate(res_frame_list): | |
| bbox = coord_list_cycle[i % len(coord_list_cycle)] | |
| ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) | |
| x1, y1, x2, y2 = bbox | |
| y2 = min(y2 + self.extra_margin, ori_frame.shape[0]) | |
| try: | |
| res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) | |
| combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], | |
| mode=self.parsing_mode, fp=self.fp) | |
| cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) | |
| except: | |
| continue | |
| timings["blending"] = time.time() - t0 | |
| timings["encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps) | |
| finally: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| timings["total"] = time.time() - video_start | |
| timings["frames"] = len(res_frame_list) | |
| timings["output"] = output_path | |
| batch_timings["videos"].append(timings) | |
| print(f" [{idx+1}/{len(audio_paths)}] {Path(audio_path).stem}: {timings['total']:.2f}s") | |
| batch_timings["total"] = time.time() - total_start | |
| batch_timings["num_videos"] = len(audio_paths) | |
| batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0 | |
| return batch_timings | |
| server = MuseTalkServerV3() | |
| app = FastAPI(title="MuseTalk API v3", version="3.0.0") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
| async def startup(): | |
| server.load_models() | |
| async def health(): | |
| return {"status": "ok" if server.is_loaded else "loading", "device": str(server.device), | |
| "avatars": list(server.loaded_avatars.keys()), "nvenc": server.use_nvenc} | |
| async def list_avatars(): | |
| avatars = [] | |
| for p in server.avatar_dir.iterdir(): | |
| if p.is_dir() and (p / "metadata.pkl").exists(): | |
| with open(p / "metadata.pkl", 'rb') as f: | |
| avatars.append(pickle.load(f)) | |
| return {"avatars": avatars} | |
| class GenReq(BaseModel): | |
| avatar_name: str | |
| audio_path: str | |
| output_path: str | |
| fps: int = 25 | |
| async def generate(req: GenReq): | |
| if not os.path.exists(req.audio_path): | |
| raise HTTPException(404, f"Audio not found: {req.audio_path}") | |
| try: | |
| timings = server.generate_with_avatar(req.avatar_name, req.audio_path, req.output_path, req.fps) | |
| return {"status": "success", "output_path": req.output_path, "timings": timings} | |
| except Exception as e: | |
| raise HTTPException(500, str(e)) | |
| class BatchReq(BaseModel): | |
| avatar_name: str | |
| audio_paths: List[str] | |
| output_dir: str | |
| fps: int = 25 | |
| async def batch(req: BatchReq): | |
| for p in req.audio_paths: | |
| if not os.path.exists(p): | |
| raise HTTPException(404, f"Audio not found: {p}") | |
| try: | |
| timings = server.generate_batch(req.avatar_name, req.audio_paths, req.output_dir, req.fps) | |
| return {"status": "success", "output_dir": req.output_dir, "timings": timings} | |
| except Exception as e: | |
| raise HTTPException(500, str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |