File size: 4,688 Bytes
90f7c1e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import torch
import torch.nn as nn
import math
from einops import rearrange
class LinearAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=32, q_norm=True):
super(LinearAttention, self).__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv1d(hidden_dim, dim, 1)
self.q_norm = q_norm
def forward(self, x):
# b, l, c = x.shape
x = x.permute(0, 2, 1)
# b, c, l = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) l -> qkv b heads c l',
heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
if self.q_norm:
q = q.softmax(dim=-2)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c l -> b (heads c) l',
heads=self.heads)
return self.to_out(out).permute(0, 2, 1)
class TransformerBlock(nn.Module):
def __init__(self, dim, n_heads=4, layer_norm_first=True):
super(TransformerBlock, self).__init__()
dim_head = dim//n_heads
self.attention = LinearAttention(dim, heads=n_heads, dim_head=dim_head)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.feed_forward = nn.Sequential(nn.Linear(dim, dim*2),
nn.SiLU(),
nn.Linear(dim*2, dim))
self.dropout1 = nn.Dropout(0.2)
self.dropout2 = nn.Dropout(0.2)
self.layer_norm_first = layer_norm_first
def forward(self, x):
nx = self.norm1(x)
x = x + self.dropout1(self.attention(nx))
nx = self.norm2(x)
nx = x + self.dropout2(self.feed_forward(nx))
# attention_out = self.attention(x)
# attention_residual_out = attention_out + x
# # print(attention_residual_out.shape)
# norm1_out = self.dropout1(self.norm1(attention_residual_out))
#
# feed_fwd_out = self.feed_forward(norm1_out)
# feed_fwd_residual_out = feed_fwd_out + norm1_out
# norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out))
return nx
class PitchFormer(nn.Module):
def __init__(self, n_mels, hidden_size, attn_layers=4):
super(PitchFormer, self).__init__()
self.sp_linear = nn.Sequential(nn.Conv1d(n_mels, hidden_size, kernel_size=1),
nn.SiLU(),
nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1)
)
self.midi_linear = nn.Sequential(nn.Conv1d(1, hidden_size, kernel_size=1),
nn.SiLU(),
nn.Conv1d(hidden_size, hidden_size//2, kernel_size=1),
)
self.hidden_size = hidden_size
self.pos_conv = nn.Conv1d(hidden_size, hidden_size,
kernel_size=63,
padding=31,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (63 * hidden_size))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, nn.SiLU())
self.attn_block = nn.ModuleList([TransformerBlock(hidden_size, 4) for i in range(attn_layers)])
# self.silu = nn.SiLU()
self.linear = nn.Sequential(nn.Linear(hidden_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, 1))
def forward(self, midi, sp):
midi = midi.unsqueeze(1)
midi = self.midi_linear(midi)
sp = self.sp_linear(sp)
x = torch.cat([midi, sp], dim=1)
# position encoding
x_conv = self.pos_conv(x)
x = x + x_conv
# x = self.silu(x)
x = x.permute(0, 2, 1)
for layer in self.attn_block:
x = layer(x)
x = self.linear(x)
return x.squeeze(-1)
if __name__ == '__main__':
model = PitchFormer(100, 256)
x = torch.rand((4, 64))
sp = torch.rand((4, 100, 64))
midi = torch.rand((4, 64))
y = model(midi, sp) |