primepake commited on
Commit
997d9c0
·
1 Parent(s): 8387742

release DAC-VAE continous latent space

Browse files
Files changed (4) hide show
  1. README.md +2 -1
  2. dac-vae/config.yml +128 -0
  3. dac-vae/inference.py +8 -5
  4. dac-vae/model.py +1 -0
README.md CHANGED
@@ -66,7 +66,8 @@ pip install -r requirements.txt
66
 
67
  2. **Extracting DAC-VAE latent**
68
  ```bash
69
- python inference.py
 
70
  ```
71
 
72
  3. **Stage 1: Auto Regressive Transformer**
 
66
 
67
  2. **Extracting DAC-VAE latent**
68
  ```bash
69
+ cd dac-vae
70
+ python inference.py --checkpoint checkpoint.pt --config config.yml
71
  ```
72
 
73
  3. **Stage 1: Auto Regressive Transformer**
dac-vae/config.yml ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model setup
2
+ vae:
3
+ sample_rate: 24000
4
+ encoder_dim: 64
5
+ latent_dim: 64
6
+ encoder_rates: [2, 4, 5, 8]
7
+ decoder_dim: 1536
8
+ decoder_rates: [8, 5, 4, 2]
9
+ d_in: 1
10
+ d_out: 1
11
+ weight_init: xavier
12
+ activation: snake
13
+ gain: 1.0
14
+
15
+ discriminator:
16
+ sample_rate: 24000
17
+ d_in: 1
18
+ rates: []
19
+ periods: [2, 3, 5, 7, 11]
20
+ fft_sizes: [2048, 1024, 512]
21
+ bands:
22
+ - [0.0, 0.1]
23
+ - [0.1, 0.25]
24
+ - [0.25, 0.5]
25
+ - [0.5, 0.75]
26
+ - [0.75, 1.0]
27
+
28
+
29
+ max_norm: 1000
30
+ max_norm_d: 10
31
+ initial_norm: 1000
32
+ initial_norm_d: 10
33
+
34
+ amp: false
35
+ batch_size: 64
36
+ val_batch_size: 4
37
+ num_workers: 0
38
+ device: cuda
39
+ num_samples: 530000
40
+ gan_start_step: 0
41
+ num_iters: 500000
42
+ save_iters: 1000
43
+ valid_freq: 1000
44
+ sample_freq: 2000
45
+ val_idx: [0, 1, 2, 3, 4, 5, 6, 7]
46
+ seed: 0
47
+ lambdas:
48
+ mel/loss: 15.0
49
+ adv/feat_loss: 2.0
50
+ adv/gen_loss: 1.0
51
+ kl/loss: 0.1
52
+ stft/loss: 0.0
53
+ waveform/loss: 0.0
54
+ logs_penalty: 0.0 #0.02
55
+ grad_penalty: 0.0 #1.0
56
+ lipschitz_penalty: 0.0 #0.001
57
+
58
+ VolumeNorm.db: [lufs, -18]
59
+
60
+ # Transforms
61
+ build_transform.preprocess:
62
+ - Identity
63
+ build_transform.augment_prob: 0.0
64
+ build_transform.augment:
65
+ - Identity
66
+ build_transform.postprocess:
67
+ - Identity
68
+ - Identity
69
+ - Identity
70
+
71
+ # Loss setup
72
+ MultiScaleSTFTLoss:
73
+ window_lengths: [1024, 2048]
74
+
75
+ MelSpectrogramLoss:
76
+ n_mels: [5, 10, 20, 40, 80, 160, 320]
77
+ window_lengths: [32, 64, 128, 256, 512, 1024, 2048]
78
+ mel_fmin: [0, 0, 0, 0, 0, 0, 0]
79
+ mel_fmax: [null, null, null, null, null, null, null]
80
+ pow: 1.0
81
+ clamp_eps: 1.0e-5
82
+ mag_weight: 0.0
83
+
84
+ # optimizer
85
+ optimizer:
86
+ type: Adamw
87
+ weight_decay: 0.001
88
+ lr: 0.0001
89
+ scheduler: linearlr # or constantlr
90
+ warmup_steps: 500
91
+
92
+ disc_optimizer:
93
+ type: Adamw
94
+ weight_decay: 0.001
95
+ lr: 0.0001
96
+ scheduler: linearlr # or constantlr
97
+ warmup_steps: 500
98
+
99
+ # Data
100
+ train:
101
+ duration: 0.38
102
+ n_examples: 10000000
103
+ without_replacement: true
104
+ shuffle_loaders: true
105
+
106
+ val:
107
+ duration: 5.0
108
+ n_examples: 100
109
+ without_replacement: true
110
+ shuffle_loaders: false
111
+
112
+ test:
113
+ duration: 10.0
114
+ n_examples: 1000
115
+ without_replacement: true
116
+ shuffle_loaders: false
117
+
118
+ train_folders:
119
+ Emilia_EN:
120
+ - /home/masuser/minimax-audio/dataset/Emilia/EN
121
+
122
+ val_folders:
123
+ Emilia_EN:
124
+ - /home/masuser/minimax-audio/dataset/libritts
125
+
126
+ test_folders:
127
+ Emilia_EN:
128
+ - /home/masuser/minimax-audio/dataset/libritts
dac-vae/inference.py CHANGED
@@ -137,6 +137,9 @@ class DACVAEInference:
137
 
138
  # Forward pass through model
139
  print("Processing through DACVAE...")
 
 
 
140
  out = self.model(audio_tensor, self.sample_rate)
141
 
142
  # Extract outputs
@@ -146,7 +149,7 @@ class DACVAEInference:
146
  z = out['z']
147
  mu = out['mu']
148
  logs = out['logs']
149
-
150
  # Clamp output
151
  recons_audio = np.clip(recons_audio, -1.0, 1.0)
152
 
@@ -167,13 +170,13 @@ class DACVAEInference:
167
 
168
  def main():
169
  parser = argparse.ArgumentParser(description="DACVAE Audio Inference")
170
- parser.add_argument('--checkpoint', type=str, required=True,
171
  help='Path to model checkpoint')
172
- parser.add_argument('--config', type=str, default=None,
173
  help='Path to config YAML (optional if config is in checkpoint)')
174
- parser.add_argument('--input', type=str, required=True,
175
  help='Path to input audio file')
176
- parser.add_argument('--output', type=str, default=None,
177
  help='Path to save output audio (default: input_reconstructed.wav)')
178
  parser.add_argument('--device', type=str, default='cuda',
179
  choices=['cuda', 'cpu'], help='Device to run on')
 
137
 
138
  # Forward pass through model
139
  print("Processing through DACVAE...")
140
+ audio_tensor = audio_tensor[:, :, :9120]
141
+
142
+ print('audio_tensor shape: ', audio_tensor.shape)
143
  out = self.model(audio_tensor, self.sample_rate)
144
 
145
  # Extract outputs
 
149
  z = out['z']
150
  mu = out['mu']
151
  logs = out['logs']
152
+ print('z shape: ', z.shape)
153
  # Clamp output
154
  recons_audio = np.clip(recons_audio, -1.0, 1.0)
155
 
 
170
 
171
  def main():
172
  parser = argparse.ArgumentParser(description="DACVAE Audio Inference")
173
+ parser.add_argument('--checkpoint', type=str, required=False, default="/mnt/nvme/ckpts/24khz/364k_20250702_043748/checkpoint.pt",
174
  help='Path to model checkpoint')
175
+ parser.add_argument('--config', type=str, default="./config.yml",
176
  help='Path to config YAML (optional if config is in checkpoint)')
177
+ parser.add_argument('--input', type=str, required=False, default='./output.wav',
178
  help='Path to input audio file')
179
+ parser.add_argument('--output', type=str, default='./test.wav',
180
  help='Path to save output audio (default: input_reconstructed.wav)')
181
  parser.add_argument('--device', type=str, default='cuda',
182
  choices=['cuda', 'cpu'], help='Device to run on')
dac-vae/model.py CHANGED
@@ -474,6 +474,7 @@ class DACVAE(BaseModel, CodecMixin):
474
  x = self.encoder(audio_data)
475
  x = F.leaky_relu(x)
476
  x = self.en_conv_post(x)
 
477
  m, logs = torch.split(x, self.latent_dim, dim=1)
478
  logs = torch.clamp(logs, min=-14.0, max=14.0)
479
 
 
474
  x = self.encoder(audio_data)
475
  x = F.leaky_relu(x)
476
  x = self.en_conv_post(x)
477
+ print('x shape: ', x.shape)
478
  m, logs = torch.split(x, self.latent_dim, dim=1)
479
  logs = torch.clamp(logs, min=-14.0, max=14.0)
480