import math import torch import torch.nn as nn import torch.nn.functional as F def initialize_weights(m): if isinstance(m, nn.Conv1d): n = m.kernel_size[0] * m.out_channels m.weight.data.normal_(0, math.sqrt(2 / n)) if m.bias is not None: nn.init.constant_(m.bias.data, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight.data, 1) nn.init.constant_(m.bias.data, 0) elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.001) if m.bias is not None: nn.init.constant_(m.bias.data, 0) class SELayer(nn.Module): def __init__(self, inp, reduction=4): super(SELayer, self).__init__() self.fc = nn.Sequential( nn.Linear(inp, int(inp // reduction)), nn.SiLU(), nn.Linear(int(inp // reduction), inp), nn.Sigmoid() ) def forward(self, x): b, c, _, = x.size() y = x.view(b, c, -1).mean(dim=2) y = self.fc(y).view(b, c, 1) return x * y class EffBlock(nn.Module): def __init__(self, in_ch, ks, resize_factor, activation, out_ch=None, se_reduction=None): super().__init__() self.in_ch = in_ch self.out_ch = self.in_ch if out_ch is None else out_ch self.resize_factor = resize_factor self.se_reduction = resize_factor if se_reduction is None else se_reduction self.ks = ks self.inner_dim = self.in_ch * self.resize_factor block = nn.Sequential( nn.Conv1d( in_channels=self.in_ch, out_channels=self.inner_dim, kernel_size=1, padding='same', bias=False ), nn.BatchNorm1d(self.inner_dim), activation(), nn.Conv1d( in_channels=self.inner_dim, out_channels=self.inner_dim, kernel_size=ks, groups=self.inner_dim, padding='same', bias=False ), nn.BatchNorm1d(self.inner_dim), activation(), SELayer(self.inner_dim, reduction=self.se_reduction), nn.Conv1d( in_channels=self.inner_dim, out_channels=self.in_ch, kernel_size=1, padding='same', bias=False ), nn.BatchNorm1d(self.in_ch), activation(), ) self.block = block def forward(self, x): return self.block(x) class LocalBlock(nn.Module): def __init__(self, in_ch, ks, activation, out_ch=None): super().__init__() self.in_ch = in_ch self.out_ch = self.in_ch if out_ch is None else out_ch self.ks = ks self.block = nn.Sequential( nn.Conv1d( in_channels=self.in_ch, out_channels=self.out_ch, kernel_size=self.ks, padding='same', bias=False ), nn.BatchNorm1d(self.out_ch), activation() ) def forward(self, x): return self.block(x) class ResidualConcat(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return torch.concat([self.fn(x, **kwargs), x], dim=1) class MapperBlock(nn.Module): def __init__(self, in_features, out_features, activation=nn.SiLU): super().__init__() self.block = nn.Sequential( nn.BatchNorm1d(in_features), nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=1), ) def forward(self, x): return self.block(x) class LegNet(nn.Module): def __init__(self, in_ch, stem_ch, stem_ks, ef_ks, ef_block_sizes, pool_sizes, resize_factor, activation=nn.SiLU, ): super().__init__() assert len(pool_sizes) == len(ef_block_sizes) self.in_ch = in_ch self.stem = LocalBlock(in_ch=in_ch, out_ch=stem_ch, ks=stem_ks, activation=activation) blocks = [] in_ch = stem_ch out_ch = stem_ch for pool_sz, out_ch in zip(pool_sizes, ef_block_sizes): blc = nn.Sequential( ResidualConcat( EffBlock( in_ch=in_ch, out_ch=in_ch, ks=ef_ks, resize_factor=resize_factor, activation=activation) ), LocalBlock(in_ch=in_ch * 2, out_ch=out_ch, ks=ef_ks, activation=activation), nn.MaxPool1d(pool_sz) if pool_sz != 1 else nn.Identity() ) in_ch = out_ch blocks.append(blc) self.main = nn.Sequential(*blocks) self.mapper = MapperBlock(in_features=out_ch, out_features=out_ch * 2) self.head = nn.Sequential(nn.Linear(out_ch * 2, out_ch * 2), nn.BatchNorm1d(out_ch * 2), activation(), nn.Linear(out_ch * 2, 1)) def forward(self, x): x = self.stem(x) x = self.main(x) x = self.mapper(x) x = F.adaptive_avg_pool1d(x, 1) x = x.squeeze(-1) x = self.head(x) x = x.squeeze(-1) return x