""" Speech-X Model Loader - Standalone Version ========================================= Loads the complete MuseTalk v1.5 inference stack """ from __future__ import annotations import logging import sys from dataclasses import dataclass from pathlib import Path import torch # Use relative imports for standalone import sys from pathlib import Path _backend_dir = Path(__file__).parent.parent if str(_backend_dir) not in sys.path: sys.path.insert(0, str(_backend_dir)) from config import ( DEVICE, WEIGHTS_DIR, MUSETALK_UNET_FP16, ) log = logging.getLogger(__name__) def _ensure_musetalk_on_path() -> None: """Add backend/ to sys.path so `musetalk.*` imports resolve.""" p = str(Path(__file__).parent.parent) if p not in sys.path: sys.path.insert(0, p) @dataclass class ModelBundle: """ All models needed by the Speech-X pipeline. vae : musetalk VAE wrapper (.vae is diffusers AutoencoderKL) unet : musetalk UNet wrapper (.model is diffusers UNet) pe : PositionalEncoding nn.Module audio_processor : musetalk AudioProcessor (HF feature extractor) whisper : transformers.WhisperModel (encoder-only) timesteps : torch.tensor([0]) on device device : torch device string weight_dtype : torch.float16 or float32 """ vae: object unet: object pe: object audio_processor: object whisper: object timesteps: torch.Tensor device: str weight_dtype: torch.dtype def load_all_models(avatar_name: str = "christine") -> ModelBundle: """ Load all models needed for MuseTalk inference. This is the main entry point - call once at startup. """ from musetalk.worker import load_musetalk_models log.info("Loading all MuseTalk models...") bundle = load_musetalk_models(avatar_name=avatar_name, device=DEVICE) return ModelBundle( vae=bundle.vae, unet=bundle.unet, pe=bundle.pe, audio_processor=bundle.audio_processor, whisper=bundle.whisper, timesteps=bundle.timesteps, device=bundle.whisper.device, weight_dtype=torch.float16 if MUSETALK_UNET_FP16 else torch.float32, ) def prewarm_models(bundle: ModelBundle) -> None: """Run a warmup inference to optimize GPU kernels.""" import numpy as np log.info("Warming up models...") # Warmup whisper dummy_audio = np.zeros(16000, dtype=np.float32) # Note: actual warmup would call whisper encoder log.info("Models warmed up")