| | from typing import Callable, Optional, Sequence |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from a_unet import ( |
| | ClassifierFreeGuidancePlugin, |
| | Conv, |
| | Module, |
| | TextConditioningPlugin, |
| | TimeConditioningPlugin, |
| | default, |
| | exists, |
| | ) |
| | from a_unet.apex import ( |
| | AttentionItem, |
| | CrossAttentionItem, |
| | InjectChannelsItem, |
| | ModulationItem, |
| | ResnetItem, |
| | SkipCat, |
| | SkipModulate, |
| | XBlock, |
| | XUNet, |
| | ) |
| | from einops import pack, unpack |
| | from torch import Tensor, nn |
| | from torchaudio import transforms |
| |
|
| | """ |
| | UNets (built with a-unet: https://github.com/archinetai/a-unet) |
| | """ |
| |
|
| |
|
| | def UNetV0( |
| | dim: int, |
| | in_channels: int, |
| | channels: Sequence[int], |
| | factors: Sequence[int], |
| | items: Sequence[int], |
| | attentions: Optional[Sequence[int]] = None, |
| | cross_attentions: Optional[Sequence[int]] = None, |
| | context_channels: Optional[Sequence[int]] = None, |
| | attention_features: Optional[int] = None, |
| | attention_heads: Optional[int] = None, |
| | embedding_features: Optional[int] = None, |
| | resnet_groups: int = 8, |
| | use_modulation: bool = True, |
| | modulation_features: int = 1024, |
| | embedding_max_length: Optional[int] = None, |
| | use_time_conditioning: bool = True, |
| | use_embedding_cfg: bool = False, |
| | use_text_conditioning: bool = False, |
| | out_channels: Optional[int] = None, |
| | ): |
| | |
| | num_layers = len(channels) |
| | attentions = default(attentions, [0] * num_layers) |
| | cross_attentions = default(cross_attentions, [0] * num_layers) |
| | context_channels = default(context_channels, [0] * num_layers) |
| | xs = (channels, factors, items, attentions, cross_attentions, context_channels) |
| | assert all(len(x) == num_layers for x in xs) |
| |
|
| | |
| | UNetV0 = XUNet |
| |
|
| | if use_embedding_cfg: |
| | msg = "use_embedding_cfg requires embedding_max_length" |
| | assert exists(embedding_max_length), msg |
| | UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length) |
| |
|
| | if use_text_conditioning: |
| | UNetV0 = TextConditioningPlugin(UNetV0) |
| |
|
| | if use_time_conditioning: |
| | assert use_modulation, "use_time_conditioning requires use_modulation=True" |
| | UNetV0 = TimeConditioningPlugin(UNetV0) |
| |
|
| | |
| | return UNetV0( |
| | dim=dim, |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | blocks=[ |
| | XBlock( |
| | channels=channels, |
| | factor=factor, |
| | context_channels=ctx_channels, |
| | items=( |
| | [ResnetItem] |
| | + [ModulationItem] * use_modulation |
| | + [InjectChannelsItem] * (ctx_channels > 0) |
| | + [AttentionItem] * att |
| | + [CrossAttentionItem] * cross |
| | ) |
| | * items, |
| | ) |
| | for channels, factor, items, att, cross, ctx_channels in zip(*xs) |
| | ], |
| | skip_t=SkipModulate if use_modulation else SkipCat, |
| | attention_features=attention_features, |
| | attention_heads=attention_heads, |
| | embedding_features=embedding_features, |
| | modulation_features=modulation_features, |
| | resnet_groups=resnet_groups, |
| | ) |
| |
|
| |
|
| | """ |
| | Plugins |
| | """ |
| |
|
| |
|
| | def LTPlugin( |
| | net_t: Callable, num_filters: int, window_length: int, stride: int |
| | ) -> Callable[..., nn.Module]: |
| | """Learned Transform Plugin""" |
| |
|
| | def Net( |
| | dim: int, in_channels: int, out_channels: Optional[int] = None, **kwargs |
| | ) -> nn.Module: |
| | out_channels = default(out_channels, in_channels) |
| | in_channel_transform = in_channels * num_filters |
| | out_channel_transform = out_channels * num_filters |
| |
|
| | padding = window_length // 2 - stride // 2 |
| | encode = Conv( |
| | dim=dim, |
| | in_channels=in_channels, |
| | out_channels=in_channel_transform, |
| | kernel_size=window_length, |
| | stride=stride, |
| | padding=padding, |
| | padding_mode="reflect", |
| | bias=False, |
| | ) |
| | decode = nn.ConvTranspose1d( |
| | in_channels=out_channel_transform, |
| | out_channels=out_channels, |
| | kernel_size=window_length, |
| | stride=stride, |
| | padding=padding, |
| | bias=False, |
| | ) |
| | net = net_t( |
| | dim=dim, |
| | in_channels=in_channel_transform, |
| | out_channels=out_channel_transform, |
| | **kwargs |
| | ) |
| |
|
| | def forward(x: Tensor, *args, **kwargs): |
| | x = encode(x) |
| | x = net(x, *args, **kwargs) |
| | x = decode(x) |
| | return x |
| |
|
| | return Module([encode, decode, net], forward) |
| |
|
| | return Net |
| |
|
| |
|
| | def AppendChannelsPlugin( |
| | net_t: Callable, |
| | channels: int, |
| | ): |
| | def Net( |
| | in_channels: int, out_channels: Optional[int] = None, **kwargs |
| | ) -> nn.Module: |
| | out_channels = default(out_channels, in_channels) |
| | net = net_t( |
| | in_channels=in_channels + channels, out_channels=out_channels, **kwargs |
| | ) |
| |
|
| | def forward(x: Tensor, *args, append_channels: Tensor, **kwargs): |
| | x = torch.cat([x, append_channels], dim=1) |
| | return net(x, *args, **kwargs) |
| |
|
| | return Module([net], forward) |
| |
|
| | return Net |
| |
|
| |
|
| | """ |
| | Other |
| | """ |
| |
|
| |
|
| | class MelSpectrogram(nn.Module): |
| | def __init__( |
| | self, |
| | n_fft: int, |
| | hop_length: int, |
| | win_length: int, |
| | sample_rate: int, |
| | n_mel_channels: int, |
| | center: bool = False, |
| | normalize: bool = False, |
| | normalize_log: bool = False, |
| | ): |
| | super().__init__() |
| | self.padding = (n_fft - hop_length) // 2 |
| | self.normalize = normalize |
| | self.normalize_log = normalize_log |
| | self.hop_length = hop_length |
| |
|
| | self.to_spectrogram = transforms.Spectrogram( |
| | n_fft=n_fft, |
| | hop_length=hop_length, |
| | win_length=win_length, |
| | center=center, |
| | power=None, |
| | ) |
| |
|
| | self.to_mel_scale = transforms.MelScale( |
| | n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate |
| | ) |
| |
|
| | def forward(self, waveform: Tensor) -> Tensor: |
| | |
| | waveform, ps = pack([waveform], "* t") |
| | |
| | waveform = F.pad(waveform, [self.padding] * 2, mode="reflect") |
| | |
| | spectrogram = self.to_spectrogram(waveform) |
| | |
| | spectrogram = torch.abs(spectrogram) |
| | |
| | mel_spectrogram = self.to_mel_scale(spectrogram) |
| | |
| | if self.normalize: |
| | mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram) |
| | mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1 |
| | if self.normalize_log: |
| | mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5)) |
| | |
| | return unpack(mel_spectrogram, ps, "* f l")[0] |
| |
|