"""ModelManager: Lazy loading and caching for ML models""" import gc import logging import os import torch from diffusers.schedulers.scheduling_ddim import DDIMScheduler from huggingface_hub import snapshot_download from omegaconf import OmegaConf logger = logging.getLogger(__name__) class ModelManager: _instance = None _whisper_encoder = None _vae = None _latentsync_unet = None _musetalk_unet = None _scheduler = None _latentsync_config = None _musetalk_pe = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def load_whisper_encoder( self, model_path: str, device: str = "cuda", num_frames: int = 12 ): """Load Whisper audio encoder (lazy loaded)""" if self._whisper_encoder is None: from latentsync.whisper.audio2feature import Audio2Feature from config import MODELS_DIR logger.info(f"Loading Whisper encoder from {model_path}...") self._whisper_encoder = Audio2Feature( model_path=model_path, device=device, num_frames=num_frames, download_root=f"{MODELS_DIR}/whisper", ) logger.info("Whisper encoder loaded") return self._whisper_encoder def load_vae(self, device: str = "cuda"): """Load VAE (lazy loaded)""" if self._vae is None: from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL logger.info("Loading VAE...") from config import MODELS_DIR vae = AutoencoderKL.from_pretrained( "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16, cache_dir=MODELS_DIR, ) vae.config.scaling_factor = 0.18215 vae.config.shift_factor = 0 self._vae = vae.to(device) logger.info("VAE loaded") return self._vae def get_scheduler(self): """Get DDIMScheduler (lazy loaded)""" if self._scheduler is None: logger.info("Loading DDIMScheduler...") self._scheduler = DDIMScheduler.from_pretrained("configs") logger.info("DDIMScheduler loaded") return self._scheduler def load_latentsync_unet(self, device: str = "cuda"): """Load LatentSync UNet (lazy loaded)""" if self._latentsync_unet is None: from latentsync.models.unet import UNet3DConditionModel from config import MODELS_DIR unet_path = f"{MODELS_DIR}/latentsync_unet.pt" if not os.path.exists(unet_path): logger.info("Downloading LatentSync-1.6 models...") os.makedirs(MODELS_DIR, exist_ok=True) snapshot_download( repo_id="ByteDance/LatentSync-1.6", local_dir=MODELS_DIR ) logger.info("Loading LatentSync UNet...") config = self.get_latentsync_config() unet, _ = UNet3DConditionModel.from_pretrained( OmegaConf.to_container(config.model), unet_path, device="cpu", ) unet = unet.to(dtype=torch.float16).to(device) from diffusers.utils.import_utils import is_xformers_available if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() self._latentsync_unet = unet logger.info("LatentSync UNet loaded") return self._latentsync_unet def get_latentsync_config(self): """Get LatentSync config""" if self._latentsync_config is None: logger.info("Loading LatentSync config...") unet_config_path = "configs/unet/stage2_512.yaml" config = OmegaConf.load(unet_config_path) self._latentsync_config = config return self._latentsync_config def load_musetalk_unet(self, device: str = "cuda"): """Load MuseTalk V1.5 UNet (lazy loaded)""" if self._musetalk_unet is None: import json from diffusers import UNet2DConditionModel logger.info("Downloading MuseTalk V1.5 models...") os.makedirs("checkpoints", exist_ok=True) snapshot_download( repo_id="TMElyralab/MuseTalk", local_dir="./checkpoints", allow_patterns=["musetalkV15/*"], ) logger.info("Loading MuseTalk V1.5 UNet...") unet_config_path = "checkpoints/musetalkV15/musetalk.json" unet_model_path = "checkpoints/musetalkV15/unet.pth" with open(unet_config_path, "r") as f: unet_config = json.load(f) unet = UNet2DConditionModel(**unet_config) weights = torch.load( unet_model_path, map_location=device, weights_only=True ) unet.load_state_dict(weights) unet = unet.to(dtype=torch.float16).to(device) from musetalk.models.unet import ( PositionalEncoding as MuseTalkPositionalEncoding, ) pe = MuseTalkPositionalEncoding(d_model=384) self._musetalk_unet = unet self._musetalk_pe = pe logger.info("MuseTalk UNet loaded") return self._musetalk_unet, self._musetalk_pe def get_whisper_model_path(self, cross_attention_dim: int): """Get Whisper model path based on cross_attention_dim""" if cross_attention_dim == 768: return "small" elif cross_attention_dim == 384: return "tiny" else: raise NotImplementedError("cross_attention_dim must be 768 or 384") def preload_latentsync_models(self): """Preload all LatentSync models at startup""" logger.info("Preloading LatentSync models...") self.get_latentsync_config() self.load_vae() config = self.get_latentsync_config() self.load_whisper_encoder( self.get_whisper_model_path(config.model.cross_attention_dim), "cuda", config.data.num_frames, ) self.load_latentsync_unet() self.get_scheduler() logger.info("LatentSync models preloaded successfully") def preload_musetalk_models(self): """Preload all MuseTalk models at startup""" logger.info("Preloading MuseTalk models...") self.load_musetalk_unet() logger.info("MuseTalk models preloaded successfully") def clear_cache(self): """Clear GPU cache and unload all models""" logger.info("Clearing model cache...") self._whisper_encoder = None self._vae = None self._latentsync_unet = None self._musetalk_unet = None self._scheduler = None self._latentsync_config = None self._musetalk_pe = None torch.cuda.empty_cache() gc.collect() logger.info("Model cache cleared")