xjsc0's picture
1
64ec292
import torch
from einops import rearrange
from torch import nn
class Pretransform(nn.Module):
def __init__(self, enable_grad, io_channels, is_discrete):
super().__init__()
self.is_discrete = is_discrete
self.io_channels = io_channels
self.encoded_channels = None
self.downsampling_ratio = None
self.enable_grad = enable_grad
def encode(self, x):
raise NotImplementedError
def decode(self, z):
raise NotImplementedError
def tokenize(self, x):
raise NotImplementedError
def decode_tokens(self, tokens):
raise NotImplementedError
class AutoencoderPretransform(Pretransform):
def __init__(
self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False
):
super().__init__(
enable_grad=False,
io_channels=model.io_channels,
is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete,
)
self.model = model
self.model.requires_grad_(False).eval()
self.scale = scale
self.downsampling_ratio = model.downsampling_ratio
self.io_channels = model.io_channels
self.sample_rate = model.sample_rate
self.model_half = model_half
self.iterate_batch = iterate_batch
self.encoded_channels = model.latent_dim
self.latent_dim = model.latent_dim
self.chunked = chunked
self.num_quantizers = (
model.bottleneck.num_quantizers
if model.bottleneck is not None and model.bottleneck.is_discrete
else None
)
self.codebook_size = (
model.bottleneck.codebook_size
if model.bottleneck is not None and model.bottleneck.is_discrete
else None
)
if self.model_half:
self.model.half()
def encode(self, x, **kwargs):
if self.model_half:
x = x.half()
self.model.to(torch.float16)
encoded = self.model.encode_audio(
x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs
)
if self.model_half:
encoded = encoded.float()
return encoded / self.scale
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
"""
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
Overlap and chunk_size params are both measured in number of latents (not audio samples)
# and therefore you likely could use the same values with decode_audio.
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
Every autoencoder will have a different receptive field size, and thus ideal overlap.
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
Smaller chunk_size uses less memory, but more compute.
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
"""
if not chunked:
# default behavior. Encode the entire audio in parallel
return self.encode(audio, **kwargs)
else:
# CHUNKED ENCODING
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
samples_per_latent = self.downsampling_ratio
total_size = audio.shape[2] # in samples
batch_size = audio.shape[0]
chunk_size *= samples_per_latent # converting metric in latents to samples
overlap *= samples_per_latent # converting metric in latents to samples
hop_size = chunk_size - overlap
chunks = []
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = audio[:, :, i : i + chunk_size]
chunks.append(chunk)
if i + chunk_size != total_size:
# Final chunk
chunk = audio[:, :, -chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# Note: y_size might be a different value from the latent length used in diffusion training
# because we can encode audio of varying lengths
# However, the audio should've been padded to a multiple of samples_per_latent by now.
y_size = total_size // samples_per_latent
# Create an empty latent, we will populate it with chunks as we encode them
y_final = torch.zeros((batch_size, self.latent_dim, y_size)).to(
audio.device
)
for i in range(num_chunks):
x_chunk = chunks[i, :]
# encode the chunk
y_chunk = self.encode(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks - 1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size // samples_per_latent
t_end = t_start + chunk_size // samples_per_latent
# remove the edges of the overlaps
ol = overlap // samples_per_latent // 2
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks - 1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
return y_final
def decode(self, z, **kwargs):
z = z * self.scale
if self.model_half:
z = z.half()
self.model.to(torch.float16)
decoded = self.model.decode_audio(
z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs
)
if self.model_half:
decoded = decoded.float()
return decoded
def decode_audio(
self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs
):
if not chunked:
# default behavior. Decode the entire latent in parallel
return self.decode(latents, **kwargs)
else:
# chunked decoding
hop_size = chunk_size - overlap
total_size = latents.shape[2]
batch_size = latents.shape[0]
chunks = []
i = 0
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = latents[:, :, i : i + chunk_size]
chunks.append(chunk)
if i + chunk_size != total_size:
# Final chunk
chunk = latents[:, :, -chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# samples_per_latent is just the downsampling ratio
samples_per_latent = self.downsampling_ratio
# Create an empty waveform, we will populate it with chunks as decode them
y_size = total_size * samples_per_latent
y_final = torch.zeros((batch_size, self.io_channels, y_size)).to(
latents.device
)
for i in range(num_chunks):
x_chunk = chunks[i, :]
# decode the chunk
y_chunk = self.decode(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks - 1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size * samples_per_latent
t_end = t_start + chunk_size * samples_per_latent
# remove the edges of the overlaps
ol = (overlap // 2) * samples_per_latent
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks - 1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
return y_final
def tokenize(self, x, **kwargs):
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
_, info = self.model.encode(x, return_info=True, **kwargs)
return info[self.model.bottleneck.tokens_id]
def decode_tokens(self, tokens, **kwargs):
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
return self.model.decode_tokens(tokens, **kwargs)
def load_state_dict(self, state_dict, strict=True):
self.model.load_state_dict(state_dict, strict=strict)
class WaveletPretransform(Pretransform):
def __init__(self, channels, levels, wavelet):
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
from .wavelets import WaveletDecode1d, WaveletEncode1d
self.encoder = WaveletEncode1d(channels, levels, wavelet)
self.decoder = WaveletDecode1d(channels, levels, wavelet)
self.downsampling_ratio = 2**levels
self.io_channels = channels
self.encoded_channels = channels * self.downsampling_ratio
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
class PQMFPretransform(Pretransform):
def __init__(self, attenuation=100, num_bands=16):
# TODO: Fix PQMF to take in in-channels
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
from .pqmf import PQMF
self.pqmf = PQMF(attenuation, num_bands)
def encode(self, x):
# x is (Batch x Channels x Time)
x = self.pqmf.forward(x)
# pqmf.forward returns (Batch x Channels x Bands x Time)
# but Pretransform needs Batch x Channels x Time
# so concatenate channels and bands into one axis
return rearrange(x, "b c n t -> b (c n) t")
def decode(self, x):
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
# returns (Batch x Channels x Time)
return self.pqmf.inverse(x)
class PretrainedDACPretransform(Pretransform):
def __init__(
self,
model_type="44khz",
model_bitrate="8kbps",
scale=1.0,
quantize_on_decode: bool = True,
chunked=True,
):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
import dac
model_path = dac.utils.download(
model_type=model_type, model_bitrate=model_bitrate
)
self.model = dac.DAC.load(model_path)
self.quantize_on_decode = quantize_on_decode
if model_type == "44khz":
self.downsampling_ratio = 512
else:
self.downsampling_ratio = 320
self.io_channels = 1
self.scale = scale
self.chunked = chunked
self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.n_codebooks
self.codebook_size = self.model.codebook_size
def encode(self, x):
latents = self.model.encoder(x)
if self.quantize_on_decode:
output = latents
else:
z, _, _, _, _ = self.model.quantizer(
latents, n_quantizers=self.model.n_codebooks
)
output = z
if self.scale != 1.0:
output = output / self.scale
return output
def decode(self, z):
if self.scale != 1.0:
z = z * self.scale
if self.quantize_on_decode:
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
return self.model.decode(z)
def tokenize(self, x):
return self.model.encode(x)[1]
def decode_tokens(self, tokens):
latents = self.model.quantizer.from_codes(tokens)
return self.model.decode(latents)
class AudiocraftCompressionPretransform(Pretransform):
def __init__(
self,
model_type="facebook/encodec_32khz",
scale=1.0,
quantize_on_decode: bool = True,
):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
try:
from audiocraft.models import CompressionModel
except ImportError:
raise ImportError(
"Audiocraft is not installed. Please install audiocraft to use Audiocraft models."
)
self.model = CompressionModel.get_pretrained(model_type)
self.quantize_on_decode = quantize_on_decode
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
self.sample_rate = self.model.sample_rate
self.io_channels = self.model.channels
self.scale = scale
# self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.num_codebooks
self.codebook_size = self.model.cardinality
self.model.to(torch.float16).eval().requires_grad_(False)
def encode(self, x):
assert False, "Audiocraft compression models do not support continuous encoding"
# latents = self.model.encoder(x)
# if self.quantize_on_decode:
# output = latents
# else:
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
# output = z
# if self.scale != 1.0:
# output = output / self.scale
# return output
def decode(self, z):
assert False, "Audiocraft compression models do not support continuous decoding"
# if self.scale != 1.0:
# z = z * self.scale
# if self.quantize_on_decode:
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
# return self.model.decode(z)
def tokenize(self, x):
with torch.cuda.amp.autocast(enabled=False):
return self.model.encode(x.to(torch.float16))[0]
def decode_tokens(self, tokens):
with torch.cuda.amp.autocast(enabled=False):
return self.model.decode(tokens)