| import time |
| from math import ceil |
| import warnings |
|
|
| import torch |
| import pytorch_lightning as pl |
| from torch_ema import ExponentialMovingAverage |
|
|
| from sgmse import sampling |
| from sgmse.sdes import SDERegistry |
| from sgmse.backbones import BackboneRegistry |
| from sgmse.util.inference import evaluate_model |
| from sgmse.util.other import pad_spec |
|
|
|
|
| class ScoreModel(pl.LightningModule): |
| @staticmethod |
| def add_argparse_args(parser): |
| parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)") |
| parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)") |
| parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)") |
| parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).") |
| parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.") |
| return parser |
|
|
| def __init__( |
| self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03, |
| num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs |
| ): |
| """ |
| Create a new ScoreModel. |
| |
| Args: |
| backbone: Backbone DNN that serves as a score-based model. |
| sde: The SDE that defines the diffusion process. |
| lr: The learning rate of the optimizer. (1e-4 by default). |
| ema_decay: The decay constant of the parameter EMA (0.999 by default). |
| t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default). |
| loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae' |
| """ |
| super().__init__() |
| |
| self.backbone = backbone |
| dnn_cls = BackboneRegistry.get_by_name(backbone) |
| self.dnn = dnn_cls(**kwargs) |
| |
| sde_cls = SDERegistry.get_by_name(sde) |
| self.sde = sde_cls(**kwargs) |
| |
| self.lr = lr |
| self.ema_decay = ema_decay |
| self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay) |
| self._error_loading_ema = False |
| self.t_eps = t_eps |
| self.loss_type = loss_type |
| self.num_eval_files = num_eval_files |
|
|
| self.save_hyperparameters(ignore=['no_wandb']) |
| self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0) |
|
|
| def configure_optimizers(self): |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
| return optimizer |
|
|
| def optimizer_step(self, *args, **kwargs): |
| |
| super().optimizer_step(*args, **kwargs) |
| self.ema.update(self.parameters()) |
|
|
| |
| def on_load_checkpoint(self, checkpoint): |
| ema = checkpoint.get('ema', None) |
| if ema is not None: |
| self.ema.load_state_dict(checkpoint['ema']) |
| else: |
| self._error_loading_ema = True |
| warnings.warn("EMA state_dict not found in checkpoint!") |
|
|
| def on_save_checkpoint(self, checkpoint): |
| checkpoint['ema'] = self.ema.state_dict() |
|
|
| def train(self, mode, no_ema=False): |
| res = super().train(mode) |
| if not self._error_loading_ema: |
| if mode == False and not no_ema: |
| |
| self.ema.store(self.parameters()) |
| self.ema.copy_to(self.parameters()) |
| else: |
| |
| if self.ema.collected_params is not None: |
| self.ema.restore(self.parameters()) |
| return res |
|
|
| def eval(self, no_ema=False): |
| return self.train(False, no_ema=no_ema) |
|
|
| def _loss(self, err): |
| if self.loss_type == 'mse': |
| losses = torch.square(err.abs()) |
| elif self.loss_type == 'mae': |
| losses = err.abs() |
| |
| |
| loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1)) |
| return loss |
|
|
| def _step(self, batch, batch_idx): |
| x, y = batch |
| t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps |
| mean, std = self.sde.marginal_prob(x, t, y) |
| z = torch.randn_like(x) |
| sigmas = std[:, None, None, None] |
| perturbed_data = mean + sigmas * z |
| score = self(perturbed_data, t, y) |
| err = score * sigmas + z |
| loss = self._loss(err) |
| return loss |
|
|
| def training_step(self, batch, batch_idx): |
| loss = self._step(batch, batch_idx) |
| self.log('train_loss', loss, on_step=True, on_epoch=True) |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| loss = self._step(batch, batch_idx) |
| self.log('valid_loss', loss, on_step=False, on_epoch=True) |
|
|
| |
| if batch_idx == 0 and self.num_eval_files != 0: |
| pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files) |
| self.log('pesq', pesq, on_step=False, on_epoch=True) |
| self.log('si_sdr', si_sdr, on_step=False, on_epoch=True) |
| self.log('estoi', estoi, on_step=False, on_epoch=True) |
|
|
| return loss |
|
|
| def forward(self, x, t, y): |
| |
| dnn_input = torch.cat([x, y], dim=1) |
| |
| |
| score = -self.dnn(dnn_input, t) |
| return score |
|
|
| def to(self, *args, **kwargs): |
| """Override PyTorch .to() to also transfer the EMA of the model weights""" |
| self.ema.to(*args, **kwargs) |
| return super().to(*args, **kwargs) |
|
|
| def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs): |
| N = self.sde.N if N is None else N |
| sde = self.sde.copy() |
| sde.N = N |
|
|
| kwargs = {"eps": self.t_eps, **kwargs} |
| if minibatch is None: |
| return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs) |
| else: |
| M = y.shape[0] |
| def batched_sampling_fn(): |
| samples, ns = [], [] |
| for i in range(int(ceil(M / minibatch))): |
| y_mini = y[i*minibatch:(i+1)*minibatch] |
| sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs) |
| sample, n = sampler() |
| samples.append(sample) |
| ns.append(n) |
| samples = torch.cat(samples, dim=0) |
| return samples, ns |
| return batched_sampling_fn |
|
|
| def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs): |
| N = self.sde.N if N is None else N |
| sde = self.sde.copy() |
| sde.N = N |
|
|
| kwargs = {"eps": self.t_eps, **kwargs} |
| if minibatch is None: |
| return sampling.get_ode_sampler(sde, self, y=y, **kwargs) |
| else: |
| M = y.shape[0] |
| def batched_sampling_fn(): |
| samples, ns = [], [] |
| for i in range(int(ceil(M / minibatch))): |
| y_mini = y[i*minibatch:(i+1)*minibatch] |
| sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs) |
| sample, n = sampler() |
| samples.append(sample) |
| ns.append(n) |
| samples = torch.cat(samples, dim=0) |
| return sample, ns |
| return batched_sampling_fn |
|
|
| def train_dataloader(self): |
| return self.data_module.train_dataloader() |
|
|
| def val_dataloader(self): |
| return self.data_module.val_dataloader() |
|
|
| def test_dataloader(self): |
| return self.data_module.test_dataloader() |
|
|
| def setup(self, stage=None): |
| return self.data_module.setup(stage=stage) |
|
|
| def to_audio(self, spec, length=None): |
| return self._istft(self._backward_transform(spec), length) |
|
|
| def _forward_transform(self, spec): |
| return self.data_module.spec_fwd(spec) |
|
|
| def _backward_transform(self, spec): |
| return self.data_module.spec_back(spec) |
|
|
| def _stft(self, sig): |
| return self.data_module.stft(sig) |
|
|
| def _istft(self, spec, length=None): |
| return self.data_module.istft(spec, length) |
|
|
| def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion", |
| corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False, |
| **kwargs |
| ): |
| """ |
| One-call speech enhancement of noisy speech `y`, for convenience. |
| """ |
| sr=16000 |
| start = time.time() |
| T_orig = y.size(1) |
| norm_factor = y.abs().max().item() |
| y = y / norm_factor |
| Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0) |
| Y = pad_spec(Y) |
| if sampler_type == "pc": |
| sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N, |
| corrector_steps=corrector_steps, snr=snr, intermediate=False, |
| **kwargs) |
| elif sampler_type == "ode": |
| sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs) |
| else: |
| print("{} is not a valid sampler type!".format(sampler_type)) |
| sample, nfe = sampler() |
| x_hat = self.to_audio(sample.squeeze(), T_orig) |
| x_hat = x_hat * norm_factor |
| x_hat = x_hat.squeeze().cpu().numpy() |
| end = time.time() |
| if timeit: |
| rtf = (end-start)/(len(x_hat)/sr) |
| return x_hat, nfe, rtf |
| else: |
| return x_hat |
|
|