Spaces:
Sleeping
Sleeping
File size: 12,207 Bytes
2279ae0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from pathlib import Path
import argparse
import soundfile as sf
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
# Import models
import models
from models.ldm.dac.audiotools import AudioSignal
class AudioDiToInference:
def __init__(self, checkpoint_path, device='cuda'):
"""Initialize Audio DiTo model from checkpoint"""
self.device = device
# Load checkpoint
print(f"Loading checkpoint from {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location='cpu')
# Extract config
self.config = OmegaConf.create(ckpt['config'])
# Create model
self.model = models.make(self.config['model'])
# Load state dict
self.model.load_state_dict(ckpt['model']['sd'])
# Move to device and set to eval
self.model = self.model.to(device)
self.model.eval()
# Get audio parameters from config
self.sample_rate = self.config.get('sample_rate', 24000)
self.mono = self.config.get('mono', True)
print(f"Model loaded successfully!")
print(f"Sample rate: {self.sample_rate} Hz")
print(f"Mono: {self.mono}")
def load_audio(self, audio_path, duration=None, offset=0.0):
"""Load audio file using AudioSignal
Args:
audio_path: Path to audio file
duration: Duration in seconds (None for full audio)
offset: Start offset in seconds
"""
# Load audio using AudioSignal
if duration is not None:
signal = AudioSignal(
str(audio_path),
duration=duration,
offset=offset,
)
else:
# Load full audio
signal = AudioSignal(str(audio_path))
# Convert to mono if needed
if self.mono and signal.num_channels > 1:
signal = signal.to_mono()
# Resample to model sample rate
if signal.sample_rate != self.sample_rate:
signal = signal.resample(self.sample_rate)
# Normalize
signal = signal.normalize()
# Clamp to [-1, 1]
signal.audio_data = signal.audio_data.clamp(-1.0, 1.0)
return signal
def save_audio(self, reconstructed, output_path):
"""Save AudioSignal to file"""
# Get audio data
print('shape of reconstructed: ', reconstructed.shape)
sf.write(output_path, reconstructed, self.sample_rate)
print(f"Saved audio to {output_path}")
def reconstruct_audio(self, audio_path, num_steps=50, save_latent=False):
"""Reconstruct entire audio file at once
Args:
audio_path: Path to audio file
num_steps: Number of diffusion steps
save_latent: Whether to return the latent representation
"""
# Load full audio without duration limit
signal = self.load_audio(audio_path, duration=None, offset=0.0)
# Get audio tensor
audio_tensor = signal.audio_data # [channels, samples]
if audio_tensor.dim() == 2:
audio_tensor = audio_tensor.squeeze(0) # [samples] for mono
# Add batch dimension
audio_tensor = audio_tensor.to(self.device) # [1, samples]
print(f"Input shape: {audio_tensor.shape}")
print(f"Full audio duration: {audio_tensor.shape[-1] / self.sample_rate:.2f}s")
with torch.no_grad():
# Prepare data dict
data = {'inp': audio_tensor}
# Step 1: Encode to latent
print('shape of audio_tensor: ', audio_tensor.shape)
z = self.model.encode(audio_tensor)
print(f"Latent shape: {z.shape}")
# Step 2: Decode latent (if model has separate decode step)
if hasattr(self.model, 'decode'):
z_dec = self.model.decode(z)
else:
z_dec = z
print(f"Decoded latent shape: {z_dec.shape}")
# Step 3: Prepare dummy coordinates (based on training code)
b, *_ = audio_tensor.shape
# Step 4: Render using diffusion
if hasattr(self.model, 'render'):
# Render expects z_dec, coord, scale
print('using render diffusion model')
reconstructed = self.model.render(z_dec)
else:
# Alternative: direct decode if render not available
reconstructed = self.model(data, mode='pred')
# Remove batch dimension
reconstructed = reconstructed.squeeze(0).squeeze(0).cpu().numpy() # [samples]
print('shape of reconstructed: ', reconstructed.shape)
if save_latent:
return reconstructed, z.cpu()
else:
return reconstructed
def save_reconstruction(self, audio_path, output_path, num_steps=50):
"""Reconstruct and save entire audio file"""
reconstructed = self.reconstruct_audio(audio_path, num_steps)
self.save_audio(reconstructed, output_path)
def compare_reconstruction(self, audio_path, output_path, num_steps=50):
"""Save original and reconstruction concatenated"""
# Load original full audio
original = self.load_audio(audio_path, duration=None, offset=0.0)
# Get reconstruction of full audio
reconstructed = self.reconstruct_audio(audio_path, num_steps)
# Add 0.5 second silence between clips
silence_samples = int(0.5 * self.sample_rate)
silence_data = torch.zeros(1, silence_samples)
# Concatenate: original -> silence -> reconstruction
concat_data = torch.cat([
original.audio_data.cpu(),
silence_data,
reconstructed.audio_data.cpu()
], dim=1)
# Create concatenated signal
comparison = AudioSignal(
concat_data,
sample_rate=self.sample_rate
)
self.save_audio(comparison, output_path)
print(f"Saved comparison (original + reconstruction) to {output_path}")
def visualize_latent(self, audio_path, output_path):
"""Visualize the latent representation of full audio"""
# Get latent
_, z = self.reconstruct_audio(audio_path, save_latent=True)
z_np = z.squeeze(0).numpy() # Remove batch dimension
# Create visualization
if z_np.ndim == 2: # [channels, frames]
n_channels = z_np.shape[0]
fig, axes = plt.subplots(n_channels, 1, figsize=(12, 2*n_channels))
if n_channels == 1:
axes = [axes]
for i in range(n_channels):
im = axes[i].imshow(
z_np[i:i+1],
aspect='auto',
cmap='coolwarm',
interpolation='nearest'
)
axes[i].set_title(f'Latent Channel {i+1}')
axes[i].set_xlabel('Time Frames')
axes[i].set_ylabel('Feature')
plt.colorbar(im, ax=axes[i])
else: # 1D latent
plt.figure(figsize=(12, 4))
plt.plot(z_np.T)
plt.title('Latent Representation')
plt.xlabel('Time Frames')
plt.ylabel('Value')
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
print(f"Saved latent visualization to {output_path}")
def batch_reconstruct(self, audio_folder, output_folder, max_files=None, num_steps=50):
"""Reconstruct all audio files in a folder (full audio)"""
audio_folder = Path(audio_folder)
output_folder = Path(output_folder)
output_folder.mkdir(exist_ok=True, parents=True)
# Get all audio files
audio_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
audio_paths = []
for ext in audio_extensions:
audio_paths.extend(audio_folder.glob(f'*{ext}'))
audio_paths.extend(audio_folder.glob(f'*{ext.upper()}'))
if max_files:
audio_paths = audio_paths[:max_files]
print(f"Processing {len(audio_paths)} audio files...")
for audio_path in audio_paths:
output_path = output_folder / f"recon_{audio_path.stem}.wav"
try:
self.save_reconstruction(
str(audio_path), str(output_path),
num_steps=num_steps
)
except Exception as e:
print(f"Error processing {audio_path}: {e}")
continue
print("Batch reconstruction complete!")
def main():
parser = argparse.ArgumentParser(description='Audio DiTo Inference')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to Audio DiTo checkpoint')
parser.add_argument('--input', type=str, required=True,
help='Input audio path or folder')
parser.add_argument('--output', type=str, required=True,
help='Output path')
parser.add_argument('--compare', action='store_true',
help='Save comparison with original')
parser.add_argument('--batch', action='store_true',
help='Process entire folder')
parser.add_argument('--visualize', action='store_true',
help='Visualize latent representation')
parser.add_argument('--steps', type=int, default=50,
help='Number of diffusion steps')
parser.add_argument('--device', type=str, default='cuda',
help='Device to use (cuda/cpu)')
parser.add_argument('--max-files', type=int, default=None,
help='Maximum files to process in batch mode')
args = parser.parse_args()
# Initialize model
audio_dito = AudioDiToInference(args.checkpoint, device=args.device)
# Process based on mode
if args.batch:
# Batch processing
audio_dito.batch_reconstruct(
args.input, args.output,
max_files=args.max_files,
num_steps=args.steps
)
elif args.visualize:
# Visualize latent
audio_dito.visualize_latent(
args.input, args.output
)
elif args.compare:
# Save comparison
audio_dito.compare_reconstruction(
args.input, args.output,
num_steps=args.steps
)
else:
# Single reconstruction
audio_dito.save_reconstruction(
args.input, args.output,
num_steps=args.steps
)
# Example usage function for direct Python use
def reconstruct_single_audio(checkpoint_path, audio_path, output_path):
"""Simple function to reconstruct a single audio file"""
audio_dito = AudioDiToInference(checkpoint_path)
audio_dito.save_reconstruction(audio_path, output_path)
if __name__ == "__main__":
main()
# Usage examples:
# 1. Single audio reconstruction (full audio):
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav
#
# 2. Save comparison (original + reconstruction):
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output compare.wav --compare
#
# 3. Batch processing (reconstruct all audio files in folder):
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio_folder/ --output output_folder/ --batch
#
# 4. Visualize latent representation:
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output latent.png --visualize
#
# 5. Use fewer diffusion steps for faster inference:
# python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav --steps 25 |