| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| sys.path.append(os.getcwd()) | |
| from infer.lib.predictors.RMVPE.yolo import YOLO13Encoder, YOLO13FullPADDecoder, HyperACE | |
| class ConvBlockRes(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| momentum=0.01 | |
| ): | |
| super(ConvBlockRes, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 3), | |
| stride=(1, 1), | |
| padding=(1, 1), | |
| bias=False | |
| ), | |
| nn.BatchNorm2d( | |
| out_channels, | |
| momentum=momentum | |
| ), | |
| nn.ReLU(), | |
| nn.Conv2d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 3), | |
| stride=(1, 1), | |
| padding=(1, 1), | |
| bias=False | |
| ), | |
| nn.BatchNorm2d( | |
| out_channels, | |
| momentum=momentum | |
| ), | |
| nn.ReLU() | |
| ) | |
| if in_channels != out_channels: | |
| self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) | |
| self.is_shortcut = True | |
| else: self.is_shortcut = False | |
| def forward(self, x): | |
| return ( | |
| self.conv(x) + self.shortcut(x) | |
| ) if self.is_shortcut else ( | |
| self.conv(x) + x | |
| ) | |
| class ResEncoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| n_blocks=1, | |
| momentum=0.01 | |
| ): | |
| super(ResEncoderBlock, self).__init__() | |
| self.n_blocks = n_blocks | |
| self.conv = nn.ModuleList() | |
| self.conv.append( | |
| ConvBlockRes( | |
| in_channels, | |
| out_channels, | |
| momentum | |
| ) | |
| ) | |
| for _ in range(n_blocks - 1): | |
| self.conv.append( | |
| ConvBlockRes( | |
| out_channels, | |
| out_channels, | |
| momentum | |
| ) | |
| ) | |
| self.kernel_size = kernel_size | |
| if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size) | |
| def forward(self, x): | |
| for i in range(self.n_blocks): | |
| x = self.conv[i](x) | |
| if self.kernel_size is not None: return x, self.pool(x) | |
| else: return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| in_size, | |
| n_encoders, | |
| kernel_size, | |
| n_blocks, | |
| out_channels=16, | |
| momentum=0.01 | |
| ): | |
| super(Encoder, self).__init__() | |
| self.n_encoders = n_encoders | |
| self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) | |
| self.layers = nn.ModuleList() | |
| for _ in range(self.n_encoders): | |
| self.layers.append( | |
| ResEncoderBlock( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| n_blocks, | |
| momentum=momentum | |
| ) | |
| ) | |
| in_channels = out_channels | |
| out_channels *= 2 | |
| in_size //= 2 | |
| self.out_size = in_size | |
| self.out_channel = out_channels | |
| def forward(self, x): | |
| concat_tensors = [] | |
| x = self.bn(x) | |
| for layer in self.layers: | |
| t, x = layer(x) | |
| concat_tensors.append(t) | |
| return x, concat_tensors | |
| class Intermediate(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| n_inters, | |
| n_blocks, | |
| momentum=0.01 | |
| ): | |
| super(Intermediate, self).__init__() | |
| self.layers = nn.ModuleList() | |
| self.layers.append( | |
| ResEncoderBlock( | |
| in_channels, | |
| out_channels, | |
| None, | |
| n_blocks, | |
| momentum | |
| ) | |
| ) | |
| for _ in range(n_inters - 1): | |
| self.layers.append( | |
| ResEncoderBlock( | |
| out_channels, | |
| out_channels, | |
| None, | |
| n_blocks, | |
| momentum | |
| ) | |
| ) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| class ResDecoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| stride, | |
| n_blocks=1, | |
| momentum=0.01 | |
| ): | |
| super(ResDecoderBlock, self).__init__() | |
| out_padding = (0, 1) if stride == (1, 2) else (1, 1) | |
| self.conv1 = nn.Sequential( | |
| nn.ConvTranspose2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 3), | |
| stride=stride, | |
| padding=(1, 1), | |
| output_padding=out_padding, | |
| bias=False | |
| ), | |
| nn.BatchNorm2d( | |
| out_channels, | |
| momentum=momentum | |
| ), | |
| nn.ReLU() | |
| ) | |
| self.conv2 = nn.ModuleList() | |
| self.conv2.append( | |
| ConvBlockRes( | |
| out_channels * 2, | |
| out_channels, | |
| momentum | |
| ) | |
| ) | |
| for _ in range(n_blocks - 1): | |
| self.conv2.append( | |
| ConvBlockRes( | |
| out_channels, | |
| out_channels, | |
| momentum | |
| ) | |
| ) | |
| def forward(self, x, concat_tensor): | |
| x = torch.cat((self.conv1(x), concat_tensor), dim=1) | |
| for conv2 in self.conv2: | |
| x = conv2(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| n_decoders, | |
| stride, | |
| n_blocks, | |
| momentum=0.01 | |
| ): | |
| super(Decoder, self).__init__() | |
| self.layers = nn.ModuleList() | |
| for _ in range(n_decoders): | |
| out_channels = in_channels // 2 | |
| self.layers.append( | |
| ResDecoderBlock( | |
| in_channels, | |
| out_channels, | |
| stride, | |
| n_blocks, | |
| momentum | |
| ) | |
| ) | |
| in_channels = out_channels | |
| def forward(self, x, concat_tensors): | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x, concat_tensors[-1 - i]) | |
| return x | |
| class DeepUnet(nn.Module): | |
| def __init__( | |
| self, | |
| kernel_size, | |
| n_blocks, | |
| en_de_layers=5, | |
| inter_layers=4, | |
| in_channels=1, | |
| en_out_channels=16 | |
| ): | |
| super(DeepUnet, self).__init__() | |
| self.encoder = Encoder( | |
| in_channels, | |
| 128, | |
| en_de_layers, | |
| kernel_size, | |
| n_blocks, | |
| en_out_channels | |
| ) | |
| self.intermediate = Intermediate( | |
| self.encoder.out_channel // 2, | |
| self.encoder.out_channel, | |
| inter_layers, | |
| n_blocks | |
| ) | |
| self.decoder = Decoder( | |
| self.encoder.out_channel, | |
| en_de_layers, | |
| kernel_size, | |
| n_blocks | |
| ) | |
| def forward(self, x): | |
| x, concat_tensors = self.encoder(x) | |
| return self.decoder( | |
| self.intermediate(x), | |
| concat_tensors | |
| ) | |
| class HPADeepUnet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=1, | |
| en_out_channels=16, | |
| base_channels=64, | |
| hyperace_k=2, | |
| hyperace_l=1, | |
| num_hyperedges=16, | |
| num_heads=8 | |
| ): | |
| super().__init__() | |
| self.encoder = YOLO13Encoder( | |
| in_channels, | |
| base_channels | |
| ) | |
| enc_ch = self.encoder.out_channels | |
| self.hyperace = HyperACE( | |
| in_channels=enc_ch, | |
| out_channels=enc_ch[-1], | |
| num_hyperedges=num_hyperedges, | |
| num_heads=num_heads, | |
| k=hyperace_k, | |
| l=hyperace_l | |
| ) | |
| self.decoder = YOLO13FullPADDecoder( | |
| encoder_channels=enc_ch, | |
| hyperace_out_c=enc_ch[-1], | |
| out_channels_final=en_out_channels | |
| ) | |
| def forward(self, x): | |
| features = self.encoder(x) | |
| return nn.functional.interpolate( | |
| self.decoder( | |
| features, | |
| self.hyperace(features) | |
| ), | |
| size=x.shape[2:], | |
| mode='bilinear', | |
| align_corners=False | |
| ) |