MPRALegNet / model.py
Ni-os's picture
Upload model.py
ccfe1ec verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def initialize_weights(m):
if isinstance(m, nn.Conv1d):
n = m.kernel_size[0] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2 / n))
if m.bias is not None:
nn.init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight.data, 1)
nn.init.constant_(m.bias.data, 0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.001)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0)
class SELayer(nn.Module):
def __init__(self, inp, reduction=4):
super(SELayer, self).__init__()
self.fc = nn.Sequential(
nn.Linear(inp, int(inp // reduction)),
nn.SiLU(),
nn.Linear(int(inp // reduction), inp),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, = x.size()
y = x.view(b, c, -1).mean(dim=2)
y = self.fc(y).view(b, c, 1)
return x * y
class EffBlock(nn.Module):
def __init__(self, in_ch, ks, resize_factor, activation, out_ch=None, se_reduction=None):
super().__init__()
self.in_ch = in_ch
self.out_ch = self.in_ch if out_ch is None else out_ch
self.resize_factor = resize_factor
self.se_reduction = resize_factor if se_reduction is None else se_reduction
self.ks = ks
self.inner_dim = self.in_ch * self.resize_factor
block = nn.Sequential(
nn.Conv1d(
in_channels=self.in_ch,
out_channels=self.inner_dim,
kernel_size=1,
padding='same',
bias=False
),
nn.BatchNorm1d(self.inner_dim),
activation(),
nn.Conv1d(
in_channels=self.inner_dim,
out_channels=self.inner_dim,
kernel_size=ks,
groups=self.inner_dim,
padding='same',
bias=False
),
nn.BatchNorm1d(self.inner_dim),
activation(),
SELayer(self.inner_dim, reduction=self.se_reduction),
nn.Conv1d(
in_channels=self.inner_dim,
out_channels=self.in_ch,
kernel_size=1,
padding='same',
bias=False
),
nn.BatchNorm1d(self.in_ch),
activation(),
)
self.block = block
def forward(self, x):
return self.block(x)
class LocalBlock(nn.Module):
def __init__(self, in_ch, ks, activation, out_ch=None):
super().__init__()
self.in_ch = in_ch
self.out_ch = self.in_ch if out_ch is None else out_ch
self.ks = ks
self.block = nn.Sequential(
nn.Conv1d(
in_channels=self.in_ch,
out_channels=self.out_ch,
kernel_size=self.ks,
padding='same',
bias=False
),
nn.BatchNorm1d(self.out_ch),
activation()
)
def forward(self, x):
return self.block(x)
class ResidualConcat(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return torch.concat([self.fn(x, **kwargs), x], dim=1)
class MapperBlock(nn.Module):
def __init__(self, in_features, out_features, activation=nn.SiLU):
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm1d(in_features),
nn.Conv1d(in_channels=in_features,
out_channels=out_features,
kernel_size=1),
)
def forward(self, x):
return self.block(x)
class LegNet(nn.Module):
def __init__(self,
in_ch,
stem_ch,
stem_ks,
ef_ks,
ef_block_sizes,
pool_sizes,
resize_factor,
activation=nn.SiLU,
):
super().__init__()
assert len(pool_sizes) == len(ef_block_sizes)
self.in_ch = in_ch
self.stem = LocalBlock(in_ch=in_ch,
out_ch=stem_ch,
ks=stem_ks,
activation=activation)
blocks = []
in_ch = stem_ch
out_ch = stem_ch
for pool_sz, out_ch in zip(pool_sizes, ef_block_sizes):
blc = nn.Sequential(
ResidualConcat(
EffBlock(
in_ch=in_ch,
out_ch=in_ch,
ks=ef_ks,
resize_factor=resize_factor,
activation=activation)
),
LocalBlock(in_ch=in_ch * 2,
out_ch=out_ch,
ks=ef_ks,
activation=activation),
nn.MaxPool1d(pool_sz) if pool_sz != 1 else nn.Identity()
)
in_ch = out_ch
blocks.append(blc)
self.main = nn.Sequential(*blocks)
self.mapper = MapperBlock(in_features=out_ch,
out_features=out_ch * 2)
self.head = nn.Sequential(nn.Linear(out_ch * 2, out_ch * 2),
nn.BatchNorm1d(out_ch * 2),
activation(),
nn.Linear(out_ch * 2, 1))
def forward(self, x):
x = self.stem(x)
x = self.main(x)
x = self.mapper(x)
x = F.adaptive_avg_pool1d(x, 1)
x = x.squeeze(-1)
x = self.head(x)
x = x.squeeze(-1)
return x