Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://github.com/jzi040941/PercepNet | |
| https://arxiv.org/abs/2008.04259 | |
| https://modelzoo.co/model/percepnet | |
| 太复杂了。 | |
| (1)pytorch 模型只是整个 pipeline 中的一部分。 | |
| (2)训练样本需经过基音分析,频谱包络之类的计算。 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class PercepNet(nn.Module): | |
| """ | |
| https://github.com/jzi040941/PercepNet/blob/main/rnn_train.py#L105 | |
| 4.1% of an x86 CPU core | |
| """ | |
| def __init__(self, input_dim=70): | |
| super(PercepNet, self).__init__() | |
| # self.hidden_dim = hidden_dim | |
| # self.n_layers = n_layers | |
| self.fc = nn.Sequential( | |
| nn.Linear(input_dim, 128), | |
| nn.ReLU() | |
| ) | |
| self.conv1 = nn.Sequential( | |
| nn.Conv1d(128, 512, 5, stride=1, padding=4), | |
| nn.ReLU() | |
| )#padding for align with c++ dnn | |
| self.conv2 = nn.Sequential( | |
| nn.Conv1d(512, 512, 3, stride=1, padding=2), | |
| nn.Tanh() | |
| ) | |
| #self.gru = nn.GRU(512, 512, 3, batch_first=True) | |
| self.gru1 = nn.GRU(512, 512, 1, batch_first=True) | |
| self.gru2 = nn.GRU(512, 512, 1, batch_first=True) | |
| self.gru3 = nn.GRU(512, 512, 1, batch_first=True) | |
| self.gru_gb = nn.GRU(512, 512, 1, batch_first=True) | |
| self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True) | |
| self.fc_gb = nn.Sequential( | |
| nn.Linear(512*5, 34), | |
| nn.Sigmoid() | |
| ) | |
| self.fc_rb = nn.Sequential( | |
| nn.Linear(128, 34), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| # x shape: [b, t, f] | |
| x = self.fc(x) | |
| x = x.permute([0, 2, 1]) | |
| # x shape: [b, f, t] | |
| # causal conv | |
| x = self.conv1(x) | |
| x = x[:, :, :-4] | |
| # x shape: [b, f, t] | |
| convout = self.conv2(x) | |
| convout = convout[:, :, :-2] | |
| convout = convout.permute([0, 2, 1]) | |
| # convout shape: [b, t, f] | |
| gru1_out, gru1_state = self.gru1(convout) | |
| gru2_out, gru2_state = self.gru2(gru1_out) | |
| gru3_out, gru3_state = self.gru3(gru2_out) | |
| gru_gb_out, gru_gb_state = self.gru_gb(gru3_out) | |
| concat_gb_layer = torch.cat(tensors=(convout, gru1_out, gru2_out, gru3_out, gru_gb_out), dim=-1) | |
| gb = self.fc_gb(concat_gb_layer) | |
| # concat rb need fix | |
| concat_rb_layer = torch.cat(tensors=(gru3_out, convout), dim=-1) | |
| rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer) | |
| rb = self.fc_rb(rnn_rb_out) | |
| output = torch.cat((gb, rb), dim=-1) | |
| return output | |
| def main(): | |
| model = PercepNet() | |
| x = torch.randn(20, 8, 70) | |
| out = model(x) | |
| print(out.shape) | |
| if __name__ == "__main__": | |
| main() | |