| """ |
| Speech-X Model Loader - Standalone Version |
| ========================================= |
| Loads the complete MuseTalk v1.5 inference stack |
| """ |
| from __future__ import annotations |
|
|
| import logging |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
|
|
| |
| import sys |
| from pathlib import Path |
|
|
| _backend_dir = Path(__file__).parent.parent |
| if str(_backend_dir) not in sys.path: |
| sys.path.insert(0, str(_backend_dir)) |
|
|
| from config import ( |
| DEVICE, |
| WEIGHTS_DIR, |
| MUSETALK_UNET_FP16, |
| ) |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| def _ensure_musetalk_on_path() -> None: |
| """Add backend/ to sys.path so `musetalk.*` imports resolve.""" |
| p = str(Path(__file__).parent.parent) |
| if p not in sys.path: |
| sys.path.insert(0, p) |
|
|
|
|
| @dataclass |
| class ModelBundle: |
| """ |
| All models needed by the Speech-X pipeline. |
| |
| vae : musetalk VAE wrapper (.vae is diffusers AutoencoderKL) |
| unet : musetalk UNet wrapper (.model is diffusers UNet) |
| pe : PositionalEncoding nn.Module |
| audio_processor : musetalk AudioProcessor (HF feature extractor) |
| whisper : transformers.WhisperModel (encoder-only) |
| timesteps : torch.tensor([0]) on device |
| device : torch device string |
| weight_dtype : torch.float16 or float32 |
| """ |
| vae: object |
| unet: object |
| pe: object |
| audio_processor: object |
| whisper: object |
| timesteps: torch.Tensor |
| device: str |
| weight_dtype: torch.dtype |
|
|
|
|
| def load_all_models(avatar_name: str = "christine") -> ModelBundle: |
| """ |
| Load all models needed for MuseTalk inference. |
| |
| This is the main entry point - call once at startup. |
| """ |
| from musetalk.worker import load_musetalk_models |
| |
| log.info("Loading all MuseTalk models...") |
| |
| bundle = load_musetalk_models(avatar_name=avatar_name, device=DEVICE) |
| |
| return ModelBundle( |
| vae=bundle.vae, |
| unet=bundle.unet, |
| pe=bundle.pe, |
| audio_processor=bundle.audio_processor, |
| whisper=bundle.whisper, |
| timesteps=bundle.timesteps, |
| device=bundle.whisper.device, |
| weight_dtype=torch.float16 if MUSETALK_UNET_FP16 else torch.float32, |
| ) |
|
|
|
|
| def prewarm_models(bundle: ModelBundle) -> None: |
| """Run a warmup inference to optimize GPU kernels.""" |
| import numpy as np |
| |
| log.info("Warming up models...") |
| |
| |
| dummy_audio = np.zeros(16000, dtype=np.float32) |
| |
| |
| log.info("Models warmed up") |
|
|