LocalSong / train_lora_encode_latents.py
Localsong's picture
Upload 5 files
12bbde9 verified
import torch
import torchaudio
from pathlib import Path
import argparse
from tqdm import tqdm
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
class AudioVAE:
def __init__(self, device: torch.device):
self.model = MusicDCAE().to(device)
self.model.eval()
self.device = device
self.latent_mean = torch.tensor(
[0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526],
device=device,
).view(1, -1, 1, 1)
self.latent_std = torch.tensor(
[0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707],
device=device,
).view(1, -1, 1, 1)
def encode(self, audio):
with torch.no_grad():
audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device)
latents, _ = self.model.encode(audio, audio_lengths, sr=48000)
latents = (latents - self.latent_mean) / self.latent_std
return latents
def decode(self, latents: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
latents = latents * self.latent_std + self.latent_mean
_, audio_list = self.model.decode(latents, sr=48000)
audio_batch = torch.stack(audio_list).to(self.device)
return audio_batch
def load_audio(audio_path, target_sr=48000):
"""Load and preprocess audio file."""
audio, sr = torchaudio.load(audio_path)
if audio.shape[0] == 1:
audio = audio.repeat(2, 1)
elif audio.shape[0] > 2:
audio = audio[:2]
if sr != target_sr:
resampler = torchaudio.transforms.Resample(sr, target_sr)
audio = resampler(audio)
return audio
def main():
parser = argparse.ArgumentParser(description='Encode audio files to VAE latents')
parser.add_argument('--audio-dir', type=str, required=True,
help='Directory containing audio files')
parser.add_argument('--output-dir', type=str, default="latents",
help='Directory to save encoded latents')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
audio_dir = Path(args.audio_dir)
audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.ogg', '*.m4a']
audio_files = []
for ext in audio_extensions:
audio_files.extend(list(audio_dir.glob(ext)))
audio_files = sorted(audio_files)
if len(audio_files) == 0:
raise ValueError(f"No audio files found in {args.audio_dir}")
print(f"Found {len(audio_files)} audio files")
vae = AudioVAE(device)
print("VAE loaded")
# Encode each audio file
print("\nEncoding audio files...")
for audio_path in tqdm(audio_files, desc="Encoding"):
try:
audio = load_audio(audio_path)
audio = audio.unsqueeze(0).to(device)
latents = vae.encode(audio)
latents = latents.squeeze(0)
output_path = output_dir / f"{audio_path.stem}.pt"
torch.save(latents.cpu(), output_path)
except Exception as e:
print(f"\nError encoding {audio_path.name}: {e}")
continue
print(f"\nEncoding complete! Saved {len(list(output_dir.glob('*.pt')))} latent files to {output_dir}")
if __name__ == '__main__':
main()