Spaces:
Runtime error
Runtime error
File size: 7,213 Bytes
98eeefd 0d4381f f5651ba 98eeefd f5651ba 98eeefd f5651ba 98eeefd f5651ba 98eeefd f5651ba 98eeefd f5651ba 98eeefd f5651ba 98eeefd dec7f5a a020bd3 98eeefd dec7f5a 98eeefd dec7f5a 98eeefd dec7f5a 0d4381f 98eeefd f3e388f 98eeefd 7f008b7 98eeefd 7f008b7 98eeefd dc296d1 98eeefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """ModelManager: Lazy loading and caching for ML models"""
import gc
import logging
import os
import torch
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
logger = logging.getLogger(__name__)
class ModelManager:
_instance = None
_whisper_encoder = None
_vae = None
_latentsync_unet = None
_musetalk_unet = None
_scheduler = None
_latentsync_config = None
_musetalk_pe = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def load_whisper_encoder(
self, model_path: str, device: str = "cuda", num_frames: int = 12
):
"""Load Whisper audio encoder (lazy loaded)"""
if self._whisper_encoder is None:
from latentsync.whisper.audio2feature import Audio2Feature
from config import MODELS_DIR
logger.info(f"Loading Whisper encoder from {model_path}...")
self._whisper_encoder = Audio2Feature(
model_path=model_path,
device=device,
num_frames=num_frames,
download_root=f"{MODELS_DIR}/whisper",
)
logger.info("Whisper encoder loaded")
return self._whisper_encoder
def load_vae(self, device: str = "cuda"):
"""Load VAE (lazy loaded)"""
if self._vae is None:
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
logger.info("Loading VAE...")
from config import MODELS_DIR
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=torch.float16,
cache_dir=MODELS_DIR,
)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
self._vae = vae.to(device)
logger.info("VAE loaded")
return self._vae
def get_scheduler(self):
"""Get DDIMScheduler (lazy loaded)"""
if self._scheduler is None:
logger.info("Loading DDIMScheduler...")
self._scheduler = DDIMScheduler.from_pretrained("configs")
logger.info("DDIMScheduler loaded")
return self._scheduler
def load_latentsync_unet(self, device: str = "cuda"):
"""Load LatentSync UNet (lazy loaded)"""
if self._latentsync_unet is None:
from latentsync.models.unet import UNet3DConditionModel
from config import MODELS_DIR
unet_path = f"{MODELS_DIR}/latentsync_unet.pt"
if not os.path.exists(unet_path):
logger.info("Downloading LatentSync-1.6 models...")
os.makedirs(MODELS_DIR, exist_ok=True)
snapshot_download(
repo_id="ByteDance/LatentSync-1.6", local_dir=MODELS_DIR
)
logger.info("Loading LatentSync UNet...")
config = self.get_latentsync_config()
unet, _ = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
unet_path,
device="cpu",
)
unet = unet.to(dtype=torch.float16).to(device)
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
self._latentsync_unet = unet
logger.info("LatentSync UNet loaded")
return self._latentsync_unet
def get_latentsync_config(self):
"""Get LatentSync config"""
if self._latentsync_config is None:
logger.info("Loading LatentSync config...")
unet_config_path = "configs/unet/stage2_512.yaml"
config = OmegaConf.load(unet_config_path)
self._latentsync_config = config
return self._latentsync_config
def load_musetalk_unet(self, device: str = "cuda"):
"""Load MuseTalk V1.5 UNet (lazy loaded)"""
if self._musetalk_unet is None:
import json
from diffusers import UNet2DConditionModel
logger.info("Downloading MuseTalk V1.5 models...")
os.makedirs("checkpoints", exist_ok=True)
snapshot_download(
repo_id="TMElyralab/MuseTalk",
local_dir="./checkpoints",
allow_patterns=["musetalkV15/*"],
)
logger.info("Loading MuseTalk V1.5 UNet...")
unet_config_path = "checkpoints/musetalkV15/musetalk.json"
unet_model_path = "checkpoints/musetalkV15/unet.pth"
with open(unet_config_path, "r") as f:
unet_config = json.load(f)
unet = UNet2DConditionModel(**unet_config)
weights = torch.load(
unet_model_path, map_location=device, weights_only=True
)
unet.load_state_dict(weights)
unet = unet.to(dtype=torch.float16).to(device)
from musetalk.models.unet import (
PositionalEncoding as MuseTalkPositionalEncoding,
)
pe = MuseTalkPositionalEncoding(d_model=384)
self._musetalk_unet = unet
self._musetalk_pe = pe
logger.info("MuseTalk UNet loaded")
return self._musetalk_unet, self._musetalk_pe
def get_whisper_model_path(self, cross_attention_dim: int):
"""Get Whisper model path based on cross_attention_dim"""
if cross_attention_dim == 768:
return "small"
elif cross_attention_dim == 384:
return "tiny"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
def preload_latentsync_models(self):
"""Preload all LatentSync models at startup"""
logger.info("Preloading LatentSync models...")
self.get_latentsync_config()
self.load_vae()
config = self.get_latentsync_config()
self.load_whisper_encoder(
self.get_whisper_model_path(config.model.cross_attention_dim),
"cuda",
config.data.num_frames,
)
self.load_latentsync_unet()
self.get_scheduler()
logger.info("LatentSync models preloaded successfully")
def preload_musetalk_models(self):
"""Preload all MuseTalk models at startup"""
logger.info("Preloading MuseTalk models...")
self.load_musetalk_unet()
logger.info("MuseTalk models preloaded successfully")
def clear_cache(self):
"""Clear GPU cache and unload all models"""
logger.info("Clearing model cache...")
self._whisper_encoder = None
self._vae = None
self._latentsync_unet = None
self._musetalk_unet = None
self._scheduler = None
self._latentsync_config = None
self._musetalk_pe = None
torch.cuda.empty_cache()
gc.collect()
logger.info("Model cache cleared")
|