Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
2279ae0
1
Parent(s):
997d9c0
add reconstruction for audio
Browse files- flowae/audio_dito_inference.py +332 -0
- flowae/configs/datasets/dae.yaml +12 -21
- flowae/configs/experiments/dito-B-audio.yaml +10 -6
- flowae/datasets/__init__.py +2 -2
- flowae/datasets/class_folder.py +2 -0
- flowae/datasets/class_folder_audio.py +196 -0
- flowae/datasets/wrapper_audio_cae.py +89 -0
- flowae/datasets/wrapper_cae.py +1 -193
- flowae/{reconstruction.py → image_dito_inference.py} +0 -0
- flowae/models/diffusion/fm.py +24 -6
- flowae/models/ldm/dac/layers.py +1 -1
- flowae/models/ldm/dac/model.py +3 -1
- flowae/models/ldm/dac/utils.py +11 -11
- flowae/models/ldm/dito.py +142 -1
- flowae/models/ldm/ldm_base.py +224 -0
- flowae/models/networks/__init__.py +2 -1
- flowae/models/networks/consistency_audio_decoder_unet.py +322 -0
- flowae/models/networks/consistency_decoder_unet.py +1 -0
- flowae/run.sh +2 -0
- flowae/upload.sh +2 -0
flowae/audio_dito_inference.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchaudio
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import argparse
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
# Import models
|
| 12 |
+
import models
|
| 13 |
+
from models.ldm.dac.audiotools import AudioSignal
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AudioDiToInference:
|
| 17 |
+
def __init__(self, checkpoint_path, device='cuda'):
|
| 18 |
+
"""Initialize Audio DiTo model from checkpoint"""
|
| 19 |
+
self.device = device
|
| 20 |
+
|
| 21 |
+
# Load checkpoint
|
| 22 |
+
print(f"Loading checkpoint from {checkpoint_path}")
|
| 23 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
| 24 |
+
|
| 25 |
+
# Extract config
|
| 26 |
+
self.config = OmegaConf.create(ckpt['config'])
|
| 27 |
+
|
| 28 |
+
# Create model
|
| 29 |
+
self.model = models.make(self.config['model'])
|
| 30 |
+
|
| 31 |
+
# Load state dict
|
| 32 |
+
self.model.load_state_dict(ckpt['model']['sd'])
|
| 33 |
+
|
| 34 |
+
# Move to device and set to eval
|
| 35 |
+
self.model = self.model.to(device)
|
| 36 |
+
self.model.eval()
|
| 37 |
+
|
| 38 |
+
# Get audio parameters from config
|
| 39 |
+
self.sample_rate = self.config.get('sample_rate', 24000)
|
| 40 |
+
self.mono = self.config.get('mono', True)
|
| 41 |
+
|
| 42 |
+
print(f"Model loaded successfully!")
|
| 43 |
+
print(f"Sample rate: {self.sample_rate} Hz")
|
| 44 |
+
print(f"Mono: {self.mono}")
|
| 45 |
+
|
| 46 |
+
def load_audio(self, audio_path, duration=None, offset=0.0):
|
| 47 |
+
"""Load audio file using AudioSignal
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
audio_path: Path to audio file
|
| 51 |
+
duration: Duration in seconds (None for full audio)
|
| 52 |
+
offset: Start offset in seconds
|
| 53 |
+
"""
|
| 54 |
+
# Load audio using AudioSignal
|
| 55 |
+
if duration is not None:
|
| 56 |
+
signal = AudioSignal(
|
| 57 |
+
str(audio_path),
|
| 58 |
+
duration=duration,
|
| 59 |
+
offset=offset,
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
# Load full audio
|
| 63 |
+
signal = AudioSignal(str(audio_path))
|
| 64 |
+
|
| 65 |
+
# Convert to mono if needed
|
| 66 |
+
if self.mono and signal.num_channels > 1:
|
| 67 |
+
signal = signal.to_mono()
|
| 68 |
+
|
| 69 |
+
# Resample to model sample rate
|
| 70 |
+
if signal.sample_rate != self.sample_rate:
|
| 71 |
+
signal = signal.resample(self.sample_rate)
|
| 72 |
+
|
| 73 |
+
# Normalize
|
| 74 |
+
signal = signal.normalize()
|
| 75 |
+
|
| 76 |
+
# Clamp to [-1, 1]
|
| 77 |
+
signal.audio_data = signal.audio_data.clamp(-1.0, 1.0)
|
| 78 |
+
|
| 79 |
+
return signal
|
| 80 |
+
|
| 81 |
+
def save_audio(self, reconstructed, output_path):
|
| 82 |
+
"""Save AudioSignal to file"""
|
| 83 |
+
# Get audio data
|
| 84 |
+
print('shape of reconstructed: ', reconstructed.shape)
|
| 85 |
+
sf.write(output_path, reconstructed, self.sample_rate)
|
| 86 |
+
print(f"Saved audio to {output_path}")
|
| 87 |
+
|
| 88 |
+
def reconstruct_audio(self, audio_path, num_steps=50, save_latent=False):
|
| 89 |
+
"""Reconstruct entire audio file at once
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
audio_path: Path to audio file
|
| 93 |
+
num_steps: Number of diffusion steps
|
| 94 |
+
save_latent: Whether to return the latent representation
|
| 95 |
+
"""
|
| 96 |
+
# Load full audio without duration limit
|
| 97 |
+
signal = self.load_audio(audio_path, duration=None, offset=0.0)
|
| 98 |
+
|
| 99 |
+
# Get audio tensor
|
| 100 |
+
audio_tensor = signal.audio_data # [channels, samples]
|
| 101 |
+
if audio_tensor.dim() == 2:
|
| 102 |
+
audio_tensor = audio_tensor.squeeze(0) # [samples] for mono
|
| 103 |
+
|
| 104 |
+
# Add batch dimension
|
| 105 |
+
audio_tensor = audio_tensor.to(self.device) # [1, samples]
|
| 106 |
+
|
| 107 |
+
print(f"Input shape: {audio_tensor.shape}")
|
| 108 |
+
print(f"Full audio duration: {audio_tensor.shape[-1] / self.sample_rate:.2f}s")
|
| 109 |
+
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
# Prepare data dict
|
| 112 |
+
data = {'inp': audio_tensor}
|
| 113 |
+
|
| 114 |
+
# Step 1: Encode to latent
|
| 115 |
+
print('shape of audio_tensor: ', audio_tensor.shape)
|
| 116 |
+
z = self.model.encode(audio_tensor)
|
| 117 |
+
print(f"Latent shape: {z.shape}")
|
| 118 |
+
|
| 119 |
+
# Step 2: Decode latent (if model has separate decode step)
|
| 120 |
+
if hasattr(self.model, 'decode'):
|
| 121 |
+
z_dec = self.model.decode(z)
|
| 122 |
+
else:
|
| 123 |
+
z_dec = z
|
| 124 |
+
print(f"Decoded latent shape: {z_dec.shape}")
|
| 125 |
+
|
| 126 |
+
# Step 3: Prepare dummy coordinates (based on training code)
|
| 127 |
+
b, *_ = audio_tensor.shape
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Step 4: Render using diffusion
|
| 131 |
+
if hasattr(self.model, 'render'):
|
| 132 |
+
# Render expects z_dec, coord, scale
|
| 133 |
+
print('using render diffusion model')
|
| 134 |
+
reconstructed = self.model.render(z_dec)
|
| 135 |
+
else:
|
| 136 |
+
# Alternative: direct decode if render not available
|
| 137 |
+
reconstructed = self.model(data, mode='pred')
|
| 138 |
+
|
| 139 |
+
# Remove batch dimension
|
| 140 |
+
reconstructed = reconstructed.squeeze(0).squeeze(0).cpu().numpy() # [samples]
|
| 141 |
+
|
| 142 |
+
print('shape of reconstructed: ', reconstructed.shape)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if save_latent:
|
| 146 |
+
return reconstructed, z.cpu()
|
| 147 |
+
else:
|
| 148 |
+
return reconstructed
|
| 149 |
+
|
| 150 |
+
def save_reconstruction(self, audio_path, output_path, num_steps=50):
|
| 151 |
+
"""Reconstruct and save entire audio file"""
|
| 152 |
+
reconstructed = self.reconstruct_audio(audio_path, num_steps)
|
| 153 |
+
self.save_audio(reconstructed, output_path)
|
| 154 |
+
|
| 155 |
+
def compare_reconstruction(self, audio_path, output_path, num_steps=50):
|
| 156 |
+
"""Save original and reconstruction concatenated"""
|
| 157 |
+
# Load original full audio
|
| 158 |
+
original = self.load_audio(audio_path, duration=None, offset=0.0)
|
| 159 |
+
|
| 160 |
+
# Get reconstruction of full audio
|
| 161 |
+
reconstructed = self.reconstruct_audio(audio_path, num_steps)
|
| 162 |
+
|
| 163 |
+
# Add 0.5 second silence between clips
|
| 164 |
+
silence_samples = int(0.5 * self.sample_rate)
|
| 165 |
+
silence_data = torch.zeros(1, silence_samples)
|
| 166 |
+
|
| 167 |
+
# Concatenate: original -> silence -> reconstruction
|
| 168 |
+
concat_data = torch.cat([
|
| 169 |
+
original.audio_data.cpu(),
|
| 170 |
+
silence_data,
|
| 171 |
+
reconstructed.audio_data.cpu()
|
| 172 |
+
], dim=1)
|
| 173 |
+
|
| 174 |
+
# Create concatenated signal
|
| 175 |
+
comparison = AudioSignal(
|
| 176 |
+
concat_data,
|
| 177 |
+
sample_rate=self.sample_rate
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.save_audio(comparison, output_path)
|
| 181 |
+
print(f"Saved comparison (original + reconstruction) to {output_path}")
|
| 182 |
+
|
| 183 |
+
def visualize_latent(self, audio_path, output_path):
|
| 184 |
+
"""Visualize the latent representation of full audio"""
|
| 185 |
+
# Get latent
|
| 186 |
+
_, z = self.reconstruct_audio(audio_path, save_latent=True)
|
| 187 |
+
|
| 188 |
+
z_np = z.squeeze(0).numpy() # Remove batch dimension
|
| 189 |
+
|
| 190 |
+
# Create visualization
|
| 191 |
+
if z_np.ndim == 2: # [channels, frames]
|
| 192 |
+
n_channels = z_np.shape[0]
|
| 193 |
+
fig, axes = plt.subplots(n_channels, 1, figsize=(12, 2*n_channels))
|
| 194 |
+
|
| 195 |
+
if n_channels == 1:
|
| 196 |
+
axes = [axes]
|
| 197 |
+
|
| 198 |
+
for i in range(n_channels):
|
| 199 |
+
im = axes[i].imshow(
|
| 200 |
+
z_np[i:i+1],
|
| 201 |
+
aspect='auto',
|
| 202 |
+
cmap='coolwarm',
|
| 203 |
+
interpolation='nearest'
|
| 204 |
+
)
|
| 205 |
+
axes[i].set_title(f'Latent Channel {i+1}')
|
| 206 |
+
axes[i].set_xlabel('Time Frames')
|
| 207 |
+
axes[i].set_ylabel('Feature')
|
| 208 |
+
plt.colorbar(im, ax=axes[i])
|
| 209 |
+
else: # 1D latent
|
| 210 |
+
plt.figure(figsize=(12, 4))
|
| 211 |
+
plt.plot(z_np.T)
|
| 212 |
+
plt.title('Latent Representation')
|
| 213 |
+
plt.xlabel('Time Frames')
|
| 214 |
+
plt.ylabel('Value')
|
| 215 |
+
|
| 216 |
+
plt.tight_layout()
|
| 217 |
+
plt.savefig(output_path, dpi=150)
|
| 218 |
+
plt.close()
|
| 219 |
+
|
| 220 |
+
print(f"Saved latent visualization to {output_path}")
|
| 221 |
+
|
| 222 |
+
def batch_reconstruct(self, audio_folder, output_folder, max_files=None, num_steps=50):
|
| 223 |
+
"""Reconstruct all audio files in a folder (full audio)"""
|
| 224 |
+
audio_folder = Path(audio_folder)
|
| 225 |
+
output_folder = Path(output_folder)
|
| 226 |
+
output_folder.mkdir(exist_ok=True, parents=True)
|
| 227 |
+
|
| 228 |
+
# Get all audio files
|
| 229 |
+
audio_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
|
| 230 |
+
audio_paths = []
|
| 231 |
+
for ext in audio_extensions:
|
| 232 |
+
audio_paths.extend(audio_folder.glob(f'*{ext}'))
|
| 233 |
+
audio_paths.extend(audio_folder.glob(f'*{ext.upper()}'))
|
| 234 |
+
|
| 235 |
+
if max_files:
|
| 236 |
+
audio_paths = audio_paths[:max_files]
|
| 237 |
+
|
| 238 |
+
print(f"Processing {len(audio_paths)} audio files...")
|
| 239 |
+
|
| 240 |
+
for audio_path in audio_paths:
|
| 241 |
+
output_path = output_folder / f"recon_{audio_path.stem}.wav"
|
| 242 |
+
try:
|
| 243 |
+
self.save_reconstruction(
|
| 244 |
+
str(audio_path), str(output_path),
|
| 245 |
+
num_steps=num_steps
|
| 246 |
+
)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f"Error processing {audio_path}: {e}")
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
print("Batch reconstruction complete!")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def main():
|
| 255 |
+
parser = argparse.ArgumentParser(description='Audio DiTo Inference')
|
| 256 |
+
parser.add_argument('--checkpoint', type=str, required=True,
|
| 257 |
+
help='Path to Audio DiTo checkpoint')
|
| 258 |
+
parser.add_argument('--input', type=str, required=True,
|
| 259 |
+
help='Input audio path or folder')
|
| 260 |
+
parser.add_argument('--output', type=str, required=True,
|
| 261 |
+
help='Output path')
|
| 262 |
+
parser.add_argument('--compare', action='store_true',
|
| 263 |
+
help='Save comparison with original')
|
| 264 |
+
parser.add_argument('--batch', action='store_true',
|
| 265 |
+
help='Process entire folder')
|
| 266 |
+
parser.add_argument('--visualize', action='store_true',
|
| 267 |
+
help='Visualize latent representation')
|
| 268 |
+
parser.add_argument('--steps', type=int, default=50,
|
| 269 |
+
help='Number of diffusion steps')
|
| 270 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 271 |
+
help='Device to use (cuda/cpu)')
|
| 272 |
+
parser.add_argument('--max-files', type=int, default=None,
|
| 273 |
+
help='Maximum files to process in batch mode')
|
| 274 |
+
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
# Initialize model
|
| 278 |
+
audio_dito = AudioDiToInference(args.checkpoint, device=args.device)
|
| 279 |
+
|
| 280 |
+
# Process based on mode
|
| 281 |
+
if args.batch:
|
| 282 |
+
# Batch processing
|
| 283 |
+
audio_dito.batch_reconstruct(
|
| 284 |
+
args.input, args.output,
|
| 285 |
+
max_files=args.max_files,
|
| 286 |
+
num_steps=args.steps
|
| 287 |
+
)
|
| 288 |
+
elif args.visualize:
|
| 289 |
+
# Visualize latent
|
| 290 |
+
audio_dito.visualize_latent(
|
| 291 |
+
args.input, args.output
|
| 292 |
+
)
|
| 293 |
+
elif args.compare:
|
| 294 |
+
# Save comparison
|
| 295 |
+
audio_dito.compare_reconstruction(
|
| 296 |
+
args.input, args.output,
|
| 297 |
+
num_steps=args.steps
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
# Single reconstruction
|
| 301 |
+
audio_dito.save_reconstruction(
|
| 302 |
+
args.input, args.output,
|
| 303 |
+
num_steps=args.steps
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# Example usage function for direct Python use
|
| 308 |
+
def reconstruct_single_audio(checkpoint_path, audio_path, output_path):
|
| 309 |
+
"""Simple function to reconstruct a single audio file"""
|
| 310 |
+
audio_dito = AudioDiToInference(checkpoint_path)
|
| 311 |
+
audio_dito.save_reconstruction(audio_path, output_path)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
main()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# Usage examples:
|
| 319 |
+
# 1. Single audio reconstruction (full audio):
|
| 320 |
+
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav
|
| 321 |
+
#
|
| 322 |
+
# 2. Save comparison (original + reconstruction):
|
| 323 |
+
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output compare.wav --compare
|
| 324 |
+
#
|
| 325 |
+
# 3. Batch processing (reconstruct all audio files in folder):
|
| 326 |
+
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio_folder/ --output output_folder/ --batch
|
| 327 |
+
#
|
| 328 |
+
# 4. Visualize latent representation:
|
| 329 |
+
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output latent.png --visualize
|
| 330 |
+
#
|
| 331 |
+
# 5. Use fewer diffusion steps for faster inference:
|
| 332 |
+
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav --steps 25
|
flowae/configs/datasets/dae.yaml
CHANGED
|
@@ -4,22 +4,19 @@ datasets:
|
|
| 4 |
name: wrapper_audio_cae
|
| 5 |
args:
|
| 6 |
dataset:
|
| 7 |
-
name:
|
| 8 |
args:
|
| 9 |
-
|
| 10 |
-
Emilia_EN: ["/home/masuser/minimax-audio/dataset/Emilia/EN"]
|
| 11 |
sample_rate: 24000
|
| 12 |
duration: 0.38
|
| 13 |
-
n_examples: 10000000
|
| 14 |
shuffle: true
|
| 15 |
-
|
| 16 |
sample_rate: 24000
|
| 17 |
duration: 0.38
|
| 18 |
mono: true
|
| 19 |
normalize: true
|
| 20 |
-
return_coords: true
|
| 21 |
loader:
|
| 22 |
-
batch_size:
|
| 23 |
num_workers: 8
|
| 24 |
drop_last: true
|
| 25 |
|
|
@@ -27,20 +24,17 @@ datasets:
|
|
| 27 |
name: wrapper_audio_cae
|
| 28 |
args:
|
| 29 |
dataset:
|
| 30 |
-
name:
|
| 31 |
args:
|
| 32 |
-
|
| 33 |
-
Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"]
|
| 34 |
sample_rate: 24000
|
| 35 |
duration: 5.0
|
| 36 |
-
n_examples: 100
|
| 37 |
shuffle: false
|
| 38 |
-
|
| 39 |
sample_rate: 24000
|
| 40 |
duration: 5.0
|
| 41 |
mono: true
|
| 42 |
normalize: true
|
| 43 |
-
return_coords: true
|
| 44 |
loader:
|
| 45 |
batch_size: 4
|
| 46 |
num_workers: 8
|
|
@@ -50,20 +44,17 @@ datasets:
|
|
| 50 |
name: wrapper_audio_cae
|
| 51 |
args:
|
| 52 |
dataset:
|
| 53 |
-
name:
|
| 54 |
args:
|
| 55 |
-
|
| 56 |
-
Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"]
|
| 57 |
sample_rate: 24000
|
| 58 |
-
duration:
|
| 59 |
-
n_examples: 1000
|
| 60 |
shuffle: false
|
| 61 |
-
|
| 62 |
sample_rate: 24000
|
| 63 |
-
duration:
|
| 64 |
mono: true
|
| 65 |
normalize: true
|
| 66 |
-
return_coords: true
|
| 67 |
loader:
|
| 68 |
batch_size: 1
|
| 69 |
num_workers: 8
|
|
|
|
| 4 |
name: wrapper_audio_cae
|
| 5 |
args:
|
| 6 |
dataset:
|
| 7 |
+
name: class_folder_audio
|
| 8 |
args:
|
| 9 |
+
root_path: "/home/masuser/minimax-audio/dataset/Emilia/EN"
|
|
|
|
| 10 |
sample_rate: 24000
|
| 11 |
duration: 0.38
|
|
|
|
| 12 |
shuffle: true
|
| 13 |
+
num_channels: 1
|
| 14 |
sample_rate: 24000
|
| 15 |
duration: 0.38
|
| 16 |
mono: true
|
| 17 |
normalize: true
|
|
|
|
| 18 |
loader:
|
| 19 |
+
batch_size: 52
|
| 20 |
num_workers: 8
|
| 21 |
drop_last: true
|
| 22 |
|
|
|
|
| 24 |
name: wrapper_audio_cae
|
| 25 |
args:
|
| 26 |
dataset:
|
| 27 |
+
name: class_folder_audio
|
| 28 |
args:
|
| 29 |
+
root_path: "/home/masuser/minimax-audio/dataset/libritts"
|
|
|
|
| 30 |
sample_rate: 24000
|
| 31 |
duration: 5.0
|
|
|
|
| 32 |
shuffle: false
|
| 33 |
+
num_channels: 1
|
| 34 |
sample_rate: 24000
|
| 35 |
duration: 5.0
|
| 36 |
mono: true
|
| 37 |
normalize: true
|
|
|
|
| 38 |
loader:
|
| 39 |
batch_size: 4
|
| 40 |
num_workers: 8
|
|
|
|
| 44 |
name: wrapper_audio_cae
|
| 45 |
args:
|
| 46 |
dataset:
|
| 47 |
+
name: class_folder_audio
|
| 48 |
args:
|
| 49 |
+
root_path: "/home/masuser/minimax-audio/dataset/libritts"
|
|
|
|
| 50 |
sample_rate: 24000
|
| 51 |
+
duration: 5.0
|
|
|
|
| 52 |
shuffle: false
|
| 53 |
+
num_channels: 1
|
| 54 |
sample_rate: 24000
|
| 55 |
+
duration: 5.0
|
| 56 |
mono: true
|
| 57 |
normalize: true
|
|
|
|
| 58 |
loader:
|
| 59 |
batch_size: 1
|
| 60 |
num_workers: 8
|
flowae/configs/experiments/dito-B-audio.yaml
CHANGED
|
@@ -8,12 +8,16 @@ model:
|
|
| 8 |
# Encoder
|
| 9 |
encoder:
|
| 10 |
name: dac_encoder
|
| 11 |
-
args: {config_name:
|
| 12 |
|
| 13 |
# Latent configuration - now fully convolutional
|
| 14 |
z_channels: 64 # Number of latent channels
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Decoder (identity for DiTo)
|
| 19 |
decoder:
|
|
@@ -21,10 +25,10 @@ model:
|
|
| 21 |
|
| 22 |
# Renderer - Fully convolutional for dynamic duration
|
| 23 |
renderer:
|
| 24 |
-
name:
|
| 25 |
args:
|
| 26 |
net:
|
| 27 |
-
name:
|
| 28 |
args:
|
| 29 |
in_channels: 1
|
| 30 |
z_dec_channels: 64
|
|
@@ -39,6 +43,6 @@ model:
|
|
| 39 |
name: fm
|
| 40 |
args: {timescale: 1000.0}
|
| 41 |
|
| 42 |
-
render_sampler: {name:
|
| 43 |
render_n_steps: 50
|
| 44 |
|
|
|
|
| 8 |
# Encoder
|
| 9 |
encoder:
|
| 10 |
name: dac_encoder
|
| 11 |
+
args: {config_name: snake}
|
| 12 |
|
| 13 |
# Latent configuration - now fully convolutional
|
| 14 |
z_channels: 64 # Number of latent channels
|
| 15 |
+
|
| 16 |
+
zaug_p: 0.1
|
| 17 |
+
zaug_decoding_loss_type: suffix
|
| 18 |
+
zaug_zdm_diffusion:
|
| 19 |
+
name: fm
|
| 20 |
+
args: {timescale: 1000.0}
|
| 21 |
|
| 22 |
# Decoder (identity for DiTo)
|
| 23 |
decoder:
|
|
|
|
| 25 |
|
| 26 |
# Renderer - Fully convolutional for dynamic duration
|
| 27 |
renderer:
|
| 28 |
+
name: fixres_renderer_wrapper
|
| 29 |
args:
|
| 30 |
net:
|
| 31 |
+
name: audio_diffusion_unet
|
| 32 |
args:
|
| 33 |
in_channels: 1
|
| 34 |
z_dec_channels: 64
|
|
|
|
| 43 |
name: fm
|
| 44 |
args: {timescale: 1000.0}
|
| 45 |
|
| 46 |
+
render_sampler: {name: fm_euler_sampler_audio}
|
| 47 |
render_n_steps: 50
|
| 48 |
|
flowae/datasets/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
from .datasets import register, make
|
| 2 |
-
from . import image_folder, class_folder, webdataset
|
| 3 |
-
from . import wrapper_cae
|
|
|
|
| 1 |
from .datasets import register, make
|
| 2 |
+
from . import image_folder, class_folder, webdataset, class_folder_audio
|
| 3 |
+
from . import wrapper_cae, wrapper_audio_cae
|
flowae/datasets/class_folder.py
CHANGED
|
@@ -6,6 +6,8 @@ from datasets import register
|
|
| 6 |
from torch.utils.data import Dataset
|
| 7 |
from torchvision import transforms
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
Image.MAX_IMAGE_PIXELS = 933120000
|
| 11 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
|
|
| 6 |
from torch.utils.data import Dataset
|
| 7 |
from torchvision import transforms
|
| 8 |
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
|
| 12 |
Image.MAX_IMAGE_PIXELS = 933120000
|
| 13 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
flowae/datasets/class_folder_audio.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from PIL import Image, ImageFile
|
| 4 |
+
|
| 5 |
+
from datasets import register
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Callable
|
| 13 |
+
|
| 14 |
+
from models.ldm.dac.audiotools import AudioSignal
|
| 15 |
+
from models.ldm.dac.audiotools.core import util
|
| 16 |
+
# Audio file extensions (from audiotools)
|
| 17 |
+
AUDIO_EXTS = ('.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.MP3', '.mp4', '.MP4', '.m4a', '.M4A')
|
| 18 |
+
|
| 19 |
+
@register('class_folder_audio')
|
| 20 |
+
class AudioFolder(Dataset):
|
| 21 |
+
"""
|
| 22 |
+
Audio dataset that loads audio files from a folder structure.
|
| 23 |
+
Similar to ClassFolder but for audio files.
|
| 24 |
+
|
| 25 |
+
Expected folder structure:
|
| 26 |
+
root_path/
|
| 27 |
+
├── class1/
|
| 28 |
+
│ ├── audio1.wav
|
| 29 |
+
│ ├── audio2.wav
|
| 30 |
+
│ └── ...
|
| 31 |
+
├── class2/
|
| 32 |
+
│ ├── audio1.wav
|
| 33 |
+
│ └── ...
|
| 34 |
+
└── ...
|
| 35 |
+
|
| 36 |
+
Or for single class (no subfolders):
|
| 37 |
+
root_path/
|
| 38 |
+
├── audio1.wav
|
| 39 |
+
├── audio2.wav
|
| 40 |
+
└── ...
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
root_path: str,
|
| 46 |
+
sample_rate: int = 24000,
|
| 47 |
+
duration: float = 2.0,
|
| 48 |
+
num_channels: int = 1,
|
| 49 |
+
random_crop: bool = True,
|
| 50 |
+
loudness_cutoff: float = -40,
|
| 51 |
+
audio_only: bool = False,
|
| 52 |
+
drop_label_p: float = 0.0,
|
| 53 |
+
shuffle: bool = True,
|
| 54 |
+
shuffle_state: int = 0,
|
| 55 |
+
transform: Optional[Callable] = None,
|
| 56 |
+
normalize: bool = True,
|
| 57 |
+
trim_silence: bool = False,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Args:
|
| 61 |
+
root_path: Path to audio files
|
| 62 |
+
sample_rate: Target sample rate for audio
|
| 63 |
+
duration: Duration in seconds for audio clips
|
| 64 |
+
num_channels: Number of channels (1 for mono, 2 for stereo)
|
| 65 |
+
random_crop: Whether to randomly crop audio (vs deterministic)
|
| 66 |
+
loudness_cutoff: Minimum loudness threshold for audio selection
|
| 67 |
+
audio_only: If True, return only audio signal. If False, return dict with labels
|
| 68 |
+
drop_label_p: Probability of dropping labels (for unconditional training)
|
| 69 |
+
shuffle: Whether to shuffle files
|
| 70 |
+
shuffle_state: Random state for shuffling
|
| 71 |
+
transform: Additional audio transforms
|
| 72 |
+
normalize: Whether to normalize audio amplitude
|
| 73 |
+
trim_silence: Whether to trim silence from audio
|
| 74 |
+
"""
|
| 75 |
+
self.root_path = root_path
|
| 76 |
+
self.sample_rate = sample_rate
|
| 77 |
+
self.duration = duration
|
| 78 |
+
self.num_channels = num_channels
|
| 79 |
+
self.random_crop = random_crop
|
| 80 |
+
self.loudness_cutoff = loudness_cutoff
|
| 81 |
+
self.audio_only = audio_only
|
| 82 |
+
self.drop_label_p = drop_label_p
|
| 83 |
+
self.transform = transform
|
| 84 |
+
self.normalize = normalize
|
| 85 |
+
self.trim_silence = trim_silence
|
| 86 |
+
|
| 87 |
+
print(f'Audio root_path: {root_path}')
|
| 88 |
+
|
| 89 |
+
# Find audio files and labels
|
| 90 |
+
self.files = []
|
| 91 |
+
|
| 92 |
+
# Fin all audio in recursive in root_path
|
| 93 |
+
for root, dirs, files in os.walk(self.root_path):
|
| 94 |
+
for file in files:
|
| 95 |
+
if file.lower().endswith(AUDIO_EXTS):
|
| 96 |
+
self.files.append(os.path.join(root, file))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
print(f'Found {len(self.files)} audio files')
|
| 100 |
+
|
| 101 |
+
# Shuffle files if requested
|
| 102 |
+
if shuffle:
|
| 103 |
+
state = util.random_state(shuffle_state)
|
| 104 |
+
combined = self.files
|
| 105 |
+
state.shuffle(combined)
|
| 106 |
+
self.files = combined
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.files)
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, idx):
|
| 112 |
+
try:
|
| 113 |
+
file_path = self.files[idx]
|
| 114 |
+
|
| 115 |
+
# Load audio using AudioSignal
|
| 116 |
+
if self.random_crop:
|
| 117 |
+
# Use salient excerpt for random cropping with loudness filtering
|
| 118 |
+
signal = AudioSignal.salient_excerpt(
|
| 119 |
+
str(file_path),
|
| 120 |
+
duration=self.duration,
|
| 121 |
+
loudness_cutoff=self.loudness_cutoff,
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
# Load from beginning or deterministic offset
|
| 125 |
+
signal = AudioSignal(
|
| 126 |
+
str(file_path),
|
| 127 |
+
duration=self.duration,
|
| 128 |
+
offset=0.0,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Convert to mono/stereo as needed
|
| 132 |
+
if self.num_channels == 1:
|
| 133 |
+
signal = signal.to_mono()
|
| 134 |
+
|
| 135 |
+
# Resample to target sample rate
|
| 136 |
+
signal = signal.resample(self.sample_rate)
|
| 137 |
+
|
| 138 |
+
# Ensure duration by padding or trimming
|
| 139 |
+
target_samples = int(self.duration * self.sample_rate)
|
| 140 |
+
if signal.length < target_samples:
|
| 141 |
+
signal = signal.zero_pad_to(target_samples)
|
| 142 |
+
elif signal.length > target_samples:
|
| 143 |
+
signal = signal.truncate_samples(target_samples)
|
| 144 |
+
|
| 145 |
+
# Optional audio processing
|
| 146 |
+
if self.trim_silence:
|
| 147 |
+
signal = signal.trim_silence()
|
| 148 |
+
# Re-pad if trimming made it too short
|
| 149 |
+
if signal.length < target_samples:
|
| 150 |
+
signal = signal.zero_pad_to(target_samples)
|
| 151 |
+
|
| 152 |
+
if self.normalize:
|
| 153 |
+
signal = signal.normalize()
|
| 154 |
+
|
| 155 |
+
# Clamp audio to [-1, 1] range
|
| 156 |
+
signal.audio_data = signal.audio_data.clamp(-1.0, 1.0)
|
| 157 |
+
|
| 158 |
+
# Apply additional transforms if provided
|
| 159 |
+
if self.transform is not None:
|
| 160 |
+
# Create a random state for transforms
|
| 161 |
+
state = util.random_state(idx)
|
| 162 |
+
transform_args = self.transform.instantiate(state, signal=signal)
|
| 163 |
+
signal = self.transform(signal, **transform_args)
|
| 164 |
+
|
| 165 |
+
# print('before process: ', signal.audio_data.shape)
|
| 166 |
+
# Store metadata
|
| 167 |
+
signal.metadata.update(
|
| 168 |
+
{
|
| 169 |
+
'file_path': str(file_path),
|
| 170 |
+
'original_sr': signal.sample_rate,
|
| 171 |
+
'duration': self.duration,
|
| 172 |
+
}
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if self.audio_only:
|
| 176 |
+
return signal
|
| 177 |
+
else:
|
| 178 |
+
return {
|
| 179 |
+
'signal': signal,
|
| 180 |
+
'file_path': str(file_path),
|
| 181 |
+
'idx': idx,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f'Error loading audio file {self.files[idx]}: {e}')
|
| 186 |
+
# Return next file on error to avoid crashing training
|
| 187 |
+
return self.__getitem__((idx + 1) % len(self))
|
| 188 |
+
|
| 189 |
+
def collate(self, batch):
|
| 190 |
+
"""Collate function for DataLoader"""
|
| 191 |
+
if self.audio_only:
|
| 192 |
+
# Batch AudioSignals
|
| 193 |
+
return AudioSignal.batch(batch)
|
| 194 |
+
else:
|
| 195 |
+
# Collate dictionary batch
|
| 196 |
+
return util.collate(batch)
|
flowae/datasets/wrapper_audio_cae.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 6 |
+
|
| 7 |
+
from datasets import register
|
| 8 |
+
import datasets
|
| 9 |
+
|
| 10 |
+
class BaseWrapperAudioCAE:
|
| 11 |
+
"""Base wrapper for audio Convolutional Autoencoder (CAE) training.
|
| 12 |
+
|
| 13 |
+
Similar to the image wrapper, but for audio data.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dataset,
|
| 19 |
+
sample_rate=24000,
|
| 20 |
+
duration=0.38, # Duration in seconds
|
| 21 |
+
n_samples=None, # Alternative: specify exact number of samples
|
| 22 |
+
return_gt=True,
|
| 23 |
+
gt_sample_rate=None, # Ground truth sample rate (if different)
|
| 24 |
+
mono=True,
|
| 25 |
+
normalize=True,
|
| 26 |
+
return_coords=True, # Whether to return coordinate grids
|
| 27 |
+
):
|
| 28 |
+
self.dataset = datasets.make(dataset)
|
| 29 |
+
self.sample_rate = sample_rate
|
| 30 |
+
self.duration = duration
|
| 31 |
+
self.n_samples = int(duration * sample_rate)
|
| 32 |
+
self.return_gt = return_gt
|
| 33 |
+
self.gt_sample_rate = gt_sample_rate or sample_rate
|
| 34 |
+
self.mono = mono
|
| 35 |
+
self.normalize = normalize
|
| 36 |
+
self.return_coords = return_coords
|
| 37 |
+
|
| 38 |
+
def process(self, audio_data):
|
| 39 |
+
"""Process audio data for DiTo training.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
audio_data: Dictionary with 'signal' key containing AudioSignal
|
| 43 |
+
or AudioSignal directly
|
| 44 |
+
"""
|
| 45 |
+
ret = {}
|
| 46 |
+
|
| 47 |
+
# Extract AudioSignal
|
| 48 |
+
if isinstance(audio_data, dict):
|
| 49 |
+
signal = audio_data['signal']
|
| 50 |
+
else:
|
| 51 |
+
signal = audio_data
|
| 52 |
+
|
| 53 |
+
# Normalize audio
|
| 54 |
+
audio_tensor = signal.audio_data # Shape: [channels, samples]
|
| 55 |
+
|
| 56 |
+
audio_tensor = audio_tensor.squeeze(0)
|
| 57 |
+
|
| 58 |
+
# Create input tensor
|
| 59 |
+
ret['inp'] = audio_tensor
|
| 60 |
+
|
| 61 |
+
if not self.return_gt:
|
| 62 |
+
return ret
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
ret['gt'] = audio_tensor
|
| 66 |
+
# print('audio_tensor shape: ', audio_tensor.shape)
|
| 67 |
+
|
| 68 |
+
return ret
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@register('wrapper_audio_cae')
|
| 72 |
+
class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset):
|
| 73 |
+
"""Dataset wrapper for audio CAE training."""
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.dataset)
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, idx):
|
| 79 |
+
data = self.dataset[idx]
|
| 80 |
+
return self.process(data)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@register('wrapper_audio_cae_iterable')
|
| 84 |
+
class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset):
|
| 85 |
+
"""Iterable dataset wrapper for audio CAE training."""
|
| 86 |
+
|
| 87 |
+
def __iter__(self):
|
| 88 |
+
for data in self.dataset:
|
| 89 |
+
yield self.process(data)
|
flowae/datasets/wrapper_cae.py
CHANGED
|
@@ -113,196 +113,4 @@ class WrapperCAE(BaseWrapperCAE, IterableDataset):
|
|
| 113 |
ret.update(data)
|
| 114 |
yield ret
|
| 115 |
else:
|
| 116 |
-
yield self.process(data)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class BaseWrapperAudioCAE:
|
| 124 |
-
"""Base wrapper for audio Convolutional Autoencoder (CAE) training.
|
| 125 |
-
|
| 126 |
-
Similar to the image wrapper, but for audio data.
|
| 127 |
-
"""
|
| 128 |
-
|
| 129 |
-
def __init__(
|
| 130 |
-
self,
|
| 131 |
-
dataset,
|
| 132 |
-
sample_rate=24000,
|
| 133 |
-
duration=0.38, # Duration in seconds
|
| 134 |
-
n_samples=None, # Alternative: specify exact number of samples
|
| 135 |
-
return_gt=True,
|
| 136 |
-
gt_sample_rate=None, # Ground truth sample rate (if different)
|
| 137 |
-
mono=True,
|
| 138 |
-
normalize=True,
|
| 139 |
-
return_coords=True, # Whether to return coordinate grids
|
| 140 |
-
):
|
| 141 |
-
self.dataset = dataset
|
| 142 |
-
self.sample_rate = sample_rate
|
| 143 |
-
self.duration = duration
|
| 144 |
-
self.n_samples = n_samples or int(duration * sample_rate)
|
| 145 |
-
self.return_gt = return_gt
|
| 146 |
-
self.gt_sample_rate = gt_sample_rate or sample_rate
|
| 147 |
-
self.mono = mono
|
| 148 |
-
self.normalize = normalize
|
| 149 |
-
self.return_coords = return_coords
|
| 150 |
-
|
| 151 |
-
def process(self, audio_data):
|
| 152 |
-
"""Process audio data for DiTo training.
|
| 153 |
-
|
| 154 |
-
Args:
|
| 155 |
-
audio_data: Dictionary with 'signal' key containing AudioSignal
|
| 156 |
-
or AudioSignal directly
|
| 157 |
-
"""
|
| 158 |
-
ret = {}
|
| 159 |
-
|
| 160 |
-
# Extract AudioSignal
|
| 161 |
-
if isinstance(audio_data, dict):
|
| 162 |
-
signal = audio_data['signal']
|
| 163 |
-
else:
|
| 164 |
-
signal = audio_data
|
| 165 |
-
|
| 166 |
-
# Convert to mono if needed
|
| 167 |
-
if self.mono and signal.num_channels > 1:
|
| 168 |
-
signal = signal.to_mono()
|
| 169 |
-
|
| 170 |
-
# Resample to target sample rate
|
| 171 |
-
if signal.sample_rate != self.sample_rate:
|
| 172 |
-
signal = signal.resample(self.sample_rate)
|
| 173 |
-
|
| 174 |
-
# Extract fixed duration
|
| 175 |
-
if signal.duration < self.duration:
|
| 176 |
-
# Pad if too short
|
| 177 |
-
signal = signal.zero_pad_to(self.n_samples)
|
| 178 |
-
else:
|
| 179 |
-
# Take random excerpt if too long
|
| 180 |
-
max_start = signal.num_samples - self.n_samples
|
| 181 |
-
if max_start > 0:
|
| 182 |
-
start_idx = random.randint(0, max_start)
|
| 183 |
-
signal = signal[..., start_idx:start_idx + self.n_samples]
|
| 184 |
-
else:
|
| 185 |
-
signal = signal[..., :self.n_samples]
|
| 186 |
-
|
| 187 |
-
# Normalize audio
|
| 188 |
-
audio_tensor = signal.audio_data # Shape: [channels, samples]
|
| 189 |
-
if self.normalize:
|
| 190 |
-
# Normalize to [-1, 1]
|
| 191 |
-
max_val = audio_tensor.abs().max()
|
| 192 |
-
if max_val > 0:
|
| 193 |
-
audio_tensor = audio_tensor / max_val
|
| 194 |
-
|
| 195 |
-
# Create input tensor
|
| 196 |
-
ret['inp'] = audio_tensor
|
| 197 |
-
|
| 198 |
-
if not self.return_gt:
|
| 199 |
-
return ret
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
ret['gt'] = audio_tensor
|
| 203 |
-
|
| 204 |
-
return ret
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
@register('wrapper_audio_cae')
|
| 208 |
-
class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset):
|
| 209 |
-
"""Dataset wrapper for audio CAE training."""
|
| 210 |
-
|
| 211 |
-
def __len__(self):
|
| 212 |
-
return len(self.dataset)
|
| 213 |
-
|
| 214 |
-
def __getitem__(self, idx):
|
| 215 |
-
data = self.dataset[idx]
|
| 216 |
-
return self.process(data)
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
@register('wrapper_audio_cae_iterable')
|
| 220 |
-
class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset):
|
| 221 |
-
"""Iterable dataset wrapper for audio CAE training."""
|
| 222 |
-
|
| 223 |
-
def __iter__(self):
|
| 224 |
-
for data in self.dataset:
|
| 225 |
-
yield self.process(data)
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
# Example usage with your existing AudioDataset
|
| 229 |
-
def create_dito_audio_dataset(config):
|
| 230 |
-
"""Create DiTo audio dataset from config."""
|
| 231 |
-
|
| 232 |
-
# Create base audio dataset using audiotools
|
| 233 |
-
|
| 234 |
-
# Setup audio loaders
|
| 235 |
-
train_folders = config.get("train_folders", {})
|
| 236 |
-
|
| 237 |
-
loader = AudioLoader(
|
| 238 |
-
sources=list(train_folders.values()),
|
| 239 |
-
transform=tfm.Compose(
|
| 240 |
-
tfm.VolumeNorm(("uniform", -20, -10)),
|
| 241 |
-
tfm.RescaleAudio(),
|
| 242 |
-
),
|
| 243 |
-
ext=['.wav', '.flac', '.mp3'],
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
# Create base dataset
|
| 247 |
-
base_dataset = AudioDataset(
|
| 248 |
-
loaders=loader,
|
| 249 |
-
sample_rate=config['sample_rate'],
|
| 250 |
-
duration=config['duration'],
|
| 251 |
-
n_examples=config['n_examples'],
|
| 252 |
-
num_channels=1 if config.get('mono', True) else 2,
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
# Wrap with DiTo wrapper
|
| 256 |
-
dito_dataset = WrapperAudioCAE(
|
| 257 |
-
dataset=base_dataset,
|
| 258 |
-
sample_rate=config['sample_rate'],
|
| 259 |
-
duration=config['duration'],
|
| 260 |
-
mono=config.get('mono', True),
|
| 261 |
-
normalize=True,
|
| 262 |
-
return_coords=True,
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
return dito_dataset
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
# For your training config, you would use it like:
|
| 269 |
-
"""
|
| 270 |
-
datasets:
|
| 271 |
-
train:
|
| 272 |
-
name: wrapper_audio_cae
|
| 273 |
-
args:
|
| 274 |
-
dataset:
|
| 275 |
-
name: audio_dataset # Your base audio dataset
|
| 276 |
-
args:
|
| 277 |
-
sources: ["/path/to/audio/files"]
|
| 278 |
-
sample_rate: 44100
|
| 279 |
-
duration: 2.0
|
| 280 |
-
n_examples: 10000
|
| 281 |
-
sample_rate: 44100
|
| 282 |
-
duration: 2.0
|
| 283 |
-
mono: true
|
| 284 |
-
normalize: true
|
| 285 |
-
return_coords: true
|
| 286 |
-
loader:
|
| 287 |
-
batch_size: 16
|
| 288 |
-
num_workers: 8
|
| 289 |
-
|
| 290 |
-
val:
|
| 291 |
-
name: wrapper_audio_cae
|
| 292 |
-
args:
|
| 293 |
-
dataset:
|
| 294 |
-
name: audio_dataset
|
| 295 |
-
args:
|
| 296 |
-
sources: ["/path/to/val/audio/files"]
|
| 297 |
-
sample_rate: 44100
|
| 298 |
-
duration: 2.0
|
| 299 |
-
n_examples: 1000
|
| 300 |
-
sample_rate: 44100
|
| 301 |
-
duration: 2.0
|
| 302 |
-
mono: true
|
| 303 |
-
normalize: true
|
| 304 |
-
return_coords: true
|
| 305 |
-
loader:
|
| 306 |
-
batch_size: 16
|
| 307 |
-
num_workers: 8
|
| 308 |
-
"""
|
|
|
|
| 113 |
ret.update(data)
|
| 114 |
yield ret
|
| 115 |
else:
|
| 116 |
+
yield self.process(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flowae/{reconstruction.py → image_dito_inference.py}
RENAMED
|
File without changes
|
flowae/models/diffusion/fm.py
CHANGED
|
@@ -22,6 +22,21 @@ class FM:
|
|
| 22 |
|
| 23 |
def B(self, t):
|
| 24 |
return -(1.0 - self.sigma_min)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def get_betas(self, n_timesteps):
|
| 27 |
return torch.zeros(n_timesteps) # Not VP and not supported
|
|
@@ -38,17 +53,20 @@ class FM:
|
|
| 38 |
|
| 39 |
if t is None:
|
| 40 |
t = torch.rand(x.shape[0], device=x.device)
|
| 41 |
-
print('x shape: ', x.shape)
|
| 42 |
x_t, noise = self.add_noise(x, t)
|
| 43 |
-
print('x_t shape: ', x_t.shape)
|
| 44 |
pred = net(x_t, t=t * self.timescale, **net_kwargs)
|
| 45 |
-
print('pred shape: ', pred.shape)
|
| 46 |
|
| 47 |
target = self.A(t) * x + self.B(t) * noise # -dxt/dt
|
| 48 |
-
print('target shape: ', target.shape)
|
| 49 |
-
print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all)
|
| 50 |
if return_loss_unreduced:
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
if return_all:
|
| 53 |
return loss, t, x_t, pred
|
| 54 |
else:
|
|
|
|
| 22 |
|
| 23 |
def B(self, t):
|
| 24 |
return -(1.0 - self.sigma_min)
|
| 25 |
+
|
| 26 |
+
def _get_reduction_dims(self, x):
|
| 27 |
+
"""Get appropriate dimensions for loss reduction based on tensor shape"""
|
| 28 |
+
if x.dim() == 4:
|
| 29 |
+
# Images: [batch, channels, height, width]
|
| 30 |
+
return [1, 2, 3]
|
| 31 |
+
elif x.dim() == 3:
|
| 32 |
+
# Audio: [batch, channels, samples] or [batch, latent_dim, time_frames]
|
| 33 |
+
return [1, 2]
|
| 34 |
+
elif x.dim() == 2:
|
| 35 |
+
# 1D signals: [batch, samples]
|
| 36 |
+
return [1]
|
| 37 |
+
else:
|
| 38 |
+
# Fallback: reduce over all non-batch dimensions
|
| 39 |
+
return list(range(1, x.dim()))
|
| 40 |
|
| 41 |
def get_betas(self, n_timesteps):
|
| 42 |
return torch.zeros(n_timesteps) # Not VP and not supported
|
|
|
|
| 53 |
|
| 54 |
if t is None:
|
| 55 |
t = torch.rand(x.shape[0], device=x.device)
|
| 56 |
+
# print('x shape: ', x.shape)
|
| 57 |
x_t, noise = self.add_noise(x, t)
|
| 58 |
+
# print('x_t shape: ', x_t.shape)
|
| 59 |
pred = net(x_t, t=t * self.timescale, **net_kwargs)
|
| 60 |
+
# print('pred shape: ', pred.shape)
|
| 61 |
|
| 62 |
target = self.A(t) * x + self.B(t) * noise # -dxt/dt
|
| 63 |
+
# print('target shape: ', target.shape)
|
| 64 |
+
# print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all)
|
| 65 |
if return_loss_unreduced:
|
| 66 |
+
print('pred shape: ', pred.shape, 'target shape: ', target.shape)
|
| 67 |
+
reduce_dims = self._get_reduction_dims(x)
|
| 68 |
+
loss = ((pred.float() - target.float()) ** 2).mean(dim=reduce_dims)
|
| 69 |
+
# loss = ((pred.float() - target.float()) ** 2).mean(dim=[1, 2, 3])
|
| 70 |
if return_all:
|
| 71 |
return loss, t, x_t, pred
|
| 72 |
else:
|
flowae/models/ldm/dac/layers.py
CHANGED
|
@@ -74,7 +74,7 @@ def get_activation(activation, channels, alpha):
|
|
| 74 |
return nn.LeakyReLU()
|
| 75 |
elif activation == "tanh":
|
| 76 |
return nn.Tanh()
|
| 77 |
-
elif activation == "
|
| 78 |
return SnakeBeta(channels, alpha)
|
| 79 |
else:
|
| 80 |
raise ValueError(f"Activation {activation} not supported")
|
|
|
|
| 74 |
return nn.LeakyReLU()
|
| 75 |
elif activation == "tanh":
|
| 76 |
return nn.Tanh()
|
| 77 |
+
elif activation == "snakebeta":
|
| 78 |
return SnakeBeta(channels, alpha)
|
| 79 |
else:
|
| 80 |
raise ValueError(f"Activation {activation} not supported")
|
flowae/models/ldm/dac/model.py
CHANGED
|
@@ -236,7 +236,8 @@ class Encoder(nn.Module):
|
|
| 236 |
|
| 237 |
def forward(self, x):
|
| 238 |
x = F.leaky_relu(x)
|
| 239 |
-
|
|
|
|
| 240 |
|
| 241 |
|
| 242 |
class DecoderBlock(nn.Module):
|
|
@@ -478,6 +479,7 @@ class DACVAE(BaseModel, CodecMixin):
|
|
| 478 |
):
|
| 479 |
x = self.encoder(audio_data)
|
| 480 |
x = self.en_conv_post(x)
|
|
|
|
| 481 |
m, logs = torch.split(x, self.latent_dim, dim=1)
|
| 482 |
logs = torch.clamp(logs, min=-14.0, max=14.0)
|
| 483 |
|
|
|
|
| 236 |
|
| 237 |
def forward(self, x):
|
| 238 |
x = F.leaky_relu(x)
|
| 239 |
+
x = self.block(x)
|
| 240 |
+
return x
|
| 241 |
|
| 242 |
|
| 243 |
class DecoderBlock(nn.Module):
|
|
|
|
| 479 |
):
|
| 480 |
x = self.encoder(audio_data)
|
| 481 |
x = self.en_conv_post(x)
|
| 482 |
+
print('x shape: ', x.shape)
|
| 483 |
m, logs = torch.split(x, self.latent_dim, dim=1)
|
| 484 |
logs = torch.clamp(logs, min=-14.0, max=14.0)
|
| 485 |
|
flowae/models/ldm/dac/utils.py
CHANGED
|
@@ -7,16 +7,16 @@ from .model import Encoder, Decoder, WNConv1d
|
|
| 7 |
|
| 8 |
default_configs = {
|
| 9 |
'snake': dict(
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
d_in=1,
|
| 14 |
activation='snake',
|
| 15 |
),
|
| 16 |
-
'
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
d_in=1,
|
| 21 |
activation='snakebeta',
|
| 22 |
),
|
|
@@ -27,10 +27,10 @@ default_configs = {
|
|
| 27 |
def make_dac_encoder(config_name, **kwargs):
|
| 28 |
encoder_kwargs = default_configs[config_name]
|
| 29 |
encoder_kwargs.update(kwargs)
|
| 30 |
-
|
| 31 |
return nn.Sequential(
|
| 32 |
Encoder(**encoder_kwargs),
|
| 33 |
-
WNConv1d(
|
| 34 |
)
|
| 35 |
|
| 36 |
|
|
@@ -38,8 +38,8 @@ def make_dac_encoder(config_name, **kwargs):
|
|
| 38 |
def make_vqgan_decoder(config_name, **kwargs):
|
| 39 |
decoder_kwargs = default_configs[config_name]
|
| 40 |
decoder_kwargs.update(kwargs)
|
| 41 |
-
|
| 42 |
return nn.Sequential(
|
| 43 |
-
WNConv1d(
|
| 44 |
Decoder(**decoder_kwargs),
|
| 45 |
)
|
|
|
|
| 7 |
|
| 8 |
default_configs = {
|
| 9 |
'snake': dict(
|
| 10 |
+
d_model=64,
|
| 11 |
+
strides=[2, 4, 5, 8],
|
| 12 |
+
d_latent=64,
|
| 13 |
d_in=1,
|
| 14 |
activation='snake',
|
| 15 |
),
|
| 16 |
+
'snakebeta': dict(
|
| 17 |
+
d_model=64,
|
| 18 |
+
strides=[2, 4, 5, 8],
|
| 19 |
+
d_latent=64,
|
| 20 |
d_in=1,
|
| 21 |
activation='snakebeta',
|
| 22 |
),
|
|
|
|
| 27 |
def make_dac_encoder(config_name, **kwargs):
|
| 28 |
encoder_kwargs = default_configs[config_name]
|
| 29 |
encoder_kwargs.update(kwargs)
|
| 30 |
+
d_model = encoder_kwargs['d_model']
|
| 31 |
return nn.Sequential(
|
| 32 |
Encoder(**encoder_kwargs),
|
| 33 |
+
WNConv1d(d_model, d_model, kernel_size=1),
|
| 34 |
)
|
| 35 |
|
| 36 |
|
|
|
|
| 38 |
def make_vqgan_decoder(config_name, **kwargs):
|
| 39 |
decoder_kwargs = default_configs[config_name]
|
| 40 |
decoder_kwargs.update(kwargs)
|
| 41 |
+
d_model = decoder_kwargs['d_model']
|
| 42 |
return nn.Sequential(
|
| 43 |
+
WNConv1d(d_model, d_model, kernel_size=1),
|
| 44 |
Decoder(**decoder_kwargs),
|
| 45 |
)
|
flowae/models/ldm/dito.py
CHANGED
|
@@ -6,7 +6,8 @@ import torch
|
|
| 6 |
import models
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
from models import register
|
| 9 |
-
|
|
|
|
| 10 |
from models.ldm.vqgan.lpips import LPIPS
|
| 11 |
|
| 12 |
|
|
@@ -178,3 +179,143 @@ class DiTo(LDMBase):
|
|
| 178 |
dae_loss_w = loss_config.get('dae_loss', 1)
|
| 179 |
ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
|
| 180 |
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import models
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
from models import register
|
| 9 |
+
|
| 10 |
+
from models.ldm.ldm_base import LDMBase, LDMBaseAudio
|
| 11 |
from models.ldm.vqgan.lpips import LPIPS
|
| 12 |
|
| 13 |
|
|
|
|
| 179 |
dae_loss_w = loss_config.get('dae_loss', 1)
|
| 180 |
ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
|
| 181 |
return ret
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@register('dito_audio')
|
| 186 |
+
class DiToAudio(LDMBaseAudio):
|
| 187 |
+
|
| 188 |
+
def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1,**kwargs):
|
| 189 |
+
super().__init__(**kwargs)
|
| 190 |
+
self.render_diffusion = models.make(render_diffusion)
|
| 191 |
+
|
| 192 |
+
if OmegaConf.is_config(render_sampler):
|
| 193 |
+
render_sampler = OmegaConf.to_container(render_sampler, resolve=True)
|
| 194 |
+
render_sampler = copy.deepcopy(render_sampler)
|
| 195 |
+
if render_sampler.get('args') is None:
|
| 196 |
+
render_sampler['args'] = {}
|
| 197 |
+
render_sampler['args']['diffusion'] = self.render_diffusion
|
| 198 |
+
self.render_sampler = models.make(render_sampler)
|
| 199 |
+
self.render_n_steps = render_n_steps
|
| 200 |
+
self.renderer_guidance = renderer_guidance
|
| 201 |
+
|
| 202 |
+
self.t_loss_monitor_v = [0 for _ in range(10)]
|
| 203 |
+
self.t_loss_monitor_n = [0 for _ in range(10)]
|
| 204 |
+
self.t_loss_monitor_decay = 0.99
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def render(self, z_dec):
|
| 208 |
+
net_kwargs = {'z_dec': z_dec}
|
| 209 |
+
n_frames = z_dec.size(2) * 320
|
| 210 |
+
shape = (z_dec.size(0), z_dec.size(0), n_frames)
|
| 211 |
+
|
| 212 |
+
if self.renderer_guidance > 1:
|
| 213 |
+
uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1)
|
| 214 |
+
uncond_net_kwargs = {'z_dec': uncond_z_dec}
|
| 215 |
+
else:
|
| 216 |
+
uncond_net_kwargs = None
|
| 217 |
+
|
| 218 |
+
ret = self.render_sampler.sample(
|
| 219 |
+
net=self.renderer,
|
| 220 |
+
n_steps=self.render_n_steps,
|
| 221 |
+
shape=shape,
|
| 222 |
+
net_kwargs=net_kwargs,
|
| 223 |
+
uncond_net_kwargs=uncond_net_kwargs,
|
| 224 |
+
guidance=self.renderer_guidance,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# if self.use_ema_renderer:
|
| 228 |
+
# self.swap_ema_renderer()
|
| 229 |
+
|
| 230 |
+
return ret
|
| 231 |
+
|
| 232 |
+
def forward(self, data, mode, has_optimizer=None):
|
| 233 |
+
if mode in ['z', 'z_dec']:
|
| 234 |
+
ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer)
|
| 235 |
+
return ret_z
|
| 236 |
+
|
| 237 |
+
grad = self.get_grad_plan(has_optimizer)
|
| 238 |
+
loss_config = self.loss_config
|
| 239 |
+
if mode == 'pred':
|
| 240 |
+
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
|
| 241 |
+
|
| 242 |
+
gt_patch = data['gt']
|
| 243 |
+
|
| 244 |
+
if grad['renderer']:
|
| 245 |
+
return self.render(z_dec)
|
| 246 |
+
else:
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
return self.render(z_dec)
|
| 249 |
+
|
| 250 |
+
elif mode == 'loss':
|
| 251 |
+
if not grad['renderer']: # Only training zdm
|
| 252 |
+
_, ret = super().forward(data, mode='z', has_optimizer=has_optimizer)
|
| 253 |
+
return ret
|
| 254 |
+
|
| 255 |
+
gt_patch = data['gt']
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer)
|
| 259 |
+
net_kwargs = {'z_dec': z_dec}
|
| 260 |
+
|
| 261 |
+
# print('latent z_dec shape: ', z_dec.shape)
|
| 262 |
+
|
| 263 |
+
t = torch.rand(gt_patch.shape[0], device=gt_patch.device)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# print('self.zaug_p:', self.zaug_p)
|
| 267 |
+
# print('self.training:', self.training)
|
| 268 |
+
|
| 269 |
+
if (self.zaug_p is not None) and self.training:
|
| 270 |
+
tz = self._tz
|
| 271 |
+
mask_aug = self._mask_aug
|
| 272 |
+
|
| 273 |
+
typ = self.zaug_decoding_loss_type
|
| 274 |
+
if typ == 'all':
|
| 275 |
+
tmin = torch.ones_like(tz) * 0
|
| 276 |
+
tmax = torch.ones_like(tz) * 1
|
| 277 |
+
elif typ == 'suffix':
|
| 278 |
+
tmin = tz
|
| 279 |
+
tmax = torch.ones_like(tz) * 1
|
| 280 |
+
elif typ == 'tz':
|
| 281 |
+
tmin = tz
|
| 282 |
+
tmax = tz
|
| 283 |
+
elif typ == 'tmax':
|
| 284 |
+
tmin = torch.ones_like(tz) * 1
|
| 285 |
+
tmax = torch.ones_like(tz) * 1
|
| 286 |
+
else:
|
| 287 |
+
raise NotImplementedError
|
| 288 |
+
t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin)
|
| 289 |
+
|
| 290 |
+
t = mask_aug * t_aug + (1 - mask_aug) * t
|
| 291 |
+
|
| 292 |
+
loss, t = self.render_diffusion.loss(
|
| 293 |
+
net=self.renderer,
|
| 294 |
+
x=gt_patch,
|
| 295 |
+
t=t,
|
| 296 |
+
net_kwargs=net_kwargs,
|
| 297 |
+
return_loss_unreduced=True
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Visualize diffusion network loss for different timesteps #
|
| 301 |
+
if self.training:
|
| 302 |
+
m = len(self.t_loss_monitor_v)
|
| 303 |
+
for i in range(len(loss)):
|
| 304 |
+
q = min(math.floor(t[i].item() * m), m - 1)
|
| 305 |
+
self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay)
|
| 306 |
+
self.t_loss_monitor_n[q] += 1
|
| 307 |
+
for q in range(m):
|
| 308 |
+
if self.t_loss_monitor_n[q] > 0:
|
| 309 |
+
if self.t_loss_monitor_n[q] < 500:
|
| 310 |
+
r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q])
|
| 311 |
+
else:
|
| 312 |
+
r = 1
|
| 313 |
+
ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r
|
| 314 |
+
# - #
|
| 315 |
+
|
| 316 |
+
dae_loss = loss.mean()
|
| 317 |
+
|
| 318 |
+
ret['dae_loss'] = dae_loss.item()
|
| 319 |
+
dae_loss_w = loss_config.get('dae_loss', 1)
|
| 320 |
+
ret['loss'] = ret['loss'] + dae_loss * dae_loss_w
|
| 321 |
+
return ret
|
flowae/models/ldm/ldm_base.py
CHANGED
|
@@ -47,6 +47,39 @@ class LDMBase(nn.Module):
|
|
| 47 |
use_ema_decoder=False,
|
| 48 |
use_ema_renderer=False,
|
| 49 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
super().__init__()
|
| 51 |
self.loss_config = loss_config if loss_config is not None else dict()
|
| 52 |
|
|
@@ -442,3 +475,194 @@ class DiagonalGaussianDistribution(object):
|
|
| 442 |
|
| 443 |
def mode(self):
|
| 444 |
return self.mean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
use_ema_decoder=False,
|
| 48 |
use_ema_renderer=False,
|
| 49 |
):
|
| 50 |
+
print('print all the args ')
|
| 51 |
+
print("encoder: ", encoder)
|
| 52 |
+
print("z_shape: ",z_shape)
|
| 53 |
+
print("decoder: ",decoder)
|
| 54 |
+
print("renderer: ",renderer)
|
| 55 |
+
print("encoder_ema_rate: ",encoder_ema_rate)
|
| 56 |
+
print("decoder_ema_rate: ",decoder_ema_rate)
|
| 57 |
+
print("renderer_ema_rate: ",renderer_ema_rate)
|
| 58 |
+
print("z_gaussian: ",z_gaussian)
|
| 59 |
+
print("z_gaussian_sample: ",z_gaussian_sample)
|
| 60 |
+
print("z_quantizer: ",z_quantizer)
|
| 61 |
+
print("z_quantizer_n_embed: ",z_quantizer_n_embed)
|
| 62 |
+
print("z_quantizer_beta: ",z_quantizer_beta)
|
| 63 |
+
print("z_layernorm: ",z_layernorm)
|
| 64 |
+
print("zaug_p: ",zaug_p)
|
| 65 |
+
print("zaug_tmax: ",zaug_tmax)
|
| 66 |
+
print("zaug_tmax_always: ",zaug_tmax_always)
|
| 67 |
+
print("zaug_decoding_loss_type: ",zaug_decoding_loss_type)
|
| 68 |
+
print("zaug_zdm_diffusion: ",zaug_zdm_diffusion)
|
| 69 |
+
print("gt_noise_lb: ",gt_noise_lb)
|
| 70 |
+
print("drop_z_p: ",drop_z_p)
|
| 71 |
+
print("zdm_net: ",zdm_net)
|
| 72 |
+
print("zdm_diffusion: ",zdm_diffusion)
|
| 73 |
+
print("zdm_sampler: ",zdm_sampler)
|
| 74 |
+
print("zdm_n_steps: ",zdm_n_steps)
|
| 75 |
+
print("zdm_ema_rate: ",zdm_ema_rate)
|
| 76 |
+
print("zdm_train_normalize: ",zdm_train_normalize)
|
| 77 |
+
print("zdm_class_cond: ",zdm_class_cond)
|
| 78 |
+
print("zdm_force_guidance: ",zdm_force_guidance)
|
| 79 |
+
print("loss_config: ",loss_config)
|
| 80 |
+
print("use_ema_encoder: ",use_ema_encoder)
|
| 81 |
+
print("use_ema_decoder: ",use_ema_decoder)
|
| 82 |
+
print("use_ema_renderer: ",use_ema_renderer)
|
| 83 |
super().__init__()
|
| 84 |
self.loss_config = loss_config if loss_config is not None else dict()
|
| 85 |
|
|
|
|
| 475 |
|
| 476 |
def mode(self):
|
| 477 |
return self.mean
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class LDMBaseAudio(nn.Module):
|
| 481 |
+
def __init__(
|
| 482 |
+
self,
|
| 483 |
+
encoder,
|
| 484 |
+
z_channels,
|
| 485 |
+
decoder,
|
| 486 |
+
renderer,
|
| 487 |
+
zaug_p=0.1,
|
| 488 |
+
zaug_tmax=1.0,
|
| 489 |
+
zaug_tmax_always=False,
|
| 490 |
+
zaug_decoding_loss_type='all',
|
| 491 |
+
zaug_zdm_diffusion={'name': 'fm', 'args': {'timescale': 1000.0}},
|
| 492 |
+
zdm_ema_rate=0.9999,
|
| 493 |
+
loss_config={},
|
| 494 |
+
encoder_ema_rate=None,
|
| 495 |
+
decoder_ema_rate=None,
|
| 496 |
+
renderer_ema_rate=None,
|
| 497 |
+
):
|
| 498 |
+
super().__init__()
|
| 499 |
+
self.loss_config = loss_config
|
| 500 |
+
|
| 501 |
+
self.encoder = models.make(encoder)
|
| 502 |
+
self.decoder = models.make(decoder)
|
| 503 |
+
self.renderer = models.make(renderer)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
self.z_layernorm = nn.LayerNorm(
|
| 507 |
+
z_channels, # e.g., 64
|
| 508 |
+
elementwise_affine=False
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
self.zaug_p = zaug_p
|
| 512 |
+
self.zaug_tmax = zaug_tmax
|
| 513 |
+
self.zaug_tmax_always = zaug_tmax_always
|
| 514 |
+
self.zaug_decoding_loss_type = zaug_decoding_loss_type
|
| 515 |
+
if zaug_zdm_diffusion is not None:
|
| 516 |
+
self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion)
|
| 517 |
+
|
| 518 |
+
# EMA models #
|
| 519 |
+
self.encoder_ema_rate = encoder_ema_rate
|
| 520 |
+
if self.encoder_ema_rate is not None:
|
| 521 |
+
self.encoder_ema = copy.deepcopy(self.encoder)
|
| 522 |
+
for p in self.encoder_ema.parameters():
|
| 523 |
+
p.requires_grad = False
|
| 524 |
+
|
| 525 |
+
self.decoder_ema_rate = decoder_ema_rate
|
| 526 |
+
if self.decoder_ema_rate is not None:
|
| 527 |
+
self.decoder_ema = copy.deepcopy(self.decoder)
|
| 528 |
+
for p in self.decoder_ema.parameters():
|
| 529 |
+
p.requires_grad = False
|
| 530 |
+
|
| 531 |
+
self.renderer_ema_rate = renderer_ema_rate
|
| 532 |
+
if self.renderer_ema_rate is not None:
|
| 533 |
+
self.renderer_ema = copy.deepcopy(self.renderer)
|
| 534 |
+
for p in self.renderer_ema.parameters():
|
| 535 |
+
p.requires_grad = False
|
| 536 |
+
#
|
| 537 |
+
|
| 538 |
+
def get_grad_plan(self, has_optimizer):
|
| 539 |
+
if has_optimizer is None:
|
| 540 |
+
has_optimizer = dict()
|
| 541 |
+
grad = dict()
|
| 542 |
+
grad['encoder'] = has_optimizer.get('encoder', False)
|
| 543 |
+
grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False)
|
| 544 |
+
grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False)
|
| 545 |
+
return grad
|
| 546 |
+
|
| 547 |
+
def normalize_latents(self, z):
|
| 548 |
+
# z shape: [batch, latent_dim, n_frames] - n_frames can vary!
|
| 549 |
+
# print('bef z shape: ', z.shape)
|
| 550 |
+
z = z.transpose(-2, -1) # [batch, latent_dim, n_frames]
|
| 551 |
+
# print('z shape: ', z.shape)
|
| 552 |
+
z = self.z_layernorm(z) # Normalize over latent_dim for each time step
|
| 553 |
+
# print('z shape: ', z.shape)
|
| 554 |
+
z = z.transpose(-2, -1) # [batch, latent_dim, n_frames]
|
| 555 |
+
# print('z shape: ', z.shape)
|
| 556 |
+
return z
|
| 557 |
+
|
| 558 |
+
def update_ema(self):
|
| 559 |
+
if self.encoder_ema_rate is not None:
|
| 560 |
+
self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate)
|
| 561 |
+
if self.decoder_ema_rate is not None:
|
| 562 |
+
self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate)
|
| 563 |
+
if self.renderer_ema_rate is not None:
|
| 564 |
+
self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate)
|
| 565 |
+
|
| 566 |
+
def get_parameters(self, name):
|
| 567 |
+
if name == 'encoder':
|
| 568 |
+
return self.encoder.parameters()
|
| 569 |
+
elif name == 'decoder':
|
| 570 |
+
p = list(self.decoder.parameters())
|
| 571 |
+
if self.z_quantizer is not None:
|
| 572 |
+
p += list(self.z_quantizer.parameters())
|
| 573 |
+
return p
|
| 574 |
+
elif name == 'renderer':
|
| 575 |
+
return self.renderer.parameters()
|
| 576 |
+
elif name == 'zdm':
|
| 577 |
+
return self.zdm_net.parameters()
|
| 578 |
+
|
| 579 |
+
def encode(self, x):
|
| 580 |
+
|
| 581 |
+
z = self.encoder(x)
|
| 582 |
+
# print('z shape: ', z.shape)
|
| 583 |
+
z = self.normalize_latents(z)
|
| 584 |
+
# print('after norm z shape: ', z.shape)
|
| 585 |
+
|
| 586 |
+
if (self.zaug_p is not None) and self.training:
|
| 587 |
+
assert self.z_layernorm is not None # ensure 0 mean 1 std
|
| 588 |
+
if self.zaug_tmax_always:
|
| 589 |
+
tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
|
| 590 |
+
else:
|
| 591 |
+
tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax
|
| 592 |
+
|
| 593 |
+
zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
|
| 594 |
+
mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float()
|
| 595 |
+
if z.dim() == 4: # Image: [batch, channels, height, width]
|
| 596 |
+
mask_shape = (-1, 1, 1, 1)
|
| 597 |
+
elif z.dim() == 3: # Audio: [batch, channels, n_frames]
|
| 598 |
+
mask_shape = (-1, 1, 1)
|
| 599 |
+
else:
|
| 600 |
+
raise ValueError(f"Unsupported tensor dimension: {z.dim()}")
|
| 601 |
+
|
| 602 |
+
z = mask_aug.view(*mask_shape) * zt + (1 - mask_aug).view(*mask_shape) * z
|
| 603 |
+
# z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z
|
| 604 |
+
self._tz = tz
|
| 605 |
+
self._mask_aug = mask_aug
|
| 606 |
+
|
| 607 |
+
# print('after zaug z shape: ', z.shape)
|
| 608 |
+
|
| 609 |
+
return z
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def decode(self, z):
|
| 613 |
+
z_dec = self.decoder(z)
|
| 614 |
+
return z_dec
|
| 615 |
+
|
| 616 |
+
def render(self, z_dec):
|
| 617 |
+
raise NotImplementedError
|
| 618 |
+
|
| 619 |
+
def forward(self, data, mode, has_optimizer=None):
|
| 620 |
+
loss = torch.tensor(0., device=data['inp'].device)
|
| 621 |
+
ret = dict()
|
| 622 |
+
# print("data['inp'] shape: ", data['inp'].shape)
|
| 623 |
+
z = self.encode(data['inp'])
|
| 624 |
+
|
| 625 |
+
z_dec = self.decode(z)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
ret['loss'] = loss
|
| 629 |
+
return z_dec, ret
|
| 630 |
+
|
| 631 |
+
def generate_samples(
|
| 632 |
+
self,
|
| 633 |
+
batch_size,
|
| 634 |
+
n_steps,
|
| 635 |
+
net_kwargs=None,
|
| 636 |
+
uncond_net_kwargs=None,
|
| 637 |
+
ema=False,
|
| 638 |
+
guidance=1.0,
|
| 639 |
+
noise=None,
|
| 640 |
+
return_z=False,
|
| 641 |
+
):
|
| 642 |
+
if self.zdm_force_guidance is not None:
|
| 643 |
+
guidance = self.zdm_force_guidance
|
| 644 |
+
|
| 645 |
+
shape = (batch_size,) + self.z_shape
|
| 646 |
+
net = self.zdm_net if not ema else self.zdm_net_ema
|
| 647 |
+
|
| 648 |
+
z = self.zdm_sampler.sample(
|
| 649 |
+
net,
|
| 650 |
+
shape,
|
| 651 |
+
n_steps,
|
| 652 |
+
net_kwargs=net_kwargs,
|
| 653 |
+
uncond_net_kwargs=uncond_net_kwargs,
|
| 654 |
+
guidance=guidance,
|
| 655 |
+
noise=noise,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
if return_z:
|
| 659 |
+
return z
|
| 660 |
+
|
| 661 |
+
if (self.zaug_p is not None) and self.zaug_tmax_always:
|
| 662 |
+
tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax
|
| 663 |
+
z, _ = self.zaug_zdm_diffusion.add_noise(z, tz)
|
| 664 |
+
|
| 665 |
+
z = self.denormalize_for_zdm(z)
|
| 666 |
+
z_dec = self.decode(z)
|
| 667 |
+
|
| 668 |
+
return self.render(z_dec)
|
flowae/models/networks/__init__.py
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
from . import consistency_decoder_unet
|
| 2 |
-
from . import dit
|
|
|
|
|
|
| 1 |
from . import consistency_decoder_unet
|
| 2 |
+
from . import dit
|
| 3 |
+
from . import consistency_audio_decoder_unet
|
flowae/models/networks/consistency_audio_decoder_unet.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from models import register
|
| 8 |
+
|
| 9 |
+
class PositionalEmbedding(nn.Module):
|
| 10 |
+
def __init__(self, pe_dim=320, out_dim=1280, max_positions=10000, endpoint=True):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.num_channels = pe_dim
|
| 13 |
+
self.max_positions = max_positions
|
| 14 |
+
self.endpoint = endpoint
|
| 15 |
+
self.f_1 = nn.Linear(pe_dim, out_dim)
|
| 16 |
+
self.f_2 = nn.Linear(out_dim, out_dim)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
|
| 20 |
+
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
|
| 21 |
+
freqs = (1 / self.max_positions) ** freqs
|
| 22 |
+
x = x.ger(freqs.to(x.dtype))
|
| 23 |
+
x = torch.cat([x.cos(), x.sin()], dim=1)
|
| 24 |
+
|
| 25 |
+
x = self.f_1(x)
|
| 26 |
+
x = F.silu(x)
|
| 27 |
+
return self.f_2(x)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class AudioEmbedding(nn.Module):
|
| 32 |
+
"""1D convolution for audio input embedding"""
|
| 33 |
+
def __init__(self, in_channels, out_channels=320, kernel_size=3) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
|
| 36 |
+
|
| 37 |
+
def forward(self, x) -> torch.Tensor:
|
| 38 |
+
return self.f(x)
|
| 39 |
+
|
| 40 |
+
class AudioUnembedding(nn.Module):
|
| 41 |
+
"""1D convolution for audio output"""
|
| 42 |
+
def __init__(self, in_channels=320, out_channels=1, kernel_size=3) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.gn = nn.GroupNorm(32, in_channels)
|
| 45 |
+
self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
|
| 46 |
+
|
| 47 |
+
def forward(self, x) -> torch.Tensor:
|
| 48 |
+
return self.f(F.silu(self.gn(x)))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AudioConvResblock(nn.Module):
|
| 52 |
+
"""1D Residual block for audio"""
|
| 53 |
+
def __init__(self, in_features, out_features, t_dim, kernel_size=3) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.f_t = nn.Linear(t_dim, out_features * 2)
|
| 56 |
+
|
| 57 |
+
self.gn_1 = nn.GroupNorm(32, in_features)
|
| 58 |
+
self.f_1 = nn.Conv1d(in_features, out_features, kernel_size=kernel_size, padding=kernel_size//2)
|
| 59 |
+
|
| 60 |
+
self.gn_2 = nn.GroupNorm(32, out_features)
|
| 61 |
+
self.f_2 = nn.Conv1d(out_features, out_features, kernel_size=kernel_size, padding=kernel_size//2)
|
| 62 |
+
|
| 63 |
+
skip_conv = in_features != out_features
|
| 64 |
+
self.f_s = (
|
| 65 |
+
nn.Conv1d(in_features, out_features, kernel_size=1, padding=0)
|
| 66 |
+
if skip_conv
|
| 67 |
+
else nn.Identity()
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, x, t):
|
| 71 |
+
x_skip = x
|
| 72 |
+
t = self.f_t(F.silu(t))
|
| 73 |
+
t = t.chunk(2, dim=1)
|
| 74 |
+
t_1 = t[0].unsqueeze(dim=2) + 1 # [batch, channels, 1]
|
| 75 |
+
t_2 = t[1].unsqueeze(dim=2) # [batch, channels, 1]
|
| 76 |
+
|
| 77 |
+
gn_1 = F.silu(self.gn_1(x))
|
| 78 |
+
f_1 = self.f_1(gn_1)
|
| 79 |
+
|
| 80 |
+
gn_2 = self.gn_2(f_1)
|
| 81 |
+
|
| 82 |
+
return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
|
| 83 |
+
|
| 84 |
+
class AudioDownsample(nn.Module):
|
| 85 |
+
"""1D downsampling for audio"""
|
| 86 |
+
def __init__(self, in_channels, t_dim, downsample_factor=2) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.f_t = nn.Linear(t_dim, in_channels * 2)
|
| 89 |
+
self.downsample_factor = downsample_factor
|
| 90 |
+
|
| 91 |
+
self.gn_1 = nn.GroupNorm(32, in_channels)
|
| 92 |
+
self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 93 |
+
self.gn_2 = nn.GroupNorm(32, in_channels)
|
| 94 |
+
|
| 95 |
+
self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 96 |
+
|
| 97 |
+
def forward(self, x, t) -> torch.Tensor:
|
| 98 |
+
x_skip = x
|
| 99 |
+
|
| 100 |
+
t = self.f_t(F.silu(t))
|
| 101 |
+
t_1, t_2 = t.chunk(2, dim=1)
|
| 102 |
+
t_1 = t_1.unsqueeze(2) + 1
|
| 103 |
+
t_2 = t_2.unsqueeze(2)
|
| 104 |
+
|
| 105 |
+
gn_1 = F.silu(self.gn_1(x))
|
| 106 |
+
# 1D average pooling
|
| 107 |
+
avg_pool1d = F.avg_pool1d(gn_1, kernel_size=self.downsample_factor)
|
| 108 |
+
f_1 = self.f_1(avg_pool1d)
|
| 109 |
+
gn_2 = self.gn_2(f_1)
|
| 110 |
+
|
| 111 |
+
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
|
| 112 |
+
|
| 113 |
+
return f_2 + F.avg_pool1d(x_skip, kernel_size=self.downsample_factor)
|
| 114 |
+
|
| 115 |
+
class AudioUpsample(nn.Module):
|
| 116 |
+
"""1D upsampling for audio"""
|
| 117 |
+
def __init__(self, in_channels, t_dim, upsample_factor=2) -> None:
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.f_t = nn.Linear(t_dim, in_channels * 2)
|
| 120 |
+
self.upsample_factor = upsample_factor
|
| 121 |
+
|
| 122 |
+
self.gn_1 = nn.GroupNorm(32, in_channels)
|
| 123 |
+
self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 124 |
+
self.gn_2 = nn.GroupNorm(32, in_channels)
|
| 125 |
+
|
| 126 |
+
self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 127 |
+
|
| 128 |
+
def forward(self, x, t) -> torch.Tensor:
|
| 129 |
+
x_skip = x
|
| 130 |
+
|
| 131 |
+
t = self.f_t(F.silu(t))
|
| 132 |
+
t_1, t_2 = t.chunk(2, dim=1)
|
| 133 |
+
t_1 = t_1.unsqueeze(2) + 1
|
| 134 |
+
t_2 = t_2.unsqueeze(2)
|
| 135 |
+
|
| 136 |
+
gn_1 = F.silu(self.gn_1(x))
|
| 137 |
+
# 1D interpolation upsampling
|
| 138 |
+
upsample = F.interpolate(gn_1, scale_factor=self.upsample_factor, mode='nearest')
|
| 139 |
+
f_1 = self.f_1(upsample)
|
| 140 |
+
gn_2 = self.gn_2(f_1)
|
| 141 |
+
|
| 142 |
+
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
|
| 143 |
+
|
| 144 |
+
return f_2 + F.interpolate(x_skip, scale_factor=self.upsample_factor, mode='nearest')
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@register('audio_diffusion_unet')
|
| 148 |
+
class AudioDiffusionUNet(nn.Module):
|
| 149 |
+
"""
|
| 150 |
+
1D UNet for audio diffusion with dynamic latent conditioning
|
| 151 |
+
|
| 152 |
+
Handles:
|
| 153 |
+
- x: [batch, 1, samples] - audio waveform (dynamic length)
|
| 154 |
+
- z_dec: [batch, 64, n_frames] - latent conditioning (dynamic length)
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
in_channels=1, # Audio channels (mono=1, stereo=2)
|
| 160 |
+
z_dec_channels=64, # Latent conditioning channels
|
| 161 |
+
c0=128, c1=256, c2=512, # Channel progression (smaller than image version)
|
| 162 |
+
pe_dim=320,
|
| 163 |
+
t_dim=1280,
|
| 164 |
+
kernel_size=3
|
| 165 |
+
) -> None:
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
# Store for dynamic conditioning
|
| 169 |
+
self.z_dec_channels = z_dec_channels
|
| 170 |
+
|
| 171 |
+
# Audio input embedding
|
| 172 |
+
self.embed_audio = AudioEmbedding(
|
| 173 |
+
in_channels=in_channels,
|
| 174 |
+
out_channels=c0,
|
| 175 |
+
kernel_size=kernel_size
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Time embedding
|
| 179 |
+
self.embed_time = PositionalEmbedding(pe_dim=pe_dim, out_dim=t_dim)
|
| 180 |
+
|
| 181 |
+
# Latent conditioning projection
|
| 182 |
+
if z_dec_channels is not None:
|
| 183 |
+
self.z_dec_proj = nn.Conv1d(z_dec_channels, c0, kernel_size=1)
|
| 184 |
+
|
| 185 |
+
# Downsampling path
|
| 186 |
+
down_0 = nn.ModuleList([
|
| 187 |
+
AudioConvResblock(c0, c0, t_dim, kernel_size),
|
| 188 |
+
AudioConvResblock(c0, c0, t_dim, kernel_size),
|
| 189 |
+
AudioConvResblock(c0, c0, t_dim, kernel_size),
|
| 190 |
+
AudioDownsample(c0, t_dim),
|
| 191 |
+
])
|
| 192 |
+
down_1 = nn.ModuleList([
|
| 193 |
+
AudioConvResblock(c0, c1, t_dim, kernel_size),
|
| 194 |
+
AudioConvResblock(c1, c1, t_dim, kernel_size),
|
| 195 |
+
AudioConvResblock(c1, c1, t_dim, kernel_size),
|
| 196 |
+
AudioDownsample(c1, t_dim),
|
| 197 |
+
])
|
| 198 |
+
down_2 = nn.ModuleList([
|
| 199 |
+
AudioConvResblock(c1, c2, t_dim, kernel_size),
|
| 200 |
+
AudioConvResblock(c2, c2, t_dim, kernel_size),
|
| 201 |
+
AudioConvResblock(c2, c2, t_dim, kernel_size),
|
| 202 |
+
AudioDownsample(c2, t_dim),
|
| 203 |
+
])
|
| 204 |
+
down_3 = nn.ModuleList([
|
| 205 |
+
AudioConvResblock(c2, c2, t_dim, kernel_size),
|
| 206 |
+
AudioConvResblock(c2, c2, t_dim, kernel_size),
|
| 207 |
+
AudioConvResblock(c2, c2, t_dim, kernel_size),
|
| 208 |
+
])
|
| 209 |
+
self.down = nn.ModuleList([down_0, down_1, down_2, down_3])
|
| 210 |
+
|
| 211 |
+
# Middle layers
|
| 212 |
+
self.mid = nn.ModuleList([
|
| 213 |
+
AudioConvResblock(c2, c2, t_dim, kernel_size),
|
| 214 |
+
AudioConvResblock(c2, c2, t_dim, kernel_size),
|
| 215 |
+
])
|
| 216 |
+
|
| 217 |
+
# Upsampling path
|
| 218 |
+
up_3 = nn.ModuleList([
|
| 219 |
+
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
|
| 220 |
+
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
|
| 221 |
+
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
|
| 222 |
+
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
|
| 223 |
+
AudioUpsample(c2, t_dim),
|
| 224 |
+
])
|
| 225 |
+
up_2 = nn.ModuleList([
|
| 226 |
+
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
|
| 227 |
+
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
|
| 228 |
+
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
|
| 229 |
+
AudioConvResblock(c2 + c1, c2, t_dim, kernel_size),
|
| 230 |
+
AudioUpsample(c2, t_dim),
|
| 231 |
+
])
|
| 232 |
+
up_1 = nn.ModuleList([
|
| 233 |
+
AudioConvResblock(c2 + c1, c1, t_dim, kernel_size),
|
| 234 |
+
AudioConvResblock(c1 * 2, c1, t_dim, kernel_size),
|
| 235 |
+
AudioConvResblock(c1 * 2, c1, t_dim, kernel_size),
|
| 236 |
+
AudioConvResblock(c0 + c1, c1, t_dim, kernel_size),
|
| 237 |
+
AudioUpsample(c1, t_dim),
|
| 238 |
+
])
|
| 239 |
+
up_0 = nn.ModuleList([
|
| 240 |
+
AudioConvResblock(c0 + c1, c0, t_dim, kernel_size),
|
| 241 |
+
AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
|
| 242 |
+
AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
|
| 243 |
+
AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
|
| 244 |
+
])
|
| 245 |
+
self.up = nn.ModuleList([up_0, up_1, up_2, up_3])
|
| 246 |
+
|
| 247 |
+
# Output layer
|
| 248 |
+
self.output = AudioUnembedding(in_channels=c0, out_channels=in_channels)
|
| 249 |
+
|
| 250 |
+
def get_last_layer_weight(self):
|
| 251 |
+
return self.output.f.weight
|
| 252 |
+
|
| 253 |
+
def condition_with_latents(self, x, z_dec):
|
| 254 |
+
"""
|
| 255 |
+
Add latent conditioning to audio features
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
x: [batch, c0, audio_samples] - audio features
|
| 259 |
+
z_dec: [batch, 64, n_frames] - latent conditioning
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
x: [batch, c0, audio_samples] - conditioned features
|
| 263 |
+
"""
|
| 264 |
+
if z_dec is None:
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
# Project latents to same channel dimension as audio features
|
| 268 |
+
z_proj = self.z_dec_proj(z_dec) # [batch, c0, n_frames]
|
| 269 |
+
|
| 270 |
+
# Interpolate latents to match audio length
|
| 271 |
+
if z_proj.shape[-1] != x.shape[-1]:
|
| 272 |
+
z_proj = F.interpolate(
|
| 273 |
+
z_proj,
|
| 274 |
+
size=x.shape[-1],
|
| 275 |
+
mode='nearest' # or 'linear' for smoother interpolation
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Add latent conditioning to audio features
|
| 279 |
+
return x + z_proj
|
| 280 |
+
|
| 281 |
+
def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
|
| 282 |
+
"""
|
| 283 |
+
Forward pass
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
x: [batch, 1, samples] - audio waveform (any length)
|
| 287 |
+
t: [batch] - diffusion timesteps
|
| 288 |
+
z_dec: [batch, 64, n_frames] - latent conditioning (any length)
|
| 289 |
+
"""
|
| 290 |
+
# Embed audio input
|
| 291 |
+
x = self.embed_audio(x) # [batch, c0, samples]
|
| 292 |
+
|
| 293 |
+
# Add latent conditioning
|
| 294 |
+
if z_dec is not None:
|
| 295 |
+
x = self.condition_with_latents(x, z_dec)
|
| 296 |
+
|
| 297 |
+
# Embed timestep
|
| 298 |
+
if t is None:
|
| 299 |
+
t = torch.zeros(x.shape[0], device=x.device)
|
| 300 |
+
t = self.embed_time(t) # [batch, t_dim]
|
| 301 |
+
|
| 302 |
+
# Downsampling with skip connections
|
| 303 |
+
skips = [x]
|
| 304 |
+
for down in self.down:
|
| 305 |
+
for block in down:
|
| 306 |
+
x = block(x, t)
|
| 307 |
+
skips.append(x)
|
| 308 |
+
|
| 309 |
+
# Middle layers
|
| 310 |
+
for mid in self.mid:
|
| 311 |
+
x = mid(x, t)
|
| 312 |
+
|
| 313 |
+
# Upsampling with skip connections
|
| 314 |
+
for up in self.up[::-1]:
|
| 315 |
+
for block in up:
|
| 316 |
+
if isinstance(block, AudioConvResblock):
|
| 317 |
+
x = torch.cat([x, skips.pop()], dim=1)
|
| 318 |
+
x = block(x, t)
|
| 319 |
+
|
| 320 |
+
# Output
|
| 321 |
+
return self.output(x)
|
| 322 |
+
|
flowae/models/networks/consistency_decoder_unet.py
CHANGED
|
@@ -239,6 +239,7 @@ class ConsistencyDecoderUNet(nn.Module):
|
|
| 239 |
|
| 240 |
def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
|
| 241 |
if z_dec is not None:
|
|
|
|
| 242 |
if z_dec.shape[-2] != x.shape[-2] or z_dec.shape[-1] != x.shape[-1]:
|
| 243 |
assert x.shape[-2] // z_dec.shape[-2] == x.shape[-1] // z_dec.shape[-1]
|
| 244 |
z_dec = F.upsample_nearest(z_dec, scale_factor=x.shape[-2] // z_dec.shape[-2])
|
|
|
|
| 239 |
|
| 240 |
def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
|
| 241 |
if z_dec is not None:
|
| 242 |
+
print('shape of x and z_dec: ', x.shape, z_dec.shape)
|
| 243 |
if z_dec.shape[-2] != x.shape[-2] or z_dec.shape[-1] != x.shape[-1]:
|
| 244 |
assert x.shape[-2] // z_dec.shape[-2] == x.shape[-1] // z_dec.shape[-1]
|
| 245 |
z_dec = F.upsample_nearest(z_dec, scale_factor=x.shape[-2] // z_dec.shape[-2])
|
flowae/run.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-f8c4-noise-sync.yaml --save-root /mnt/nvme/dito
|
| 2 |
+
torchrun --nnodes=1 --nproc-per-node=1 run.py --config configs/experiments/dito-B-audio.yaml --save-root /mnt/nvme/dito
|
flowae/upload.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
az storage blob upload-batch \
|
| 2 |
+
--connection-string ""
|