Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| import numpy as np | |
| from pathlib import Path | |
| import argparse | |
| import soundfile as sf | |
| from omegaconf import OmegaConf | |
| import matplotlib.pyplot as plt | |
| # Import models | |
| import models | |
| from models.ldm.dac.audiotools import AudioSignal | |
| class AudioDiToInference: | |
| def __init__(self, checkpoint_path, device='cuda'): | |
| """Initialize Audio DiTo model from checkpoint""" | |
| self.device = device | |
| # Load checkpoint | |
| print(f"Loading checkpoint from {checkpoint_path}") | |
| ckpt = torch.load(checkpoint_path, map_location='cpu') | |
| # Extract config | |
| self.config = OmegaConf.create(ckpt['config']) | |
| # Create model | |
| self.model = models.make(self.config['model']) | |
| # Load state dict | |
| self.model.load_state_dict(ckpt['model']['sd']) | |
| # Move to device and set to eval | |
| self.model = self.model.to(device) | |
| self.model.eval() | |
| # Get audio parameters from config | |
| self.sample_rate = self.config.get('sample_rate', 24000) | |
| self.mono = self.config.get('mono', True) | |
| print(f"Model loaded successfully!") | |
| print(f"Sample rate: {self.sample_rate} Hz") | |
| print(f"Mono: {self.mono}") | |
| def load_audio(self, audio_path, duration=None, offset=0.0): | |
| """Load audio file using AudioSignal | |
| Args: | |
| audio_path: Path to audio file | |
| duration: Duration in seconds (None for full audio) | |
| offset: Start offset in seconds | |
| """ | |
| # Load audio using AudioSignal | |
| if duration is not None: | |
| signal = AudioSignal( | |
| str(audio_path), | |
| duration=duration, | |
| offset=offset, | |
| ) | |
| else: | |
| # Load full audio | |
| signal = AudioSignal(str(audio_path)) | |
| # Convert to mono if needed | |
| if self.mono and signal.num_channels > 1: | |
| signal = signal.to_mono() | |
| # Resample to model sample rate | |
| if signal.sample_rate != self.sample_rate: | |
| signal = signal.resample(self.sample_rate) | |
| # Normalize | |
| signal = signal.normalize() | |
| # Clamp to [-1, 1] | |
| signal.audio_data = signal.audio_data.clamp(-1.0, 1.0) | |
| return signal | |
| def save_audio(self, reconstructed, output_path): | |
| """Save AudioSignal to file""" | |
| # Get audio data | |
| print('shape of reconstructed: ', reconstructed.shape) | |
| sf.write(output_path, reconstructed, self.sample_rate) | |
| print(f"Saved audio to {output_path}") | |
| def reconstruct_audio(self, audio_path, num_steps=50, save_latent=False): | |
| """Reconstruct entire audio file at once | |
| Args: | |
| audio_path: Path to audio file | |
| num_steps: Number of diffusion steps | |
| save_latent: Whether to return the latent representation | |
| """ | |
| # Load full audio without duration limit | |
| signal = self.load_audio(audio_path, duration=None, offset=0.0) | |
| # Get audio tensor | |
| audio_tensor = signal.audio_data # [channels, samples] | |
| if audio_tensor.dim() == 2: | |
| audio_tensor = audio_tensor.squeeze(0) # [samples] for mono | |
| # Add batch dimension | |
| audio_tensor = audio_tensor.to(self.device) # [1, samples] | |
| print(f"Input shape: {audio_tensor.shape}") | |
| print(f"Full audio duration: {audio_tensor.shape[-1] / self.sample_rate:.2f}s") | |
| with torch.no_grad(): | |
| # Prepare data dict | |
| data = {'inp': audio_tensor} | |
| # Step 1: Encode to latent | |
| print('shape of audio_tensor: ', audio_tensor.shape) | |
| z = self.model.encode(audio_tensor) | |
| print(f"Latent shape: {z.shape}") | |
| # Step 2: Decode latent (if model has separate decode step) | |
| if hasattr(self.model, 'decode'): | |
| z_dec = self.model.decode(z) | |
| else: | |
| z_dec = z | |
| print(f"Decoded latent shape: {z_dec.shape}") | |
| # Step 3: Prepare dummy coordinates (based on training code) | |
| b, *_ = audio_tensor.shape | |
| # Step 4: Render using diffusion | |
| if hasattr(self.model, 'render'): | |
| # Render expects z_dec, coord, scale | |
| print('using render diffusion model') | |
| reconstructed = self.model.render(z_dec) | |
| else: | |
| # Alternative: direct decode if render not available | |
| reconstructed = self.model(data, mode='pred') | |
| # Remove batch dimension | |
| reconstructed = reconstructed.squeeze(0).squeeze(0).cpu().numpy() # [samples] | |
| print('shape of reconstructed: ', reconstructed.shape) | |
| if save_latent: | |
| return reconstructed, z.cpu() | |
| else: | |
| return reconstructed | |
| def save_reconstruction(self, audio_path, output_path, num_steps=50): | |
| """Reconstruct and save entire audio file""" | |
| reconstructed = self.reconstruct_audio(audio_path, num_steps) | |
| self.save_audio(reconstructed, output_path) | |
| def compare_reconstruction(self, audio_path, output_path, num_steps=50): | |
| """Save original and reconstruction concatenated""" | |
| # Load original full audio | |
| original = self.load_audio(audio_path, duration=None, offset=0.0) | |
| # Get reconstruction of full audio | |
| reconstructed = self.reconstruct_audio(audio_path, num_steps) | |
| # Add 0.5 second silence between clips | |
| silence_samples = int(0.5 * self.sample_rate) | |
| silence_data = torch.zeros(1, silence_samples) | |
| # Concatenate: original -> silence -> reconstruction | |
| concat_data = torch.cat([ | |
| original.audio_data.cpu(), | |
| silence_data, | |
| reconstructed.audio_data.cpu() | |
| ], dim=1) | |
| # Create concatenated signal | |
| comparison = AudioSignal( | |
| concat_data, | |
| sample_rate=self.sample_rate | |
| ) | |
| self.save_audio(comparison, output_path) | |
| print(f"Saved comparison (original + reconstruction) to {output_path}") | |
| def visualize_latent(self, audio_path, output_path): | |
| """Visualize the latent representation of full audio""" | |
| # Get latent | |
| _, z = self.reconstruct_audio(audio_path, save_latent=True) | |
| z_np = z.squeeze(0).numpy() # Remove batch dimension | |
| # Create visualization | |
| if z_np.ndim == 2: # [channels, frames] | |
| n_channels = z_np.shape[0] | |
| fig, axes = plt.subplots(n_channels, 1, figsize=(12, 2*n_channels)) | |
| if n_channels == 1: | |
| axes = [axes] | |
| for i in range(n_channels): | |
| im = axes[i].imshow( | |
| z_np[i:i+1], | |
| aspect='auto', | |
| cmap='coolwarm', | |
| interpolation='nearest' | |
| ) | |
| axes[i].set_title(f'Latent Channel {i+1}') | |
| axes[i].set_xlabel('Time Frames') | |
| axes[i].set_ylabel('Feature') | |
| plt.colorbar(im, ax=axes[i]) | |
| else: # 1D latent | |
| plt.figure(figsize=(12, 4)) | |
| plt.plot(z_np.T) | |
| plt.title('Latent Representation') | |
| plt.xlabel('Time Frames') | |
| plt.ylabel('Value') | |
| plt.tight_layout() | |
| plt.savefig(output_path, dpi=150) | |
| plt.close() | |
| print(f"Saved latent visualization to {output_path}") | |
| def batch_reconstruct(self, audio_folder, output_folder, max_files=None, num_steps=50): | |
| """Reconstruct all audio files in a folder (full audio)""" | |
| audio_folder = Path(audio_folder) | |
| output_folder = Path(output_folder) | |
| output_folder.mkdir(exist_ok=True, parents=True) | |
| # Get all audio files | |
| audio_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] | |
| audio_paths = [] | |
| for ext in audio_extensions: | |
| audio_paths.extend(audio_folder.glob(f'*{ext}')) | |
| audio_paths.extend(audio_folder.glob(f'*{ext.upper()}')) | |
| if max_files: | |
| audio_paths = audio_paths[:max_files] | |
| print(f"Processing {len(audio_paths)} audio files...") | |
| for audio_path in audio_paths: | |
| output_path = output_folder / f"recon_{audio_path.stem}.wav" | |
| try: | |
| self.save_reconstruction( | |
| str(audio_path), str(output_path), | |
| num_steps=num_steps | |
| ) | |
| except Exception as e: | |
| print(f"Error processing {audio_path}: {e}") | |
| continue | |
| print("Batch reconstruction complete!") | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Audio DiTo Inference') | |
| parser.add_argument('--checkpoint', type=str, required=True, | |
| help='Path to Audio DiTo checkpoint') | |
| parser.add_argument('--input', type=str, required=True, | |
| help='Input audio path or folder') | |
| parser.add_argument('--output', type=str, required=True, | |
| help='Output path') | |
| parser.add_argument('--compare', action='store_true', | |
| help='Save comparison with original') | |
| parser.add_argument('--batch', action='store_true', | |
| help='Process entire folder') | |
| parser.add_argument('--visualize', action='store_true', | |
| help='Visualize latent representation') | |
| parser.add_argument('--steps', type=int, default=50, | |
| help='Number of diffusion steps') | |
| parser.add_argument('--device', type=str, default='cuda', | |
| help='Device to use (cuda/cpu)') | |
| parser.add_argument('--max-files', type=int, default=None, | |
| help='Maximum files to process in batch mode') | |
| args = parser.parse_args() | |
| # Initialize model | |
| audio_dito = AudioDiToInference(args.checkpoint, device=args.device) | |
| # Process based on mode | |
| if args.batch: | |
| # Batch processing | |
| audio_dito.batch_reconstruct( | |
| args.input, args.output, | |
| max_files=args.max_files, | |
| num_steps=args.steps | |
| ) | |
| elif args.visualize: | |
| # Visualize latent | |
| audio_dito.visualize_latent( | |
| args.input, args.output | |
| ) | |
| elif args.compare: | |
| # Save comparison | |
| audio_dito.compare_reconstruction( | |
| args.input, args.output, | |
| num_steps=args.steps | |
| ) | |
| else: | |
| # Single reconstruction | |
| audio_dito.save_reconstruction( | |
| args.input, args.output, | |
| num_steps=args.steps | |
| ) | |
| # Example usage function for direct Python use | |
| def reconstruct_single_audio(checkpoint_path, audio_path, output_path): | |
| """Simple function to reconstruct a single audio file""" | |
| audio_dito = AudioDiToInference(checkpoint_path) | |
| audio_dito.save_reconstruction(audio_path, output_path) | |
| if __name__ == "__main__": | |
| main() | |
| # Usage examples: | |
| # 1. Single audio reconstruction (full audio): | |
| # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav | |
| # | |
| # 2. Save comparison (original + reconstruction): | |
| # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output compare.wav --compare | |
| # | |
| # 3. Batch processing (reconstruct all audio files in folder): | |
| # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio_folder/ --output output_folder/ --batch | |
| # | |
| # 4. Visualize latent representation: | |
| # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output latent.png --visualize | |
| # | |
| # 5. Use fewer diffusion steps for faster inference: | |
| # python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav --steps 25 |