| import torch |
| import torchaudio |
| from transformers import Pipeline |
| from librosa import resample |
| from soundfile import write |
| from sgmse.model import ScoreModel |
| from sgmse.util.other import pad_spec |
|
|
| class CustomSpeechEnhancementPipeline(Pipeline): |
| def __init__(self, model, target_sr=16000, pad_mode="zero_pad", args=None): |
| """ |
| Custom pipeline for speech enhancement using ScoreModel. |
| |
| Args: |
| model: The speech enhancement model loaded from a checkpoint (ScoreModel). |
| target_sr: Target sample rate for the input audio (default is 16 kHz). |
| pad_mode: Padding mode for spectrogram (default is "zero_pad"). |
| args: Parsed arguments (device, corrector, corrector_steps, snr, etc.). |
| """ |
| super().__init__(model=model) |
| self.target_sr = target_sr |
| self.pad_mode = pad_mode |
| self.args = args |
|
|
| def preprocess(self, audio_path): |
| |
| y, sr = torchaudio.load(audio_path) |
|
|
| |
| if sr != self.target_sr: |
| y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=self.target_sr)) |
|
|
| |
| norm_factor = y.abs().max() |
| y = y / norm_factor |
|
|
| |
| Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.args.device))), 0) |
| Y = pad_spec(Y, mode=self.pad_mode) |
|
|
| return Y, norm_factor, y.size(1) |
|
|
| def _forward(self, model_inputs): |
| Y, norm_factor, T_orig = model_inputs |
|
|
| |
| sampler = self.model.get_pc_sampler( |
| 'reverse_diffusion', |
| self.args.corrector, |
| Y.to(self.args.device), |
| N=self.args.N, |
| corrector_steps=self.args.corrector_steps, |
| snr=self.args.snr |
| ) |
|
|
| |
| sample, _ = sampler() |
|
|
| |
| x_hat = self.model.to_audio(sample.squeeze(), T_orig) |
|
|
| |
| x_hat = x_hat * norm_factor |
|
|
| return x_hat |
|
|
| def postprocess(self, model_outputs): |
| |
| return model_outputs.cpu().numpy() |
|
|
| def pad_spec(self, Y): |
| """ |
| Apply padding to the spectrogram as per the model's required padding mode. |
| |
| Args: |
| Y: Input spectrogram tensor. |
| |
| Returns: |
| Padded spectrogram. |
| """ |
| |
| return torch.nn.functional.pad(Y, (0, 0, 0, 1), mode=self.pad_mode) |
|
|