File size: 2,603 Bytes
249e06d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | """
Speech-X Model Loader - Standalone Version
=========================================
Loads the complete MuseTalk v1.5 inference stack
"""
from __future__ import annotations
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
import torch
# Use relative imports for standalone
import sys
from pathlib import Path
_backend_dir = Path(__file__).parent.parent
if str(_backend_dir) not in sys.path:
sys.path.insert(0, str(_backend_dir))
from config import (
DEVICE,
WEIGHTS_DIR,
MUSETALK_UNET_FP16,
)
log = logging.getLogger(__name__)
def _ensure_musetalk_on_path() -> None:
"""Add backend/ to sys.path so `musetalk.*` imports resolve."""
p = str(Path(__file__).parent.parent)
if p not in sys.path:
sys.path.insert(0, p)
@dataclass
class ModelBundle:
"""
All models needed by the Speech-X pipeline.
vae : musetalk VAE wrapper (.vae is diffusers AutoencoderKL)
unet : musetalk UNet wrapper (.model is diffusers UNet)
pe : PositionalEncoding nn.Module
audio_processor : musetalk AudioProcessor (HF feature extractor)
whisper : transformers.WhisperModel (encoder-only)
timesteps : torch.tensor([0]) on device
device : torch device string
weight_dtype : torch.float16 or float32
"""
vae: object
unet: object
pe: object
audio_processor: object
whisper: object
timesteps: torch.Tensor
device: str
weight_dtype: torch.dtype
def load_all_models(avatar_name: str = "christine") -> ModelBundle:
"""
Load all models needed for MuseTalk inference.
This is the main entry point - call once at startup.
"""
from musetalk.worker import load_musetalk_models
log.info("Loading all MuseTalk models...")
bundle = load_musetalk_models(avatar_name=avatar_name, device=DEVICE)
return ModelBundle(
vae=bundle.vae,
unet=bundle.unet,
pe=bundle.pe,
audio_processor=bundle.audio_processor,
whisper=bundle.whisper,
timesteps=bundle.timesteps,
device=bundle.whisper.device,
weight_dtype=torch.float16 if MUSETALK_UNET_FP16 else torch.float32,
)
def prewarm_models(bundle: ModelBundle) -> None:
"""Run a warmup inference to optimize GPU kernels."""
import numpy as np
log.info("Warming up models...")
# Warmup whisper
dummy_audio = np.zeros(16000, dtype=np.float32)
# Note: actual warmup would call whisper encoder
log.info("Models warmed up")
|