| import torch |
| import torch.nn as nn |
|
|
| from audio_denoiser.modules.Permute import Permute |
| from audio_denoiser.modules.SimpleRoberta import SimpleRoberta |
| from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler |
|
|
| import json |
|
|
|
|
| class AudioNoiseModel(nn.Module): |
| def __init__(self, config: dict): |
| super(AudioNoiseModel, self).__init__() |
|
|
| |
| self.config = config |
| scaler_dict = config["scaler"] |
| self.scaler = SpectrogramScaler.from_dict(scaler_dict) |
| self.in_channels = config.get("in_channels", 257) |
| self.roberta_hidden_size = config.get("roberta_hidden_size", 768) |
| self.model1 = nn.Sequential( |
| nn.Conv1d(self.in_channels, 1024, kernel_size=1), |
| nn.ELU(), |
| nn.Conv1d(1024, 1024, kernel_size=1), |
| nn.ELU(), |
| nn.Conv1d(1024, self.in_channels, kernel_size=1), |
| ) |
| self.model2 = nn.Sequential( |
| Permute(0, 2, 1), |
| nn.Linear(self.in_channels, self.roberta_hidden_size), |
| SimpleRoberta(num_hidden_layers=5, hidden_size=self.roberta_hidden_size), |
| nn.Linear(self.roberta_hidden_size, self.in_channels), |
| Permute(0, 2, 1), |
| ) |
|
|
| @property |
| def sample_rate(self) -> int: |
| return self.config.get("sample_rate", 16000) |
|
|
| @property |
| def n_fft(self) -> int: |
| return self.config.get("n_fft", 512) |
|
|
| @property |
| def num_frames(self) -> int: |
| return self.config.get("num_frames", 32) |
|
|
| def forward(self, x, use_scaler: bool = False, out_scale: float = 1.0): |
| if use_scaler: |
| x = self.scaler(x) |
| x1 = self.model1(x) |
| x2 = self.model2(x) |
| x = x1 + x2 |
| return x * out_scale |
|
|
|
|
| def load_audio_denosier_model(dir_path: str, device) -> AudioNoiseModel: |
| config = json.load(open(f"{dir_path}/config.json", "r")) |
| model = AudioNoiseModel(config) |
| model.load_state_dict(torch.load(f"{dir_path}/pytorch_model.bin")) |
|
|
| model.to(device) |
| model.model1.to(device) |
| model.model2.to(device) |
|
|
| return model |
|
|