| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| | from einops.layers.torch import Rearrange |
| | from ring_attention_pytorch import RingAttention |
| |
|
| | |
| |
|
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def default(val, d): |
| | return val if exists(val) else d |
| |
|
| |
|
| | def calc_same_padding(kernel_size): |
| | pad = kernel_size // 2 |
| | return (pad, pad - (kernel_size + 1) % 2) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class Swish(nn.Module): |
| | def forward(self, x): |
| | return x * x.sigmoid() |
| |
|
| |
|
| | class GLU(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, x): |
| | out, gate = x.chunk(2, dim=self.dim) |
| | return out * gate.sigmoid() |
| |
|
| |
|
| | class DepthWiseConv1d(nn.Module): |
| | def __init__(self, chan_in, chan_out, kernel_size, padding): |
| | super().__init__() |
| | self.padding = padding |
| | self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in) |
| |
|
| | def forward(self, x): |
| | x = F.pad(x, self.padding) |
| | return self.conv(x) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class Scale(nn.Module): |
| | def __init__(self, scale, fn): |
| | super().__init__() |
| | self.fn = fn |
| | self.scale = scale |
| |
|
| | def forward(self, x, **kwargs): |
| | return self.fn(x, **kwargs) * self.scale |
| |
|
| |
|
| | class PreNorm(nn.Module): |
| | def __init__(self, dim, fn): |
| | super().__init__() |
| | self.fn = fn |
| | self.norm = nn.LayerNorm(dim) |
| |
|
| | def forward(self, x, **kwargs): |
| |
|
| | x = self.norm(x.to(x.device)) |
| |
|
| | out = self.fn(x.to(x.device), **kwargs) |
| | |
| | return out |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, dim, mult=4, dropout=0.0): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.Linear(dim, dim * mult), |
| | Swish(), |
| | nn.Dropout(dropout), |
| | nn.Linear(dim * mult, dim), |
| | nn.Dropout(dropout), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| |
|
| | class ConformerConvModule(nn.Module): |
| | def __init__( |
| | self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0 |
| | ): |
| | super().__init__() |
| |
|
| | inner_dim = dim * expansion_factor |
| | padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) |
| |
|
| | self.net = nn.Sequential( |
| | nn.LayerNorm(dim), |
| | Rearrange("b n c -> b c n"), |
| | nn.Conv1d(dim, inner_dim * 2, 1), |
| | GLU(dim=1), |
| | DepthWiseConv1d( |
| | inner_dim, inner_dim, kernel_size=kernel_size, padding=padding |
| | ), |
| | nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), |
| | Swish(), |
| | nn.Conv1d(inner_dim, dim, 1), |
| | Rearrange("b c n -> b n c"), |
| | nn.Dropout(dropout), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class ConformerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | dim_head=64, |
| | heads=8, |
| | ff_mult=4, |
| | conv_expansion_factor=2, |
| | conv_kernel_size=31, |
| | attn_dropout=0.0, |
| | ff_dropout=0.0, |
| | conv_dropout=0.0, |
| | conv_causal=False |
| | ): |
| | super().__init__() |
| | self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) |
| | self.attn = RingAttention( |
| | dim=dim, |
| | dim_head=dim_head, |
| | heads=heads, |
| | causal=True, |
| | auto_shard_seq=False, |
| | ring_attn=True, |
| | ring_seq_size=512, |
| | ) |
| | self.self_attn_dropout = torch.nn.Dropout(attn_dropout) |
| | self.conv = ConformerConvModule( |
| | dim=dim, |
| | causal=conv_causal, |
| | expansion_factor=conv_expansion_factor, |
| | kernel_size=conv_kernel_size, |
| | dropout=conv_dropout, |
| | ) |
| | self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) |
| |
|
| | self.attn = PreNorm(dim, self.attn) |
| | self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) |
| | self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) |
| |
|
| | self.post_norm = nn.LayerNorm(dim) |
| |
|
| |
|
| | def forward(self, x, mask=None): |
| | x_ff1 = self.ff1(x) + x |
| | |
| | x = self.attn(x, mask=mask) |
| | x = self.self_attn_dropout(x) |
| | x = x + x_ff1 |
| | x = self.conv(x) + x |
| | x = self.ff2(x) + x |
| | x = self.post_norm(x) |
| | return x |
| |
|
| |
|
| |
|
| | |
| |
|
| |
|
| | class Conformer(nn.Module): |
| | def __init__( |
| | self, |
| | |
| | dim, |
| | *, |
| | depth, |
| | dim_head=64, |
| | heads=8, |
| | ff_mult=4, |
| | conv_expansion_factor=2, |
| | conv_kernel_size=31, |
| | attn_dropout=0.0, |
| | ff_dropout=0.0, |
| | conv_dropout=0.0, |
| | conv_causal=False |
| | ): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | self.layers = nn.ModuleList([]) |
| |
|
| | for _ in range(depth): |
| | self.layers.append( |
| | ConformerBlock( |
| | dim=dim, |
| | dim_head=dim_head, |
| | heads=heads, |
| | ff_mult=ff_mult, |
| | conv_expansion_factor=conv_expansion_factor, |
| | conv_kernel_size=conv_kernel_size, |
| | conv_causal=conv_causal, |
| | ) |
| | ) |
| |
|
| |
|
| | def forward(self, x): |
| |
|
| | for block in self.layers: |
| | |
| | x = block(x) |
| |
|
| | return x |