| from typing import Any, Callable, Dict
|
| import random
|
| try:
|
| import lightning.pytorch as pl
|
| except ImportError:
|
| import pytorch_lightning as pl
|
| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
|
| class AudioSep(pl.LightningModule):
|
| def __init__(
|
| self,
|
| ss_model: nn.Module,
|
| waveform_mixer,
|
| query_encoder,
|
| loss_function,
|
| optimizer_type: str,
|
| learning_rate: float,
|
| lr_lambda_func,
|
| use_text_ratio=1.0,
|
| ):
|
| r"""Pytorch Lightning wrapper of PyTorch model, including forward,
|
| optimization of model, etc.
|
|
|
| Args:
|
| ss_model: nn.Module
|
| anchor_segment_detector: nn.Module
|
| loss_function: function or object
|
| learning_rate: float
|
| lr_lambda: function
|
| """
|
|
|
| super().__init__()
|
| self.ss_model = ss_model
|
| self.waveform_mixer = waveform_mixer
|
| self.query_encoder = query_encoder
|
| self.query_encoder_type = self.query_encoder.encoder_type
|
| self.use_text_ratio = use_text_ratio
|
| self.loss_function = loss_function
|
| self.optimizer_type = optimizer_type
|
| self.learning_rate = learning_rate
|
| self.lr_lambda_func = lr_lambda_func
|
|
|
|
|
| def forward(self, x):
|
| pass
|
|
|
| def training_step(self, batch_data_dict, batch_idx):
|
| r"""Forward a mini-batch data to model, calculate loss function, and
|
| train for one step. A mini-batch data is evenly distributed to multiple
|
| devices (if there are) for parallel training.
|
|
|
| Args:
|
| batch_data_dict: e.g.
|
| 'audio_text': {
|
| 'text': ['a sound of dog', ...]
|
| 'waveform': (batch_size, 1, samples)
|
| }
|
| batch_idx: int
|
|
|
| Returns:
|
| loss: float, loss function of this mini-batch
|
| """
|
|
|
| random.seed(batch_idx)
|
|
|
| batch_audio_text_dict = batch_data_dict['audio_text']
|
|
|
| batch_text = batch_audio_text_dict['text']
|
| batch_audio = batch_audio_text_dict['waveform']
|
| device = batch_audio.device
|
|
|
| mixtures, segments = self.waveform_mixer(
|
| waveforms=batch_audio
|
| )
|
|
|
|
|
| if self.query_encoder_type == 'CLAP':
|
| conditions = self.query_encoder.get_query_embed(
|
| modality='hybird',
|
| text=batch_text,
|
| audio=segments.squeeze(1),
|
| use_text_ratio=self.use_text_ratio,
|
| )
|
|
|
| input_dict = {
|
| 'mixture': mixtures[:, None, :].squeeze(1),
|
| 'condition': conditions,
|
| }
|
|
|
| target_dict = {
|
| 'segment': segments.squeeze(1),
|
| }
|
|
|
| self.ss_model.train()
|
| sep_segment = self.ss_model(input_dict)['waveform']
|
| sep_segment = sep_segment.squeeze()
|
|
|
|
|
| output_dict = {
|
| 'segment': sep_segment,
|
| }
|
|
|
|
|
| loss = self.loss_function(output_dict, target_dict)
|
|
|
| self.log_dict({"train_loss": loss})
|
|
|
| return loss
|
|
|
| def test_step(self, batch, batch_idx):
|
| pass
|
|
|
| def configure_optimizers(self):
|
| r"""Configure optimizer.
|
| """
|
|
|
| if self.optimizer_type == "AdamW":
|
| optimizer = optim.AdamW(
|
| params=self.ss_model.parameters(),
|
| lr=self.learning_rate,
|
| betas=(0.9, 0.999),
|
| eps=1e-08,
|
| weight_decay=0.0,
|
| amsgrad=True,
|
| )
|
| else:
|
| raise NotImplementedError
|
|
|
| scheduler = LambdaLR(optimizer, self.lr_lambda_func)
|
|
|
| output_dict = {
|
| "optimizer": optimizer,
|
| "lr_scheduler": {
|
| 'scheduler': scheduler,
|
| 'interval': 'step',
|
| 'frequency': 1,
|
| }
|
| }
|
|
|
| return output_dict
|
|
|
|
|
| def get_model_class(model_type):
|
| if model_type == 'ResUNet30':
|
| from models.resunet import ResUNet30
|
| return ResUNet30
|
|
|
| else:
|
| raise NotImplementedError
|
|
|