LongCat-AudioDiT-Enhanced / memory_manager.py
cronos3k's picture
Upload memory_manager.py with huggingface_hub
0d1b9e9 verified
"""
VRAM Memory Manager for LongCat-AudioDiT + Whisper.
Orchestrates loading and unloading of:
- AudioDiT TTS models (1B / 3.5B)
- Whisper STT models (turbo / large-v3)
Modes:
"auto" – probe available VRAM; keep both loaded if possible, else sequential
"simultaneous"– always keep both loaded (fails if VRAM too small)
"sequential" – always unload one before loading the other (safest for ≤12GB)
"""
import gc
import logging
from enum import Enum
from typing import Dict, Optional
import torch
logger = logging.getLogger(__name__)
# Estimated peak VRAM (GB) per model in fp16 / int8 on 1 GPU
AUDIODIT_VRAM = {
"1B": 4.0,
"3.5B": 10.0,
}
WHISPER_VRAM = {
"turbo": 1.6,
"large-v3": 3.0,
}
# Leave this headroom free for activations, KV-cache, OS
VRAM_HEADROOM_GB = 2.0
class LoadMode(str, Enum):
AUTO = "auto"
SIMULTANEOUS = "simultaneous"
SEQUENTIAL = "sequential"
def _available_vram_gb() -> float:
"""Return free VRAM in GB on the default CUDA device, or 0 if no GPU."""
if not torch.cuda.is_available():
return 0.0
free, _ = torch.cuda.mem_get_info()
return free / (1024 ** 3)
def _total_vram_gb() -> float:
if not torch.cuda.is_available():
return 0.0
_, total = torch.cuda.mem_get_info()
return total / (1024 ** 3)
def _used_vram_gb() -> float:
if not torch.cuda.is_available():
return 0.0
allocated = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
return max(allocated, reserved) / (1024 ** 3)
class ModelMemoryManager:
"""
Coordinates AudioDiT + Whisper model lifecycle.
Typical usage::
mgr = ModelMemoryManager(mode="auto")
tts_model, tokenizer = mgr.get_tts(audiodit_size="1B", device="cuda")
# ... generate audio ...
whisper = mgr.get_whisper(whisper_size="turbo")
text, lang = whisper.transcribe("audio.wav")
mgr.release_all()
"""
def __init__(self, mode: str = "auto"):
self.mode = LoadMode(mode)
self._tts_model = None
self._tts_tokenizer = None
self._tts_size: Optional[str] = None
self._tts_device: Optional[str] = None
self._whisper: Optional[object] = None # WhisperHelper
self._whisper_size: Optional[str] = None
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def get_tts(self, audiodit_size: str = "1B", device: str = "cuda"):
"""
Return (AudioDiTModel, tokenizer), loading if necessary.
If mode is sequential and Whisper is loaded, Whisper is unloaded first.
"""
if self._tts_model is not None and self._tts_size == audiodit_size:
return self._tts_model, self._tts_tokenizer
# Need to load a (potentially different) TTS model
if self._tts_model is not None:
self._unload_tts()
# Sequential: unload Whisper first
if self._should_unload_whisper_for_tts(audiodit_size):
logger.info("Sequential mode: unloading Whisper before loading AudioDiT %s", audiodit_size)
self._unload_whisper()
self._load_tts(audiodit_size, device)
return self._tts_model, self._tts_tokenizer
def get_whisper(self, whisper_size: str = "turbo"):
"""
Return WhisperHelper, loading if necessary.
If mode is sequential and TTS is loaded, TTS is unloaded first.
"""
from whisper_helper import WhisperHelper
if self._whisper is not None and self._whisper_size == whisper_size:
return self._whisper
# Need to (re)load
if self._whisper is not None:
self._unload_whisper()
# Sequential: unload TTS first
if self._should_unload_tts_for_whisper(whisper_size):
logger.info("Sequential mode: unloading AudioDiT before loading Whisper %s", whisper_size)
self._unload_tts()
device = self._tts_device or ("cuda" if torch.cuda.is_available() else "cpu")
self._whisper = WhisperHelper(model_size=whisper_size, device=device)
self._whisper.load()
self._whisper_size = whisper_size
return self._whisper
def release_tts(self):
"""Explicitly unload TTS model."""
self._unload_tts()
def release_whisper(self):
"""Explicitly unload Whisper model."""
self._unload_whisper()
def release_all(self):
"""Unload everything and free VRAM."""
self._unload_tts()
self._unload_whisper()
# ------------------------------------------------------------------
# Status helpers
# ------------------------------------------------------------------
def status(self) -> Dict:
tts_loaded = self._tts_model is not None
whisper_loaded = self._whisper is not None and getattr(self._whisper, "is_loaded", False)
return {
"mode": self.mode.value,
"tts_loaded": tts_loaded,
"tts_size": self._tts_size if tts_loaded else None,
"whisper_loaded": whisper_loaded,
"whisper_size": self._whisper_size if whisper_loaded else None,
"vram_used_gb": round(_used_vram_gb(), 2),
"vram_total_gb": round(_total_vram_gb(), 2),
"vram_free_gb": round(_available_vram_gb(), 2),
}
def status_str(self) -> str:
s = self.status()
lines = [
f"Mode: {s['mode']}",
f"TTS: {'[ON] ' + s['tts_size'] if s['tts_loaded'] else '[OFF] not loaded'}",
f"Whisper: {'[ON] ' + s['whisper_size'] if s['whisper_loaded'] else '[OFF] not loaded'}",
]
if torch.cuda.is_available():
lines.append(
f"VRAM: {s['vram_used_gb']:.1f} / {s['vram_total_gb']:.1f} GB "
f"({s['vram_free_gb']:.1f} GB free)"
)
return "\n".join(lines)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _should_unload_whisper_for_tts(self, audiodit_size: str) -> bool:
if self._whisper is None:
return False
if self.mode == LoadMode.SEQUENTIAL:
return True
if self.mode == LoadMode.SIMULTANEOUS:
return False
# AUTO: check if both fit
needed = AUDIODIT_VRAM.get(audiodit_size, 10.0) + WHISPER_VRAM.get(self._whisper_size, 3.0)
available = _available_vram_gb() + WHISPER_VRAM.get(self._whisper_size, 3.0) # pretend whisper free
return needed + VRAM_HEADROOM_GB > available
def _should_unload_tts_for_whisper(self, whisper_size: str) -> bool:
if self._tts_model is None:
return False
if self.mode == LoadMode.SEQUENTIAL:
return True
if self.mode == LoadMode.SIMULTANEOUS:
return False
# AUTO
needed = AUDIODIT_VRAM.get(self._tts_size, 10.0) + WHISPER_VRAM.get(whisper_size, 3.0)
available = _available_vram_gb() + AUDIODIT_VRAM.get(self._tts_size, 10.0)
return needed + VRAM_HEADROOM_GB > available
def _load_tts(self, audiodit_size: str, device: str):
import audiodit # noqa: F401 – registers AutoConfig / AutoModel
from audiodit import AudioDiTModel
from transformers import AutoTokenizer
from pathlib import Path
from safetensors import safe_open
# Prefer local model dir; fall back to HF Hub id
local_dir_map = {
"1B": Path(__file__).parent / "models" / "audiodit" / "1B",
"3.5B": Path(__file__).parent / "models" / "audiodit" / "3.5B",
}
hf_id_map = {
"1B": "meituan-longcat/LongCat-AudioDiT-1B",
"3.5B": "meituan-longcat/LongCat-AudioDiT-3.5B",
}
local_dir = local_dir_map.get(audiodit_size)
if local_dir and (local_dir / "config.json").exists():
model_id = str(local_dir)
safetensors_path = local_dir / "model.safetensors"
else:
model_id = hf_id_map.get(audiodit_size, audiodit_size)
safetensors_path = None
logger.info("Loading AudioDiT %s from %s on %s …", audiodit_size, model_id, device)
torch_device = torch.device(device)
model = AudioDiTModel.from_pretrained(model_id).to(torch_device)
# Transformers 5.x uses meta-device init which breaks weight_norm parameters
# in the VAE (weight_g stays zero → NaN output). Fix: reload VAE weights
# directly from safetensors, bypassing the meta-device path.
# When loading from HF Hub, find the cached safetensors file.
if safetensors_path is None:
try:
from huggingface_hub import try_to_load_from_cache
cached = try_to_load_from_cache(model_id, "model.safetensors")
if cached and Path(cached).exists():
safetensors_path = Path(cached)
except Exception:
pass
if safetensors_path and Path(safetensors_path).exists():
logger.info("Reloading VAE weights from safetensors (meta-device fix) …")
vae_sd = {}
with safe_open(str(safetensors_path), framework="pt", device="cpu") as f:
for k in f.keys():
if k.startswith("vae."):
vae_sd[k[4:]] = f.get_tensor(k)
model.vae.load_state_dict(vae_sd, strict=True)
logger.info("VAE weights reloaded OK.")
else:
logger.warning("Could not find safetensors for VAE fix — output may be silence.")
model.vae.to_half()
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder_model)
self._tts_model = model
self._tts_tokenizer = tokenizer
self._tts_size = audiodit_size
self._tts_device = device
logger.info("AudioDiT %s loaded.", audiodit_size)
def _unload_tts(self):
if self._tts_model is None:
return
logger.info("Unloading AudioDiT %s …", self._tts_size)
del self._tts_model
del self._tts_tokenizer
self._tts_model = None
self._tts_tokenizer = None
self._tts_size = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("AudioDiT unloaded.")
def _unload_whisper(self):
if self._whisper is None:
return
logger.info("Unloading Whisper %s …", self._whisper_size)
self._whisper.unload()
self._whisper = None
self._whisper_size = None
logger.info("Whisper unloaded.")