|
|
from abc import ABC, abstractmethod
|
|
|
from math import floor
|
|
|
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from einops import pack, rearrange, unpack
|
|
|
from torch import Generator, Tensor, nn
|
|
|
|
|
|
from .components import AppendChannelsPlugin, MelSpectrogram
|
|
|
from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler
|
|
|
from .utils import (
|
|
|
closest_power_2,
|
|
|
default,
|
|
|
downsample,
|
|
|
exists,
|
|
|
groupby,
|
|
|
randn_like,
|
|
|
upsample,
|
|
|
)
|
|
|
|
|
|
|
|
|
class DiffusionModel(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
net_t: Callable,
|
|
|
diffusion_t: Callable = VDiffusion,
|
|
|
sampler_t: Callable = VSampler,
|
|
|
loss_fn: Callable = torch.nn.functional.mse_loss,
|
|
|
dim: int = 1,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__()
|
|
|
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
|
|
|
sampler_kwargs, kwargs = groupby("sampler_", kwargs)
|
|
|
|
|
|
self.net = net_t(dim=dim, **kwargs)
|
|
|
self.diffusion = diffusion_t(net=self.net, loss_fn=loss_fn, **diffusion_kwargs)
|
|
|
self.sampler = sampler_t(net=self.net, **sampler_kwargs)
|
|
|
|
|
|
def forward(self, *args, **kwargs) -> Tensor:
|
|
|
return self.diffusion(*args, **kwargs)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def sample(self, *args, **kwargs) -> Tensor:
|
|
|
return self.sampler(*args, **kwargs)
|
|
|
|
|
|
|
|
|
class EncoderBase(nn.Module, ABC):
|
|
|
"""Abstract class for DiffusionAE encoder"""
|
|
|
|
|
|
@abstractmethod
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.out_channels = None
|
|
|
self.downsample_factor = None
|
|
|
|
|
|
|
|
|
class AdapterBase(nn.Module, ABC):
|
|
|
"""Abstract class for DiffusionAE encoder"""
|
|
|
|
|
|
@abstractmethod
|
|
|
def encode(self, x: Tensor) -> Tensor:
|
|
|
pass
|
|
|
|
|
|
@abstractmethod
|
|
|
def decode(self, x: Tensor) -> Tensor:
|
|
|
pass
|
|
|
|
|
|
|
|
|
class DiffusionAE(DiffusionModel):
|
|
|
"""Diffusion Auto Encoder"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
channels: Sequence[int],
|
|
|
encoder: EncoderBase,
|
|
|
inject_depth: int,
|
|
|
latent_factor: Optional[int] = None,
|
|
|
adapter: Optional[AdapterBase] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
context_channels = [0] * len(channels)
|
|
|
context_channels[inject_depth] = encoder.out_channels
|
|
|
super().__init__(
|
|
|
in_channels=in_channels,
|
|
|
channels=channels,
|
|
|
context_channels=context_channels,
|
|
|
**kwargs,
|
|
|
)
|
|
|
self.in_channels = in_channels
|
|
|
self.encoder = encoder
|
|
|
self.inject_depth = inject_depth
|
|
|
|
|
|
self.latent_factor = default(latent_factor, self.encoder.downsample_factor)
|
|
|
self.adapter = adapter.requires_grad_(False) if exists(adapter) else None
|
|
|
|
|
|
def forward(
|
|
|
self, x: Tensor, with_info: bool = False, **kwargs
|
|
|
) -> Union[Tensor, Tuple[Tensor, Any]]:
|
|
|
|
|
|
latent, info = self.encode(x, with_info=True)
|
|
|
channels = [None] * self.inject_depth + [latent]
|
|
|
|
|
|
x = self.adapter.encode(x) if exists(self.adapter) else x
|
|
|
|
|
|
loss = super().forward(x, channels=channels, **kwargs)
|
|
|
return (loss, info) if with_info else loss
|
|
|
|
|
|
def encode(self, *args, **kwargs):
|
|
|
return self.encoder(*args, **kwargs)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def decode(
|
|
|
self, latent: Tensor, generator: Optional[Generator] = None, **kwargs
|
|
|
) -> Tensor:
|
|
|
b = latent.shape[0]
|
|
|
noise_length = closest_power_2(latent.shape[2] * self.latent_factor)
|
|
|
|
|
|
noise = torch.randn(
|
|
|
(b, self.in_channels, noise_length),
|
|
|
device=latent.device,
|
|
|
dtype=latent.dtype,
|
|
|
generator=generator,
|
|
|
)
|
|
|
|
|
|
channels = [None] * self.inject_depth + [latent]
|
|
|
|
|
|
out = super().sample(noise, channels=channels, **kwargs)
|
|
|
|
|
|
return self.adapter.decode(out) if exists(self.adapter) else out
|
|
|
|
|
|
|
|
|
class DiffusionUpsampler(DiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
upsample_factor: int,
|
|
|
net_t: Callable,
|
|
|
**kwargs,
|
|
|
):
|
|
|
self.upsample_factor = upsample_factor
|
|
|
super().__init__(
|
|
|
net_t=AppendChannelsPlugin(net_t, channels=in_channels),
|
|
|
in_channels=in_channels,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
def reupsample(self, x: Tensor) -> Tensor:
|
|
|
x = x.clone()
|
|
|
x = downsample(x, factor=self.upsample_factor)
|
|
|
x = upsample(x, factor=self.upsample_factor)
|
|
|
return x
|
|
|
|
|
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
|
|
reupsampled = self.reupsample(x)
|
|
|
return super().forward(x, *args, append_channels=reupsampled, **kwargs)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def sample(
|
|
|
self, downsampled: Tensor, generator: Optional[Generator] = None, **kwargs
|
|
|
) -> Tensor:
|
|
|
reupsampled = upsample(downsampled, factor=self.upsample_factor)
|
|
|
noise = randn_like(reupsampled, generator=generator)
|
|
|
return super().sample(noise, append_channels=reupsampled, **kwargs)
|
|
|
|
|
|
|
|
|
class DiffusionVocoder(DiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
net_t: Callable,
|
|
|
mel_channels: int,
|
|
|
mel_n_fft: int,
|
|
|
mel_hop_length: Optional[int] = None,
|
|
|
mel_win_length: Optional[int] = None,
|
|
|
in_channels: int = 1,
|
|
|
**kwargs,
|
|
|
):
|
|
|
mel_hop_length = default(mel_hop_length, floor(mel_n_fft) // 4)
|
|
|
mel_win_length = default(mel_win_length, mel_n_fft)
|
|
|
mel_kwargs, kwargs = groupby("mel_", kwargs)
|
|
|
super().__init__(
|
|
|
net_t=AppendChannelsPlugin(net_t, channels=1),
|
|
|
in_channels=1,
|
|
|
**kwargs,
|
|
|
)
|
|
|
self.to_spectrogram = MelSpectrogram(
|
|
|
n_fft=mel_n_fft,
|
|
|
hop_length=mel_hop_length,
|
|
|
win_length=mel_win_length,
|
|
|
n_mel_channels=mel_channels,
|
|
|
**mel_kwargs,
|
|
|
)
|
|
|
self.to_flat = nn.ConvTranspose1d(
|
|
|
in_channels=mel_channels,
|
|
|
out_channels=1,
|
|
|
kernel_size=mel_win_length,
|
|
|
stride=mel_hop_length,
|
|
|
padding=(mel_win_length - mel_hop_length) // 2,
|
|
|
bias=False,
|
|
|
)
|
|
|
|
|
|
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
|
|
|
|
|
spectrogram = rearrange(self.to_spectrogram(x), "b c f l -> (b c) f l")
|
|
|
spectrogram_flat = self.to_flat(spectrogram)
|
|
|
|
|
|
x = rearrange(x, "b c t -> (b c) 1 t")
|
|
|
return super().forward(x, *args, append_channels=spectrogram_flat, **kwargs)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def sample(
|
|
|
self, spectrogram: Tensor, generator: Optional[Generator] = None, **kwargs
|
|
|
) -> Tensor:
|
|
|
|
|
|
spectrogram, ps = pack([spectrogram], "* f l")
|
|
|
spectrogram_flat = self.to_flat(spectrogram)
|
|
|
|
|
|
noise = randn_like(spectrogram_flat, generator=generator)
|
|
|
waveform = super().sample(noise, append_channels=spectrogram_flat, **kwargs)
|
|
|
|
|
|
waveform = rearrange(waveform, "... 1 t -> ... t")
|
|
|
waveform = unpack(waveform, ps, "* t")[0]
|
|
|
return waveform
|
|
|
|
|
|
|
|
|
class DiffusionAR(DiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
length: int,
|
|
|
num_splits: int,
|
|
|
diffusion_t: Callable = ARVDiffusion,
|
|
|
sampler_t: Callable = ARVSampler,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(
|
|
|
in_channels=in_channels + 1,
|
|
|
out_channels=in_channels,
|
|
|
diffusion_t=diffusion_t,
|
|
|
diffusion_length=length,
|
|
|
diffusion_num_splits=num_splits,
|
|
|
sampler_t=sampler_t,
|
|
|
sampler_in_channels=in_channels,
|
|
|
sampler_length=length,
|
|
|
sampler_num_splits=num_splits,
|
|
|
use_time_conditioning=False,
|
|
|
use_modulation=False,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|