| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import math |
| | from torch.nn import functional as F |
| |
|
| |
|
| | class StyleAdaptiveLayerNorm(nn.Module): |
| | def __init__(self, normalized_shape, eps=1e-5): |
| | super().__init__() |
| | self.in_dim = normalized_shape |
| | self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False) |
| | self.style = nn.Linear(self.in_dim, self.in_dim * 2) |
| | self.style.bias.data[: self.in_dim] = 1 |
| | self.style.bias.data[self.in_dim :] = 0 |
| |
|
| | def forward(self, x, condition): |
| | |
| |
|
| | style = self.style(torch.mean(condition, dim=1, keepdim=True)) |
| |
|
| | gamma, beta = style.chunk(2, -1) |
| |
|
| | out = self.norm(x) |
| |
|
| | out = gamma * out + beta |
| | return out |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | def __init__(self, d_model, dropout, max_len=5000): |
| | super().__init__() |
| |
|
| | self.dropout = dropout |
| | position = torch.arange(max_len).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) |
| | ) |
| | pe = torch.zeros(max_len, 1, d_model) |
| | pe[:, 0, 0::2] = torch.sin(position * div_term) |
| | pe[:, 0, 1::2] = torch.cos(position * div_term) |
| | self.register_buffer("pe", pe) |
| |
|
| | def forward(self, x): |
| | x = x + self.pe[: x.size(0)] |
| | return F.dropout(x, self.dropout, training=self.training) |
| |
|
| |
|
| | class TransformerFFNLayer(nn.Module): |
| | def __init__( |
| | self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout |
| | ): |
| | super().__init__() |
| |
|
| | self.encoder_hidden = encoder_hidden |
| | self.conv_filter_size = conv_filter_size |
| | self.conv_kernel_size = conv_kernel_size |
| | self.encoder_dropout = encoder_dropout |
| |
|
| | self.ffn_1 = nn.Conv1d( |
| | self.encoder_hidden, |
| | self.conv_filter_size, |
| | self.conv_kernel_size, |
| | padding=self.conv_kernel_size // 2, |
| | ) |
| | self.ffn_1.weight.data.normal_(0.0, 0.02) |
| | self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden) |
| | self.ffn_2.weight.data.normal_(0.0, 0.02) |
| |
|
| | def forward(self, x): |
| | |
| | x = self.ffn_1(x.permute(0, 2, 1)).permute( |
| | 0, 2, 1 |
| | ) |
| | x = F.relu(x) |
| | x = F.dropout(x, self.encoder_dropout, training=self.training) |
| | x = self.ffn_2(x) |
| | return x |
| |
|
| |
|
| | class TransformerEncoderLayer(nn.Module): |
| | def __init__( |
| | self, |
| | encoder_hidden, |
| | encoder_head, |
| | conv_filter_size, |
| | conv_kernel_size, |
| | encoder_dropout, |
| | use_cln, |
| | ): |
| | super().__init__() |
| | self.encoder_hidden = encoder_hidden |
| | self.encoder_head = encoder_head |
| | self.conv_filter_size = conv_filter_size |
| | self.conv_kernel_size = conv_kernel_size |
| | self.encoder_dropout = encoder_dropout |
| | self.use_cln = use_cln |
| |
|
| | if not self.use_cln: |
| | self.ln_1 = nn.LayerNorm(self.encoder_hidden) |
| | self.ln_2 = nn.LayerNorm(self.encoder_hidden) |
| | else: |
| | self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden) |
| | self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden) |
| |
|
| | self.self_attn = nn.MultiheadAttention( |
| | self.encoder_hidden, self.encoder_head, batch_first=True |
| | ) |
| |
|
| | self.ffn = TransformerFFNLayer( |
| | self.encoder_hidden, |
| | self.conv_filter_size, |
| | self.conv_kernel_size, |
| | self.encoder_dropout, |
| | ) |
| |
|
| | def forward(self, x, key_padding_mask, conditon=None): |
| | |
| |
|
| | |
| | residual = x |
| | if self.use_cln: |
| | x = self.ln_1(x, conditon) |
| | else: |
| | x = self.ln_1(x) |
| |
|
| | if key_padding_mask != None: |
| | key_padding_mask_input = ~(key_padding_mask.bool()) |
| | else: |
| | key_padding_mask_input = None |
| | x, _ = self.self_attn( |
| | query=x, key=x, value=x, key_padding_mask=key_padding_mask_input |
| | ) |
| | x = F.dropout(x, self.encoder_dropout, training=self.training) |
| | x = residual + x |
| |
|
| | |
| | residual = x |
| | if self.use_cln: |
| | x = self.ln_2(x, conditon) |
| | else: |
| | x = self.ln_2(x) |
| | x = self.ffn(x) |
| | x = residual + x |
| |
|
| | return x |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | enc_emb_tokens=None, |
| | encoder_layer=4, |
| | encoder_hidden=256, |
| | encoder_head=4, |
| | conv_filter_size=1024, |
| | conv_kernel_size=5, |
| | encoder_dropout=0.1, |
| | use_cln=False, |
| | cfg=None, |
| | ): |
| | super().__init__() |
| |
|
| | self.encoder_layer = ( |
| | encoder_layer if encoder_layer is not None else cfg.encoder_layer |
| | ) |
| | self.encoder_hidden = ( |
| | encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden |
| | ) |
| | self.encoder_head = ( |
| | encoder_head if encoder_head is not None else cfg.encoder_head |
| | ) |
| | self.conv_filter_size = ( |
| | conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size |
| | ) |
| | self.conv_kernel_size = ( |
| | conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size |
| | ) |
| | self.encoder_dropout = ( |
| | encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout |
| | ) |
| | self.use_cln = use_cln if use_cln is not None else cfg.use_cln |
| |
|
| | if enc_emb_tokens != None: |
| | self.use_enc_emb = True |
| | self.enc_emb_tokens = enc_emb_tokens |
| | else: |
| | self.use_enc_emb = False |
| |
|
| | self.position_emb = PositionalEncoding( |
| | self.encoder_hidden, self.encoder_dropout |
| | ) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | self.layers.extend( |
| | [ |
| | TransformerEncoderLayer( |
| | self.encoder_hidden, |
| | self.encoder_head, |
| | self.conv_filter_size, |
| | self.conv_kernel_size, |
| | self.encoder_dropout, |
| | self.use_cln, |
| | ) |
| | for i in range(self.encoder_layer) |
| | ] |
| | ) |
| |
|
| | if self.use_cln: |
| | self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden) |
| | else: |
| | self.last_ln = nn.LayerNorm(self.encoder_hidden) |
| |
|
| | def forward(self, x, key_padding_mask, condition=None): |
| | if len(x.shape) == 2 and self.use_enc_emb: |
| | x = self.enc_emb_tokens(x) |
| | x = self.position_emb(x) |
| | else: |
| | x = self.position_emb(x) |
| |
|
| | for layer in self.layers: |
| | x = layer(x, key_padding_mask, condition) |
| |
|
| | if self.use_cln: |
| | x = self.last_ln(x, condition) |
| | else: |
| | x = self.last_ln(x) |
| |
|
| | return x |
| |
|