dfd-1b-v1 / conformer.py
trisongz's picture
Duplicate from Speech-Arena-2025/DF_Arena_1B_V_1
63dd064
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.nn.modules.transformer import _get_clones
from torch import Tensor
from einops import rearrange
from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
# helper classes
class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()
class DepthWiseConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)
# attention, feedforward, and conv module
class Scale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class Attention(nn.Module):
# Head Token attention: https://arxiv.org/pdf/2210.05958.pdf
def __init__(self, dim, heads=8, dim_head=64, qkv_bias=False, dropout=0., proj_drop=0.):
super().__init__()
self.num_heads = heads
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(inner_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.act = nn.GELU()
self.ht_proj = nn.Linear(dim_head, dim,bias=True)
self.ht_norm = nn.LayerNorm(dim_head)
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_heads, dim))
def forward(self, x, mask=None):
B, N, C = x.shape
# head token
head_pos = self.pos_embed.expand(x.shape[0], -1, -1)
x_ = x.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
x_ = x_.mean(dim=2) # now the shape is [B, h, 1, d//h]
x_ = self.ht_proj(x_).reshape(B, -1, self.num_heads, C // self.num_heads)
x_ = self.act(self.ht_norm(x_)).flatten(2)
x_ = x_ + head_pos
x = torch.cat([x, x_], dim=1)
# normal mhsa
qkv = self.qkv(x).reshape(B, N+self.num_heads, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N+self.num_heads, C)
x = self.proj(x)
# merge head tokens into cls token
cls, patch, ht = torch.split(x, [1, N-1, self.num_heads], dim=1)
cls = cls + torch.mean(ht, dim=1, keepdim=True) + torch.mean(patch, dim=1, keepdim=True)
x = torch.cat([cls, patch], dim=1)
x = self.proj_drop(x)
return x, attn
class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class ConformerConvModule(nn.Module):
def __init__(
self,
dim,
causal = False,
expansion_factor = 2,
kernel_size = 31,
dropout = 0.
):
super().__init__()
inner_dim = dim * expansion_factor
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
Swish(),
nn.Conv1d(inner_dim, dim, 1),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# Conformer Block
class ConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
self.post_norm = nn.LayerNorm(dim)
def forward(self, x, mask = None):
x = self.ff1(x) + x
attn_x, attn_weight = self.attn(x, mask = mask)
x = attn_x + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x, attn_weight
# Conformer
class Conformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(ConformerBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,
conv_causal = conv_causal
))
def forward(self, x):
for block in self.layers:
x = block(x)
return x
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
class FinalConformer(nn.Module):
def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1):
super(FinalConformer, self).__init__()
self.dim_head=int(emb_size/heads)
self.dim=emb_size
self.heads=heads
self.kernel_size=kernel_size
self.n_encoders=n_encoders
self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False)
self.encoder_blocks=_get_clones(ConformerBlock( dim = emb_size, dim_head=self.dim_head, heads= heads,
ff_mult = ffmult, conv_expansion_factor = exp_fac, conv_kernel_size = kernel_size),
n_encoders)
self.class_token = nn.Parameter(torch.rand(1, emb_size))
self.fc5 = nn.Linear(emb_size, 2)
def forward(self, x): # x shape [bs, tiempo, frecuencia]
x = x + self.positional_emb[:, :x.size(1), :]
x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])#[bs,1+tiempo,emb_size]
list_attn_weight = []
for layer in self.encoder_blocks:
x, attn_weight = layer(x) #[bs,1+tiempo,emb_size]
list_attn_weight.append(attn_weight)
embedding=x[:,0,:] #[bs, emb_size]
out=self.fc5(embedding) #[bs,2]
return out, list_attn_weight