|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
from inspect import isfunction
|
|
|
from math import ceil, floor, log, pi, log2
|
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
|
|
from packaging import version
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from einops import rearrange, reduce, repeat
|
|
|
from einops.layers.torch import Rearrange
|
|
|
from einops_exts import rearrange_many
|
|
|
from torch import Tensor, einsum
|
|
|
from torch.backends.cuda import sdp_kernel
|
|
|
from torch.nn import functional as F
|
|
|
from dac.nn.layers import Snake1d
|
|
|
|
|
|
"""
|
|
|
Utils
|
|
|
"""
|
|
|
|
|
|
|
|
|
class ConditionedSequential(nn.Module):
|
|
|
def __init__(self, *modules):
|
|
|
super().__init__()
|
|
|
self.module_list = nn.ModuleList(*modules)
|
|
|
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
|
|
|
for module in self.module_list:
|
|
|
x = module(x, mapping)
|
|
|
return x
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
|
|
|
if exists(val):
|
|
|
return val
|
|
|
return d() if isfunction(d) else d
|
|
|
|
|
|
def exists(val: Optional[T]) -> T:
|
|
|
return val is not None
|
|
|
|
|
|
def closest_power_2(x: float) -> int:
|
|
|
exponent = log2(x)
|
|
|
distance_fn = lambda z: abs(x - 2 ** z)
|
|
|
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
|
|
return 2 ** int(exponent_closest)
|
|
|
|
|
|
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
|
|
|
return_dicts: Tuple[Dict, Dict] = ({}, {})
|
|
|
for key in d.keys():
|
|
|
no_prefix = int(not key.startswith(prefix))
|
|
|
return_dicts[no_prefix][key] = d[key]
|
|
|
return return_dicts
|
|
|
|
|
|
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
|
|
|
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
|
|
|
if keep_prefix:
|
|
|
return kwargs_with_prefix, kwargs
|
|
|
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
|
|
|
return kwargs_no_prefix, kwargs
|
|
|
|
|
|
"""
|
|
|
Convolutional Blocks
|
|
|
"""
|
|
|
import typing as tp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
|
|
padding_total: int = 0) -> int:
|
|
|
"""See `pad_for_conv1d`."""
|
|
|
length = x.shape[-1]
|
|
|
n_frames = (length - kernel_size + padding_total) / stride + 1
|
|
|
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
|
|
return ideal_length - length
|
|
|
|
|
|
|
|
|
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
|
|
"""Pad for a convolution to make sure that the last window is full.
|
|
|
Extra padding is added at the end. This is required to ensure that we can rebuild
|
|
|
an output of the same length, as otherwise, even with padding, some time steps
|
|
|
might get removed.
|
|
|
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
|
|
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
|
|
1 2 3 # (output frames of a convolution, last 0 is never used)
|
|
|
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
|
|
1 2 3 4 # once you removed padding, we are missing one time step !
|
|
|
"""
|
|
|
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
|
|
return F.pad(x, (0, extra_padding))
|
|
|
|
|
|
|
|
|
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
|
|
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
|
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
|
|
"""
|
|
|
length = x.shape[-1]
|
|
|
padding_left, padding_right = paddings
|
|
|
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
|
if mode == 'reflect':
|
|
|
max_pad = max(padding_left, padding_right)
|
|
|
extra_pad = 0
|
|
|
if length <= max_pad:
|
|
|
extra_pad = max_pad - length + 1
|
|
|
x = F.pad(x, (0, extra_pad))
|
|
|
padded = F.pad(x, paddings, mode, value)
|
|
|
end = padded.shape[-1] - extra_pad
|
|
|
return padded[..., :end]
|
|
|
else:
|
|
|
return F.pad(x, paddings, mode, value)
|
|
|
|
|
|
|
|
|
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
|
|
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
|
|
padding_left, padding_right = paddings
|
|
|
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
|
assert (padding_left + padding_right) <= x.shape[-1]
|
|
|
end = x.shape[-1] - padding_right
|
|
|
return x[..., padding_left: end]
|
|
|
|
|
|
|
|
|
class Conv1d(nn.Conv1d):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def forward(self, x: Tensor, causal=False) -> Tensor:
|
|
|
kernel_size = self.kernel_size[0]
|
|
|
stride = self.stride[0]
|
|
|
dilation = self.dilation[0]
|
|
|
kernel_size = (kernel_size - 1) * dilation + 1
|
|
|
padding_total = kernel_size - stride
|
|
|
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
|
|
if causal:
|
|
|
|
|
|
x = pad1d(x, (padding_total, extra_padding))
|
|
|
else:
|
|
|
|
|
|
padding_right = padding_total // 2
|
|
|
padding_left = padding_total - padding_right
|
|
|
x = pad1d(x, (padding_left, padding_right + extra_padding))
|
|
|
return super().forward(x)
|
|
|
|
|
|
class ConvTranspose1d(nn.ConvTranspose1d):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def forward(self, x: Tensor, causal=False) -> Tensor:
|
|
|
kernel_size = self.kernel_size[0]
|
|
|
stride = self.stride[0]
|
|
|
padding_total = kernel_size - stride
|
|
|
|
|
|
y = super().forward(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if causal:
|
|
|
padding_right = ceil(padding_total)
|
|
|
padding_left = padding_total - padding_right
|
|
|
y = unpad1d(y, (padding_left, padding_right))
|
|
|
else:
|
|
|
|
|
|
padding_right = padding_total // 2
|
|
|
padding_left = padding_total - padding_right
|
|
|
y = unpad1d(y, (padding_left, padding_right))
|
|
|
return y
|
|
|
|
|
|
|
|
|
def Downsample1d(
|
|
|
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
|
|
) -> nn.Module:
|
|
|
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
|
|
|
|
|
return Conv1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels,
|
|
|
kernel_size=factor * kernel_multiplier + 1,
|
|
|
stride=factor
|
|
|
)
|
|
|
|
|
|
|
|
|
def Upsample1d(
|
|
|
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
|
|
) -> nn.Module:
|
|
|
|
|
|
if factor == 1:
|
|
|
return Conv1d(
|
|
|
in_channels=in_channels, out_channels=out_channels, kernel_size=3
|
|
|
)
|
|
|
|
|
|
if use_nearest:
|
|
|
return nn.Sequential(
|
|
|
nn.Upsample(scale_factor=factor, mode="nearest"),
|
|
|
Conv1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels,
|
|
|
kernel_size=3
|
|
|
),
|
|
|
)
|
|
|
else:
|
|
|
return ConvTranspose1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels,
|
|
|
kernel_size=factor * 2,
|
|
|
stride=factor
|
|
|
)
|
|
|
|
|
|
|
|
|
class ConvBlock1d(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
*,
|
|
|
kernel_size: int = 3,
|
|
|
stride: int = 1,
|
|
|
dilation: int = 1,
|
|
|
num_groups: int = 8,
|
|
|
use_norm: bool = True,
|
|
|
use_snake: bool = False
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
self.groupnorm = (
|
|
|
nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
|
|
|
if use_norm
|
|
|
else nn.Identity()
|
|
|
)
|
|
|
|
|
|
if use_snake:
|
|
|
self.activation = Snake1d(in_channels)
|
|
|
else:
|
|
|
self.activation = nn.SiLU()
|
|
|
|
|
|
self.project = Conv1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels,
|
|
|
kernel_size=kernel_size,
|
|
|
stride=stride,
|
|
|
dilation=dilation,
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
|
|
|
) -> Tensor:
|
|
|
x = self.groupnorm(x)
|
|
|
if exists(scale_shift):
|
|
|
scale, shift = scale_shift
|
|
|
x = x * (scale + 1) + shift
|
|
|
x = self.activation(x)
|
|
|
return self.project(x, causal=causal)
|
|
|
|
|
|
|
|
|
class MappingToScaleShift(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
features: int,
|
|
|
channels: int,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.to_scale_shift = nn.Sequential(
|
|
|
nn.SiLU(),
|
|
|
nn.Linear(in_features=features, out_features=channels * 2),
|
|
|
)
|
|
|
|
|
|
def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
scale_shift = self.to_scale_shift(mapping)
|
|
|
scale_shift = rearrange(scale_shift, "b c -> b c 1")
|
|
|
scale, shift = scale_shift.chunk(2, dim=1)
|
|
|
return scale, shift
|
|
|
|
|
|
|
|
|
class ResnetBlock1d(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
*,
|
|
|
kernel_size: int = 3,
|
|
|
stride: int = 1,
|
|
|
dilation: int = 1,
|
|
|
use_norm: bool = True,
|
|
|
use_snake: bool = False,
|
|
|
num_groups: int = 8,
|
|
|
context_mapping_features: Optional[int] = None,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
self.use_mapping = exists(context_mapping_features)
|
|
|
|
|
|
self.block1 = ConvBlock1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels,
|
|
|
kernel_size=kernel_size,
|
|
|
stride=stride,
|
|
|
dilation=dilation,
|
|
|
use_norm=use_norm,
|
|
|
num_groups=num_groups,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
if self.use_mapping:
|
|
|
assert exists(context_mapping_features)
|
|
|
self.to_scale_shift = MappingToScaleShift(
|
|
|
features=context_mapping_features, channels=out_channels
|
|
|
)
|
|
|
|
|
|
self.block2 = ConvBlock1d(
|
|
|
in_channels=out_channels,
|
|
|
out_channels=out_channels,
|
|
|
use_norm=use_norm,
|
|
|
num_groups=num_groups,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
self.to_out = (
|
|
|
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
|
|
if in_channels != out_channels
|
|
|
else nn.Identity()
|
|
|
)
|
|
|
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
|
|
|
assert_message = "context mapping required if context_mapping_features > 0"
|
|
|
assert not (self.use_mapping ^ exists(mapping)), assert_message
|
|
|
|
|
|
h = self.block1(x, causal=causal)
|
|
|
|
|
|
scale_shift = None
|
|
|
if self.use_mapping:
|
|
|
scale_shift = self.to_scale_shift(mapping)
|
|
|
|
|
|
h = self.block2(h, scale_shift=scale_shift, causal=causal)
|
|
|
|
|
|
return h + self.to_out(x)
|
|
|
|
|
|
|
|
|
class Patcher(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
patch_size: int,
|
|
|
context_mapping_features: Optional[int] = None,
|
|
|
use_snake: bool = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
|
|
|
assert out_channels % patch_size == 0, assert_message
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
self.block = ResnetBlock1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels // patch_size,
|
|
|
num_groups=1,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
|
|
|
x = self.block(x, mapping, causal=causal)
|
|
|
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class Unpatcher(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
patch_size: int,
|
|
|
context_mapping_features: Optional[int] = None,
|
|
|
use_snake: bool = False
|
|
|
):
|
|
|
super().__init__()
|
|
|
assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
|
|
|
assert in_channels % patch_size == 0, assert_message
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
self.block = ResnetBlock1d(
|
|
|
in_channels=in_channels // patch_size,
|
|
|
out_channels=out_channels,
|
|
|
num_groups=1,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
|
|
|
x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
|
|
|
x = self.block(x, mapping, causal=causal)
|
|
|
return x
|
|
|
|
|
|
|
|
|
"""
|
|
|
Attention Components
|
|
|
"""
|
|
|
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
|
|
mid_features = features * multiplier
|
|
|
return nn.Sequential(
|
|
|
nn.Linear(in_features=features, out_features=mid_features),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(in_features=mid_features, out_features=features),
|
|
|
)
|
|
|
|
|
|
def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
|
|
|
b, ndim = sim.shape[0], mask.ndim
|
|
|
if ndim == 3:
|
|
|
mask = rearrange(mask, "b n m -> b 1 n m")
|
|
|
if ndim == 2:
|
|
|
mask = repeat(mask, "n m -> b 1 n m", b=b)
|
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
|
sim = sim.masked_fill(~mask, max_neg_value)
|
|
|
return sim
|
|
|
|
|
|
def causal_mask(q: Tensor, k: Tensor) -> Tensor:
|
|
|
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
|
|
|
mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
|
|
|
mask = repeat(mask, "n m -> b n m", b=b)
|
|
|
return mask
|
|
|
|
|
|
class AttentionBase(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
features: int,
|
|
|
*,
|
|
|
head_features: int,
|
|
|
num_heads: int,
|
|
|
out_features: Optional[int] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.scale = head_features**-0.5
|
|
|
self.num_heads = num_heads
|
|
|
mid_features = head_features * num_heads
|
|
|
out_features = default(out_features, features)
|
|
|
|
|
|
self.to_out = nn.Linear(
|
|
|
in_features=mid_features, out_features=out_features
|
|
|
)
|
|
|
|
|
|
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
|
|
|
|
|
if not self.use_flash:
|
|
|
return
|
|
|
|
|
|
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
|
|
|
|
|
if device_properties.major == 8 and device_properties.minor == 0:
|
|
|
|
|
|
self.sdp_kernel_config = (True, False, False)
|
|
|
else:
|
|
|
|
|
|
self.sdp_kernel_config = (False, True, True)
|
|
|
|
|
|
def forward(
|
|
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
|
|
|
) -> Tensor:
|
|
|
|
|
|
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
|
|
|
|
|
if not self.use_flash:
|
|
|
if is_causal and not mask:
|
|
|
|
|
|
mask = causal_mask(q, k)
|
|
|
|
|
|
|
|
|
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
|
|
|
sim = add_mask(sim, mask) if exists(mask) else sim
|
|
|
|
|
|
|
|
|
attn = sim.softmax(dim=-1, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
out = einsum("... n m, ... m d -> ... n d", attn, v)
|
|
|
else:
|
|
|
with sdp_kernel(*self.sdp_kernel_config):
|
|
|
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
|
|
|
|
|
|
out = rearrange(out, "b h n d -> b n (h d)")
|
|
|
return self.to_out(out)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
features: int,
|
|
|
*,
|
|
|
head_features: int,
|
|
|
num_heads: int,
|
|
|
out_features: Optional[int] = None,
|
|
|
context_features: Optional[int] = None,
|
|
|
causal: bool = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.context_features = context_features
|
|
|
self.causal = causal
|
|
|
mid_features = head_features * num_heads
|
|
|
context_features = default(context_features, features)
|
|
|
|
|
|
self.norm = nn.LayerNorm(features)
|
|
|
self.norm_context = nn.LayerNorm(context_features)
|
|
|
self.to_q = nn.Linear(
|
|
|
in_features=features, out_features=mid_features, bias=False
|
|
|
)
|
|
|
self.to_kv = nn.Linear(
|
|
|
in_features=context_features, out_features=mid_features * 2, bias=False
|
|
|
)
|
|
|
self.attention = AttentionBase(
|
|
|
features,
|
|
|
num_heads=num_heads,
|
|
|
head_features=head_features,
|
|
|
out_features=out_features,
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: Tensor,
|
|
|
context: Optional[Tensor] = None,
|
|
|
context_mask: Optional[Tensor] = None,
|
|
|
causal: Optional[bool] = False,
|
|
|
) -> Tensor:
|
|
|
assert_message = "You must provide a context when using context_features"
|
|
|
assert not self.context_features or exists(context), assert_message
|
|
|
|
|
|
context = default(context, x)
|
|
|
|
|
|
x, context = self.norm(x), self.norm_context(context)
|
|
|
|
|
|
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
|
|
|
|
|
if exists(context_mask):
|
|
|
|
|
|
mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
|
|
|
k, v = k * mask, v * mask
|
|
|
|
|
|
|
|
|
return self.attention(q, k, v, is_causal=self.causal or causal)
|
|
|
|
|
|
|
|
|
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
|
|
mid_features = features * multiplier
|
|
|
return nn.Sequential(
|
|
|
nn.Linear(in_features=features, out_features=mid_features),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(in_features=mid_features, out_features=features),
|
|
|
)
|
|
|
|
|
|
"""
|
|
|
Transformer Blocks
|
|
|
"""
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
features: int,
|
|
|
num_heads: int,
|
|
|
head_features: int,
|
|
|
multiplier: int,
|
|
|
context_features: Optional[int] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.use_cross_attention = exists(context_features) and context_features > 0
|
|
|
|
|
|
self.attention = Attention(
|
|
|
features=features,
|
|
|
num_heads=num_heads,
|
|
|
head_features=head_features
|
|
|
)
|
|
|
|
|
|
if self.use_cross_attention:
|
|
|
self.cross_attention = Attention(
|
|
|
features=features,
|
|
|
num_heads=num_heads,
|
|
|
head_features=head_features,
|
|
|
context_features=context_features
|
|
|
)
|
|
|
|
|
|
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
|
|
|
|
|
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
|
|
|
x = self.attention(x, causal=causal) + x
|
|
|
if self.use_cross_attention:
|
|
|
x = self.cross_attention(x, context=context, context_mask=context_mask) + x
|
|
|
x = self.feed_forward(x) + x
|
|
|
return x
|
|
|
|
|
|
|
|
|
"""
|
|
|
Transformers
|
|
|
"""
|
|
|
|
|
|
|
|
|
class Transformer1d(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_layers: int,
|
|
|
channels: int,
|
|
|
num_heads: int,
|
|
|
head_features: int,
|
|
|
multiplier: int,
|
|
|
context_features: Optional[int] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.to_in = nn.Sequential(
|
|
|
nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
|
|
|
Conv1d(
|
|
|
in_channels=channels,
|
|
|
out_channels=channels,
|
|
|
kernel_size=1,
|
|
|
),
|
|
|
Rearrange("b c t -> b t c"),
|
|
|
)
|
|
|
|
|
|
self.blocks = nn.ModuleList(
|
|
|
[
|
|
|
TransformerBlock(
|
|
|
features=channels,
|
|
|
head_features=head_features,
|
|
|
num_heads=num_heads,
|
|
|
multiplier=multiplier,
|
|
|
context_features=context_features,
|
|
|
)
|
|
|
for i in range(num_layers)
|
|
|
]
|
|
|
)
|
|
|
|
|
|
self.to_out = nn.Sequential(
|
|
|
Rearrange("b t c -> b c t"),
|
|
|
Conv1d(
|
|
|
in_channels=channels,
|
|
|
out_channels=channels,
|
|
|
kernel_size=1,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
|
|
|
x = self.to_in(x)
|
|
|
for block in self.blocks:
|
|
|
x = block(x, context=context, context_mask=context_mask, causal=causal)
|
|
|
x = self.to_out(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
"""
|
|
|
Time Embeddings
|
|
|
"""
|
|
|
|
|
|
|
|
|
class SinusoidalEmbedding(nn.Module):
|
|
|
def __init__(self, dim: int):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
|
device, half_dim = x.device, self.dim // 2
|
|
|
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
|
|
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
|
|
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
|
|
|
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
|
|
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Module):
|
|
|
"""Used for continuous time"""
|
|
|
|
|
|
def __init__(self, dim: int):
|
|
|
super().__init__()
|
|
|
assert (dim % 2) == 0
|
|
|
half_dim = dim // 2
|
|
|
self.weights = nn.Parameter(torch.randn(half_dim))
|
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
|
x = rearrange(x, "b -> b 1")
|
|
|
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
|
|
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
|
|
fouriered = torch.cat((x, fouriered), dim=-1)
|
|
|
return fouriered
|
|
|
|
|
|
|
|
|
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
|
|
return nn.Sequential(
|
|
|
LearnedPositionalEmbedding(dim),
|
|
|
nn.Linear(in_features=dim + 1, out_features=out_features),
|
|
|
)
|
|
|
|
|
|
|
|
|
"""
|
|
|
Encoder/Decoder Components
|
|
|
"""
|
|
|
|
|
|
|
|
|
class DownsampleBlock1d(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
*,
|
|
|
factor: int,
|
|
|
num_groups: int,
|
|
|
num_layers: int,
|
|
|
kernel_multiplier: int = 2,
|
|
|
use_pre_downsample: bool = True,
|
|
|
use_skip: bool = False,
|
|
|
use_snake: bool = False,
|
|
|
extract_channels: int = 0,
|
|
|
context_channels: int = 0,
|
|
|
num_transformer_blocks: int = 0,
|
|
|
attention_heads: Optional[int] = None,
|
|
|
attention_features: Optional[int] = None,
|
|
|
attention_multiplier: Optional[int] = None,
|
|
|
context_mapping_features: Optional[int] = None,
|
|
|
context_embedding_features: Optional[int] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.use_pre_downsample = use_pre_downsample
|
|
|
self.use_skip = use_skip
|
|
|
self.use_transformer = num_transformer_blocks > 0
|
|
|
self.use_extract = extract_channels > 0
|
|
|
self.use_context = context_channels > 0
|
|
|
|
|
|
channels = out_channels if use_pre_downsample else in_channels
|
|
|
|
|
|
self.downsample = Downsample1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels,
|
|
|
factor=factor,
|
|
|
kernel_multiplier=kernel_multiplier,
|
|
|
)
|
|
|
|
|
|
self.blocks = nn.ModuleList(
|
|
|
[
|
|
|
ResnetBlock1d(
|
|
|
in_channels=channels + context_channels if i == 0 else channels,
|
|
|
out_channels=channels,
|
|
|
num_groups=num_groups,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
for i in range(num_layers)
|
|
|
]
|
|
|
)
|
|
|
|
|
|
if self.use_transformer:
|
|
|
assert (
|
|
|
(exists(attention_heads) or exists(attention_features))
|
|
|
and exists(attention_multiplier)
|
|
|
)
|
|
|
|
|
|
if attention_features is None and attention_heads is not None:
|
|
|
attention_features = channels // attention_heads
|
|
|
|
|
|
if attention_heads is None and attention_features is not None:
|
|
|
attention_heads = channels // attention_features
|
|
|
|
|
|
self.transformer = Transformer1d(
|
|
|
num_layers=num_transformer_blocks,
|
|
|
channels=channels,
|
|
|
num_heads=attention_heads,
|
|
|
head_features=attention_features,
|
|
|
multiplier=attention_multiplier,
|
|
|
context_features=context_embedding_features
|
|
|
)
|
|
|
|
|
|
if self.use_extract:
|
|
|
num_extract_groups = min(num_groups, extract_channels)
|
|
|
self.to_extracted = ResnetBlock1d(
|
|
|
in_channels=out_channels,
|
|
|
out_channels=extract_channels,
|
|
|
num_groups=num_extract_groups,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: Tensor,
|
|
|
*,
|
|
|
mapping: Optional[Tensor] = None,
|
|
|
channels: Optional[Tensor] = None,
|
|
|
embedding: Optional[Tensor] = None,
|
|
|
embedding_mask: Optional[Tensor] = None,
|
|
|
causal: Optional[bool] = False
|
|
|
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
|
|
|
|
|
|
if self.use_pre_downsample:
|
|
|
x = self.downsample(x)
|
|
|
|
|
|
if self.use_context and exists(channels):
|
|
|
x = torch.cat([x, channels], dim=1)
|
|
|
|
|
|
skips = []
|
|
|
for block in self.blocks:
|
|
|
x = block(x, mapping=mapping, causal=causal)
|
|
|
skips += [x] if self.use_skip else []
|
|
|
|
|
|
if self.use_transformer:
|
|
|
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
|
|
|
skips += [x] if self.use_skip else []
|
|
|
|
|
|
if not self.use_pre_downsample:
|
|
|
x = self.downsample(x)
|
|
|
|
|
|
if self.use_extract:
|
|
|
extracted = self.to_extracted(x)
|
|
|
return x, extracted
|
|
|
|
|
|
return (x, skips) if self.use_skip else x
|
|
|
|
|
|
|
|
|
class UpsampleBlock1d(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
*,
|
|
|
factor: int,
|
|
|
num_layers: int,
|
|
|
num_groups: int,
|
|
|
use_nearest: bool = False,
|
|
|
use_pre_upsample: bool = False,
|
|
|
use_skip: bool = False,
|
|
|
use_snake: bool = False,
|
|
|
skip_channels: int = 0,
|
|
|
use_skip_scale: bool = False,
|
|
|
extract_channels: int = 0,
|
|
|
num_transformer_blocks: int = 0,
|
|
|
attention_heads: Optional[int] = None,
|
|
|
attention_features: Optional[int] = None,
|
|
|
attention_multiplier: Optional[int] = None,
|
|
|
context_mapping_features: Optional[int] = None,
|
|
|
context_embedding_features: Optional[int] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.use_extract = extract_channels > 0
|
|
|
self.use_pre_upsample = use_pre_upsample
|
|
|
self.use_transformer = num_transformer_blocks > 0
|
|
|
self.use_skip = use_skip
|
|
|
self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
|
|
|
|
|
|
channels = out_channels if use_pre_upsample else in_channels
|
|
|
|
|
|
self.blocks = nn.ModuleList(
|
|
|
[
|
|
|
ResnetBlock1d(
|
|
|
in_channels=channels + skip_channels,
|
|
|
out_channels=channels,
|
|
|
num_groups=num_groups,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
for _ in range(num_layers)
|
|
|
]
|
|
|
)
|
|
|
|
|
|
if self.use_transformer:
|
|
|
assert (
|
|
|
(exists(attention_heads) or exists(attention_features))
|
|
|
and exists(attention_multiplier)
|
|
|
)
|
|
|
|
|
|
if attention_features is None and attention_heads is not None:
|
|
|
attention_features = channels // attention_heads
|
|
|
|
|
|
if attention_heads is None and attention_features is not None:
|
|
|
attention_heads = channels // attention_features
|
|
|
|
|
|
self.transformer = Transformer1d(
|
|
|
num_layers=num_transformer_blocks,
|
|
|
channels=channels,
|
|
|
num_heads=attention_heads,
|
|
|
head_features=attention_features,
|
|
|
multiplier=attention_multiplier,
|
|
|
context_features=context_embedding_features,
|
|
|
)
|
|
|
|
|
|
self.upsample = Upsample1d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=out_channels,
|
|
|
factor=factor,
|
|
|
use_nearest=use_nearest,
|
|
|
)
|
|
|
|
|
|
if self.use_extract:
|
|
|
num_extract_groups = min(num_groups, extract_channels)
|
|
|
self.to_extracted = ResnetBlock1d(
|
|
|
in_channels=out_channels,
|
|
|
out_channels=extract_channels,
|
|
|
num_groups=num_extract_groups,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
|
|
|
return torch.cat([x, skip * self.skip_scale], dim=1)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: Tensor,
|
|
|
*,
|
|
|
skips: Optional[List[Tensor]] = None,
|
|
|
mapping: Optional[Tensor] = None,
|
|
|
embedding: Optional[Tensor] = None,
|
|
|
embedding_mask: Optional[Tensor] = None,
|
|
|
causal: Optional[bool] = False
|
|
|
) -> Union[Tuple[Tensor, Tensor], Tensor]:
|
|
|
|
|
|
if self.use_pre_upsample:
|
|
|
x = self.upsample(x)
|
|
|
|
|
|
for block in self.blocks:
|
|
|
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
|
|
|
x = block(x, mapping=mapping, causal=causal)
|
|
|
|
|
|
if self.use_transformer:
|
|
|
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
|
|
|
|
|
|
if not self.use_pre_upsample:
|
|
|
x = self.upsample(x)
|
|
|
|
|
|
if self.use_extract:
|
|
|
extracted = self.to_extracted(x)
|
|
|
return x, extracted
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class BottleneckBlock1d(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
channels: int,
|
|
|
*,
|
|
|
num_groups: int,
|
|
|
num_transformer_blocks: int = 0,
|
|
|
attention_heads: Optional[int] = None,
|
|
|
attention_features: Optional[int] = None,
|
|
|
attention_multiplier: Optional[int] = None,
|
|
|
context_mapping_features: Optional[int] = None,
|
|
|
context_embedding_features: Optional[int] = None,
|
|
|
use_snake: bool = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.use_transformer = num_transformer_blocks > 0
|
|
|
|
|
|
self.pre_block = ResnetBlock1d(
|
|
|
in_channels=channels,
|
|
|
out_channels=channels,
|
|
|
num_groups=num_groups,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
if self.use_transformer:
|
|
|
assert (
|
|
|
(exists(attention_heads) or exists(attention_features))
|
|
|
and exists(attention_multiplier)
|
|
|
)
|
|
|
|
|
|
if attention_features is None and attention_heads is not None:
|
|
|
attention_features = channels // attention_heads
|
|
|
|
|
|
if attention_heads is None and attention_features is not None:
|
|
|
attention_heads = channels // attention_features
|
|
|
|
|
|
self.transformer = Transformer1d(
|
|
|
num_layers=num_transformer_blocks,
|
|
|
channels=channels,
|
|
|
num_heads=attention_heads,
|
|
|
head_features=attention_features,
|
|
|
multiplier=attention_multiplier,
|
|
|
context_features=context_embedding_features,
|
|
|
)
|
|
|
|
|
|
self.post_block = ResnetBlock1d(
|
|
|
in_channels=channels,
|
|
|
out_channels=channels,
|
|
|
num_groups=num_groups,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: Tensor,
|
|
|
*,
|
|
|
mapping: Optional[Tensor] = None,
|
|
|
embedding: Optional[Tensor] = None,
|
|
|
embedding_mask: Optional[Tensor] = None,
|
|
|
causal: Optional[bool] = False
|
|
|
) -> Tensor:
|
|
|
x = self.pre_block(x, mapping=mapping, causal=causal)
|
|
|
if self.use_transformer:
|
|
|
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
|
|
|
x = self.post_block(x, mapping=mapping, causal=causal)
|
|
|
return x
|
|
|
|
|
|
|
|
|
"""
|
|
|
UNet
|
|
|
"""
|
|
|
|
|
|
|
|
|
class UNet1d(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
channels: int,
|
|
|
multipliers: Sequence[int],
|
|
|
factors: Sequence[int],
|
|
|
num_blocks: Sequence[int],
|
|
|
attentions: Sequence[int],
|
|
|
patch_size: int = 1,
|
|
|
resnet_groups: int = 8,
|
|
|
use_context_time: bool = True,
|
|
|
kernel_multiplier_downsample: int = 2,
|
|
|
use_nearest_upsample: bool = False,
|
|
|
use_skip_scale: bool = True,
|
|
|
use_snake: bool = False,
|
|
|
use_stft: bool = False,
|
|
|
use_stft_context: bool = False,
|
|
|
out_channels: Optional[int] = None,
|
|
|
context_features: Optional[int] = None,
|
|
|
context_features_multiplier: int = 4,
|
|
|
context_channels: Optional[Sequence[int]] = None,
|
|
|
context_embedding_features: Optional[int] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__()
|
|
|
out_channels = default(out_channels, in_channels)
|
|
|
context_channels = list(default(context_channels, []))
|
|
|
num_layers = len(multipliers) - 1
|
|
|
use_context_features = exists(context_features)
|
|
|
use_context_channels = len(context_channels) > 0
|
|
|
context_mapping_features = None
|
|
|
|
|
|
attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
|
|
|
|
|
|
self.num_layers = num_layers
|
|
|
self.use_context_time = use_context_time
|
|
|
self.use_context_features = use_context_features
|
|
|
self.use_context_channels = use_context_channels
|
|
|
self.use_stft = use_stft
|
|
|
self.use_stft_context = use_stft_context
|
|
|
|
|
|
self.context_features = context_features
|
|
|
context_channels_pad_length = num_layers + 1 - len(context_channels)
|
|
|
context_channels = context_channels + [0] * context_channels_pad_length
|
|
|
self.context_channels = context_channels
|
|
|
self.context_embedding_features = context_embedding_features
|
|
|
|
|
|
if use_context_channels:
|
|
|
has_context = [c > 0 for c in context_channels]
|
|
|
self.has_context = has_context
|
|
|
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
|
|
|
|
|
|
assert (
|
|
|
len(factors) == num_layers
|
|
|
and len(attentions) >= num_layers
|
|
|
and len(num_blocks) == num_layers
|
|
|
)
|
|
|
|
|
|
if use_context_time or use_context_features:
|
|
|
context_mapping_features = channels * context_features_multiplier
|
|
|
|
|
|
self.to_mapping = nn.Sequential(
|
|
|
nn.Linear(context_mapping_features, context_mapping_features),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(context_mapping_features, context_mapping_features),
|
|
|
nn.GELU(),
|
|
|
)
|
|
|
|
|
|
if use_context_time:
|
|
|
assert exists(context_mapping_features)
|
|
|
self.to_time = nn.Sequential(
|
|
|
TimePositionalEmbedding(
|
|
|
dim=channels, out_features=context_mapping_features
|
|
|
),
|
|
|
nn.GELU(),
|
|
|
)
|
|
|
|
|
|
if use_context_features:
|
|
|
assert exists(context_features) and exists(context_mapping_features)
|
|
|
self.to_features = nn.Sequential(
|
|
|
nn.Linear(
|
|
|
in_features=context_features, out_features=context_mapping_features
|
|
|
),
|
|
|
nn.GELU(),
|
|
|
)
|
|
|
|
|
|
if use_stft:
|
|
|
stft_kwargs, kwargs = groupby("stft_", kwargs)
|
|
|
assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
|
|
|
stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
|
|
|
in_channels *= stft_channels
|
|
|
out_channels *= stft_channels
|
|
|
context_channels[0] *= stft_channels if use_stft_context else 1
|
|
|
assert exists(in_channels) and exists(out_channels)
|
|
|
self.stft = STFT(**stft_kwargs)
|
|
|
|
|
|
assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
|
|
|
|
|
|
self.to_in = Patcher(
|
|
|
in_channels=in_channels + context_channels[0],
|
|
|
out_channels=channels * multipliers[0],
|
|
|
patch_size=patch_size,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
self.downsamples = nn.ModuleList(
|
|
|
[
|
|
|
DownsampleBlock1d(
|
|
|
in_channels=channels * multipliers[i],
|
|
|
out_channels=channels * multipliers[i + 1],
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
context_channels=context_channels[i + 1],
|
|
|
context_embedding_features=context_embedding_features,
|
|
|
num_layers=num_blocks[i],
|
|
|
factor=factors[i],
|
|
|
kernel_multiplier=kernel_multiplier_downsample,
|
|
|
num_groups=resnet_groups,
|
|
|
use_pre_downsample=True,
|
|
|
use_skip=True,
|
|
|
use_snake=use_snake,
|
|
|
num_transformer_blocks=attentions[i],
|
|
|
**attention_kwargs,
|
|
|
)
|
|
|
for i in range(num_layers)
|
|
|
]
|
|
|
)
|
|
|
|
|
|
self.bottleneck = BottleneckBlock1d(
|
|
|
channels=channels * multipliers[-1],
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
context_embedding_features=context_embedding_features,
|
|
|
num_groups=resnet_groups,
|
|
|
num_transformer_blocks=attentions[-1],
|
|
|
use_snake=use_snake,
|
|
|
**attention_kwargs,
|
|
|
)
|
|
|
|
|
|
self.upsamples = nn.ModuleList(
|
|
|
[
|
|
|
UpsampleBlock1d(
|
|
|
in_channels=channels * multipliers[i + 1],
|
|
|
out_channels=channels * multipliers[i],
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
context_embedding_features=context_embedding_features,
|
|
|
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
|
|
|
factor=factors[i],
|
|
|
use_nearest=use_nearest_upsample,
|
|
|
num_groups=resnet_groups,
|
|
|
use_skip_scale=use_skip_scale,
|
|
|
use_pre_upsample=False,
|
|
|
use_skip=True,
|
|
|
use_snake=use_snake,
|
|
|
skip_channels=channels * multipliers[i + 1],
|
|
|
num_transformer_blocks=attentions[i],
|
|
|
**attention_kwargs,
|
|
|
)
|
|
|
for i in reversed(range(num_layers))
|
|
|
]
|
|
|
)
|
|
|
|
|
|
self.to_out = Unpatcher(
|
|
|
in_channels=channels * multipliers[0],
|
|
|
out_channels=out_channels,
|
|
|
patch_size=patch_size,
|
|
|
context_mapping_features=context_mapping_features,
|
|
|
use_snake=use_snake
|
|
|
)
|
|
|
|
|
|
def get_channels(
|
|
|
self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
|
|
|
) -> Optional[Tensor]:
|
|
|
"""Gets context channels at `layer` and checks that shape is correct"""
|
|
|
use_context_channels = self.use_context_channels and self.has_context[layer]
|
|
|
if not use_context_channels:
|
|
|
return None
|
|
|
assert exists(channels_list), "Missing context"
|
|
|
|
|
|
channels_id = self.channels_ids[layer]
|
|
|
|
|
|
channels = channels_list[channels_id]
|
|
|
message = f"Missing context for layer {layer} at index {channels_id}"
|
|
|
assert exists(channels), message
|
|
|
|
|
|
num_channels = self.context_channels[layer]
|
|
|
message = f"Expected context with {num_channels} channels at idx {channels_id}"
|
|
|
assert channels.shape[1] == num_channels, message
|
|
|
|
|
|
channels = self.stft.encode1d(channels) if self.use_stft_context else channels
|
|
|
return channels
|
|
|
|
|
|
def get_mapping(
|
|
|
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
|
|
|
) -> Optional[Tensor]:
|
|
|
"""Combines context time features and features into mapping"""
|
|
|
items, mapping = [], None
|
|
|
|
|
|
if self.use_context_time:
|
|
|
assert_message = "use_context_time=True but no time features provided"
|
|
|
assert exists(time), assert_message
|
|
|
items += [self.to_time(time)]
|
|
|
|
|
|
if self.use_context_features:
|
|
|
assert_message = "context_features exists but no features provided"
|
|
|
assert exists(features), assert_message
|
|
|
items += [self.to_features(features)]
|
|
|
|
|
|
if self.use_context_time or self.use_context_features:
|
|
|
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
|
|
mapping = self.to_mapping(mapping)
|
|
|
return mapping
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: Tensor,
|
|
|
time: Optional[Tensor] = None,
|
|
|
*,
|
|
|
features: Optional[Tensor] = None,
|
|
|
channels_list: Optional[Sequence[Tensor]] = None,
|
|
|
embedding: Optional[Tensor] = None,
|
|
|
embedding_mask: Optional[Tensor] = None,
|
|
|
causal: Optional[bool] = False,
|
|
|
) -> Tensor:
|
|
|
channels = self.get_channels(channels_list, layer=0)
|
|
|
|
|
|
x = self.stft.encode1d(x) if self.use_stft else x
|
|
|
|
|
|
x = torch.cat([x, channels], dim=1) if exists(channels) else x
|
|
|
|
|
|
mapping = self.get_mapping(time, features)
|
|
|
x = self.to_in(x, mapping, causal=causal)
|
|
|
skips_list = [x]
|
|
|
|
|
|
for i, downsample in enumerate(self.downsamples):
|
|
|
channels = self.get_channels(channels_list, layer=i + 1)
|
|
|
x, skips = downsample(
|
|
|
x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
|
|
|
)
|
|
|
skips_list += [skips]
|
|
|
|
|
|
x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
|
|
|
|
|
|
for i, upsample in enumerate(self.upsamples):
|
|
|
skips = skips_list.pop()
|
|
|
x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
|
|
|
|
|
|
x += skips_list.pop()
|
|
|
x = self.to_out(x, mapping, causal=causal)
|
|
|
x = self.stft.decode1d(x) if self.use_stft else x
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
""" Conditioning Modules """
|
|
|
|
|
|
|
|
|
class FixedEmbedding(nn.Module):
|
|
|
def __init__(self, max_length: int, features: int):
|
|
|
super().__init__()
|
|
|
self.max_length = max_length
|
|
|
self.embedding = nn.Embedding(max_length, features)
|
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
|
batch_size, length, device = *x.shape[0:2], x.device
|
|
|
assert_message = "Input sequence length must be <= max_length"
|
|
|
assert length <= self.max_length, assert_message
|
|
|
position = torch.arange(length, device=device)
|
|
|
fixed_embedding = self.embedding(position)
|
|
|
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
|
|
|
return fixed_embedding
|
|
|
|
|
|
|
|
|
def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
|
|
|
if proba == 1:
|
|
|
return torch.ones(shape, device=device, dtype=torch.bool)
|
|
|
elif proba == 0:
|
|
|
return torch.zeros(shape, device=device, dtype=torch.bool)
|
|
|
else:
|
|
|
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
|
|
|
|
|
|
|
|
class UNetCFG1d(UNet1d):
|
|
|
|
|
|
"""UNet1d with Classifier-Free Guidance"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
context_embedding_max_length: int,
|
|
|
context_embedding_features: int,
|
|
|
use_xattn_time: bool = False,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(
|
|
|
context_embedding_features=context_embedding_features, **kwargs
|
|
|
)
|
|
|
|
|
|
self.use_xattn_time = use_xattn_time
|
|
|
|
|
|
if use_xattn_time:
|
|
|
assert exists(context_embedding_features)
|
|
|
self.to_time_embedding = nn.Sequential(
|
|
|
TimePositionalEmbedding(
|
|
|
dim=kwargs["channels"], out_features=context_embedding_features
|
|
|
),
|
|
|
nn.GELU(),
|
|
|
)
|
|
|
|
|
|
context_embedding_max_length += 1
|
|
|
|
|
|
self.fixed_embedding = FixedEmbedding(
|
|
|
max_length=context_embedding_max_length, features=context_embedding_features
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: Tensor,
|
|
|
time: Tensor,
|
|
|
*,
|
|
|
embedding: Tensor,
|
|
|
embedding_mask: Optional[Tensor] = None,
|
|
|
embedding_scale: float = 1.0,
|
|
|
embedding_mask_proba: float = 0.0,
|
|
|
batch_cfg: bool = False,
|
|
|
rescale_cfg: bool = False,
|
|
|
scale_phi: float = 0.4,
|
|
|
negative_embedding: Optional[Tensor] = None,
|
|
|
negative_embedding_mask: Optional[Tensor] = None,
|
|
|
**kwargs,
|
|
|
) -> Tensor:
|
|
|
b, device = embedding.shape[0], embedding.device
|
|
|
|
|
|
if self.use_xattn_time:
|
|
|
embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
|
|
|
|
|
|
if embedding_mask is not None:
|
|
|
embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
|
|
|
|
|
|
fixed_embedding = self.fixed_embedding(embedding)
|
|
|
|
|
|
if embedding_mask_proba > 0.0:
|
|
|
|
|
|
batch_mask = rand_bool(
|
|
|
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
|
|
)
|
|
|
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
|
|
|
|
|
if embedding_scale != 1.0:
|
|
|
if batch_cfg:
|
|
|
batch_x = torch.cat([x, x], dim=0)
|
|
|
batch_time = torch.cat([time, time], dim=0)
|
|
|
|
|
|
if negative_embedding is not None:
|
|
|
if negative_embedding_mask is not None:
|
|
|
negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
|
|
|
|
|
|
negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
|
|
|
|
|
|
batch_embed = torch.cat([embedding, negative_embedding], dim=0)
|
|
|
|
|
|
else:
|
|
|
batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
|
|
|
|
|
|
batch_mask = None
|
|
|
if embedding_mask is not None:
|
|
|
batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
|
|
|
|
|
|
batch_features = None
|
|
|
features = kwargs.pop("features", None)
|
|
|
if self.use_context_features:
|
|
|
batch_features = torch.cat([features, features], dim=0)
|
|
|
|
|
|
batch_channels = None
|
|
|
channels_list = kwargs.pop("channels_list", None)
|
|
|
if self.use_context_channels:
|
|
|
batch_channels = []
|
|
|
for channels in channels_list:
|
|
|
batch_channels += [torch.cat([channels, channels], dim=0)]
|
|
|
|
|
|
|
|
|
batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
|
|
|
out, out_masked = batch_out.chunk(2, dim=0)
|
|
|
|
|
|
else:
|
|
|
|
|
|
out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
|
|
|
out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
|
|
|
|
|
|
out_cfg = out_masked + (out - out_masked) * embedding_scale
|
|
|
|
|
|
if rescale_cfg:
|
|
|
|
|
|
out_std = out.std(dim=1, keepdim=True)
|
|
|
out_cfg_std = out_cfg.std(dim=1, keepdim=True)
|
|
|
|
|
|
return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
|
|
|
|
|
|
else:
|
|
|
|
|
|
return out_cfg
|
|
|
|
|
|
else:
|
|
|
return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
|
|
|
|
|
|
|
|
|
class UNetNCCA1d(UNet1d):
|
|
|
|
|
|
"""UNet1d with Noise Channel Conditioning Augmentation"""
|
|
|
|
|
|
def __init__(self, context_features: int, **kwargs):
|
|
|
super().__init__(context_features=context_features, **kwargs)
|
|
|
self.embedder = NumberEmbedder(features=context_features)
|
|
|
|
|
|
def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
|
|
|
x = x if torch.is_tensor(x) else torch.tensor(x)
|
|
|
return x.expand(shape)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: Tensor,
|
|
|
time: Tensor,
|
|
|
*,
|
|
|
channels_list: Sequence[Tensor],
|
|
|
channels_augmentation: Union[
|
|
|
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
|
|
|
] = False,
|
|
|
channels_scale: Union[
|
|
|
float, Sequence[float], Sequence[Sequence[float]], Tensor
|
|
|
] = 0,
|
|
|
**kwargs,
|
|
|
) -> Tensor:
|
|
|
b, n = x.shape[0], len(channels_list)
|
|
|
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
|
|
|
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
|
|
|
|
|
|
|
|
|
for i in range(n):
|
|
|
scale = channels_scale[:, i] * channels_augmentation[:, i]
|
|
|
scale = rearrange(scale, "b -> b 1 1")
|
|
|
item = channels_list[i]
|
|
|
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale)
|
|
|
|
|
|
|
|
|
channels_scale_emb = self.embedder(channels_scale)
|
|
|
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
|
|
|
|
|
|
return super().forward(
|
|
|
x=x,
|
|
|
time=time,
|
|
|
channels_list=channels_list,
|
|
|
features=channels_scale_emb,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
|
|
|
class UNetAll1d(UNetCFG1d, UNetNCCA1d):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
return UNetCFG1d.forward(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
|
|
|
if type == "base":
|
|
|
return UNet1d(**kwargs)
|
|
|
elif type == "all":
|
|
|
return UNetAll1d(**kwargs)
|
|
|
elif type == "cfg":
|
|
|
return UNetCFG1d(**kwargs)
|
|
|
elif type == "ncca":
|
|
|
return UNetNCCA1d(**kwargs)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown XUNet1d type: {type}")
|
|
|
|
|
|
class NumberEmbedder(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
features: int,
|
|
|
dim: int = 256,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.features = features
|
|
|
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
|
|
|
|
|
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
|
|
if not torch.is_tensor(x):
|
|
|
device = next(self.embedding.parameters()).device
|
|
|
x = torch.tensor(x, device=device)
|
|
|
assert isinstance(x, Tensor)
|
|
|
shape = x.shape
|
|
|
x = rearrange(x, "... -> (...)")
|
|
|
embedding = self.embedding(x)
|
|
|
x = embedding.view(*shape, self.features)
|
|
|
return x
|
|
|
|
|
|
|
|
|
"""
|
|
|
Audio Transforms
|
|
|
"""
|
|
|
|
|
|
|
|
|
class STFT(nn.Module):
|
|
|
"""Helper for torch stft and istft"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_fft: int = 1023,
|
|
|
hop_length: int = 256,
|
|
|
window_length: Optional[int] = None,
|
|
|
length: Optional[int] = None,
|
|
|
use_complex: bool = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.num_fft = num_fft
|
|
|
self.hop_length = default(hop_length, floor(num_fft // 4))
|
|
|
self.window_length = default(window_length, num_fft)
|
|
|
self.length = length
|
|
|
self.register_buffer("window", torch.hann_window(self.window_length))
|
|
|
self.use_complex = use_complex
|
|
|
|
|
|
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
b = wave.shape[0]
|
|
|
wave = rearrange(wave, "b c t -> (b c) t")
|
|
|
|
|
|
stft = torch.stft(
|
|
|
wave,
|
|
|
n_fft=self.num_fft,
|
|
|
hop_length=self.hop_length,
|
|
|
win_length=self.window_length,
|
|
|
window=self.window,
|
|
|
return_complex=True,
|
|
|
normalized=True,
|
|
|
)
|
|
|
|
|
|
if self.use_complex:
|
|
|
|
|
|
stft_a, stft_b = stft.real, stft.imag
|
|
|
else:
|
|
|
|
|
|
magnitude, phase = torch.abs(stft), torch.angle(stft)
|
|
|
stft_a, stft_b = magnitude, phase
|
|
|
|
|
|
return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
|
|
|
|
|
|
def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
|
|
|
b, l = stft_a.shape[0], stft_a.shape[-1]
|
|
|
length = closest_power_2(l * self.hop_length)
|
|
|
|
|
|
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
|
|
|
|
|
|
if self.use_complex:
|
|
|
real, imag = stft_a, stft_b
|
|
|
else:
|
|
|
magnitude, phase = stft_a, stft_b
|
|
|
real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
|
|
|
|
|
|
stft = torch.stack([real, imag], dim=-1)
|
|
|
|
|
|
wave = torch.istft(
|
|
|
stft,
|
|
|
n_fft=self.num_fft,
|
|
|
hop_length=self.hop_length,
|
|
|
win_length=self.window_length,
|
|
|
window=self.window,
|
|
|
length=default(self.length, length),
|
|
|
normalized=True,
|
|
|
)
|
|
|
|
|
|
return rearrange(wave, "(b c) t -> b c t", b=b)
|
|
|
|
|
|
def encode1d(
|
|
|
self, wave: Tensor, stacked: bool = True
|
|
|
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
|
|
stft_a, stft_b = self.encode(wave)
|
|
|
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
|
|
|
return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
|
|
|
|
|
|
def decode1d(self, stft_pair: Tensor) -> Tensor:
|
|
|
f = self.num_fft // 2 + 1
|
|
|
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
|
|
|
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
|
|
|
return self.decode(stft_a, stft_b)
|
|
|
|