File size: 7,213 Bytes
98eeefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d4381f
f5651ba
98eeefd
 
 
f5651ba
 
 
 
98eeefd
 
 
 
 
 
 
 
 
 
f5651ba
 
98eeefd
f5651ba
 
 
98eeefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5651ba
98eeefd
f5651ba
 
 
 
 
 
 
 
98eeefd
 
 
 
 
 
f5651ba
98eeefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dec7f5a
 
a020bd3
 
 
 
 
 
 
 
98eeefd
 
 
 
 
dec7f5a
 
98eeefd
dec7f5a
 
 
98eeefd
dec7f5a
 
 
0d4381f
 
 
98eeefd
f3e388f
98eeefd
 
 
 
 
 
 
 
 
7f008b7
98eeefd
7f008b7
98eeefd
 
 
dc296d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98eeefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""ModelManager: Lazy loading and caching for ML models"""

import gc
import logging
import os
import torch
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf

logger = logging.getLogger(__name__)


class ModelManager:
    _instance = None

    _whisper_encoder = None
    _vae = None
    _latentsync_unet = None
    _musetalk_unet = None
    _scheduler = None
    _latentsync_config = None
    _musetalk_pe = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    def load_whisper_encoder(
        self, model_path: str, device: str = "cuda", num_frames: int = 12
    ):
        """Load Whisper audio encoder (lazy loaded)"""
        if self._whisper_encoder is None:
            from latentsync.whisper.audio2feature import Audio2Feature
            from config import MODELS_DIR

            logger.info(f"Loading Whisper encoder from {model_path}...")
            self._whisper_encoder = Audio2Feature(
                model_path=model_path,
                device=device,
                num_frames=num_frames,
                download_root=f"{MODELS_DIR}/whisper",
            )
            logger.info("Whisper encoder loaded")
        return self._whisper_encoder

    def load_vae(self, device: str = "cuda"):
        """Load VAE (lazy loaded)"""
        if self._vae is None:
            from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL

            logger.info("Loading VAE...")
            from config import MODELS_DIR

            vae = AutoencoderKL.from_pretrained(
                "stabilityai/sd-vae-ft-mse",
                torch_dtype=torch.float16,
                cache_dir=MODELS_DIR,
            )
            vae.config.scaling_factor = 0.18215
            vae.config.shift_factor = 0
            self._vae = vae.to(device)
            logger.info("VAE loaded")
        return self._vae

    def get_scheduler(self):
        """Get DDIMScheduler (lazy loaded)"""
        if self._scheduler is None:
            logger.info("Loading DDIMScheduler...")
            self._scheduler = DDIMScheduler.from_pretrained("configs")
            logger.info("DDIMScheduler loaded")
        return self._scheduler

    def load_latentsync_unet(self, device: str = "cuda"):
        """Load LatentSync UNet (lazy loaded)"""
        if self._latentsync_unet is None:
            from latentsync.models.unet import UNet3DConditionModel
            from config import MODELS_DIR

            unet_path = f"{MODELS_DIR}/latentsync_unet.pt"

            if not os.path.exists(unet_path):
                logger.info("Downloading LatentSync-1.6 models...")
                os.makedirs(MODELS_DIR, exist_ok=True)
                snapshot_download(
                    repo_id="ByteDance/LatentSync-1.6", local_dir=MODELS_DIR
                )

            logger.info("Loading LatentSync UNet...")
            config = self.get_latentsync_config()

            unet, _ = UNet3DConditionModel.from_pretrained(
                OmegaConf.to_container(config.model),
                unet_path,
                device="cpu",
            )
            unet = unet.to(dtype=torch.float16).to(device)

            from diffusers.utils.import_utils import is_xformers_available

            if is_xformers_available():
                unet.enable_xformers_memory_efficient_attention()

            self._latentsync_unet = unet
            logger.info("LatentSync UNet loaded")
        return self._latentsync_unet

    def get_latentsync_config(self):
        """Get LatentSync config"""
        if self._latentsync_config is None:
            logger.info("Loading LatentSync config...")
            unet_config_path = "configs/unet/stage2_512.yaml"
            config = OmegaConf.load(unet_config_path)
            self._latentsync_config = config
        return self._latentsync_config

    def load_musetalk_unet(self, device: str = "cuda"):
        """Load MuseTalk V1.5 UNet (lazy loaded)"""
        if self._musetalk_unet is None:
            import json
            from diffusers import UNet2DConditionModel

            logger.info("Downloading MuseTalk V1.5 models...")
            os.makedirs("checkpoints", exist_ok=True)
            snapshot_download(
                repo_id="TMElyralab/MuseTalk",
                local_dir="./checkpoints",
                allow_patterns=["musetalkV15/*"],
            )

            logger.info("Loading MuseTalk V1.5 UNet...")
            unet_config_path = "checkpoints/musetalkV15/musetalk.json"
            unet_model_path = "checkpoints/musetalkV15/unet.pth"

            with open(unet_config_path, "r") as f:
                unet_config = json.load(f)

            unet = UNet2DConditionModel(**unet_config)
            weights = torch.load(
                unet_model_path, map_location=device, weights_only=True
            )
            unet.load_state_dict(weights)
            unet = unet.to(dtype=torch.float16).to(device)

            from musetalk.models.unet import (
                PositionalEncoding as MuseTalkPositionalEncoding,
            )

            pe = MuseTalkPositionalEncoding(d_model=384)

            self._musetalk_unet = unet
            self._musetalk_pe = pe
            logger.info("MuseTalk UNet loaded")
        return self._musetalk_unet, self._musetalk_pe

    def get_whisper_model_path(self, cross_attention_dim: int):
        """Get Whisper model path based on cross_attention_dim"""
        if cross_attention_dim == 768:
            return "small"
        elif cross_attention_dim == 384:
            return "tiny"
        else:
            raise NotImplementedError("cross_attention_dim must be 768 or 384")

    def preload_latentsync_models(self):
        """Preload all LatentSync models at startup"""
        logger.info("Preloading LatentSync models...")
        self.get_latentsync_config()
        self.load_vae()
        config = self.get_latentsync_config()
        self.load_whisper_encoder(
            self.get_whisper_model_path(config.model.cross_attention_dim),
            "cuda",
            config.data.num_frames,
        )
        self.load_latentsync_unet()
        self.get_scheduler()
        logger.info("LatentSync models preloaded successfully")

    def preload_musetalk_models(self):
        """Preload all MuseTalk models at startup"""
        logger.info("Preloading MuseTalk models...")
        self.load_musetalk_unet()
        logger.info("MuseTalk models preloaded successfully")

    def clear_cache(self):
        """Clear GPU cache and unload all models"""
        logger.info("Clearing model cache...")

        self._whisper_encoder = None
        self._vae = None
        self._latentsync_unet = None
        self._musetalk_unet = None
        self._scheduler = None
        self._latentsync_config = None
        self._musetalk_pe = None

        torch.cuda.empty_cache()
        gc.collect()
        logger.info("Model cache cleared")