|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
|
|
|
from .configuration_bs_roformer import BSRoformerConfig |
|
|
|
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin): |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__() |
|
|
self.head_dim = config.hidden_size // config.num_attention_heads |
|
|
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) |
|
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
def forward(self, x, position_ids): |
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos() |
|
|
sin = emb.sin() |
|
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
class BSRoformerMLP(nn.Module): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
self.act_fn = ACT2FN["gelu"] |
|
|
|
|
|
def forward(self, x): |
|
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
return down_proj |
|
|
|
|
|
|
|
|
class BSRoformerAttention(nn.Module): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__() |
|
|
self.is_causal = False |
|
|
self.config = config |
|
|
|
|
|
self.head_dim = config.hidden_size // config.num_attention_heads |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.attention_dropout = config.attention_dropout |
|
|
|
|
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
|
|
|
|
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
position_embeddings: tuple[torch.Tensor, torch.Tensor], |
|
|
attention_mask=None, |
|
|
): |
|
|
input_shape = hidden_states.size()[:-1] |
|
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
scaling=self.scaling, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
class BSRoformerLayer(nn.Module): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__() |
|
|
self.self_attn = BSRoformerAttention(config) |
|
|
self.mlp = BSRoformerMLP(config) |
|
|
|
|
|
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
position_embeddings, |
|
|
attention_mask, |
|
|
): |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
hidden_states, _ = self.self_attn( |
|
|
hidden_states, |
|
|
position_embeddings, |
|
|
attention_mask, |
|
|
) |
|
|
hidden_states = hidden_states + residual |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = hidden_states + residual |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BSRoformerAxialTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
config: BSRoformerConfig, |
|
|
transformer_depth: int, |
|
|
is_time_transformer: bool, |
|
|
): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList([BSRoformerLayer(config) for _ in range(transformer_depth)]) |
|
|
self.is_time_transformer = is_time_transformer |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
position_embeddings, |
|
|
attention_mask, |
|
|
): |
|
|
if self.is_time_transformer: |
|
|
hidden_states = rearrange(hidden_states, 'b t f d -> b f t d') |
|
|
|
|
|
b, seq_len_1, seq_len_2, d = hidden_states.shape |
|
|
hidden_states = rearrange(hidden_states, 'b n m d -> (b n) m d') |
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states = layer( |
|
|
hidden_states, |
|
|
position_embeddings, |
|
|
attention_mask, |
|
|
) |
|
|
|
|
|
hidden_states = rearrange(hidden_states, '(b n) m d -> b n m d', b=b) |
|
|
|
|
|
if self.is_time_transformer: |
|
|
hidden_states = rearrange(hidden_states, 'b f t d -> b t f d') |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BandSplit(nn.Module): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__() |
|
|
self.dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands) |
|
|
self.to_features = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential(nn.RMSNorm(dim_in, eps=config.rms_norm_eps), nn.Linear(dim_in, config.hidden_size)) |
|
|
for dim_in in self.dim_inputs |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x_split = x.split(self.dim_inputs, dim=-1) |
|
|
outs = [to_feature(split_input) for split_input, to_feature in zip(x_split, self.to_features)] |
|
|
return torch.stack(outs, dim=-2) |
|
|
|
|
|
|
|
|
class MaskEstimator(nn.Module): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__() |
|
|
dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands) |
|
|
self.to_freq_mlps = nn.ModuleList([nn.Linear(config.hidden_size, dim_in) for dim_in in dim_inputs]) |
|
|
|
|
|
def forward(self, x): |
|
|
x_unbind = x.unbind(dim=-2) |
|
|
outs = [mlp(band_features) for band_features, mlp in zip(x_unbind, self.to_freq_mlps)] |
|
|
return torch.cat(outs, dim=-1) |
|
|
|
|
|
|
|
|
class BSRoformerPreTrainedModel(PreTrainedModel): |
|
|
config_class = BSRoformerConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["BSRoformerLayer"] |
|
|
|
|
|
|
|
|
class BSRoformerModel(BSRoformerPreTrainedModel): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.band_split = BandSplit(config) |
|
|
self.layers = nn.ModuleList( |
|
|
nn.ModuleList( |
|
|
[ |
|
|
BSRoformerAxialTransformer(config, config.time_transformer_depth, is_time_transformer=True), |
|
|
BSRoformerAxialTransformer(config, config.freq_transformer_depth, is_time_transformer=False), |
|
|
] |
|
|
) |
|
|
for _ in range(config.depth) |
|
|
) |
|
|
self.rotary_emb = RotaryEmbedding(config) |
|
|
self.final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
rn = config.register_token_num |
|
|
self.register_tokens = nn.Parameter(torch.normal(0, 0.02, size=(rn, rn, config.hidden_size))) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
position_ids=None, |
|
|
): |
|
|
hidden_states = self.band_split(x) |
|
|
|
|
|
b, t, n, h = hidden_states.shape |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(t, device=hidden_states.device).unsqueeze(0) |
|
|
pos_embeds = self.rotary_emb(hidden_states, position_ids) |
|
|
pos_embeds_for_freq = self.rotary_emb( |
|
|
hidden_states, |
|
|
torch.arange(n, device=hidden_states.device).unsqueeze(0), |
|
|
) |
|
|
|
|
|
rn = self.config.register_token_num |
|
|
hidden_states = F.pad(hidden_states, (0, 0, 0, rn, 0, rn)) |
|
|
hidden_states[:, t:, n:, :] = self.register_tokens |
|
|
|
|
|
def pad_rope(cos, sin): |
|
|
cos_padded = F.pad(cos, (0, 0, 0, rn), value=1.0) |
|
|
sin_padded = F.pad(sin, (0, 0, 0, rn), value=0.0) |
|
|
return cos_padded, sin_padded |
|
|
|
|
|
pos_embeds = pad_rope(*pos_embeds) |
|
|
pos_embeds_for_freq = pad_rope(*pos_embeds_for_freq) |
|
|
|
|
|
for time_transformer, freq_transformer in self.layers: |
|
|
hidden_states = time_transformer( |
|
|
hidden_states, |
|
|
position_embeddings=pos_embeds, |
|
|
attention_mask=None, |
|
|
) |
|
|
hidden_states = freq_transformer( |
|
|
hidden_states, |
|
|
position_embeddings=pos_embeds_for_freq, |
|
|
attention_mask=None, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states[:, :t, :n, :] |
|
|
|
|
|
return self.final_norm(hidden_states) |
|
|
|
|
|
|
|
|
class BSRoformerForMaskedEstimation(BSRoformerPreTrainedModel): |
|
|
def __init__(self, config: BSRoformerConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.model = BSRoformerModel(config) |
|
|
self.mask_estimators = nn.ModuleList([MaskEstimator(config) for _ in range(config.num_stems)]) |
|
|
|
|
|
self.stft_kwargs = dict( |
|
|
n_fft=config.stft_n_fft, |
|
|
hop_length=config.stft_hop_length, |
|
|
win_length=config.stft_win_length, |
|
|
normalized=False, |
|
|
) |
|
|
self.register_buffer("stft_window", torch.hann_window(config.stft_win_length), persistent=False) |
|
|
|
|
|
freqs = config.stft_n_fft // 2 + 1 |
|
|
assert sum(config.freqs_per_bands) == freqs, f"Sum of freqs_per_bands must be {freqs}" |
|
|
self.wave_channels = config.num_input_channel |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
raw_audio: torch.Tensor, |
|
|
target: Optional[torch.Tensor] = None, |
|
|
): |
|
|
device = raw_audio.device |
|
|
|
|
|
with torch.autocast(device_type=device.type, enabled=False): |
|
|
b, c, t = raw_audio.shape |
|
|
raw_audio_packed = rearrange(raw_audio, "b c t -> (b c) t") |
|
|
stft_repr = torch.stft( |
|
|
raw_audio_packed, |
|
|
**self.stft_kwargs, |
|
|
window=self.stft_window, |
|
|
return_complex=True, |
|
|
) |
|
|
stft_repr = torch.view_as_real(stft_repr) |
|
|
stft_repr = rearrange(stft_repr, "(b c) f t T -> b c f t T", c=c) |
|
|
stft_repr_merged = rearrange(stft_repr, "b c f t T -> b t (f c T)") |
|
|
|
|
|
hidden_states = self.model(stft_repr_merged) |
|
|
|
|
|
mask = torch.stack([fn(hidden_states) for fn in self.mask_estimators], dim=1) |
|
|
mask = rearrange(mask, "b n t (f c T) -> b n c f t T", T=2, c=c) |
|
|
mask = mask.to(dtype=torch.float32) |
|
|
|
|
|
with torch.autocast(device_type=device.type, enabled=False): |
|
|
stft_repr_expanded = rearrange(stft_repr, "b c f t T -> b 1 c f t T") |
|
|
stft_repr_complex = torch.view_as_complex(stft_repr_expanded) |
|
|
mask_complex = torch.view_as_complex(mask) |
|
|
masked_stft = stft_repr_complex * mask_complex |
|
|
|
|
|
masked_stft = rearrange(masked_stft, "b n c f t -> (b n c) f t") |
|
|
recon_audio = torch.istft( |
|
|
masked_stft, |
|
|
**self.stft_kwargs, |
|
|
window=self.stft_window, |
|
|
return_complex=False, |
|
|
length=raw_audio.shape[-1], |
|
|
) |
|
|
recon_audio = rearrange(recon_audio, "(b n c) t -> b n c t", c=self.wave_channels, n=self.config.num_stems) |
|
|
|
|
|
if target is None: |
|
|
return recon_audio |
|
|
|
|
|
target = target[..., : recon_audio.shape[-1]] |
|
|
loss = F.l1_loss(recon_audio, target) |
|
|
return loss |
|
|
|
|
|
def separate( |
|
|
self, |
|
|
mixed_wave: torch.Tensor, |
|
|
chunk_size: int = 44100 * 8, |
|
|
overlap_size: int = 44100 * 4, |
|
|
batch_size: int = 16, |
|
|
gap_size: int = 44100 * 1, |
|
|
verbose: bool = True, |
|
|
): |
|
|
""" |
|
|
Separates a full audio waveform into its constituent stems using a sliding window approach. |
|
|
|
|
|
Args: |
|
|
mixed_wave (`torch.Tensor` of shape `(channels, time)`): |
|
|
The raw audio waveform of the mixture. |
|
|
chunk_size (`int`, *optional*, defaults to `352800` (8 seconds at 44.1kHz)): |
|
|
The size of each audio chunk for processing. |
|
|
overlap_size (`int`, *optional*, defaults to `176400` (4 seconds at 44.1kHz)): |
|
|
The size of the overlap between consecutive chunks. |
|
|
batch_size (`int`, *optional*, defaults to `16`): |
|
|
The number of chunks to process in a single batch. |
|
|
gap_size (`int`, *optional*, defaults to `44100` (1 second at 44.1kHz)): |
|
|
The size of the gap for the fade-in/fade-out window. |
|
|
verbose (`bool`, *optional*, defaults to `True`): |
|
|
Whether to print progress information during processing. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor (`torch.Tensor` of shape `(num_stems, channels, time)`): |
|
|
The separated audio waveforms. |
|
|
""" |
|
|
if mixed_wave.dim() != 2: |
|
|
raise ValueError("Input `mixed_wave` must be a 2D tensor of shape (channels, time)") |
|
|
|
|
|
device = mixed_wave.device |
|
|
|
|
|
|
|
|
fade_size = chunk_size // 10 |
|
|
window = torch.ones(chunk_size - 2 * gap_size, device=device) |
|
|
window[:fade_size] = torch.linspace(0, 1, fade_size, device=device) |
|
|
window[-fade_size:] = torch.linspace(1, 0, fade_size, device=device) |
|
|
window = F.pad(window, (gap_size, gap_size), value=0.0) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
wave_length = mixed_wave.shape[-1] |
|
|
|
|
|
if wave_length <= chunk_size: |
|
|
num_chunks = 1 |
|
|
else: |
|
|
num_chunks = math.ceil((wave_length - chunk_size) / overlap_size) + 1 |
|
|
|
|
|
required_length = (num_chunks - 1) * overlap_size + chunk_size |
|
|
padded_wave = F.pad( |
|
|
mixed_wave, |
|
|
(0, required_length - wave_length), |
|
|
mode="constant", |
|
|
) |
|
|
|
|
|
unfolded_chunks = padded_wave.unfold( |
|
|
dimension=-1, |
|
|
size=chunk_size, |
|
|
step=overlap_size, |
|
|
) |
|
|
batch = unfolded_chunks.permute(1, 0, 2) |
|
|
|
|
|
if verbose: |
|
|
print(f"Input wave shape: {mixed_wave.shape}") |
|
|
print(f"Padded wave shape: {padded_wave.shape}") |
|
|
print(f"Number of chunks: {num_chunks}") |
|
|
output_chunks = [] |
|
|
for i in range(0, num_chunks, batch_size): |
|
|
chunk_batch = batch[i : i + batch_size] |
|
|
output_chunk = self(chunk_batch) |
|
|
output_chunks.append(output_chunk) |
|
|
if verbose: |
|
|
print(f"Processed chunks {i} to {i + chunk_batch.shape[0]}") |
|
|
batch_output = torch.cat(output_chunks, dim=0) |
|
|
|
|
|
_, num_stems, C, _ = batch_output.shape |
|
|
batch_output = batch_output.view(num_chunks, -1, chunk_size).permute(1, 0, 2) |
|
|
batch_output = batch_output * window |
|
|
output_result_buffer = F.fold( |
|
|
batch_output.permute(0, 2, 1), |
|
|
output_size=(1, required_length), |
|
|
kernel_size=(1, chunk_size), |
|
|
stride=(1, overlap_size), |
|
|
) |
|
|
|
|
|
window_for_fold = window.expand(1, 1, -1).repeat(1, num_chunks, 1) |
|
|
weighted_sum_counter = F.fold( |
|
|
window_for_fold.permute(0, 2, 1), |
|
|
output_size=(1, required_length), |
|
|
kernel_size=(1, chunk_size), |
|
|
stride=(1, overlap_size), |
|
|
) |
|
|
|
|
|
output_result_buffer = output_result_buffer.view(num_stems, C, -1) |
|
|
weighted_sum_counter = weighted_sum_counter.view(1, 1, -1) |
|
|
weighted_sum_counter.clamp_min_(1e-8) |
|
|
|
|
|
final_output = (output_result_buffer / weighted_sum_counter)[:, :, :wave_length] |
|
|
|
|
|
return final_output |
|
|
|