Spaces:
Running on Zero
Running on Zero
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| class Pretransform(nn.Module): | |
| def __init__(self, enable_grad, io_channels, is_discrete): | |
| super().__init__() | |
| self.is_discrete = is_discrete | |
| self.io_channels = io_channels | |
| self.encoded_channels = None | |
| self.downsampling_ratio = None | |
| self.enable_grad = enable_grad | |
| def encode(self, x): | |
| raise NotImplementedError | |
| def decode(self, z): | |
| raise NotImplementedError | |
| def tokenize(self, x): | |
| raise NotImplementedError | |
| def decode_tokens(self, tokens): | |
| raise NotImplementedError | |
| class AutoencoderPretransform(Pretransform): | |
| def __init__( | |
| self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False | |
| ): | |
| super().__init__( | |
| enable_grad=False, | |
| io_channels=model.io_channels, | |
| is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete, | |
| ) | |
| self.model = model | |
| self.model.requires_grad_(False).eval() | |
| self.scale = scale | |
| self.downsampling_ratio = model.downsampling_ratio | |
| self.io_channels = model.io_channels | |
| self.sample_rate = model.sample_rate | |
| self.model_half = model_half | |
| self.iterate_batch = iterate_batch | |
| self.encoded_channels = model.latent_dim | |
| self.latent_dim = model.latent_dim | |
| self.chunked = chunked | |
| self.num_quantizers = ( | |
| model.bottleneck.num_quantizers | |
| if model.bottleneck is not None and model.bottleneck.is_discrete | |
| else None | |
| ) | |
| self.codebook_size = ( | |
| model.bottleneck.codebook_size | |
| if model.bottleneck is not None and model.bottleneck.is_discrete | |
| else None | |
| ) | |
| if self.model_half: | |
| self.model.half() | |
| def encode(self, x, **kwargs): | |
| if self.model_half: | |
| x = x.half() | |
| self.model.to(torch.float16) | |
| encoded = self.model.encode_audio( | |
| x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs | |
| ) | |
| if self.model_half: | |
| encoded = encoded.float() | |
| return encoded / self.scale | |
| def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): | |
| """ | |
| Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. | |
| If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. | |
| Overlap and chunk_size params are both measured in number of latents (not audio samples) | |
| # and therefore you likely could use the same values with decode_audio. | |
| A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. | |
| Every autoencoder will have a different receptive field size, and thus ideal overlap. | |
| You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. | |
| The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. | |
| Smaller chunk_size uses less memory, but more compute. | |
| The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version | |
| For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks | |
| """ | |
| if not chunked: | |
| # default behavior. Encode the entire audio in parallel | |
| return self.encode(audio, **kwargs) | |
| else: | |
| # CHUNKED ENCODING | |
| # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) | |
| samples_per_latent = self.downsampling_ratio | |
| total_size = audio.shape[2] # in samples | |
| batch_size = audio.shape[0] | |
| chunk_size *= samples_per_latent # converting metric in latents to samples | |
| overlap *= samples_per_latent # converting metric in latents to samples | |
| hop_size = chunk_size - overlap | |
| chunks = [] | |
| for i in range(0, total_size - chunk_size + 1, hop_size): | |
| chunk = audio[:, :, i : i + chunk_size] | |
| chunks.append(chunk) | |
| if i + chunk_size != total_size: | |
| # Final chunk | |
| chunk = audio[:, :, -chunk_size:] | |
| chunks.append(chunk) | |
| chunks = torch.stack(chunks) | |
| num_chunks = chunks.shape[0] | |
| # Note: y_size might be a different value from the latent length used in diffusion training | |
| # because we can encode audio of varying lengths | |
| # However, the audio should've been padded to a multiple of samples_per_latent by now. | |
| y_size = total_size // samples_per_latent | |
| # Create an empty latent, we will populate it with chunks as we encode them | |
| y_final = torch.zeros((batch_size, self.latent_dim, y_size)).to( | |
| audio.device | |
| ) | |
| for i in range(num_chunks): | |
| x_chunk = chunks[i, :] | |
| # encode the chunk | |
| y_chunk = self.encode(x_chunk) | |
| # figure out where to put the audio along the time domain | |
| if i == num_chunks - 1: | |
| # final chunk always goes at the end | |
| t_end = y_size | |
| t_start = t_end - y_chunk.shape[2] | |
| else: | |
| t_start = i * hop_size // samples_per_latent | |
| t_end = t_start + chunk_size // samples_per_latent | |
| # remove the edges of the overlaps | |
| ol = overlap // samples_per_latent // 2 | |
| chunk_start = 0 | |
| chunk_end = y_chunk.shape[2] | |
| if i > 0: | |
| # no overlap for the start of the first chunk | |
| t_start += ol | |
| chunk_start += ol | |
| if i < num_chunks - 1: | |
| # no overlap for the end of the last chunk | |
| t_end -= ol | |
| chunk_end -= ol | |
| # paste the chunked audio into our y_final output audio | |
| y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end] | |
| return y_final | |
| def decode(self, z, **kwargs): | |
| z = z * self.scale | |
| if self.model_half: | |
| z = z.half() | |
| self.model.to(torch.float16) | |
| decoded = self.model.decode_audio( | |
| z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs | |
| ) | |
| if self.model_half: | |
| decoded = decoded.float() | |
| return decoded | |
| def decode_audio( | |
| self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs | |
| ): | |
| if not chunked: | |
| # default behavior. Decode the entire latent in parallel | |
| return self.decode(latents, **kwargs) | |
| else: | |
| # chunked decoding | |
| hop_size = chunk_size - overlap | |
| total_size = latents.shape[2] | |
| batch_size = latents.shape[0] | |
| chunks = [] | |
| i = 0 | |
| for i in range(0, total_size - chunk_size + 1, hop_size): | |
| chunk = latents[:, :, i : i + chunk_size] | |
| chunks.append(chunk) | |
| if i + chunk_size != total_size: | |
| # Final chunk | |
| chunk = latents[:, :, -chunk_size:] | |
| chunks.append(chunk) | |
| chunks = torch.stack(chunks) | |
| num_chunks = chunks.shape[0] | |
| # samples_per_latent is just the downsampling ratio | |
| samples_per_latent = self.downsampling_ratio | |
| # Create an empty waveform, we will populate it with chunks as decode them | |
| y_size = total_size * samples_per_latent | |
| y_final = torch.zeros((batch_size, self.io_channels, y_size)).to( | |
| latents.device | |
| ) | |
| for i in range(num_chunks): | |
| x_chunk = chunks[i, :] | |
| # decode the chunk | |
| y_chunk = self.decode(x_chunk) | |
| # figure out where to put the audio along the time domain | |
| if i == num_chunks - 1: | |
| # final chunk always goes at the end | |
| t_end = y_size | |
| t_start = t_end - y_chunk.shape[2] | |
| else: | |
| t_start = i * hop_size * samples_per_latent | |
| t_end = t_start + chunk_size * samples_per_latent | |
| # remove the edges of the overlaps | |
| ol = (overlap // 2) * samples_per_latent | |
| chunk_start = 0 | |
| chunk_end = y_chunk.shape[2] | |
| if i > 0: | |
| # no overlap for the start of the first chunk | |
| t_start += ol | |
| chunk_start += ol | |
| if i < num_chunks - 1: | |
| # no overlap for the end of the last chunk | |
| t_end -= ol | |
| chunk_end -= ol | |
| # paste the chunked audio into our y_final output audio | |
| y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end] | |
| return y_final | |
| def tokenize(self, x, **kwargs): | |
| assert self.model.is_discrete, "Cannot tokenize with a continuous model" | |
| _, info = self.model.encode(x, return_info=True, **kwargs) | |
| return info[self.model.bottleneck.tokens_id] | |
| def decode_tokens(self, tokens, **kwargs): | |
| assert self.model.is_discrete, "Cannot decode tokens with a continuous model" | |
| return self.model.decode_tokens(tokens, **kwargs) | |
| def load_state_dict(self, state_dict, strict=True): | |
| self.model.load_state_dict(state_dict, strict=strict) | |
| class WaveletPretransform(Pretransform): | |
| def __init__(self, channels, levels, wavelet): | |
| super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) | |
| from .wavelets import WaveletDecode1d, WaveletEncode1d | |
| self.encoder = WaveletEncode1d(channels, levels, wavelet) | |
| self.decoder = WaveletDecode1d(channels, levels, wavelet) | |
| self.downsampling_ratio = 2**levels | |
| self.io_channels = channels | |
| self.encoded_channels = channels * self.downsampling_ratio | |
| def encode(self, x): | |
| return self.encoder(x) | |
| def decode(self, z): | |
| return self.decoder(z) | |
| class PQMFPretransform(Pretransform): | |
| def __init__(self, attenuation=100, num_bands=16): | |
| # TODO: Fix PQMF to take in in-channels | |
| super().__init__(enable_grad=False, io_channels=1, is_discrete=False) | |
| from .pqmf import PQMF | |
| self.pqmf = PQMF(attenuation, num_bands) | |
| def encode(self, x): | |
| # x is (Batch x Channels x Time) | |
| x = self.pqmf.forward(x) | |
| # pqmf.forward returns (Batch x Channels x Bands x Time) | |
| # but Pretransform needs Batch x Channels x Time | |
| # so concatenate channels and bands into one axis | |
| return rearrange(x, "b c n t -> b (c n) t") | |
| def decode(self, x): | |
| # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) | |
| x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) | |
| # returns (Batch x Channels x Time) | |
| return self.pqmf.inverse(x) | |
| class PretrainedDACPretransform(Pretransform): | |
| def __init__( | |
| self, | |
| model_type="44khz", | |
| model_bitrate="8kbps", | |
| scale=1.0, | |
| quantize_on_decode: bool = True, | |
| chunked=True, | |
| ): | |
| super().__init__(enable_grad=False, io_channels=1, is_discrete=True) | |
| import dac | |
| model_path = dac.utils.download( | |
| model_type=model_type, model_bitrate=model_bitrate | |
| ) | |
| self.model = dac.DAC.load(model_path) | |
| self.quantize_on_decode = quantize_on_decode | |
| if model_type == "44khz": | |
| self.downsampling_ratio = 512 | |
| else: | |
| self.downsampling_ratio = 320 | |
| self.io_channels = 1 | |
| self.scale = scale | |
| self.chunked = chunked | |
| self.encoded_channels = self.model.latent_dim | |
| self.num_quantizers = self.model.n_codebooks | |
| self.codebook_size = self.model.codebook_size | |
| def encode(self, x): | |
| latents = self.model.encoder(x) | |
| if self.quantize_on_decode: | |
| output = latents | |
| else: | |
| z, _, _, _, _ = self.model.quantizer( | |
| latents, n_quantizers=self.model.n_codebooks | |
| ) | |
| output = z | |
| if self.scale != 1.0: | |
| output = output / self.scale | |
| return output | |
| def decode(self, z): | |
| if self.scale != 1.0: | |
| z = z * self.scale | |
| if self.quantize_on_decode: | |
| z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) | |
| return self.model.decode(z) | |
| def tokenize(self, x): | |
| return self.model.encode(x)[1] | |
| def decode_tokens(self, tokens): | |
| latents = self.model.quantizer.from_codes(tokens) | |
| return self.model.decode(latents) | |
| class AudiocraftCompressionPretransform(Pretransform): | |
| def __init__( | |
| self, | |
| model_type="facebook/encodec_32khz", | |
| scale=1.0, | |
| quantize_on_decode: bool = True, | |
| ): | |
| super().__init__(enable_grad=False, io_channels=1, is_discrete=True) | |
| try: | |
| from audiocraft.models import CompressionModel | |
| except ImportError: | |
| raise ImportError( | |
| "Audiocraft is not installed. Please install audiocraft to use Audiocraft models." | |
| ) | |
| self.model = CompressionModel.get_pretrained(model_type) | |
| self.quantize_on_decode = quantize_on_decode | |
| self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) | |
| self.sample_rate = self.model.sample_rate | |
| self.io_channels = self.model.channels | |
| self.scale = scale | |
| # self.encoded_channels = self.model.latent_dim | |
| self.num_quantizers = self.model.num_codebooks | |
| self.codebook_size = self.model.cardinality | |
| self.model.to(torch.float16).eval().requires_grad_(False) | |
| def encode(self, x): | |
| assert False, "Audiocraft compression models do not support continuous encoding" | |
| # latents = self.model.encoder(x) | |
| # if self.quantize_on_decode: | |
| # output = latents | |
| # else: | |
| # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) | |
| # output = z | |
| # if self.scale != 1.0: | |
| # output = output / self.scale | |
| # return output | |
| def decode(self, z): | |
| assert False, "Audiocraft compression models do not support continuous decoding" | |
| # if self.scale != 1.0: | |
| # z = z * self.scale | |
| # if self.quantize_on_decode: | |
| # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) | |
| # return self.model.decode(z) | |
| def tokenize(self, x): | |
| with torch.cuda.amp.autocast(enabled=False): | |
| return self.model.encode(x.to(torch.float16))[0] | |
| def decode_tokens(self, tokens): | |
| with torch.cuda.amp.autocast(enabled=False): | |
| return self.model.decode(tokens) | |