Ace-Step-Munk / profile_inference.py
OnyxlMunkey's picture
Add ACE-Step 1.5 Docker app
e961681
#!/usr/bin/env python3
"""
ACE-Step 1.5 Inference Profiler & Benchmark
Comprehensive profiling tool that supports all features, devices, and backends.
Uses the high-level inference API and built-in time_costs for accurate timing.
Modes:
profile - Profile a single generation run with detailed timing breakdown
benchmark - Run a matrix of configurations and produce a summary table
tier-test - Auto-test across simulated GPU tiers (4/6/8/12/16/24/48 GB)
understand - Profile the understand_music() API (audio codes -> metadata)
create_sample - Profile the create_sample() API (inspiration/simple mode)
format_sample - Profile the format_sample() API (caption+lyrics -> metadata)
Usage:
# Profile text2music with default settings
python profile_inference.py
# Profile with thinking enabled on MPS
python profile_inference.py --device mps --thinking
# Benchmark across configurations
python profile_inference.py --mode benchmark
# Test all GPU tiers automatically (the key feature!)
python profile_inference.py --mode tier-test
# Test specific tiers only
python profile_inference.py --mode tier-test --tiers 6 8 16
# Test tiers with LM enabled (where supported)
python profile_inference.py --mode tier-test --tier-with-lm
# Profile create_sample (inspiration mode)
python profile_inference.py --mode create_sample --sample-query "a soft Bengali love song"
# Profile understand mode
python profile_inference.py --mode understand
# Full profiling with cProfile
python profile_inference.py --detailed --llm-debug
"""
import time
import argparse
import sys
import os
import json
import tempfile
import traceback
from contextlib import contextmanager
from collections import defaultdict
from typing import Tuple, Dict, Any, List, Optional
# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
import torch
from acestep.inference import (
generate_music,
understand_music,
create_sample,
format_sample,
GenerationParams,
GenerationConfig,
GenerationResult,
)
from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.gpu_config import (
get_gpu_config,
set_global_gpu_config,
get_gpu_tier,
find_best_lm_model_on_disk,
is_lm_model_size_allowed,
GPUConfig,
VRAM_AUTO_OFFLOAD_THRESHOLD_GB,
)
# =============================================================================
# Device / Backend helpers
# =============================================================================
def resolve_device(device: str) -> str:
"""Resolve 'auto' device to the best available device."""
if device == "auto":
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
return device
def auto_detect_backend(device: str) -> str:
"""Auto-detect the best LLM backend for the resolved device."""
if device == "mps":
try:
import mlx.core # noqa: F401
return "mlx"
except ImportError:
return "pt"
if device.startswith("cuda"):
return "vllm"
return "pt"
def load_env_config() -> Dict[str, str]:
"""Load configuration defaults from .env file."""
env_config = {
"ACESTEP_CONFIG_PATH": "acestep-v15-turbo",
"ACESTEP_LM_MODEL_PATH": "acestep-5Hz-lm-0.6B",
"ACESTEP_DEVICE": "auto",
"ACESTEP_LM_BACKEND": "auto",
}
env_file = os.path.join(PROJECT_ROOT, ".env")
if os.path.exists(env_file):
with open(env_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
if key in env_config and value:
env_config[key] = value
return env_config
# =============================================================================
# Timer utilities
# =============================================================================
class PreciseTimer:
"""High-precision timer with GPU synchronization for accurate timing."""
def __init__(self, device: str = "cpu"):
self.device = device
self.timings: Dict[str, List[float]] = defaultdict(list)
self.enabled = True
def sync(self):
"""Synchronize GPU operations for accurate timing."""
if not self.enabled:
return
if self.device.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
elif self.device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
if hasattr(torch, "mps"):
torch.mps.synchronize()
elif self.device.startswith("xpu") and hasattr(torch, "xpu"):
torch.xpu.synchronize()
@contextmanager
def time(self, name: str):
"""Time a code section with GPU synchronization."""
if not self.enabled:
yield
return
self.sync()
start = time.perf_counter()
try:
yield
finally:
self.sync()
elapsed = time.perf_counter() - start
self.timings[name].append(elapsed)
def get_total(self, name: str) -> float:
return sum(self.timings.get(name, []))
def get_mean(self, name: str) -> float:
times = self.timings.get(name, [])
return sum(times) / len(times) if times else 0.0
def get_count(self, name: str) -> int:
return len(self.timings.get(name, []))
def reset(self):
self.timings.clear()
# =============================================================================
# Example config loader
# =============================================================================
def load_example_config(
example_file: str, cli_overrides: argparse.Namespace
) -> Tuple[Optional[GenerationParams], Optional[GenerationConfig]]:
"""Load configuration from example JSON file, applying CLI overrides."""
try:
with open(example_file, "r", encoding="utf-8") as f:
data = json.load(f)
params = GenerationParams(
caption=data.get("caption", ""),
lyrics=data.get("lyrics", ""),
bpm=data.get("bpm"),
keyscale=data.get("keyscale", ""),
timesignature=data.get("timesignature", ""),
vocal_language=data.get("language", "unknown"),
duration=(
cli_overrides.duration
if cli_overrides.duration is not None
else data.get("duration", -1.0)
),
thinking=cli_overrides.thinking,
use_cot_metas=cli_overrides.use_cot_metas,
use_cot_caption=cli_overrides.use_cot_caption,
use_cot_language=cli_overrides.use_cot_language,
use_constrained_decoding=cli_overrides.use_constrained_decoding,
inference_steps=(
cli_overrides.inference_steps
if cli_overrides.inference_steps is not None
else data.get("inference_steps", 8)
),
seed=(
cli_overrides.seed
if cli_overrides.seed is not None
else data.get("seed", 42)
),
task_type=cli_overrides.task_type,
lm_temperature=cli_overrides.lm_temperature,
lm_cfg_scale=cli_overrides.lm_cfg_scale,
guidance_scale=cli_overrides.guidance_scale,
reference_audio=cli_overrides.reference_audio,
src_audio=cli_overrides.src_audio,
)
config = GenerationConfig(
batch_size=(
cli_overrides.batch_size
if cli_overrides.batch_size is not None
else data.get("batch_size", 1)
),
seeds=[params.seed] if params.seed >= 0 else None,
use_random_seed=(params.seed < 0),
audio_format="flac",
)
return params, config
except Exception as e:
print(f" Failed to load example: {e}")
return None, None
# =============================================================================
# Printing helpers
# =============================================================================
def print_time_costs_breakdown(
time_costs: Dict[str, float], total_wall_time: float
):
"""Print a detailed timing breakdown from result.extra_outputs['time_costs']."""
print("\n" + "=" * 100)
print("PROFILING RESULTS")
print("=" * 100)
if not time_costs:
print("\n (No time_costs data available from the pipeline)")
print(f"\n Total wall time: {total_wall_time:.3f}s")
return
# Categorize keys
lm_keys = {
k: v
for k, v in time_costs.items()
if k.startswith("lm_") and isinstance(v, (int, float))
}
dit_keys = {
k: v
for k, v in time_costs.items()
if k.startswith("dit_") and isinstance(v, (int, float))
}
pipeline_keys = {
k: v
for k, v in time_costs.items()
if k.startswith("pipeline_") and isinstance(v, (int, float))
}
other_keys = {
k: v
for k, v in time_costs.items()
if not k.startswith(("lm_", "dit_", "pipeline_"))
and isinstance(v, (int, float))
}
print(f"\n{'COMPONENT':<50} {'TIME (s)':<12} {'% of wall':<10}")
print("-" * 72)
# LM timing
lm_total = lm_keys.get("lm_total_time", 0.0)
if lm_keys:
print(
f"\n{'LLM (5Hz Language Model)':<50} "
f"{lm_total:<12.3f} {100 * lm_total / total_wall_time:>6.1f}%"
)
for k, v in sorted(lm_keys.items()):
if k != "lm_total_time":
label = k.replace("lm_", " ")
print(
f" {label:<48} "
f"{v:<12.3f} {100 * v / total_wall_time:>6.1f}%"
)
# DiT timing
dit_total = dit_keys.get("dit_total_time_cost", 0.0)
if dit_keys:
print(
f"\n{'DiT (Diffusion Transformer)':<50} "
f"{dit_total:<12.3f} {100 * dit_total / total_wall_time:>6.1f}%"
)
for k, v in sorted(dit_keys.items()):
if k != "dit_total_time_cost":
label = k.replace("dit_", " ")
print(
f" {label:<48} "
f"{v:<12.3f} {100 * v / total_wall_time:>6.1f}%"
)
# Pipeline total
if pipeline_keys:
for k, v in sorted(pipeline_keys.items()):
print(
f"\n{'Pipeline: ' + k:<50} "
f"{v:<12.3f} {100 * v / total_wall_time:>6.1f}%"
)
# Other keys
if other_keys:
print(f"\n{'Other:':<50}")
for k, v in sorted(other_keys.items()):
print(
f" {k:<48} "
f"{v:<12.3f} {100 * v / total_wall_time:>6.1f}%"
)
# Overhead (wall time minus accounted time)
accounted = lm_total + dit_total
overhead = total_wall_time - accounted
if overhead > 0.01:
print(
f"\n{'Overhead (I/O, audio save, etc.)':<50} "
f"{overhead:<12.3f} {100 * overhead / total_wall_time:>6.1f}%"
)
print(f"\n{'TOTAL WALL TIME':<50} {total_wall_time:<12.3f} {'100.0%':>6}")
# Performance insights
print("\n" + "=" * 100)
print("PERFORMANCE INSIGHTS")
print("=" * 100)
if lm_total > 0 and dit_total > 0:
if lm_total > dit_total * 2:
print(
f"\n LLM is the bottleneck: {lm_total:.1f}s "
f"({100 * lm_total / total_wall_time:.0f}% of total)"
)
print(" Suggestions:")
print(" 1. Run with --llm-debug for token-level throughput analysis")
print(" 2. Try --no-constrained-decoding to reduce FSM overhead")
print(" 3. Compare backends: --lm-backend vllm vs pt vs mlx")
print(
" 4. Reduce lm_cfg_scale "
"(currently doubles forward passes if > 1.0)"
)
elif dit_total > lm_total * 2:
print(
f"\n DiT is the bottleneck: {dit_total:.1f}s "
f"({100 * dit_total / total_wall_time:.0f}% of total)"
)
print(" Suggestions:")
print(" 1. Reduce --inference-steps (turbo model supports 4-8)")
print(" 2. Reduce --duration")
print(" 3. Try --quantization int8_weight_only")
else:
print(
f"\n Balanced pipeline: LLM={lm_total:.1f}s, DiT={dit_total:.1f}s"
)
elif dit_total > 0:
print(f"\n DiT only (no LLM): {dit_total:.1f}s")
vae_time = dit_keys.get("dit_vae_decode_time_cost", 0.0)
diffusion_time = dit_keys.get(
"dit_diffusion_time_cost", dit_total - vae_time
)
if vae_time > 0:
print(
f" Diffusion: {diffusion_time:.1f}s, "
f"VAE decode: {vae_time:.1f}s"
)
def print_result_summary(result: GenerationResult, mode: str = "profile"):
"""Print a short summary of the generation result."""
if result.success:
n_audios = len(result.audios)
silent_count = sum(1 for a in result.audios if a.get("silent", False))
print(f"\n Success! Generated {n_audios} audio(s)", end="")
if silent_count:
print(f" ({silent_count} silent)", end="")
print()
else:
print(f"\n FAILED: {result.error}")
# =============================================================================
# Mode: profile (text2music and other task types)
# =============================================================================
def run_profile_mode(dit_handler, llm_handler, args, timer: PreciseTimer):
"""Run a single profiled generation."""
example_dir = "text2music"
example_file = os.path.join(
PROJECT_ROOT, "examples", example_dir, args.example
)
if not os.path.exists(example_file):
print(f"\n Example not found: {example_file}")
sys.exit(1)
print(f"\n Loading example: {args.example}")
params, config = load_example_config(example_file, args)
if not params or not config:
print(" Failed to load example config")
sys.exit(1)
caption_preview = (
params.caption[:80] + "..."
if len(params.caption) > 80
else params.caption
)
print(f" Caption: {caption_preview}")
print(
f" Task: {params.task_type}, Batch: {config.batch_size}, "
f"Steps: {params.inference_steps}"
)
print(
f" Thinking: {params.thinking}, CoT Metas: {params.use_cot_metas}, "
f"CoT Caption: {params.use_cot_caption}"
)
# Use a temporary directory for output (don't pollute project root)
save_dir = tempfile.mkdtemp(prefix="acestep_profile_")
# Warmup
if not args.no_warmup:
print("\n" + "-" * 100)
print("WARMUP RUN")
print("-" * 100)
warmup_params = GenerationParams(
caption=params.caption,
lyrics=params.lyrics,
bpm=params.bpm,
keyscale=params.keyscale,
timesignature=params.timesignature,
vocal_language=params.vocal_language,
duration=params.duration,
thinking=params.thinking,
use_cot_metas=params.use_cot_metas,
use_cot_caption=params.use_cot_caption,
use_cot_language=params.use_cot_language,
use_constrained_decoding=params.use_constrained_decoding,
inference_steps=params.inference_steps,
seed=42,
task_type=params.task_type,
lm_temperature=params.lm_temperature,
lm_cfg_scale=params.lm_cfg_scale,
guidance_scale=params.guidance_scale,
)
warmup_config = GenerationConfig(
batch_size=1, seeds=[42], use_random_seed=False, audio_format="flac"
)
warmup_start = time.perf_counter()
warmup_result = generate_music(
dit_handler, llm_handler, warmup_params, warmup_config,
save_dir=save_dir,
)
warmup_time = time.perf_counter() - warmup_start
print(f" Warmup completed: {warmup_time:.2f}s")
if not warmup_result.success:
print(f" Warning: warmup failed: {warmup_result.error}")
timer.reset()
# Profiling run
print("\n" + "=" * 100)
print("PROFILING RUN")
print("=" * 100)
# Optional cProfile
prof = None
if args.detailed:
import cProfile
prof = cProfile.Profile()
prof.enable()
timer.sync()
total_start = time.perf_counter()
result = generate_music(
dit_handler, llm_handler, params, config, save_dir=save_dir
)
timer.sync()
total_wall_time = time.perf_counter() - total_start
if args.detailed and prof:
prof.disable()
_print_cprofile(prof)
# Print results
print_result_summary(result, "profile")
time_costs = (
result.extra_outputs.get("time_costs", {}) if result.success else {}
)
print_time_costs_breakdown(time_costs, total_wall_time)
# Cleanup temp dir
_cleanup_dir(save_dir)
return result, total_wall_time
# =============================================================================
# Mode: benchmark
# =============================================================================
def run_benchmark_mode(dit_handler, llm_handler, args, timer: PreciseTimer):
"""Run a matrix of configurations and produce a summary table."""
example_file = os.path.join(
PROJECT_ROOT, "examples", "text2music", args.example
)
if not os.path.exists(example_file):
print(f"\n Example not found: {example_file}")
sys.exit(1)
with open(example_file, "r", encoding="utf-8") as f:
example_data = json.load(f)
save_dir = tempfile.mkdtemp(prefix="acestep_bench_")
# Define benchmark matrix
durations = [30, 60, 120]
batch_sizes = [1, 2]
thinking_options = (
[False, True] if llm_handler.llm_initialized else [False]
)
inference_steps_options = [8]
# Clamp to GPU limits
gpu_config = get_gpu_config()
max_dur = gpu_config.max_duration_without_lm
max_batch = gpu_config.max_batch_size_without_lm
durations = [d for d in durations if d <= max_dur]
batch_sizes = [b for b in batch_sizes if b <= max_batch]
if not durations:
durations = [30]
if not batch_sizes:
batch_sizes = [1]
configs = []
for dur in durations:
for bs in batch_sizes:
for think in thinking_options:
for steps in inference_steps_options:
configs.append(
{
"duration": dur,
"batch_size": bs,
"thinking": think,
"inference_steps": steps,
}
)
print(f"\n Running {len(configs)} benchmark configurations...")
print(f" Durations: {durations}, Batch sizes: {batch_sizes}")
print(f" Thinking: {thinking_options}, Steps: {inference_steps_options}")
# Warmup
if not args.no_warmup:
print("\n Warmup run...")
warmup_params = GenerationParams(
caption=example_data.get("caption", ""),
lyrics=example_data.get("lyrics", ""),
duration=30,
thinking=False,
inference_steps=8,
seed=42,
)
warmup_config = GenerationConfig(
batch_size=1, seeds=[42], use_random_seed=False, audio_format="flac"
)
generate_music(
dit_handler, llm_handler, warmup_params, warmup_config,
save_dir=save_dir,
)
print(" Warmup done.")
# Run benchmark
results = []
for i, cfg in enumerate(configs):
label = (
f"dur={cfg['duration']}s, bs={cfg['batch_size']}, "
f"think={cfg['thinking']}, steps={cfg['inference_steps']}"
)
print(f"\n [{i + 1}/{len(configs)}] {label}")
params = GenerationParams(
caption=example_data.get("caption", ""),
lyrics=example_data.get("lyrics", ""),
bpm=example_data.get("bpm"),
keyscale=example_data.get("keyscale", ""),
timesignature=example_data.get("timesignature", ""),
vocal_language=example_data.get("language", "unknown"),
duration=cfg["duration"],
thinking=cfg["thinking"],
use_cot_metas=cfg["thinking"],
use_cot_caption=cfg["thinking"],
use_cot_language=cfg["thinking"],
use_constrained_decoding=args.use_constrained_decoding,
inference_steps=cfg["inference_steps"],
seed=42,
lm_temperature=args.lm_temperature,
lm_cfg_scale=args.lm_cfg_scale,
guidance_scale=args.guidance_scale,
)
config = GenerationConfig(
batch_size=cfg["batch_size"],
seeds=[42 + j for j in range(cfg["batch_size"])],
use_random_seed=False,
audio_format="flac",
)
timer.sync()
t0 = time.perf_counter()
result = generate_music(
dit_handler, llm_handler, params, config, save_dir=save_dir
)
timer.sync()
wall_time = time.perf_counter() - t0
tc = (
result.extra_outputs.get("time_costs", {})
if result.success
else {}
)
entry = {
"config": cfg,
"wall_time": wall_time,
"success": result.success,
"error": result.error,
"lm_time": tc.get("lm_total_time", 0.0),
"dit_time": tc.get("dit_total_time_cost", 0.0),
"vae_time": tc.get("dit_vae_decode_time_cost", 0.0),
"n_audios": len(result.audios) if result.success else 0,
}
results.append(entry)
status = "OK" if result.success else f"FAIL: {result.error}"
print(
f" {status} | wall={wall_time:.1f}s, "
f"lm={entry['lm_time']:.1f}s, dit={entry['dit_time']:.1f}s"
)
# Print summary table
print("\n" + "=" * 120)
print("BENCHMARK SUMMARY")
print("=" * 120)
header = (
f"{'Duration':<10} {'Batch':<7} {'Think':<7} {'Steps':<7} "
f"{'Wall(s)':<10} {'LM(s)':<10} {'DiT(s)':<10} "
f"{'VAE(s)':<10} {'Status':<10}"
)
print(header)
print("-" * 120)
for entry in results:
cfg = entry["config"]
status = "OK" if entry["success"] else "FAIL"
print(
f"{cfg['duration']:<10} {cfg['batch_size']:<7} "
f"{str(cfg['thinking']):<7} {cfg['inference_steps']:<7} "
f"{entry['wall_time']:<10.2f} {entry['lm_time']:<10.2f} "
f"{entry['dit_time']:<10.2f} {entry['vae_time']:<10.2f} "
f"{status:<10}"
)
# Save benchmark results as JSON
if args.benchmark_output:
output_path = args.benchmark_output
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, default=str)
print(f"\n Benchmark results saved to: {output_path}")
_cleanup_dir(save_dir)
return results
# =============================================================================
# Mode: tier-test (THE KEY FEATURE)
# =============================================================================
def _get_vram_info_str() -> str:
"""Get current VRAM usage string for logging."""
if not torch.cuda.is_available():
return "N/A"
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
return f"alloc={allocated:.2f}GB, reserved={reserved:.2f}GB"
def _run_single_tier_test(
sim_gb: float,
gpu_config: GPUConfig,
args,
example_data: Dict,
checkpoint_dir: str,
disk_lm_models: List[str],
*,
offload_override: Optional[bool] = None,
offload_dit_override: Optional[bool] = None,
quantization_override: Optional[str] = "USE_DEFAULT",
test_variant: str = "default",
batch_size_override: Optional[int] = None,
use_lm_override: Optional[bool] = None,
) -> Dict[str, Any]:
"""
Run a single tier test with the given configuration.
Args:
sim_gb: Simulated VRAM in GB
gpu_config: GPU configuration for this tier
args: CLI arguments
example_data: Example JSON data for generation
checkpoint_dir: Path to checkpoints directory
disk_lm_models: List of LM models found on disk
offload_override: If not None, override offload_to_cpu setting
offload_dit_override: If not None, override offload_dit_to_cpu setting
quantization_override: If not "USE_DEFAULT", override quantization setting
(None means no quantization, "int8_weight_only" etc.)
test_variant: Label for this test variant ("default", "no-quant", "no-offload")
batch_size_override: If not None, override batch size (used by batch boundary tests)
use_lm_override: If not None, force LM on (True) or off (False)
Returns:
Result dictionary for this test
"""
tier = gpu_config.tier
# Determine test configuration
if use_lm_override is not None:
use_lm = use_lm_override and gpu_config.init_lm_default and bool(gpu_config.available_lm_models)
else:
use_lm = args.tier_with_lm and gpu_config.init_lm_default and bool(gpu_config.available_lm_models)
if hasattr(torch, "xpu") and torch.xpu.is_available():
offload_override = False
if offload_override is not None:
offload = offload_override
else:
offload = gpu_config.offload_to_cpu_default
if offload_dit_override is not None:
offload_dit = offload_dit_override
else:
offload_dit = gpu_config.offload_dit_to_cpu_default
if quantization_override != "USE_DEFAULT":
quantization = quantization_override
else:
quantization = "int8_weight_only" if gpu_config.quantization_default else None
# Find LM model on disk
lm_model = None
lm_backend = gpu_config.recommended_backend
if use_lm:
lm_model = find_best_lm_model_on_disk(
gpu_config.recommended_lm_model, disk_lm_models
)
if not lm_model:
print(f" ⚠️ No compatible LM model on disk for tier {tier}, skipping LM")
use_lm = False
# Clamp duration to tier limit
test_duration = args.tier_duration
max_dur = gpu_config.max_duration_with_lm if use_lm else gpu_config.max_duration_without_lm
if test_duration > max_dur:
test_duration = max_dur
print(f" Duration clamped to {test_duration}s (tier limit)")
batch_size = batch_size_override if batch_size_override is not None else 1
print(f"\n Test config [{test_variant}]: duration={test_duration}s, batch={batch_size}, LM={use_lm}")
if use_lm:
print(f" LM model: {lm_model}, backend: {lm_backend}")
print(f" offload={offload}, offload_dit={offload_dit}, quant={quantization}")
# Enforce VRAM cap
if torch.cuda.is_available():
total_bytes = torch.cuda.get_device_properties(0).total_memory
total_gb = total_bytes / (1024 ** 3)
if sim_gb < total_gb:
reference_context_gb = 0.5
allocator_budget_gb = max(0.5, sim_gb - reference_context_gb)
fraction = max(0.01, min(1.0, allocator_budget_gb / total_gb))
torch.cuda.set_per_process_memory_fraction(fraction)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Initialize result entry
result_entry = {
"tier_gb": sim_gb,
"tier": tier,
"test_variant": test_variant,
"use_lm": use_lm,
"lm_model": lm_model,
"lm_backend": lm_backend,
"offload": offload,
"offload_dit": offload_dit,
"quantization": quantization,
"duration": test_duration,
"batch_size": batch_size,
"init_success": False,
"gen_success": False,
"wall_time": 0.0,
"error": None,
"peak_vram_gb": 0.0,
}
dit_handler = None
llm_handler = None
try:
print(f"\n Initializing DiT handler... ({_get_vram_info_str()})")
dit_handler = AceStepHandler()
# Determine flash attention availability
use_flash_attention = False
try:
import flash_attn # noqa: F401
use_flash_attention = True
except ImportError:
pass
# compile_model must be True when quantization is used;
# --tier-skip-compile can skip it for non-quantized tiers to save time
if quantization:
compile_model = True
elif args.tier_skip_compile:
compile_model = False
else:
compile_model = gpu_config.compile_model_default
status_dit, success_dit = dit_handler.initialize_service(
project_root=PROJECT_ROOT,
config_path=args.config_path,
device="auto",
use_flash_attention=use_flash_attention,
compile_model=compile_model,
offload_to_cpu=offload,
offload_dit_to_cpu=offload_dit,
quantization=quantization,
)
if not success_dit:
result_entry["error"] = f"DiT init failed: {status_dit}"
print(f" ❌ DiT init failed: {status_dit}")
_cleanup_handlers(dit_handler, None)
return result_entry
print(f" ✅ DiT ready ({_get_vram_info_str()})")
llm_handler = LLMHandler()
if use_lm:
print(f" Initializing LLM handler (backend={lm_backend})... ({_get_vram_info_str()})")
status_llm, success_llm = llm_handler.initialize(
checkpoint_dir=checkpoint_dir,
lm_model_path=lm_model,
backend=lm_backend,
device="auto",
offload_to_cpu=offload,
dtype=None,
)
if success_llm:
print(f" ✅ LLM ready ({_get_vram_info_str()})")
else:
print(f" ⚠️ LLM init failed: {status_llm}")
use_lm = False
result_entry["use_lm"] = False
result_entry["error"] = f"LM init failed (non-fatal): {status_llm}"
result_entry["init_success"] = True
except torch.cuda.OutOfMemoryError as e:
result_entry["error"] = f"Init OOM: {e}"
print(f" ❌ Init OOM: {e}")
_cleanup_handlers(dit_handler, llm_handler)
return result_entry
except Exception as e:
result_entry["error"] = f"Init exception: {e}"
print(f" ❌ Init exception: {e}")
traceback.print_exc()
_cleanup_handlers(dit_handler, llm_handler)
return result_entry
# Run generation
try:
print(f"\n Running generation... ({_get_vram_info_str()})")
save_dir = tempfile.mkdtemp(prefix=f"acestep_tier{int(sim_gb)}_{test_variant}_")
params = GenerationParams(
caption=example_data.get("caption", ""),
lyrics=example_data.get("lyrics", ""),
bpm=example_data.get("bpm"),
keyscale=example_data.get("keyscale", ""),
timesignature=example_data.get("timesignature", ""),
vocal_language=example_data.get("language", "unknown"),
duration=test_duration,
thinking=use_lm,
use_cot_metas=use_lm,
use_cot_caption=False,
use_cot_language=False,
use_constrained_decoding=True,
inference_steps=8,
seed=42,
lm_temperature=0.85,
lm_cfg_scale=2.0,
guidance_scale=7.0,
)
config = GenerationConfig(
batch_size=batch_size,
seeds=[42 + j for j in range(batch_size)],
use_random_seed=False,
audio_format="flac",
)
# When testing batch boundaries, temporarily override the GPU tier config's
# max_batch limits so that inference.py's clamping doesn't reduce our test
# batch size. We restore the original values after the test.
_patched_tier_config = False
_orig_batch_with_lm = None
_orig_batch_without_lm = None
if batch_size_override is not None and batch_size_override > 1:
from acestep.gpu_config import GPU_TIER_CONFIGS as _tier_configs
tier = gpu_config.tier
if tier in _tier_configs:
_patched_tier_config = True
_orig_batch_with_lm = _tier_configs[tier]["max_batch_size_with_lm"]
_orig_batch_without_lm = _tier_configs[tier]["max_batch_size_without_lm"]
_tier_configs[tier]["max_batch_size_with_lm"] = max(batch_size_override, _orig_batch_with_lm)
_tier_configs[tier]["max_batch_size_without_lm"] = max(batch_size_override, _orig_batch_without_lm)
t0 = time.perf_counter()
try:
result = generate_music(
dit_handler, llm_handler, params, config, save_dir=save_dir
)
finally:
# Restore original tier config values
if _patched_tier_config:
_tier_configs[tier]["max_batch_size_with_lm"] = _orig_batch_with_lm
_tier_configs[tier]["max_batch_size_without_lm"] = _orig_batch_without_lm
wall_time = time.perf_counter() - t0
result_entry["wall_time"] = wall_time
result_entry["gen_success"] = result.success
if result.success:
tc = result.extra_outputs.get("time_costs", {})
result_entry["lm_time"] = tc.get("lm_total_time", 0.0)
result_entry["dit_time"] = tc.get("dit_total_time_cost", 0.0)
result_entry["vae_time"] = tc.get("dit_vae_decode_time_cost", 0.0)
n_audios = len(result.audios)
print(f" ✅ [{test_variant}] Generation OK: {n_audios} audio(s) in {wall_time:.1f}s")
else:
result_entry["error"] = result.error
print(f" ❌ [{test_variant}] Generation FAILED: {result.error}")
_cleanup_dir(save_dir)
except torch.cuda.OutOfMemoryError as e:
result_entry["error"] = f"OOM: {e}"
print(f" ❌ [{test_variant}] OOM ERROR: {e}")
except Exception as e:
result_entry["error"] = f"Generation exception: {e}"
print(f" ❌ [{test_variant}] Exception: {e}")
traceback.print_exc()
# Record peak VRAM
if torch.cuda.is_available():
peak_bytes = torch.cuda.max_memory_allocated()
result_entry["peak_vram_gb"] = peak_bytes / (1024 ** 3)
print(f" Peak VRAM: {result_entry['peak_vram_gb']:.2f}GB")
# Cleanup
_cleanup_handlers(dit_handler, llm_handler)
return result_entry
def run_tier_test_mode(args):
"""
Automatically test inference across multiple simulated GPU tiers.
For each tier:
1. Set MAX_CUDA_VRAM to simulate the VRAM limit
2. Initialize gpu_config for that tier
3. Initialize DiT + (optionally) LLM handlers with tier-appropriate settings
4. Run a short generation and verify it completes without OOM
5. Report results
When --tier-boundary is enabled, each tier is tested with up to 3 configurations:
- default: tier's default settings (quantization + offload as configured)
- no-quant: same as default but with quantization disabled
- no-offload: no quantization AND no CPU offload (all models on GPU)
This replaces the manual workflow of:
MAX_CUDA_VRAM=8 uv run acestep → click UI → wait → check
"""
# Determine which tiers to test
default_tiers = [4, 6, 8, 12, 16, 24, 48]
tiers_to_test = args.tiers if args.tiers else default_tiers
# Load example for generation
example_file = os.path.join(
PROJECT_ROOT, "examples", "text2music", args.example
)
if not os.path.exists(example_file):
print(f"\n Example not found: {example_file}")
sys.exit(1)
with open(example_file, "r", encoding="utf-8") as f:
example_data = json.load(f)
# Scan available LM models on disk
checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints")
disk_lm_models = []
if os.path.exists(checkpoint_dir):
for item in sorted(os.listdir(checkpoint_dir)):
if os.path.isdir(os.path.join(checkpoint_dir, item)) and item.startswith("acestep-5Hz-lm-"):
disk_lm_models.append(item)
boundary_mode = getattr(args, "tier_boundary", False)
batch_boundary_mode = getattr(args, "tier_batch_boundary", False)
print(f"\n Tiers to test: {tiers_to_test}")
print(f" LM models on disk: {disk_lm_models}")
print(f" Test with LM: {args.tier_with_lm}")
print(f" Test duration: {args.tier_duration}s")
print(f" Boundary testing: {boundary_mode}")
print(f" Batch boundary testing: {batch_boundary_mode}")
print(f" Example: {args.example}")
# Results collector
all_results = []
for sim_gb in tiers_to_test:
print("\n" + "=" * 120)
print(f" TIER TEST: {sim_gb}GB simulated VRAM")
print("=" * 120)
# Configure GPU simulation
os.environ["MAX_CUDA_VRAM"] = str(sim_gb)
# Force re-detection of GPU config
gpu_config = get_gpu_config(gpu_memory_gb=float(sim_gb))
set_global_gpu_config(gpu_config)
tier = gpu_config.tier
print(f" Tier: {tier}")
print(f" init_lm_default: {gpu_config.init_lm_default}")
print(f" available_lm_models: {gpu_config.available_lm_models}")
print(f" recommended_lm_model: {gpu_config.recommended_lm_model}")
print(f" recommended_backend: {gpu_config.recommended_backend}")
print(f" lm_backend_restriction: {gpu_config.lm_backend_restriction}")
print(f" offload_to_cpu: {gpu_config.offload_to_cpu_default}")
print(f" offload_dit_to_cpu: {gpu_config.offload_dit_to_cpu_default}")
print(f" quantization: {gpu_config.quantization_default}")
print(f" max_duration_with_lm: {gpu_config.max_duration_with_lm}s")
print(f" max_duration_without_lm: {gpu_config.max_duration_without_lm}s")
print(f" max_batch_with_lm: {gpu_config.max_batch_size_with_lm}")
print(f" max_batch_without_lm: {gpu_config.max_batch_size_without_lm}")
# ---- Test 1: Default configuration ----
print(f"\n --- Variant: default ---")
result_default = _run_single_tier_test(
sim_gb, gpu_config, args, example_data,
checkpoint_dir, disk_lm_models,
test_variant="default",
)
all_results.append(result_default)
if boundary_mode:
# ---- Test 2: No quantization (keep offload as default) ----
# Skip if the tier already doesn't use quantization (no point re-testing)
if gpu_config.quantization_default:
print(f"\n --- Variant: no-quant (offload={gpu_config.offload_to_cpu_default}) ---")
result_no_quant = _run_single_tier_test(
sim_gb, gpu_config, args, example_data,
checkpoint_dir, disk_lm_models,
quantization_override=None,
test_variant="no-quant",
)
all_results.append(result_no_quant)
else:
print(f"\n --- Variant: no-quant — SKIPPED (tier already has quantization=False) ---")
# ---- Test 3: No quantization AND no offload ----
# Skip if the tier already has both disabled
# Also skip if simulated VRAM is too small — the unquantized DiT model
# alone needs ~6GB; without offload there is no room left for VAE decode,
# which causes a fallback to CPU VAE with tiny chunk_size and 20+ hour runs.
MIN_VRAM_FOR_NO_OFFLOAD = 8 # GB — DiT (~6GB) + VAE headroom (~2GB)
if sim_gb < MIN_VRAM_FOR_NO_OFFLOAD:
print(f"\n --- Variant: no-offload — SKIPPED (simulated {sim_gb}GB < {MIN_VRAM_FOR_NO_OFFLOAD}GB minimum for no-offload) ---")
elif gpu_config.quantization_default or gpu_config.offload_to_cpu_default:
print(f"\n --- Variant: no-offload (quant=None, offload=False) ---")
result_no_offload = _run_single_tier_test(
sim_gb, gpu_config, args, example_data,
checkpoint_dir, disk_lm_models,
offload_override=False,
offload_dit_override=False,
quantization_override=None,
test_variant="no-offload",
)
all_results.append(result_no_offload)
else:
print(f"\n --- Variant: no-offload — SKIPPED (tier already has offload=False, quant=False) ---")
if batch_boundary_mode:
# ---- Batch boundary tests: escalate batch size until OOM ----
BATCH_SIZES_TO_TEST = [1, 2, 4, 8]
# Test WITHOUT LM
print(f"\n --- Batch boundary: without LM ---")
for bs in BATCH_SIZES_TO_TEST:
print(f"\n --- Variant: batch-noLM-{bs} (batch_size={bs}, no LM) ---")
result_batch = _run_single_tier_test(
sim_gb, gpu_config, args, example_data,
checkpoint_dir, disk_lm_models,
test_variant=f"batch-noLM-{bs}",
batch_size_override=bs,
use_lm_override=False,
)
all_results.append(result_batch)
if not result_batch["gen_success"]:
print(f" ⚠️ Batch size {bs} failed without LM — stopping escalation")
break
# Test WITH LM (if tier supports it)
if gpu_config.init_lm_default and bool(gpu_config.available_lm_models):
print(f"\n --- Batch boundary: with LM ---")
for bs in BATCH_SIZES_TO_TEST:
print(f"\n --- Variant: batch-LM-{bs} (batch_size={bs}, with LM) ---")
result_batch_lm = _run_single_tier_test(
sim_gb, gpu_config, args, example_data,
checkpoint_dir, disk_lm_models,
test_variant=f"batch-LM-{bs}",
batch_size_override=bs,
use_lm_override=True,
)
all_results.append(result_batch_lm)
if not result_batch_lm["gen_success"]:
print(f" ⚠️ Batch size {bs} failed with LM — stopping escalation")
break
# ---- Print summary ----
_print_tier_test_summary(all_results)
if boundary_mode:
_print_boundary_summary(all_results)
if batch_boundary_mode:
_print_batch_boundary_summary(all_results)
# Save results
if args.benchmark_output:
with open(args.benchmark_output, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2, default=str)
print(f"\n Results saved to: {args.benchmark_output}")
return all_results
def _cleanup_handlers(dit_handler, llm_handler):
"""Clean up handlers and free GPU memory."""
try:
if dit_handler is not None:
if hasattr(dit_handler, 'model') and dit_handler.model is not None:
dit_handler.model = None
if hasattr(dit_handler, 'vae') and dit_handler.vae is not None:
dit_handler.vae = None
if hasattr(dit_handler, 'text_encoder') and dit_handler.text_encoder is not None:
dit_handler.text_encoder = None
del dit_handler
except Exception:
pass
try:
if llm_handler is not None:
if hasattr(llm_handler, 'llm') and llm_handler.llm is not None:
llm_handler.llm = None
del llm_handler
except Exception:
pass
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def _print_tier_test_summary(results: List[Dict]):
"""Print a summary table of all tier test results."""
# Detect if any result has a test_variant (boundary mode)
has_variants = any(r.get("test_variant", "default") != "default" for r in results)
print("\n" + "=" * 160)
print("TIER TEST SUMMARY")
print("=" * 160)
if has_variants:
header = (
f"{'VRAM':>6} {'Tier':<10} {'Variant':<12} {'LM':>4} {'LM Model':<24} {'Backend':<8} "
f"{'Offload':<8} {'Quant':<6} {'Init':>5} {'Gen':>5} "
f"{'Wall(s)':>8} {'Peak(GB)':>9} {'Status':<30}"
)
else:
header = (
f"{'VRAM':>6} {'Tier':<10} {'LM':>4} {'LM Model':<28} {'Backend':<8} "
f"{'Offload':<8} {'Quant':<6} {'Init':>5} {'Gen':>5} "
f"{'Wall(s)':>8} {'Peak(GB)':>9} {'Status':<30}"
)
print(header)
print("-" * 160)
pass_count = 0
fail_count = 0
for r in results:
lm_model_short = (r.get("lm_model") or "-")
max_lm_len = 22 if has_variants else 26
if len(lm_model_short) > max_lm_len:
lm_model_short = lm_model_short[:max_lm_len] + ".."
init_ok = "✅" if r["init_success"] else "❌"
gen_ok = "✅" if r["gen_success"] else "❌"
status = "PASS" if r["gen_success"] else (r.get("error", "FAIL") or "FAIL")
if len(status) > 28:
status = status[:28] + ".."
if r["gen_success"]:
pass_count += 1
else:
fail_count += 1
quant = "int8" if r.get("quantization") else "-"
variant = r.get("test_variant", "default")
if has_variants:
print(
f"{r['tier_gb']:5d}GB {r['tier']:<10} {variant:<12} "
f"{'Y' if r['use_lm'] else 'N':>4} {lm_model_short:<24} "
f"{r.get('lm_backend', '-'):<8} "
f"{'Y' if r['offload'] else 'N':<8} {quant:<6} "
f"{init_ok:>5} {gen_ok:>5} "
f"{r['wall_time']:>8.1f} {r.get('peak_vram_gb', 0):>9.2f} "
f"{status:<30}"
)
else:
print(
f"{r['tier_gb']:5d}GB {r['tier']:<10} "
f"{'Y' if r['use_lm'] else 'N':>4} {lm_model_short:<28} "
f"{r.get('lm_backend', '-'):<8} "
f"{'Y' if r['offload'] else 'N':<8} {quant:<6} "
f"{init_ok:>5} {gen_ok:>5} "
f"{r['wall_time']:>8.1f} {r.get('peak_vram_gb', 0):>9.2f} "
f"{status:<30}"
)
print("-" * 160)
print(f" Total: {len(results)} tests run, {pass_count} PASSED, {fail_count} FAILED")
def _print_boundary_summary(results: List[Dict]):
"""
Print a boundary analysis summary showing the minimum tier for each capability.
Analyzes results from boundary testing to determine:
- Minimum tier that works WITHOUT INT8 quantization
- Minimum tier that works WITHOUT CPU offload (and without quantization)
"""
print("\n" + "=" * 100)
print("BOUNDARY ANALYSIS")
print("=" * 100)
print()
print(" This analysis shows the minimum VRAM tier at which each optimization")
print(" can be safely disabled while still completing inference successfully.")
print()
# Collect results by variant
no_quant_results = [r for r in results if r.get("test_variant") == "no-quant"]
no_offload_results = [r for r in results if r.get("test_variant") == "no-offload"]
default_results = [r for r in results if r.get("test_variant") == "default"]
# Also consider default results where the tier already has quant/offload disabled
# (e.g., tier6b default already has quantization=False)
for r in default_results:
if not r.get("quantization") and r not in no_quant_results:
# This tier's default already runs without quantization
no_quant_results.append(r)
if not r.get("offload") and not r.get("quantization") and r not in no_offload_results:
# This tier's default already runs without offload and without quantization
no_offload_results.append(r)
# Sort by VRAM
no_quant_results.sort(key=lambda r: r["tier_gb"])
no_offload_results.sort(key=lambda r: r["tier_gb"])
# Find minimum passing tier for each capability
def _find_min_passing(result_list, capability_name):
passing = [r for r in result_list if r.get("gen_success")]
failing = [r for r in result_list if not r.get("gen_success")]
if passing:
min_pass = passing[0]
print(f" {capability_name}:")
print(f" Minimum tier: {min_pass['tier']} ({min_pass['tier_gb']}GB)")
print(f" Peak VRAM: {min_pass.get('peak_vram_gb', 0):.2f}GB")
if failing:
max_fail = failing[-1]
print(f" Last failure: {max_fail['tier']} ({max_fail['tier_gb']}GB) — {max_fail.get('error', 'unknown')[:60]}")
else:
if failing:
print(f" {capability_name}:")
print(f" ❌ No tier passed this test. All tested tiers failed.")
for r in failing:
err = (r.get("error") or "unknown")[:50]
print(f" {r['tier_gb']}GB ({r['tier']}): {err}")
else:
print(f" {capability_name}:")
print(f" ⚠️ No test results available for this capability.")
print()
return passing[0] if passing else None
min_no_quant = _find_min_passing(no_quant_results, "Without INT8 Quantization")
min_no_offload = _find_min_passing(no_offload_results, "Without CPU Offload (and no quantization)")
# Print compact summary table
print(" " + "-" * 60)
print(f" {'Capability':<45} {'Min Tier':<10} {'VRAM':>6}")
print(" " + "-" * 60)
if min_no_quant:
print(f" {'No INT8 Quantization':<45} {min_no_quant['tier']:<10} {min_no_quant['tier_gb']:>5}GB")
else:
print(f" {'No INT8 Quantization':<45} {'N/A':<10} {'N/A':>6}")
if min_no_offload:
print(f" {'No CPU Offload (all models on GPU)':<45} {min_no_offload['tier']:<10} {min_no_offload['tier_gb']:>5}GB")
else:
print(f" {'No CPU Offload (all models on GPU)':<45} {'N/A':<10} {'N/A':>6}")
print(" " + "-" * 60)
print()
print(" Note: These boundaries are empirical and may vary based on:")
print(" - DiT model variant (turbo vs base)")
print(" - Whether LM is enabled (--tier-with-lm)")
print(" - Generation duration and batch size")
print(" - Flash attention availability")
def _print_batch_boundary_summary(results: List[Dict]):
"""
Print a batch boundary analysis summary showing the maximum safe batch size per tier.
Analyzes results from batch boundary testing to determine:
- Maximum batch size WITHOUT LM for each tier
- Maximum batch size WITH LM for each tier
"""
print("\n" + "=" * 120)
print("BATCH BOUNDARY ANALYSIS")
print("=" * 120)
print()
print(" This analysis shows the maximum batch size that completed successfully")
print(" for each simulated VRAM tier.")
print()
# Collect batch boundary results
batch_no_lm = [r for r in results if r.get("test_variant", "").startswith("batch-noLM-")]
batch_with_lm = [r for r in results if r.get("test_variant", "").startswith("batch-LM-")]
# Group by tier_gb
def _group_by_tier(result_list):
groups = {}
for r in result_list:
tier_gb = r["tier_gb"]
if tier_gb not in groups:
groups[tier_gb] = {"tier": r["tier"], "results": []}
groups[tier_gb]["results"].append(r)
return groups
no_lm_groups = _group_by_tier(batch_no_lm)
with_lm_groups = _group_by_tier(batch_with_lm)
# Find max passing batch per tier
def _max_passing_batch(group_results):
max_bs = 0
peak_vram = 0.0
for r in group_results:
if r.get("gen_success"):
bs = r.get("batch_size", 1)
if bs > max_bs:
max_bs = bs
peak_vram = r.get("peak_vram_gb", 0)
return max_bs, peak_vram
# Collect all tier_gb values
all_tier_gbs = sorted(set(list(no_lm_groups.keys()) + list(with_lm_groups.keys())))
# Print table
print(f" {'VRAM':>6} {'Tier':<12} {'Max Batch (no LM)':>18} {'Peak VRAM':>10} {'Max Batch (with LM)':>20} {'Peak VRAM':>10}")
print(" " + "-" * 90)
summary_rows = []
for tier_gb in all_tier_gbs:
tier_name = no_lm_groups.get(tier_gb, with_lm_groups.get(tier_gb, {})).get("tier", "?")
no_lm_max, no_lm_peak = (0, 0.0)
if tier_gb in no_lm_groups:
no_lm_max, no_lm_peak = _max_passing_batch(no_lm_groups[tier_gb]["results"])
with_lm_max, with_lm_peak = (0, 0.0)
if tier_gb in with_lm_groups:
with_lm_max, with_lm_peak = _max_passing_batch(with_lm_groups[tier_gb]["results"])
no_lm_str = str(no_lm_max) if no_lm_max > 0 else "FAIL"
with_lm_str = str(with_lm_max) if with_lm_max > 0 else ("N/A" if tier_gb not in with_lm_groups else "FAIL")
no_lm_peak_str = f"{no_lm_peak:.2f}GB" if no_lm_max > 0 else "-"
with_lm_peak_str = f"{with_lm_peak:.2f}GB" if with_lm_max > 0 else "-"
print(
f" {tier_gb:5d}GB {tier_name:<12} {no_lm_str:>18} {no_lm_peak_str:>10} "
f"{with_lm_str:>20} {with_lm_peak_str:>10}"
)
summary_rows.append({
"tier_gb": tier_gb,
"tier": tier_name,
"max_batch_no_lm": no_lm_max,
"max_batch_with_lm": with_lm_max if tier_gb in with_lm_groups else None,
})
print(" " + "-" * 90)
print()
# Print comparison with current GPU_TIER_CONFIGS
print(" Comparison with current GPU_TIER_CONFIGS:")
print(f" {'VRAM':>6} {'Tier':<12} {'Config (no LM)':>15} {'Tested (no LM)':>15} {'Config (LM)':>12} {'Tested (LM)':>12} {'Recommendation':<30}")
print(" " + "-" * 110)
for row in summary_rows:
tier_gb = row["tier_gb"]
tier_name = row["tier"]
cfg = get_gpu_config(gpu_memory_gb=float(tier_gb))
cfg_no_lm = cfg.max_batch_size_without_lm
cfg_with_lm = cfg.max_batch_size_with_lm
tested_no_lm = row["max_batch_no_lm"]
tested_with_lm = row["max_batch_with_lm"]
tested_no_lm_str = str(tested_no_lm) if tested_no_lm > 0 else "FAIL"
tested_with_lm_str = str(tested_with_lm) if tested_with_lm is not None and tested_with_lm > 0 else ("N/A" if tested_with_lm is None else "FAIL")
# Recommendation
rec_parts = []
if tested_no_lm > 0 and tested_no_lm != cfg_no_lm:
rec_parts.append(f"no_lm: {cfg_no_lm}{tested_no_lm}")
if tested_with_lm is not None and tested_with_lm > 0 and tested_with_lm != cfg_with_lm:
rec_parts.append(f"lm: {cfg_with_lm}{tested_with_lm}")
recommendation = ", ".join(rec_parts) if rec_parts else "OK"
print(
f" {tier_gb:5d}GB {tier_name:<12} {cfg_no_lm:>15} {tested_no_lm_str:>15} "
f"{cfg_with_lm:>12} {tested_with_lm_str:>12} {recommendation:<30}"
)
print(" " + "-" * 110)
print()
print(" Note: Batch boundary results are empirical and depend on:")
print(" - DiT model variant (turbo vs base)")
print(" - Generation duration (longer = more VRAM per batch)")
print(" - Flash attention availability")
print(" - LM model size (0.6B vs 1.7B vs 4B)")
print(" - Quantization and offload settings")
# =============================================================================
# Mode: understand
# =============================================================================
def run_understand_mode(dit_handler, llm_handler, args, timer: PreciseTimer):
"""Profile the understand_music() API."""
if not llm_handler.llm_initialized:
print("\n LLM not initialized. understand mode requires LLM.")
print(" Re-run with --thinking or ensure LLM is available.")
sys.exit(1)
audio_codes = args.audio_codes if args.audio_codes else ""
print(
f"\n Audio codes: "
f"{'<provided>' if audio_codes else '<empty - will generate sample>'}"
)
timer.sync()
t0 = time.perf_counter()
result = understand_music(
llm_handler=llm_handler,
audio_codes=audio_codes,
temperature=args.lm_temperature,
use_constrained_decoding=args.use_constrained_decoding,
)
timer.sync()
wall_time = time.perf_counter() - t0
print(f"\n Wall time: {wall_time:.3f}s")
print(f" Success: {result.success}")
if result.success:
print(f" Caption: {result.caption[:100]}...")
print(
f" BPM: {result.bpm}, Duration: {result.duration}, "
f"Key: {result.keyscale}"
)
print(
f" Language: {result.language}, Time Sig: {result.timesignature}"
)
if result.lyrics:
print(f" Lyrics: {result.lyrics[:100]}...")
else:
print(f" Error: {result.error}")
return result, wall_time
# =============================================================================
# Mode: create_sample
# =============================================================================
def run_create_sample_mode(
dit_handler, llm_handler, args, timer: PreciseTimer
):
"""Profile the create_sample() API (inspiration/simple mode)."""
if not llm_handler.llm_initialized:
print("\n LLM not initialized. create_sample mode requires LLM.")
sys.exit(1)
query = args.sample_query or "a soft love song for a quiet evening"
print(f"\n Query: {query}")
print(f" Instrumental: {args.instrumental}")
timer.sync()
t0 = time.perf_counter()
result = create_sample(
llm_handler=llm_handler,
query=query,
instrumental=args.instrumental,
temperature=args.lm_temperature,
use_constrained_decoding=args.use_constrained_decoding,
)
timer.sync()
wall_time = time.perf_counter() - t0
print(f"\n Wall time: {wall_time:.3f}s")
print(f" Success: {result.success}")
if result.success:
print(f" Caption: {result.caption[:100]}...")
print(
f" BPM: {result.bpm}, Duration: {result.duration}, "
f"Key: {result.keyscale}"
)
print(
f" Language: {result.language}, Time Sig: {result.timesignature}"
)
print(f" Instrumental: {result.instrumental}")
if result.lyrics:
print(f" Lyrics: {result.lyrics[:100]}...")
else:
print(f" Error: {result.error}")
return result, wall_time
# =============================================================================
# Mode: format_sample
# =============================================================================
def run_format_sample_mode(
dit_handler, llm_handler, args, timer: PreciseTimer
):
"""Profile the format_sample() API."""
if not llm_handler.llm_initialized:
print("\n LLM not initialized. format_sample mode requires LLM.")
sys.exit(1)
example_file = os.path.join(
PROJECT_ROOT, "examples", "text2music", args.example
)
if not os.path.exists(example_file):
print(f"\n Example not found: {example_file}")
sys.exit(1)
with open(example_file, "r", encoding="utf-8") as f:
data = json.load(f)
caption = data.get("caption", "Latin pop, reggaeton")
lyrics = data.get("lyrics", "[Verse 1]\nHola mundo")
print(f"\n Caption: {caption[:80]}...")
print(f" Lyrics: {lyrics[:80]}...")
timer.sync()
t0 = time.perf_counter()
result = format_sample(
llm_handler=llm_handler,
caption=caption,
lyrics=lyrics,
temperature=args.lm_temperature,
use_constrained_decoding=args.use_constrained_decoding,
)
timer.sync()
wall_time = time.perf_counter() - t0
print(f"\n Wall time: {wall_time:.3f}s")
print(f" Success: {result.success}")
if result.success:
print(f" Caption: {result.caption[:100]}...")
print(
f" BPM: {result.bpm}, Duration: {result.duration}, "
f"Key: {result.keyscale}"
)
print(
f" Language: {result.language}, Time Sig: {result.timesignature}"
)
else:
print(f" Error: {result.error}")
return result, wall_time
# =============================================================================
# cProfile helper
# =============================================================================
def _print_cprofile(prof):
"""Print cProfile results and save to file."""
import pstats
import io
output_file = "profile_cprofile_detailed.txt"
with open(output_file, "w") as f:
ps = pstats.Stats(prof, stream=f)
ps.sort_stats("cumulative")
ps.print_stats(100)
print("\n" + "=" * 100)
print("TOP 20 FUNCTIONS BY CUMULATIVE TIME (cProfile)")
print("=" * 100)
s = io.StringIO()
ps = pstats.Stats(prof, stream=s)
ps.sort_stats("cumulative")
ps.print_stats(20)
print(s.getvalue())
print(f"Full report saved to: {output_file}")
def _cleanup_dir(path: str):
"""Remove temporary directory silently."""
try:
import shutil
shutil.rmtree(path, ignore_errors=True)
except Exception:
pass
# =============================================================================
# Handler initialization (for non-tier-test modes)
# =============================================================================
def initialize_handlers(
args, device: str
) -> Tuple[AceStepHandler, LLMHandler]:
"""Initialize DiT and LLM handlers with current API."""
dit_handler = AceStepHandler()
llm_handler = LLMHandler()
# Determine flash attention availability
use_flash_attention = False
if device.startswith("cuda"):
try:
import flash_attn # noqa: F401
use_flash_attention = True
except ImportError:
pass
compile_model = os.environ.get(
"ACESTEP_COMPILE_MODEL", ""
).strip().lower() in {"1", "true", "yes", "y", "on"}
print(" Initializing DiT handler...")
status_dit, success_dit = dit_handler.initialize_service(
project_root=PROJECT_ROOT,
config_path=args.config_path,
device=args.device, # Pass original device string (handler resolves "auto")
use_flash_attention=use_flash_attention,
compile_model=compile_model,
offload_to_cpu=args.offload_to_cpu,
offload_dit_to_cpu=args.offload_dit_to_cpu,
quantization=args.quantization,
)
if not success_dit:
print(f" DiT initialization failed: {status_dit}")
sys.exit(1)
print(f" DiT ready (device={dit_handler.device})")
# Determine if LLM should be initialized
need_llm = (
args.thinking
or args.use_cot_metas
or args.use_cot_caption
or args.use_cot_language
or args.mode in ("understand", "create_sample", "format_sample")
)
if need_llm:
print(f" Initializing LLM handler (backend={args.lm_backend})...")
status_llm, success_llm = llm_handler.initialize(
checkpoint_dir=os.path.join(PROJECT_ROOT, "checkpoints"),
lm_model_path=args.lm_model,
backend=args.lm_backend,
device=args.device,
offload_to_cpu=args.offload_to_cpu,
dtype=None,
)
if success_llm:
print(f" LLM ready (backend={llm_handler.llm_backend})")
else:
print(f" LLM initialization failed: {status_llm}")
if args.mode in ("understand", "create_sample", "format_sample"):
sys.exit(1)
else:
print(
" LLM not needed for current configuration "
"(thinking/CoT disabled)"
)
return dit_handler, llm_handler
# =============================================================================
# CLI argument parser
# =============================================================================
def build_parser() -> argparse.ArgumentParser:
"""Build the argument parser with all options."""
env_config = load_env_config()
parser = argparse.ArgumentParser(
description="ACE-Step 1.5 Inference Profiler & Benchmark",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python profile_inference.py # Profile text2music
python profile_inference.py --thinking --llm-debug # With LLM analysis
python profile_inference.py --mode benchmark # Benchmark matrix
python profile_inference.py --mode tier-test # Test all GPU tiers
python profile_inference.py --mode tier-test --tiers 6 8 16 # Test specific tiers
python profile_inference.py --mode tier-test --tier-with-lm # Test tiers with LM
python profile_inference.py --mode understand # Profile understand API
python profile_inference.py --mode create_sample --sample-query "jazz ballad"
python profile_inference.py --device mps --lm-backend mlx # Apple Silicon
python profile_inference.py --device cuda --lm-backend vllm # NVIDIA GPU
""",
)
# Mode
parser.add_argument(
"--mode",
type=str,
default="profile",
choices=[
"profile",
"benchmark",
"tier-test",
"understand",
"create_sample",
"format_sample",
],
help="Profiling mode (default: profile)",
)
# Device & backend
parser.add_argument(
"--device",
type=str,
default=env_config["ACESTEP_DEVICE"],
help=(
f"Device: auto/cuda/mps/cpu "
f"(default: {env_config['ACESTEP_DEVICE']})"
),
)
parser.add_argument(
"--lm-backend",
type=str,
default=env_config["ACESTEP_LM_BACKEND"],
choices=["auto", "vllm", "pt", "mlx"],
help=(
f"LLM backend "
f"(default: {env_config['ACESTEP_LM_BACKEND']})"
),
)
# Model paths
parser.add_argument(
"--config-path",
type=str,
default=env_config["ACESTEP_CONFIG_PATH"],
help=(
f"DiT model config "
f"(default: {env_config['ACESTEP_CONFIG_PATH']})"
),
)
parser.add_argument(
"--lm-model",
type=str,
default=env_config["ACESTEP_LM_MODEL_PATH"],
help=(
f"LLM model path "
f"(default: {env_config['ACESTEP_LM_MODEL_PATH']})"
),
)
# Hardware options
parser.add_argument(
"--offload-to-cpu",
action="store_true",
help="Offload models to CPU when not in use",
)
parser.add_argument(
"--offload-dit-to-cpu",
action="store_true",
help="Offload DiT to CPU when not in use",
)
parser.add_argument(
"--quantization",
type=str,
default=None,
choices=["int8_weight_only", "fp8_weight_only", "w8a8_dynamic"],
help="Quantization mode for DiT model",
)
# Example & input
parser.add_argument(
"--example",
type=str,
default="example_05.json",
help="Example JSON file from examples/text2music/",
)
# Task type
parser.add_argument(
"--task-type",
type=str,
default="text2music",
choices=[
"text2music",
"cover",
"repaint",
"lego",
"extract",
"complete",
],
help="Generation task type (default: text2music)",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Reference audio path (for cover/style transfer)",
)
parser.add_argument(
"--src-audio",
type=str,
default=None,
help="Source audio path (for audio-to-audio tasks)",
)
# Generation parameters
parser.add_argument(
"--duration",
type=float,
default=None,
help="Audio duration in seconds (overrides example)",
)
parser.add_argument(
"--batch-size",
type=int,
default=None,
help="Batch size (overrides example)",
)
parser.add_argument(
"--inference-steps",
type=int,
default=None,
help="Diffusion inference steps (overrides example)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed (overrides example)",
)
parser.add_argument(
"--guidance-scale",
type=float,
default=7.0,
help="CFG guidance scale for DiT (default: 7.0)",
)
# LLM / CoT parameters
parser.add_argument(
"--thinking",
action="store_true",
help="Enable 5Hz LM Chain-of-Thought reasoning",
)
parser.add_argument(
"--use-cot-metas",
action="store_true",
help="Enable LLM to generate music metadata via CoT",
)
parser.add_argument(
"--use-cot-caption",
action="store_true",
help="Enable LLM to rewrite/format caption via CoT",
)
parser.add_argument(
"--use-cot-language",
action="store_true",
help="Enable LLM to detect vocal language via CoT",
)
parser.add_argument(
"--use-constrained-decoding",
action="store_true",
default=True,
help="Use FSM-based constrained decoding (default: True)",
)
parser.add_argument(
"--no-constrained-decoding",
action="store_true",
help="Disable constrained decoding",
)
parser.add_argument(
"--lm-temperature",
type=float,
default=0.85,
help="LLM sampling temperature (default: 0.85)",
)
parser.add_argument(
"--lm-cfg-scale",
type=float,
default=2.0,
help="LLM CFG scale (default: 2.0)",
)
# Profiling options
parser.add_argument(
"--no-warmup",
action="store_true",
help="Skip warmup run (includes compilation overhead)",
)
parser.add_argument(
"--detailed",
action="store_true",
help="Enable cProfile function-level analysis",
)
parser.add_argument(
"--llm-debug",
action="store_true",
help="Enable deep LLM debugging (token count, throughput)",
)
# Benchmark options
parser.add_argument(
"--benchmark-output",
type=str,
default=None,
help="Save benchmark results to JSON file",
)
# Tier-test options
parser.add_argument(
"--tiers",
type=int,
nargs="+",
default=None,
help="Specific VRAM tiers to test (e.g., --tiers 6 8 16). Default: all tiers",
)
parser.add_argument(
"--tier-with-lm",
action="store_true",
help="Enable LM for tiers that support it (default: DiT-only test)",
)
parser.add_argument(
"--tier-duration",
type=float,
default=240,
help="Test generation duration in seconds for tier-test (default: 240)",
)
parser.add_argument(
"--tier-skip-compile",
action="store_true",
help="Skip torch.compile for non-quantized tiers (faster testing, less realistic)",
)
parser.add_argument(
"--tier-boundary",
action="store_true",
help="Enable boundary testing: for each tier, also test without INT8 quantization "
"and without CPU offload to find the minimum VRAM tier for each capability",
)
parser.add_argument(
"--tier-batch-boundary",
action="store_true",
help="Enable batch size boundary testing: for each tier, progressively test "
"batch sizes 1, 2, 4, 8 (stop at first OOM) to find the maximum safe batch "
"size. Tests both with-LM and without-LM configurations.",
)
# create_sample / understand options
parser.add_argument(
"--sample-query",
type=str,
default=None,
help="Query for create_sample mode",
)
parser.add_argument(
"--instrumental",
action="store_true",
help="Generate instrumental music (for create_sample)",
)
parser.add_argument(
"--audio-codes",
type=str,
default=None,
help="Audio codes string (for understand mode)",
)
return parser
# =============================================================================
# Main
# =============================================================================
def main():
parser = build_parser()
args = parser.parse_args()
# Handle --no-constrained-decoding
if args.no_constrained_decoding:
args.use_constrained_decoding = False
# Tier-test mode has its own initialization flow
if args.mode == "tier-test":
print("=" * 120)
print("ACE-Step 1.5 Tier Compatibility Test")
print("=" * 120)
run_tier_test_mode(args)
print("\n" + "=" * 120)
print("DONE")
print("=" * 120)
return
# Resolve device
device = resolve_device(args.device)
# Auto-detect backend
if args.lm_backend == "auto":
args.lm_backend = auto_detect_backend(device)
# Setup GPU config
gpu_config = get_gpu_config()
set_global_gpu_config(gpu_config)
# Auto-enable offload for small GPUs
if (
gpu_config.gpu_memory_gb > 0
and gpu_config.gpu_memory_gb < VRAM_AUTO_OFFLOAD_THRESHOLD_GB
and not args.offload_to_cpu
):
args.offload_to_cpu = True
# Print header
print("=" * 100)
print("ACE-Step 1.5 Inference Profiler")
print("=" * 100)
print(f"\n Mode: {args.mode}")
print(f" Device: {device} (requested: {args.device})")
print(f" LLM Backend: {args.lm_backend}")
print(f" DiT Config: {args.config_path}")
print(f" LLM Model: {args.lm_model}")
print(
f" GPU Memory: {gpu_config.gpu_memory_gb:.1f} GB "
f"(tier: {gpu_config.tier})"
)
if args.quantization:
print(f" Quantization: {args.quantization}")
if args.offload_to_cpu:
print(" CPU Offload: enabled")
print(f"\n Thinking: {args.thinking}")
print(f" CoT Metas: {args.use_cot_metas}")
print(f" CoT Caption: {args.use_cot_caption}")
print(f" CoT Language: {args.use_cot_language}")
print(f" Constrained: {args.use_constrained_decoding}")
print(f" Warmup: {'disabled' if args.no_warmup else 'enabled'}")
# Initialize handlers
print("\n" + "-" * 100)
print("INITIALIZING MODELS")
print("-" * 100)
dit_handler, llm_handler = initialize_handlers(args, device)
# Create timer with resolved device
actual_device = getattr(dit_handler, "device", device)
timer = PreciseTimer(device=actual_device)
# Dispatch to mode
print("\n" + "=" * 100)
print(f"RUNNING MODE: {args.mode.upper()}")
print("=" * 100)
if args.mode == "profile":
run_profile_mode(dit_handler, llm_handler, args, timer)
elif args.mode == "benchmark":
run_benchmark_mode(dit_handler, llm_handler, args, timer)
elif args.mode == "understand":
run_understand_mode(dit_handler, llm_handler, args, timer)
elif args.mode == "create_sample":
run_create_sample_mode(dit_handler, llm_handler, args, timer)
elif args.mode == "format_sample":
run_format_sample_mode(dit_handler, llm_handler, args, timer)
print("\n" + "=" * 100)
print("DONE")
print("=" * 100)
if __name__ == "__main__":
main()