| |
| |
| |
| |
| |
|
|
| import typing as tp |
|
|
| import flashy |
| import julius |
| import omegaconf |
| import torch |
| import torch.nn.functional as F |
|
|
| from . import builders |
| from . import base |
| from .. import models |
| from ..modules.diffusion_schedule import NoiseSchedule |
| from ..metrics import RelativeVolumeMel |
| from ..models.builders import get_processor |
| from ..utils.samples.manager import SampleManager |
| from ..solvers.compression import CompressionSolver |
|
|
|
|
| class PerStageMetrics: |
| """Handle prompting the metrics per stage. |
| It outputs the metrics per range of diffusion states. |
| e.g. avg loss when t in [250, 500] |
| """ |
| def __init__(self, num_steps: int, num_stages: int = 4): |
| self.num_steps = num_steps |
| self.num_stages = num_stages |
|
|
| def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): |
| if type(step) is int: |
| stage = int((step / self.num_steps) * self.num_stages) |
| return {f"{name}_{stage}": loss for name, loss in losses.items()} |
| elif type(step) is torch.Tensor: |
| stage_tensor = ((step / self.num_steps) * self.num_stages).long() |
| out: tp.Dict[str, float] = {} |
| for stage_idx in range(self.num_stages): |
| mask = (stage_tensor == stage_idx) |
| N = mask.sum() |
| stage_out = {} |
| if N > 0: |
| for name, loss in losses.items(): |
| stage_loss = (mask * loss).sum() / N |
| stage_out[f"{name}_{stage_idx}"] = stage_loss |
| out = {**out, **stage_out} |
| return out |
|
|
|
|
| class DataProcess: |
| """Apply filtering or resampling. |
| |
| Args: |
| initial_sr (int): Initial sample rate. |
| target_sr (int): Target sample rate. |
| use_resampling: Whether to use resampling or not. |
| use_filter (bool): |
| n_bands (int): Number of bands to consider. |
| idx_band (int): |
| device (torch.device or str): |
| cutoffs (): |
| boost (bool): |
| """ |
| def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, |
| use_filter: bool = False, n_bands: int = 4, |
| idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): |
| """Apply filtering or resampling |
| Args: |
| initial_sr (int): sample rate of the dataset |
| target_sr (int): sample rate after resampling |
| use_resampling (bool): whether or not performs resampling |
| use_filter (bool): when True filter the data to keep only one frequency band |
| n_bands (int): Number of bands used |
| cuts (none or list): The cutoff frequencies of the band filtering |
| if None then we use mel scale bands. |
| idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs |
| boost (bool): make the data scale match our music dataset. |
| """ |
| assert idx_band < n_bands |
| self.idx_band = idx_band |
| if use_filter: |
| if cutoffs is not None: |
| self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) |
| else: |
| self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) |
| self.use_filter = use_filter |
| self.use_resampling = use_resampling |
| self.target_sr = target_sr |
| self.initial_sr = initial_sr |
| self.boost = boost |
|
|
| def process_data(self, x, metric=False): |
| if x is None: |
| return None |
| if self.boost: |
| x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) |
| x * 0.22 |
| if self.use_filter and not metric: |
| x = self.filter(x)[self.idx_band] |
| if self.use_resampling: |
| x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) |
| return x |
|
|
| def inverse_process(self, x): |
| """Upsampling only.""" |
| if self.use_resampling: |
| x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) |
| return x |
|
|
|
|
| class DiffusionSolver(base.StandardSolver): |
| """Solver for compression task. |
| |
| The diffusion task allows for MultiBand diffusion model training. |
| |
| Args: |
| cfg (DictConfig): Configuration. |
| """ |
| def __init__(self, cfg: omegaconf.DictConfig): |
| super().__init__(cfg) |
| self.cfg = cfg |
| self.device = cfg.device |
| self.sample_rate: int = self.cfg.sample_rate |
| self.codec_model = CompressionSolver.model_from_checkpoint( |
| cfg.compression_model_checkpoint, device=self.device) |
|
|
| self.codec_model.set_num_codebooks(cfg.n_q) |
| assert self.codec_model.sample_rate == self.cfg.sample_rate, ( |
| f"Codec model sample rate is {self.codec_model.sample_rate} but " |
| f"Solver sample rate is {self.cfg.sample_rate}." |
| ) |
| assert self.codec_model.sample_rate == self.sample_rate, \ |
| f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ |
| "don't match." |
|
|
| self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) |
| self.register_stateful('sample_processor') |
| self.sample_processor.to(self.device) |
|
|
| self.schedule = NoiseSchedule( |
| **cfg.schedule, device=self.device, sample_processor=self.sample_processor) |
|
|
| self.eval_metric: tp.Optional[torch.nn.Module] = None |
|
|
| self.rvm = RelativeVolumeMel() |
| self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, |
| use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, |
| use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, |
| idx_band=cfg.filter.idx_band, device=self.device) |
|
|
| @property |
| def best_metric_name(self) -> tp.Optional[str]: |
| if self._current_stage == "evaluate": |
| return 'rvm' |
| else: |
| return 'loss' |
|
|
| @torch.no_grad() |
| def get_condition(self, wav: torch.Tensor) -> torch.Tensor: |
| codes, scale = self.codec_model.encode(wav) |
| assert scale is None, "Scaled compression models not supported." |
| emb = self.codec_model.decode_latent(codes) |
| return emb |
|
|
| def build_model(self): |
| """Build model and optimizer as well as optional Exponential Moving Average of the model. |
| """ |
| |
| self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) |
| self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) |
| self.register_stateful('model', 'optimizer') |
| self.register_best_state('model') |
| self.register_ema('model') |
|
|
| def build_dataloaders(self): |
| """Build audio dataloaders for each stage.""" |
| self.dataloaders = builders.get_audio_datasets(self.cfg) |
|
|
| def show(self): |
| |
| raise NotImplementedError() |
|
|
| def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): |
| """Perform one training or valid step on a given batch.""" |
| x = batch.to(self.device) |
| loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss |
|
|
| condition = self.get_condition(x) |
| sample = self.data_processor.process_data(x) |
|
|
| input_, target, step = self.schedule.get_training_item(sample, |
| tensor_step=self.cfg.schedule.variable_step_batch) |
| out = self.model(input_, step, condition=condition).sample |
|
|
| base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) |
| reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) |
| loss = base_loss / reference_loss ** self.cfg.loss.norm_power |
|
|
| if self.is_training: |
| loss.mean().backward() |
| flashy.distrib.sync_model(self.model) |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| metrics = { |
| 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), |
| } |
| metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) |
| metrics.update({ |
| 'std_in': input_.std(), 'std_out': out.std()}) |
| return metrics |
|
|
| def run_epoch(self): |
| |
| self.rng = torch.Generator() |
| self.rng.manual_seed(1234 + self.epoch) |
| self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) |
| |
| super().run_epoch() |
|
|
| def evaluate(self): |
| """Evaluate stage. |
| Runs audio reconstruction evaluation. |
| """ |
| self.model.eval() |
| evaluate_stage_name = f'{self.current_stage}' |
| loader = self.dataloaders['evaluate'] |
| updates = len(loader) |
| lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) |
|
|
| metrics = {} |
| n = 1 |
| for idx, batch in enumerate(lp): |
| x = batch.to(self.device) |
| with torch.no_grad(): |
| y_pred = self.regenerate(x) |
|
|
| y_pred = y_pred.cpu() |
| y = batch.cpu() |
| rvm = self.rvm(y_pred, y) |
| lp.update(**rvm) |
| if len(metrics) == 0: |
| metrics = rvm |
| else: |
| for key in rvm.keys(): |
| metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) |
| metrics = flashy.distrib.average_metrics(metrics) |
| return metrics |
|
|
| @torch.no_grad() |
| def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): |
| """Regenerate the given waveform.""" |
| condition = self.get_condition(wav) |
| initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) |
| result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, |
| step_list=step_list) |
| result = self.data_processor.inverse_process(result) |
| return result |
|
|
| def generate(self): |
| """Generate stage.""" |
| sample_manager = SampleManager(self.xp) |
| self.model.eval() |
| generate_stage_name = f'{self.current_stage}' |
|
|
| loader = self.dataloaders['generate'] |
| updates = len(loader) |
| lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) |
|
|
| for batch in lp: |
| reference, _ = batch |
| reference = reference.to(self.device) |
| estimate = self.regenerate(reference) |
| reference = reference.cpu() |
| estimate = estimate.cpu() |
| sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) |
| flashy.distrib.barrier() |
|
|