agkavin
Initial commit: speech_to_video project with models via LFS
249e06d
"""
Speech-X Model Loader - Standalone Version
=========================================
Loads the complete MuseTalk v1.5 inference stack
"""
from __future__ import annotations
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
import torch
# Use relative imports for standalone
import sys
from pathlib import Path
_backend_dir = Path(__file__).parent.parent
if str(_backend_dir) not in sys.path:
sys.path.insert(0, str(_backend_dir))
from config import (
DEVICE,
WEIGHTS_DIR,
MUSETALK_UNET_FP16,
)
log = logging.getLogger(__name__)
def _ensure_musetalk_on_path() -> None:
"""Add backend/ to sys.path so `musetalk.*` imports resolve."""
p = str(Path(__file__).parent.parent)
if p not in sys.path:
sys.path.insert(0, p)
@dataclass
class ModelBundle:
"""
All models needed by the Speech-X pipeline.
vae : musetalk VAE wrapper (.vae is diffusers AutoencoderKL)
unet : musetalk UNet wrapper (.model is diffusers UNet)
pe : PositionalEncoding nn.Module
audio_processor : musetalk AudioProcessor (HF feature extractor)
whisper : transformers.WhisperModel (encoder-only)
timesteps : torch.tensor([0]) on device
device : torch device string
weight_dtype : torch.float16 or float32
"""
vae: object
unet: object
pe: object
audio_processor: object
whisper: object
timesteps: torch.Tensor
device: str
weight_dtype: torch.dtype
def load_all_models(avatar_name: str = "christine") -> ModelBundle:
"""
Load all models needed for MuseTalk inference.
This is the main entry point - call once at startup.
"""
from musetalk.worker import load_musetalk_models
log.info("Loading all MuseTalk models...")
bundle = load_musetalk_models(avatar_name=avatar_name, device=DEVICE)
return ModelBundle(
vae=bundle.vae,
unet=bundle.unet,
pe=bundle.pe,
audio_processor=bundle.audio_processor,
whisper=bundle.whisper,
timesteps=bundle.timesteps,
device=bundle.whisper.device,
weight_dtype=torch.float16 if MUSETALK_UNET_FP16 else torch.float32,
)
def prewarm_models(bundle: ModelBundle) -> None:
"""Run a warmup inference to optimize GPU kernels."""
import numpy as np
log.info("Warming up models...")
# Warmup whisper
dummy_audio = np.zeros(16000, dtype=np.float32)
# Note: actual warmup would call whisper encoder
log.info("Models warmed up")