| |
| |
| |
| |
| |
|
|
| import logging |
| import multiprocessing |
| from pathlib import Path |
| import typing as tp |
|
|
| import flashy |
| import omegaconf |
| import torch |
| from torch import nn |
|
|
| from . import base, builders |
| from .. import models, quantization |
| from ..utils import checkpoint |
| from ..utils.samples.manager import SampleManager |
| from ..utils.utils import get_pool_executor |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class CompressionSolver(base.StandardSolver): |
| """Solver for compression task. |
| |
| The compression task combines a set of perceptual and objective losses |
| to train an EncodecModel (composed of an encoder-decoder and a quantizer) |
| to perform high fidelity audio reconstruction. |
| """ |
| def __init__(self, cfg: omegaconf.DictConfig): |
| super().__init__(cfg) |
| self.rng: torch.Generator |
| self.adv_losses = builders.get_adversarial_losses(self.cfg) |
| self.aux_losses = nn.ModuleDict() |
| self.info_losses = nn.ModuleDict() |
| assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." |
| loss_weights = dict() |
| for loss_name, weight in self.cfg.losses.items(): |
| if loss_name in ['adv', 'feat']: |
| for adv_name, _ in self.adv_losses.items(): |
| loss_weights[f'{loss_name}_{adv_name}'] = weight |
| elif weight > 0: |
| self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) |
| loss_weights[loss_name] = weight |
| else: |
| self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) |
| self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) |
| self.register_stateful('adv_losses') |
|
|
| @property |
| def best_metric_name(self) -> tp.Optional[str]: |
| |
| return None |
|
|
| def build_model(self): |
| """Instantiate model and optimizer.""" |
| |
| self.model = models.builders.get_compression_model(self.cfg).to(self.device) |
| self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) |
| self.register_stateful('model', 'optimizer') |
| self.register_best_state('model') |
| self.register_ema('model') |
|
|
| def build_dataloaders(self): |
| """Instantiate audio dataloaders for each stage.""" |
| self.dataloaders = builders.get_audio_datasets(self.cfg) |
|
|
| def show(self): |
| """Show the compression model and employed adversarial loss.""" |
| self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:") |
| self.log_model_summary(self.model) |
| self.logger.info("Adversarial loss:") |
| self.log_model_summary(self.adv_losses) |
| self.logger.info("Auxiliary losses:") |
| self.logger.info(self.aux_losses) |
| self.logger.info("Info losses:") |
| self.logger.info(self.info_losses) |
|
|
| def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): |
| """Perform one training or valid step on a given batch.""" |
| x = batch.to(self.device) |
| y = x.clone() |
|
|
| qres = self.model(x) |
| assert isinstance(qres, quantization.QuantizedResult) |
| y_pred = qres.x |
| |
| metrics['bandwidth'] = qres.bandwidth.mean() |
|
|
| if self.is_training: |
| d_losses: dict = {} |
| if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every: |
| for adv_name, adversary in self.adv_losses.items(): |
| disc_loss = adversary.train_adv(y_pred, y) |
| d_losses[f'd_{adv_name}'] = disc_loss |
| metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values()))) |
| metrics.update(d_losses) |
|
|
| balanced_losses: dict = {} |
| other_losses: dict = {} |
|
|
| |
| if qres.penalty is not None and qres.penalty.requires_grad: |
| other_losses['penalty'] = qres.penalty |
|
|
| |
| for adv_name, adversary in self.adv_losses.items(): |
| adv_loss, feat_loss = adversary(y_pred, y) |
| balanced_losses[f'adv_{adv_name}'] = adv_loss |
| balanced_losses[f'feat_{adv_name}'] = feat_loss |
|
|
| |
| for loss_name, criterion in self.aux_losses.items(): |
| loss = criterion(y_pred, y) |
| balanced_losses[loss_name] = loss |
|
|
| |
| metrics.update(balanced_losses) |
| metrics.update(other_losses) |
| metrics.update(qres.metrics) |
|
|
| if self.is_training: |
| |
| other_loss = torch.tensor(0., device=self.device) |
| if 'penalty' in other_losses: |
| other_loss += other_losses['penalty'] |
| if other_loss.requires_grad: |
| other_loss.backward(retain_graph=True) |
| ratio1 = sum(p.grad.data.norm(p=2).pow(2) |
| for p in self.model.parameters() if p.grad is not None) |
| assert isinstance(ratio1, torch.Tensor) |
| metrics['ratio1'] = ratio1.sqrt() |
|
|
| |
| |
| metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred) |
| |
| metrics.update(self.balancer.metrics) |
| ratio2 = sum(p.grad.data.norm(p=2).pow(2) |
| for p in self.model.parameters() if p.grad is not None) |
| assert isinstance(ratio2, torch.Tensor) |
| metrics['ratio2'] = ratio2.sqrt() |
|
|
| |
| flashy.distrib.sync_model(self.model) |
| if self.cfg.optim.max_norm: |
| torch.nn.utils.clip_grad_norm_( |
| self.model.parameters(), self.cfg.optim.max_norm |
| ) |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
|
|
| |
| info_losses: dict = {} |
| with torch.no_grad(): |
| for loss_name, criterion in self.info_losses.items(): |
| loss = criterion(y_pred, y) |
| info_losses[loss_name] = loss |
|
|
| metrics.update(info_losses) |
|
|
| |
| adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')] |
| if len(adv_losses) > 0: |
| metrics['adv'] = torch.sum(torch.stack(adv_losses)) |
| feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')] |
| if len(feat_losses) > 0: |
| metrics['feat'] = torch.sum(torch.stack(feat_losses)) |
|
|
| return metrics |
|
|
| def run_epoch(self): |
| |
| self.rng = torch.Generator() |
| self.rng.manual_seed(1234 + self.epoch) |
| |
| super().run_epoch() |
|
|
| def evaluate(self): |
| """Evaluate stage. Runs audio reconstruction evaluation.""" |
| self.model.eval() |
| evaluate_stage_name = str(self.current_stage) |
|
|
| loader = self.dataloaders['evaluate'] |
| updates = len(loader) |
| lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) |
| average = flashy.averager() |
|
|
| pendings = [] |
| ctx = multiprocessing.get_context('spawn') |
| with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: |
| for idx, batch in enumerate(lp): |
| x = batch.to(self.device) |
| with torch.no_grad(): |
| qres = self.model(x) |
|
|
| y_pred = qres.x.cpu() |
| y = batch.cpu() |
| pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) |
|
|
| metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) |
| for pending in metrics_lp: |
| metrics = pending.result() |
| metrics = average(metrics) |
|
|
| metrics = flashy.distrib.average_metrics(metrics, len(loader)) |
| return metrics |
|
|
| def generate(self): |
| """Generate stage.""" |
| self.model.eval() |
| sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) |
| generate_stage_name = str(self.current_stage) |
|
|
| loader = self.dataloaders['generate'] |
| updates = len(loader) |
| lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) |
|
|
| for batch in lp: |
| reference, _ = batch |
| reference = reference.to(self.device) |
| with torch.no_grad(): |
| qres = self.model(reference) |
| assert isinstance(qres, quantization.QuantizedResult) |
|
|
| reference = reference.cpu() |
| estimate = qres.x.cpu() |
| sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) |
|
|
| flashy.distrib.barrier() |
|
|
| def load_from_pretrained(self, name: str) -> dict: |
| model = models.CompressionModel.get_pretrained(name) |
| if isinstance(model, models.DAC): |
| raise RuntimeError("Cannot fine tune a DAC model.") |
| elif isinstance(model, models.HFEncodecCompressionModel): |
| self.logger.warning('Trying to automatically convert a HuggingFace model ' |
| 'to AudioCraft, this might fail!') |
| state = model.model.state_dict() |
| new_state = {} |
| for k, v in state.items(): |
| if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: |
| |
| layer = int(k.split('.')[2]) |
| if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): |
|
|
| k = k.replace('.conv.', '.convtr.') |
| k = k.replace('encoder.layers.', 'encoder.model.') |
| k = k.replace('decoder.layers.', 'decoder.model.') |
| k = k.replace('conv.', 'conv.conv.') |
| k = k.replace('convtr.', 'convtr.convtr.') |
| k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') |
| k = k.replace('.codebook.', '._codebook.') |
| new_state[k] = v |
| state = new_state |
| elif isinstance(model, models.EncodecModel): |
| state = model.state_dict() |
| else: |
| raise RuntimeError(f"Cannot fine tune model type {type(model)}.") |
| return { |
| 'best_state': {'model': state} |
| } |
|
|
| @staticmethod |
| def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], |
| device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: |
| """Instantiate a CompressionModel from a given checkpoint path or dora sig. |
| This method is a convenient endpoint to load a CompressionModel to use in other solvers. |
| |
| Args: |
| checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. |
| This also supports pre-trained models by using a path of the form //pretrained/NAME. |
| See `model_from_pretrained` for a list of supported pretrained models. |
| use_ema (bool): Use EMA variant of the model instead of the actual model. |
| device (torch.device or str): Device on which the model is loaded. |
| """ |
| checkpoint_path = str(checkpoint_path) |
| if checkpoint_path.startswith('//pretrained/'): |
| name = checkpoint_path.split('/', 3)[-1] |
| return models.CompressionModel.get_pretrained(name, device) |
| logger = logging.getLogger(__name__) |
| logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") |
| _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) |
| assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" |
| state = checkpoint.load_checkpoint(_checkpoint_path) |
| assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" |
| cfg = state['xp.cfg'] |
| cfg.device = device |
| compression_model = models.builders.get_compression_model(cfg).to(device) |
| assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" |
|
|
| assert 'best_state' in state and state['best_state'] != {} |
| assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." |
| compression_model.load_state_dict(state['best_state']['model']) |
| compression_model.eval() |
| logger.info("Compression model loaded!") |
| return compression_model |
|
|
| @staticmethod |
| def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, |
| checkpoint_path: tp.Union[Path, str], |
| device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: |
| """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. |
| |
| Args: |
| cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. |
| checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. |
| use_ema (bool): Use EMA variant of the model instead of the actual model. |
| device (torch.device or str): Device on which the model is loaded. |
| """ |
| compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) |
| compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) |
| return compression_model |
|
|
|
|
| def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict: |
| """Audio reconstruction evaluation method that can be conveniently pickled.""" |
| metrics = {} |
| if cfg.evaluate.metrics.visqol: |
| visqol = builders.get_visqol(cfg.metrics.visqol) |
| metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate) |
| sisnr = builders.get_loss('sisnr', cfg) |
| metrics['sisnr'] = sisnr(y_pred, y) |
| return metrics |
|
|