|
|
import torch |
|
|
import torchaudio |
|
|
from pathlib import Path |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
|
|
|
|
|
class AudioVAE: |
|
|
def __init__(self, device: torch.device): |
|
|
self.model = MusicDCAE().to(device) |
|
|
self.model.eval() |
|
|
self.device = device |
|
|
self.latent_mean = torch.tensor( |
|
|
[0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], |
|
|
device=device, |
|
|
).view(1, -1, 1, 1) |
|
|
self.latent_std = torch.tensor( |
|
|
[0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], |
|
|
device=device, |
|
|
).view(1, -1, 1, 1) |
|
|
|
|
|
def encode(self, audio): |
|
|
|
|
|
with torch.no_grad(): |
|
|
audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) |
|
|
latents, _ = self.model.encode(audio, audio_lengths, sr=48000) |
|
|
latents = (latents - self.latent_mean) / self.latent_std |
|
|
return latents |
|
|
|
|
|
def decode(self, latents: torch.Tensor) -> torch.Tensor: |
|
|
with torch.no_grad(): |
|
|
latents = latents * self.latent_std + self.latent_mean |
|
|
_, audio_list = self.model.decode(latents, sr=48000) |
|
|
audio_batch = torch.stack(audio_list).to(self.device) |
|
|
return audio_batch |
|
|
|
|
|
def load_audio(audio_path, target_sr=48000): |
|
|
"""Load and preprocess audio file.""" |
|
|
audio, sr = torchaudio.load(audio_path) |
|
|
|
|
|
if audio.shape[0] == 1: |
|
|
audio = audio.repeat(2, 1) |
|
|
elif audio.shape[0] > 2: |
|
|
audio = audio[:2] |
|
|
|
|
|
if sr != target_sr: |
|
|
resampler = torchaudio.transforms.Resample(sr, target_sr) |
|
|
audio = resampler(audio) |
|
|
|
|
|
return audio |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Encode audio files to VAE latents') |
|
|
|
|
|
parser.add_argument('--audio-dir', type=str, required=True, |
|
|
help='Directory containing audio files') |
|
|
parser.add_argument('--output-dir', type=str, default="latents", |
|
|
help='Directory to save encoded latents') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
audio_dir = Path(args.audio_dir) |
|
|
audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.ogg', '*.m4a'] |
|
|
audio_files = [] |
|
|
for ext in audio_extensions: |
|
|
audio_files.extend(list(audio_dir.glob(ext))) |
|
|
audio_files = sorted(audio_files) |
|
|
|
|
|
if len(audio_files) == 0: |
|
|
raise ValueError(f"No audio files found in {args.audio_dir}") |
|
|
|
|
|
print(f"Found {len(audio_files)} audio files") |
|
|
|
|
|
vae = AudioVAE(device) |
|
|
print("VAE loaded") |
|
|
|
|
|
|
|
|
print("\nEncoding audio files...") |
|
|
for audio_path in tqdm(audio_files, desc="Encoding"): |
|
|
try: |
|
|
audio = load_audio(audio_path) |
|
|
audio = audio.unsqueeze(0).to(device) |
|
|
latents = vae.encode(audio) |
|
|
latents = latents.squeeze(0) |
|
|
|
|
|
output_path = output_dir / f"{audio_path.stem}.pt" |
|
|
torch.save(latents.cpu(), output_path) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nError encoding {audio_path.name}: {e}") |
|
|
continue |
|
|
|
|
|
print(f"\nEncoding complete! Saved {len(list(output_dir.glob('*.pt')))} latent files to {output_dir}") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|