|
|
""" |
|
|
MuseTalk HTTP API Server v3 (Fixed) |
|
|
Optimized with: |
|
|
1. Sequential face blending (parallel had overhead) |
|
|
2. NVENC hardware video encoding |
|
|
3. Batch audio processing |
|
|
""" |
|
|
import os |
|
|
import cv2 |
|
|
import copy |
|
|
import torch |
|
|
import glob |
|
|
import shutil |
|
|
import pickle |
|
|
import numpy as np |
|
|
import subprocess |
|
|
import tempfile |
|
|
import hashlib |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Optional, List |
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from tqdm import tqdm |
|
|
from transformers import WhisperModel |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
from musetalk.utils.blending import get_image |
|
|
from musetalk.utils.face_parsing import FaceParsing |
|
|
from musetalk.utils.audio_processor import AudioProcessor |
|
|
from musetalk.utils.utils import get_file_type, datagen, load_all_model |
|
|
from musetalk.utils.preprocessing import coord_placeholder |
|
|
|
|
|
|
|
|
class MuseTalkServerV3: |
|
|
def __init__(self): |
|
|
self.device = None |
|
|
self.vae = None |
|
|
self.unet = None |
|
|
self.pe = None |
|
|
self.whisper = None |
|
|
self.audio_processor = None |
|
|
self.fp = None |
|
|
self.timesteps = None |
|
|
self.weight_dtype = None |
|
|
self.is_loaded = False |
|
|
|
|
|
self.loaded_avatars = {} |
|
|
self.avatar_dir = Path("./avatars") |
|
|
|
|
|
self.fps = 25 |
|
|
self.batch_size = 8 |
|
|
self.use_float16 = True |
|
|
self.version = "v15" |
|
|
self.extra_margin = 10 |
|
|
self.parsing_mode = "jaw" |
|
|
self.left_cheek_width = 90 |
|
|
self.right_cheek_width = 90 |
|
|
self.audio_padding_left = 2 |
|
|
self.audio_padding_right = 2 |
|
|
|
|
|
|
|
|
self.use_nvenc = True |
|
|
self.nvenc_preset = "p4" |
|
|
self.crf = 23 |
|
|
|
|
|
def load_models(self, gpu_id: int = 0): |
|
|
if self.is_loaded: |
|
|
print("Models already loaded!") |
|
|
return |
|
|
|
|
|
print("=" * 50) |
|
|
print("Loading MuseTalk models (v3 Optimized)...") |
|
|
print("=" * 50) |
|
|
|
|
|
start_time = time.time() |
|
|
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.vae, self.unet, self.pe = load_all_model( |
|
|
unet_model_path="./models/musetalkV15/unet.pth", |
|
|
vae_type="sd-vae", |
|
|
unet_config="./models/musetalk/config.json", |
|
|
device=self.device |
|
|
) |
|
|
self.timesteps = torch.tensor([0], device=self.device) |
|
|
|
|
|
self.pe = self.pe.half().to(self.device) |
|
|
self.vae.vae = self.vae.vae.half().to(self.device) |
|
|
self.unet.model = self.unet.model.half().to(self.device) |
|
|
|
|
|
self.audio_processor = AudioProcessor(feature_extractor_path="./models/whisper") |
|
|
self.weight_dtype = self.unet.model.dtype |
|
|
self.whisper = WhisperModel.from_pretrained("./models/whisper") |
|
|
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() |
|
|
self.whisper.requires_grad_(False) |
|
|
|
|
|
self.fp = FaceParsing( |
|
|
left_cheek_width=self.left_cheek_width, |
|
|
right_cheek_width=self.right_cheek_width |
|
|
) |
|
|
|
|
|
self.is_loaded = True |
|
|
print(f"Models loaded in {time.time() - start_time:.2f}s") |
|
|
|
|
|
def load_avatar(self, avatar_name: str) -> dict: |
|
|
if avatar_name in self.loaded_avatars: |
|
|
return self.loaded_avatars[avatar_name] |
|
|
|
|
|
avatar_path = self.avatar_dir / avatar_name |
|
|
if not avatar_path.exists(): |
|
|
raise FileNotFoundError(f"Avatar not found: {avatar_name}") |
|
|
|
|
|
avatar_data = {} |
|
|
with open(avatar_path / "metadata.pkl", 'rb') as f: |
|
|
avatar_data['metadata'] = pickle.load(f) |
|
|
with open(avatar_path / "coords.pkl", 'rb') as f: |
|
|
avatar_data['coord_list'] = pickle.load(f) |
|
|
with open(avatar_path / "frames.pkl", 'rb') as f: |
|
|
avatar_data['frame_list'] = pickle.load(f) |
|
|
with open(avatar_path / "latents.pkl", 'rb') as f: |
|
|
latents_np = pickle.load(f) |
|
|
avatar_data['latent_list'] = [torch.from_numpy(l).to(self.device) for l in latents_np] |
|
|
|
|
|
self.loaded_avatars[avatar_name] = avatar_data |
|
|
return avatar_data |
|
|
|
|
|
def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float: |
|
|
t0 = time.time() |
|
|
temp_vid = output_path.replace('.mp4', '_temp.mp4') |
|
|
|
|
|
if self.use_nvenc: |
|
|
cmd = ( |
|
|
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " |
|
|
f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} -pix_fmt yuv420p {temp_vid}" |
|
|
) |
|
|
else: |
|
|
cmd = ( |
|
|
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " |
|
|
f"-vcodec libx264 -crf 18 -pix_fmt yuv420p {temp_vid}" |
|
|
) |
|
|
os.system(cmd) |
|
|
|
|
|
os.system(f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}") |
|
|
os.remove(temp_vid) if os.path.exists(temp_vid) else None |
|
|
|
|
|
return time.time() - t0 |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_with_avatar(self, avatar_name: str, audio_path: str, output_path: str, fps: int = 25) -> dict: |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("Models not loaded!") |
|
|
|
|
|
timings = {} |
|
|
total_start = time.time() |
|
|
|
|
|
t0 = time.time() |
|
|
avatar = self.load_avatar(avatar_name) |
|
|
timings["avatar_load"] = time.time() - t0 |
|
|
|
|
|
coord_list = avatar['coord_list'] |
|
|
frame_list = avatar['frame_list'] |
|
|
input_latent_list = avatar['latent_list'] |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
try: |
|
|
|
|
|
t0 = time.time() |
|
|
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) |
|
|
whisper_chunks = self.audio_processor.get_whisper_chunk( |
|
|
whisper_input_features, self.device, self.weight_dtype, self.whisper, |
|
|
librosa_length, fps=fps, |
|
|
audio_padding_length_left=self.audio_padding_left, |
|
|
audio_padding_length_right=self.audio_padding_right, |
|
|
) |
|
|
timings["whisper_features"] = time.time() - t0 |
|
|
|
|
|
|
|
|
frame_list_cycle = frame_list + frame_list[::-1] |
|
|
coord_list_cycle = coord_list + coord_list[::-1] |
|
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, |
|
|
batch_size=self.batch_size, delay_frame=0, device=self.device) |
|
|
|
|
|
res_frame_list = [] |
|
|
for whisper_batch, latent_batch in gen: |
|
|
audio_feature_batch = self.pe(whisper_batch) |
|
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
|
|
pred_latents = self.unet.model(latent_batch, self.timesteps, |
|
|
encoder_hidden_states=audio_feature_batch).sample |
|
|
recon = self.vae.decode_latents(pred_latents) |
|
|
res_frame_list.extend(recon) |
|
|
timings["unet_inference"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
result_img_path = os.path.join(temp_dir, "results") |
|
|
os.makedirs(result_img_path, exist_ok=True) |
|
|
|
|
|
for i, res_frame in enumerate(res_frame_list): |
|
|
bbox = coord_list_cycle[i % len(coord_list_cycle)] |
|
|
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) |
|
|
x1, y1, x2, y2 = bbox |
|
|
y2 = min(y2 + self.extra_margin, ori_frame.shape[0]) |
|
|
|
|
|
try: |
|
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], |
|
|
mode=self.parsing_mode, fp=self.fp) |
|
|
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) |
|
|
except: |
|
|
continue |
|
|
timings["face_blending"] = time.time() - t0 |
|
|
|
|
|
|
|
|
timings["video_encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps) |
|
|
|
|
|
finally: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
timings["total"] = time.time() - total_start |
|
|
timings["frames_generated"] = len(res_frame_list) |
|
|
return timings |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_batch(self, avatar_name: str, audio_paths: List[str], output_dir: str, fps: int = 25) -> dict: |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("Models not loaded!") |
|
|
|
|
|
batch_timings = {"videos": [], "total": 0} |
|
|
total_start = time.time() |
|
|
|
|
|
t0 = time.time() |
|
|
avatar = self.load_avatar(avatar_name) |
|
|
batch_timings["avatar_load"] = time.time() - t0 |
|
|
|
|
|
coord_list = avatar['coord_list'] |
|
|
frame_list = avatar['frame_list'] |
|
|
input_latent_list = avatar['latent_list'] |
|
|
frame_list_cycle = frame_list + frame_list[::-1] |
|
|
coord_list_cycle = coord_list + coord_list[::-1] |
|
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
for idx, audio_path in enumerate(audio_paths): |
|
|
video_start = time.time() |
|
|
timings = {} |
|
|
output_path = os.path.join(output_dir, f"{Path(audio_path).stem}.mp4") |
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
try: |
|
|
t0 = time.time() |
|
|
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) |
|
|
whisper_chunks = self.audio_processor.get_whisper_chunk( |
|
|
whisper_input_features, self.device, self.weight_dtype, self.whisper, |
|
|
librosa_length, fps=fps, |
|
|
audio_padding_length_left=self.audio_padding_left, |
|
|
audio_padding_length_right=self.audio_padding_right, |
|
|
) |
|
|
timings["whisper"] = time.time() - t0 |
|
|
|
|
|
t0 = time.time() |
|
|
gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, |
|
|
batch_size=self.batch_size, delay_frame=0, device=self.device) |
|
|
res_frame_list = [] |
|
|
for whisper_batch, latent_batch in gen: |
|
|
audio_feature_batch = self.pe(whisper_batch) |
|
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
|
|
pred_latents = self.unet.model(latent_batch, self.timesteps, |
|
|
encoder_hidden_states=audio_feature_batch).sample |
|
|
res_frame_list.extend(self.vae.decode_latents(pred_latents)) |
|
|
timings["unet"] = time.time() - t0 |
|
|
|
|
|
t0 = time.time() |
|
|
result_img_path = os.path.join(temp_dir, "results") |
|
|
os.makedirs(result_img_path, exist_ok=True) |
|
|
for i, res_frame in enumerate(res_frame_list): |
|
|
bbox = coord_list_cycle[i % len(coord_list_cycle)] |
|
|
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) |
|
|
x1, y1, x2, y2 = bbox |
|
|
y2 = min(y2 + self.extra_margin, ori_frame.shape[0]) |
|
|
try: |
|
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], |
|
|
mode=self.parsing_mode, fp=self.fp) |
|
|
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) |
|
|
except: |
|
|
continue |
|
|
timings["blending"] = time.time() - t0 |
|
|
|
|
|
timings["encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps) |
|
|
|
|
|
finally: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
timings["total"] = time.time() - video_start |
|
|
timings["frames"] = len(res_frame_list) |
|
|
timings["output"] = output_path |
|
|
batch_timings["videos"].append(timings) |
|
|
print(f" [{idx+1}/{len(audio_paths)}] {Path(audio_path).stem}: {timings['total']:.2f}s") |
|
|
|
|
|
batch_timings["total"] = time.time() - total_start |
|
|
batch_timings["num_videos"] = len(audio_paths) |
|
|
batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0 |
|
|
return batch_timings |
|
|
|
|
|
|
|
|
server = MuseTalkServerV3() |
|
|
app = FastAPI(title="MuseTalk API v3", version="3.0.0") |
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup(): |
|
|
server.load_models() |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return {"status": "ok" if server.is_loaded else "loading", "device": str(server.device), |
|
|
"avatars": list(server.loaded_avatars.keys()), "nvenc": server.use_nvenc} |
|
|
|
|
|
@app.get("/avatars") |
|
|
async def list_avatars(): |
|
|
avatars = [] |
|
|
for p in server.avatar_dir.iterdir(): |
|
|
if p.is_dir() and (p / "metadata.pkl").exists(): |
|
|
with open(p / "metadata.pkl", 'rb') as f: |
|
|
avatars.append(pickle.load(f)) |
|
|
return {"avatars": avatars} |
|
|
|
|
|
class GenReq(BaseModel): |
|
|
avatar_name: str |
|
|
audio_path: str |
|
|
output_path: str |
|
|
fps: int = 25 |
|
|
|
|
|
@app.post("/generate/avatar") |
|
|
async def generate(req: GenReq): |
|
|
if not os.path.exists(req.audio_path): |
|
|
raise HTTPException(404, f"Audio not found: {req.audio_path}") |
|
|
try: |
|
|
timings = server.generate_with_avatar(req.avatar_name, req.audio_path, req.output_path, req.fps) |
|
|
return {"status": "success", "output_path": req.output_path, "timings": timings} |
|
|
except Exception as e: |
|
|
raise HTTPException(500, str(e)) |
|
|
|
|
|
class BatchReq(BaseModel): |
|
|
avatar_name: str |
|
|
audio_paths: List[str] |
|
|
output_dir: str |
|
|
fps: int = 25 |
|
|
|
|
|
@app.post("/generate/batch") |
|
|
async def batch(req: BatchReq): |
|
|
for p in req.audio_paths: |
|
|
if not os.path.exists(p): |
|
|
raise HTTPException(404, f"Audio not found: {p}") |
|
|
try: |
|
|
timings = server.generate_batch(req.avatar_name, req.audio_paths, req.output_dir, req.fps) |
|
|
return {"status": "success", "output_dir": req.output_dir, "timings": timings} |
|
|
except Exception as e: |
|
|
raise HTTPException(500, str(e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|