Spaces:
Running on Zero
Running on Zero
File size: 5,259 Bytes
64ec292 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import json
import torch
import torchaudio.transforms as T
from torch import nn
from .autoencoders import create_autoencoder_from_config
from .utils import load_ckpt_state_dict
class PadCrop(nn.Module):
def __init__(self, n_samples, randomize=True):
super().__init__()
self.n_samples = n_samples
self.randomize = randomize
def __call__(self, signal):
n, s = signal.shape
start = (
0
if (not self.randomize)
else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
)
end = start + self.n_samples
output = signal.new_zeros([n, self.n_samples])
output[:, : min(s, self.n_samples)] = signal[:, start:end]
return output
def set_audio_channels(audio, target_channels):
if target_channels == 1:
audio = audio.mean(1, keepdim=True)
elif target_channels == 2:
if audio.shape[1] == 1:
audio = audio.repeat(1, 2, 1)
elif audio.shape[1] > 2:
audio = audio[:, :2, :]
return audio
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
audio = audio.to(device)
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
assert target_length is None
if target_length is None:
target_length = audio.shape[-1]
audio = PadCrop(target_length, randomize=False)(audio)
# Add batch dimension
if audio.dim() == 1:
audio = audio.unsqueeze(0).unsqueeze(0)
elif audio.dim() == 2:
audio = audio.unsqueeze(0)
audio = set_audio_channels(audio, target_channels)
return audio
class StableAudioInfer(nn.Module):
def __init__(self, model_config_path, model_ckpt_path=None):
super().__init__()
with open(model_config_path) as f:
self.model_config = json.load(f)
self.model = create_autoencoder_from_config(self.model_config)
if model_ckpt_path is not None:
self.model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
self.sample_rate = self.model_config["sample_rate"]
self.sample_size = self.model_config["sample_size"]
self.io_channels = self.model.io_channels
self.sample_size = 24576
@property
def device(self):
return next(self.parameters()).device
def normalize_audio(self, y, target_dbfs=0):
"""Normalize audio to a specific dBFS level."""
max_amplitude = torch.max(torch.abs(y))
target_amplitude = 10.0 ** (target_dbfs / 20.0)
scale_factor = target_amplitude / max_amplitude
return y * scale_factor
def encode_audio(self, input_audio, in_sr):
"""Encode audio waveform into VAE latent representation.
Args:
input_audio: Input audio tensor.
in_sr: Input sample rate.
Returns:
Latent tensor from the VAE encoder.
"""
input_audio = prepare_audio(
input_audio,
in_sr=in_sr,
target_sr=self.model.sample_rate,
target_length=None, # Determined after resampling
target_channels=self.io_channels,
device=self.device,
)
input_audio = self.normalize_audio(input_audio, -6)
with torch.no_grad():
# Use chunked encoding for long audio to save memory
if input_audio.shape[-1] > (128 + 10) * self.model.sample_rate:
latent = self.model.encode_audio(input_audio, chunked=True)
else:
latent = self.model.encode_audio(input_audio, chunked=False)
return latent
def decode_audio(self, latent):
"""Decode VAE latent back to audio waveform.
Args:
latent: Latent tensor.
Returns:
Decoded audio tensor.
"""
with torch.no_grad():
# Use chunked decoding for long latents to save memory
if latent.shape[-1] > 128 + 10:
output = self.model.decode_audio(latent, chunked=True)
else:
output = self.model.decode_audio(latent, chunked=False)
return output
def forward(self, func_type, x, sr=None):
x = x.to(next(self.parameters()).device)
if func_type == "encode":
assert sr is not None, "sr is required for encoding"
return self.encode_audio(input_audio=x, in_sr=sr)
elif func_type == "decode":
return self.decode_audio(x)
else:
raise ValueError(f"Unknown func_type: {func_type}")
if __name__ == "__main__":
import torchaudio
device = "cuda"
vae_model = StableAudioInfer(
model_config_path="config/stable_audio_2_0_vae_20hz_official.json",
model_ckpt_path="ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
)
vae_model = vae_model.eval().to(device)
input_audio, in_sr = torchaudio.load("path/to/input.wav")
latent = vae_model(func_type="encode", x=input_audio, sr=in_sr)
output_audio = vae_model(func_type="decode", x=latent, sr=None)
output_audio = output_audio.squeeze(0).cpu()
torchaudio.save("output.wav", output_audio, sample_rate=44100)
|