| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from einops.layers.torch import Rearrange |
| from ring_attention_pytorch import RingAttention |
|
|
| |
|
|
|
|
| def exists(val): |
| return val is not None |
|
|
|
|
| def default(val, d): |
| return val if exists(val) else d |
|
|
|
|
| def calc_same_padding(kernel_size): |
| pad = kernel_size // 2 |
| return (pad, pad - (kernel_size + 1) % 2) |
|
|
|
|
| |
|
|
|
|
| class Swish(nn.Module): |
| def forward(self, x): |
| return x * x.sigmoid() |
|
|
|
|
| class GLU(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| out, gate = x.chunk(2, dim=self.dim) |
| return out * gate.sigmoid() |
|
|
|
|
| class DepthWiseConv1d(nn.Module): |
| def __init__(self, chan_in, chan_out, kernel_size, padding): |
| super().__init__() |
| self.padding = padding |
| self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in) |
|
|
| def forward(self, x): |
| x = F.pad(x, self.padding) |
| return self.conv(x) |
|
|
|
|
| |
|
|
|
|
| class Scale(nn.Module): |
| def __init__(self, scale, fn): |
| super().__init__() |
| self.fn = fn |
| self.scale = scale |
|
|
| def forward(self, x, **kwargs): |
| return self.fn(x, **kwargs) * self.scale |
|
|
|
|
| class PreNorm(nn.Module): |
| def __init__(self, dim, fn): |
| super().__init__() |
| self.fn = fn |
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x, **kwargs): |
|
|
| x = self.norm(x.to(x.device)) |
|
|
| out = self.fn(x.to(x.device), **kwargs) |
| |
| return out |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, mult=4, dropout=0.0): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, dim * mult), |
| Swish(), |
| nn.Dropout(dropout), |
| nn.Linear(dim * mult, dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class ConformerConvModule(nn.Module): |
| def __init__( |
| self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0 |
| ): |
| super().__init__() |
|
|
| inner_dim = dim * expansion_factor |
| padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) |
|
|
| self.net = nn.Sequential( |
| nn.LayerNorm(dim), |
| Rearrange("b n c -> b c n"), |
| nn.Conv1d(dim, inner_dim * 2, 1), |
| GLU(dim=1), |
| DepthWiseConv1d( |
| inner_dim, inner_dim, kernel_size=kernel_size, padding=padding |
| ), |
| nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), |
| Swish(), |
| nn.Conv1d(inner_dim, dim, 1), |
| Rearrange("b c n -> b n c"), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| |
|
|
|
|
| class ConformerBlock(nn.Module): |
| def __init__( |
| self, |
| *, |
| dim, |
| dim_head=64, |
| heads=8, |
| ff_mult=4, |
| conv_expansion_factor=2, |
| conv_kernel_size=31, |
| attn_dropout=0.0, |
| ff_dropout=0.0, |
| conv_dropout=0.0, |
| conv_causal=False |
| ): |
| super().__init__() |
| self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) |
| self.attn = RingAttention( |
| dim=dim, |
| dim_head=dim_head, |
| heads=heads, |
| causal=True, |
| auto_shard_seq=False, |
| ring_attn=True, |
| ring_seq_size=512, |
| ) |
| self.self_attn_dropout = torch.nn.Dropout(attn_dropout) |
| self.conv = ConformerConvModule( |
| dim=dim, |
| causal=conv_causal, |
| expansion_factor=conv_expansion_factor, |
| kernel_size=conv_kernel_size, |
| dropout=conv_dropout, |
| ) |
| self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) |
|
|
| self.attn = PreNorm(dim, self.attn) |
| self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) |
| self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) |
|
|
| self.post_norm = nn.LayerNorm(dim) |
|
|
|
|
| def forward(self, x, mask=None): |
| x_ff1 = self.ff1(x) + x |
| |
| x = self.attn(x, mask=mask) |
| x = self.self_attn_dropout(x) |
| x = x + x_ff1 |
| x = self.conv(x) + x |
| x = self.ff2(x) + x |
| x = self.post_norm(x) |
| return x |
|
|
|
|
|
|
| |
|
|
|
|
| class Conformer(nn.Module): |
| def __init__( |
| self, |
| |
| dim, |
| *, |
| depth, |
| dim_head=64, |
| heads=8, |
| ff_mult=4, |
| conv_expansion_factor=2, |
| conv_kernel_size=31, |
| attn_dropout=0.0, |
| ff_dropout=0.0, |
| conv_dropout=0.0, |
| conv_causal=False |
| ): |
| super().__init__() |
| self.dim = dim |
|
|
| self.layers = nn.ModuleList([]) |
|
|
| for _ in range(depth): |
| self.layers.append( |
| ConformerBlock( |
| dim=dim, |
| dim_head=dim_head, |
| heads=heads, |
| ff_mult=ff_mult, |
| conv_expansion_factor=conv_expansion_factor, |
| conv_kernel_size=conv_kernel_size, |
| conv_causal=conv_causal, |
| ) |
| ) |
|
|
|
|
| def forward(self, x): |
|
|
| for block in self.layers: |
| |
| x = block(x) |
|
|
| return x |