xjsc0's picture
1
64ec292
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()
class conform_conv(nn.Module):
def __init__(
self, channels: int, kernel_size: int = 31, DropoutL=0.1, bias: bool = True
):
super().__init__()
self.act2 = nn.SiLU()
self.act1 = GLU(1)
self.pointwise_conv1 = nn.Conv1d(
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
)
self.drop = nn.Dropout(DropoutL) if DropoutL > 0.0 else nn.Identity()
def forward(self, x):
x = x.transpose(1, 2)
x = self.act1(self.pointwise_conv1(x))
x = self.depthwise_conv(x)
x = self.norm(x)
x = self.act2(x)
x = self.pointwise_conv2(x)
return self.drop(x).transpose(1, 2)
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32, conditiondim=None):
super().__init__()
if conditiondim is None:
conditiondim = dim
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_q = nn.Linear(dim, hidden_dim, bias=False)
self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False)
self.to_out = nn.Sequential(
nn.Linear(
hidden_dim,
dim,
),
)
def forward(self, q, kv=None, mask=None):
# b, c, h, w = x.shape
if kv is None:
kv = q
# q, kv = map(
# lambda t: rearrange(t, "b c t -> b t c", ), (q, kv)
# )
q = self.to_q(q)
k, v = self.to_kv(kv).chunk(2, dim=2)
q, k, v = map(
lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v)
)
if mask is not None:
mask = mask.unsqueeze(1).unsqueeze(1)
with torch.backends.cuda.sdp_kernel(enable_math=False):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out = rearrange(
out,
"b h t c -> b t (h c) ",
h=self.heads,
)
return self.to_out(out)
class conform_ffn(nn.Module):
def __init__(self, dim, DropoutL1: float = 0.1, DropoutL2: float = 0.1):
super().__init__()
self.ln1 = nn.Linear(dim, dim * 4)
self.ln2 = nn.Linear(dim * 4, dim)
self.drop1 = nn.Dropout(DropoutL1) if DropoutL1 > 0.0 else nn.Identity()
self.drop2 = nn.Dropout(DropoutL2) if DropoutL2 > 0.0 else nn.Identity()
self.act = nn.SiLU()
def forward(self, x):
x = self.ln1(x)
x = self.act(x)
x = self.drop1(x)
x = self.ln2(x)
return self.drop2(x)
class conform_blocke(nn.Module):
def __init__(
self,
dim: int,
kernel_size: int = 31,
conv_drop: float = 0.1,
ffn_latent_drop: float = 0.1,
ffn_out_drop: float = 0.1,
attention_drop: float = 0.1,
attention_heads: int = 4,
attention_heads_dim: int = 64,
):
super().__init__()
self.ffn1 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
self.ffn2 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
self.att = Attention(dim, heads=attention_heads, dim_head=attention_heads_dim)
self.attdrop = (
nn.Dropout(attention_drop) if attention_drop > 0.0 else nn.Identity()
)
self.conv = conform_conv(
dim,
kernel_size=kernel_size,
DropoutL=conv_drop,
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.norm4 = nn.LayerNorm(dim)
self.norm5 = nn.LayerNorm(dim)
def forward(
self,
x,
mask=None,
):
x = self.ffn1(self.norm1(x)) * 0.5 + x
x = self.attdrop(self.att(self.norm2(x), mask=mask)) + x
x = self.conv(self.norm3(x)) + x
x = self.ffn2(self.norm4(x)) * 0.5 + x
return self.norm5(x)
# return x
class Gcf(nn.Module):
def __init__(
self,
dim: int,
kernel_size: int = 31,
conv_drop: float = 0.1,
ffn_latent_drop: float = 0.1,
ffn_out_drop: float = 0.1,
attention_drop: float = 0.1,
attention_heads: int = 4,
attention_heads_dim: int = 64,
):
super().__init__()
self.att1 = conform_blocke(
dim=dim,
kernel_size=kernel_size,
conv_drop=conv_drop,
ffn_latent_drop=ffn_latent_drop,
ffn_out_drop=ffn_out_drop,
attention_drop=attention_drop,
attention_heads=attention_heads,
attention_heads_dim=attention_heads_dim,
)
self.att2 = conform_blocke(
dim=dim,
kernel_size=kernel_size,
conv_drop=conv_drop,
ffn_latent_drop=ffn_latent_drop,
ffn_out_drop=ffn_out_drop,
attention_drop=attention_drop,
attention_heads=attention_heads,
attention_heads_dim=attention_heads_dim,
)
self.glu1 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2))
self.glu2 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2))
def forward(self, midi, bound):
midi = self.att1(midi)
bound = self.att2(bound)
midis = self.glu1(midi)
bounds = self.glu2(bound)
return midi + bounds, bound + midis
class Gmidi_conform(nn.Module):
def __init__(
self,
lay: int,
dim: int,
indim: int,
outdim: int,
use_lay_skip: bool,
kernel_size: int = 31,
conv_drop: float = 0.1,
ffn_latent_drop: float = 0.1,
ffn_out_drop: float = 0.1,
attention_drop: float = 0.1,
attention_heads: int = 4,
attention_heads_dim: int = 64,
):
super().__init__()
self.inln = nn.Linear(indim, dim)
self.inln1 = nn.Linear(indim, dim)
self.outln = nn.Linear(dim, outdim)
self.cutheard = nn.Linear(dim, 1)
# self.cutheard = nn.Linear(dim, outdim)
self.lay = lay
self.use_lay_skip = use_lay_skip
self.cf_lay = nn.ModuleList(
[
Gcf(
dim=dim,
kernel_size=kernel_size,
conv_drop=conv_drop,
ffn_latent_drop=ffn_latent_drop,
ffn_out_drop=ffn_out_drop,
attention_drop=attention_drop,
attention_heads=attention_heads,
attention_heads_dim=attention_heads_dim,
)
for _ in range(lay)
]
)
self.att1 = conform_blocke(
dim=dim,
kernel_size=kernel_size,
conv_drop=conv_drop,
ffn_latent_drop=ffn_latent_drop,
ffn_out_drop=ffn_out_drop,
attention_drop=attention_drop,
attention_heads=attention_heads,
attention_heads_dim=attention_heads_dim,
)
self.att2 = conform_blocke(
dim=dim,
kernel_size=kernel_size,
conv_drop=conv_drop,
ffn_latent_drop=ffn_latent_drop,
ffn_out_drop=ffn_out_drop,
attention_drop=attention_drop,
attention_heads=attention_heads,
attention_heads_dim=attention_heads_dim,
)
def forward(self, x, mask=None):
x1 = x.clone()
x = self.inln(x)
x1 = self.inln1(x1)
if mask is not None:
x = x.masked_fill(~mask.unsqueeze(-1), 0)
for idx, i in enumerate(self.cf_lay):
x, x1 = i(x, x1)
if mask is not None:
x = x.masked_fill(~mask.unsqueeze(-1), 0)
x, x1 = self.att1(x), self.att2(x1)
cutprp = self.cutheard(x1)
midiout = self.outln(x)
return midiout, cutprp