""" MuseTalk HTTP API Server v3 (Fixed) Optimized with: 1. Sequential face blending (parallel had overhead) 2. NVENC hardware video encoding 3. Batch audio processing """ import os import cv2 import copy import torch import glob import shutil import pickle import numpy as np import subprocess import tempfile import hashlib import time from pathlib import Path from typing import Optional, List from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from tqdm import tqdm from transformers import WhisperModel import uvicorn # MuseTalk imports from musetalk.utils.blending import get_image from musetalk.utils.face_parsing import FaceParsing from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, datagen, load_all_model from musetalk.utils.preprocessing import coord_placeholder class MuseTalkServerV3: def __init__(self): self.device = None self.vae = None self.unet = None self.pe = None self.whisper = None self.audio_processor = None self.fp = None self.timesteps = None self.weight_dtype = None self.is_loaded = False self.loaded_avatars = {} self.avatar_dir = Path("./avatars") self.fps = 25 self.batch_size = 8 self.use_float16 = True self.version = "v15" self.extra_margin = 10 self.parsing_mode = "jaw" self.left_cheek_width = 90 self.right_cheek_width = 90 self.audio_padding_left = 2 self.audio_padding_right = 2 # NVENC self.use_nvenc = True self.nvenc_preset = "p4" self.crf = 23 def load_models(self, gpu_id: int = 0): if self.is_loaded: print("Models already loaded!") return print("=" * 50) print("Loading MuseTalk models (v3 Optimized)...") print("=" * 50) start_time = time.time() self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") self.vae, self.unet, self.pe = load_all_model( unet_model_path="./models/musetalkV15/unet.pth", vae_type="sd-vae", unet_config="./models/musetalk/config.json", device=self.device ) self.timesteps = torch.tensor([0], device=self.device) self.pe = self.pe.half().to(self.device) self.vae.vae = self.vae.vae.half().to(self.device) self.unet.model = self.unet.model.half().to(self.device) self.audio_processor = AudioProcessor(feature_extractor_path="./models/whisper") self.weight_dtype = self.unet.model.dtype self.whisper = WhisperModel.from_pretrained("./models/whisper") self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() self.whisper.requires_grad_(False) self.fp = FaceParsing( left_cheek_width=self.left_cheek_width, right_cheek_width=self.right_cheek_width ) self.is_loaded = True print(f"Models loaded in {time.time() - start_time:.2f}s") def load_avatar(self, avatar_name: str) -> dict: if avatar_name in self.loaded_avatars: return self.loaded_avatars[avatar_name] avatar_path = self.avatar_dir / avatar_name if not avatar_path.exists(): raise FileNotFoundError(f"Avatar not found: {avatar_name}") avatar_data = {} with open(avatar_path / "metadata.pkl", 'rb') as f: avatar_data['metadata'] = pickle.load(f) with open(avatar_path / "coords.pkl", 'rb') as f: avatar_data['coord_list'] = pickle.load(f) with open(avatar_path / "frames.pkl", 'rb') as f: avatar_data['frame_list'] = pickle.load(f) with open(avatar_path / "latents.pkl", 'rb') as f: latents_np = pickle.load(f) avatar_data['latent_list'] = [torch.from_numpy(l).to(self.device) for l in latents_np] self.loaded_avatars[avatar_name] = avatar_data return avatar_data def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float: t0 = time.time() temp_vid = output_path.replace('.mp4', '_temp.mp4') if self.use_nvenc: cmd = ( f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} -pix_fmt yuv420p {temp_vid}" ) else: cmd = ( f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " f"-vcodec libx264 -crf 18 -pix_fmt yuv420p {temp_vid}" ) os.system(cmd) os.system(f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}") os.remove(temp_vid) if os.path.exists(temp_vid) else None return time.time() - t0 @torch.no_grad() def generate_with_avatar(self, avatar_name: str, audio_path: str, output_path: str, fps: int = 25) -> dict: if not self.is_loaded: raise RuntimeError("Models not loaded!") timings = {} total_start = time.time() t0 = time.time() avatar = self.load_avatar(avatar_name) timings["avatar_load"] = time.time() - t0 coord_list = avatar['coord_list'] frame_list = avatar['frame_list'] input_latent_list = avatar['latent_list'] temp_dir = tempfile.mkdtemp() try: # Whisper t0 = time.time() whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) whisper_chunks = self.audio_processor.get_whisper_chunk( whisper_input_features, self.device, self.weight_dtype, self.whisper, librosa_length, fps=fps, audio_padding_length_left=self.audio_padding_left, audio_padding_length_right=self.audio_padding_right, ) timings["whisper_features"] = time.time() - t0 # Cycle lists frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] # UNet t0 = time.time() gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=self.batch_size, delay_frame=0, device=self.device) res_frame_list = [] for whisper_batch, latent_batch in gen: audio_feature_batch = self.pe(whisper_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) pred_latents = self.unet.model(latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch).sample recon = self.vae.decode_latents(pred_latents) res_frame_list.extend(recon) timings["unet_inference"] = time.time() - t0 # Face blending (sequential - faster than parallel due to FP overhead) t0 = time.time() result_img_path = os.path.join(temp_dir, "results") os.makedirs(result_img_path, exist_ok=True) for i, res_frame in enumerate(res_frame_list): bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) x1, y1, x2, y2 = bbox y2 = min(y2 + self.extra_margin, ori_frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=self.parsing_mode, fp=self.fp) cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) except: continue timings["face_blending"] = time.time() - t0 # NVENC encoding timings["video_encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps) finally: shutil.rmtree(temp_dir, ignore_errors=True) timings["total"] = time.time() - total_start timings["frames_generated"] = len(res_frame_list) return timings @torch.no_grad() def generate_batch(self, avatar_name: str, audio_paths: List[str], output_dir: str, fps: int = 25) -> dict: if not self.is_loaded: raise RuntimeError("Models not loaded!") batch_timings = {"videos": [], "total": 0} total_start = time.time() t0 = time.time() avatar = self.load_avatar(avatar_name) batch_timings["avatar_load"] = time.time() - t0 coord_list = avatar['coord_list'] frame_list = avatar['frame_list'] input_latent_list = avatar['latent_list'] frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] os.makedirs(output_dir, exist_ok=True) for idx, audio_path in enumerate(audio_paths): video_start = time.time() timings = {} output_path = os.path.join(output_dir, f"{Path(audio_path).stem}.mp4") temp_dir = tempfile.mkdtemp() try: t0 = time.time() whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) whisper_chunks = self.audio_processor.get_whisper_chunk( whisper_input_features, self.device, self.weight_dtype, self.whisper, librosa_length, fps=fps, audio_padding_length_left=self.audio_padding_left, audio_padding_length_right=self.audio_padding_right, ) timings["whisper"] = time.time() - t0 t0 = time.time() gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=self.batch_size, delay_frame=0, device=self.device) res_frame_list = [] for whisper_batch, latent_batch in gen: audio_feature_batch = self.pe(whisper_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) pred_latents = self.unet.model(latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch).sample res_frame_list.extend(self.vae.decode_latents(pred_latents)) timings["unet"] = time.time() - t0 t0 = time.time() result_img_path = os.path.join(temp_dir, "results") os.makedirs(result_img_path, exist_ok=True) for i, res_frame in enumerate(res_frame_list): bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) x1, y1, x2, y2 = bbox y2 = min(y2 + self.extra_margin, ori_frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=self.parsing_mode, fp=self.fp) cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) except: continue timings["blending"] = time.time() - t0 timings["encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps) finally: shutil.rmtree(temp_dir, ignore_errors=True) timings["total"] = time.time() - video_start timings["frames"] = len(res_frame_list) timings["output"] = output_path batch_timings["videos"].append(timings) print(f" [{idx+1}/{len(audio_paths)}] {Path(audio_path).stem}: {timings['total']:.2f}s") batch_timings["total"] = time.time() - total_start batch_timings["num_videos"] = len(audio_paths) batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0 return batch_timings server = MuseTalkServerV3() app = FastAPI(title="MuseTalk API v3", version="3.0.0") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.on_event("startup") async def startup(): server.load_models() @app.get("/health") async def health(): return {"status": "ok" if server.is_loaded else "loading", "device": str(server.device), "avatars": list(server.loaded_avatars.keys()), "nvenc": server.use_nvenc} @app.get("/avatars") async def list_avatars(): avatars = [] for p in server.avatar_dir.iterdir(): if p.is_dir() and (p / "metadata.pkl").exists(): with open(p / "metadata.pkl", 'rb') as f: avatars.append(pickle.load(f)) return {"avatars": avatars} class GenReq(BaseModel): avatar_name: str audio_path: str output_path: str fps: int = 25 @app.post("/generate/avatar") async def generate(req: GenReq): if not os.path.exists(req.audio_path): raise HTTPException(404, f"Audio not found: {req.audio_path}") try: timings = server.generate_with_avatar(req.avatar_name, req.audio_path, req.output_path, req.fps) return {"status": "success", "output_path": req.output_path, "timings": timings} except Exception as e: raise HTTPException(500, str(e)) class BatchReq(BaseModel): avatar_name: str audio_paths: List[str] output_dir: str fps: int = 25 @app.post("/generate/batch") async def batch(req: BatchReq): for p in req.audio_paths: if not os.path.exists(p): raise HTTPException(404, f"Audio not found: {p}") try: timings = server.generate_batch(req.avatar_name, req.audio_paths, req.output_dir, req.fps) return {"status": "success", "output_dir": req.output_dir, "timings": timings} except Exception as e: raise HTTPException(500, str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)