Ni-os commited on
Commit
ccfe1ec
·
verified ·
1 Parent(s): 09838c6

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +187 -0
model.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def initialize_weights(m):
8
+ if isinstance(m, nn.Conv1d):
9
+ n = m.kernel_size[0] * m.out_channels
10
+ m.weight.data.normal_(0, math.sqrt(2 / n))
11
+ if m.bias is not None:
12
+ nn.init.constant_(m.bias.data, 0)
13
+ elif isinstance(m, nn.BatchNorm1d):
14
+ nn.init.constant_(m.weight.data, 1)
15
+ nn.init.constant_(m.bias.data, 0)
16
+ elif isinstance(m, nn.Linear):
17
+ m.weight.data.normal_(0, 0.001)
18
+ if m.bias is not None:
19
+ nn.init.constant_(m.bias.data, 0)
20
+
21
+ class SELayer(nn.Module):
22
+ def __init__(self, inp, reduction=4):
23
+ super(SELayer, self).__init__()
24
+ self.fc = nn.Sequential(
25
+ nn.Linear(inp, int(inp // reduction)),
26
+ nn.SiLU(),
27
+ nn.Linear(int(inp // reduction), inp),
28
+ nn.Sigmoid()
29
+ )
30
+
31
+ def forward(self, x):
32
+ b, c, _, = x.size()
33
+ y = x.view(b, c, -1).mean(dim=2)
34
+ y = self.fc(y).view(b, c, 1)
35
+ return x * y
36
+
37
+ class EffBlock(nn.Module):
38
+ def __init__(self, in_ch, ks, resize_factor, activation, out_ch=None, se_reduction=None):
39
+ super().__init__()
40
+ self.in_ch = in_ch
41
+ self.out_ch = self.in_ch if out_ch is None else out_ch
42
+ self.resize_factor = resize_factor
43
+ self.se_reduction = resize_factor if se_reduction is None else se_reduction
44
+ self.ks = ks
45
+ self.inner_dim = self.in_ch * self.resize_factor
46
+
47
+ block = nn.Sequential(
48
+ nn.Conv1d(
49
+ in_channels=self.in_ch,
50
+ out_channels=self.inner_dim,
51
+ kernel_size=1,
52
+ padding='same',
53
+ bias=False
54
+ ),
55
+ nn.BatchNorm1d(self.inner_dim),
56
+ activation(),
57
+
58
+ nn.Conv1d(
59
+ in_channels=self.inner_dim,
60
+ out_channels=self.inner_dim,
61
+ kernel_size=ks,
62
+ groups=self.inner_dim,
63
+ padding='same',
64
+ bias=False
65
+ ),
66
+ nn.BatchNorm1d(self.inner_dim),
67
+ activation(),
68
+ SELayer(self.inner_dim, reduction=self.se_reduction),
69
+ nn.Conv1d(
70
+ in_channels=self.inner_dim,
71
+ out_channels=self.in_ch,
72
+ kernel_size=1,
73
+ padding='same',
74
+ bias=False
75
+ ),
76
+ nn.BatchNorm1d(self.in_ch),
77
+ activation(),
78
+ )
79
+
80
+ self.block = block
81
+
82
+ def forward(self, x):
83
+ return self.block(x)
84
+
85
+ class LocalBlock(nn.Module):
86
+ def __init__(self, in_ch, ks, activation, out_ch=None):
87
+ super().__init__()
88
+ self.in_ch = in_ch
89
+ self.out_ch = self.in_ch if out_ch is None else out_ch
90
+ self.ks = ks
91
+
92
+ self.block = nn.Sequential(
93
+ nn.Conv1d(
94
+ in_channels=self.in_ch,
95
+ out_channels=self.out_ch,
96
+ kernel_size=self.ks,
97
+ padding='same',
98
+ bias=False
99
+ ),
100
+ nn.BatchNorm1d(self.out_ch),
101
+ activation()
102
+ )
103
+
104
+ def forward(self, x):
105
+ return self.block(x)
106
+
107
+ class ResidualConcat(nn.Module):
108
+ def __init__(self, fn):
109
+ super().__init__()
110
+ self.fn = fn
111
+
112
+ def forward(self, x, **kwargs):
113
+ return torch.concat([self.fn(x, **kwargs), x], dim=1)
114
+
115
+ class MapperBlock(nn.Module):
116
+ def __init__(self, in_features, out_features, activation=nn.SiLU):
117
+ super().__init__()
118
+ self.block = nn.Sequential(
119
+ nn.BatchNorm1d(in_features),
120
+ nn.Conv1d(in_channels=in_features,
121
+ out_channels=out_features,
122
+ kernel_size=1),
123
+ )
124
+
125
+ def forward(self, x):
126
+ return self.block(x)
127
+
128
+ class LegNet(nn.Module):
129
+ def __init__(self,
130
+ in_ch,
131
+ stem_ch,
132
+ stem_ks,
133
+ ef_ks,
134
+ ef_block_sizes,
135
+ pool_sizes,
136
+ resize_factor,
137
+ activation=nn.SiLU,
138
+ ):
139
+ super().__init__()
140
+ assert len(pool_sizes) == len(ef_block_sizes)
141
+
142
+ self.in_ch = in_ch
143
+ self.stem = LocalBlock(in_ch=in_ch,
144
+ out_ch=stem_ch,
145
+ ks=stem_ks,
146
+ activation=activation)
147
+
148
+ blocks = []
149
+
150
+ in_ch = stem_ch
151
+ out_ch = stem_ch
152
+ for pool_sz, out_ch in zip(pool_sizes, ef_block_sizes):
153
+ blc = nn.Sequential(
154
+ ResidualConcat(
155
+ EffBlock(
156
+ in_ch=in_ch,
157
+ out_ch=in_ch,
158
+ ks=ef_ks,
159
+ resize_factor=resize_factor,
160
+ activation=activation)
161
+ ),
162
+ LocalBlock(in_ch=in_ch * 2,
163
+ out_ch=out_ch,
164
+ ks=ef_ks,
165
+ activation=activation),
166
+ nn.MaxPool1d(pool_sz) if pool_sz != 1 else nn.Identity()
167
+ )
168
+ in_ch = out_ch
169
+ blocks.append(blc)
170
+ self.main = nn.Sequential(*blocks)
171
+
172
+ self.mapper = MapperBlock(in_features=out_ch,
173
+ out_features=out_ch * 2)
174
+ self.head = nn.Sequential(nn.Linear(out_ch * 2, out_ch * 2),
175
+ nn.BatchNorm1d(out_ch * 2),
176
+ activation(),
177
+ nn.Linear(out_ch * 2, 1))
178
+
179
+ def forward(self, x):
180
+ x = self.stem(x)
181
+ x = self.main(x)
182
+ x = self.mapper(x)
183
+ x = F.adaptive_avg_pool1d(x, 1)
184
+ x = x.squeeze(-1)
185
+ x = self.head(x)
186
+ x = x.squeeze(-1)
187
+ return x