| | import time
|
| | from math import ceil
|
| | import warnings
|
| |
|
| | import torch
|
| | import pytorch_lightning as pl
|
| | from torch_ema import ExponentialMovingAverage
|
| |
|
| | from sgmse import sampling
|
| | from sgmse.sdes import SDERegistry
|
| | from sgmse.backbones import BackboneRegistry
|
| | from sgmse.util.inference import evaluate_model
|
| | from sgmse.util.other import pad_spec
|
| |
|
| |
|
| | class ScoreModel(pl.LightningModule):
|
| | @staticmethod
|
| | def add_argparse_args(parser):
|
| | parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
|
| | parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
|
| | parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
|
| | parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
|
| | parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
|
| | return parser
|
| |
|
| | def __init__(
|
| | self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
|
| | num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
|
| | ):
|
| | """
|
| | Create a new ScoreModel.
|
| |
|
| | Args:
|
| | backbone: Backbone DNN that serves as a score-based model.
|
| | sde: The SDE that defines the diffusion process.
|
| | lr: The learning rate of the optimizer. (1e-4 by default).
|
| | ema_decay: The decay constant of the parameter EMA (0.999 by default).
|
| | t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
|
| | loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
|
| | """
|
| | super().__init__()
|
| |
|
| | self.backbone = backbone
|
| | dnn_cls = BackboneRegistry.get_by_name(backbone)
|
| | self.dnn = dnn_cls(**kwargs)
|
| |
|
| | sde_cls = SDERegistry.get_by_name(sde)
|
| | self.sde = sde_cls(**kwargs)
|
| |
|
| | self.lr = lr
|
| | self.ema_decay = ema_decay
|
| | self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
|
| | self._error_loading_ema = False
|
| | self.t_eps = t_eps
|
| | self.loss_type = loss_type
|
| | self.num_eval_files = num_eval_files
|
| |
|
| | self.save_hyperparameters(ignore=['no_wandb'])
|
| | self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
|
| |
|
| | def configure_optimizers(self):
|
| | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
| | return optimizer
|
| |
|
| | def optimizer_step(self, *args, **kwargs):
|
| |
|
| | super().optimizer_step(*args, **kwargs)
|
| | self.ema.update(self.parameters())
|
| |
|
| |
|
| | def on_load_checkpoint(self, checkpoint):
|
| | ema = checkpoint.get('ema', None)
|
| | if ema is not None:
|
| | self.ema.load_state_dict(checkpoint['ema'])
|
| | else:
|
| | self._error_loading_ema = True
|
| | warnings.warn("EMA state_dict not found in checkpoint!")
|
| |
|
| | def on_save_checkpoint(self, checkpoint):
|
| | checkpoint['ema'] = self.ema.state_dict()
|
| |
|
| | def train(self, mode, no_ema=False):
|
| | res = super().train(mode)
|
| | if not self._error_loading_ema:
|
| | if mode == False and not no_ema:
|
| |
|
| | self.ema.store(self.parameters())
|
| | self.ema.copy_to(self.parameters())
|
| | else:
|
| |
|
| | if self.ema.collected_params is not None:
|
| | self.ema.restore(self.parameters())
|
| | return res
|
| |
|
| | def eval(self, no_ema=False):
|
| | return self.train(False, no_ema=no_ema)
|
| |
|
| | def _loss(self, err):
|
| | if self.loss_type == 'mse':
|
| | losses = torch.square(err.abs())
|
| | elif self.loss_type == 'mae':
|
| | losses = err.abs()
|
| |
|
| |
|
| | loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
|
| | return loss
|
| |
|
| | def _step(self, batch, batch_idx):
|
| | x, y = batch
|
| | t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
|
| | mean, std = self.sde.marginal_prob(x, t, y)
|
| | z = torch.randn_like(x)
|
| | sigmas = std[:, None, None, None]
|
| | perturbed_data = mean + sigmas * z
|
| | score = self(perturbed_data, t, y)
|
| | err = score * sigmas + z
|
| | loss = self._loss(err)
|
| | return loss
|
| |
|
| | def training_step(self, batch, batch_idx):
|
| | loss = self._step(batch, batch_idx)
|
| | self.log('train_loss', loss, on_step=True, on_epoch=True)
|
| | return loss
|
| |
|
| | def validation_step(self, batch, batch_idx):
|
| | loss = self._step(batch, batch_idx)
|
| | self.log('valid_loss', loss, on_step=False, on_epoch=True)
|
| |
|
| |
|
| | if batch_idx == 0 and self.num_eval_files != 0:
|
| | pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
|
| | self.log('pesq', pesq, on_step=False, on_epoch=True)
|
| | self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
|
| | self.log('estoi', estoi, on_step=False, on_epoch=True)
|
| |
|
| | return loss
|
| |
|
| | def forward(self, x, t, y):
|
| |
|
| | dnn_input = torch.cat([x, y], dim=1)
|
| |
|
| |
|
| | score = -self.dnn(dnn_input, t)
|
| | return score
|
| |
|
| | def to(self, *args, **kwargs):
|
| | """Override PyTorch .to() to also transfer the EMA of the model weights"""
|
| | self.ema.to(*args, **kwargs)
|
| | return super().to(*args, **kwargs)
|
| |
|
| | def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
|
| | N = self.sde.N if N is None else N
|
| | sde = self.sde.copy()
|
| | sde.N = N
|
| |
|
| | kwargs = {"eps": self.t_eps, **kwargs}
|
| | if minibatch is None:
|
| | return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
|
| | else:
|
| | M = y.shape[0]
|
| | def batched_sampling_fn():
|
| | samples, ns = [], []
|
| | for i in range(int(ceil(M / minibatch))):
|
| | y_mini = y[i*minibatch:(i+1)*minibatch]
|
| | sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
|
| | sample, n = sampler()
|
| | samples.append(sample)
|
| | ns.append(n)
|
| | samples = torch.cat(samples, dim=0)
|
| | return samples, ns
|
| | return batched_sampling_fn
|
| |
|
| | def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
|
| | N = self.sde.N if N is None else N
|
| | sde = self.sde.copy()
|
| | sde.N = N
|
| |
|
| | kwargs = {"eps": self.t_eps, **kwargs}
|
| | if minibatch is None:
|
| | return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
|
| | else:
|
| | M = y.shape[0]
|
| | def batched_sampling_fn():
|
| | samples, ns = [], []
|
| | for i in range(int(ceil(M / minibatch))):
|
| | y_mini = y[i*minibatch:(i+1)*minibatch]
|
| | sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
|
| | sample, n = sampler()
|
| | samples.append(sample)
|
| | ns.append(n)
|
| | samples = torch.cat(samples, dim=0)
|
| | return sample, ns
|
| | return batched_sampling_fn
|
| |
|
| | def train_dataloader(self):
|
| | return self.data_module.train_dataloader()
|
| |
|
| | def val_dataloader(self):
|
| | return self.data_module.val_dataloader()
|
| |
|
| | def test_dataloader(self):
|
| | return self.data_module.test_dataloader()
|
| |
|
| | def setup(self, stage=None):
|
| | return self.data_module.setup(stage=stage)
|
| |
|
| | def to_audio(self, spec, length=None):
|
| | return self._istft(self._backward_transform(spec), length)
|
| |
|
| | def _forward_transform(self, spec):
|
| | return self.data_module.spec_fwd(spec)
|
| |
|
| | def _backward_transform(self, spec):
|
| | return self.data_module.spec_back(spec)
|
| |
|
| | def _stft(self, sig):
|
| | return self.data_module.stft(sig)
|
| |
|
| | def _istft(self, spec, length=None):
|
| | return self.data_module.istft(spec, length)
|
| |
|
| | def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
|
| | corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
|
| | **kwargs
|
| | ):
|
| | """
|
| | One-call speech enhancement of noisy speech `y`, for convenience.
|
| | """
|
| | sr=16000
|
| | start = time.time()
|
| | T_orig = y.size(1)
|
| | norm_factor = y.abs().max().item()
|
| | y = y / norm_factor
|
| | Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
|
| | Y = pad_spec(Y)
|
| | if sampler_type == "pc":
|
| | sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N,
|
| | corrector_steps=corrector_steps, snr=snr, intermediate=False,
|
| | **kwargs)
|
| | elif sampler_type == "ode":
|
| | sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
|
| | else:
|
| | print("{} is not a valid sampler type!".format(sampler_type))
|
| | sample, nfe = sampler()
|
| | x_hat = self.to_audio(sample.squeeze(), T_orig)
|
| | x_hat = x_hat * norm_factor
|
| | x_hat = x_hat.squeeze().cpu().numpy()
|
| | end = time.time()
|
| | if timeit:
|
| | rtf = (end-start)/(len(x_hat)/sr)
|
| | return x_hat, nfe, rtf
|
| | else:
|
| | return x_hat
|
| |
|