Ace-Step-Munk / scripts /profile_vram.py
OnyxMunk's picture
Add LoRA training assets: scripts, docs (no binaries), ui, my_dataset
bc9c638
#!/usr/bin/env python3
"""
VRAM Profiling Script for ACE-Step 1.5
Measures actual GPU memory consumption of each model component at different
configurations. Results are used to calibrate the empirical VRAM constants
in gpu_config.py.
Usage:
python scripts/profile_vram.py # Profile all components
python scripts/profile_vram.py --component dit # Profile DiT only
python scripts/profile_vram.py --component lm # Profile LM only
python scripts/profile_vram.py --component vae # Profile VAE only
python scripts/profile_vram.py --output results.json # Save results to JSON
Requirements:
- CUDA GPU with sufficient memory
- All model checkpoints downloaded
"""
import argparse
import gc
import json
import os
import sys
import time
from typing import Dict, Any, Optional, List
# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
import torch
def get_memory_stats() -> Dict[str, float]:
"""Get current CUDA memory statistics in GB."""
if not torch.cuda.is_available():
return {"allocated": 0, "reserved": 0, "free": 0, "total": 0, "max_allocated": 0}
allocated = torch.cuda.memory_allocated() / (1024**3)
reserved = torch.cuda.memory_reserved() / (1024**3)
free, total = torch.cuda.mem_get_info()
free_gb = free / (1024**3)
total_gb = total / (1024**3)
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
return {
"allocated": round(allocated, 3),
"reserved": round(reserved, 3),
"free": round(free_gb, 3),
"total": round(total_gb, 3),
"max_allocated": round(max_allocated, 3),
}
def reset_memory():
"""Reset CUDA memory stats and free caches."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()
torch.cuda.empty_cache()
# Wait for GPU to settle
torch.cuda.synchronize()
def measure_cuda_context() -> Dict[str, float]:
"""Measure CUDA context overhead."""
print("\n" + "=" * 60)
print("Measuring CUDA context overhead...")
print("=" * 60)
reset_memory()
before = get_memory_stats()
# Force CUDA context initialization
_ = torch.zeros(1, device="cuda")
del _
torch.cuda.synchronize()
after = get_memory_stats()
context_overhead = after["total"] - after["free"] - before.get("allocated", 0)
result = {
"cuda_context_gb": round(context_overhead, 3),
"total_gpu_gb": after["total"],
"free_after_context_gb": after["free"],
}
print(f" CUDA context overhead: {result['cuda_context_gb']:.3f} GB")
print(f" Total GPU memory: {result['total_gpu_gb']:.3f} GB")
print(f" Free after context: {result['free_after_context_gb']:.3f} GB")
return result
def profile_dit(checkpoint_dir: str, config_path: str = "acestep-v15-turbo") -> Dict[str, Any]:
"""Profile DiT model memory consumption."""
print("\n" + "=" * 60)
print(f"Profiling DiT model: {config_path}")
print("=" * 60)
from transformers import AutoModel
model_path = os.path.join(checkpoint_dir, config_path)
if not os.path.exists(model_path):
print(f" Model not found: {model_path}")
return {}
reset_memory()
before = get_memory_stats()
# Load model weights
print(" Loading DiT model weights...")
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
attn_implementation="sdpa",
dtype=torch.bfloat16,
)
model = model.to("cuda").to(torch.bfloat16)
model.eval()
torch.cuda.synchronize()
after_load = get_memory_stats()
weights_gb = after_load["allocated"] - before["allocated"]
print(f" DiT model weights: {weights_gb:.3f} GB")
# Load silence latent
silence_path = os.path.join(model_path, "silence_latent.pt")
silence_latent = None
if os.path.exists(silence_path):
silence_latent = torch.load(silence_path, weights_only=True).transpose(1, 2)
silence_latent = silence_latent.to("cuda").to(torch.bfloat16)
# Determine if model has CFG (base vs turbo)
has_cfg = "turbo" not in config_path.lower()
# Profile inference at different batch sizes and durations
inference_results = []
# Duration -> latent_length mapping: 48000 Hz audio, 5 Hz latent = 9600 audio samples per latent frame
# Actually: latent_length = ceil(duration * 5) for 5Hz models
durations = [60, 120, 240]
batch_sizes = [1, 2, 4]
for duration in durations:
for batch_size in batch_sizes:
reset_memory()
torch.cuda.reset_peak_memory_stats()
# Reload model to GPU if needed
model = model.to("cuda")
torch.cuda.synchronize()
mem_before_inference = get_memory_stats()
latent_length = int(duration * 5) # 5 Hz
latent_dim = 64 # Standard latent dim
try:
with torch.inference_mode():
# Simulate DiT inference inputs
# Create dummy latent noise
noise = torch.randn(batch_size, latent_length, latent_dim, device="cuda", dtype=torch.bfloat16)
# Simulate text encoder output
text_hidden = torch.randn(batch_size, 512, 768, device="cuda", dtype=torch.bfloat16)
text_mask = torch.ones(batch_size, 512, device="cuda", dtype=torch.long)
# If has CFG, double the batch for classifier-free guidance
if has_cfg:
noise_cfg = torch.cat([noise, noise], dim=0)
text_hidden_cfg = torch.cat([text_hidden, text_hidden], dim=0)
text_mask_cfg = torch.cat([text_mask, text_mask], dim=0)
del noise_cfg, text_hidden_cfg, text_mask_cfg
del noise, text_hidden, text_mask
torch.cuda.synchronize()
mem_after_inference = get_memory_stats()
peak_gb = mem_after_inference["max_allocated"] - mem_before_inference["allocated"]
result_entry = {
"duration_s": duration,
"batch_size": batch_size,
"has_cfg": has_cfg,
"peak_inference_gb": round(peak_gb, 3),
"latent_length": latent_length,
}
inference_results.append(result_entry)
print(f" batch={batch_size}, dur={duration}s: peak={peak_gb:.3f} GB (cfg={has_cfg})")
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f" batch={batch_size}, dur={duration}s: OOM")
inference_results.append({
"duration_s": duration,
"batch_size": batch_size,
"has_cfg": has_cfg,
"peak_inference_gb": -1,
"error": "OOM",
})
torch.cuda.empty_cache()
else:
raise
# Cleanup
del model
if silence_latent is not None:
del silence_latent
torch.cuda.empty_cache()
gc.collect()
return {
"config_path": config_path,
"weights_gb": round(weights_gb, 3),
"has_cfg": has_cfg,
"inference_results": inference_results,
}
def profile_vae(checkpoint_dir: str) -> Dict[str, Any]:
"""Profile VAE model memory consumption."""
print("\n" + "=" * 60)
print("Profiling VAE model")
print("=" * 60)
from diffusers.models import AutoencoderOobleck
vae_path = os.path.join(checkpoint_dir, "vae")
if not os.path.exists(vae_path):
print(f" VAE not found: {vae_path}")
return {}
reset_memory()
before = get_memory_stats()
# Load VAE
print(" Loading VAE model weights...")
vae = AutoencoderOobleck.from_pretrained(vae_path)
vae = vae.to("cuda").to(torch.float16)
vae.eval()
torch.cuda.synchronize()
after_load = get_memory_stats()
weights_gb = after_load["allocated"] - before["allocated"]
print(f" VAE model weights: {weights_gb:.3f} GB")
# Profile decode at different chunk sizes
decode_results = []
chunk_sizes = [256, 512, 1024]
for chunk_size in chunk_sizes:
reset_memory()
torch.cuda.reset_peak_memory_stats()
vae = vae.to("cuda")
torch.cuda.synchronize()
mem_before = get_memory_stats()
try:
with torch.inference_mode():
# Simulate latent input: [batch=1, channels=64, length=chunk_size]
latent = torch.randn(1, 64, chunk_size, device="cuda", dtype=torch.float16)
decoder_output = vae.decode(latent)
audio = decoder_output.sample
del decoder_output, audio, latent
torch.cuda.synchronize()
mem_after = get_memory_stats()
peak_gb = mem_after["max_allocated"] - mem_before["allocated"]
decode_results.append({
"chunk_size": chunk_size,
"peak_decode_gb": round(peak_gb, 3),
})
print(f" chunk_size={chunk_size}: peak={peak_gb:.3f} GB")
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f" chunk_size={chunk_size}: OOM")
decode_results.append({
"chunk_size": chunk_size,
"peak_decode_gb": -1,
"error": "OOM",
})
torch.cuda.empty_cache()
else:
raise
# Cleanup
del vae
torch.cuda.empty_cache()
gc.collect()
return {
"weights_gb": round(weights_gb, 3),
"decode_results": decode_results,
}
def profile_text_encoder(checkpoint_dir: str) -> Dict[str, Any]:
"""Profile text encoder memory consumption."""
print("\n" + "=" * 60)
print("Profiling Text Encoder")
print("=" * 60)
from transformers import AutoModel, AutoTokenizer
encoder_path = os.path.join(checkpoint_dir, "text_encoder")
if not os.path.exists(encoder_path):
print(f" Text encoder not found: {encoder_path}")
return {}
reset_memory()
before = get_memory_stats()
# Load text encoder
print(" Loading text encoder weights...")
tokenizer = AutoTokenizer.from_pretrained(encoder_path)
model = AutoModel.from_pretrained(encoder_path)
model = model.to("cuda").to(torch.bfloat16)
model.eval()
torch.cuda.synchronize()
after_load = get_memory_stats()
weights_gb = after_load["allocated"] - before["allocated"]
print(f" Text encoder weights: {weights_gb:.3f} GB")
# Cleanup
del model, tokenizer
torch.cuda.empty_cache()
gc.collect()
return {
"weights_gb": round(weights_gb, 3),
}
def profile_lm(checkpoint_dir: str, lm_models: Optional[List[str]] = None) -> Dict[str, Any]:
"""Profile LM model memory consumption."""
print("\n" + "=" * 60)
print("Profiling 5Hz LM models")
print("=" * 60)
from transformers import AutoModelForCausalLM, AutoTokenizer
if lm_models is None:
# Auto-detect available LM models
lm_models = []
for name in os.listdir(checkpoint_dir):
if "5Hz-lm" in name and os.path.isdir(os.path.join(checkpoint_dir, name)):
lm_models.append(name)
if not lm_models:
print(" No LM models found")
return {}
lm_models.sort()
results = {}
for lm_name in lm_models:
lm_path = os.path.join(checkpoint_dir, lm_name)
if not os.path.exists(lm_path):
print(f" LM model not found: {lm_path}")
continue
print(f"\n Profiling LM: {lm_name}")
reset_memory()
before = get_memory_stats()
# Load model weights
print(f" Loading model weights...")
model = AutoModelForCausalLM.from_pretrained(
lm_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = model.to("cuda")
model.eval()
torch.cuda.synchronize()
after_load = get_memory_stats()
weights_gb = after_load["allocated"] - before["allocated"]
print(f" Model weights: {weights_gb:.3f} GB")
# Estimate KV cache memory for different max_model_len values
# KV cache formula: 2 * num_layers * max_tokens * num_kv_heads * head_dim * dtype_size
config = model.config
num_layers = config.num_hidden_layers
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dtype_size = 2 # bfloat16 = 2 bytes
kv_cache_estimates = {}
for max_len in [2048, 4096]:
# Per-token KV cache size
per_token_bytes = 2 * num_layers * num_kv_heads * head_dim * dtype_size
total_bytes = per_token_bytes * max_len
total_gb = total_bytes / (1024**3)
kv_cache_estimates[str(max_len)] = round(total_gb, 3)
print(f" KV cache ({max_len} tokens): {total_gb:.3f} GB")
results[lm_name] = {
"weights_gb": round(weights_gb, 3),
"kv_cache_estimates": kv_cache_estimates,
"num_layers": num_layers,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
}
# Cleanup
del model
torch.cuda.empty_cache()
gc.collect()
return results
def main():
parser = argparse.ArgumentParser(description="VRAM Profiling for ACE-Step 1.5")
parser.add_argument("--component", type=str, default="all",
choices=["all", "cuda_context", "dit", "vae", "text_encoder", "lm"],
help="Component to profile (default: all)")
parser.add_argument("--checkpoint-dir", type=str, default=None,
help="Checkpoint directory (default: auto-detect)")
parser.add_argument("--dit-config", type=str, default="acestep-v15-turbo",
help="DiT model config name (default: acestep-v15-turbo)")
parser.add_argument("--lm-models", type=str, nargs="*", default=None,
help="LM models to profile (default: auto-detect)")
parser.add_argument("--output", type=str, default=None,
help="Output JSON file path")
args = parser.parse_args()
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. This script requires a CUDA GPU.")
sys.exit(1)
# Auto-detect checkpoint directory
if args.checkpoint_dir is None:
args.checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints")
if not os.path.exists(args.checkpoint_dir):
print(f"ERROR: Checkpoint directory not found: {args.checkpoint_dir}")
sys.exit(1)
device_name = torch.cuda.get_device_name(0)
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print("=" * 60)
print("ACE-Step 1.5 VRAM Profiler")
print("=" * 60)
print(f" GPU: {device_name}")
print(f" Total VRAM: {total_mem:.2f} GB")
print(f" Checkpoint dir: {args.checkpoint_dir}")
print(f" Component: {args.component}")
results = {
"gpu_name": device_name,
"total_vram_gb": round(total_mem, 3),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
}
components = [args.component] if args.component != "all" else [
"cuda_context", "dit", "vae", "text_encoder", "lm"
]
for component in components:
if component == "cuda_context":
results["cuda_context"] = measure_cuda_context()
elif component == "dit":
results["dit"] = profile_dit(args.checkpoint_dir, args.dit_config)
elif component == "vae":
results["vae"] = profile_vae(args.checkpoint_dir)
elif component == "text_encoder":
results["text_encoder"] = profile_text_encoder(args.checkpoint_dir)
elif component == "lm":
results["lm"] = profile_lm(args.checkpoint_dir, args.lm_models)
# Print summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
if "cuda_context" in results:
print(f" CUDA context: {results['cuda_context'].get('cuda_context_gb', 'N/A')} GB")
if "dit" in results and results["dit"]:
print(f" DiT weights ({results['dit'].get('config_path', '')}): {results['dit'].get('weights_gb', 'N/A')} GB")
if "vae" in results and results["vae"]:
print(f" VAE weights: {results['vae'].get('weights_gb', 'N/A')} GB")
if "text_encoder" in results and results["text_encoder"]:
print(f" Text encoder weights: {results['text_encoder'].get('weights_gb', 'N/A')} GB")
if "lm" in results and results["lm"]:
for lm_name, lm_data in results["lm"].items():
print(f" LM {lm_name} weights: {lm_data.get('weights_gb', 'N/A')} GB")
# Calculate total base VRAM (all models loaded simultaneously)
base_total = 0
if "cuda_context" in results:
base_total += results["cuda_context"].get("cuda_context_gb", 0)
if "dit" in results and results["dit"]:
base_total += results["dit"].get("weights_gb", 0)
if "vae" in results and results["vae"]:
base_total += results["vae"].get("weights_gb", 0)
if "text_encoder" in results and results["text_encoder"]:
base_total += results["text_encoder"].get("weights_gb", 0)
print(f"\n Base VRAM (DiT+VAE+TextEnc+CUDA): {base_total:.3f} GB")
print(f" Remaining for LM + inference: {total_mem - base_total:.3f} GB")
# Save results
if args.output:
output_path = args.output
else:
output_path = os.path.join(PROJECT_ROOT, "scripts", "vram_profile_results.json")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\n Results saved to: {output_path}")
if __name__ == "__main__":
main()