Diffusers
MuseTalk1.5 / server /fast_engine.py
Marcos
Add H.264 WebSocket streaming and React.js web interface
32bba92
Raw
History Blame Contribute Delete
27.9 kB
"""
MuseTalk Fast Engine - Optimized for real-time inference
Implements: Pre-loaded models, Avatar caching, Parallel chunk processing
MXFP4/Blackwell Optimizations (RTX 5090):
- torch.compile with max-autotune mode
- FP16 precision for all models
- TF32 enabled for matmuls
- Optional CUDA Graphs for fixed batch sizes
"""
import os
import sys
import torch
import numpy as np
import cv2
import pickle
import glob
import copy
import time
import uuid
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from queue import Queue
from typing import List, Tuple, Optional, Dict
import subprocess
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')
# Check for Blackwell GPU optimizations
def _check_blackwell_gpu() -> bool:
"""Check if running on Blackwell GPU (compute capability 12.x)"""
if not torch.cuda.is_available():
return False
cc = torch.cuda.get_device_capability(0)
return cc[0] >= 12
IS_BLACKWELL = _check_blackwell_gpu()
class FastMuseTalkEngine:
"""
Optimized MuseTalk engine with:
1. Pre-loaded models in memory
2. Pre-processed avatar cache
3. Parallel chunk processing
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
print("=" * 50)
print("Initializing FastMuseTalkEngine...")
print("=" * 50)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Paths
self.base_dir = Path(__file__).parent.parent
self.server_dir = Path(__file__).parent
# New video paths - separate videos for talking and idle
self.avatar_video = self.server_dir / "avatar_videos" / "falando.mp4" # For lip-sync
self.idle_video = self.server_dir / "avatar_videos" / "idle.mp4" # For idle animation
# Cache directory inside server folder
self.cache_dir = self.server_dir / "avatar_cache"
self.results_dir = self.base_dir / "results" / "server"
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.results_dir.mkdir(parents=True, exist_ok=True)
# Models (loaded once)
self.vae = None
self.unet = None
self.pe = None
self.audio_processor = None
self.timesteps = None
# Avatar cache for lip-sync (falando.mp4)
self.avatar_frame_paths = [] # paths to frame images
self.avatar_frames = [] # actual frame numpy arrays
self.avatar_latents = [] # pre-computed VAE latents
self.avatar_coords = [] # bounding box tuples (x1,y1,x2,y2)
self.avatar_fps = 30
self.avatar_loaded = False
# Idle animation cache (idle.mp4)
self.idle_frames = [] # frames for idle animation
self.idle_fps = 30
self.idle_loaded = False
# Thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=4)
self._initialized = True
def load_models(self, use_blackwell_optimizations: bool = True):
"""Load all models into GPU memory with optional Blackwell optimizations"""
if self.vae is not None:
print("Models already loaded")
return
print("\n[1/4] Loading VAE, UNet, PE models...")
start = time.time()
from musetalk.utils.utils import load_all_model
# Blackwell optimizations are now applied automatically in load_all_model
self.vae, self.unet, self.pe = load_all_model(
use_blackwell_optimizations=use_blackwell_optimizations
)
# Track optimization state
self._blackwell_optimized = use_blackwell_optimizations and IS_BLACKWELL
print(f" Models loaded in {time.time()-start:.2f}s")
print("\n[2/4] Loading Whisper audio processor...")
start = time.time()
from musetalk.whisper.audio2feature import Audio2Feature
whisper_path = self.base_dir / "models" / "whisper" / "tiny.pt"
self.audio_processor = Audio2Feature(model_path=str(whisper_path))
print(f" Whisper loaded in {time.time()-start:.2f}s")
self.timesteps = torch.tensor([0], device=self.device)
print("\n✓ All models loaded and ready!")
def preprocess_avatar(self, force=False):
"""Pre-process avatar video (falando.mp4) and cache everything"""
if self.avatar_loaded and not force:
print("Avatar already preprocessed")
return
cache_file = self.cache_dir / "falando_cache.pkl"
# Try to load from cache
if cache_file.exists() and not force:
print("\n[3/4] Loading avatar (falando) from cache...")
start = time.time()
with open(cache_file, 'rb') as f:
cache = pickle.load(f)
self.avatar_frame_paths = cache['frame_paths']
self.avatar_frames = cache['frames'] # actual numpy arrays
self.avatar_coords = cache['coords'] # list of (x1,y1,x2,y2) tuples
self.avatar_fps = cache['fps']
# Recompute latents (can't pickle CUDA tensors easily)
self._compute_latents()
self.avatar_loaded = True
print(f" Avatar (falando) loaded from cache in {time.time()-start:.2f}s")
return
print("\n[3/4] Pre-processing avatar video - falando.mp4 (first time only)...")
start = time.time()
# Extract frames
print(" Extracting frames from falando.mp4...")
import imageio
frames_dir = self.cache_dir / "falando_frames"
frames_dir.mkdir(exist_ok=True)
reader = imageio.get_reader(str(self.avatar_video))
meta = reader.get_meta_data()
self.avatar_fps = meta.get('fps', 30)
frame_paths = []
for i, frame in enumerate(reader):
frame_path = frames_dir / f"{i:08d}.png"
if not frame_path.exists():
imageio.imwrite(str(frame_path), frame)
frame_paths.append(str(frame_path))
self.avatar_frame_paths = frame_paths
print(f" Extracted {len(frame_paths)} frames")
# Get landmarks and bounding boxes
# Returns: coords_list (list of (x1,y1,x2,y2) tuples), frames (list of numpy arrays)
print(" Computing face landmarks...")
from musetalk.utils.preprocessing import get_landmark_and_bbox
self.avatar_coords, self.avatar_frames = get_landmark_and_bbox(frame_paths, 0)
# Save to cache
cache = {
'frame_paths': self.avatar_frame_paths,
'frames': self.avatar_frames,
'coords': self.avatar_coords,
'fps': self.avatar_fps
}
with open(cache_file, 'wb') as f:
pickle.dump(cache, f)
# Compute latents
self._compute_latents()
self.avatar_loaded = True
print(f" Avatar (falando) preprocessed in {time.time()-start:.2f}s")
def load_idle_video(self, force=False):
"""Load idle video (idle.mp4) frames for idle animation"""
if self.idle_loaded and not force:
print("Idle video already loaded")
return
cache_file = self.cache_dir / "idle_cache.pkl"
# Try to load from cache
if cache_file.exists() and not force:
print("\n[5/5] Loading idle video from cache...")
start = time.time()
with open(cache_file, 'rb') as f:
cache = pickle.load(f)
self.idle_frames = cache['frames']
self.idle_fps = cache['fps']
self.idle_loaded = True
print(f" Idle video loaded from cache: {len(self.idle_frames)} frames in {time.time()-start:.2f}s")
return
print("\n[5/5] Loading idle video - idle.mp4 (first time only)...")
start = time.time()
# Extract all frames from idle.mp4
cap = cv2.VideoCapture(str(self.idle_video))
self.idle_fps = cap.get(cv2.CAP_PROP_FPS)
self.idle_frames = []
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.idle_frames.append(frame_rgb)
cap.release()
# Save to cache
cache = {
'frames': self.idle_frames,
'fps': self.idle_fps
}
with open(cache_file, 'wb') as f:
pickle.dump(cache, f)
self.idle_loaded = True
print(f" Idle video loaded: {len(self.idle_frames)} frames @ {self.idle_fps} fps in {time.time()-start:.2f}s")
def _compute_latents(self):
"""Pre-compute VAE latents for all avatar frames"""
print("\n[4/4] Pre-computing VAE latents...")
start = time.time()
# coord_placeholder is a tuple (0.0, 0.0, 0.0, 0.0) used when no face detected
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
self.avatar_latents = []
for i, (bbox, frame) in enumerate(zip(self.avatar_coords, self.avatar_frames)):
# Skip invalid bboxes
if bbox == coord_placeholder or bbox is None:
if self.avatar_latents:
self.avatar_latents.append(self.avatar_latents[-1])
else:
# Create a placeholder latent
self.avatar_latents.append(None)
continue
x1, y1, x2, y2 = [int(c) for c in bbox]
# Ensure valid crop coordinates
h, w = frame.shape[:2]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
if x2 <= x1 or y2 <= y1:
if self.avatar_latents:
self.avatar_latents.append(self.avatar_latents[-1])
else:
self.avatar_latents.append(None)
continue
crop = frame[y1:y2, x1:x2]
crop = cv2.resize(crop, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latent = self.vae.get_latents_for_unet(crop)
self.avatar_latents.append(latent)
# Replace any None latents with the first valid one
first_valid = None
for lat in self.avatar_latents:
if lat is not None:
first_valid = lat
break
if first_valid is not None:
self.avatar_latents = [lat if lat is not None else first_valid for lat in self.avatar_latents]
print(f" Computed {len(self.avatar_latents)} latents in {time.time()-start:.2f}s")
def initialize(self):
"""Full initialization - call once at startup"""
total_start = time.time()
print("\n" + "=" * 50)
print("FAST ENGINE INITIALIZATION")
print("=" * 50)
self.load_models()
self.preprocess_avatar() # Load falando.mp4 for lip-sync
self.load_idle_video() # Load idle.mp4 for idle animation
print("\n" + "=" * 50)
print(f"INITIALIZATION COMPLETE in {time.time()-total_start:.2f}s")
print("=" * 50 + "\n")
def process_audio_chunk(self, audio_path: str, chunk_idx: int,
start_frame: int, end_frame: int) -> List[np.ndarray]:
"""Process a single audio chunk and generate frames"""
# Get audio features for this chunk
whisper_feature = self.audio_processor.audio2feat(audio_path)
whisper_chunks = self.audio_processor.feature2chunks(
feature_array=whisper_feature,
fps=self.avatar_fps
)
num_frames = min(len(whisper_chunks), end_frame - start_frame)
generated_frames = []
# Process in batches
batch_size = 8
for i in range(0, num_frames, batch_size):
batch_end = min(i + batch_size, num_frames)
# Get whisper features for batch
whisper_batch = whisper_chunks[i:batch_end]
# Get corresponding avatar latents (cycle if needed)
latent_batch = []
for j in range(len(whisper_batch)):
frame_idx = (start_frame + i + j) % len(self.avatar_latents)
latent_batch.append(self.avatar_latents[frame_idx])
# Convert to tensors
audio_tensor = torch.stack([
torch.FloatTensor(w).to(self.device)
for w in whisper_batch
])
latent_tensor = torch.cat(latent_batch, dim=0).to(self.device)
# Convert to FP16 if Blackwell optimizations enabled
if self._blackwell_optimized:
audio_tensor = audio_tensor.half()
latent_tensor = latent_tensor.half()
# Generate with UNet
with torch.no_grad():
pred = self.unet(
latent_tensor,
self.timesteps,
encoder_hidden_states=audio_tensor
).sample
recon = self.vae.decode_latents(pred)
for frame in recon:
generated_frames.append(frame)
return generated_frames
def compose_video_chunk(self, generated_frames: List[np.ndarray],
start_frame: int) -> List[np.ndarray]:
"""Compose generated frames with original avatar"""
from musetalk.utils.blending import get_image
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
composed_frames = []
for i, gen_frame in enumerate(generated_frames):
frame_idx = (start_frame + i) % len(self.avatar_frames)
bbox = self.avatar_coords[frame_idx]
orig_frame = self.avatar_frames[frame_idx].copy()
if bbox != coord_placeholder and bbox is not None:
x1, y1, x2, y2 = [int(c) for c in bbox]
gen_resized = cv2.resize(gen_frame, (x2-x1, y2-y1))
# Simple paste instead of blending for speed
orig_frame[y1:y2, x1:x2] = gen_resized
composed_frames.append(orig_frame)
else:
composed_frames.append(orig_frame)
return composed_frames
def generate_video_fast(self, audio_path: str, output_path: str = None,
resolution: int = 256, batch_size: int = 8,
callback=None) -> str:
"""
Generate lip-sync video with all optimizations.
resolution: output face size (128, 192, 256, 320) - affects quality/speed
batch_size: frames per batch (4, 8, 16) - affects VRAM usage/speed
callback(progress, message) is called with progress updates.
"""
if output_path is None:
output_path = str(self.results_dir / f"fast_{uuid.uuid4().hex[:8]}.mp4")
print(f"\n[VIDEO GENERATION] resolution={resolution}, batch_size={batch_size}")
print(f"video in {self.avatar_fps} FPS, audio idx in 50FPS")
total_start = time.time()
if callback:
callback(0, "Extraindo features de áudio...")
# Get total audio length and features
audio_start = time.time()
whisper_feature = self.audio_processor.audio2feat(audio_path)
whisper_chunks = self.audio_processor.feature2chunks(
feature_array=whisper_feature,
fps=self.avatar_fps
)
total_frames = len(whisper_chunks)
if callback:
callback(10, f"Processando {total_frames} frames...")
print(f"Audio processing: {time.time()-audio_start:.2f}s, {total_frames} frames")
# Process in parallel chunks
chunk_size = 50 # frames per chunk
num_chunks = (total_frames + chunk_size - 1) // chunk_size
all_generated = [None] * total_frames
gen_start = time.time()
# Process chunks
for chunk_idx in range(num_chunks):
start_frame = chunk_idx * chunk_size
end_frame = min(start_frame + chunk_size, total_frames)
# Get whisper features for this chunk
chunk_whisper = whisper_chunks[start_frame:end_frame]
# Process in batches (using batch_size parameter)
for i in range(0, len(chunk_whisper), batch_size):
batch_end = min(i + batch_size, len(chunk_whisper))
whisper_batch = chunk_whisper[i:batch_end]
# Get latents
latent_batch = []
for j in range(len(whisper_batch)):
frame_idx = (start_frame + i + j) % len(self.avatar_latents)
latent_batch.append(self.avatar_latents[frame_idx])
# Convert to tensors
audio_tensor = torch.stack([
torch.FloatTensor(w).to(self.device)
for w in whisper_batch
])
latent_tensor = torch.cat(latent_batch, dim=0).to(self.device)
# Convert to FP16 if Blackwell optimizations enabled
if self._blackwell_optimized:
audio_tensor = audio_tensor.half()
latent_tensor = latent_tensor.half()
# Generate
with torch.no_grad():
pred = self.unet.model(
latent_tensor,
self.timesteps,
encoder_hidden_states=audio_tensor
).sample
recon = self.vae.decode_latents(pred)
for j, frame in enumerate(recon):
all_generated[start_frame + i + j] = frame
progress = int(10 + (chunk_idx + 1) / num_chunks * 60)
if callback:
callback(progress, f"Gerado chunk {chunk_idx+1}/{num_chunks}")
print(f"Generation: {time.time()-gen_start:.2f}s")
if callback:
callback(70, "Compondo frames...")
# Compose all frames
compose_start = time.time()
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
# Scale factor based on resolution (256 is default/full quality)
scale_factor = resolution / 256.0
final_frames = []
for i, gen_frame in enumerate(all_generated):
if gen_frame is None:
continue
frame_idx = i % len(self.avatar_frames)
bbox = self.avatar_coords[frame_idx]
orig_frame = self.avatar_frames[frame_idx].copy()
if bbox != coord_placeholder and bbox is not None:
x1, y1, x2, y2 = [int(c) for c in bbox]
# Ensure valid coords
h, w = orig_frame.shape[:2]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
if x2 > x1 and y2 > y1:
# Apply resolution scaling for quality/speed tradeoff
target_w = max(32, int((x2-x1) * scale_factor))
target_h = max(32, int((y2-y1) * scale_factor))
gen_resized = cv2.resize(gen_frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
gen_resized = cv2.resize(gen_resized, (x2-x1, y2-y1), interpolation=cv2.INTER_LINEAR)
orig_frame[y1:y2, x1:x2] = gen_resized
final_frames.append(orig_frame)
else:
final_frames.append(orig_frame)
print(f"Composition: {time.time()-compose_start:.2f}s")
if callback:
callback(85, "Escrevendo vídeo...")
# Write video
write_start = time.time()
temp_video = output_path.replace('.mp4', '_temp.mp4')
h, w = final_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video, fourcc, self.avatar_fps, (w, h))
for frame in final_frames:
out.write(frame)
out.release()
print(f"Video write: {time.time()-write_start:.2f}s")
if callback:
callback(95, "Adicionando áudio...")
# Add audio
cmd = [
"ffmpeg", "-y", "-v", "quiet",
"-i", temp_video,
"-i", audio_path,
"-c:v", "libx264", "-preset", "fast",
"-c:a", "aac",
"-shortest",
output_path
]
subprocess.run(cmd, capture_output=True)
try:
os.remove(temp_video)
except:
pass
total_time = time.time() - total_start
print(f"\nTotal generation time: {total_time:.2f}s for {total_frames} frames")
print(f"Speed: {total_frames/total_time:.1f} fps")
if callback:
callback(100, "Concluído!")
return output_path
def generate_frames_streaming(self, audio_path: str, resolution: int = 256, batch_size: int = 8):
"""
Generator that yields frames one by one as they're generated.
Perfect for WebSocket streaming.
Yields:
{"type": "info", "total_frames": N, "fps": 30}
{"type": "frame", "frame": numpy_array, "index": N}
"""
print(f"\n[STREAMING] Starting frame generation: resolution={resolution}, batch_size={batch_size}")
# Get audio features
whisper_feature = self.audio_processor.audio2feat(audio_path)
whisper_chunks = self.audio_processor.feature2chunks(
feature_array=whisper_feature,
fps=self.avatar_fps
)
total_frames = len(whisper_chunks)
print(f"[STREAMING] Total frames to generate: {total_frames}")
# Send info first
yield {"type": "info", "total_frames": total_frames, "fps": self.avatar_fps}
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
scale_factor = resolution / 256.0
frame_index = 0
# Process in batches but yield frames one by one
for batch_start in range(0, total_frames, batch_size):
batch_end = min(batch_start + batch_size, total_frames)
whisper_batch = whisper_chunks[batch_start:batch_end]
# Get latents for batch
latent_batch = []
for j in range(len(whisper_batch)):
idx = (batch_start + j) % len(self.avatar_latents)
latent_batch.append(self.avatar_latents[idx])
# Convert to tensors
audio_tensor = torch.stack([
torch.FloatTensor(w).to(self.device)
for w in whisper_batch
])
latent_tensor = torch.cat(latent_batch, dim=0).to(self.device)
# Convert to FP16 if Blackwell optimizations enabled
if self._blackwell_optimized:
audio_tensor = audio_tensor.half()
latent_tensor = latent_tensor.half()
# Generate batch
with torch.no_grad():
pred = self.unet.model(
latent_tensor,
self.timesteps,
encoder_hidden_states=audio_tensor
).sample
recon = self.vae.decode_latents(pred)
# Yield each frame from the batch
for i, gen_frame in enumerate(recon):
global_idx = batch_start + i
if global_idx >= total_frames:
break
# Compose with original frame
avatar_idx = global_idx % len(self.avatar_frames)
bbox = self.avatar_coords[avatar_idx]
orig_frame = self.avatar_frames[avatar_idx].copy()
if bbox != coord_placeholder and bbox is not None:
x1, y1, x2, y2 = [int(c) for c in bbox]
h, w = orig_frame.shape[:2]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
if x2 > x1 and y2 > y1:
target_w = max(32, int((x2-x1) * scale_factor))
target_h = max(32, int((y2-y1) * scale_factor))
gen_resized = cv2.resize(gen_frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
gen_resized = cv2.resize(gen_resized, (x2-x1, y2-y1), interpolation=cv2.INTER_LINEAR)
orig_frame[y1:y2, x1:x2] = gen_resized
yield {"type": "frame", "frame": orig_frame, "index": frame_index}
frame_index += 1
print(f"[STREAMING] Completed: {frame_index} frames")
def get_idle_frames(self) -> Tuple[List[np.ndarray], float]:
"""
Get all frames from idle.mp4 for idle animation.
Returns: (frames_list, fps)
"""
if not self.idle_loaded or not self.idle_frames:
return [], 30.0
return self.idle_frames, self.idle_fps
def generate_video_streaming(self, audio_path: str,
chunk_callback=None) -> List[str]:
"""
Generate video in streaming chunks.
chunk_callback(chunk_idx, chunk_path) is called as each chunk is ready.
Returns list of chunk video paths.
"""
# Split audio into chunks
chunk_duration = 3 # seconds
# Get audio duration
import wave
try:
with wave.open(audio_path, 'rb') as wav:
duration = wav.getnframes() / wav.getframerate()
except:
# Fallback for non-wav
duration = 10 # assume 10 seconds
num_chunks = max(1, int(np.ceil(duration / chunk_duration)))
chunk_paths = []
for i in range(num_chunks):
start_time = i * chunk_duration
# Extract audio chunk
chunk_audio = str(self.results_dir / f"chunk_{i}_audio.wav")
cmd = [
"ffmpeg", "-y", "-v", "quiet",
"-i", audio_path,
"-ss", str(start_time),
"-t", str(chunk_duration),
"-ar", "16000", "-ac", "1",
chunk_audio
]
subprocess.run(cmd, capture_output=True)
# Generate video for this chunk
chunk_video = str(self.results_dir / f"chunk_{i}_{uuid.uuid4().hex[:6]}.mp4")
start_frame = int(start_time * self.avatar_fps)
self.generate_video_fast(chunk_audio, chunk_video)
chunk_paths.append(chunk_video)
if chunk_callback:
chunk_callback(i, chunk_video)
# Cleanup audio chunk
try:
os.remove(chunk_audio)
except:
pass
return chunk_paths
# Global engine instance
_engine: Optional[FastMuseTalkEngine] = None
def get_engine() -> FastMuseTalkEngine:
"""Get or create the singleton engine"""
global _engine
if _engine is None:
_engine = FastMuseTalkEngine()
return _engine
def initialize_engine():
"""Initialize the engine (call at server startup)"""
engine = get_engine()
engine.initialize()
return engine
if __name__ == "__main__":
# Test the engine
engine = initialize_engine()
# Test with sample audio
test_audio = "/workspace/MuseTalk1.5/data/audio/teacher_intro.wav"
if os.path.exists(test_audio):
print("\nTesting video generation...")
output = engine.generate_video_fast(test_audio)
print(f"Output: {output}")