from typing import Dict, List, Optional import pytorch_lightning as pl import torch import torchaudio as ta from torch import nn from .bandsplit import BandSplitModule from .maskestim import OverlappingMaskEstimationModule from .tfmodel import SeqBandModellingModule from .utils import MusicalBandsplitSpecification class BaseEndToEndModule(pl.LightningModule): def __init__( self, ) -> None: super().__init__() class BaseBandit(BaseEndToEndModule): def __init__( self, in_channels: int, fs: int, band_type: str = "musical", n_bands: int = 64, require_no_overlap: bool = False, require_no_gap: bool = True, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, n_sqm_modules: int = 12, emb_dim: int = 128, rnn_dim: int = 256, bidirectional: bool = True, rnn_type: str = "LSTM", n_fft: int = 2048, win_length: Optional[int] = 2048, hop_length: int = 512, window_fn: str = "hann_window", wkwargs: Optional[Dict] = None, power: Optional[int] = None, center: bool = True, normalized: bool = True, pad_mode: str = "constant", onesided: bool = True, ): super().__init__() self.in_channels = in_channels self.instantitate_spectral( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_fn=window_fn, wkwargs=wkwargs, power=power, normalized=normalized, center=center, pad_mode=pad_mode, onesided=onesided, ) self.instantiate_bandsplit( in_channels=in_channels, band_type=band_type, n_bands=n_bands, require_no_overlap=require_no_overlap, require_no_gap=require_no_gap, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, emb_dim=emb_dim, n_fft=n_fft, fs=fs, ) self.instantiate_tf_modelling( n_sqm_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, ) def instantitate_spectral( self, n_fft: int = 2048, win_length: Optional[int] = 2048, hop_length: int = 512, window_fn: str = "hann_window", wkwargs: Optional[Dict] = None, power: Optional[int] = None, normalized: bool = True, center: bool = True, pad_mode: str = "constant", onesided: bool = True, ): assert power is None window_fn = torch.__dict__[window_fn] self.stft = ta.transforms.Spectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length, pad_mode=pad_mode, pad=0, window_fn=window_fn, wkwargs=wkwargs, power=power, normalized=normalized, center=center, onesided=onesided, ) self.istft = ta.transforms.InverseSpectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length, pad_mode=pad_mode, pad=0, window_fn=window_fn, wkwargs=wkwargs, normalized=normalized, center=center, onesided=onesided, ) def instantiate_bandsplit( self, in_channels: int, band_type: str = "musical", n_bands: int = 64, require_no_overlap: bool = False, require_no_gap: bool = True, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, emb_dim: int = 128, n_fft: int = 2048, fs: int = 44100, ): assert band_type == "musical" self.band_specs = MusicalBandsplitSpecification( nfft=n_fft, fs=fs, n_bands=n_bands ) self.band_split = BandSplitModule( in_channels=in_channels, band_specs=self.band_specs.get_band_specs(), require_no_overlap=require_no_overlap, require_no_gap=require_no_gap, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, emb_dim=emb_dim, ) def instantiate_tf_modelling( self, n_sqm_modules: int = 12, emb_dim: int = 128, rnn_dim: int = 256, bidirectional: bool = True, rnn_type: str = "LSTM", ): try: self.tf_model = torch.compile( SeqBandModellingModule( n_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, ), disable=True, ) except Exception: self.tf_model = SeqBandModellingModule( n_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, ) def mask(self, x, m): return x * m def forward(self, batch, mode="train"): # Model takes mono as input we give stereo, so we do process of each channel independently init_shape = batch.shape if not isinstance(batch, dict): mono = batch.view(-1, 1, batch.shape[-1]) batch = {"mixture": {"audio": mono}} with torch.no_grad(): mixture = batch["mixture"]["audio"] x = self.stft(mixture) batch["mixture"]["spectrogram"] = x if "sources" in batch.keys(): for stem in batch["sources"].keys(): s = batch["sources"][stem]["audio"] s = self.stft(s) batch["sources"][stem]["spectrogram"] = s batch = self.separate(batch) if 1: b = [] for s in self.stems: # We need to obtain stereo again r = batch["estimates"][s]["audio"].view( -1, init_shape[1], init_shape[2] ) b.append(r) # And we need to return back tensor and not independent stems batch = torch.stack(b, dim=1) return batch def encode(self, batch): x = batch["mixture"]["spectrogram"] length = batch["mixture"]["audio"].shape[-1] z = self.band_split(x) # (batch, emb_dim, n_band, n_time) q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) return x, q, length def separate(self, batch): raise NotImplementedError class Bandit(BaseBandit): def __init__( self, in_channels: int, stems: List[str], band_type: str = "musical", n_bands: int = 64, require_no_overlap: bool = False, require_no_gap: bool = True, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, n_sqm_modules: int = 12, emb_dim: int = 128, rnn_dim: int = 256, bidirectional: bool = True, rnn_type: str = "LSTM", mlp_dim: int = 512, hidden_activation: str = "Tanh", hidden_activation_kwargs: Dict | None = None, complex_mask: bool = True, use_freq_weights: bool = True, n_fft: int = 2048, win_length: int | None = 2048, hop_length: int = 512, window_fn: str = "hann_window", wkwargs: Dict | None = None, power: int | None = None, center: bool = True, normalized: bool = True, pad_mode: str = "constant", onesided: bool = True, fs: int = 44100, stft_precisions="32", bandsplit_precisions="bf16", tf_model_precisions="bf16", mask_estim_precisions="bf16", ): super().__init__( in_channels=in_channels, band_type=band_type, n_bands=n_bands, require_no_overlap=require_no_overlap, require_no_gap=require_no_gap, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, n_sqm_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_fn=window_fn, wkwargs=wkwargs, power=power, center=center, normalized=normalized, pad_mode=pad_mode, onesided=onesided, fs=fs, ) self.stems = stems self.instantiate_mask_estim( in_channels=in_channels, stems=stems, emb_dim=emb_dim, mlp_dim=mlp_dim, hidden_activation=hidden_activation, hidden_activation_kwargs=hidden_activation_kwargs, complex_mask=complex_mask, n_freq=n_fft // 2 + 1, use_freq_weights=use_freq_weights, ) def instantiate_mask_estim( self, in_channels: int, stems: List[str], emb_dim: int, mlp_dim: int, hidden_activation: str, hidden_activation_kwargs: Optional[Dict] = None, complex_mask: bool = True, n_freq: Optional[int] = None, use_freq_weights: bool = False, ): if hidden_activation_kwargs is None: hidden_activation_kwargs = {} assert n_freq is not None self.mask_estim = nn.ModuleDict( { stem: OverlappingMaskEstimationModule( band_specs=self.band_specs.get_band_specs(), freq_weights=self.band_specs.get_freq_weights(), n_freq=n_freq, emb_dim=emb_dim, mlp_dim=mlp_dim, in_channels=in_channels, hidden_activation=hidden_activation, hidden_activation_kwargs=hidden_activation_kwargs, complex_mask=complex_mask, use_freq_weights=use_freq_weights, ) for stem in stems } ) def separate(self, batch): batch["estimates"] = {} x, q, length = self.encode(batch) for stem, mem in self.mask_estim.items(): m = mem(q) s = self.mask(x, m.to(x.dtype)) s = torch.reshape(s, x.shape) batch["estimates"][stem] = { "audio": self.istft(s, length), "spectrogram": s, } return batch