File size: 3,447 Bytes
12bbde9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torchaudio
from pathlib import Path
import argparse
from tqdm import tqdm
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE

class AudioVAE:
    def __init__(self, device: torch.device):
        self.model = MusicDCAE().to(device)
        self.model.eval()
        self.device = device
        self.latent_mean = torch.tensor(
            [0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526],
            device=device,
        ).view(1, -1, 1, 1)
        self.latent_std = torch.tensor(
            [0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707],
            device=device,
        ).view(1, -1, 1, 1)

    def encode(self, audio):

        with torch.no_grad():
            audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device)
            latents, _ = self.model.encode(audio, audio_lengths, sr=48000)
            latents = (latents - self.latent_mean) / self.latent_std
        return latents

    def decode(self, latents: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            latents = latents * self.latent_std + self.latent_mean
            _, audio_list = self.model.decode(latents, sr=48000)
            audio_batch = torch.stack(audio_list).to(self.device)
        return audio_batch

def load_audio(audio_path, target_sr=48000):
    """Load and preprocess audio file."""
    audio, sr = torchaudio.load(audio_path)

    if audio.shape[0] == 1:
        audio = audio.repeat(2, 1)
    elif audio.shape[0] > 2:
        audio = audio[:2]

    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(sr, target_sr)
        audio = resampler(audio)

    return audio


def main():
    parser = argparse.ArgumentParser(description='Encode audio files to VAE latents')

    parser.add_argument('--audio-dir', type=str, required=True,
                        help='Directory containing audio files')
    parser.add_argument('--output-dir', type=str, default="latents",
                        help='Directory to save encoded latents')

    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    audio_dir = Path(args.audio_dir)
    audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.ogg', '*.m4a']
    audio_files = []
    for ext in audio_extensions:
        audio_files.extend(list(audio_dir.glob(ext)))
    audio_files = sorted(audio_files)

    if len(audio_files) == 0:
        raise ValueError(f"No audio files found in {args.audio_dir}")

    print(f"Found {len(audio_files)} audio files")

    vae = AudioVAE(device)
    print("VAE loaded")

    # Encode each audio file
    print("\nEncoding audio files...")
    for audio_path in tqdm(audio_files, desc="Encoding"):
        try:
            audio = load_audio(audio_path)
            audio = audio.unsqueeze(0).to(device)
            latents = vae.encode(audio)
            latents = latents.squeeze(0)

            output_path = output_dir / f"{audio_path.stem}.pt"
            torch.save(latents.cpu(), output_path)

        except Exception as e:
            print(f"\nError encoding {audio_path.name}: {e}")
            continue

    print(f"\nEncoding complete! Saved {len(list(output_dir.glob('*.pt')))} latent files to {output_dir}")

if __name__ == '__main__':
    main()