| import torch | |
| import pytorch_lightning as pl | |
| import torchaudio | |
| import os | |
| import pathlib | |
| import tqdm | |
| from model import ( | |
| EncoderModule, | |
| ChannelFeatureModule, | |
| ChannelModule, | |
| MultiScaleSpectralLoss, | |
| GSTModule, | |
| ) | |
| class PretrainLightningModule(pl.LightningModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.config = config | |
| if config["general"]["use_gst"]: | |
| self.encoder = EncoderModule(config) | |
| self.gst = GSTModule(config) | |
| else: | |
| self.encoder = EncoderModule(config, use_channel=True) | |
| self.channelfeats = ChannelFeatureModule(config) | |
| self.channel = ChannelModule(config) | |
| self.vocoder = None | |
| self.criteria_a = MultiScaleSpectralLoss(config) | |
| if "feature_loss" in config["train"]: | |
| if config["train"]["feature_loss"]["type"] == "mae": | |
| self.criteria_b = torch.nn.L1Loss() | |
| else: | |
| self.criteria_b = torch.nn.MSELoss() | |
| else: | |
| self.criteria = torch.nn.L1Loss() | |
| self.alpha = config["train"]["alpha"] | |
| def forward(self, melspecs, wavsaux): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(melspecs.transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| wavsdeg = self.channel(wavsaux, chfeats) | |
| return enc_out, wavsdeg | |
| def training_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| wavsdeg = self.channel(batch["wavsaux"], chfeats) | |
| loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| loss_encoder = self.criteria_b(enc_out, batch["melspecsaux"]) | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| loss_encoder = self.criteria_b(enc_out, batch["melceps"]) | |
| loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder | |
| self.log( | |
| "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True | |
| ) | |
| self.log( | |
| "train_loss_recons", | |
| loss_recons, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| self.log( | |
| "train_loss_encoder", | |
| loss_encoder, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| wavsdeg = self.channel(batch["wavsaux"], chfeats) | |
| loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| val_aux_feats = batch["melspecsaux"] | |
| feats_name = "melspec" | |
| loss_encoder = self.criteria_b(enc_out, val_aux_feats) | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| val_aux_feats = batch["melceps"] | |
| feats_name = "melcep" | |
| loss_encoder = self.criteria_b(enc_out, val_aux_feats) | |
| loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder | |
| logger_img_dict = { | |
| "val_src_melspec": batch["melspecs"], | |
| "val_pred_{}".format(feats_name): enc_out, | |
| "val_aux_{}".format(feats_name): val_aux_feats, | |
| } | |
| logger_wav_dict = { | |
| "val_src_wav": batch["wavs"], | |
| "val_pred_wav": wavsdeg, | |
| "val_aux_wav": batch["wavsaux"], | |
| } | |
| return { | |
| "val_loss": loss, | |
| "val_loss_recons": loss_recons, | |
| "val_loss_encoder": loss_encoder, | |
| "logger_dict": [logger_img_dict, logger_wav_dict], | |
| } | |
| def validation_epoch_end(self, outputs): | |
| val_loss = torch.stack([out["val_loss"] for out in outputs]).mean().item() | |
| val_loss_recons = ( | |
| torch.stack([out["val_loss_recons"] for out in outputs]).mean().item() | |
| ) | |
| val_loss_encoder = ( | |
| torch.stack([out["val_loss_encoder"] for out in outputs]).mean().item() | |
| ) | |
| self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True) | |
| self.log( | |
| "val_loss_recons", | |
| val_loss_recons, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| self.log( | |
| "val_loss_encoder", | |
| val_loss_encoder, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| def test_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| wavsdeg = self.channel(batch["wavsaux"], chfeats) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| enc_feats_aux = batch["melspecsaux"] | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) | |
| enc_feats_aux = torch.cat( | |
| (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 | |
| ) | |
| recons_wav = self.vocoder(enc_feats_aux).squeeze(1) | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats_input = batch["melspecs"] | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats_input = torch.cat( | |
| (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 | |
| ) | |
| input_recons = self.vocoder(enc_feats_input).squeeze(1) | |
| if "wavsaux" in batch: | |
| gt_wav = batch["wavsaux"] | |
| else: | |
| gt_wav = None | |
| return { | |
| "reconstructed": recons_wav, | |
| "remastered": remas, | |
| "channeled": wavsdeg, | |
| "groundtruth": gt_wav, | |
| "input": batch["wavs"], | |
| "input_recons": input_recons, | |
| } | |
| def test_epoch_end(self, outputs): | |
| wav_dir = ( | |
| pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" | |
| ) | |
| os.makedirs(wav_dir, exist_ok=True) | |
| mel_dir = ( | |
| pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" | |
| ) | |
| os.makedirs(mel_dir, exist_ok=True) | |
| print("Saving mel spectrogram plots ...") | |
| for idx, out in enumerate(tqdm.tqdm(outputs)): | |
| for key in [ | |
| "reconstructed", | |
| "remastered", | |
| "channeled", | |
| "input", | |
| "input_recons", | |
| "groundtruth", | |
| ]: | |
| if out[key] != None: | |
| torchaudio.save( | |
| wav_dir / "{}-{}.wav".format(idx, key), | |
| out[key][0, ...].unsqueeze(0).cpu(), | |
| sample_rate=self.config["preprocess"]["sampling_rate"], | |
| channels_first=True, | |
| ) | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), lr=self.config["train"]["learning_rate"] | |
| ) | |
| lr_scheduler_config = { | |
| "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True | |
| ), | |
| "interval": "epoch", | |
| "frequency": 3, | |
| "monitor": "val_loss", | |
| } | |
| return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} | |
| class SSLBaseModule(pl.LightningModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.config = config | |
| if config["general"]["use_gst"]: | |
| self.encoder = EncoderModule(config) | |
| self.gst = GSTModule(config) | |
| else: | |
| self.encoder = EncoderModule(config, use_channel=True) | |
| self.channelfeats = ChannelFeatureModule(config) | |
| self.channel = ChannelModule(config) | |
| if config["train"]["load_pretrained"]: | |
| pre_model = PretrainLightningModule.load_from_checkpoint( | |
| checkpoint_path=config["train"]["pretrained_path"] | |
| ) | |
| self.encoder.load_state_dict(pre_model.encoder.state_dict(), strict=False) | |
| self.channel.load_state_dict(pre_model.channel.state_dict(), strict=False) | |
| if config["general"]["use_gst"]: | |
| self.gst.load_state_dict(pre_model.gst.state_dict(), strict=False) | |
| else: | |
| self.channelfeats.load_state_dict( | |
| pre_model.channelfeats.state_dict(), strict=False | |
| ) | |
| self.vocoder = None | |
| self.criteria = self.get_loss_function(config) | |
| def training_step(self, batch, batch_idx): | |
| raise NotImplementedError() | |
| def validation_step(self, batch, batch_idx): | |
| raise NotImplementedError() | |
| def validation_epoch_end(self, outputs): | |
| raise NotImplementedError() | |
| def configure_optimizers(self): | |
| raise NotImplementedError() | |
| def get_loss_function(self, config): | |
| raise NotImplementedError() | |
| def forward(self, melspecs, f0s=None): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(melspecs.transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((f0s.unsqueeze(1), enc_out), dim=1) | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| wavsdeg = self.channel(remas, chfeats) | |
| return remas, wavsdeg | |
| def test_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| wavsdeg = self.channel(remas, chfeats) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats_input = batch["melspecs"] | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats_input = torch.cat( | |
| (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 | |
| ) | |
| input_recons = self.vocoder(enc_feats_input).squeeze(1) | |
| if "wavsaux" in batch: | |
| gt_wav = batch["wavsaux"] | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats_aux = batch["melspecsaux"] | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats_aux = torch.cat( | |
| (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 | |
| ) | |
| recons_wav = self.vocoder(enc_feats_aux).squeeze(1) | |
| else: | |
| gt_wav = None | |
| recons_wav = None | |
| return { | |
| "reconstructed": recons_wav, | |
| "remastered": remas, | |
| "channeled": wavsdeg, | |
| "input": batch["wavs"], | |
| "input_recons": input_recons, | |
| "groundtruth": gt_wav, | |
| } | |
| def test_epoch_end(self, outputs): | |
| wav_dir = ( | |
| pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" | |
| ) | |
| os.makedirs(wav_dir, exist_ok=True) | |
| mel_dir = ( | |
| pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" | |
| ) | |
| os.makedirs(mel_dir, exist_ok=True) | |
| print("Saving mel spectrogram plots ...") | |
| for idx, out in enumerate(tqdm.tqdm(outputs)): | |
| plot_keys = [] | |
| for key in [ | |
| "reconstructed", | |
| "remastered", | |
| "channeled", | |
| "input", | |
| "input_recons", | |
| "groundtruth", | |
| ]: | |
| if out[key] != None: | |
| plot_keys.append(key) | |
| torchaudio.save( | |
| wav_dir / "{}-{}.wav".format(idx, key), | |
| out[key][0, ...].unsqueeze(0).cpu(), | |
| sample_rate=self.config["preprocess"]["sampling_rate"], | |
| channels_first=True, | |
| ) | |
| class SSLStepLightningModule(SSLBaseModule): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| if config["train"]["fix_channel"]: | |
| for param in self.channel.parameters(): | |
| param.requires_grad = False | |
| def training_step(self, batch, batch_idx, optimizer_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| wavsdeg = self.channel(remas, chfeats) | |
| loss = self.criteria(wavsdeg, batch["wavs"]) | |
| self.log( | |
| "train_loss", | |
| loss, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| feats_name = "melspec" | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) | |
| feats_name = "melcep" | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| wavsdeg = self.channel(remas, chfeats) | |
| loss = self.criteria(wavsdeg, batch["wavs"]) | |
| logger_img_dict = { | |
| "val_src_melspec": batch["melspecs"], | |
| "val_pred_{}".format(feats_name): enc_out, | |
| } | |
| for auxfeats in ["melceps", "melspecsaux"]: | |
| if auxfeats in batch: | |
| logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats] | |
| logger_wav_dict = { | |
| "val_src_wav": batch["wavs"], | |
| "val_remastered_wav": remas, | |
| "val_pred_wav": wavsdeg, | |
| } | |
| if "wavsaux" in batch: | |
| logger_wav_dict["val_aux_wav"] = batch["wavsaux"] | |
| d_out = {"val_loss": loss, "logger_dict": [logger_img_dict, logger_wav_dict]} | |
| return d_out | |
| def validation_epoch_end(self, outputs): | |
| self.log( | |
| "val_loss", | |
| torch.stack([out["val_loss"] for out in outputs]).mean().item(), | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| def optimizer_step( | |
| self, | |
| epoch, | |
| batch_idx, | |
| optimizer, | |
| optimizer_idx, | |
| optimizer_closure, | |
| on_tpu=False, | |
| using_native_amp=False, | |
| using_lbfgs=False, | |
| ): | |
| if epoch < self.config["train"]["epoch_channel"]: | |
| if optimizer_idx == 0: | |
| optimizer.step(closure=optimizer_closure) | |
| elif optimizer_idx == 1: | |
| optimizer_closure() | |
| else: | |
| if optimizer_idx == 0: | |
| optimizer_closure() | |
| elif optimizer_idx == 1: | |
| optimizer.step(closure=optimizer_closure) | |
| def configure_optimizers(self): | |
| if self.config["train"]["fix_channel"]: | |
| if self.config["general"]["use_gst"]: | |
| optimizer_channel = torch.optim.Adam( | |
| self.gst.parameters(), lr=self.config["train"]["learning_rate"] | |
| ) | |
| else: | |
| optimizer_channel = torch.optim.Adam( | |
| self.channelfeats.parameters(), | |
| lr=self.config["train"]["learning_rate"], | |
| ) | |
| optimizer_encoder = torch.optim.Adam( | |
| self.encoder.parameters(), lr=self.config["train"]["learning_rate"] | |
| ) | |
| else: | |
| if self.config["general"]["use_gst"]: | |
| optimizer_channel = torch.optim.Adam( | |
| [ | |
| {"params": self.channel.parameters()}, | |
| {"params": self.gst.parameters()}, | |
| ], | |
| lr=self.config["train"]["learning_rate"], | |
| ) | |
| else: | |
| optimizer_channel = torch.optim.Adam( | |
| [ | |
| {"params": self.channel.parameters()}, | |
| {"params": self.channelfeats.parameters()}, | |
| ], | |
| lr=self.config["train"]["learning_rate"], | |
| ) | |
| optimizer_encoder = torch.optim.Adam( | |
| self.encoder.parameters(), lr=self.config["train"]["learning_rate"] | |
| ) | |
| optimizers = [optimizer_channel, optimizer_encoder] | |
| schedulers = [ | |
| { | |
| "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizers[0], mode="min", factor=0.5, min_lr=1e-5, verbose=True | |
| ), | |
| "interval": "epoch", | |
| "frequency": 3, | |
| "monitor": "val_loss", | |
| }, | |
| { | |
| "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizers[1], mode="min", factor=0.5, min_lr=1e-5, verbose=True | |
| ), | |
| "interval": "epoch", | |
| "frequency": 3, | |
| "monitor": "val_loss", | |
| }, | |
| ] | |
| return optimizers, schedulers | |
| def get_loss_function(self, config): | |
| return MultiScaleSpectralLoss(config) | |
| class SSLDualLightningModule(SSLBaseModule): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| if config["train"]["fix_channel"]: | |
| for param in self.channel.parameters(): | |
| param.requires_grad = False | |
| self.spec_module = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=config["preprocess"]["sampling_rate"], | |
| n_fft=config["preprocess"]["fft_length"], | |
| win_length=config["preprocess"]["frame_length"], | |
| hop_length=config["preprocess"]["frame_shift"], | |
| f_min=config["preprocess"]["fmin"], | |
| f_max=config["preprocess"]["fmax"], | |
| n_mels=config["preprocess"]["n_mels"], | |
| power=1, | |
| center=True, | |
| norm="slaney", | |
| mel_scale="slaney", | |
| ) | |
| self.beta = config["train"]["beta"] | |
| self.criteria_a, self.criteria_b = self.get_loss_function(config) | |
| def training_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| wavsdeg = self.channel(remas, chfeats) | |
| loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) | |
| with torch.no_grad(): | |
| wavsdegtask = self.channel(batch["wavstask"], chfeats) | |
| melspecstask = self.calc_spectrogram(wavsdegtask) | |
| if self.config["general"]["use_gst"]: | |
| enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) | |
| else: | |
| enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) | |
| enc_out_task = enc_out_task.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| loss_task = self.criteria_b(enc_out_task, batch["melspecstask"]) | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| loss_task = self.criteria_b(enc_out_task, batch["melcepstask"]) | |
| loss = self.beta * loss_recons + (1 - self.beta) * loss_task | |
| self.log( | |
| "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True | |
| ) | |
| self.log( | |
| "train_loss_recons", | |
| loss_recons, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| self.log( | |
| "train_loss_task", | |
| loss_task, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| feats_name = "melspec" | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) | |
| feats_name = "melcep" | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| wavsdeg = self.channel(remas, chfeats) | |
| loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) | |
| wavsdegtask = self.channel(batch["wavstask"], chfeats) | |
| melspecstask = self.calc_spectrogram(wavsdegtask) | |
| if self.config["general"]["use_gst"]: | |
| enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) | |
| else: | |
| enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) | |
| enc_out_task = enc_out_task.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_out_task_truth = batch["melspecstask"] | |
| loss_task = self.criteria_b(enc_out_task, enc_out_task_truth) | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_out_task_truth = batch["melcepstask"] | |
| loss_task = self.criteria_b(enc_out_task, enc_out_task_truth) | |
| loss = self.beta * loss_recons + (1 - self.beta) * loss_task | |
| logger_img_dict = { | |
| "val_src_melspec": batch["melspecs"], | |
| "val_pred_{}".format(feats_name): enc_out, | |
| "val_truth_{}_task".format(feats_name): enc_out_task_truth, | |
| "val_pred_{}_task".format(feats_name): enc_out_task, | |
| } | |
| for auxfeats in ["melceps", "melspecsaux"]: | |
| if auxfeats in batch: | |
| logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats] | |
| logger_wav_dict = { | |
| "val_src_wav": batch["wavs"], | |
| "val_remastered_wav": remas, | |
| "val_pred_wav": wavsdeg, | |
| "val_truth_wavtask": batch["wavstask"], | |
| "val_deg_wavtask": wavsdegtask, | |
| } | |
| if "wavsaux" in batch: | |
| logger_wav_dict["val_aux_wav"] = batch["wavsaux"] | |
| d_out = { | |
| "val_loss": loss, | |
| "val_loss_recons": loss_recons, | |
| "val_loss_task": loss_task, | |
| "logger_dict": [logger_img_dict, logger_wav_dict], | |
| } | |
| return d_out | |
| def validation_epoch_end(self, outputs): | |
| self.log( | |
| "val_loss", | |
| torch.stack([out["val_loss"] for out in outputs]).mean().item(), | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| self.log( | |
| "val_loss_recons", | |
| torch.stack([out["val_loss_recons"] for out in outputs]).mean().item(), | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| self.log( | |
| "val_loss_task", | |
| torch.stack([out["val_loss_task"] for out in outputs]).mean().item(), | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| def test_step(self, batch, batch_idx): | |
| if self.config["general"]["use_gst"]: | |
| enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) | |
| chfeats = self.gst(batch["melspecs"].transpose(1, 2)) | |
| else: | |
| enc_out, enc_hidden = self.encoder( | |
| batch["melspecs"].unsqueeze(1).transpose(2, 3) | |
| ) | |
| chfeats = self.channelfeats(enc_hidden) | |
| enc_out = enc_out.squeeze(1).transpose(1, 2) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats = enc_out | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) | |
| remas = self.vocoder(enc_feats).squeeze(1) | |
| wavsdeg = self.channel(remas, chfeats) | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats_input = batch["melspecs"] | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats_input = torch.cat( | |
| (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 | |
| ) | |
| input_recons = self.vocoder(enc_feats_input).squeeze(1) | |
| wavsdegtask = self.channel(batch["wavstask"], chfeats) | |
| if "wavsaux" in batch: | |
| gt_wav = batch["wavsaux"] | |
| if self.config["general"]["feature_type"] == "melspec": | |
| enc_feats_aux = batch["melspecsaux"] | |
| elif self.config["general"]["feature_type"] == "vocfeats": | |
| enc_feats_aux = torch.cat( | |
| (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 | |
| ) | |
| recons_wav = self.vocoder(enc_feats_aux).squeeze(1) | |
| else: | |
| gt_wav = None | |
| recons_wav = None | |
| return { | |
| "reconstructed": recons_wav, | |
| "remastered": remas, | |
| "channeled": wavsdeg, | |
| "channeled_task": wavsdegtask, | |
| "input": batch["wavs"], | |
| "input_recons": input_recons, | |
| "groundtruth": gt_wav, | |
| } | |
| def test_epoch_end(self, outputs): | |
| wav_dir = ( | |
| pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" | |
| ) | |
| os.makedirs(wav_dir, exist_ok=True) | |
| mel_dir = ( | |
| pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" | |
| ) | |
| os.makedirs(mel_dir, exist_ok=True) | |
| print("Saving mel spectrogram plots ...") | |
| for idx, out in enumerate(tqdm.tqdm(outputs)): | |
| plot_keys = [] | |
| for key in [ | |
| "reconstructed", | |
| "remastered", | |
| "channeled", | |
| "channeled_task", | |
| "input", | |
| "input_recons", | |
| "groundtruth", | |
| ]: | |
| if out[key] != None: | |
| plot_keys.append(key) | |
| torchaudio.save( | |
| wav_dir / "{}-{}.wav".format(idx, key), | |
| out[key][0, ...].unsqueeze(0).cpu(), | |
| sample_rate=self.config["preprocess"]["sampling_rate"], | |
| channels_first=True, | |
| ) | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), lr=self.config["train"]["learning_rate"] | |
| ) | |
| lr_scheduler_config = { | |
| "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True | |
| ), | |
| "interval": "epoch", | |
| "frequency": 3, | |
| "monitor": "val_loss", | |
| } | |
| return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} | |
| def calc_spectrogram(self, wav): | |
| specs = self.spec_module(wav) | |
| log_spec = torch.log( | |
| torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"]) | |
| * self.config["preprocess"]["comp_factor"] | |
| ).to(torch.float32) | |
| return log_spec | |
| def get_loss_function(self, config): | |
| if config["train"]["feature_loss"]["type"] == "mae": | |
| feature_loss = torch.nn.L1Loss() | |
| else: | |
| feature_loss = torch.nn.MSELoss() | |
| return MultiScaleSpectralLoss(config), feature_loss | |