""" MuseTalk Fast Engine - Optimized for real-time inference Implements: Pre-loaded models, Avatar caching, Parallel chunk processing MXFP4/Blackwell Optimizations (RTX 5090): - torch.compile with max-autotune mode - FP16 precision for all models - TF32 enabled for matmuls - Optional CUDA Graphs for fixed batch sizes """ import os import sys import torch import numpy as np import cv2 import pickle import glob import copy import time import uuid import threading from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed from queue import Queue from typing import List, Tuple, Optional, Dict import subprocess # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) # Suppress warnings import warnings warnings.filterwarnings('ignore') # Check for Blackwell GPU optimizations def _check_blackwell_gpu() -> bool: """Check if running on Blackwell GPU (compute capability 12.x)""" if not torch.cuda.is_available(): return False cc = torch.cuda.get_device_capability(0) return cc[0] >= 12 IS_BLACKWELL = _check_blackwell_gpu() class FastMuseTalkEngine: """ Optimized MuseTalk engine with: 1. Pre-loaded models in memory 2. Pre-processed avatar cache 3. Parallel chunk processing """ _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return print("=" * 50) print("Initializing FastMuseTalkEngine...") print("=" * 50) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Paths self.base_dir = Path(__file__).parent.parent self.server_dir = Path(__file__).parent # New video paths - separate videos for talking and idle self.avatar_video = self.server_dir / "avatar_videos" / "falando.mp4" # For lip-sync self.idle_video = self.server_dir / "avatar_videos" / "idle.mp4" # For idle animation # Cache directory inside server folder self.cache_dir = self.server_dir / "avatar_cache" self.results_dir = self.base_dir / "results" / "server" self.cache_dir.mkdir(parents=True, exist_ok=True) self.results_dir.mkdir(parents=True, exist_ok=True) # Models (loaded once) self.vae = None self.unet = None self.pe = None self.audio_processor = None self.timesteps = None # Avatar cache for lip-sync (falando.mp4) self.avatar_frame_paths = [] # paths to frame images self.avatar_frames = [] # actual frame numpy arrays self.avatar_latents = [] # pre-computed VAE latents self.avatar_coords = [] # bounding box tuples (x1,y1,x2,y2) self.avatar_fps = 30 self.avatar_loaded = False # Idle animation cache (idle.mp4) self.idle_frames = [] # frames for idle animation self.idle_fps = 30 self.idle_loaded = False # Thread pool for parallel processing self.executor = ThreadPoolExecutor(max_workers=4) self._initialized = True def load_models(self, use_blackwell_optimizations: bool = True): """Load all models into GPU memory with optional Blackwell optimizations""" if self.vae is not None: print("Models already loaded") return print("\n[1/4] Loading VAE, UNet, PE models...") start = time.time() from musetalk.utils.utils import load_all_model # Blackwell optimizations are now applied automatically in load_all_model self.vae, self.unet, self.pe = load_all_model( use_blackwell_optimizations=use_blackwell_optimizations ) # Track optimization state self._blackwell_optimized = use_blackwell_optimizations and IS_BLACKWELL print(f" Models loaded in {time.time()-start:.2f}s") print("\n[2/4] Loading Whisper audio processor...") start = time.time() from musetalk.whisper.audio2feature import Audio2Feature whisper_path = self.base_dir / "models" / "whisper" / "tiny.pt" self.audio_processor = Audio2Feature(model_path=str(whisper_path)) print(f" Whisper loaded in {time.time()-start:.2f}s") self.timesteps = torch.tensor([0], device=self.device) print("\n✓ All models loaded and ready!") def preprocess_avatar(self, force=False): """Pre-process avatar video (falando.mp4) and cache everything""" if self.avatar_loaded and not force: print("Avatar already preprocessed") return cache_file = self.cache_dir / "falando_cache.pkl" # Try to load from cache if cache_file.exists() and not force: print("\n[3/4] Loading avatar (falando) from cache...") start = time.time() with open(cache_file, 'rb') as f: cache = pickle.load(f) self.avatar_frame_paths = cache['frame_paths'] self.avatar_frames = cache['frames'] # actual numpy arrays self.avatar_coords = cache['coords'] # list of (x1,y1,x2,y2) tuples self.avatar_fps = cache['fps'] # Recompute latents (can't pickle CUDA tensors easily) self._compute_latents() self.avatar_loaded = True print(f" Avatar (falando) loaded from cache in {time.time()-start:.2f}s") return print("\n[3/4] Pre-processing avatar video - falando.mp4 (first time only)...") start = time.time() # Extract frames print(" Extracting frames from falando.mp4...") import imageio frames_dir = self.cache_dir / "falando_frames" frames_dir.mkdir(exist_ok=True) reader = imageio.get_reader(str(self.avatar_video)) meta = reader.get_meta_data() self.avatar_fps = meta.get('fps', 30) frame_paths = [] for i, frame in enumerate(reader): frame_path = frames_dir / f"{i:08d}.png" if not frame_path.exists(): imageio.imwrite(str(frame_path), frame) frame_paths.append(str(frame_path)) self.avatar_frame_paths = frame_paths print(f" Extracted {len(frame_paths)} frames") # Get landmarks and bounding boxes # Returns: coords_list (list of (x1,y1,x2,y2) tuples), frames (list of numpy arrays) print(" Computing face landmarks...") from musetalk.utils.preprocessing import get_landmark_and_bbox self.avatar_coords, self.avatar_frames = get_landmark_and_bbox(frame_paths, 0) # Save to cache cache = { 'frame_paths': self.avatar_frame_paths, 'frames': self.avatar_frames, 'coords': self.avatar_coords, 'fps': self.avatar_fps } with open(cache_file, 'wb') as f: pickle.dump(cache, f) # Compute latents self._compute_latents() self.avatar_loaded = True print(f" Avatar (falando) preprocessed in {time.time()-start:.2f}s") def load_idle_video(self, force=False): """Load idle video (idle.mp4) frames for idle animation""" if self.idle_loaded and not force: print("Idle video already loaded") return cache_file = self.cache_dir / "idle_cache.pkl" # Try to load from cache if cache_file.exists() and not force: print("\n[5/5] Loading idle video from cache...") start = time.time() with open(cache_file, 'rb') as f: cache = pickle.load(f) self.idle_frames = cache['frames'] self.idle_fps = cache['fps'] self.idle_loaded = True print(f" Idle video loaded from cache: {len(self.idle_frames)} frames in {time.time()-start:.2f}s") return print("\n[5/5] Loading idle video - idle.mp4 (first time only)...") start = time.time() # Extract all frames from idle.mp4 cap = cv2.VideoCapture(str(self.idle_video)) self.idle_fps = cap.get(cv2.CAP_PROP_FPS) self.idle_frames = [] while True: ret, frame = cap.read() if not ret: break # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) self.idle_frames.append(frame_rgb) cap.release() # Save to cache cache = { 'frames': self.idle_frames, 'fps': self.idle_fps } with open(cache_file, 'wb') as f: pickle.dump(cache, f) self.idle_loaded = True print(f" Idle video loaded: {len(self.idle_frames)} frames @ {self.idle_fps} fps in {time.time()-start:.2f}s") def _compute_latents(self): """Pre-compute VAE latents for all avatar frames""" print("\n[4/4] Pre-computing VAE latents...") start = time.time() # coord_placeholder is a tuple (0.0, 0.0, 0.0, 0.0) used when no face detected coord_placeholder = (0.0, 0.0, 0.0, 0.0) self.avatar_latents = [] for i, (bbox, frame) in enumerate(zip(self.avatar_coords, self.avatar_frames)): # Skip invalid bboxes if bbox == coord_placeholder or bbox is None: if self.avatar_latents: self.avatar_latents.append(self.avatar_latents[-1]) else: # Create a placeholder latent self.avatar_latents.append(None) continue x1, y1, x2, y2 = [int(c) for c in bbox] # Ensure valid crop coordinates h, w = frame.shape[:2] x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w, x2), min(h, y2) if x2 <= x1 or y2 <= y1: if self.avatar_latents: self.avatar_latents.append(self.avatar_latents[-1]) else: self.avatar_latents.append(None) continue crop = frame[y1:y2, x1:x2] crop = cv2.resize(crop, (256, 256), interpolation=cv2.INTER_LANCZOS4) latent = self.vae.get_latents_for_unet(crop) self.avatar_latents.append(latent) # Replace any None latents with the first valid one first_valid = None for lat in self.avatar_latents: if lat is not None: first_valid = lat break if first_valid is not None: self.avatar_latents = [lat if lat is not None else first_valid for lat in self.avatar_latents] print(f" Computed {len(self.avatar_latents)} latents in {time.time()-start:.2f}s") def initialize(self): """Full initialization - call once at startup""" total_start = time.time() print("\n" + "=" * 50) print("FAST ENGINE INITIALIZATION") print("=" * 50) self.load_models() self.preprocess_avatar() # Load falando.mp4 for lip-sync self.load_idle_video() # Load idle.mp4 for idle animation print("\n" + "=" * 50) print(f"INITIALIZATION COMPLETE in {time.time()-total_start:.2f}s") print("=" * 50 + "\n") def process_audio_chunk(self, audio_path: str, chunk_idx: int, start_frame: int, end_frame: int) -> List[np.ndarray]: """Process a single audio chunk and generate frames""" # Get audio features for this chunk whisper_feature = self.audio_processor.audio2feat(audio_path) whisper_chunks = self.audio_processor.feature2chunks( feature_array=whisper_feature, fps=self.avatar_fps ) num_frames = min(len(whisper_chunks), end_frame - start_frame) generated_frames = [] # Process in batches batch_size = 8 for i in range(0, num_frames, batch_size): batch_end = min(i + batch_size, num_frames) # Get whisper features for batch whisper_batch = whisper_chunks[i:batch_end] # Get corresponding avatar latents (cycle if needed) latent_batch = [] for j in range(len(whisper_batch)): frame_idx = (start_frame + i + j) % len(self.avatar_latents) latent_batch.append(self.avatar_latents[frame_idx]) # Convert to tensors audio_tensor = torch.stack([ torch.FloatTensor(w).to(self.device) for w in whisper_batch ]) latent_tensor = torch.cat(latent_batch, dim=0).to(self.device) # Convert to FP16 if Blackwell optimizations enabled if self._blackwell_optimized: audio_tensor = audio_tensor.half() latent_tensor = latent_tensor.half() # Generate with UNet with torch.no_grad(): pred = self.unet( latent_tensor, self.timesteps, encoder_hidden_states=audio_tensor ).sample recon = self.vae.decode_latents(pred) for frame in recon: generated_frames.append(frame) return generated_frames def compose_video_chunk(self, generated_frames: List[np.ndarray], start_frame: int) -> List[np.ndarray]: """Compose generated frames with original avatar""" from musetalk.utils.blending import get_image coord_placeholder = (0.0, 0.0, 0.0, 0.0) composed_frames = [] for i, gen_frame in enumerate(generated_frames): frame_idx = (start_frame + i) % len(self.avatar_frames) bbox = self.avatar_coords[frame_idx] orig_frame = self.avatar_frames[frame_idx].copy() if bbox != coord_placeholder and bbox is not None: x1, y1, x2, y2 = [int(c) for c in bbox] gen_resized = cv2.resize(gen_frame, (x2-x1, y2-y1)) # Simple paste instead of blending for speed orig_frame[y1:y2, x1:x2] = gen_resized composed_frames.append(orig_frame) else: composed_frames.append(orig_frame) return composed_frames def generate_video_fast(self, audio_path: str, output_path: str = None, resolution: int = 256, batch_size: int = 8, callback=None) -> str: """ Generate lip-sync video with all optimizations. resolution: output face size (128, 192, 256, 320) - affects quality/speed batch_size: frames per batch (4, 8, 16) - affects VRAM usage/speed callback(progress, message) is called with progress updates. """ if output_path is None: output_path = str(self.results_dir / f"fast_{uuid.uuid4().hex[:8]}.mp4") print(f"\n[VIDEO GENERATION] resolution={resolution}, batch_size={batch_size}") print(f"video in {self.avatar_fps} FPS, audio idx in 50FPS") total_start = time.time() if callback: callback(0, "Extraindo features de áudio...") # Get total audio length and features audio_start = time.time() whisper_feature = self.audio_processor.audio2feat(audio_path) whisper_chunks = self.audio_processor.feature2chunks( feature_array=whisper_feature, fps=self.avatar_fps ) total_frames = len(whisper_chunks) if callback: callback(10, f"Processando {total_frames} frames...") print(f"Audio processing: {time.time()-audio_start:.2f}s, {total_frames} frames") # Process in parallel chunks chunk_size = 50 # frames per chunk num_chunks = (total_frames + chunk_size - 1) // chunk_size all_generated = [None] * total_frames gen_start = time.time() # Process chunks for chunk_idx in range(num_chunks): start_frame = chunk_idx * chunk_size end_frame = min(start_frame + chunk_size, total_frames) # Get whisper features for this chunk chunk_whisper = whisper_chunks[start_frame:end_frame] # Process in batches (using batch_size parameter) for i in range(0, len(chunk_whisper), batch_size): batch_end = min(i + batch_size, len(chunk_whisper)) whisper_batch = chunk_whisper[i:batch_end] # Get latents latent_batch = [] for j in range(len(whisper_batch)): frame_idx = (start_frame + i + j) % len(self.avatar_latents) latent_batch.append(self.avatar_latents[frame_idx]) # Convert to tensors audio_tensor = torch.stack([ torch.FloatTensor(w).to(self.device) for w in whisper_batch ]) latent_tensor = torch.cat(latent_batch, dim=0).to(self.device) # Convert to FP16 if Blackwell optimizations enabled if self._blackwell_optimized: audio_tensor = audio_tensor.half() latent_tensor = latent_tensor.half() # Generate with torch.no_grad(): pred = self.unet.model( latent_tensor, self.timesteps, encoder_hidden_states=audio_tensor ).sample recon = self.vae.decode_latents(pred) for j, frame in enumerate(recon): all_generated[start_frame + i + j] = frame progress = int(10 + (chunk_idx + 1) / num_chunks * 60) if callback: callback(progress, f"Gerado chunk {chunk_idx+1}/{num_chunks}") print(f"Generation: {time.time()-gen_start:.2f}s") if callback: callback(70, "Compondo frames...") # Compose all frames compose_start = time.time() coord_placeholder = (0.0, 0.0, 0.0, 0.0) # Scale factor based on resolution (256 is default/full quality) scale_factor = resolution / 256.0 final_frames = [] for i, gen_frame in enumerate(all_generated): if gen_frame is None: continue frame_idx = i % len(self.avatar_frames) bbox = self.avatar_coords[frame_idx] orig_frame = self.avatar_frames[frame_idx].copy() if bbox != coord_placeholder and bbox is not None: x1, y1, x2, y2 = [int(c) for c in bbox] # Ensure valid coords h, w = orig_frame.shape[:2] x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w, x2), min(h, y2) if x2 > x1 and y2 > y1: # Apply resolution scaling for quality/speed tradeoff target_w = max(32, int((x2-x1) * scale_factor)) target_h = max(32, int((y2-y1) * scale_factor)) gen_resized = cv2.resize(gen_frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR) gen_resized = cv2.resize(gen_resized, (x2-x1, y2-y1), interpolation=cv2.INTER_LINEAR) orig_frame[y1:y2, x1:x2] = gen_resized final_frames.append(orig_frame) else: final_frames.append(orig_frame) print(f"Composition: {time.time()-compose_start:.2f}s") if callback: callback(85, "Escrevendo vídeo...") # Write video write_start = time.time() temp_video = output_path.replace('.mp4', '_temp.mp4') h, w = final_frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_video, fourcc, self.avatar_fps, (w, h)) for frame in final_frames: out.write(frame) out.release() print(f"Video write: {time.time()-write_start:.2f}s") if callback: callback(95, "Adicionando áudio...") # Add audio cmd = [ "ffmpeg", "-y", "-v", "quiet", "-i", temp_video, "-i", audio_path, "-c:v", "libx264", "-preset", "fast", "-c:a", "aac", "-shortest", output_path ] subprocess.run(cmd, capture_output=True) try: os.remove(temp_video) except: pass total_time = time.time() - total_start print(f"\nTotal generation time: {total_time:.2f}s for {total_frames} frames") print(f"Speed: {total_frames/total_time:.1f} fps") if callback: callback(100, "Concluído!") return output_path def generate_frames_streaming(self, audio_path: str, resolution: int = 256, batch_size: int = 8): """ Generator that yields frames one by one as they're generated. Perfect for WebSocket streaming. Yields: {"type": "info", "total_frames": N, "fps": 30} {"type": "frame", "frame": numpy_array, "index": N} """ print(f"\n[STREAMING] Starting frame generation: resolution={resolution}, batch_size={batch_size}") # Get audio features whisper_feature = self.audio_processor.audio2feat(audio_path) whisper_chunks = self.audio_processor.feature2chunks( feature_array=whisper_feature, fps=self.avatar_fps ) total_frames = len(whisper_chunks) print(f"[STREAMING] Total frames to generate: {total_frames}") # Send info first yield {"type": "info", "total_frames": total_frames, "fps": self.avatar_fps} coord_placeholder = (0.0, 0.0, 0.0, 0.0) scale_factor = resolution / 256.0 frame_index = 0 # Process in batches but yield frames one by one for batch_start in range(0, total_frames, batch_size): batch_end = min(batch_start + batch_size, total_frames) whisper_batch = whisper_chunks[batch_start:batch_end] # Get latents for batch latent_batch = [] for j in range(len(whisper_batch)): idx = (batch_start + j) % len(self.avatar_latents) latent_batch.append(self.avatar_latents[idx]) # Convert to tensors audio_tensor = torch.stack([ torch.FloatTensor(w).to(self.device) for w in whisper_batch ]) latent_tensor = torch.cat(latent_batch, dim=0).to(self.device) # Convert to FP16 if Blackwell optimizations enabled if self._blackwell_optimized: audio_tensor = audio_tensor.half() latent_tensor = latent_tensor.half() # Generate batch with torch.no_grad(): pred = self.unet.model( latent_tensor, self.timesteps, encoder_hidden_states=audio_tensor ).sample recon = self.vae.decode_latents(pred) # Yield each frame from the batch for i, gen_frame in enumerate(recon): global_idx = batch_start + i if global_idx >= total_frames: break # Compose with original frame avatar_idx = global_idx % len(self.avatar_frames) bbox = self.avatar_coords[avatar_idx] orig_frame = self.avatar_frames[avatar_idx].copy() if bbox != coord_placeholder and bbox is not None: x1, y1, x2, y2 = [int(c) for c in bbox] h, w = orig_frame.shape[:2] x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w, x2), min(h, y2) if x2 > x1 and y2 > y1: target_w = max(32, int((x2-x1) * scale_factor)) target_h = max(32, int((y2-y1) * scale_factor)) gen_resized = cv2.resize(gen_frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR) gen_resized = cv2.resize(gen_resized, (x2-x1, y2-y1), interpolation=cv2.INTER_LINEAR) orig_frame[y1:y2, x1:x2] = gen_resized yield {"type": "frame", "frame": orig_frame, "index": frame_index} frame_index += 1 print(f"[STREAMING] Completed: {frame_index} frames") def get_idle_frames(self) -> Tuple[List[np.ndarray], float]: """ Get all frames from idle.mp4 for idle animation. Returns: (frames_list, fps) """ if not self.idle_loaded or not self.idle_frames: return [], 30.0 return self.idle_frames, self.idle_fps def generate_video_streaming(self, audio_path: str, chunk_callback=None) -> List[str]: """ Generate video in streaming chunks. chunk_callback(chunk_idx, chunk_path) is called as each chunk is ready. Returns list of chunk video paths. """ # Split audio into chunks chunk_duration = 3 # seconds # Get audio duration import wave try: with wave.open(audio_path, 'rb') as wav: duration = wav.getnframes() / wav.getframerate() except: # Fallback for non-wav duration = 10 # assume 10 seconds num_chunks = max(1, int(np.ceil(duration / chunk_duration))) chunk_paths = [] for i in range(num_chunks): start_time = i * chunk_duration # Extract audio chunk chunk_audio = str(self.results_dir / f"chunk_{i}_audio.wav") cmd = [ "ffmpeg", "-y", "-v", "quiet", "-i", audio_path, "-ss", str(start_time), "-t", str(chunk_duration), "-ar", "16000", "-ac", "1", chunk_audio ] subprocess.run(cmd, capture_output=True) # Generate video for this chunk chunk_video = str(self.results_dir / f"chunk_{i}_{uuid.uuid4().hex[:6]}.mp4") start_frame = int(start_time * self.avatar_fps) self.generate_video_fast(chunk_audio, chunk_video) chunk_paths.append(chunk_video) if chunk_callback: chunk_callback(i, chunk_video) # Cleanup audio chunk try: os.remove(chunk_audio) except: pass return chunk_paths # Global engine instance _engine: Optional[FastMuseTalkEngine] = None def get_engine() -> FastMuseTalkEngine: """Get or create the singleton engine""" global _engine if _engine is None: _engine = FastMuseTalkEngine() return _engine def initialize_engine(): """Initialize the engine (call at server startup)""" engine = get_engine() engine.initialize() return engine if __name__ == "__main__": # Test the engine engine = initialize_engine() # Test with sample audio test_audio = "/workspace/MuseTalk1.5/data/audio/teacher_intro.wav" if os.path.exists(test_audio): print("\nTesting video generation...") output = engine.generate_video_fast(test_audio) print(f"Output: {output}")