NeoPy commited on
Commit
71ab593
·
verified ·
1 Parent(s): 6ad0709

Create hpa-rmvpe.py

Browse files
Files changed (1) hide show
  1. hpa-rmvpe.py +592 -0
hpa-rmvpe.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from librosa.filters import mel
10
+
11
+ sys.path.append(os.getcwd())
12
+
13
+ N_MELS, N_CLASS = 128, 360
14
+
15
+ def autopad(k, p=None):
16
+ if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
17
+ return p
18
+
19
+ class Conv(nn.Module):
20
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
21
+ super().__init__()
22
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
23
+ self.bn = nn.BatchNorm2d(c2)
24
+ self.act = nn.SiLU() if act else nn.Identity()
25
+
26
+ def forward(self, x):
27
+ return self.act(self.bn(self.conv(x)))
28
+
29
+ class DSConv(nn.Module):
30
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
31
+ super().__init__()
32
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
33
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
34
+ self.bn = nn.BatchNorm2d(c2)
35
+ self.act = nn.SiLU() if act else nn.Identity()
36
+
37
+ def forward(self, x):
38
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
39
+
40
+ class DS_Bottleneck(nn.Module):
41
+ def __init__(self, c1, c2, k=3, shortcut=True):
42
+ super().__init__()
43
+ self.dsconv1 = DSConv(c1, c1, k=3, s=1)
44
+ self.dsconv2 = DSConv(c1, c2, k=k, s=1)
45
+ self.shortcut = shortcut and c1 == c2
46
+
47
+ def forward(self, x):
48
+ return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x))
49
+
50
+ class DS_C3k(nn.Module):
51
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
52
+ super().__init__()
53
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
54
+ self.cv2 = Conv(c1, int(c2 * e), 1, 1)
55
+ self.cv3 = Conv(2 * int(c2 * e), c2, 1, 1)
56
+ self.m = nn.Sequential(*[DS_Bottleneck(int(c2 * e), int(c2 * e), k=k, shortcut=True) for _ in range(n)])
57
+
58
+ def forward(self, x):
59
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
60
+
61
+ class DS_C3k2(nn.Module):
62
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
63
+ super().__init__()
64
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
65
+ self.m = DS_C3k(int(c2 * e), int(c2 * e), n=n, k=k, e=1.0)
66
+ self.cv2 = Conv(int(c2 * e), c2, 1, 1)
67
+
68
+ def forward(self, x):
69
+ return self.cv2(self.m(self.cv1(x)))
70
+
71
+ class AdaptiveHyperedgeGeneration(nn.Module):
72
+ def __init__(self, in_channels, num_hyperedges, num_heads):
73
+ super().__init__()
74
+ self.num_hyperedges = num_hyperedges
75
+ self.num_heads = num_heads
76
+ self.head_dim = max(1, in_channels // num_heads)
77
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
78
+ self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)
79
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
80
+ self.scale = self.head_dim ** -0.5
81
+
82
+ def forward(self, x):
83
+ B, N, C = x.shape
84
+ P = self.global_proto.unsqueeze(0) + self.context_mapper(torch.cat((F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1), F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)), dim=1)).view(B, self.num_hyperedges, C)
85
+
86
+ return F.softmax(((self.query_proj(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) @ P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)) * self.scale).mean(dim=1).permute(0, 2, 1), dim=-1)
87
+
88
+ class HypergraphConvolution(nn.Module):
89
+ def __init__(self, in_channels, out_channels):
90
+ super().__init__()
91
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
92
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
93
+ self.act = nn.SiLU()
94
+
95
+ def forward(self, x, A):
96
+ return x + self.act(self.W_v(A.transpose(1, 2).bmm(self.act(self.W_e(A.bmm(x))))))
97
+
98
+ class AdaptiveHypergraphComputation(nn.Module):
99
+ def __init__(self, in_channels, out_channels, num_hyperedges, num_heads):
100
+ super().__init__()
101
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(in_channels, num_hyperedges, num_heads)
102
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
103
+
104
+ def forward(self, x):
105
+ B, _, H, W = x.shape
106
+ x_flat = x.flatten(2).permute(0, 2, 1)
107
+ return self.hypergraph_conv(x_flat, self.adaptive_hyperedge_gen(x_flat)).permute(0, 2, 1).view(B, -1, H, W)
108
+
109
+ class C3AH(nn.Module):
110
+ def __init__(self, c1, c2, num_hyperedges, num_heads, e=0.5):
111
+ super().__init__()
112
+ self.cv1 = Conv(c1, int(c1 * e), 1, 1)
113
+ self.cv2 = Conv(c1, int(c1 * e), 1, 1)
114
+ self.ahc = AdaptiveHypergraphComputation(int(c1 * e), int(c1 * e), num_hyperedges, num_heads)
115
+ self.cv3 = Conv(2 * int(c1 * e), c2, 1, 1)
116
+
117
+ def forward(self, x):
118
+ return self.cv3(torch.cat((self.ahc(self.cv2(x)), self.cv1(x)), dim=1))
119
+
120
+ class HyperACE(nn.Module):
121
+ def __init__(self, in_channels, out_channels, num_hyperedges=16, num_heads=8, k=2, l=1, c_h=0.5, c_l=0.25):
122
+ super().__init__()
123
+ c2, c3, c4, c5 = in_channels
124
+ c_mid = c4
125
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
126
+ self.c_h = int(c_mid * c_h)
127
+ self.c_l = int(c_mid * c_l)
128
+ self.c_s = c_mid - self.c_h - self.c_l
129
+ self.high_order_branch = nn.ModuleList([C3AH(self.c_h, self.c_h, num_hyperedges=num_hyperedges, num_heads=num_heads, e=1.0) for _ in range(k)])
130
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
131
+ self.low_order_branch = nn.Sequential(*[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)])
132
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
133
+
134
+ def forward(self, x):
135
+ B2, B3, B4, B5 = x
136
+ _, _, H4, W4 = B4.shape
137
+
138
+ x_h, x_l, x_s = self.fuse_conv(
139
+ torch.cat(
140
+ (
141
+ F.interpolate(B2, size=(H4, W4), mode='bilinear', align_corners=False),
142
+ F.interpolate(B3, size=(H4, W4), mode='bilinear', align_corners=False),
143
+ B4,
144
+ F.interpolate(B5, size=(H4, W4), mode='bilinear', align_corners=False)
145
+ ),
146
+ dim=1
147
+ )
148
+ ).split([self.c_h, self.c_l, self.c_s], dim=1)
149
+
150
+ return self.final_fuse(torch.cat((self.high_order_fuse(torch.cat([m(x_h) for m in self.high_order_branch], dim=1)), self.low_order_branch(x_l), x_s), dim=1))
151
+
152
+ class GatedFusion(nn.Module):
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
156
+
157
+ def forward(self, f_in, h):
158
+ return f_in + self.gamma * h
159
+
160
+ class YOLO13Encoder(nn.Module):
161
+ def __init__(self, in_channels, base_channels=32):
162
+ super().__init__()
163
+ self.stem = DSConv(in_channels, base_channels, k=3, s=1)
164
+
165
+ self.p2 = nn.Sequential(
166
+ DSConv(base_channels, base_channels*2, k=3, s=(2, 2)),
167
+ DS_C3k2(base_channels*2, base_channels*2, n=1)
168
+ )
169
+
170
+ self.p3 = nn.Sequential(
171
+ DSConv(base_channels*2, base_channels*4, k=3, s=(2, 2)),
172
+ DS_C3k2(base_channels*4, base_channels*4, n=2)
173
+ )
174
+
175
+ self.p4 = nn.Sequential(
176
+ DSConv(base_channels*4, base_channels*8, k=3, s=(2, 2)),
177
+ DS_C3k2(base_channels*8, base_channels*8, n=2)
178
+ )
179
+
180
+ self.p5 = nn.Sequential(
181
+ DSConv(base_channels*8, base_channels*16, k=3, s=(2, 2)),
182
+ DS_C3k2(base_channels*16, base_channels*16, n=1)
183
+ )
184
+
185
+ self.out_channels = [base_channels*2, base_channels*4, base_channels*8, base_channels*16]
186
+
187
+ def forward(self, x):
188
+ x = self.stem(x)
189
+ p2 = self.p2(x)
190
+ p3 = self.p3(p2)
191
+ p4 = self.p4(p3)
192
+ p5 = self.p5(p4)
193
+ return [p2, p3, p4, p5]
194
+
195
+ class YOLO13FullPADDecoder(nn.Module):
196
+ def __init__(self, encoder_channels, hyperace_out_c, out_channels_final):
197
+ super().__init__()
198
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
199
+ c_d5, c_d4, c_d3, c_d2 = c_p5, c_p4, c_p3, c_p2
200
+
201
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
202
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
203
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
204
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
205
+
206
+ self.fusion_d5 = GatedFusion(c_d5)
207
+ self.fusion_d4 = GatedFusion(c_d4)
208
+ self.fusion_d3 = GatedFusion(c_d3)
209
+ self.fusion_d2 = GatedFusion(c_d2)
210
+
211
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
212
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
213
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
214
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
215
+
216
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
217
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
218
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
219
+
220
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
221
+ self.final_conv = Conv(c_d2, out_channels_final, 1, 1)
222
+
223
+ def forward(self, enc_feats, h_ace):
224
+ p2, p3, p4, p5 = enc_feats
225
+
226
+ d5 = self.skip_p5(p5)
227
+ d4 = self.up_d5(F.interpolate(self.fusion_d5(d5, self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear', align_corners=False))), size=p4.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p4(p4)
228
+ d3 = self.up_d4(F.interpolate(self.fusion_d4(d4, self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear', align_corners=False))), size=p3.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p3(p3)
229
+ d2 = self.up_d3(F.interpolate(self.fusion_d3(d3, self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear', align_corners=False))), size=p2.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p2(p2)
230
+
231
+ return self.final_conv(self.final_d2(self.fusion_d2(d2, self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear', align_corners=False)))))
232
+
233
+ class ConvBlockRes(nn.Module):
234
+ def __init__(self, in_channels, out_channels, momentum=0.01):
235
+ super(ConvBlockRes, self).__init__()
236
+ self.conv = nn.Sequential(
237
+ nn.Conv2d(
238
+ in_channels=in_channels,
239
+ out_channels=out_channels,
240
+ kernel_size=(3, 3),
241
+ stride=(1, 1),
242
+ padding=(1, 1),
243
+ bias=False
244
+ ),
245
+ nn.BatchNorm2d(
246
+ out_channels,
247
+ momentum=momentum
248
+ ),
249
+ nn.ReLU(),
250
+ nn.Conv2d(
251
+ in_channels=out_channels,
252
+ out_channels=out_channels,
253
+ kernel_size=(3, 3),
254
+ stride=(1, 1),
255
+ padding=(1, 1),
256
+ bias=False
257
+ ),
258
+ nn.BatchNorm2d(
259
+ out_channels,
260
+ momentum=momentum
261
+ ),
262
+ nn.ReLU()
263
+ )
264
+
265
+ if in_channels != out_channels:
266
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
267
+ self.is_shortcut = True
268
+ else: self.is_shortcut = False
269
+
270
+ def forward(self, x):
271
+ return (self.conv(x) + self.shortcut(x)) if self.is_shortcut else (self.conv(x) + x)
272
+
273
+ class ResEncoderBlock(nn.Module):
274
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
275
+ super(ResEncoderBlock, self).__init__()
276
+ self.n_blocks = n_blocks
277
+ self.conv = nn.ModuleList()
278
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
279
+
280
+ for _ in range(n_blocks - 1):
281
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
282
+
283
+ self.kernel_size = kernel_size
284
+ if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
285
+
286
+ def forward(self, x):
287
+ for i in range(self.n_blocks):
288
+ x = self.conv[i](x)
289
+
290
+ if self.kernel_size is not None: return x, self.pool(x)
291
+ else: return x
292
+
293
+ class Encoder(nn.Module):
294
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
295
+ super(Encoder, self).__init__()
296
+ self.n_encoders = n_encoders
297
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
298
+ self.layers = nn.ModuleList()
299
+
300
+ for _ in range(self.n_encoders):
301
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
302
+ in_channels = out_channels
303
+ out_channels *= 2
304
+ in_size //= 2
305
+
306
+ self.out_size = in_size
307
+ self.out_channel = out_channels
308
+
309
+ def forward(self, x):
310
+ concat_tensors = []
311
+ x = self.bn(x)
312
+
313
+ for layer in self.layers:
314
+ t, x = layer(x)
315
+ concat_tensors.append(t)
316
+
317
+ return x, concat_tensors
318
+
319
+ class Intermediate(nn.Module):
320
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
321
+ super(Intermediate, self).__init__()
322
+ self.layers = nn.ModuleList()
323
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
324
+
325
+ for _ in range(n_inters - 1):
326
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
327
+
328
+ def forward(self, x):
329
+ for layer in self.layers:
330
+ x = layer(x)
331
+
332
+ return x
333
+
334
+ class ResDecoderBlock(nn.Module):
335
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
336
+ super(ResDecoderBlock, self).__init__()
337
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
338
+ self.conv1 = nn.Sequential(
339
+ nn.ConvTranspose2d(
340
+ in_channels=in_channels,
341
+ out_channels=out_channels,
342
+ kernel_size=(3, 3),
343
+ stride=stride,
344
+ padding=(1, 1),
345
+ output_padding=out_padding,
346
+ bias=False
347
+ ),
348
+ nn.BatchNorm2d(
349
+ out_channels,
350
+ momentum=momentum
351
+ ),
352
+ nn.ReLU()
353
+ )
354
+
355
+ self.conv2 = nn.ModuleList()
356
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
357
+
358
+ for _ in range(n_blocks - 1):
359
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
360
+
361
+ def forward(self, x, concat_tensor):
362
+ x = torch.cat((self.conv1(x), concat_tensor), dim=1)
363
+ for conv2 in self.conv2:
364
+ x = conv2(x)
365
+
366
+ return x
367
+
368
+ class Decoder(nn.Module):
369
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
370
+ super(Decoder, self).__init__()
371
+ self.layers = nn.ModuleList()
372
+
373
+ for _ in range(n_decoders):
374
+ out_channels = in_channels // 2
375
+ self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
376
+ in_channels = out_channels
377
+
378
+ def forward(self, x, concat_tensors):
379
+ for i, layer in enumerate(self.layers):
380
+ x = layer(x, concat_tensors[-1 - i])
381
+
382
+ return x
383
+
384
+ class DeepUnet(nn.Module):
385
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
386
+ super(DeepUnet, self).__init__()
387
+ self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
388
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
389
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
390
+
391
+ def forward(self, x):
392
+ x, concat_tensors = self.encoder(x)
393
+ return self.decoder(self.intermediate(x), concat_tensors)
394
+
395
+ class HPADeepUnet(nn.Module):
396
+ def __init__(self, in_channels=1, en_out_channels=16, base_channels=64, hyperace_k=2, hyperace_l=1, num_hyperedges=16, num_heads=8):
397
+ super().__init__()
398
+ self.encoder = YOLO13Encoder(in_channels, base_channels)
399
+ enc_ch = self.encoder.out_channels
400
+
401
+ self.hyperace = HyperACE(
402
+ in_channels=enc_ch,
403
+ out_channels=enc_ch[-1],
404
+ num_hyperedges=num_hyperedges,
405
+ num_heads=num_heads,
406
+ k=hyperace_k,
407
+ l=hyperace_l
408
+ )
409
+
410
+ self.decoder = YOLO13FullPADDecoder(
411
+ encoder_channels=enc_ch,
412
+ hyperace_out_c=enc_ch[-1],
413
+ out_channels_final=en_out_channels
414
+ )
415
+
416
+ def forward(self, x):
417
+ features = self.encoder(x)
418
+ return nn.functional.interpolate(self.decoder(features, self.hyperace(features)), size=x.shape[2:], mode='bilinear', align_corners=False)
419
+
420
+ class BiGRU(nn.Module):
421
+ def __init__(self, input_features, hidden_features, num_layers):
422
+ super(BiGRU, self).__init__()
423
+ self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
424
+
425
+ def forward(self, x):
426
+ try:
427
+ return self.gru(x)[0]
428
+ except:
429
+ torch.backends.cudnn.enabled = False
430
+ return self.gru(x)[0]
431
+
432
+ class E2E(nn.Module):
433
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16, hpa=False):
434
+ super(E2E, self).__init__()
435
+ self.unet = (
436
+ HPADeepUnet(
437
+ in_channels=in_channels,
438
+ en_out_channels=en_out_channels,
439
+ base_channels=64,
440
+ hyperace_k=2,
441
+ hyperace_l=1,
442
+ num_hyperedges=16,
443
+ num_heads=4
444
+ )
445
+ ) if hpa else (
446
+ DeepUnet(
447
+ kernel_size,
448
+ n_blocks,
449
+ en_de_layers,
450
+ inter_layers,
451
+ in_channels,
452
+ en_out_channels
453
+ )
454
+ )
455
+
456
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
457
+ self.fc = (
458
+ nn.Sequential(
459
+ BiGRU(3 * 128, 256, n_gru),
460
+ nn.Linear(512, N_CLASS),
461
+ nn.Dropout(0.25),
462
+ nn.Sigmoid()
463
+ )
464
+ ) if n_gru else (
465
+ nn.Sequential(
466
+ nn.Linear(3 * N_MELS, N_CLASS),
467
+ nn.Dropout(0.25),
468
+ nn.Sigmoid()
469
+ )
470
+ )
471
+
472
+ def forward(self, mel):
473
+ return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
474
+
475
+ class MelSpectrogram(nn.Module):
476
+ def __init__(self, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
477
+ super().__init__()
478
+ n_fft = win_length if n_fft is None else n_fft
479
+ self.hann_window = {}
480
+ mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
481
+ mel_basis = torch.from_numpy(mel_basis).float()
482
+ self.register_buffer("mel_basis", mel_basis)
483
+ self.n_fft = win_length if n_fft is None else n_fft
484
+ self.hop_length = hop_length
485
+ self.win_length = win_length
486
+ self.sample_rate = sample_rate
487
+ self.n_mel_channels = n_mel_channels
488
+ self.clamp = clamp
489
+
490
+ def forward(self, audio, keyshift=0, speed=1, center=True):
491
+ factor = 2 ** (keyshift / 12)
492
+ win_length_new = int(np.round(self.win_length * factor))
493
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
494
+ if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
495
+
496
+ n_fft = int(np.round(self.n_fft * factor))
497
+ hop_length = int(np.round(self.hop_length * speed))
498
+
499
+ fft = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
500
+ magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt()
501
+
502
+ if keyshift != 0:
503
+ size = self.n_fft // 2 + 1
504
+ resize = magnitude.size(1)
505
+ if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
506
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
507
+
508
+ mel_output = self.mel_basis @ magnitude
509
+ return mel_output.clamp(min=self.clamp).log()
510
+
511
+ class RMVPE:
512
+ def __init__(self, model_path, is_half, device=None, providers=None, onnx=False, hpa=False):
513
+ self.onnx = onnx
514
+
515
+ if self.onnx:
516
+ import onnxruntime as ort
517
+
518
+ sess_options = ort.SessionOptions()
519
+ sess_options.log_severity_level = 3
520
+ self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
521
+ else:
522
+ model = E2E(4, 1, (2, 2), 5, 4, 1, 16, hpa=hpa)
523
+
524
+ model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
525
+ model.eval()
526
+ if is_half: model = model.half()
527
+ self.model = model.to(device)
528
+
529
+ self.device = device
530
+ self.is_half = is_half
531
+ self.mel_extractor = MelSpectrogram(N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
532
+ cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
533
+ self.cents_mapping = np.pad(cents_mapping, (4, 4))
534
+
535
+ def mel2hidden(self, mel, chunk_size = 32000):
536
+ with torch.no_grad():
537
+ n_frames = mel.shape[-1]
538
+ mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
539
+
540
+ output_chunks = []
541
+ pad_frames = mel.shape[-1]
542
+
543
+ for start in range(0, pad_frames, chunk_size):
544
+ mel_chunk = mel[..., start:min(start + chunk_size, pad_frames)]
545
+ assert mel_chunk.shape[-1] % 32 == 0
546
+
547
+ if self.onnx:
548
+ mel_chunk = mel_chunk.cpu().numpy().astype(np.float32)
549
+ out_chunk = torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: mel_chunk})[0], device=self.device)
550
+ else:
551
+ if self.is_half: mel_chunk = mel_chunk.half()
552
+ out_chunk = self.model(mel_chunk)
553
+
554
+ output_chunks.append(out_chunk)
555
+
556
+ hidden = torch.cat(output_chunks, dim=1)
557
+ return hidden[:, :n_frames]
558
+
559
+ def decode(self, hidden, thred=0.03):
560
+ f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
561
+ f0[f0 == 10] = 0
562
+
563
+ return f0
564
+
565
+ def infer_from_audio(self, audio, thred=0.03):
566
+ hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
567
+
568
+ return self.decode(hidden.squeeze(0).cpu().numpy().astype(np.float32), thred=thred)
569
+
570
+ def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
571
+ f0 = self.infer_from_audio(audio, thred)
572
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
573
+
574
+ return f0
575
+
576
+ def to_local_average_cents(self, salience, thred=0.05):
577
+ center = np.argmax(salience, axis=1)
578
+ salience = np.pad(salience, ((0, 0), (4, 4)))
579
+ center += 4
580
+ todo_salience, todo_cents_mapping = [], []
581
+ starts = center - 4
582
+ ends = center + 5
583
+
584
+ for idx in range(salience.shape[0]):
585
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
586
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
587
+
588
+ todo_salience = np.array(todo_salience)
589
+ devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
590
+ devided[np.max(salience, axis=1) <= thred] = 0
591
+
592
+ return devided