ChuxiJ commited on
Commit
3a962ef
·
verified ·
1 Parent(s): d44f408

Initial upload: ACE-Step v1.5 1D VAE (stable-audio-tools format)

Browse files
Files changed (4) hide show
  1. README.md +130 -0
  2. checkpoint.ckpt +3 -0
  3. config.json +123 -0
  4. stable_audio_vae.py +205 -0
README.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: stable-audio-tools
3
+ license: mit
4
+ pipeline_tag: text-to-audio
5
+ tags:
6
+ - audio
7
+ - music
8
+ - vae
9
+ - autoencoder
10
+ - ace-step
11
+ - stable-audio-tools
12
+ ---
13
+
14
+ <h1 align="center">ACE-Step v1.5 1D VAE</h1>
15
+ <h1 align="center">Stable Audio Tools Format</h1>
16
+ <p align="center">
17
+ <a href="https://github.com/ACE-Step/ACE-Step-1.5">GitHub</a> |
18
+ <a href="https://ace-step.github.io/ace-step-v1.5.github.io/">Project</a> |
19
+ <a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
20
+ <a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5">Space Demo</a> |
21
+ <a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
22
+ <a href="https://arxiv.org/abs/2602.00744">Tech Report</a>
23
+ </p>
24
+
25
+ ## Model Details
26
+
27
+ This is the 1D Variational Autoencoder (VAE) used in [ACE-Step v1.5](https://github.com/ACE-Step/ACE-Step-1.5) for music generation. The weights are provided in **[stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools)** compatible format, making it easy to load, fine-tune, and integrate into your own training pipelines.
28
+
29
+ - **Developed by:** [ACE-STEP](https://github.com/ACE-Step)
30
+ - **Model type:** Audio VAE (Oobleck Autoencoder)
31
+ - **License:** [MIT](https://opensource.org/licenses/MIT)
32
+
33
+ | Parameter | Value |
34
+ |-----------|-------|
35
+ | Architecture | Oobleck Autoencoder (VAE) |
36
+ | Audio Channels | 2 (Stereo) |
37
+ | Sampling Rate | 48,000 Hz |
38
+ | Latent Dim | 64 |
39
+ | Encoder Latent Dim | 128 |
40
+ | Downsampling Ratio | 1,920 |
41
+ | Encoder/Decoder Channels | 128 |
42
+ | Channel Multipliers | [1, 2, 4, 8, 16] |
43
+ | Strides | [2, 4, 4, 6, 10] |
44
+ | Activation | Snake |
45
+
46
+ ## 🏗️ Architecture
47
+
48
+ The VAE is a core component of the ACE-Step v1.5 pipeline, responsible for compressing raw stereo audio (48kHz) into a compact latent representation with a 1920x downsampling ratio and 64-dimensional latent space. The DiT operates in this latent space to generate music.
49
+
50
+ ## Quick Start
51
+
52
+ ### Installation
53
+
54
+ ```bash
55
+ pip install stable-audio-tools torchaudio
56
+ ```
57
+
58
+ ### Load and Use
59
+
60
+ ```python
61
+ from stable_audio_vae import StableAudioVAE
62
+
63
+ # Load model
64
+ vae = StableAudioVAE(
65
+ config_path="config.json",
66
+ checkpoint_path="checkpoint.ckpt",
67
+ )
68
+ vae = vae.cuda().eval()
69
+
70
+ # Encode audio
71
+ wav = vae.load_wav("input.wav")
72
+ wav = wav.cuda()
73
+ latent = vae.encode(wav)
74
+ print(f"Latent shape: {latent.shape}") # [batch, 64, time/1920]
75
+
76
+ # Decode back to audio
77
+ output = vae.decode(latent)
78
+ ```
79
+
80
+ ### Command Line
81
+
82
+ ```bash
83
+ python stable_audio_vae.py -i input.wav -o output.wav
84
+
85
+ # For long audio, use chunked processing
86
+ python stable_audio_vae.py -i input.wav -o output.wav --chunked
87
+ ```
88
+
89
+ ## Fine-Tuning
90
+
91
+ This checkpoint is compatible with [stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) training pipelines. The `config.json` includes full training configuration (optimizer, loss, discriminator settings) that you can use as a starting point for fine-tuning.
92
+
93
+ ## File Structure
94
+
95
+ ```
96
+ .
97
+ ├── config.json # Model architecture and training config
98
+ ├── checkpoint.ckpt # Model weights (PyTorch checkpoint)
99
+ ├── stable_audio_vae.py # Inference script with StableAudioVAE wrapper
100
+ └── README.md
101
+ ```
102
+
103
+ ## 🦁 Related Models
104
+
105
+ | Model | Description | Hugging Face |
106
+ |-------|-------------|--------------|
107
+ | `acestep-v15-base` | DiT base model (CFG, 50 steps) | [Link](https://huggingface.co/ACE-Step/acestep-v15-base) |
108
+ | `acestep-v15-sft` | DiT SFT model (CFG, 50 steps) | [Link](https://huggingface.co/ACE-Step/acestep-v15-sft) |
109
+ | `acestep-v15-turbo` | DiT turbo model (8 steps) | [Link](https://huggingface.co/ACE-Step/Ace-Step1.5) |
110
+ | `acestep-v15-xl-base` | XL DiT base (4B, CFG, 50 steps) | [Link](https://huggingface.co/ACE-Step/acestep-v15-xl-base) |
111
+ | `acestep-v15-xl-sft` | XL DiT SFT (4B, CFG, 50 steps) | [Link](https://huggingface.co/ACE-Step/acestep-v15-xl-sft) |
112
+ | `acestep-v15-xl-turbo` | XL DiT turbo (4B, 8 steps) | [Link](https://huggingface.co/ACE-Step/acestep-v15-xl-turbo) |
113
+
114
+ ## 🙏 Acknowledgements
115
+
116
+ This project is co-led by ACE Studio and StepFun.
117
+
118
+ ## 📖 Citation
119
+
120
+ If you find this project useful for your research, please consider citing:
121
+
122
+ ```BibTeX
123
+ @misc{gong2026acestep,
124
+ title={ACE-Step 1.5: Pushing the Boundaries of Open-Source Music Generation},
125
+ author={Junmin Gong, Yulin Song, Wenxiao Zhao, Sen Wang, Shengyuan Xu, Jing Guo},
126
+ howpublished={\url{https://github.com/ace-step/ACE-Step-1.5}},
127
+ year={2026},
128
+ note={GitHub repository}
129
+ }
130
+ ```
checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1575959a062145b8a36e4db420431d38748c82c7ba53ebe6742b073b9abf58b5
3
+ size 674902910
config.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 122880,
4
+ "sample_rate": 48000,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8, 16],
13
+ "strides": [2, 4, 4, 6, 10],
14
+ "latent_dim": 128,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 6, 10],
25
+ "latent_dim": 64,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 64,
34
+ "downsampling_ratio": 1920,
35
+ "io_channels": 2
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "encoder_freeze_on_warmup": true,
41
+ "use_ema": true,
42
+ "optimizer_configs": {
43
+ "autoencoder": {
44
+ "optimizer": {
45
+ "type": "Muon",
46
+ "config": {
47
+ "betas": [0.8, 0.99],
48
+ "lr": 1.5e-4,
49
+ "weight_decay": 1e-3
50
+ }
51
+ },
52
+ "scheduler": {
53
+ "type": "InverseLR",
54
+ "config": {
55
+ "inv_gamma": 200000,
56
+ "power": 0.5,
57
+ "warmup": 0.999
58
+ }
59
+ }
60
+ },
61
+ "discriminator": {
62
+ "optimizer": {
63
+ "type": "Muon",
64
+ "config": {
65
+ "betas": [0.8, 0.99],
66
+ "lr": 3e-4,
67
+ "weight_decay": 1e-3
68
+ }
69
+ },
70
+ "scheduler": {
71
+ "type": "InverseLR",
72
+ "config": {
73
+ "inv_gamma": 200000,
74
+ "power": 0.5,
75
+ "warmup": 0.999
76
+ }
77
+ }
78
+ }
79
+ },
80
+ "loss_configs": {
81
+ "discriminator": {
82
+ "type": "encodec",
83
+ "config": {
84
+ "filters": 64,
85
+ "n_ffts": [2048, 1024, 512, 256, 128],
86
+ "hop_lengths": [512, 256, 128, 64, 32],
87
+ "win_lengths": [2048, 1024, 512, 256, 128]
88
+ },
89
+ "weights": {
90
+ "adversarial": 0.5,
91
+ "feature_matching": 5.0
92
+ }
93
+ },
94
+ "spectral": {
95
+ "type": "mrstft",
96
+ "config": {
97
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
98
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
99
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
100
+ "perceptual_weighting": true
101
+ },
102
+ "weights": {
103
+ "mrstft": 1.0
104
+ }
105
+ },
106
+ "time": {
107
+ "type": "l1",
108
+ "weights": {
109
+ "l1": 0.0
110
+ }
111
+ },
112
+ "bottleneck": {
113
+ "type": "kl",
114
+ "weights": {
115
+ "kl": 0
116
+ }
117
+ }
118
+ },
119
+ "demo": {
120
+ "demo_every": 2000
121
+ }
122
+ }
123
+ }
stable_audio_vae.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+ import torchaudio
7
+
8
+ from stable_audio_tools.models.autoencoders import create_autoencoder_from_config
9
+
10
+
11
+ DEFAULT_ROOT = "./"
12
+ DEFAULT_CONFIG_PATH = os.path.join(DEFAULT_ROOT, "config.json")
13
+ DEFAULT_CHECKPOINT_PATH = os.path.join(DEFAULT_ROOT, "checkpoint.ckpt")
14
+
15
+
16
+ def remove_weight_norm_(module):
17
+ """Recursively remove weight normalization from all submodules."""
18
+ for name, child in module.named_children():
19
+ if hasattr(child, "weight"):
20
+ try:
21
+ remove_weight_norm(child)
22
+ except ValueError:
23
+ pass
24
+ remove_weight_norm_(child)
25
+
26
+
27
+ def add_weight_norm_(module):
28
+ """Recursively add weight normalization to all submodules."""
29
+ for name, child in module.named_children():
30
+ if hasattr(child, "weight"):
31
+ weight_norm(child)
32
+ add_weight_norm_(child)
33
+
34
+
35
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
36
+ """Resample, pad/crop, and set audio channels."""
37
+ audio = audio.to(device)
38
+
39
+ if in_sr != target_sr:
40
+ audio = torchaudio.functional.resample(
41
+ audio, orig_freq=in_sr, new_freq=target_sr
42
+ )
43
+ if target_length is None:
44
+ target_length = audio.shape[-1]
45
+ audio = PadCrop(target_length, randomize=False)(audio)
46
+
47
+ if audio.dim() == 1:
48
+ audio = audio.unsqueeze(0).unsqueeze(0)
49
+ elif audio.dim() == 2:
50
+ audio = audio.unsqueeze(0)
51
+
52
+ audio = set_audio_channels(audio, target_channels)
53
+ return audio
54
+
55
+
56
+ class PadCrop(torch.nn.Module):
57
+ def __init__(self, n_samples, randomize=True):
58
+ super().__init__()
59
+ self.n_samples = n_samples
60
+ self.randomize = randomize
61
+
62
+ def __call__(self, signal):
63
+ n, s = signal.shape
64
+ start = 0 if (not self.randomize) else torch.randint(
65
+ 0, max(0, s - self.n_samples) + 1, []
66
+ ).item()
67
+ end = start + self.n_samples
68
+ output = signal.new_zeros([n, self.n_samples])
69
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
70
+ return output
71
+
72
+
73
+ def set_audio_channels(audio, target_channels):
74
+ if target_channels == 1:
75
+ audio = audio.mean(1, keepdim=True)
76
+ elif target_channels == 2:
77
+ if audio.shape[1] == 1:
78
+ audio = audio.repeat(1, 2, 1)
79
+ elif audio.shape[1] > 2:
80
+ audio = audio[:, :2, :]
81
+ return audio
82
+
83
+
84
+ class StableAudioVAE(nn.Module):
85
+ def __init__(
86
+ self,
87
+ sampling_rate=48000,
88
+ config_path=DEFAULT_CONFIG_PATH,
89
+ checkpoint_path=DEFAULT_CHECKPOINT_PATH,
90
+ scale_factor=1.0,
91
+ shift_factor=0.0,
92
+ remove_norm=False,
93
+ overlap=32,
94
+ chunk_size=128,
95
+ ):
96
+ super(StableAudioVAE, self).__init__()
97
+ with open(config_path, "r") as f:
98
+ self.config = json.load(f)
99
+ self.vae = create_autoencoder_from_config(self.config)
100
+
101
+ # Load checkpoint - support both .ckpt (PyTorch) and .safetensors
102
+ if checkpoint_path.endswith(".safetensors"):
103
+ from safetensors.torch import load_file
104
+ checkpoints = load_file(checkpoint_path)
105
+ else:
106
+ checkpoints = torch.load(
107
+ checkpoint_path, map_location=torch.device("cpu")
108
+ )
109
+ if "state_dict" in checkpoints:
110
+ checkpoints = checkpoints["state_dict"]
111
+
112
+ # Strip "autoencoder." prefix if present
113
+ has_autoencoder = any(
114
+ k.startswith("autoencoder.") for k in checkpoints.keys()
115
+ )
116
+ if has_autoencoder:
117
+ checkpoints = {
118
+ k.replace("autoencoder.", ""): v
119
+ for k, v in checkpoints.items()
120
+ if k.startswith("autoencoder.")
121
+ }
122
+ self.vae.load_state_dict(checkpoints)
123
+
124
+ if remove_norm:
125
+ remove_weight_norm_(self.vae)
126
+
127
+ self.scale_factor = scale_factor
128
+ self.shift_factor = shift_factor
129
+ self.sampling_rate = sampling_rate
130
+ self.io_channels = self.config["audio_channels"]
131
+ self.overlap = overlap
132
+ self.chunk_size = chunk_size
133
+ self.downsampling_ratio = self.vae.downsampling_ratio
134
+ self.latent_dim = self.vae.latent_dim
135
+
136
+ def load_wav(self, path):
137
+ wav, sr = torchaudio.load(path)
138
+ wav = prepare_audio(
139
+ wav,
140
+ in_sr=sr,
141
+ target_sr=self.sampling_rate,
142
+ target_length=None,
143
+ target_channels=self.io_channels,
144
+ device="cpu",
145
+ )
146
+ return wav
147
+
148
+ @torch.no_grad()
149
+ def encode(self, wav, chunked=False):
150
+ if wav.shape[1] <= self.chunk_size * self.vae.downsampling_ratio:
151
+ chunked = False
152
+ latent = self.vae.encode_audio(wav, chunked=chunked)
153
+ latent = self.scale_factor * (latent - self.shift_factor)
154
+ return latent
155
+
156
+ @torch.no_grad()
157
+ def decode(self, z, chunked=False):
158
+ z = z / self.scale_factor + self.shift_factor
159
+ if z.shape[-1] <= self.chunk_size:
160
+ chunked = False
161
+ output = self.vae.decode_audio(z, chunked=chunked)
162
+ return output
163
+
164
+ @torch.no_grad()
165
+ def forward(self, wav, chunked=False):
166
+ """Encode and decode audio (reconstruction)."""
167
+ latent = self.vae.encode_audio(wav, chunked=chunked)
168
+ latent = self.scale_factor * (latent - self.shift_factor)
169
+ latent = latent / self.scale_factor + self.shift_factor
170
+ output = self.vae.decode_audio(latent, chunked=chunked)
171
+ return output
172
+
173
+
174
+ if __name__ == "__main__":
175
+ import argparse
176
+
177
+ parser = argparse.ArgumentParser(description="Encode and decode audio with StableAudioVAE")
178
+ parser.add_argument("-m", "--model", type=str, default=DEFAULT_CHECKPOINT_PATH, help="path to checkpoint")
179
+ parser.add_argument("-c", "--config", type=str, default=DEFAULT_CONFIG_PATH, help="path to config.json")
180
+ parser.add_argument("-i", "--input", type=str, required=True, help="input audio file")
181
+ parser.add_argument("-o", "--output", type=str, required=True, help="output audio file")
182
+ parser.add_argument("-sr", "--sampling_rate", type=int, default=48000, help="sampling rate")
183
+ parser.add_argument("--chunked", action="store_true", help="use chunked processing for long audio")
184
+ args = parser.parse_args()
185
+
186
+ pipeline = StableAudioVAE(
187
+ sampling_rate=args.sampling_rate,
188
+ config_path=args.config,
189
+ checkpoint_path=args.model,
190
+ )
191
+ pipeline = pipeline.cuda()
192
+
193
+ wav = pipeline.load_wav(args.input)
194
+ wav = wav.cuda()
195
+ print(f"Input shape: {wav.shape}")
196
+
197
+ z = pipeline.encode(wav, chunked=args.chunked)
198
+ print(f"Latent shape: {z.shape}")
199
+
200
+ output = pipeline.decode(z, chunked=args.chunked)
201
+ print(f"Output shape: {output.shape}")
202
+
203
+ output = output[0].cpu()
204
+ torchaudio.save(args.output, output, pipeline.sampling_rate)
205
+ print(f"Saved to {args.output}")