| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from .base_model import BaseModel |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dimension, groups=1): |
| super().__init__() |
| |
| self.weight = nn.Parameter(torch.ones(dimension)) |
| self.groups = groups |
| self.eps = 1e-5 |
|
|
| def forward(self, input): |
| |
| B, N, T = input.shape |
| assert N % self.groups == 0 |
|
|
| input_float = input.reshape(B, self.groups, -1, T).float() |
| input_norm = input_float * torch.rsqrt(input_float.pow(2).mean(-2, keepdim=True) + self.eps) |
|
|
| return input_norm.type_as(input).reshape(B, N, T) * self.weight.reshape(1, -1, 1) |
| |
| class RMVN(nn.Module): |
| """ |
| Rescaled MVN. |
| """ |
| def __init__(self, dimension, groups=1): |
| super(RMVN, self).__init__() |
| |
| self.mean = nn.Parameter(torch.zeros(dimension)) |
| self.std = nn.Parameter(torch.ones(dimension)) |
| self.groups = groups |
| self.eps = 1e-5 |
|
|
| def forward(self, input): |
| |
| B, N = input.shape[:2] |
| assert N % self.groups == 0 |
| input_reshape = input.reshape(B, self.groups, N // self.groups, -1) |
| T = input_reshape.shape[-1] |
|
|
| input_norm = (input_reshape - input_reshape.mean(2).unsqueeze(2)) / (input_reshape.var(2).unsqueeze(2) + self.eps).sqrt() |
| input_norm = input_norm.reshape(B, N, T) * self.std.reshape(1, -1, 1) + self.mean.reshape(1, -1, 1) |
|
|
| return input_norm.reshape(input.shape) |
| |
| class Roformer(nn.Module): |
| """ |
| Transformer with rotary positional embedding. |
| """ |
| def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000, |
| input_drop=0., attention_drop=0., causal=True): |
| super().__init__() |
|
|
| self.input_size = input_size |
| self.hidden_size = hidden_size // num_head |
| self.num_head = num_head |
| self.theta = theta |
| self.window = window |
| |
| cos_freq, sin_freq = self._calc_rotary_emb() |
| self.register_buffer("cos_freq", cos_freq) |
| self.register_buffer("sin_freq", sin_freq) |
| |
| self.attention_drop = attention_drop |
| self.causal = causal |
| self.eps = 1e-5 |
|
|
| self.input_norm = RMSNorm(self.input_size) |
| self.input_drop = nn.Dropout(p=input_drop) |
| self.weight = nn.Conv1d(self.input_size, self.hidden_size*self.num_head*3, 1, bias=False) |
| self.output = nn.Conv1d(self.hidden_size*self.num_head, self.input_size, 1, bias=False) |
|
|
| self.MLP = nn.Sequential(RMSNorm(self.input_size), |
| nn.Conv1d(self.input_size, self.input_size*8, 1, bias=False), |
| nn.SiLU() |
| ) |
| self.MLP_output = nn.Conv1d(self.input_size*4, self.input_size, 1, bias=False) |
|
|
| def _calc_rotary_emb(self): |
| freq = 1. / (self.theta ** (torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size)) |
| freq = freq.reshape(1, -1) |
| pos = torch.arange(0, self.window).reshape(-1, 1) |
| cos_freq = torch.cos(pos*freq) |
| sin_freq = torch.sin(pos*freq) |
| cos_freq = torch.stack([cos_freq]*2, -1).reshape(self.window, self.hidden_size) |
| sin_freq = torch.stack([sin_freq]*2, -1).reshape(self.window, self.hidden_size) |
|
|
| return cos_freq, sin_freq |
| |
| def _add_rotary_emb(self, feature, pos): |
| |
| N = feature.shape[-1] |
|
|
| feature_reshape = feature.reshape(-1, N) |
| pos = min(pos, self.window-1) |
| cos_freq = self.cos_freq[pos] |
| sin_freq = self.sin_freq[pos] |
| reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) |
| feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, N) |
| feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) |
| |
| return feature_rope.reshape(feature.shape) |
|
|
| def _add_rotary_sequence(self, feature): |
| |
| T, N = feature.shape[-2:] |
| feature_reshape = feature.reshape(-1, T, N) |
|
|
| cos_freq = self.cos_freq[:T] |
| sin_freq = self.sin_freq[:T] |
| reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) |
| feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, T, N) |
| feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) |
| |
| return feature_rope.reshape(feature.shape) |
| |
| def forward(self, input): |
| |
|
|
| B, _, T = input.shape |
|
|
| weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size*3, T).mT |
| Q, K, V = torch.split(weight, self.hidden_size, dim=-1) |
| |
| |
| Q_rot = self._add_rotary_sequence(Q) |
| K_rot = self._add_rotary_sequence(K) |
|
|
| attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(), dropout_p=self.attention_drop, is_causal=self.causal) |
| attention_output = attention_output.mT.reshape(B, -1, T) |
| output = self.output(attention_output) + input |
|
|
| gate, z = self.MLP(output).chunk(2, dim=1) |
| output = output + self.MLP_output(F.silu(gate) * z) |
|
|
| return output, (K_rot, V) |
| |
| class ConvActNorm1d(nn.Module): |
| def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): |
| super(ConvActNorm1d, self).__init__() |
| |
| self.in_channel = in_channel |
| self.kernel = kernel |
| self.causal = causal |
| if not causal: |
| self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2, groups=in_channel), |
| RMSNorm(in_channel), |
| nn.Conv1d(in_channel, hidden_channel, 1), |
| nn.SiLU(), |
| nn.Conv1d(hidden_channel, in_channel, 1) |
| ) |
| else: |
| self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1, groups=in_channel), |
| RMSNorm(in_channel), |
| nn.Conv1d(in_channel, hidden_channel, 1), |
| nn.SiLU(), |
| nn.Conv1d(hidden_channel, in_channel, 1) |
| ) |
| |
| def forward(self, input): |
| |
| output = self.conv(input) |
| if self.causal: |
| output = output[...,:-self.kernel+1] |
| return input + output |
|
|
| class ICB(nn.Module): |
| def __init__(self, in_channel, kernel=7, causal=False): |
| super(ICB, self).__init__() |
| |
| self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), |
| ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), |
| ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) |
| ) |
| |
| def forward(self, input): |
| |
| return self.blocks(input) |
|
|
| class BSNet(nn.Module): |
| def __init__(self, feature_dim, kernel=7): |
| super(BSNet, self).__init__() |
|
|
| self.feature_dim = feature_dim |
|
|
| self.band_net = Roformer(self.feature_dim, self.feature_dim, num_head=8, window=100, causal=False) |
| self.seq_net = ICB(self.feature_dim, kernel=kernel) |
|
|
| def forward(self, input): |
| |
|
|
| B, nband, N, T = input.shape |
|
|
| |
| band_input = input.permute(0,3,2,1).reshape(B*T, -1, nband) |
| band_output, _ = self.band_net(band_input) |
| band_output = band_output.reshape(B, T, -1, nband).permute(0,3,2,1) |
|
|
| |
| output = self.seq_net(band_output.reshape(B*nband, -1, T)).reshape(B, nband, -1, T) |
|
|
| return output |
| |
| class Apollo(BaseModel): |
| def __init__( |
| self, |
| sr: int, |
| win: int, |
| feature_dim: int, |
| layer: int |
| ): |
| super().__init__(sample_rate=sr) |
| |
| self.sr = sr |
| self.win = int(sr * win // 1000) |
| self.stride = self.win // 2 |
| self.enc_dim = self.win // 2 + 1 |
| self.feature_dim = feature_dim |
| self.eps = torch.finfo(torch.float32).eps |
|
|
| |
| bandwidth = int(self.win / 160) |
| self.band_width = [bandwidth]*79 |
| self.band_width.append(self.enc_dim - np.sum(self.band_width)) |
| self.nband = len(self.band_width) |
| print(self.band_width, self.nband) |
|
|
| self.BN = nn.ModuleList([]) |
| for i in range(self.nband): |
| self.BN.append(nn.Sequential(RMSNorm(self.band_width[i]*2+1), |
| nn.Conv1d(self.band_width[i]*2+1, self.feature_dim, 1)) |
| ) |
|
|
| self.net = [] |
| for _ in range(layer): |
| self.net.append(BSNet(self.feature_dim)) |
| self.net = nn.Sequential(*self.net) |
| |
| self.output = nn.ModuleList([]) |
| for i in range(self.nband): |
| self.output.append(nn.Sequential(RMSNorm(self.feature_dim), |
| nn.Conv1d(self.feature_dim, self.band_width[i]*4, 1), |
| nn.GLU(dim=1) |
| ) |
| ) |
|
|
| def spec_band_split(self, input): |
|
|
| B, nch, nsample = input.shape |
|
|
| spec = torch.stft(input.view(B*nch, nsample), n_fft=self.win, hop_length=self.stride, |
| window=torch.hann_window(self.win).to(input.device), return_complex=True) |
|
|
| subband_spec = [] |
| subband_spec_norm = [] |
| subband_power = [] |
| band_idx = 0 |
| for i in range(self.nband): |
| this_spec = spec[:,band_idx:band_idx+self.band_width[i]] |
| subband_spec.append(this_spec) |
| subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) |
| subband_spec_norm.append(torch.complex(this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1])) |
| band_idx += self.band_width[i] |
| subband_power = torch.cat(subband_power, 1) |
|
|
| return subband_spec_norm, subband_power |
|
|
| def feature_extractor(self, input): |
| |
| subband_spec_norm, subband_power = self.spec_band_split(input) |
| |
| |
| subband_feature = [] |
| for i in range(self.nband): |
| concat_spec = torch.cat([subband_spec_norm[i].real, subband_spec_norm[i].imag, torch.log(subband_power[:,i].unsqueeze(1))], 1) |
| subband_feature.append(self.BN[i](concat_spec)) |
| subband_feature = torch.stack(subband_feature, 1) |
|
|
| return subband_feature |
| |
| def forward(self, input): |
|
|
| B, nch, nsample = input.shape |
|
|
| subband_feature = self.feature_extractor(input) |
| feature = self.net(subband_feature) |
|
|
| est_spec = [] |
| for i in range(self.nband): |
| this_RI = self.output[i](feature[:,i]).view(B*nch, 2, self.band_width[i], -1) |
| est_spec.append(torch.complex(this_RI[:,0], this_RI[:,1])) |
| est_spec = torch.cat(est_spec, 1) |
| output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, |
| window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1) |
| |
| return output |
| |
| def get_model_args(self): |
| model_args = {"n_sample_rate": 2} |
| return model_args |