xjsc0's picture
1
64ec292
from typing import Dict, List, Optional
import pytorch_lightning as pl
import torch
import torchaudio as ta
from torch import nn
from .bandsplit import BandSplitModule
from .maskestim import OverlappingMaskEstimationModule
from .tfmodel import SeqBandModellingModule
from .utils import MusicalBandsplitSpecification
class BaseEndToEndModule(pl.LightningModule):
def __init__(
self,
) -> None:
super().__init__()
class BaseBandit(BaseEndToEndModule):
def __init__(
self,
in_channels: int,
fs: int,
band_type: str = "musical",
n_bands: int = 64,
require_no_overlap: bool = False,
require_no_gap: bool = True,
normalize_channel_independently: bool = False,
treat_channel_as_feature: bool = True,
n_sqm_modules: int = 12,
emb_dim: int = 128,
rnn_dim: int = 256,
bidirectional: bool = True,
rnn_type: str = "LSTM",
n_fft: int = 2048,
win_length: Optional[int] = 2048,
hop_length: int = 512,
window_fn: str = "hann_window",
wkwargs: Optional[Dict] = None,
power: Optional[int] = None,
center: bool = True,
normalized: bool = True,
pad_mode: str = "constant",
onesided: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.instantitate_spectral(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window_fn=window_fn,
wkwargs=wkwargs,
power=power,
normalized=normalized,
center=center,
pad_mode=pad_mode,
onesided=onesided,
)
self.instantiate_bandsplit(
in_channels=in_channels,
band_type=band_type,
n_bands=n_bands,
require_no_overlap=require_no_overlap,
require_no_gap=require_no_gap,
normalize_channel_independently=normalize_channel_independently,
treat_channel_as_feature=treat_channel_as_feature,
emb_dim=emb_dim,
n_fft=n_fft,
fs=fs,
)
self.instantiate_tf_modelling(
n_sqm_modules=n_sqm_modules,
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
)
def instantitate_spectral(
self,
n_fft: int = 2048,
win_length: Optional[int] = 2048,
hop_length: int = 512,
window_fn: str = "hann_window",
wkwargs: Optional[Dict] = None,
power: Optional[int] = None,
normalized: bool = True,
center: bool = True,
pad_mode: str = "constant",
onesided: bool = True,
):
assert power is None
window_fn = torch.__dict__[window_fn]
self.stft = ta.transforms.Spectrogram(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
pad_mode=pad_mode,
pad=0,
window_fn=window_fn,
wkwargs=wkwargs,
power=power,
normalized=normalized,
center=center,
onesided=onesided,
)
self.istft = ta.transforms.InverseSpectrogram(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
pad_mode=pad_mode,
pad=0,
window_fn=window_fn,
wkwargs=wkwargs,
normalized=normalized,
center=center,
onesided=onesided,
)
def instantiate_bandsplit(
self,
in_channels: int,
band_type: str = "musical",
n_bands: int = 64,
require_no_overlap: bool = False,
require_no_gap: bool = True,
normalize_channel_independently: bool = False,
treat_channel_as_feature: bool = True,
emb_dim: int = 128,
n_fft: int = 2048,
fs: int = 44100,
):
assert band_type == "musical"
self.band_specs = MusicalBandsplitSpecification(
nfft=n_fft, fs=fs, n_bands=n_bands
)
self.band_split = BandSplitModule(
in_channels=in_channels,
band_specs=self.band_specs.get_band_specs(),
require_no_overlap=require_no_overlap,
require_no_gap=require_no_gap,
normalize_channel_independently=normalize_channel_independently,
treat_channel_as_feature=treat_channel_as_feature,
emb_dim=emb_dim,
)
def instantiate_tf_modelling(
self,
n_sqm_modules: int = 12,
emb_dim: int = 128,
rnn_dim: int = 256,
bidirectional: bool = True,
rnn_type: str = "LSTM",
):
try:
self.tf_model = torch.compile(
SeqBandModellingModule(
n_modules=n_sqm_modules,
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
),
disable=True,
)
except Exception:
self.tf_model = SeqBandModellingModule(
n_modules=n_sqm_modules,
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
)
def mask(self, x, m):
return x * m
def forward(self, batch, mode="train"):
# Model takes mono as input we give stereo, so we do process of each channel independently
init_shape = batch.shape
if not isinstance(batch, dict):
mono = batch.view(-1, 1, batch.shape[-1])
batch = {"mixture": {"audio": mono}}
with torch.no_grad():
mixture = batch["mixture"]["audio"]
x = self.stft(mixture)
batch["mixture"]["spectrogram"] = x
if "sources" in batch.keys():
for stem in batch["sources"].keys():
s = batch["sources"][stem]["audio"]
s = self.stft(s)
batch["sources"][stem]["spectrogram"] = s
batch = self.separate(batch)
if 1:
b = []
for s in self.stems:
# We need to obtain stereo again
r = batch["estimates"][s]["audio"].view(
-1, init_shape[1], init_shape[2]
)
b.append(r)
# And we need to return back tensor and not independent stems
batch = torch.stack(b, dim=1)
return batch
def encode(self, batch):
x = batch["mixture"]["spectrogram"]
length = batch["mixture"]["audio"].shape[-1]
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
return x, q, length
def separate(self, batch):
raise NotImplementedError
class Bandit(BaseBandit):
def __init__(
self,
in_channels: int,
stems: List[str],
band_type: str = "musical",
n_bands: int = 64,
require_no_overlap: bool = False,
require_no_gap: bool = True,
normalize_channel_independently: bool = False,
treat_channel_as_feature: bool = True,
n_sqm_modules: int = 12,
emb_dim: int = 128,
rnn_dim: int = 256,
bidirectional: bool = True,
rnn_type: str = "LSTM",
mlp_dim: int = 512,
hidden_activation: str = "Tanh",
hidden_activation_kwargs: Dict | None = None,
complex_mask: bool = True,
use_freq_weights: bool = True,
n_fft: int = 2048,
win_length: int | None = 2048,
hop_length: int = 512,
window_fn: str = "hann_window",
wkwargs: Dict | None = None,
power: int | None = None,
center: bool = True,
normalized: bool = True,
pad_mode: str = "constant",
onesided: bool = True,
fs: int = 44100,
stft_precisions="32",
bandsplit_precisions="bf16",
tf_model_precisions="bf16",
mask_estim_precisions="bf16",
):
super().__init__(
in_channels=in_channels,
band_type=band_type,
n_bands=n_bands,
require_no_overlap=require_no_overlap,
require_no_gap=require_no_gap,
normalize_channel_independently=normalize_channel_independently,
treat_channel_as_feature=treat_channel_as_feature,
n_sqm_modules=n_sqm_modules,
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window_fn=window_fn,
wkwargs=wkwargs,
power=power,
center=center,
normalized=normalized,
pad_mode=pad_mode,
onesided=onesided,
fs=fs,
)
self.stems = stems
self.instantiate_mask_estim(
in_channels=in_channels,
stems=stems,
emb_dim=emb_dim,
mlp_dim=mlp_dim,
hidden_activation=hidden_activation,
hidden_activation_kwargs=hidden_activation_kwargs,
complex_mask=complex_mask,
n_freq=n_fft // 2 + 1,
use_freq_weights=use_freq_weights,
)
def instantiate_mask_estim(
self,
in_channels: int,
stems: List[str],
emb_dim: int,
mlp_dim: int,
hidden_activation: str,
hidden_activation_kwargs: Optional[Dict] = None,
complex_mask: bool = True,
n_freq: Optional[int] = None,
use_freq_weights: bool = False,
):
if hidden_activation_kwargs is None:
hidden_activation_kwargs = {}
assert n_freq is not None
self.mask_estim = nn.ModuleDict(
{
stem: OverlappingMaskEstimationModule(
band_specs=self.band_specs.get_band_specs(),
freq_weights=self.band_specs.get_freq_weights(),
n_freq=n_freq,
emb_dim=emb_dim,
mlp_dim=mlp_dim,
in_channels=in_channels,
hidden_activation=hidden_activation,
hidden_activation_kwargs=hidden_activation_kwargs,
complex_mask=complex_mask,
use_freq_weights=use_freq_weights,
)
for stem in stems
}
)
def separate(self, batch):
batch["estimates"] = {}
x, q, length = self.encode(batch)
for stem, mem in self.mask_estim.items():
m = mem(q)
s = self.mask(x, m.to(x.dtype))
s = torch.reshape(s, x.shape)
batch["estimates"][stem] = {
"audio": self.istft(s, length),
"spectrogram": s,
}
return batch