xjsc0's picture
1
64ec292
# https://github.com/Human9000/nd-Mamba2-torch
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential
try:
from mamba_ssm.modules.mamba2 import Mamba2
except Exception as e:
print("Exception during load Mamba2 modules: {}".format(str(e)))
print("Load local torch implementation!")
from .ex_bi_mamba2 import Mamba2
class MambaBlock(nn.Module):
def __init__(self, in_channels):
super(MambaBlock, self).__init__()
self.forward_mamba2 = Mamba2(
d_model=in_channels,
d_state=128,
d_conv=4,
expand=4,
headdim=64,
)
self.backward_mamba2 = Mamba2(
d_model=in_channels,
d_state=128,
d_conv=4,
expand=4,
headdim=64,
)
def forward(self, input):
forward_f = input
forward_f_output = self.forward_mamba2(forward_f)
backward_f = torch.flip(input, [1])
backward_f_output = self.backward_mamba2(backward_f)
backward_f_output2 = torch.flip(backward_f_output, [1])
output = torch.cat([forward_f_output + input, backward_f_output2 + input], -1)
return output
class TAC(nn.Module):
"""
A transform-average-concatenate (TAC) module.
"""
def __init__(self, input_size, hidden_size):
super(TAC, self).__init__()
self.input_size = input_size
self.eps = torch.finfo(torch.float32).eps
self.input_norm = nn.GroupNorm(1, input_size, self.eps)
self.TAC_input = nn.Sequential(nn.Linear(input_size, hidden_size), nn.Tanh())
self.TAC_mean = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.TAC_output = nn.Sequential(
nn.Linear(hidden_size * 2, input_size), nn.Tanh()
)
def forward(self, input):
# input shape: batch, group, N, *
batch_size, G, N = input.shape[:3]
output = self.input_norm(input.view(batch_size * G, N, -1)).view(
batch_size, G, N, -1
)
T = output.shape[-1]
# transform
group_input = output # B, G, N, T
group_input = (
group_input.permute(0, 3, 1, 2).contiguous().view(-1, N)
) # B*T*G, N
group_output = self.TAC_input(group_input).view(
batch_size, T, G, -1
) # B, T, G, H
# mean pooling
group_mean = group_output.mean(2).view(batch_size * T, -1) # B*T, H
group_mean = (
self.TAC_mean(group_mean)
.unsqueeze(1)
.expand(batch_size * T, G, group_mean.shape[-1])
.contiguous()
) # B*T, G, H
# concate
group_output = group_output.view(batch_size * T, G, -1) # B*T, G, H
group_output = torch.cat([group_output, group_mean], 2) # B*T, G, 2H
group_output = self.TAC_output(
group_output.view(-1, group_output.shape[-1])
) # B*T*G, N
group_output = (
group_output.view(batch_size, T, G, -1).permute(0, 2, 3, 1).contiguous()
) # B, G, N, T
output = input + group_output.view(input.shape)
return output
class ResMamba(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.0, bidirectional=True):
super(ResMamba, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.eps = torch.finfo(torch.float32).eps
self.norm = nn.GroupNorm(1, input_size, self.eps)
self.dropout = nn.Dropout(p=dropout)
self.rnn = MambaBlock(input_size)
self.proj = nn.Linear(input_size * 2, input_size)
# linear projection layer
def forward(self, input):
# input shape: batch, dim, seq
rnn_output = self.rnn(
self.dropout(self.norm(input)).transpose(1, 2).contiguous()
)
rnn_output = self.proj(
rnn_output.contiguous().view(-1, rnn_output.shape[2])
).view(input.shape[0], input.shape[2], input.shape[1])
return input + rnn_output.transpose(1, 2).contiguous()
class BSNet(nn.Module):
def __init__(self, in_channel, nband=7):
super(BSNet, self).__init__()
self.nband = nband
self.feature_dim = in_channel // nband
self.band_rnn = ResMamba(self.feature_dim, self.feature_dim * 2)
self.band_comm = ResMamba(self.feature_dim, self.feature_dim * 2)
self.channel_comm = TAC(self.feature_dim, self.feature_dim * 3)
def forward(self, input):
# input shape: B, nch, nband*N, T
B, nch, N, T = input.shape
band_output = self.band_rnn(
input.view(B * nch * self.nband, self.feature_dim, -1)
).view(B * nch, self.nband, -1, T)
# band comm
band_output = (
band_output.permute(0, 3, 2, 1)
.contiguous()
.view(B * nch * T, -1, self.nband)
)
output = (
self.band_comm(band_output)
.view(B * nch, T, -1, self.nband)
.permute(0, 3, 2, 1)
.contiguous()
)
# channel comm
output = (
output.view(B, nch, self.nband, -1, T)
.transpose(1, 2)
.contiguous()
.view(B * self.nband, nch, -1, T)
)
output = (
self.channel_comm(output)
.view(B, self.nband, nch, -1, T)
.transpose(1, 2)
.contiguous()
)
return output.view(B, nch, N, T)
class Separator(nn.Module):
def __init__(
self,
sr=44100,
win=2048,
stride=512,
feature_dim=128,
num_repeat_mask=8,
num_repeat_map=4,
num_output=4,
):
super(Separator, self).__init__()
self.sr = sr
self.win = win
self.stride = stride
self.group = self.win // 2
self.enc_dim = self.win // 2 + 1
self.feature_dim = feature_dim
self.num_output = num_output
self.eps = torch.finfo(torch.float32).eps
# 0-1k (50 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop), 8k-16k (1k hop), 16k-20k (2k hop), 20k-inf
bandwidth_50 = int(np.floor(50 / (sr / 2.0) * self.enc_dim))
bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim))
bandwidth_250 = int(np.floor(250 / (sr / 2.0) * self.enc_dim))
bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim))
bandwidth_1k = int(np.floor(1000 / (sr / 2.0) * self.enc_dim))
bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim))
self.band_width = [bandwidth_50] * 20
self.band_width += [bandwidth_100] * 10
self.band_width += [bandwidth_250] * 8
self.band_width += [bandwidth_500] * 8
self.band_width += [bandwidth_1k] * 8
self.band_width += [bandwidth_2k] * 2
self.band_width.append(self.enc_dim - np.sum(self.band_width))
self.nband = len(self.band_width)
print(self.band_width)
self.BN_mask = nn.ModuleList([])
for i in range(self.nband):
self.BN_mask.append(
nn.Sequential(
nn.GroupNorm(1, self.band_width[i] * 2, self.eps),
nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),
)
)
self.BN_map = nn.ModuleList([])
for i in range(self.nband):
self.BN_map.append(
nn.Sequential(
nn.GroupNorm(1, self.band_width[i] * 2, self.eps),
nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),
)
)
self.separator_mask = []
for i in range(num_repeat_mask):
self.separator_mask.append(BSNet(self.nband * self.feature_dim, self.nband))
self.separator_mask = nn.Sequential(*self.separator_mask)
self.separator_map = []
for i in range(num_repeat_map):
self.separator_map.append(BSNet(self.nband * self.feature_dim, self.nband))
self.separator_map = nn.Sequential(*self.separator_map)
self.in_conv = nn.Conv1d(self.feature_dim * 2, self.feature_dim, 1)
self.Tanh = nn.Tanh()
self.mask = nn.ModuleList([])
self.map = nn.ModuleList([])
for i in range(self.nband):
self.mask.append(
nn.Sequential(
nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps),
nn.Conv1d(
self.feature_dim, self.feature_dim * 1 * self.num_output, 1
),
nn.Tanh(),
nn.Conv1d(
self.feature_dim * 1 * self.num_output,
self.feature_dim * 1 * self.num_output,
1,
groups=self.num_output,
),
nn.Tanh(),
nn.Conv1d(
self.feature_dim * 1 * self.num_output,
self.band_width[i] * 4 * self.num_output,
1,
groups=self.num_output,
),
)
)
self.map.append(
nn.Sequential(
nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps),
nn.Conv1d(
self.feature_dim, self.feature_dim * 1 * self.num_output, 1
),
nn.Tanh(),
nn.Conv1d(
self.feature_dim * 1 * self.num_output,
self.feature_dim * 1 * self.num_output,
1,
groups=self.num_output,
),
nn.Tanh(),
nn.Conv1d(
self.feature_dim * 1 * self.num_output,
self.band_width[i] * 4 * self.num_output,
1,
groups=self.num_output,
),
)
)
def pad_input(self, input, window, stride):
"""
Zero-padding input according to window/stride size.
"""
batch_size, nsample = input.shape
# pad the signals at the end for matching the window/stride size
rest = window - (stride + nsample % window) % window
if rest > 0:
pad = torch.zeros(batch_size, rest).type(input.type())
input = torch.cat([input, pad], 1)
pad_aux = torch.zeros(batch_size, stride).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 1)
return input, rest
def forward(self, input):
# input shape: (B, C, T)
batch_size, nch, nsample = input.shape
input = input.view(batch_size * nch, -1)
# frequency-domain separation
spec = torch.stft(
input,
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(input.device).type(input.type()),
return_complex=True,
)
# concat real and imag, split to subbands
spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
subband_spec_RI = []
subband_spec = []
band_idx = 0
for i in range(len(self.band_width)):
subband_spec_RI.append(
spec_RI[:, :, band_idx : band_idx + self.band_width[i]].contiguous()
)
subband_spec.append(
spec[:, band_idx : band_idx + self.band_width[i]]
) # B*nch, BW, T
band_idx += self.band_width[i]
# normalization and bottleneck
subband_feature_mask = []
for i in range(len(self.band_width)):
subband_feature_mask.append(
self.BN_mask[i](
subband_spec_RI[i].view(
batch_size * nch, self.band_width[i] * 2, -1
)
)
)
subband_feature_mask = torch.stack(subband_feature_mask, 1) # B, nband, N, T
subband_feature_map = []
for i in range(len(self.band_width)):
subband_feature_map.append(
self.BN_map[i](
subband_spec_RI[i].view(
batch_size * nch, self.band_width[i] * 2, -1
)
)
)
subband_feature_map = torch.stack(subband_feature_map, 1) # B, nband, N, T
# separator
sep_output = checkpoint_sequential(
self.separator_mask,
2,
subband_feature_mask.view(
batch_size, nch, self.nband * self.feature_dim, -1
),
) # B, nband*N, T
sep_output = sep_output.view(batch_size * nch, self.nband, self.feature_dim, -1)
combined = torch.cat((subband_feature_map, sep_output), dim=2)
combined1 = combined.reshape(
batch_size * nch * self.nband, self.feature_dim * 2, -1
)
combined2 = self.Tanh(self.in_conv(combined1))
combined3 = combined2.reshape(
batch_size * nch, self.nband, self.feature_dim, -1
)
sep_output2 = checkpoint_sequential(
self.separator_map,
2,
combined3.view(batch_size, nch, self.nband * self.feature_dim, -1),
) # 1B, nband*N, T
sep_output2 = sep_output2.view(
batch_size * nch, self.nband, self.feature_dim, -1
)
sep_subband_spec = []
sep_subband_spec_mask = []
for i in range(self.nband):
this_output = self.mask[i](sep_output[:, i]).view(
batch_size * nch, 2, 2, self.num_output, self.band_width[i], -1
)
this_mask = this_output[:, 0] * torch.sigmoid(
this_output[:, 1]
) # B*nch, 2, K, BW, T
this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
# force mask sum to 1
this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) # B*nch, 1, BW, T
this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) # B*nch, 1, BW, T
this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output
this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output
est_spec_real = (
subband_spec[i].real.unsqueeze(1) * this_mask_real
- subband_spec[i].imag.unsqueeze(1) * this_mask_imag
) # B*nch, K, BW, T
est_spec_imag = (
subband_spec[i].real.unsqueeze(1) * this_mask_imag
+ subband_spec[i].imag.unsqueeze(1) * this_mask_real
) # B*nch, K, BW, T
##################################
this_output2 = self.map[i](sep_output2[:, i]).view(
batch_size * nch, 2, 2, self.num_output, self.band_width[i], -1
)
this_map = this_output2[:, 0] * torch.sigmoid(
this_output2[:, 1]
) # B*nch, 2, K, BW, T
this_map_real = this_map[:, 0] # B*nch, K, BW, T
this_map_imag = this_map[:, 1] # B*nch, K, BW, T
est_spec_real2 = est_spec_real + this_map_real
est_spec_imag2 = est_spec_imag + this_map_imag
sep_subband_spec.append(torch.complex(est_spec_real2, est_spec_imag2))
sep_subband_spec_mask.append(torch.complex(est_spec_real, est_spec_imag))
sep_subband_spec = torch.cat(sep_subband_spec, 2)
est_spec_mask = torch.cat(sep_subband_spec_mask, 2)
output = torch.istft(
sep_subband_spec.view(batch_size * nch * self.num_output, self.enc_dim, -1),
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(input.device).type(input.type()),
length=nsample,
)
output_mask = torch.istft(
est_spec_mask.view(batch_size * nch * self.num_output, self.enc_dim, -1),
n_fft=self.win,
hop_length=self.stride,
window=torch.hann_window(self.win).to(input.device).type(input.type()),
length=nsample,
)
output = (
output.view(batch_size, nch, self.num_output, -1)
.transpose(1, 2)
.contiguous()
)
output_mask = (
output_mask.view(batch_size, nch, self.num_output, -1)
.transpose(1, 2)
.contiguous()
)
# return output, output_mask
return output
if __name__ == "__main__":
model = Separator().cuda()
arr = np.zeros((1, 2, 3 * 44100), dtype=np.float32)
x = torch.from_numpy(arr).cuda()
res = model(x)