Spaces:
Sleeping
Sleeping
File size: 1,239 Bytes
ce367e1 3c27def ce367e1 3c27def 25ba0c9 3c27def ce367e1 3c27def ce367e1 3c27def ce367e1 3c27def ce367e1 25ba0c9 ce367e1 25ba0c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class DynamicConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, K=4, reduction=4):
super().__init__()
self.K = K
self.convs = nn.ModuleList([
nn.Conv1d(in_channels, out_channels, kernel_size,
padding='same')
for _ in range(K)
])
# self.residual_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
self.attn = nn.Sequential(
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(in_channels, in_channels // reduction, 1),
nn.SiLU(),
nn.Conv1d(in_channels // reduction, in_channels // reduction, 1),
nn.SiLU(),
nn.Conv1d(in_channels // reduction, K, 1)
)
nn.init.normal_(self.attn[-1].weight, mean=0, std=0.1)
def forward(self, x):
x = x.permute(0, 2, 1)
attn_logits = self.attn(x)
attn_weights = F.softmax(attn_logits, dim=1)
conv_outs = [conv(x) for conv in self.convs]
out = sum(w * o for w, o in zip(attn_weights.split(1, dim=1), conv_outs))
# residual connection
return out + x
|