Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| class GLU(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| out, gate = x.chunk(2, dim=self.dim) | |
| return out * gate.sigmoid() | |
| class conform_conv(nn.Module): | |
| def __init__( | |
| self, channels: int, kernel_size: int = 31, DropoutL=0.1, bias: bool = True | |
| ): | |
| super().__init__() | |
| self.act2 = nn.SiLU() | |
| self.act1 = GLU(1) | |
| self.pointwise_conv1 = nn.Conv1d( | |
| channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias | |
| ) | |
| # self.lorder is used to distinguish if it's a causal convolution, | |
| # if self.lorder > 0: | |
| # it's a causal convolution, the input will be padded with | |
| # `self.lorder` frames on the left in forward (causal conv impl). | |
| # else: it's a symmetrical convolution | |
| assert (kernel_size - 1) % 2 == 0 | |
| padding = (kernel_size - 1) // 2 | |
| self.depthwise_conv = nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| stride=1, | |
| padding=padding, | |
| groups=channels, | |
| bias=bias, | |
| ) | |
| self.norm = nn.BatchNorm1d(channels) | |
| self.pointwise_conv2 = nn.Conv1d( | |
| channels, channels, kernel_size=1, stride=1, padding=0, bias=bias | |
| ) | |
| self.drop = nn.Dropout(DropoutL) if DropoutL > 0.0 else nn.Identity() | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| x = self.act1(self.pointwise_conv1(x)) | |
| x = self.depthwise_conv(x) | |
| x = self.norm(x) | |
| x = self.act2(x) | |
| x = self.pointwise_conv2(x) | |
| return self.drop(x).transpose(1, 2) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads=4, dim_head=32, conditiondim=None): | |
| super().__init__() | |
| if conditiondim is None: | |
| conditiondim = dim | |
| self.scale = dim_head**-0.5 | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_q = nn.Linear(dim, hidden_dim, bias=False) | |
| self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False) | |
| self.to_out = nn.Sequential( | |
| nn.Linear( | |
| hidden_dim, | |
| dim, | |
| ), | |
| ) | |
| def forward(self, q, kv=None, mask=None): | |
| # b, c, h, w = x.shape | |
| if kv is None: | |
| kv = q | |
| # q, kv = map( | |
| # lambda t: rearrange(t, "b c t -> b t c", ), (q, kv) | |
| # ) | |
| q = self.to_q(q) | |
| k, v = self.to_kv(kv).chunk(2, dim=2) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v) | |
| ) | |
| if mask is not None: | |
| mask = mask.unsqueeze(1).unsqueeze(1) | |
| with torch.backends.cuda.sdp_kernel(enable_math=False): | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) | |
| out = rearrange( | |
| out, | |
| "b h t c -> b t (h c) ", | |
| h=self.heads, | |
| ) | |
| return self.to_out(out) | |
| class conform_ffn(nn.Module): | |
| def __init__(self, dim, DropoutL1: float = 0.1, DropoutL2: float = 0.1): | |
| super().__init__() | |
| self.ln1 = nn.Linear(dim, dim * 4) | |
| self.ln2 = nn.Linear(dim * 4, dim) | |
| self.drop1 = nn.Dropout(DropoutL1) if DropoutL1 > 0.0 else nn.Identity() | |
| self.drop2 = nn.Dropout(DropoutL2) if DropoutL2 > 0.0 else nn.Identity() | |
| self.act = nn.SiLU() | |
| def forward(self, x): | |
| x = self.ln1(x) | |
| x = self.act(x) | |
| x = self.drop1(x) | |
| x = self.ln2(x) | |
| return self.drop2(x) | |
| class conform_blocke(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| kernel_size: int = 31, | |
| conv_drop: float = 0.1, | |
| ffn_latent_drop: float = 0.1, | |
| ffn_out_drop: float = 0.1, | |
| attention_drop: float = 0.1, | |
| attention_heads: int = 4, | |
| attention_heads_dim: int = 64, | |
| ): | |
| super().__init__() | |
| self.ffn1 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop) | |
| self.ffn2 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop) | |
| self.att = Attention(dim, heads=attention_heads, dim_head=attention_heads_dim) | |
| self.attdrop = ( | |
| nn.Dropout(attention_drop) if attention_drop > 0.0 else nn.Identity() | |
| ) | |
| self.conv = conform_conv( | |
| dim, | |
| kernel_size=kernel_size, | |
| DropoutL=conv_drop, | |
| ) | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.norm3 = nn.LayerNorm(dim) | |
| self.norm4 = nn.LayerNorm(dim) | |
| self.norm5 = nn.LayerNorm(dim) | |
| def forward( | |
| self, | |
| x, | |
| mask=None, | |
| ): | |
| x = self.ffn1(self.norm1(x)) * 0.5 + x | |
| x = self.attdrop(self.att(self.norm2(x), mask=mask)) + x | |
| x = self.conv(self.norm3(x)) + x | |
| x = self.ffn2(self.norm4(x)) * 0.5 + x | |
| return self.norm5(x) | |
| # return x | |
| class Gcf(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| kernel_size: int = 31, | |
| conv_drop: float = 0.1, | |
| ffn_latent_drop: float = 0.1, | |
| ffn_out_drop: float = 0.1, | |
| attention_drop: float = 0.1, | |
| attention_heads: int = 4, | |
| attention_heads_dim: int = 64, | |
| ): | |
| super().__init__() | |
| self.att1 = conform_blocke( | |
| dim=dim, | |
| kernel_size=kernel_size, | |
| conv_drop=conv_drop, | |
| ffn_latent_drop=ffn_latent_drop, | |
| ffn_out_drop=ffn_out_drop, | |
| attention_drop=attention_drop, | |
| attention_heads=attention_heads, | |
| attention_heads_dim=attention_heads_dim, | |
| ) | |
| self.att2 = conform_blocke( | |
| dim=dim, | |
| kernel_size=kernel_size, | |
| conv_drop=conv_drop, | |
| ffn_latent_drop=ffn_latent_drop, | |
| ffn_out_drop=ffn_out_drop, | |
| attention_drop=attention_drop, | |
| attention_heads=attention_heads, | |
| attention_heads_dim=attention_heads_dim, | |
| ) | |
| self.glu1 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2)) | |
| self.glu2 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2)) | |
| def forward(self, midi, bound): | |
| midi = self.att1(midi) | |
| bound = self.att2(bound) | |
| midis = self.glu1(midi) | |
| bounds = self.glu2(bound) | |
| return midi + bounds, bound + midis | |
| class Gmidi_conform(nn.Module): | |
| def __init__( | |
| self, | |
| lay: int, | |
| dim: int, | |
| indim: int, | |
| outdim: int, | |
| use_lay_skip: bool, | |
| kernel_size: int = 31, | |
| conv_drop: float = 0.1, | |
| ffn_latent_drop: float = 0.1, | |
| ffn_out_drop: float = 0.1, | |
| attention_drop: float = 0.1, | |
| attention_heads: int = 4, | |
| attention_heads_dim: int = 64, | |
| ): | |
| super().__init__() | |
| self.inln = nn.Linear(indim, dim) | |
| self.inln1 = nn.Linear(indim, dim) | |
| self.outln = nn.Linear(dim, outdim) | |
| self.cutheard = nn.Linear(dim, 1) | |
| # self.cutheard = nn.Linear(dim, outdim) | |
| self.lay = lay | |
| self.use_lay_skip = use_lay_skip | |
| self.cf_lay = nn.ModuleList( | |
| [ | |
| Gcf( | |
| dim=dim, | |
| kernel_size=kernel_size, | |
| conv_drop=conv_drop, | |
| ffn_latent_drop=ffn_latent_drop, | |
| ffn_out_drop=ffn_out_drop, | |
| attention_drop=attention_drop, | |
| attention_heads=attention_heads, | |
| attention_heads_dim=attention_heads_dim, | |
| ) | |
| for _ in range(lay) | |
| ] | |
| ) | |
| self.att1 = conform_blocke( | |
| dim=dim, | |
| kernel_size=kernel_size, | |
| conv_drop=conv_drop, | |
| ffn_latent_drop=ffn_latent_drop, | |
| ffn_out_drop=ffn_out_drop, | |
| attention_drop=attention_drop, | |
| attention_heads=attention_heads, | |
| attention_heads_dim=attention_heads_dim, | |
| ) | |
| self.att2 = conform_blocke( | |
| dim=dim, | |
| kernel_size=kernel_size, | |
| conv_drop=conv_drop, | |
| ffn_latent_drop=ffn_latent_drop, | |
| ffn_out_drop=ffn_out_drop, | |
| attention_drop=attention_drop, | |
| attention_heads=attention_heads, | |
| attention_heads_dim=attention_heads_dim, | |
| ) | |
| def forward(self, x, mask=None): | |
| x1 = x.clone() | |
| x = self.inln(x) | |
| x1 = self.inln1(x1) | |
| if mask is not None: | |
| x = x.masked_fill(~mask.unsqueeze(-1), 0) | |
| for idx, i in enumerate(self.cf_lay): | |
| x, x1 = i(x, x1) | |
| if mask is not None: | |
| x = x.masked_fill(~mask.unsqueeze(-1), 0) | |
| x, x1 = self.att1(x), self.att2(x1) | |
| cutprp = self.cutheard(x1) | |
| midiout = self.outln(x) | |
| return midiout, cutprp | |