File size: 12,207 Bytes
2279ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from pathlib import Path
import argparse
import soundfile as sf
from omegaconf import OmegaConf
import matplotlib.pyplot as plt

# Import models
import models
from models.ldm.dac.audiotools import AudioSignal


class AudioDiToInference:
    def __init__(self, checkpoint_path, device='cuda'):
        """Initialize Audio DiTo model from checkpoint"""
        self.device = device
        
        # Load checkpoint
        print(f"Loading checkpoint from {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location='cpu')
        
        # Extract config
        self.config = OmegaConf.create(ckpt['config'])
        
        # Create model
        self.model = models.make(self.config['model'])
        
        # Load state dict
        self.model.load_state_dict(ckpt['model']['sd'])
        
        # Move to device and set to eval
        self.model = self.model.to(device)
        self.model.eval()
        
        # Get audio parameters from config
        self.sample_rate = self.config.get('sample_rate', 24000)
        self.mono = self.config.get('mono', True)
        
        print(f"Model loaded successfully!")
        print(f"Sample rate: {self.sample_rate} Hz")
        print(f"Mono: {self.mono}")
        
    def load_audio(self, audio_path, duration=None, offset=0.0):
        """Load audio file using AudioSignal
        
        Args:
            audio_path: Path to audio file
            duration: Duration in seconds (None for full audio)
            offset: Start offset in seconds
        """
        # Load audio using AudioSignal
        if duration is not None:
            signal = AudioSignal(
                str(audio_path),
                duration=duration,
                offset=offset,
            )
        else:
            # Load full audio
            signal = AudioSignal(str(audio_path))
        
        # Convert to mono if needed
        if self.mono and signal.num_channels > 1:
            signal = signal.to_mono()
        
        # Resample to model sample rate
        if signal.sample_rate != self.sample_rate:
            signal = signal.resample(self.sample_rate)
        
        # Normalize
        signal = signal.normalize()
        
        # Clamp to [-1, 1]
        signal.audio_data = signal.audio_data.clamp(-1.0, 1.0)
        
        return signal
    
    def save_audio(self, reconstructed, output_path):
        """Save AudioSignal to file"""
        # Get audio data
        print('shape of reconstructed: ', reconstructed.shape)
        sf.write(output_path, reconstructed, self.sample_rate)
        print(f"Saved audio to {output_path}")
    
    def reconstruct_audio(self, audio_path, num_steps=50, save_latent=False):
        """Reconstruct entire audio file at once
        
        Args:
            audio_path: Path to audio file
            num_steps: Number of diffusion steps
            save_latent: Whether to return the latent representation
        """
        # Load full audio without duration limit
        signal = self.load_audio(audio_path, duration=None, offset=0.0)
        
        # Get audio tensor
        audio_tensor = signal.audio_data  # [channels, samples]
        if audio_tensor.dim() == 2:
            audio_tensor = audio_tensor.squeeze(0)  # [samples] for mono
        
        # Add batch dimension
        audio_tensor = audio_tensor.to(self.device)  # [1, samples]
        
        print(f"Input shape: {audio_tensor.shape}")
        print(f"Full audio duration: {audio_tensor.shape[-1] / self.sample_rate:.2f}s")
        
        with torch.no_grad():
            # Prepare data dict
            data = {'inp': audio_tensor}
            
            # Step 1: Encode to latent
            print('shape of audio_tensor: ', audio_tensor.shape)
            z = self.model.encode(audio_tensor)
            print(f"Latent shape: {z.shape}")
            
            # Step 2: Decode latent (if model has separate decode step)
            if hasattr(self.model, 'decode'):
                z_dec = self.model.decode(z)
            else:
                z_dec = z
            print(f"Decoded latent shape: {z_dec.shape}")
            
            # Step 3: Prepare dummy coordinates (based on training code)
            b, *_ = audio_tensor.shape

            
            # Step 4: Render using diffusion
            if hasattr(self.model, 'render'):
                # Render expects z_dec, coord, scale
                print('using render diffusion model')
                reconstructed = self.model.render(z_dec)
            else:
                # Alternative: direct decode if render not available
                reconstructed = self.model(data, mode='pred')
        
        # Remove batch dimension
        reconstructed = reconstructed.squeeze(0).squeeze(0).cpu().numpy()  # [samples]

        print('shape of reconstructed: ', reconstructed.shape)
    
        
        if save_latent:
            return reconstructed, z.cpu()
        else:
            return reconstructed
    
    def save_reconstruction(self, audio_path, output_path, num_steps=50):
        """Reconstruct and save entire audio file"""
        reconstructed = self.reconstruct_audio(audio_path, num_steps)
        self.save_audio(reconstructed, output_path)
    
    def compare_reconstruction(self, audio_path, output_path, num_steps=50):
        """Save original and reconstruction concatenated"""
        # Load original full audio
        original = self.load_audio(audio_path, duration=None, offset=0.0)
        
        # Get reconstruction of full audio
        reconstructed = self.reconstruct_audio(audio_path, num_steps)
        
        # Add 0.5 second silence between clips
        silence_samples = int(0.5 * self.sample_rate)
        silence_data = torch.zeros(1, silence_samples)
        
        # Concatenate: original -> silence -> reconstruction
        concat_data = torch.cat([
            original.audio_data.cpu(),
            silence_data,
            reconstructed.audio_data.cpu()
        ], dim=1)
        
        # Create concatenated signal
        comparison = AudioSignal(
            concat_data,
            sample_rate=self.sample_rate
        )
        
        self.save_audio(comparison, output_path)
        print(f"Saved comparison (original + reconstruction) to {output_path}")
    
    def visualize_latent(self, audio_path, output_path):
        """Visualize the latent representation of full audio"""
        # Get latent
        _, z = self.reconstruct_audio(audio_path, save_latent=True)
        
        z_np = z.squeeze(0).numpy()  # Remove batch dimension
        
        # Create visualization
        if z_np.ndim == 2:  # [channels, frames]
            n_channels = z_np.shape[0]
            fig, axes = plt.subplots(n_channels, 1, figsize=(12, 2*n_channels))
            
            if n_channels == 1:
                axes = [axes]
            
            for i in range(n_channels):
                im = axes[i].imshow(
                    z_np[i:i+1], 
                    aspect='auto', 
                    cmap='coolwarm',
                    interpolation='nearest'
                )
                axes[i].set_title(f'Latent Channel {i+1}')
                axes[i].set_xlabel('Time Frames')
                axes[i].set_ylabel('Feature')
                plt.colorbar(im, ax=axes[i])
        else:  # 1D latent
            plt.figure(figsize=(12, 4))
            plt.plot(z_np.T)
            plt.title('Latent Representation')
            plt.xlabel('Time Frames')
            plt.ylabel('Value')
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=150)
        plt.close()
        
        print(f"Saved latent visualization to {output_path}")
    
    def batch_reconstruct(self, audio_folder, output_folder, max_files=None, num_steps=50):
        """Reconstruct all audio files in a folder (full audio)"""
        audio_folder = Path(audio_folder)
        output_folder = Path(output_folder)
        output_folder.mkdir(exist_ok=True, parents=True)
        
        # Get all audio files
        audio_extensions = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
        audio_paths = []
        for ext in audio_extensions:
            audio_paths.extend(audio_folder.glob(f'*{ext}'))
            audio_paths.extend(audio_folder.glob(f'*{ext.upper()}'))
        
        if max_files:
            audio_paths = audio_paths[:max_files]
        
        print(f"Processing {len(audio_paths)} audio files...")
        
        for audio_path in audio_paths:
            output_path = output_folder / f"recon_{audio_path.stem}.wav"
            try:
                self.save_reconstruction(
                    str(audio_path), str(output_path), 
                    num_steps=num_steps
                )
            except Exception as e:
                print(f"Error processing {audio_path}: {e}")
                continue
        
        print("Batch reconstruction complete!")


def main():
    parser = argparse.ArgumentParser(description='Audio DiTo Inference')
    parser.add_argument('--checkpoint', type=str, required=True,
                        help='Path to Audio DiTo checkpoint')
    parser.add_argument('--input', type=str, required=True,
                        help='Input audio path or folder')
    parser.add_argument('--output', type=str, required=True,
                        help='Output path')
    parser.add_argument('--compare', action='store_true',
                        help='Save comparison with original')
    parser.add_argument('--batch', action='store_true',
                        help='Process entire folder')
    parser.add_argument('--visualize', action='store_true',
                        help='Visualize latent representation')
    parser.add_argument('--steps', type=int, default=50,
                        help='Number of diffusion steps')
    parser.add_argument('--device', type=str, default='cuda',
                        help='Device to use (cuda/cpu)')
    parser.add_argument('--max-files', type=int, default=None,
                        help='Maximum files to process in batch mode')
    
    args = parser.parse_args()
    
    # Initialize model
    audio_dito = AudioDiToInference(args.checkpoint, device=args.device)
    
    # Process based on mode
    if args.batch:
        # Batch processing
        audio_dito.batch_reconstruct(
            args.input, args.output, 
            max_files=args.max_files,
            num_steps=args.steps
        )
    elif args.visualize:
        # Visualize latent
        audio_dito.visualize_latent(
            args.input, args.output
        )
    elif args.compare:
        # Save comparison
        audio_dito.compare_reconstruction(
            args.input, args.output,
            num_steps=args.steps
        )
    else:
        # Single reconstruction
        audio_dito.save_reconstruction(
            args.input, args.output,
            num_steps=args.steps
        )


# Example usage function for direct Python use
def reconstruct_single_audio(checkpoint_path, audio_path, output_path):
    """Simple function to reconstruct a single audio file"""
    audio_dito = AudioDiToInference(checkpoint_path)
    audio_dito.save_reconstruction(audio_path, output_path)


if __name__ == "__main__":
    main()


# Usage examples:
# 1. Single audio reconstruction (full audio):
#    python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav
#
# 2. Save comparison (original + reconstruction):
#    python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output compare.wav --compare
#
# 3. Batch processing (reconstruct all audio files in folder):
#    python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio_folder/ --output output_folder/ --batch
#
# 4. Visualize latent representation:
#    python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output latent.png --visualize
#
# 5. Use fewer diffusion steps for faster inference:
#    python audio_dito_inference.py --checkpoint ckpt-best.pth --input audio.wav --output recon.wav --steps 25