lipsync-docker / shared /model_manager.py
naicoi's picture
model-dirs (#2)
f5651ba
"""ModelManager: Lazy loading and caching for ML models"""
import gc
import logging
import os
import torch
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
logger = logging.getLogger(__name__)
class ModelManager:
_instance = None
_whisper_encoder = None
_vae = None
_latentsync_unet = None
_musetalk_unet = None
_scheduler = None
_latentsync_config = None
_musetalk_pe = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def load_whisper_encoder(
self, model_path: str, device: str = "cuda", num_frames: int = 12
):
"""Load Whisper audio encoder (lazy loaded)"""
if self._whisper_encoder is None:
from latentsync.whisper.audio2feature import Audio2Feature
from config import MODELS_DIR
logger.info(f"Loading Whisper encoder from {model_path}...")
self._whisper_encoder = Audio2Feature(
model_path=model_path,
device=device,
num_frames=num_frames,
download_root=f"{MODELS_DIR}/whisper",
)
logger.info("Whisper encoder loaded")
return self._whisper_encoder
def load_vae(self, device: str = "cuda"):
"""Load VAE (lazy loaded)"""
if self._vae is None:
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
logger.info("Loading VAE...")
from config import MODELS_DIR
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=torch.float16,
cache_dir=MODELS_DIR,
)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
self._vae = vae.to(device)
logger.info("VAE loaded")
return self._vae
def get_scheduler(self):
"""Get DDIMScheduler (lazy loaded)"""
if self._scheduler is None:
logger.info("Loading DDIMScheduler...")
self._scheduler = DDIMScheduler.from_pretrained("configs")
logger.info("DDIMScheduler loaded")
return self._scheduler
def load_latentsync_unet(self, device: str = "cuda"):
"""Load LatentSync UNet (lazy loaded)"""
if self._latentsync_unet is None:
from latentsync.models.unet import UNet3DConditionModel
from config import MODELS_DIR
unet_path = f"{MODELS_DIR}/latentsync_unet.pt"
if not os.path.exists(unet_path):
logger.info("Downloading LatentSync-1.6 models...")
os.makedirs(MODELS_DIR, exist_ok=True)
snapshot_download(
repo_id="ByteDance/LatentSync-1.6", local_dir=MODELS_DIR
)
logger.info("Loading LatentSync UNet...")
config = self.get_latentsync_config()
unet, _ = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
unet_path,
device="cpu",
)
unet = unet.to(dtype=torch.float16).to(device)
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
self._latentsync_unet = unet
logger.info("LatentSync UNet loaded")
return self._latentsync_unet
def get_latentsync_config(self):
"""Get LatentSync config"""
if self._latentsync_config is None:
logger.info("Loading LatentSync config...")
unet_config_path = "configs/unet/stage2_512.yaml"
config = OmegaConf.load(unet_config_path)
self._latentsync_config = config
return self._latentsync_config
def load_musetalk_unet(self, device: str = "cuda"):
"""Load MuseTalk V1.5 UNet (lazy loaded)"""
if self._musetalk_unet is None:
import json
from diffusers import UNet2DConditionModel
logger.info("Downloading MuseTalk V1.5 models...")
os.makedirs("checkpoints", exist_ok=True)
snapshot_download(
repo_id="TMElyralab/MuseTalk",
local_dir="./checkpoints",
allow_patterns=["musetalkV15/*"],
)
logger.info("Loading MuseTalk V1.5 UNet...")
unet_config_path = "checkpoints/musetalkV15/musetalk.json"
unet_model_path = "checkpoints/musetalkV15/unet.pth"
with open(unet_config_path, "r") as f:
unet_config = json.load(f)
unet = UNet2DConditionModel(**unet_config)
weights = torch.load(
unet_model_path, map_location=device, weights_only=True
)
unet.load_state_dict(weights)
unet = unet.to(dtype=torch.float16).to(device)
from musetalk.models.unet import (
PositionalEncoding as MuseTalkPositionalEncoding,
)
pe = MuseTalkPositionalEncoding(d_model=384)
self._musetalk_unet = unet
self._musetalk_pe = pe
logger.info("MuseTalk UNet loaded")
return self._musetalk_unet, self._musetalk_pe
def get_whisper_model_path(self, cross_attention_dim: int):
"""Get Whisper model path based on cross_attention_dim"""
if cross_attention_dim == 768:
return "small"
elif cross_attention_dim == 384:
return "tiny"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
def preload_latentsync_models(self):
"""Preload all LatentSync models at startup"""
logger.info("Preloading LatentSync models...")
self.get_latentsync_config()
self.load_vae()
config = self.get_latentsync_config()
self.load_whisper_encoder(
self.get_whisper_model_path(config.model.cross_attention_dim),
"cuda",
config.data.num_frames,
)
self.load_latentsync_unet()
self.get_scheduler()
logger.info("LatentSync models preloaded successfully")
def preload_musetalk_models(self):
"""Preload all MuseTalk models at startup"""
logger.info("Preloading MuseTalk models...")
self.load_musetalk_unet()
logger.info("MuseTalk models preloaded successfully")
def clear_cache(self):
"""Clear GPU cache and unload all models"""
logger.info("Clearing model cache...")
self._whisper_encoder = None
self._vae = None
self._latentsync_unet = None
self._musetalk_unet = None
self._scheduler = None
self._latentsync_config = None
self._musetalk_pe = None
torch.cuda.empty_cache()
gc.collect()
logger.info("Model cache cleared")