| |
| |
| |
| |
| |
|
|
| """ |
| Pytorch Unet Module used for diffusion. |
| """ |
|
|
| from dataclasses import dataclass |
| import typing as tp |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding |
|
|
|
|
| @dataclass |
| class Output: |
| sample: torch.Tensor |
|
|
|
|
| def get_model(cfg, channels: int, side: int, num_steps: int): |
| if cfg.model == 'unet': |
| return DiffusionUnet( |
| chin=channels, num_steps=num_steps, **cfg.diffusion_unet) |
| else: |
| raise RuntimeError('Not Implemented') |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, |
| dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, |
| dropout: float = 0.): |
| super().__init__() |
| stride = 1 |
| padding = dilation * (kernel - stride) // 2 |
| Conv = nn.Conv1d |
| Drop = nn.Dropout1d |
| self.norm1 = nn.GroupNorm(norm_groups, channels) |
| self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) |
| self.activation1 = activation() |
| self.dropout1 = Drop(dropout) |
|
|
| self.norm2 = nn.GroupNorm(norm_groups, channels) |
| self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) |
| self.activation2 = activation() |
| self.dropout2 = Drop(dropout) |
|
|
| def forward(self, x): |
| h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) |
| h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) |
| return x + h |
|
|
|
|
| class DecoderLayer(nn.Module): |
| def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, |
| norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, |
| dropout: float = 0.): |
| super().__init__() |
| padding = (kernel - stride) // 2 |
| self.res_blocks = nn.Sequential( |
| *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) |
| for idx in range(res_blocks)]) |
| self.norm = nn.GroupNorm(norm_groups, chin) |
| ConvTr = nn.ConvTranspose1d |
| self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) |
| self.activation = activation() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.res_blocks(x) |
| x = self.norm(x) |
| x = self.activation(x) |
| x = self.convtr(x) |
| return x |
|
|
|
|
| class EncoderLayer(nn.Module): |
| def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, |
| norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, |
| dropout: float = 0.): |
| super().__init__() |
| padding = (kernel - stride) // 2 |
| Conv = nn.Conv1d |
| self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) |
| self.norm = nn.GroupNorm(norm_groups, chout) |
| self.activation = activation() |
| self.res_blocks = nn.Sequential( |
| *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) |
| for idx in range(res_blocks)]) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, C, T = x.shape |
| stride, = self.conv.stride |
| pad = (stride - (T % stride)) % stride |
| x = F.pad(x, (0, pad)) |
|
|
| x = self.conv(x) |
| x = self.norm(x) |
| x = self.activation(x) |
| x = self.res_blocks(x) |
| return x |
|
|
|
|
| class BLSTM(nn.Module): |
| """BiLSTM with same hidden units as input dim. |
| """ |
| def __init__(self, dim, layers=2): |
| super().__init__() |
| self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) |
| self.linear = nn.Linear(2 * dim, dim) |
|
|
| def forward(self, x): |
| x = x.permute(2, 0, 1) |
| x = self.lstm(x)[0] |
| x = self.linear(x) |
| x = x.permute(1, 2, 0) |
| return x |
|
|
|
|
| class DiffusionUnet(nn.Module): |
| def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., |
| max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, |
| bilstm: bool = False, transformer: bool = False, |
| codec_dim: tp.Optional[int] = None, **kwargs): |
| super().__init__() |
| self.encoders = nn.ModuleList() |
| self.decoders = nn.ModuleList() |
| self.embeddings: tp.Optional[nn.ModuleList] = None |
| self.embedding = nn.Embedding(num_steps, hidden) |
| if emb_all_layers: |
| self.embeddings = nn.ModuleList() |
| self.condition_embedding: tp.Optional[nn.Module] = None |
| for d in range(depth): |
| encoder = EncoderLayer(chin, hidden, **kwargs) |
| decoder = DecoderLayer(hidden, chin, **kwargs) |
| self.encoders.append(encoder) |
| self.decoders.insert(0, decoder) |
| if emb_all_layers and d > 0: |
| assert self.embeddings is not None |
| self.embeddings.append(nn.Embedding(num_steps, hidden)) |
| chin = hidden |
| hidden = min(int(chin * growth), max_channels) |
| self.bilstm: tp.Optional[nn.Module] |
| if bilstm: |
| self.bilstm = BLSTM(chin) |
| else: |
| self.bilstm = None |
| self.use_transformer = transformer |
| self.cross_attention = False |
| if transformer: |
| self.cross_attention = cross_attention |
| self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, |
| cross_attention=cross_attention) |
|
|
| self.use_codec = False |
| if codec_dim is not None: |
| self.conv_codec = nn.Conv1d(codec_dim, chin, 1) |
| self.use_codec = True |
|
|
| def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): |
| skips = [] |
| bs = x.size(0) |
| z = x |
| view_args = [1] |
| if type(step) is torch.Tensor: |
| step_tensor = step |
| else: |
| step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) |
|
|
| for idx, encoder in enumerate(self.encoders): |
| z = encoder(z) |
| if idx == 0: |
| z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) |
| elif self.embeddings is not None: |
| z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) |
|
|
| skips.append(z) |
|
|
| if self.use_codec: |
| assert condition is not None, "Model defined for conditionnal generation" |
| condition_emb = self.conv_codec(condition) |
| assert condition_emb.size(-1) <= 2 * z.size(-1), \ |
| f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" |
| if not self.cross_attention: |
|
|
| condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) |
| assert z.size() == condition_emb.size() |
| z += condition_emb |
| cross_attention_src = None |
| else: |
| cross_attention_src = condition_emb.permute(0, 2, 1) |
| B, T, C = cross_attention_src.shape |
| positions = torch.arange(T, device=x.device).view(1, -1, 1) |
| pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) |
| cross_attention_src = cross_attention_src + pos_emb |
| if self.use_transformer: |
| z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) |
| else: |
| if self.bilstm is None: |
| z = torch.zeros_like(z) |
| else: |
| z = self.bilstm(z) |
|
|
| for decoder in self.decoders: |
| s = skips.pop(-1) |
| z = z[:, :, :s.shape[2]] |
| z = z + s |
| z = decoder(z) |
|
|
| z = z[:, :, :x.shape[2]] |
| return Output(z) |
|
|