Spaces:
Running on Zero
Running on Zero
| """ | |
| VRAM Memory Manager for LongCat-AudioDiT + Whisper. | |
| Orchestrates loading and unloading of: | |
| - AudioDiT TTS models (1B / 3.5B) | |
| - Whisper STT models (turbo / large-v3) | |
| Modes: | |
| "auto" – probe available VRAM; keep both loaded if possible, else sequential | |
| "simultaneous"– always keep both loaded (fails if VRAM too small) | |
| "sequential" – always unload one before loading the other (safest for ≤12GB) | |
| """ | |
| import gc | |
| import logging | |
| from enum import Enum | |
| from typing import Dict, Optional | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| # Estimated peak VRAM (GB) per model in fp16 / int8 on 1 GPU | |
| AUDIODIT_VRAM = { | |
| "1B": 4.0, | |
| "3.5B": 10.0, | |
| } | |
| WHISPER_VRAM = { | |
| "turbo": 1.6, | |
| "large-v3": 3.0, | |
| } | |
| # Leave this headroom free for activations, KV-cache, OS | |
| VRAM_HEADROOM_GB = 2.0 | |
| class LoadMode(str, Enum): | |
| AUTO = "auto" | |
| SIMULTANEOUS = "simultaneous" | |
| SEQUENTIAL = "sequential" | |
| def _available_vram_gb() -> float: | |
| """Return free VRAM in GB on the default CUDA device, or 0 if no GPU.""" | |
| if not torch.cuda.is_available(): | |
| return 0.0 | |
| free, _ = torch.cuda.mem_get_info() | |
| return free / (1024 ** 3) | |
| def _total_vram_gb() -> float: | |
| if not torch.cuda.is_available(): | |
| return 0.0 | |
| _, total = torch.cuda.mem_get_info() | |
| return total / (1024 ** 3) | |
| def _used_vram_gb() -> float: | |
| if not torch.cuda.is_available(): | |
| return 0.0 | |
| allocated = torch.cuda.memory_allocated() | |
| reserved = torch.cuda.memory_reserved() | |
| return max(allocated, reserved) / (1024 ** 3) | |
| class ModelMemoryManager: | |
| """ | |
| Coordinates AudioDiT + Whisper model lifecycle. | |
| Typical usage:: | |
| mgr = ModelMemoryManager(mode="auto") | |
| tts_model, tokenizer = mgr.get_tts(audiodit_size="1B", device="cuda") | |
| # ... generate audio ... | |
| whisper = mgr.get_whisper(whisper_size="turbo") | |
| text, lang = whisper.transcribe("audio.wav") | |
| mgr.release_all() | |
| """ | |
| def __init__(self, mode: str = "auto"): | |
| self.mode = LoadMode(mode) | |
| self._tts_model = None | |
| self._tts_tokenizer = None | |
| self._tts_size: Optional[str] = None | |
| self._tts_device: Optional[str] = None | |
| self._whisper: Optional[object] = None # WhisperHelper | |
| self._whisper_size: Optional[str] = None | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def get_tts(self, audiodit_size: str = "1B", device: str = "cuda"): | |
| """ | |
| Return (AudioDiTModel, tokenizer), loading if necessary. | |
| If mode is sequential and Whisper is loaded, Whisper is unloaded first. | |
| """ | |
| if self._tts_model is not None and self._tts_size == audiodit_size: | |
| return self._tts_model, self._tts_tokenizer | |
| # Need to load a (potentially different) TTS model | |
| if self._tts_model is not None: | |
| self._unload_tts() | |
| # Sequential: unload Whisper first | |
| if self._should_unload_whisper_for_tts(audiodit_size): | |
| logger.info("Sequential mode: unloading Whisper before loading AudioDiT %s", audiodit_size) | |
| self._unload_whisper() | |
| self._load_tts(audiodit_size, device) | |
| return self._tts_model, self._tts_tokenizer | |
| def get_whisper(self, whisper_size: str = "turbo"): | |
| """ | |
| Return WhisperHelper, loading if necessary. | |
| If mode is sequential and TTS is loaded, TTS is unloaded first. | |
| """ | |
| from whisper_helper import WhisperHelper | |
| if self._whisper is not None and self._whisper_size == whisper_size: | |
| return self._whisper | |
| # Need to (re)load | |
| if self._whisper is not None: | |
| self._unload_whisper() | |
| # Sequential: unload TTS first | |
| if self._should_unload_tts_for_whisper(whisper_size): | |
| logger.info("Sequential mode: unloading AudioDiT before loading Whisper %s", whisper_size) | |
| self._unload_tts() | |
| device = self._tts_device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self._whisper = WhisperHelper(model_size=whisper_size, device=device) | |
| self._whisper.load() | |
| self._whisper_size = whisper_size | |
| return self._whisper | |
| def release_tts(self): | |
| """Explicitly unload TTS model.""" | |
| self._unload_tts() | |
| def release_whisper(self): | |
| """Explicitly unload Whisper model.""" | |
| self._unload_whisper() | |
| def release_all(self): | |
| """Unload everything and free VRAM.""" | |
| self._unload_tts() | |
| self._unload_whisper() | |
| # ------------------------------------------------------------------ | |
| # Status helpers | |
| # ------------------------------------------------------------------ | |
| def status(self) -> Dict: | |
| tts_loaded = self._tts_model is not None | |
| whisper_loaded = self._whisper is not None and getattr(self._whisper, "is_loaded", False) | |
| return { | |
| "mode": self.mode.value, | |
| "tts_loaded": tts_loaded, | |
| "tts_size": self._tts_size if tts_loaded else None, | |
| "whisper_loaded": whisper_loaded, | |
| "whisper_size": self._whisper_size if whisper_loaded else None, | |
| "vram_used_gb": round(_used_vram_gb(), 2), | |
| "vram_total_gb": round(_total_vram_gb(), 2), | |
| "vram_free_gb": round(_available_vram_gb(), 2), | |
| } | |
| def status_str(self) -> str: | |
| s = self.status() | |
| lines = [ | |
| f"Mode: {s['mode']}", | |
| f"TTS: {'[ON] ' + s['tts_size'] if s['tts_loaded'] else '[OFF] not loaded'}", | |
| f"Whisper: {'[ON] ' + s['whisper_size'] if s['whisper_loaded'] else '[OFF] not loaded'}", | |
| ] | |
| if torch.cuda.is_available(): | |
| lines.append( | |
| f"VRAM: {s['vram_used_gb']:.1f} / {s['vram_total_gb']:.1f} GB " | |
| f"({s['vram_free_gb']:.1f} GB free)" | |
| ) | |
| return "\n".join(lines) | |
| # ------------------------------------------------------------------ | |
| # Private helpers | |
| # ------------------------------------------------------------------ | |
| def _should_unload_whisper_for_tts(self, audiodit_size: str) -> bool: | |
| if self._whisper is None: | |
| return False | |
| if self.mode == LoadMode.SEQUENTIAL: | |
| return True | |
| if self.mode == LoadMode.SIMULTANEOUS: | |
| return False | |
| # AUTO: check if both fit | |
| needed = AUDIODIT_VRAM.get(audiodit_size, 10.0) + WHISPER_VRAM.get(self._whisper_size, 3.0) | |
| available = _available_vram_gb() + WHISPER_VRAM.get(self._whisper_size, 3.0) # pretend whisper free | |
| return needed + VRAM_HEADROOM_GB > available | |
| def _should_unload_tts_for_whisper(self, whisper_size: str) -> bool: | |
| if self._tts_model is None: | |
| return False | |
| if self.mode == LoadMode.SEQUENTIAL: | |
| return True | |
| if self.mode == LoadMode.SIMULTANEOUS: | |
| return False | |
| # AUTO | |
| needed = AUDIODIT_VRAM.get(self._tts_size, 10.0) + WHISPER_VRAM.get(whisper_size, 3.0) | |
| available = _available_vram_gb() + AUDIODIT_VRAM.get(self._tts_size, 10.0) | |
| return needed + VRAM_HEADROOM_GB > available | |
| def _load_tts(self, audiodit_size: str, device: str): | |
| import audiodit # noqa: F401 – registers AutoConfig / AutoModel | |
| from audiodit import AudioDiTModel | |
| from transformers import AutoTokenizer | |
| from pathlib import Path | |
| from safetensors import safe_open | |
| # Prefer local model dir; fall back to HF Hub id | |
| local_dir_map = { | |
| "1B": Path(__file__).parent / "models" / "audiodit" / "1B", | |
| "3.5B": Path(__file__).parent / "models" / "audiodit" / "3.5B", | |
| } | |
| hf_id_map = { | |
| "1B": "meituan-longcat/LongCat-AudioDiT-1B", | |
| "3.5B": "meituan-longcat/LongCat-AudioDiT-3.5B", | |
| } | |
| local_dir = local_dir_map.get(audiodit_size) | |
| if local_dir and (local_dir / "config.json").exists(): | |
| model_id = str(local_dir) | |
| safetensors_path = local_dir / "model.safetensors" | |
| else: | |
| model_id = hf_id_map.get(audiodit_size, audiodit_size) | |
| safetensors_path = None | |
| logger.info("Loading AudioDiT %s from %s on %s …", audiodit_size, model_id, device) | |
| torch_device = torch.device(device) | |
| model = AudioDiTModel.from_pretrained(model_id).to(torch_device) | |
| # Transformers 5.x uses meta-device init which breaks weight_norm parameters | |
| # in the VAE (weight_g stays zero → NaN output). Fix: reload VAE weights | |
| # directly from safetensors, bypassing the meta-device path. | |
| # When loading from HF Hub, find the cached safetensors file. | |
| if safetensors_path is None: | |
| try: | |
| from huggingface_hub import try_to_load_from_cache | |
| cached = try_to_load_from_cache(model_id, "model.safetensors") | |
| if cached and Path(cached).exists(): | |
| safetensors_path = Path(cached) | |
| except Exception: | |
| pass | |
| if safetensors_path and Path(safetensors_path).exists(): | |
| logger.info("Reloading VAE weights from safetensors (meta-device fix) …") | |
| vae_sd = {} | |
| with safe_open(str(safetensors_path), framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| if k.startswith("vae."): | |
| vae_sd[k[4:]] = f.get_tensor(k) | |
| model.vae.load_state_dict(vae_sd, strict=True) | |
| logger.info("VAE weights reloaded OK.") | |
| else: | |
| logger.warning("Could not find safetensors for VAE fix — output may be silence.") | |
| model.vae.to_half() | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder_model) | |
| self._tts_model = model | |
| self._tts_tokenizer = tokenizer | |
| self._tts_size = audiodit_size | |
| self._tts_device = device | |
| logger.info("AudioDiT %s loaded.", audiodit_size) | |
| def _unload_tts(self): | |
| if self._tts_model is None: | |
| return | |
| logger.info("Unloading AudioDiT %s …", self._tts_size) | |
| del self._tts_model | |
| del self._tts_tokenizer | |
| self._tts_model = None | |
| self._tts_tokenizer = None | |
| self._tts_size = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("AudioDiT unloaded.") | |
| def _unload_whisper(self): | |
| if self._whisper is None: | |
| return | |
| logger.info("Unloading Whisper %s …", self._whisper_size) | |
| self._whisper.unload() | |
| self._whisper = None | |
| self._whisper_size = None | |
| logger.info("Whisper unloaded.") | |