primepake commited on
Commit
2279ae0
·
1 Parent(s): 997d9c0

add reconstruction for audio

Browse files
flowae/audio_dito_inference.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import argparse
7
+ import soundfile as sf
8
+ from omegaconf import OmegaConf
9
+ import matplotlib.pyplot as plt
10
+
11
+ # Import models
12
+ import models
13
+ from models.ldm.dac.audiotools import AudioSignal
14
+
15
+
16
+ class AudioDiToInference:
17
+ def __init__(self, checkpoint_path, device='cuda'):
18
+ """Initialize Audio DiTo model from checkpoint"""
19
+ self.device = device
20
+
21
+ # Load checkpoint
22
+ print(f"Loading checkpoint from {checkpoint_path}")
23
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
24
+
25
+ # Extract config
26
+ self.config = OmegaConf.create(ckpt['config'])
27
+
28
+ # Create model
29
+ self.model = models.make(self.config['model'])
30
+
31
+ # Load state dict
32
+ self.model.load_state_dict(ckpt['model']['sd'])
33
+
34
+ # Move to device and set to eval
35
+ self.model = self.model.to(device)
36
+ self.model.eval()
37
+
38
+ # Get audio parameters from config
39
+ self.sample_rate = self.config.get('sample_rate', 24000)
40
+ self.mono = self.config.get('mono', True)
41
+
42
+ print(f"Model loaded successfully!")
43
+ print(f"Sample rate: {self.sample_rate} Hz")
44
+ print(f"Mono: {self.mono}")
45
+
46
+ def load_audio(self, audio_path, duration=None, offset=0.0):
47
+ """Load audio file using AudioSignal
48
+
49
+ Args:
50
+ audio_path: Path to audio file
51
+ duration: Duration in seconds (None for full audio)
52
+ offset: Start offset in seconds
53
+ """
54
+ # Load audio using AudioSignal
55
+ if duration is not None:
56
+ signal = AudioSignal(
57
+ str(audio_path),
58
+ duration=duration,
59
+ offset=offset,
60
+ )
61
+ else:
62
+ # Load full audio
63
+ signal = AudioSignal(str(audio_path))
64
+
65
+ # Convert to mono if needed
66
+ if self.mono and signal.num_channels > 1:
67
+ signal = signal.to_mono()
68
+
69
+ # Resample to model sample rate
70
+ if signal.sample_rate != self.sample_rate:
71
+ signal = signal.resample(self.sample_rate)
72
+
73
+ # Normalize
74
+ signal = signal.normalize()
75
+
76
+ # Clamp to [-1, 1]
77
+ signal.audio_data = signal.audio_data.clamp(-1.0, 1.0)
78
+
79
+ return signal
80
+
81
+ def save_audio(self, reconstructed, output_path):
82
+ """Save AudioSignal to file"""
83
+ # Get audio data
84
+ print('shape of reconstructed: ', reconstructed.shape)
85
+ sf.write(output_path, reconstructed, self.sample_rate)
86
+ print(f"Saved audio to {output_path}")
87
+
88
+ def reconstruct_audio(self, audio_path, num_steps=50, save_latent=False):
89
+ """Reconstruct entire audio file at once
90
+
91
+ Args:
92
+ audio_path: Path to audio file
93
+ num_steps: Number of diffusion steps
94
+ save_latent: Whether to return the latent representation
95
+ """
96
+ # Load full audio without duration limit
97
+ signal = self.load_audio(audio_path, duration=None, offset=0.0)
98
+
99
+ # Get audio tensor
100
+ audio_tensor = signal.audio_data # [channels, samples]
101
+ if audio_tensor.dim() == 2:
102
+ audio_tensor = audio_tensor.squeeze(0) # [samples] for mono
103
+
104
+ # Add batch dimension
105
+ audio_tensor = audio_tensor.to(self.device) # [1, samples]
106
+
107
+ print(f"Input shape: {audio_tensor.shape}")
108
+ print(f"Full audio duration: {audio_tensor.shape[-1] / self.sample_rate:.2f}s")
109
+
110
+ with torch.no_grad():
111
+ # Prepare data dict
112
+ data = {'inp': audio_tensor}
113
+
114
+ # Step 1: Encode to latent
115
+ print('shape of audio_tensor: ', audio_tensor.shape)
116
+ z = self.model.encode(audio_tensor)
117
+ print(f"Latent shape: {z.shape}")
118
+
119
+ # Step 2: Decode latent (if model has separate decode step)
120
+ if hasattr(self.model, 'decode'):
121
+ z_dec = self.model.decode(z)
122
+ else:
123
+ z_dec = z
124
+ print(f"Decoded latent shape: {z_dec.shape}")
125
+
126
+ # Step 3: Prepare dummy coordinates (based on training code)
127
+ b, *_ = audio_tensor.shape
128
+
129
+
130
+ # Step 4: Render using diffusion
131
+ if hasattr(self.model, 'render'):
132
+ # Render expects z_dec, coord, scale
133
+ print('using render diffusion model')
134
+ reconstructed = self.model.render(z_dec)
135
+ else:
136
+ # Alternative: direct decode if render not available
137
+ reconstructed = self.model(data, mode='pred')
138
+
139
+ # Remove batch dimension
140
+ reconstructed = reconstructed.squeeze(0).squeeze(0).cpu().numpy() # [samples]
141
+
142
+ print('shape of reconstructed: ', reconstructed.shape)
143
+
144
+
145
+ if save_latent:
146
+ return reconstructed, z.cpu()
147
+ else:
148
+ return reconstructed
149
+
150
+ def save_reconstruction(self, audio_path, output_path, num_steps=50):
151
+ """Reconstruct and save entire audio file"""
152
+ reconstructed = self.reconstruct_audio(audio_path, num_steps)
153
+ self.save_audio(reconstructed, output_path)
154
+
155
+ def compare_reconstruction(self, audio_path, output_path, num_steps=50):
156
+ """Save original and reconstruction concatenated"""
157
+ # Load original full audio
158
+ original = self.load_audio(audio_path, duration=None, offset=0.0)
159
+
160
+ # Get reconstruction of full audio
161
+ reconstructed = self.reconstruct_audio(audio_path, num_steps)
162
+
163
+ # Add 0.5 second silence between clips
164
+ silence_samples = int(0.5 * self.sample_rate)
165
+ silence_data = torch.zeros(1, silence_samples)
166
+
167
+ # Concatenate: original -> silence -> reconstruction
168
+ concat_data = torch.cat([
169
+ original.audio_data.cpu(),
170
+ silence_data,
171
+ reconstructed.audio_data.cpu()
172
+ ], dim=1)
173
+
174
+ # Create concatenated signal
175
+ comparison = AudioSignal(
176
+ concat_data,
177
+ sample_rate=self.sample_rate
178
+ )
179
+
180
+ self.save_audio(comparison, output_path)
181
+ print(f"Saved comparison (original + reconstruction) to {output_path}")
182
+
183
+ def visualize_latent(self, audio_path, output_path):
184
+ """Visualize the latent representation of full audio"""
185
+ # Get latent
186
+ _, z = self.reconstruct_audio(audio_path, save_latent=True)
187
+
188
+ z_np = z.squeeze(0).numpy() # Remove batch dimension
189
+
190
+ # Create visualization
191
+ if z_np.ndim == 2: # [channels, frames]
192
+ n_channels = z_np.shape[0]
193
+ fig, axes = plt.subplots(n_channels, 1, figsize=(12, 2*n_channels))
194
+
195
+ if n_channels == 1:
196
+ axes = [axes]
197
+
198
+ for i in range(n_channels):
199
+ im = axes[i].imshow(
200
+ z_np[i:i+1],
201
+ aspect='auto',
202
+ cmap='coolwarm',
203
+ interpolation='nearest'
204
+ )
205
+ axes[i].set_title(f'Latent Channel {i+1}')
206
+ axes[i].set_xlabel('Time Frames')
207
+ axes[i].set_ylabel('Feature')
208
+ plt.colorbar(im, ax=axes[i])
209
+ else: # 1D latent
210
+ plt.figure(figsize=(12, 4))
211
+ plt.plot(z_np.T)
212
+ plt.title('Latent Representation')
213
+ plt.xlabel('Time Frames')
214
+ plt.ylabel('Value')
215
+
216
+ plt.tight_layout()
217
+ plt.savefig(output_path, dpi=150)
218
+ plt.close()
219
+
220
+ print(f"Saved latent visualization to {output_path}")
221
+
222
+ def batch_reconstruct(self, audio_folder, output_folder, max_files=None, num_steps=50):
223
+ """Reconstruct all audio files in a folder (full audio)"""
224
+ audio_folder = Path(audio_folder)
225
+ output_folder = Path(output_folder)
226
+ output_folder.mkdir(exist_ok=True, parents=True)
227
+
228
+ # Get all audio files
229
+ audio_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
230
+ audio_paths = []
231
+ for ext in audio_extensions:
232
+ audio_paths.extend(audio_folder.glob(f'*{ext}'))
233
+ audio_paths.extend(audio_folder.glob(f'*{ext.upper()}'))
234
+
235
+ if max_files:
236
+ audio_paths = audio_paths[:max_files]
237
+
238
+ print(f"Processing {len(audio_paths)} audio files...")
239
+
240
+ for audio_path in audio_paths:
241
+ output_path = output_folder / f"recon_{audio_path.stem}.wav"
242
+ try:
243
+ self.save_reconstruction(
244
+ str(audio_path), str(output_path),
245
+ num_steps=num_steps
246
+ )
247
+ except Exception as e:
248
+ print(f"Error processing {audio_path}: {e}")
249
+ continue
250
+
251
+ print("Batch reconstruction complete!")
252
+
253
+
254
+ def main():
255
+ parser = argparse.ArgumentParser(description='Audio DiTo Inference')
256
+ parser.add_argument('--checkpoint', type=str, required=True,
257
+ help='Path to Audio DiTo checkpoint')
258
+ parser.add_argument('--input', type=str, required=True,
259
+ help='Input audio path or folder')
260
+ parser.add_argument('--output', type=str, required=True,
261
+ help='Output path')
262
+ parser.add_argument('--compare', action='store_true',
263
+ help='Save comparison with original')
264
+ parser.add_argument('--batch', action='store_true',
265
+ help='Process entire folder')
266
+ parser.add_argument('--visualize', action='store_true',
267
+ help='Visualize latent representation')
268
+ parser.add_argument('--steps', type=int, default=50,
269
+ help='Number of diffusion steps')
270
+ parser.add_argument('--device', type=str, default='cuda',
271
+ help='Device to use (cuda/cpu)')
272
+ parser.add_argument('--max-files', type=int, default=None,
273
+ help='Maximum files to process in batch mode')
274
+
275
+ args = parser.parse_args()
276
+
277
+ # Initialize model
278
+ audio_dito = AudioDiToInference(args.checkpoint, device=args.device)
279
+
280
+ # Process based on mode
281
+ if args.batch:
282
+ # Batch processing
283
+ audio_dito.batch_reconstruct(
284
+ args.input, args.output,
285
+ max_files=args.max_files,
286
+ num_steps=args.steps
287
+ )
288
+ elif args.visualize:
289
+ # Visualize latent
290
+ audio_dito.visualize_latent(
291
+ args.input, args.output
292
+ )
293
+ elif args.compare:
294
+ # Save comparison
295
+ audio_dito.compare_reconstruction(
296
+ args.input, args.output,
297
+ num_steps=args.steps
298
+ )
299
+ else:
300
+ # Single reconstruction
301
+ audio_dito.save_reconstruction(
302
+ args.input, args.output,
303
+ num_steps=args.steps
304
+ )
305
+
306
+
307
+ # Example usage function for direct Python use
308
+ def reconstruct_single_audio(checkpoint_path, audio_path, output_path):
309
+ """Simple function to reconstruct a single audio file"""
310
+ audio_dito = AudioDiToInference(checkpoint_path)
311
+ audio_dito.save_reconstruction(audio_path, output_path)
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
316
+
317
+
318
+ # Usage examples:
319
+ # 1. Single audio reconstruction (full audio):
320
+ # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav
321
+ #
322
+ # 2. Save comparison (original + reconstruction):
323
+ # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output compare.wav --compare
324
+ #
325
+ # 3. Batch processing (reconstruct all audio files in folder):
326
+ # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio_folder/ --output output_folder/ --batch
327
+ #
328
+ # 4. Visualize latent representation:
329
+ # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output latent.png --visualize
330
+ #
331
+ # 5. Use fewer diffusion steps for faster inference:
332
+ # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav --steps 25
flowae/configs/datasets/dae.yaml CHANGED
@@ -4,22 +4,19 @@ datasets:
4
  name: wrapper_audio_cae
5
  args:
6
  dataset:
7
- name: audio_dataset_from_folders
8
  args:
9
- folders:
10
- Emilia_EN: ["/home/masuser/minimax-audio/dataset/Emilia/EN"]
11
  sample_rate: 24000
12
  duration: 0.38
13
- n_examples: 10000000
14
  shuffle: true
15
- mono: true
16
  sample_rate: 24000
17
  duration: 0.38
18
  mono: true
19
  normalize: true
20
- return_coords: true
21
  loader:
22
- batch_size: 64
23
  num_workers: 8
24
  drop_last: true
25
 
@@ -27,20 +24,17 @@ datasets:
27
  name: wrapper_audio_cae
28
  args:
29
  dataset:
30
- name: audio_dataset_from_folders
31
  args:
32
- folders:
33
- Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"]
34
  sample_rate: 24000
35
  duration: 5.0
36
- n_examples: 100
37
  shuffle: false
38
- mono: true
39
  sample_rate: 24000
40
  duration: 5.0
41
  mono: true
42
  normalize: true
43
- return_coords: true
44
  loader:
45
  batch_size: 4
46
  num_workers: 8
@@ -50,20 +44,17 @@ datasets:
50
  name: wrapper_audio_cae
51
  args:
52
  dataset:
53
- name: audio_dataset_from_folders
54
  args:
55
- folders:
56
- Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"]
57
  sample_rate: 24000
58
- duration: 10.0
59
- n_examples: 1000
60
  shuffle: false
61
- mono: true
62
  sample_rate: 24000
63
- duration: 10.0
64
  mono: true
65
  normalize: true
66
- return_coords: true
67
  loader:
68
  batch_size: 1
69
  num_workers: 8
 
4
  name: wrapper_audio_cae
5
  args:
6
  dataset:
7
+ name: class_folder_audio
8
  args:
9
+ root_path: "/home/masuser/minimax-audio/dataset/Emilia/EN"
 
10
  sample_rate: 24000
11
  duration: 0.38
 
12
  shuffle: true
13
+ num_channels: 1
14
  sample_rate: 24000
15
  duration: 0.38
16
  mono: true
17
  normalize: true
 
18
  loader:
19
+ batch_size: 52
20
  num_workers: 8
21
  drop_last: true
22
 
 
24
  name: wrapper_audio_cae
25
  args:
26
  dataset:
27
+ name: class_folder_audio
28
  args:
29
+ root_path: "/home/masuser/minimax-audio/dataset/libritts"
 
30
  sample_rate: 24000
31
  duration: 5.0
 
32
  shuffle: false
33
+ num_channels: 1
34
  sample_rate: 24000
35
  duration: 5.0
36
  mono: true
37
  normalize: true
 
38
  loader:
39
  batch_size: 4
40
  num_workers: 8
 
44
  name: wrapper_audio_cae
45
  args:
46
  dataset:
47
+ name: class_folder_audio
48
  args:
49
+ root_path: "/home/masuser/minimax-audio/dataset/libritts"
 
50
  sample_rate: 24000
51
+ duration: 5.0
 
52
  shuffle: false
53
+ num_channels: 1
54
  sample_rate: 24000
55
+ duration: 5.0
56
  mono: true
57
  normalize: true
 
58
  loader:
59
  batch_size: 1
60
  num_workers: 8
flowae/configs/experiments/dito-B-audio.yaml CHANGED
@@ -8,12 +8,16 @@ model:
8
  # Encoder
9
  encoder:
10
  name: dac_encoder
11
- args: {config_name: snakebeta}
12
 
13
  # Latent configuration - now fully convolutional
14
  z_channels: 64 # Number of latent channels
15
- z_downsample_factor: 320 # Product of encoder_rates: 2*4*5*8
16
- z_layernorm: true
 
 
 
 
17
 
18
  # Decoder (identity for DiTo)
19
  decoder:
@@ -21,10 +25,10 @@ model:
21
 
22
  # Renderer - Fully convolutional for dynamic duration
23
  renderer:
24
- name: audio_renderer_wrapper
25
  args:
26
  net:
27
- name: consistency_decoder_unet # Fully Convolutional Network
28
  args:
29
  in_channels: 1
30
  z_dec_channels: 64
@@ -39,6 +43,6 @@ model:
39
  name: fm
40
  args: {timescale: 1000.0}
41
 
42
- render_sampler: {name: fm_euler_sampler}
43
  render_n_steps: 50
44
 
 
8
  # Encoder
9
  encoder:
10
  name: dac_encoder
11
+ args: {config_name: snake}
12
 
13
  # Latent configuration - now fully convolutional
14
  z_channels: 64 # Number of latent channels
15
+
16
+ zaug_p: 0.1
17
+ zaug_decoding_loss_type: suffix
18
+ zaug_zdm_diffusion:
19
+ name: fm
20
+ args: {timescale: 1000.0}
21
 
22
  # Decoder (identity for DiTo)
23
  decoder:
 
25
 
26
  # Renderer - Fully convolutional for dynamic duration
27
  renderer:
28
+ name: fixres_renderer_wrapper
29
  args:
30
  net:
31
+ name: audio_diffusion_unet
32
  args:
33
  in_channels: 1
34
  z_dec_channels: 64
 
43
  name: fm
44
  args: {timescale: 1000.0}
45
 
46
+ render_sampler: {name: fm_euler_sampler_audio}
47
  render_n_steps: 50
48
 
flowae/datasets/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .datasets import register, make
2
- from . import image_folder, class_folder, webdataset
3
- from . import wrapper_cae
 
1
  from .datasets import register, make
2
+ from . import image_folder, class_folder, webdataset, class_folder_audio
3
+ from . import wrapper_cae, wrapper_audio_cae
flowae/datasets/class_folder.py CHANGED
@@ -6,6 +6,8 @@ from datasets import register
6
  from torch.utils.data import Dataset
7
  from torchvision import transforms
8
 
 
 
9
 
10
  Image.MAX_IMAGE_PIXELS = 933120000
11
  ImageFile.LOAD_TRUNCATED_IMAGES = True
 
6
  from torch.utils.data import Dataset
7
  from torchvision import transforms
8
 
9
+ import os
10
+ import random
11
 
12
  Image.MAX_IMAGE_PIXELS = 933120000
13
  ImageFile.LOAD_TRUNCATED_IMAGES = True
flowae/datasets/class_folder_audio.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image, ImageFile
4
+
5
+ from datasets import register
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+
9
+ import os
10
+ import random
11
+ from pathlib import Path
12
+ from typing import Optional, Callable
13
+
14
+ from models.ldm.dac.audiotools import AudioSignal
15
+ from models.ldm.dac.audiotools.core import util
16
+ # Audio file extensions (from audiotools)
17
+ AUDIO_EXTS = ('.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.MP3', '.mp4', '.MP4', '.m4a', '.M4A')
18
+
19
+ @register('class_folder_audio')
20
+ class AudioFolder(Dataset):
21
+ """
22
+ Audio dataset that loads audio files from a folder structure.
23
+ Similar to ClassFolder but for audio files.
24
+
25
+ Expected folder structure:
26
+ root_path/
27
+ ├── class1/
28
+ │ ├── audio1.wav
29
+ │ ├── audio2.wav
30
+ │ └── ...
31
+ ├── class2/
32
+ │ ├── audio1.wav
33
+ │ └── ...
34
+ └── ...
35
+
36
+ Or for single class (no subfolders):
37
+ root_path/
38
+ ├── audio1.wav
39
+ ├── audio2.wav
40
+ └── ...
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ root_path: str,
46
+ sample_rate: int = 24000,
47
+ duration: float = 2.0,
48
+ num_channels: int = 1,
49
+ random_crop: bool = True,
50
+ loudness_cutoff: float = -40,
51
+ audio_only: bool = False,
52
+ drop_label_p: float = 0.0,
53
+ shuffle: bool = True,
54
+ shuffle_state: int = 0,
55
+ transform: Optional[Callable] = None,
56
+ normalize: bool = True,
57
+ trim_silence: bool = False,
58
+ ):
59
+ """
60
+ Args:
61
+ root_path: Path to audio files
62
+ sample_rate: Target sample rate for audio
63
+ duration: Duration in seconds for audio clips
64
+ num_channels: Number of channels (1 for mono, 2 for stereo)
65
+ random_crop: Whether to randomly crop audio (vs deterministic)
66
+ loudness_cutoff: Minimum loudness threshold for audio selection
67
+ audio_only: If True, return only audio signal. If False, return dict with labels
68
+ drop_label_p: Probability of dropping labels (for unconditional training)
69
+ shuffle: Whether to shuffle files
70
+ shuffle_state: Random state for shuffling
71
+ transform: Additional audio transforms
72
+ normalize: Whether to normalize audio amplitude
73
+ trim_silence: Whether to trim silence from audio
74
+ """
75
+ self.root_path = root_path
76
+ self.sample_rate = sample_rate
77
+ self.duration = duration
78
+ self.num_channels = num_channels
79
+ self.random_crop = random_crop
80
+ self.loudness_cutoff = loudness_cutoff
81
+ self.audio_only = audio_only
82
+ self.drop_label_p = drop_label_p
83
+ self.transform = transform
84
+ self.normalize = normalize
85
+ self.trim_silence = trim_silence
86
+
87
+ print(f'Audio root_path: {root_path}')
88
+
89
+ # Find audio files and labels
90
+ self.files = []
91
+
92
+ # Fin all audio in recursive in root_path
93
+ for root, dirs, files in os.walk(self.root_path):
94
+ for file in files:
95
+ if file.lower().endswith(AUDIO_EXTS):
96
+ self.files.append(os.path.join(root, file))
97
+
98
+
99
+ print(f'Found {len(self.files)} audio files')
100
+
101
+ # Shuffle files if requested
102
+ if shuffle:
103
+ state = util.random_state(shuffle_state)
104
+ combined = self.files
105
+ state.shuffle(combined)
106
+ self.files = combined
107
+
108
+ def __len__(self):
109
+ return len(self.files)
110
+
111
+ def __getitem__(self, idx):
112
+ try:
113
+ file_path = self.files[idx]
114
+
115
+ # Load audio using AudioSignal
116
+ if self.random_crop:
117
+ # Use salient excerpt for random cropping with loudness filtering
118
+ signal = AudioSignal.salient_excerpt(
119
+ str(file_path),
120
+ duration=self.duration,
121
+ loudness_cutoff=self.loudness_cutoff,
122
+ )
123
+ else:
124
+ # Load from beginning or deterministic offset
125
+ signal = AudioSignal(
126
+ str(file_path),
127
+ duration=self.duration,
128
+ offset=0.0,
129
+ )
130
+
131
+ # Convert to mono/stereo as needed
132
+ if self.num_channels == 1:
133
+ signal = signal.to_mono()
134
+
135
+ # Resample to target sample rate
136
+ signal = signal.resample(self.sample_rate)
137
+
138
+ # Ensure duration by padding or trimming
139
+ target_samples = int(self.duration * self.sample_rate)
140
+ if signal.length < target_samples:
141
+ signal = signal.zero_pad_to(target_samples)
142
+ elif signal.length > target_samples:
143
+ signal = signal.truncate_samples(target_samples)
144
+
145
+ # Optional audio processing
146
+ if self.trim_silence:
147
+ signal = signal.trim_silence()
148
+ # Re-pad if trimming made it too short
149
+ if signal.length < target_samples:
150
+ signal = signal.zero_pad_to(target_samples)
151
+
152
+ if self.normalize:
153
+ signal = signal.normalize()
154
+
155
+ # Clamp audio to [-1, 1] range
156
+ signal.audio_data = signal.audio_data.clamp(-1.0, 1.0)
157
+
158
+ # Apply additional transforms if provided
159
+ if self.transform is not None:
160
+ # Create a random state for transforms
161
+ state = util.random_state(idx)
162
+ transform_args = self.transform.instantiate(state, signal=signal)
163
+ signal = self.transform(signal, **transform_args)
164
+
165
+ # print('before process: ', signal.audio_data.shape)
166
+ # Store metadata
167
+ signal.metadata.update(
168
+ {
169
+ 'file_path': str(file_path),
170
+ 'original_sr': signal.sample_rate,
171
+ 'duration': self.duration,
172
+ }
173
+ )
174
+
175
+ if self.audio_only:
176
+ return signal
177
+ else:
178
+ return {
179
+ 'signal': signal,
180
+ 'file_path': str(file_path),
181
+ 'idx': idx,
182
+ }
183
+
184
+ except Exception as e:
185
+ print(f'Error loading audio file {self.files[idx]}: {e}')
186
+ # Return next file on error to avoid crashing training
187
+ return self.__getitem__((idx + 1) % len(self))
188
+
189
+ def collate(self, batch):
190
+ """Collate function for DataLoader"""
191
+ if self.audio_only:
192
+ # Batch AudioSignals
193
+ return AudioSignal.batch(batch)
194
+ else:
195
+ # Collate dictionary batch
196
+ return util.collate(batch)
flowae/datasets/wrapper_audio_cae.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from PIL import Image
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset, IterableDataset
6
+
7
+ from datasets import register
8
+ import datasets
9
+
10
+ class BaseWrapperAudioCAE:
11
+ """Base wrapper for audio Convolutional Autoencoder (CAE) training.
12
+
13
+ Similar to the image wrapper, but for audio data.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ dataset,
19
+ sample_rate=24000,
20
+ duration=0.38, # Duration in seconds
21
+ n_samples=None, # Alternative: specify exact number of samples
22
+ return_gt=True,
23
+ gt_sample_rate=None, # Ground truth sample rate (if different)
24
+ mono=True,
25
+ normalize=True,
26
+ return_coords=True, # Whether to return coordinate grids
27
+ ):
28
+ self.dataset = datasets.make(dataset)
29
+ self.sample_rate = sample_rate
30
+ self.duration = duration
31
+ self.n_samples = int(duration * sample_rate)
32
+ self.return_gt = return_gt
33
+ self.gt_sample_rate = gt_sample_rate or sample_rate
34
+ self.mono = mono
35
+ self.normalize = normalize
36
+ self.return_coords = return_coords
37
+
38
+ def process(self, audio_data):
39
+ """Process audio data for DiTo training.
40
+
41
+ Args:
42
+ audio_data: Dictionary with 'signal' key containing AudioSignal
43
+ or AudioSignal directly
44
+ """
45
+ ret = {}
46
+
47
+ # Extract AudioSignal
48
+ if isinstance(audio_data, dict):
49
+ signal = audio_data['signal']
50
+ else:
51
+ signal = audio_data
52
+
53
+ # Normalize audio
54
+ audio_tensor = signal.audio_data # Shape: [channels, samples]
55
+
56
+ audio_tensor = audio_tensor.squeeze(0)
57
+
58
+ # Create input tensor
59
+ ret['inp'] = audio_tensor
60
+
61
+ if not self.return_gt:
62
+ return ret
63
+
64
+
65
+ ret['gt'] = audio_tensor
66
+ # print('audio_tensor shape: ', audio_tensor.shape)
67
+
68
+ return ret
69
+
70
+
71
+ @register('wrapper_audio_cae')
72
+ class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset):
73
+ """Dataset wrapper for audio CAE training."""
74
+
75
+ def __len__(self):
76
+ return len(self.dataset)
77
+
78
+ def __getitem__(self, idx):
79
+ data = self.dataset[idx]
80
+ return self.process(data)
81
+
82
+
83
+ @register('wrapper_audio_cae_iterable')
84
+ class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset):
85
+ """Iterable dataset wrapper for audio CAE training."""
86
+
87
+ def __iter__(self):
88
+ for data in self.dataset:
89
+ yield self.process(data)
flowae/datasets/wrapper_cae.py CHANGED
@@ -113,196 +113,4 @@ class WrapperCAE(BaseWrapperCAE, IterableDataset):
113
  ret.update(data)
114
  yield ret
115
  else:
116
- yield self.process(data)
117
-
118
-
119
-
120
-
121
-
122
-
123
- class BaseWrapperAudioCAE:
124
- """Base wrapper for audio Convolutional Autoencoder (CAE) training.
125
-
126
- Similar to the image wrapper, but for audio data.
127
- """
128
-
129
- def __init__(
130
- self,
131
- dataset,
132
- sample_rate=24000,
133
- duration=0.38, # Duration in seconds
134
- n_samples=None, # Alternative: specify exact number of samples
135
- return_gt=True,
136
- gt_sample_rate=None, # Ground truth sample rate (if different)
137
- mono=True,
138
- normalize=True,
139
- return_coords=True, # Whether to return coordinate grids
140
- ):
141
- self.dataset = dataset
142
- self.sample_rate = sample_rate
143
- self.duration = duration
144
- self.n_samples = n_samples or int(duration * sample_rate)
145
- self.return_gt = return_gt
146
- self.gt_sample_rate = gt_sample_rate or sample_rate
147
- self.mono = mono
148
- self.normalize = normalize
149
- self.return_coords = return_coords
150
-
151
- def process(self, audio_data):
152
- """Process audio data for DiTo training.
153
-
154
- Args:
155
- audio_data: Dictionary with 'signal' key containing AudioSignal
156
- or AudioSignal directly
157
- """
158
- ret = {}
159
-
160
- # Extract AudioSignal
161
- if isinstance(audio_data, dict):
162
- signal = audio_data['signal']
163
- else:
164
- signal = audio_data
165
-
166
- # Convert to mono if needed
167
- if self.mono and signal.num_channels > 1:
168
- signal = signal.to_mono()
169
-
170
- # Resample to target sample rate
171
- if signal.sample_rate != self.sample_rate:
172
- signal = signal.resample(self.sample_rate)
173
-
174
- # Extract fixed duration
175
- if signal.duration < self.duration:
176
- # Pad if too short
177
- signal = signal.zero_pad_to(self.n_samples)
178
- else:
179
- # Take random excerpt if too long
180
- max_start = signal.num_samples - self.n_samples
181
- if max_start > 0:
182
- start_idx = random.randint(0, max_start)
183
- signal = signal[..., start_idx:start_idx + self.n_samples]
184
- else:
185
- signal = signal[..., :self.n_samples]
186
-
187
- # Normalize audio
188
- audio_tensor = signal.audio_data # Shape: [channels, samples]
189
- if self.normalize:
190
- # Normalize to [-1, 1]
191
- max_val = audio_tensor.abs().max()
192
- if max_val > 0:
193
- audio_tensor = audio_tensor / max_val
194
-
195
- # Create input tensor
196
- ret['inp'] = audio_tensor
197
-
198
- if not self.return_gt:
199
- return ret
200
-
201
-
202
- ret['gt'] = audio_tensor
203
-
204
- return ret
205
-
206
-
207
- @register('wrapper_audio_cae')
208
- class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset):
209
- """Dataset wrapper for audio CAE training."""
210
-
211
- def __len__(self):
212
- return len(self.dataset)
213
-
214
- def __getitem__(self, idx):
215
- data = self.dataset[idx]
216
- return self.process(data)
217
-
218
-
219
- @register('wrapper_audio_cae_iterable')
220
- class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset):
221
- """Iterable dataset wrapper for audio CAE training."""
222
-
223
- def __iter__(self):
224
- for data in self.dataset:
225
- yield self.process(data)
226
-
227
-
228
- # Example usage with your existing AudioDataset
229
- def create_dito_audio_dataset(config):
230
- """Create DiTo audio dataset from config."""
231
-
232
- # Create base audio dataset using audiotools
233
-
234
- # Setup audio loaders
235
- train_folders = config.get("train_folders", {})
236
-
237
- loader = AudioLoader(
238
- sources=list(train_folders.values()),
239
- transform=tfm.Compose(
240
- tfm.VolumeNorm(("uniform", -20, -10)),
241
- tfm.RescaleAudio(),
242
- ),
243
- ext=['.wav', '.flac', '.mp3'],
244
- )
245
-
246
- # Create base dataset
247
- base_dataset = AudioDataset(
248
- loaders=loader,
249
- sample_rate=config['sample_rate'],
250
- duration=config['duration'],
251
- n_examples=config['n_examples'],
252
- num_channels=1 if config.get('mono', True) else 2,
253
- )
254
-
255
- # Wrap with DiTo wrapper
256
- dito_dataset = WrapperAudioCAE(
257
- dataset=base_dataset,
258
- sample_rate=config['sample_rate'],
259
- duration=config['duration'],
260
- mono=config.get('mono', True),
261
- normalize=True,
262
- return_coords=True,
263
- )
264
-
265
- return dito_dataset
266
-
267
-
268
- # For your training config, you would use it like:
269
- """
270
- datasets:
271
- train:
272
- name: wrapper_audio_cae
273
- args:
274
- dataset:
275
- name: audio_dataset # Your base audio dataset
276
- args:
277
- sources: ["/path/to/audio/files"]
278
- sample_rate: 44100
279
- duration: 2.0
280
- n_examples: 10000
281
- sample_rate: 44100
282
- duration: 2.0
283
- mono: true
284
- normalize: true
285
- return_coords: true
286
- loader:
287
- batch_size: 16
288
- num_workers: 8
289
-
290
- val:
291
- name: wrapper_audio_cae
292
- args:
293
- dataset:
294
- name: audio_dataset
295
- args:
296
- sources: ["/path/to/val/audio/files"]
297
- sample_rate: 44100
298
- duration: 2.0
299
- n_examples: 1000
300
- sample_rate: 44100
301
- duration: 2.0
302
- mono: true
303
- normalize: true
304
- return_coords: true
305
- loader:
306
- batch_size: 16
307
- num_workers: 8
308
- """
 
113
  ret.update(data)
114
  yield ret
115
  else:
116
+ yield self.process(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flowae/{reconstruction.py → image_dito_inference.py} RENAMED
File without changes
flowae/models/diffusion/fm.py CHANGED
@@ -22,6 +22,21 @@ class FM:
22
 
23
  def B(self, t):
24
  return -(1.0 - self.sigma_min)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def get_betas(self, n_timesteps):
27
  return torch.zeros(n_timesteps) # Not VP and not supported
@@ -38,17 +53,20 @@ class FM:
38
 
39
  if t is None:
40
  t = torch.rand(x.shape[0], device=x.device)
41
- print('x shape: ', x.shape)
42
  x_t, noise = self.add_noise(x, t)
43
- print('x_t shape: ', x_t.shape)
44
  pred = net(x_t, t=t * self.timescale, **net_kwargs)
45
- print('pred shape: ', pred.shape)
46
 
47
  target = self.A(t) * x + self.B(t) * noise # -dxt/dt
48
- print('target shape: ', target.shape)
49
- print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all)
50
  if return_loss_unreduced:
51
- loss = ((pred.float() - target.float()) ** 2).mean(dim=[1, 2, 3])
 
 
 
52
  if return_all:
53
  return loss, t, x_t, pred
54
  else:
 
22
 
23
  def B(self, t):
24
  return -(1.0 - self.sigma_min)
25
+
26
+ def _get_reduction_dims(self, x):
27
+ """Get appropriate dimensions for loss reduction based on tensor shape"""
28
+ if x.dim() == 4:
29
+ # Images: [batch, channels, height, width]
30
+ return [1, 2, 3]
31
+ elif x.dim() == 3:
32
+ # Audio: [batch, channels, samples] or [batch, latent_dim, time_frames]
33
+ return [1, 2]
34
+ elif x.dim() == 2:
35
+ # 1D signals: [batch, samples]
36
+ return [1]
37
+ else:
38
+ # Fallback: reduce over all non-batch dimensions
39
+ return list(range(1, x.dim()))
40
 
41
  def get_betas(self, n_timesteps):
42
  return torch.zeros(n_timesteps) # Not VP and not supported
 
53
 
54
  if t is None:
55
  t = torch.rand(x.shape[0], device=x.device)
56
+ # print('x shape: ', x.shape)
57
  x_t, noise = self.add_noise(x, t)
58
+ # print('x_t shape: ', x_t.shape)
59
  pred = net(x_t, t=t * self.timescale, **net_kwargs)
60
+ # print('pred shape: ', pred.shape)
61
 
62
  target = self.A(t) * x + self.B(t) * noise # -dxt/dt
63
+ # print('target shape: ', target.shape)
64
+ # print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all)
65
  if return_loss_unreduced:
66
+ print('pred shape: ', pred.shape, 'target shape: ', target.shape)
67
+ reduce_dims = self._get_reduction_dims(x)
68
+ loss = ((pred.float() - target.float()) ** 2).mean(dim=reduce_dims)
69
+ # loss = ((pred.float() - target.float()) ** 2).mean(dim=[1, 2, 3])
70
  if return_all:
71
  return loss, t, x_t, pred
72
  else:
flowae/models/ldm/dac/layers.py CHANGED
@@ -74,7 +74,7 @@ def get_activation(activation, channels, alpha):
74
  return nn.LeakyReLU()
75
  elif activation == "tanh":
76
  return nn.Tanh()
77
- elif activation == "snake_beta":
78
  return SnakeBeta(channels, alpha)
79
  else:
80
  raise ValueError(f"Activation {activation} not supported")
 
74
  return nn.LeakyReLU()
75
  elif activation == "tanh":
76
  return nn.Tanh()
77
+ elif activation == "snakebeta":
78
  return SnakeBeta(channels, alpha)
79
  else:
80
  raise ValueError(f"Activation {activation} not supported")
flowae/models/ldm/dac/model.py CHANGED
@@ -236,7 +236,8 @@ class Encoder(nn.Module):
236
 
237
  def forward(self, x):
238
  x = F.leaky_relu(x)
239
- return self.block(x)
 
240
 
241
 
242
  class DecoderBlock(nn.Module):
@@ -478,6 +479,7 @@ class DACVAE(BaseModel, CodecMixin):
478
  ):
479
  x = self.encoder(audio_data)
480
  x = self.en_conv_post(x)
 
481
  m, logs = torch.split(x, self.latent_dim, dim=1)
482
  logs = torch.clamp(logs, min=-14.0, max=14.0)
483
 
 
236
 
237
  def forward(self, x):
238
  x = F.leaky_relu(x)
239
+ x = self.block(x)
240
+ return x
241
 
242
 
243
  class DecoderBlock(nn.Module):
 
479
  ):
480
  x = self.encoder(audio_data)
481
  x = self.en_conv_post(x)
482
+ print('x shape: ', x.shape)
483
  m, logs = torch.split(x, self.latent_dim, dim=1)
484
  logs = torch.clamp(logs, min=-14.0, max=14.0)
485
 
flowae/models/ldm/dac/utils.py CHANGED
@@ -7,16 +7,16 @@ from .model import Encoder, Decoder, WNConv1d
7
 
8
  default_configs = {
9
  'snake': dict(
10
- encoder_dim=64,
11
- encoder_rates=[2, 4, 5, 8],
12
- latent_dim=64,
13
  d_in=1,
14
  activation='snake',
15
  ),
16
- 'snake': dict(
17
- encoder_dim=64,
18
- encoder_rates=[2, 4, 5, 8],
19
- latent_dim=64,
20
  d_in=1,
21
  activation='snakebeta',
22
  ),
@@ -27,10 +27,10 @@ default_configs = {
27
  def make_dac_encoder(config_name, **kwargs):
28
  encoder_kwargs = default_configs[config_name]
29
  encoder_kwargs.update(kwargs)
30
- latent_dim = encoder_kwargs['latent_dim']
31
  return nn.Sequential(
32
  Encoder(**encoder_kwargs),
33
- WNConv1d(latent_dim, latent_dim, kernel_size=1),
34
  )
35
 
36
 
@@ -38,8 +38,8 @@ def make_dac_encoder(config_name, **kwargs):
38
  def make_vqgan_decoder(config_name, **kwargs):
39
  decoder_kwargs = default_configs[config_name]
40
  decoder_kwargs.update(kwargs)
41
- latent_dim = decoder_kwargs['latent_dim']
42
  return nn.Sequential(
43
- WNConv1d(latent_dim, latent_dim, kernel_size=1),
44
  Decoder(**decoder_kwargs),
45
  )
 
7
 
8
  default_configs = {
9
  'snake': dict(
10
+ d_model=64,
11
+ strides=[2, 4, 5, 8],
12
+ d_latent=64,
13
  d_in=1,
14
  activation='snake',
15
  ),
16
+ 'snakebeta': dict(
17
+ d_model=64,
18
+ strides=[2, 4, 5, 8],
19
+ d_latent=64,
20
  d_in=1,
21
  activation='snakebeta',
22
  ),
 
27
  def make_dac_encoder(config_name, **kwargs):
28
  encoder_kwargs = default_configs[config_name]
29
  encoder_kwargs.update(kwargs)
30
+ d_model = encoder_kwargs['d_model']
31
  return nn.Sequential(
32
  Encoder(**encoder_kwargs),
33
+ WNConv1d(d_model, d_model, kernel_size=1),
34
  )
35
 
36
 
 
38
  def make_vqgan_decoder(config_name, **kwargs):
39
  decoder_kwargs = default_configs[config_name]
40
  decoder_kwargs.update(kwargs)
41
+ d_model = decoder_kwargs['d_model']
42
  return nn.Sequential(
43
+ WNConv1d(d_model, d_model, kernel_size=1),
44
  Decoder(**decoder_kwargs),
45
  )
flowae/models/ldm/dito.py CHANGED
@@ -6,7 +6,8 @@ import torch
6
  import models
7
  from omegaconf import OmegaConf
8
  from models import register
9
- from models.ldm.ldm_base import LDMBase
 
10
  from models.ldm.vqgan.lpips import LPIPS
11
 
12
 
@@ -178,3 +179,143 @@ class DiTo(LDMBase):
178
  dae_loss_w = loss_config.get('dae_loss', 1)
179
  ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
180
  return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import models
7
  from omegaconf import OmegaConf
8
  from models import register
9
+
10
+ from models.ldm.ldm_base import LDMBase, LDMBaseAudio
11
  from models.ldm.vqgan.lpips import LPIPS
12
 
13
 
 
179
  dae_loss_w = loss_config.get('dae_loss', 1)
180
  ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
181
  return ret
182
+
183
+
184
+
185
+ @register('dito_audio')
186
+ class DiToAudio(LDMBaseAudio):
187
+
188
+ def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1,**kwargs):
189
+ super().__init__(**kwargs)
190
+ self.render_diffusion = models.make(render_diffusion)
191
+
192
+ if OmegaConf.is_config(render_sampler):
193
+ render_sampler = OmegaConf.to_container(render_sampler, resolve=True)
194
+ render_sampler = copy.deepcopy(render_sampler)
195
+ if render_sampler.get('args') is None:
196
+ render_sampler['args'] = {}
197
+ render_sampler['args']['diffusion'] = self.render_diffusion
198
+ self.render_sampler = models.make(render_sampler)
199
+ self.render_n_steps = render_n_steps
200
+ self.renderer_guidance = renderer_guidance
201
+
202
+ self.t_loss_monitor_v = [0 for _ in range(10)]
203
+ self.t_loss_monitor_n = [0 for _ in range(10)]
204
+ self.t_loss_monitor_decay = 0.99
205
+
206
+
207
+ def render(self, z_dec):
208
+ net_kwargs = {'z_dec': z_dec}
209
+ n_frames = z_dec.size(2) * 320
210
+ shape = (z_dec.size(0), z_dec.size(0), n_frames)
211
+
212
+ if self.renderer_guidance > 1:
213
+ uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1)
214
+ uncond_net_kwargs = {'z_dec': uncond_z_dec}
215
+ else:
216
+ uncond_net_kwargs = None
217
+
218
+ ret = self.render_sampler.sample(
219
+ net=self.renderer,
220
+ n_steps=self.render_n_steps,
221
+ shape=shape,
222
+ net_kwargs=net_kwargs,
223
+ uncond_net_kwargs=uncond_net_kwargs,
224
+ guidance=self.renderer_guidance,
225
+ )
226
+
227
+ # if self.use_ema_renderer:
228
+ # self.swap_ema_renderer()
229
+
230
+ return ret
231
+
232
+ def forward(self, data, mode, has_optimizer=None):
233
+ if mode in ['z', 'z_dec']:
234
+ ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer)
235
+ return ret_z
236
+
237
+ grad = self.get_grad_plan(has_optimizer)
238
+ loss_config = self.loss_config
239
+ if mode == 'pred':
240
+ z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
241
+
242
+ gt_patch = data['gt']
243
+
244
+ if grad['renderer']:
245
+ return self.render(z_dec)
246
+ else:
247
+ with torch.no_grad():
248
+ return self.render(z_dec)
249
+
250
+ elif mode == 'loss':
251
+ if not grad['renderer']: # Only training zdm
252
+ _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer)
253
+ return ret
254
+
255
+ gt_patch = data['gt']
256
+
257
+
258
+ z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
259
+ net_kwargs = {'z_dec': z_dec}
260
+
261
+ # print('latent z_dec shape: ', z_dec.shape)
262
+
263
+ t = torch.rand(gt_patch.shape[0], device=gt_patch.device)
264
+
265
+
266
+ # print('self.zaug_p:', self.zaug_p)
267
+ # print('self.training:', self.training)
268
+
269
+ if (self.zaug_p is not None) and self.training:
270
+ tz = self._tz
271
+ mask_aug = self._mask_aug
272
+
273
+ typ = self.zaug_decoding_loss_type
274
+ if typ == 'all':
275
+ tmin = torch.ones_like(tz) * 0
276
+ tmax = torch.ones_like(tz) * 1
277
+ elif typ == 'suffix':
278
+ tmin = tz
279
+ tmax = torch.ones_like(tz) * 1
280
+ elif typ == 'tz':
281
+ tmin = tz
282
+ tmax = tz
283
+ elif typ == 'tmax':
284
+ tmin = torch.ones_like(tz) * 1
285
+ tmax = torch.ones_like(tz) * 1
286
+ else:
287
+ raise NotImplementedError
288
+ t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin)
289
+
290
+ t = mask_aug * t_aug + (1 - mask_aug) * t
291
+
292
+ loss, t = self.render_diffusion.loss(
293
+ net=self.renderer,
294
+ x=gt_patch,
295
+ t=t,
296
+ net_kwargs=net_kwargs,
297
+ return_loss_unreduced=True
298
+ )
299
+
300
+ # Visualize diffusion network loss for different timesteps #
301
+ if self.training:
302
+ m = len(self.t_loss_monitor_v)
303
+ for i in range(len(loss)):
304
+ q = min(math.floor(t[i].item() * m), m - 1)
305
+ self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay)
306
+ self.t_loss_monitor_n[q] += 1
307
+ for q in range(m):
308
+ if self.t_loss_monitor_n[q] > 0:
309
+ if self.t_loss_monitor_n[q] < 500:
310
+ r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q])
311
+ else:
312
+ r = 1
313
+ ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r
314
+ # - #
315
+
316
+ dae_loss = loss.mean()
317
+
318
+ ret['dae_loss'] = dae_loss.item()
319
+ dae_loss_w = loss_config.get('dae_loss', 1)
320
+ ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
321
+ return ret
flowae/models/ldm/ldm_base.py CHANGED
@@ -47,6 +47,39 @@ class LDMBase(nn.Module):
47
  use_ema_decoder=False,
48
  use_ema_renderer=False,
49
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  super().__init__()
51
  self.loss_config = loss_config if loss_config is not None else dict()
52
 
@@ -442,3 +475,194 @@ class DiagonalGaussianDistribution(object):
442
 
443
  def mode(self):
444
  return self.mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  use_ema_decoder=False,
48
  use_ema_renderer=False,
49
  ):
50
+ print('print all the args ')
51
+ print("encoder: ", encoder)
52
+ print("z_shape: ",z_shape)
53
+ print("decoder: ",decoder)
54
+ print("renderer: ",renderer)
55
+ print("encoder_ema_rate: ",encoder_ema_rate)
56
+ print("decoder_ema_rate: ",decoder_ema_rate)
57
+ print("renderer_ema_rate: ",renderer_ema_rate)
58
+ print("z_gaussian: ",z_gaussian)
59
+ print("z_gaussian_sample: ",z_gaussian_sample)
60
+ print("z_quantizer: ",z_quantizer)
61
+ print("z_quantizer_n_embed: ",z_quantizer_n_embed)
62
+ print("z_quantizer_beta: ",z_quantizer_beta)
63
+ print("z_layernorm: ",z_layernorm)
64
+ print("zaug_p: ",zaug_p)
65
+ print("zaug_tmax: ",zaug_tmax)
66
+ print("zaug_tmax_always: ",zaug_tmax_always)
67
+ print("zaug_decoding_loss_type: ",zaug_decoding_loss_type)
68
+ print("zaug_zdm_diffusion: ",zaug_zdm_diffusion)
69
+ print("gt_noise_lb: ",gt_noise_lb)
70
+ print("drop_z_p: ",drop_z_p)
71
+ print("zdm_net: ",zdm_net)
72
+ print("zdm_diffusion: ",zdm_diffusion)
73
+ print("zdm_sampler: ",zdm_sampler)
74
+ print("zdm_n_steps: ",zdm_n_steps)
75
+ print("zdm_ema_rate: ",zdm_ema_rate)
76
+ print("zdm_train_normalize: ",zdm_train_normalize)
77
+ print("zdm_class_cond: ",zdm_class_cond)
78
+ print("zdm_force_guidance: ",zdm_force_guidance)
79
+ print("loss_config: ",loss_config)
80
+ print("use_ema_encoder: ",use_ema_encoder)
81
+ print("use_ema_decoder: ",use_ema_decoder)
82
+ print("use_ema_renderer: ",use_ema_renderer)
83
  super().__init__()
84
  self.loss_config = loss_config if loss_config is not None else dict()
85
 
 
475
 
476
  def mode(self):
477
  return self.mean
478
+
479
+
480
+ class LDMBaseAudio(nn.Module):
481
+ def __init__(
482
+ self,
483
+ encoder,
484
+ z_channels,
485
+ decoder,
486
+ renderer,
487
+ zaug_p=0.1,
488
+ zaug_tmax=1.0,
489
+ zaug_tmax_always=False,
490
+ zaug_decoding_loss_type='all',
491
+ zaug_zdm_diffusion={'name': 'fm', 'args': {'timescale': 1000.0}},
492
+ zdm_ema_rate=0.9999,
493
+ loss_config={},
494
+ encoder_ema_rate=None,
495
+ decoder_ema_rate=None,
496
+ renderer_ema_rate=None,
497
+ ):
498
+ super().__init__()
499
+ self.loss_config = loss_config
500
+
501
+ self.encoder = models.make(encoder)
502
+ self.decoder = models.make(decoder)
503
+ self.renderer = models.make(renderer)
504
+
505
+
506
+ self.z_layernorm = nn.LayerNorm(
507
+ z_channels, # e.g., 64
508
+ elementwise_affine=False
509
+ )
510
+
511
+ self.zaug_p = zaug_p
512
+ self.zaug_tmax = zaug_tmax
513
+ self.zaug_tmax_always = zaug_tmax_always
514
+ self.zaug_decoding_loss_type = zaug_decoding_loss_type
515
+ if zaug_zdm_diffusion is not None:
516
+ self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion)
517
+
518
+ # EMA models #
519
+ self.encoder_ema_rate = encoder_ema_rate
520
+ if self.encoder_ema_rate is not None:
521
+ self.encoder_ema = copy.deepcopy(self.encoder)
522
+ for p in self.encoder_ema.parameters():
523
+ p.requires_grad = False
524
+
525
+ self.decoder_ema_rate = decoder_ema_rate
526
+ if self.decoder_ema_rate is not None:
527
+ self.decoder_ema = copy.deepcopy(self.decoder)
528
+ for p in self.decoder_ema.parameters():
529
+ p.requires_grad = False
530
+
531
+ self.renderer_ema_rate = renderer_ema_rate
532
+ if self.renderer_ema_rate is not None:
533
+ self.renderer_ema = copy.deepcopy(self.renderer)
534
+ for p in self.renderer_ema.parameters():
535
+ p.requires_grad = False
536
+ #
537
+
538
+ def get_grad_plan(self, has_optimizer):
539
+ if has_optimizer is None:
540
+ has_optimizer = dict()
541
+ grad = dict()
542
+ grad['encoder'] = has_optimizer.get('encoder', False)
543
+ grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False)
544
+ grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False)
545
+ return grad
546
+
547
+ def normalize_latents(self, z):
548
+ # z shape: [batch, latent_dim, n_frames] - n_frames can vary!
549
+ # print('bef z shape: ', z.shape)
550
+ z = z.transpose(-2, -1) # [batch, latent_dim, n_frames]
551
+ # print('z shape: ', z.shape)
552
+ z = self.z_layernorm(z) # Normalize over latent_dim for each time step
553
+ # print('z shape: ', z.shape)
554
+ z = z.transpose(-2, -1) # [batch, latent_dim, n_frames]
555
+ # print('z shape: ', z.shape)
556
+ return z
557
+
558
+ def update_ema(self):
559
+ if self.encoder_ema_rate is not None:
560
+ self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate)
561
+ if self.decoder_ema_rate is not None:
562
+ self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate)
563
+ if self.renderer_ema_rate is not None:
564
+ self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate)
565
+
566
+ def get_parameters(self, name):
567
+ if name == 'encoder':
568
+ return self.encoder.parameters()
569
+ elif name == 'decoder':
570
+ p = list(self.decoder.parameters())
571
+ if self.z_quantizer is not None:
572
+ p += list(self.z_quantizer.parameters())
573
+ return p
574
+ elif name == 'renderer':
575
+ return self.renderer.parameters()
576
+ elif name == 'zdm':
577
+ return self.zdm_net.parameters()
578
+
579
+ def encode(self, x):
580
+
581
+ z = self.encoder(x)
582
+ # print('z shape: ', z.shape)
583
+ z = self.normalize_latents(z)
584
+ # print('after norm z shape: ', z.shape)
585
+
586
+ if (self.zaug_p is not None) and self.training:
587
+ assert self.z_layernorm is not None # ensure 0 mean 1 std
588
+ if self.zaug_tmax_always:
589
+ tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
590
+ else:
591
+ tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax
592
+
593
+ zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
594
+ mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float()
595
+ if z.dim() == 4: # Image: [batch, channels, height, width]
596
+ mask_shape = (-1, 1, 1, 1)
597
+ elif z.dim() == 3: # Audio: [batch, channels, n_frames]
598
+ mask_shape = (-1, 1, 1)
599
+ else:
600
+ raise ValueError(f"Unsupported tensor dimension: {z.dim()}")
601
+
602
+ z = mask_aug.view(*mask_shape) * zt + (1 - mask_aug).view(*mask_shape) * z
603
+ # z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z
604
+ self._tz = tz
605
+ self._mask_aug = mask_aug
606
+
607
+ # print('after zaug z shape: ', z.shape)
608
+
609
+ return z
610
+
611
+
612
+ def decode(self, z):
613
+ z_dec = self.decoder(z)
614
+ return z_dec
615
+
616
+ def render(self, z_dec):
617
+ raise NotImplementedError
618
+
619
+ def forward(self, data, mode, has_optimizer=None):
620
+ loss = torch.tensor(0., device=data['inp'].device)
621
+ ret = dict()
622
+ # print("data['inp'] shape: ", data['inp'].shape)
623
+ z = self.encode(data['inp'])
624
+
625
+ z_dec = self.decode(z)
626
+
627
+
628
+ ret['loss'] = loss
629
+ return z_dec, ret
630
+
631
+ def generate_samples(
632
+ self,
633
+ batch_size,
634
+ n_steps,
635
+ net_kwargs=None,
636
+ uncond_net_kwargs=None,
637
+ ema=False,
638
+ guidance=1.0,
639
+ noise=None,
640
+ return_z=False,
641
+ ):
642
+ if self.zdm_force_guidance is not None:
643
+ guidance = self.zdm_force_guidance
644
+
645
+ shape = (batch_size,) + self.z_shape
646
+ net = self.zdm_net if not ema else self.zdm_net_ema
647
+
648
+ z = self.zdm_sampler.sample(
649
+ net,
650
+ shape,
651
+ n_steps,
652
+ net_kwargs=net_kwargs,
653
+ uncond_net_kwargs=uncond_net_kwargs,
654
+ guidance=guidance,
655
+ noise=noise,
656
+ )
657
+
658
+ if return_z:
659
+ return z
660
+
661
+ if (self.zaug_p is not None) and self.zaug_tmax_always:
662
+ tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
663
+ z, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
664
+
665
+ z = self.denormalize_for_zdm(z)
666
+ z_dec = self.decode(z)
667
+
668
+ return self.render(z_dec)
flowae/models/networks/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from . import consistency_decoder_unet
2
- from . import dit
 
 
1
  from . import consistency_decoder_unet
2
+ from . import dit
3
+ from . import consistency_audio_decoder_unet
flowae/models/networks/consistency_audio_decoder_unet.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+
7
+ from models import register
8
+
9
+ class PositionalEmbedding(nn.Module):
10
+ def __init__(self, pe_dim=320, out_dim=1280, max_positions=10000, endpoint=True):
11
+ super().__init__()
12
+ self.num_channels = pe_dim
13
+ self.max_positions = max_positions
14
+ self.endpoint = endpoint
15
+ self.f_1 = nn.Linear(pe_dim, out_dim)
16
+ self.f_2 = nn.Linear(out_dim, out_dim)
17
+
18
+ def forward(self, x):
19
+ freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
20
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
21
+ freqs = (1 / self.max_positions) ** freqs
22
+ x = x.ger(freqs.to(x.dtype))
23
+ x = torch.cat([x.cos(), x.sin()], dim=1)
24
+
25
+ x = self.f_1(x)
26
+ x = F.silu(x)
27
+ return self.f_2(x)
28
+
29
+
30
+
31
+ class AudioEmbedding(nn.Module):
32
+ """1D convolution for audio input embedding"""
33
+ def __init__(self, in_channels, out_channels=320, kernel_size=3) -> None:
34
+ super().__init__()
35
+ self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
36
+
37
+ def forward(self, x) -> torch.Tensor:
38
+ return self.f(x)
39
+
40
+ class AudioUnembedding(nn.Module):
41
+ """1D convolution for audio output"""
42
+ def __init__(self, in_channels=320, out_channels=1, kernel_size=3) -> None:
43
+ super().__init__()
44
+ self.gn = nn.GroupNorm(32, in_channels)
45
+ self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
46
+
47
+ def forward(self, x) -> torch.Tensor:
48
+ return self.f(F.silu(self.gn(x)))
49
+
50
+
51
+ class AudioConvResblock(nn.Module):
52
+ """1D Residual block for audio"""
53
+ def __init__(self, in_features, out_features, t_dim, kernel_size=3) -> None:
54
+ super().__init__()
55
+ self.f_t = nn.Linear(t_dim, out_features * 2)
56
+
57
+ self.gn_1 = nn.GroupNorm(32, in_features)
58
+ self.f_1 = nn.Conv1d(in_features, out_features, kernel_size=kernel_size, padding=kernel_size//2)
59
+
60
+ self.gn_2 = nn.GroupNorm(32, out_features)
61
+ self.f_2 = nn.Conv1d(out_features, out_features, kernel_size=kernel_size, padding=kernel_size//2)
62
+
63
+ skip_conv = in_features != out_features
64
+ self.f_s = (
65
+ nn.Conv1d(in_features, out_features, kernel_size=1, padding=0)
66
+ if skip_conv
67
+ else nn.Identity()
68
+ )
69
+
70
+ def forward(self, x, t):
71
+ x_skip = x
72
+ t = self.f_t(F.silu(t))
73
+ t = t.chunk(2, dim=1)
74
+ t_1 = t[0].unsqueeze(dim=2) + 1 # [batch, channels, 1]
75
+ t_2 = t[1].unsqueeze(dim=2) # [batch, channels, 1]
76
+
77
+ gn_1 = F.silu(self.gn_1(x))
78
+ f_1 = self.f_1(gn_1)
79
+
80
+ gn_2 = self.gn_2(f_1)
81
+
82
+ return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
83
+
84
+ class AudioDownsample(nn.Module):
85
+ """1D downsampling for audio"""
86
+ def __init__(self, in_channels, t_dim, downsample_factor=2) -> None:
87
+ super().__init__()
88
+ self.f_t = nn.Linear(t_dim, in_channels * 2)
89
+ self.downsample_factor = downsample_factor
90
+
91
+ self.gn_1 = nn.GroupNorm(32, in_channels)
92
+ self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
93
+ self.gn_2 = nn.GroupNorm(32, in_channels)
94
+
95
+ self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
96
+
97
+ def forward(self, x, t) -> torch.Tensor:
98
+ x_skip = x
99
+
100
+ t = self.f_t(F.silu(t))
101
+ t_1, t_2 = t.chunk(2, dim=1)
102
+ t_1 = t_1.unsqueeze(2) + 1
103
+ t_2 = t_2.unsqueeze(2)
104
+
105
+ gn_1 = F.silu(self.gn_1(x))
106
+ # 1D average pooling
107
+ avg_pool1d = F.avg_pool1d(gn_1, kernel_size=self.downsample_factor)
108
+ f_1 = self.f_1(avg_pool1d)
109
+ gn_2 = self.gn_2(f_1)
110
+
111
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
112
+
113
+ return f_2 + F.avg_pool1d(x_skip, kernel_size=self.downsample_factor)
114
+
115
+ class AudioUpsample(nn.Module):
116
+ """1D upsampling for audio"""
117
+ def __init__(self, in_channels, t_dim, upsample_factor=2) -> None:
118
+ super().__init__()
119
+ self.f_t = nn.Linear(t_dim, in_channels * 2)
120
+ self.upsample_factor = upsample_factor
121
+
122
+ self.gn_1 = nn.GroupNorm(32, in_channels)
123
+ self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
124
+ self.gn_2 = nn.GroupNorm(32, in_channels)
125
+
126
+ self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
127
+
128
+ def forward(self, x, t) -> torch.Tensor:
129
+ x_skip = x
130
+
131
+ t = self.f_t(F.silu(t))
132
+ t_1, t_2 = t.chunk(2, dim=1)
133
+ t_1 = t_1.unsqueeze(2) + 1
134
+ t_2 = t_2.unsqueeze(2)
135
+
136
+ gn_1 = F.silu(self.gn_1(x))
137
+ # 1D interpolation upsampling
138
+ upsample = F.interpolate(gn_1, scale_factor=self.upsample_factor, mode='nearest')
139
+ f_1 = self.f_1(upsample)
140
+ gn_2 = self.gn_2(f_1)
141
+
142
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
143
+
144
+ return f_2 + F.interpolate(x_skip, scale_factor=self.upsample_factor, mode='nearest')
145
+
146
+
147
+ @register('audio_diffusion_unet')
148
+ class AudioDiffusionUNet(nn.Module):
149
+ """
150
+ 1D UNet for audio diffusion with dynamic latent conditioning
151
+
152
+ Handles:
153
+ - x: [batch, 1, samples] - audio waveform (dynamic length)
154
+ - z_dec: [batch, 64, n_frames] - latent conditioning (dynamic length)
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ in_channels=1, # Audio channels (mono=1, stereo=2)
160
+ z_dec_channels=64, # Latent conditioning channels
161
+ c0=128, c1=256, c2=512, # Channel progression (smaller than image version)
162
+ pe_dim=320,
163
+ t_dim=1280,
164
+ kernel_size=3
165
+ ) -> None:
166
+ super().__init__()
167
+
168
+ # Store for dynamic conditioning
169
+ self.z_dec_channels = z_dec_channels
170
+
171
+ # Audio input embedding
172
+ self.embed_audio = AudioEmbedding(
173
+ in_channels=in_channels,
174
+ out_channels=c0,
175
+ kernel_size=kernel_size
176
+ )
177
+
178
+ # Time embedding
179
+ self.embed_time = PositionalEmbedding(pe_dim=pe_dim, out_dim=t_dim)
180
+
181
+ # Latent conditioning projection
182
+ if z_dec_channels is not None:
183
+ self.z_dec_proj = nn.Conv1d(z_dec_channels, c0, kernel_size=1)
184
+
185
+ # Downsampling path
186
+ down_0 = nn.ModuleList([
187
+ AudioConvResblock(c0, c0, t_dim, kernel_size),
188
+ AudioConvResblock(c0, c0, t_dim, kernel_size),
189
+ AudioConvResblock(c0, c0, t_dim, kernel_size),
190
+ AudioDownsample(c0, t_dim),
191
+ ])
192
+ down_1 = nn.ModuleList([
193
+ AudioConvResblock(c0, c1, t_dim, kernel_size),
194
+ AudioConvResblock(c1, c1, t_dim, kernel_size),
195
+ AudioConvResblock(c1, c1, t_dim, kernel_size),
196
+ AudioDownsample(c1, t_dim),
197
+ ])
198
+ down_2 = nn.ModuleList([
199
+ AudioConvResblock(c1, c2, t_dim, kernel_size),
200
+ AudioConvResblock(c2, c2, t_dim, kernel_size),
201
+ AudioConvResblock(c2, c2, t_dim, kernel_size),
202
+ AudioDownsample(c2, t_dim),
203
+ ])
204
+ down_3 = nn.ModuleList([
205
+ AudioConvResblock(c2, c2, t_dim, kernel_size),
206
+ AudioConvResblock(c2, c2, t_dim, kernel_size),
207
+ AudioConvResblock(c2, c2, t_dim, kernel_size),
208
+ ])
209
+ self.down = nn.ModuleList([down_0, down_1, down_2, down_3])
210
+
211
+ # Middle layers
212
+ self.mid = nn.ModuleList([
213
+ AudioConvResblock(c2, c2, t_dim, kernel_size),
214
+ AudioConvResblock(c2, c2, t_dim, kernel_size),
215
+ ])
216
+
217
+ # Upsampling path
218
+ up_3 = nn.ModuleList([
219
+ AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
220
+ AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
221
+ AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
222
+ AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
223
+ AudioUpsample(c2, t_dim),
224
+ ])
225
+ up_2 = nn.ModuleList([
226
+ AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
227
+ AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
228
+ AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
229
+ AudioConvResblock(c2 + c1, c2, t_dim, kernel_size),
230
+ AudioUpsample(c2, t_dim),
231
+ ])
232
+ up_1 = nn.ModuleList([
233
+ AudioConvResblock(c2 + c1, c1, t_dim, kernel_size),
234
+ AudioConvResblock(c1 * 2, c1, t_dim, kernel_size),
235
+ AudioConvResblock(c1 * 2, c1, t_dim, kernel_size),
236
+ AudioConvResblock(c0 + c1, c1, t_dim, kernel_size),
237
+ AudioUpsample(c1, t_dim),
238
+ ])
239
+ up_0 = nn.ModuleList([
240
+ AudioConvResblock(c0 + c1, c0, t_dim, kernel_size),
241
+ AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
242
+ AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
243
+ AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
244
+ ])
245
+ self.up = nn.ModuleList([up_0, up_1, up_2, up_3])
246
+
247
+ # Output layer
248
+ self.output = AudioUnembedding(in_channels=c0, out_channels=in_channels)
249
+
250
+ def get_last_layer_weight(self):
251
+ return self.output.f.weight
252
+
253
+ def condition_with_latents(self, x, z_dec):
254
+ """
255
+ Add latent conditioning to audio features
256
+
257
+ Args:
258
+ x: [batch, c0, audio_samples] - audio features
259
+ z_dec: [batch, 64, n_frames] - latent conditioning
260
+
261
+ Returns:
262
+ x: [batch, c0, audio_samples] - conditioned features
263
+ """
264
+ if z_dec is None:
265
+ return x
266
+
267
+ # Project latents to same channel dimension as audio features
268
+ z_proj = self.z_dec_proj(z_dec) # [batch, c0, n_frames]
269
+
270
+ # Interpolate latents to match audio length
271
+ if z_proj.shape[-1] != x.shape[-1]:
272
+ z_proj = F.interpolate(
273
+ z_proj,
274
+ size=x.shape[-1],
275
+ mode='nearest' # or 'linear' for smoother interpolation
276
+ )
277
+
278
+ # Add latent conditioning to audio features
279
+ return x + z_proj
280
+
281
+ def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
282
+ """
283
+ Forward pass
284
+
285
+ Args:
286
+ x: [batch, 1, samples] - audio waveform (any length)
287
+ t: [batch] - diffusion timesteps
288
+ z_dec: [batch, 64, n_frames] - latent conditioning (any length)
289
+ """
290
+ # Embed audio input
291
+ x = self.embed_audio(x) # [batch, c0, samples]
292
+
293
+ # Add latent conditioning
294
+ if z_dec is not None:
295
+ x = self.condition_with_latents(x, z_dec)
296
+
297
+ # Embed timestep
298
+ if t is None:
299
+ t = torch.zeros(x.shape[0], device=x.device)
300
+ t = self.embed_time(t) # [batch, t_dim]
301
+
302
+ # Downsampling with skip connections
303
+ skips = [x]
304
+ for down in self.down:
305
+ for block in down:
306
+ x = block(x, t)
307
+ skips.append(x)
308
+
309
+ # Middle layers
310
+ for mid in self.mid:
311
+ x = mid(x, t)
312
+
313
+ # Upsampling with skip connections
314
+ for up in self.up[::-1]:
315
+ for block in up:
316
+ if isinstance(block, AudioConvResblock):
317
+ x = torch.cat([x, skips.pop()], dim=1)
318
+ x = block(x, t)
319
+
320
+ # Output
321
+ return self.output(x)
322
+
flowae/models/networks/consistency_decoder_unet.py CHANGED
@@ -239,6 +239,7 @@ class ConsistencyDecoderUNet(nn.Module):
239
 
240
  def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
241
  if z_dec is not None:
 
242
  if z_dec.shape[-2] != x.shape[-2] or z_dec.shape[-1] != x.shape[-1]:
243
  assert x.shape[-2] // z_dec.shape[-2] == x.shape[-1] // z_dec.shape[-1]
244
  z_dec = F.upsample_nearest(z_dec, scale_factor=x.shape[-2] // z_dec.shape[-2])
 
239
 
240
  def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
241
  if z_dec is not None:
242
+ print('shape of x and z_dec: ', x.shape, z_dec.shape)
243
  if z_dec.shape[-2] != x.shape[-2] or z_dec.shape[-1] != x.shape[-1]:
244
  assert x.shape[-2] // z_dec.shape[-2] == x.shape[-1] // z_dec.shape[-1]
245
  z_dec = F.upsample_nearest(z_dec, scale_factor=x.shape[-2] // z_dec.shape[-2])
flowae/run.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-f8c4-noise-sync.yaml --save-root /mnt/nvme/dito
2
+ torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-audio.yaml --save-root /mnt/nvme/dito
flowae/upload.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ az storage blob upload-batch \
2
+ --connection-string ""