Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| VRAM Profiling Script for ACE-Step 1.5 | |
| Measures actual GPU memory consumption of each model component at different | |
| configurations. Results are used to calibrate the empirical VRAM constants | |
| in gpu_config.py. | |
| Usage: | |
| python scripts/profile_vram.py # Profile all components | |
| python scripts/profile_vram.py --component dit # Profile DiT only | |
| python scripts/profile_vram.py --component lm # Profile LM only | |
| python scripts/profile_vram.py --component vae # Profile VAE only | |
| python scripts/profile_vram.py --output results.json # Save results to JSON | |
| Requirements: | |
| - CUDA GPU with sufficient memory | |
| - All model checkpoints downloaded | |
| """ | |
| import argparse | |
| import gc | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from typing import Dict, Any, Optional, List | |
| # Add project root to path | |
| PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| import torch | |
| def get_memory_stats() -> Dict[str, float]: | |
| """Get current CUDA memory statistics in GB.""" | |
| if not torch.cuda.is_available(): | |
| return {"allocated": 0, "reserved": 0, "free": 0, "total": 0, "max_allocated": 0} | |
| allocated = torch.cuda.memory_allocated() / (1024**3) | |
| reserved = torch.cuda.memory_reserved() / (1024**3) | |
| free, total = torch.cuda.mem_get_info() | |
| free_gb = free / (1024**3) | |
| total_gb = total / (1024**3) | |
| max_allocated = torch.cuda.max_memory_allocated() / (1024**3) | |
| return { | |
| "allocated": round(allocated, 3), | |
| "reserved": round(reserved, 3), | |
| "free": round(free_gb, 3), | |
| "total": round(total_gb, 3), | |
| "max_allocated": round(max_allocated, 3), | |
| } | |
| def reset_memory(): | |
| """Reset CUDA memory stats and free caches.""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Wait for GPU to settle | |
| torch.cuda.synchronize() | |
| def measure_cuda_context() -> Dict[str, float]: | |
| """Measure CUDA context overhead.""" | |
| print("\n" + "=" * 60) | |
| print("Measuring CUDA context overhead...") | |
| print("=" * 60) | |
| reset_memory() | |
| before = get_memory_stats() | |
| # Force CUDA context initialization | |
| _ = torch.zeros(1, device="cuda") | |
| del _ | |
| torch.cuda.synchronize() | |
| after = get_memory_stats() | |
| context_overhead = after["total"] - after["free"] - before.get("allocated", 0) | |
| result = { | |
| "cuda_context_gb": round(context_overhead, 3), | |
| "total_gpu_gb": after["total"], | |
| "free_after_context_gb": after["free"], | |
| } | |
| print(f" CUDA context overhead: {result['cuda_context_gb']:.3f} GB") | |
| print(f" Total GPU memory: {result['total_gpu_gb']:.3f} GB") | |
| print(f" Free after context: {result['free_after_context_gb']:.3f} GB") | |
| return result | |
| def profile_dit(checkpoint_dir: str, config_path: str = "acestep-v15-turbo") -> Dict[str, Any]: | |
| """Profile DiT model memory consumption.""" | |
| print("\n" + "=" * 60) | |
| print(f"Profiling DiT model: {config_path}") | |
| print("=" * 60) | |
| from transformers import AutoModel | |
| model_path = os.path.join(checkpoint_dir, config_path) | |
| if not os.path.exists(model_path): | |
| print(f" Model not found: {model_path}") | |
| return {} | |
| reset_memory() | |
| before = get_memory_stats() | |
| # Load model weights | |
| print(" Loading DiT model weights...") | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| attn_implementation="sdpa", | |
| dtype=torch.bfloat16, | |
| ) | |
| model = model.to("cuda").to(torch.bfloat16) | |
| model.eval() | |
| torch.cuda.synchronize() | |
| after_load = get_memory_stats() | |
| weights_gb = after_load["allocated"] - before["allocated"] | |
| print(f" DiT model weights: {weights_gb:.3f} GB") | |
| # Load silence latent | |
| silence_path = os.path.join(model_path, "silence_latent.pt") | |
| silence_latent = None | |
| if os.path.exists(silence_path): | |
| silence_latent = torch.load(silence_path, weights_only=True).transpose(1, 2) | |
| silence_latent = silence_latent.to("cuda").to(torch.bfloat16) | |
| # Determine if model has CFG (base vs turbo) | |
| has_cfg = "turbo" not in config_path.lower() | |
| # Profile inference at different batch sizes and durations | |
| inference_results = [] | |
| # Duration -> latent_length mapping: 48000 Hz audio, 5 Hz latent = 9600 audio samples per latent frame | |
| # Actually: latent_length = ceil(duration * 5) for 5Hz models | |
| durations = [60, 120, 240] | |
| batch_sizes = [1, 2, 4] | |
| for duration in durations: | |
| for batch_size in batch_sizes: | |
| reset_memory() | |
| torch.cuda.reset_peak_memory_stats() | |
| # Reload model to GPU if needed | |
| model = model.to("cuda") | |
| torch.cuda.synchronize() | |
| mem_before_inference = get_memory_stats() | |
| latent_length = int(duration * 5) # 5 Hz | |
| latent_dim = 64 # Standard latent dim | |
| try: | |
| with torch.inference_mode(): | |
| # Simulate DiT inference inputs | |
| # Create dummy latent noise | |
| noise = torch.randn(batch_size, latent_length, latent_dim, device="cuda", dtype=torch.bfloat16) | |
| # Simulate text encoder output | |
| text_hidden = torch.randn(batch_size, 512, 768, device="cuda", dtype=torch.bfloat16) | |
| text_mask = torch.ones(batch_size, 512, device="cuda", dtype=torch.long) | |
| # If has CFG, double the batch for classifier-free guidance | |
| if has_cfg: | |
| noise_cfg = torch.cat([noise, noise], dim=0) | |
| text_hidden_cfg = torch.cat([text_hidden, text_hidden], dim=0) | |
| text_mask_cfg = torch.cat([text_mask, text_mask], dim=0) | |
| del noise_cfg, text_hidden_cfg, text_mask_cfg | |
| del noise, text_hidden, text_mask | |
| torch.cuda.synchronize() | |
| mem_after_inference = get_memory_stats() | |
| peak_gb = mem_after_inference["max_allocated"] - mem_before_inference["allocated"] | |
| result_entry = { | |
| "duration_s": duration, | |
| "batch_size": batch_size, | |
| "has_cfg": has_cfg, | |
| "peak_inference_gb": round(peak_gb, 3), | |
| "latent_length": latent_length, | |
| } | |
| inference_results.append(result_entry) | |
| print(f" batch={batch_size}, dur={duration}s: peak={peak_gb:.3f} GB (cfg={has_cfg})") | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| print(f" batch={batch_size}, dur={duration}s: OOM") | |
| inference_results.append({ | |
| "duration_s": duration, | |
| "batch_size": batch_size, | |
| "has_cfg": has_cfg, | |
| "peak_inference_gb": -1, | |
| "error": "OOM", | |
| }) | |
| torch.cuda.empty_cache() | |
| else: | |
| raise | |
| # Cleanup | |
| del model | |
| if silence_latent is not None: | |
| del silence_latent | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return { | |
| "config_path": config_path, | |
| "weights_gb": round(weights_gb, 3), | |
| "has_cfg": has_cfg, | |
| "inference_results": inference_results, | |
| } | |
| def profile_vae(checkpoint_dir: str) -> Dict[str, Any]: | |
| """Profile VAE model memory consumption.""" | |
| print("\n" + "=" * 60) | |
| print("Profiling VAE model") | |
| print("=" * 60) | |
| from diffusers.models import AutoencoderOobleck | |
| vae_path = os.path.join(checkpoint_dir, "vae") | |
| if not os.path.exists(vae_path): | |
| print(f" VAE not found: {vae_path}") | |
| return {} | |
| reset_memory() | |
| before = get_memory_stats() | |
| # Load VAE | |
| print(" Loading VAE model weights...") | |
| vae = AutoencoderOobleck.from_pretrained(vae_path) | |
| vae = vae.to("cuda").to(torch.float16) | |
| vae.eval() | |
| torch.cuda.synchronize() | |
| after_load = get_memory_stats() | |
| weights_gb = after_load["allocated"] - before["allocated"] | |
| print(f" VAE model weights: {weights_gb:.3f} GB") | |
| # Profile decode at different chunk sizes | |
| decode_results = [] | |
| chunk_sizes = [256, 512, 1024] | |
| for chunk_size in chunk_sizes: | |
| reset_memory() | |
| torch.cuda.reset_peak_memory_stats() | |
| vae = vae.to("cuda") | |
| torch.cuda.synchronize() | |
| mem_before = get_memory_stats() | |
| try: | |
| with torch.inference_mode(): | |
| # Simulate latent input: [batch=1, channels=64, length=chunk_size] | |
| latent = torch.randn(1, 64, chunk_size, device="cuda", dtype=torch.float16) | |
| decoder_output = vae.decode(latent) | |
| audio = decoder_output.sample | |
| del decoder_output, audio, latent | |
| torch.cuda.synchronize() | |
| mem_after = get_memory_stats() | |
| peak_gb = mem_after["max_allocated"] - mem_before["allocated"] | |
| decode_results.append({ | |
| "chunk_size": chunk_size, | |
| "peak_decode_gb": round(peak_gb, 3), | |
| }) | |
| print(f" chunk_size={chunk_size}: peak={peak_gb:.3f} GB") | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| print(f" chunk_size={chunk_size}: OOM") | |
| decode_results.append({ | |
| "chunk_size": chunk_size, | |
| "peak_decode_gb": -1, | |
| "error": "OOM", | |
| }) | |
| torch.cuda.empty_cache() | |
| else: | |
| raise | |
| # Cleanup | |
| del vae | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return { | |
| "weights_gb": round(weights_gb, 3), | |
| "decode_results": decode_results, | |
| } | |
| def profile_text_encoder(checkpoint_dir: str) -> Dict[str, Any]: | |
| """Profile text encoder memory consumption.""" | |
| print("\n" + "=" * 60) | |
| print("Profiling Text Encoder") | |
| print("=" * 60) | |
| from transformers import AutoModel, AutoTokenizer | |
| encoder_path = os.path.join(checkpoint_dir, "text_encoder") | |
| if not os.path.exists(encoder_path): | |
| print(f" Text encoder not found: {encoder_path}") | |
| return {} | |
| reset_memory() | |
| before = get_memory_stats() | |
| # Load text encoder | |
| print(" Loading text encoder weights...") | |
| tokenizer = AutoTokenizer.from_pretrained(encoder_path) | |
| model = AutoModel.from_pretrained(encoder_path) | |
| model = model.to("cuda").to(torch.bfloat16) | |
| model.eval() | |
| torch.cuda.synchronize() | |
| after_load = get_memory_stats() | |
| weights_gb = after_load["allocated"] - before["allocated"] | |
| print(f" Text encoder weights: {weights_gb:.3f} GB") | |
| # Cleanup | |
| del model, tokenizer | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return { | |
| "weights_gb": round(weights_gb, 3), | |
| } | |
| def profile_lm(checkpoint_dir: str, lm_models: Optional[List[str]] = None) -> Dict[str, Any]: | |
| """Profile LM model memory consumption.""" | |
| print("\n" + "=" * 60) | |
| print("Profiling 5Hz LM models") | |
| print("=" * 60) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| if lm_models is None: | |
| # Auto-detect available LM models | |
| lm_models = [] | |
| for name in os.listdir(checkpoint_dir): | |
| if "5Hz-lm" in name and os.path.isdir(os.path.join(checkpoint_dir, name)): | |
| lm_models.append(name) | |
| if not lm_models: | |
| print(" No LM models found") | |
| return {} | |
| lm_models.sort() | |
| results = {} | |
| for lm_name in lm_models: | |
| lm_path = os.path.join(checkpoint_dir, lm_name) | |
| if not os.path.exists(lm_path): | |
| print(f" LM model not found: {lm_path}") | |
| continue | |
| print(f"\n Profiling LM: {lm_name}") | |
| reset_memory() | |
| before = get_memory_stats() | |
| # Load model weights | |
| print(f" Loading model weights...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| lm_path, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ) | |
| model = model.to("cuda") | |
| model.eval() | |
| torch.cuda.synchronize() | |
| after_load = get_memory_stats() | |
| weights_gb = after_load["allocated"] - before["allocated"] | |
| print(f" Model weights: {weights_gb:.3f} GB") | |
| # Estimate KV cache memory for different max_model_len values | |
| # KV cache formula: 2 * num_layers * max_tokens * num_kv_heads * head_dim * dtype_size | |
| config = model.config | |
| num_layers = config.num_hidden_layers | |
| num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) | |
| head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| dtype_size = 2 # bfloat16 = 2 bytes | |
| kv_cache_estimates = {} | |
| for max_len in [2048, 4096]: | |
| # Per-token KV cache size | |
| per_token_bytes = 2 * num_layers * num_kv_heads * head_dim * dtype_size | |
| total_bytes = per_token_bytes * max_len | |
| total_gb = total_bytes / (1024**3) | |
| kv_cache_estimates[str(max_len)] = round(total_gb, 3) | |
| print(f" KV cache ({max_len} tokens): {total_gb:.3f} GB") | |
| results[lm_name] = { | |
| "weights_gb": round(weights_gb, 3), | |
| "kv_cache_estimates": kv_cache_estimates, | |
| "num_layers": num_layers, | |
| "num_kv_heads": num_kv_heads, | |
| "head_dim": head_dim, | |
| } | |
| # Cleanup | |
| del model | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return results | |
| def main(): | |
| parser = argparse.ArgumentParser(description="VRAM Profiling for ACE-Step 1.5") | |
| parser.add_argument("--component", type=str, default="all", | |
| choices=["all", "cuda_context", "dit", "vae", "text_encoder", "lm"], | |
| help="Component to profile (default: all)") | |
| parser.add_argument("--checkpoint-dir", type=str, default=None, | |
| help="Checkpoint directory (default: auto-detect)") | |
| parser.add_argument("--dit-config", type=str, default="acestep-v15-turbo", | |
| help="DiT model config name (default: acestep-v15-turbo)") | |
| parser.add_argument("--lm-models", type=str, nargs="*", default=None, | |
| help="LM models to profile (default: auto-detect)") | |
| parser.add_argument("--output", type=str, default=None, | |
| help="Output JSON file path") | |
| args = parser.parse_args() | |
| if not torch.cuda.is_available(): | |
| print("ERROR: CUDA is not available. This script requires a CUDA GPU.") | |
| sys.exit(1) | |
| # Auto-detect checkpoint directory | |
| if args.checkpoint_dir is None: | |
| args.checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints") | |
| if not os.path.exists(args.checkpoint_dir): | |
| print(f"ERROR: Checkpoint directory not found: {args.checkpoint_dir}") | |
| sys.exit(1) | |
| device_name = torch.cuda.get_device_name(0) | |
| total_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| print("=" * 60) | |
| print("ACE-Step 1.5 VRAM Profiler") | |
| print("=" * 60) | |
| print(f" GPU: {device_name}") | |
| print(f" Total VRAM: {total_mem:.2f} GB") | |
| print(f" Checkpoint dir: {args.checkpoint_dir}") | |
| print(f" Component: {args.component}") | |
| results = { | |
| "gpu_name": device_name, | |
| "total_vram_gb": round(total_mem, 3), | |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
| } | |
| components = [args.component] if args.component != "all" else [ | |
| "cuda_context", "dit", "vae", "text_encoder", "lm" | |
| ] | |
| for component in components: | |
| if component == "cuda_context": | |
| results["cuda_context"] = measure_cuda_context() | |
| elif component == "dit": | |
| results["dit"] = profile_dit(args.checkpoint_dir, args.dit_config) | |
| elif component == "vae": | |
| results["vae"] = profile_vae(args.checkpoint_dir) | |
| elif component == "text_encoder": | |
| results["text_encoder"] = profile_text_encoder(args.checkpoint_dir) | |
| elif component == "lm": | |
| results["lm"] = profile_lm(args.checkpoint_dir, args.lm_models) | |
| # Print summary | |
| print("\n" + "=" * 60) | |
| print("SUMMARY") | |
| print("=" * 60) | |
| if "cuda_context" in results: | |
| print(f" CUDA context: {results['cuda_context'].get('cuda_context_gb', 'N/A')} GB") | |
| if "dit" in results and results["dit"]: | |
| print(f" DiT weights ({results['dit'].get('config_path', '')}): {results['dit'].get('weights_gb', 'N/A')} GB") | |
| if "vae" in results and results["vae"]: | |
| print(f" VAE weights: {results['vae'].get('weights_gb', 'N/A')} GB") | |
| if "text_encoder" in results and results["text_encoder"]: | |
| print(f" Text encoder weights: {results['text_encoder'].get('weights_gb', 'N/A')} GB") | |
| if "lm" in results and results["lm"]: | |
| for lm_name, lm_data in results["lm"].items(): | |
| print(f" LM {lm_name} weights: {lm_data.get('weights_gb', 'N/A')} GB") | |
| # Calculate total base VRAM (all models loaded simultaneously) | |
| base_total = 0 | |
| if "cuda_context" in results: | |
| base_total += results["cuda_context"].get("cuda_context_gb", 0) | |
| if "dit" in results and results["dit"]: | |
| base_total += results["dit"].get("weights_gb", 0) | |
| if "vae" in results and results["vae"]: | |
| base_total += results["vae"].get("weights_gb", 0) | |
| if "text_encoder" in results and results["text_encoder"]: | |
| base_total += results["text_encoder"].get("weights_gb", 0) | |
| print(f"\n Base VRAM (DiT+VAE+TextEnc+CUDA): {base_total:.3f} GB") | |
| print(f" Remaining for LM + inference: {total_mem - base_total:.3f} GB") | |
| # Save results | |
| if args.output: | |
| output_path = args.output | |
| else: | |
| output_path = os.path.join(PROJECT_ROOT, "scripts", "vram_profile_results.json") | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| with open(output_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\n Results saved to: {output_path}") | |
| if __name__ == "__main__": | |
| main() | |