Spaces:
Runtime error
Runtime error
| """ModelManager: Lazy loading and caching for ML models""" | |
| import gc | |
| import logging | |
| import os | |
| import torch | |
| from diffusers.schedulers.scheduling_ddim import DDIMScheduler | |
| from huggingface_hub import snapshot_download | |
| from omegaconf import OmegaConf | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| _instance = None | |
| _whisper_encoder = None | |
| _vae = None | |
| _latentsync_unet = None | |
| _musetalk_unet = None | |
| _scheduler = None | |
| _latentsync_config = None | |
| _musetalk_pe = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def get_instance(cls): | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def load_whisper_encoder( | |
| self, model_path: str, device: str = "cuda", num_frames: int = 12 | |
| ): | |
| """Load Whisper audio encoder (lazy loaded)""" | |
| if self._whisper_encoder is None: | |
| from latentsync.whisper.audio2feature import Audio2Feature | |
| from config import MODELS_DIR | |
| logger.info(f"Loading Whisper encoder from {model_path}...") | |
| self._whisper_encoder = Audio2Feature( | |
| model_path=model_path, | |
| device=device, | |
| num_frames=num_frames, | |
| download_root=f"{MODELS_DIR}/whisper", | |
| ) | |
| logger.info("Whisper encoder loaded") | |
| return self._whisper_encoder | |
| def load_vae(self, device: str = "cuda"): | |
| """Load VAE (lazy loaded)""" | |
| if self._vae is None: | |
| from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL | |
| logger.info("Loading VAE...") | |
| from config import MODELS_DIR | |
| vae = AutoencoderKL.from_pretrained( | |
| "stabilityai/sd-vae-ft-mse", | |
| torch_dtype=torch.float16, | |
| cache_dir=MODELS_DIR, | |
| ) | |
| vae.config.scaling_factor = 0.18215 | |
| vae.config.shift_factor = 0 | |
| self._vae = vae.to(device) | |
| logger.info("VAE loaded") | |
| return self._vae | |
| def get_scheduler(self): | |
| """Get DDIMScheduler (lazy loaded)""" | |
| if self._scheduler is None: | |
| logger.info("Loading DDIMScheduler...") | |
| self._scheduler = DDIMScheduler.from_pretrained("configs") | |
| logger.info("DDIMScheduler loaded") | |
| return self._scheduler | |
| def load_latentsync_unet(self, device: str = "cuda"): | |
| """Load LatentSync UNet (lazy loaded)""" | |
| if self._latentsync_unet is None: | |
| from latentsync.models.unet import UNet3DConditionModel | |
| from config import MODELS_DIR | |
| unet_path = f"{MODELS_DIR}/latentsync_unet.pt" | |
| if not os.path.exists(unet_path): | |
| logger.info("Downloading LatentSync-1.6 models...") | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| snapshot_download( | |
| repo_id="ByteDance/LatentSync-1.6", local_dir=MODELS_DIR | |
| ) | |
| logger.info("Loading LatentSync UNet...") | |
| config = self.get_latentsync_config() | |
| unet, _ = UNet3DConditionModel.from_pretrained( | |
| OmegaConf.to_container(config.model), | |
| unet_path, | |
| device="cpu", | |
| ) | |
| unet = unet.to(dtype=torch.float16).to(device) | |
| from diffusers.utils.import_utils import is_xformers_available | |
| if is_xformers_available(): | |
| unet.enable_xformers_memory_efficient_attention() | |
| self._latentsync_unet = unet | |
| logger.info("LatentSync UNet loaded") | |
| return self._latentsync_unet | |
| def get_latentsync_config(self): | |
| """Get LatentSync config""" | |
| if self._latentsync_config is None: | |
| logger.info("Loading LatentSync config...") | |
| unet_config_path = "configs/unet/stage2_512.yaml" | |
| config = OmegaConf.load(unet_config_path) | |
| self._latentsync_config = config | |
| return self._latentsync_config | |
| def load_musetalk_unet(self, device: str = "cuda"): | |
| """Load MuseTalk V1.5 UNet (lazy loaded)""" | |
| if self._musetalk_unet is None: | |
| import json | |
| from diffusers import UNet2DConditionModel | |
| logger.info("Downloading MuseTalk V1.5 models...") | |
| os.makedirs("checkpoints", exist_ok=True) | |
| snapshot_download( | |
| repo_id="TMElyralab/MuseTalk", | |
| local_dir="./checkpoints", | |
| allow_patterns=["musetalkV15/*"], | |
| ) | |
| logger.info("Loading MuseTalk V1.5 UNet...") | |
| unet_config_path = "checkpoints/musetalkV15/musetalk.json" | |
| unet_model_path = "checkpoints/musetalkV15/unet.pth" | |
| with open(unet_config_path, "r") as f: | |
| unet_config = json.load(f) | |
| unet = UNet2DConditionModel(**unet_config) | |
| weights = torch.load( | |
| unet_model_path, map_location=device, weights_only=True | |
| ) | |
| unet.load_state_dict(weights) | |
| unet = unet.to(dtype=torch.float16).to(device) | |
| from musetalk.models.unet import ( | |
| PositionalEncoding as MuseTalkPositionalEncoding, | |
| ) | |
| pe = MuseTalkPositionalEncoding(d_model=384) | |
| self._musetalk_unet = unet | |
| self._musetalk_pe = pe | |
| logger.info("MuseTalk UNet loaded") | |
| return self._musetalk_unet, self._musetalk_pe | |
| def get_whisper_model_path(self, cross_attention_dim: int): | |
| """Get Whisper model path based on cross_attention_dim""" | |
| if cross_attention_dim == 768: | |
| return "small" | |
| elif cross_attention_dim == 384: | |
| return "tiny" | |
| else: | |
| raise NotImplementedError("cross_attention_dim must be 768 or 384") | |
| def preload_latentsync_models(self): | |
| """Preload all LatentSync models at startup""" | |
| logger.info("Preloading LatentSync models...") | |
| self.get_latentsync_config() | |
| self.load_vae() | |
| config = self.get_latentsync_config() | |
| self.load_whisper_encoder( | |
| self.get_whisper_model_path(config.model.cross_attention_dim), | |
| "cuda", | |
| config.data.num_frames, | |
| ) | |
| self.load_latentsync_unet() | |
| self.get_scheduler() | |
| logger.info("LatentSync models preloaded successfully") | |
| def preload_musetalk_models(self): | |
| """Preload all MuseTalk models at startup""" | |
| logger.info("Preloading MuseTalk models...") | |
| self.load_musetalk_unet() | |
| logger.info("MuseTalk models preloaded successfully") | |
| def clear_cache(self): | |
| """Clear GPU cache and unload all models""" | |
| logger.info("Clearing model cache...") | |
| self._whisper_encoder = None | |
| self._vae = None | |
| self._latentsync_unet = None | |
| self._musetalk_unet = None | |
| self._scheduler = None | |
| self._latentsync_config = None | |
| self._musetalk_pe = None | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| logger.info("Model cache cleared") | |