import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers.activations import ACT2FN from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from .configuration_bs_roformer import BSRoformerConfig def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class RotaryEmbedding(nn.Module): def __init__(self, config: BSRoformerConfig): super().__init__() self.head_dim = config.hidden_size // config.num_attention_heads inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class BSRoformerMLP(nn.Module): def __init__(self, config: BSRoformerConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN["gelu"] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class BSRoformerAttention(nn.Module): def __init__(self, config: BSRoformerConfig): super().__init__() self.is_causal = False self.config = config self.head_dim = config.hidden_size // config.num_attention_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask=None, ): input_shape = hidden_states.size()[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class BSRoformerLayer(nn.Module): def __init__(self, config: BSRoformerConfig): super().__init__() self.self_attn = BSRoformerAttention(config) self.mlp = BSRoformerMLP(config) self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states, position_embeddings, attention_mask, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states, position_embeddings, attention_mask, ) hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = hidden_states + residual return hidden_states class BSRoformerAxialTransformer(nn.Module): def __init__( self, config: BSRoformerConfig, transformer_depth: int, is_time_transformer: bool, ): super().__init__() self.layers = nn.ModuleList([BSRoformerLayer(config) for _ in range(transformer_depth)]) self.is_time_transformer = is_time_transformer def forward( self, hidden_states, position_embeddings, attention_mask, ): if self.is_time_transformer: hidden_states = rearrange(hidden_states, 'b t f d -> b f t d') b, seq_len_1, seq_len_2, d = hidden_states.shape hidden_states = rearrange(hidden_states, 'b n m d -> (b n) m d') for layer in self.layers: hidden_states = layer( hidden_states, position_embeddings, attention_mask, ) hidden_states = rearrange(hidden_states, '(b n) m d -> b n m d', b=b) if self.is_time_transformer: hidden_states = rearrange(hidden_states, 'b f t d -> b t f d') return hidden_states class BandSplit(nn.Module): def __init__(self, config: BSRoformerConfig): super().__init__() self.dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands) self.to_features = nn.ModuleList( [ nn.Sequential(nn.RMSNorm(dim_in, eps=config.rms_norm_eps), nn.Linear(dim_in, config.hidden_size)) for dim_in in self.dim_inputs ] ) def forward(self, x): x_split = x.split(self.dim_inputs, dim=-1) outs = [to_feature(split_input) for split_input, to_feature in zip(x_split, self.to_features)] return torch.stack(outs, dim=-2) class MaskEstimator(nn.Module): def __init__(self, config: BSRoformerConfig): super().__init__() dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands) self.to_freq_mlps = nn.ModuleList([nn.Linear(config.hidden_size, dim_in) for dim_in in dim_inputs]) def forward(self, x): x_unbind = x.unbind(dim=-2) outs = [mlp(band_features) for band_features, mlp in zip(x_unbind, self.to_freq_mlps)] return torch.cat(outs, dim=-1) class BSRoformerPreTrainedModel(PreTrainedModel): config_class = BSRoformerConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BSRoformerLayer"] class BSRoformerModel(BSRoformerPreTrainedModel): def __init__(self, config: BSRoformerConfig): super().__init__(config) self.config = config self.band_split = BandSplit(config) self.layers = nn.ModuleList( nn.ModuleList( [ BSRoformerAxialTransformer(config, config.time_transformer_depth, is_time_transformer=True), BSRoformerAxialTransformer(config, config.freq_transformer_depth, is_time_transformer=False), ] ) for _ in range(config.depth) ) self.rotary_emb = RotaryEmbedding(config) self.final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) rn = config.register_token_num self.register_tokens = nn.Parameter(torch.normal(0, 0.02, size=(rn, rn, config.hidden_size))) self.post_init() def forward( self, x, position_ids=None, ): hidden_states = self.band_split(x) b, t, n, h = hidden_states.shape if position_ids is None: position_ids = torch.arange(t, device=hidden_states.device).unsqueeze(0) pos_embeds = self.rotary_emb(hidden_states, position_ids) pos_embeds_for_freq = self.rotary_emb( hidden_states, torch.arange(n, device=hidden_states.device).unsqueeze(0), ) rn = self.config.register_token_num hidden_states = F.pad(hidden_states, (0, 0, 0, rn, 0, rn)) hidden_states[:, t:, n:, :] = self.register_tokens def pad_rope(cos, sin): cos_padded = F.pad(cos, (0, 0, 0, rn), value=1.0) sin_padded = F.pad(sin, (0, 0, 0, rn), value=0.0) return cos_padded, sin_padded pos_embeds = pad_rope(*pos_embeds) pos_embeds_for_freq = pad_rope(*pos_embeds_for_freq) for time_transformer, freq_transformer in self.layers: hidden_states = time_transformer( hidden_states, position_embeddings=pos_embeds, attention_mask=None, ) hidden_states = freq_transformer( hidden_states, position_embeddings=pos_embeds_for_freq, attention_mask=None, ) hidden_states = hidden_states[:, :t, :n, :] return self.final_norm(hidden_states) class BSRoformerForMaskedEstimation(BSRoformerPreTrainedModel): def __init__(self, config: BSRoformerConfig): super().__init__(config) self.config = config self.model = BSRoformerModel(config) self.mask_estimators = nn.ModuleList([MaskEstimator(config) for _ in range(config.num_stems)]) self.stft_kwargs = dict( n_fft=config.stft_n_fft, hop_length=config.stft_hop_length, win_length=config.stft_win_length, normalized=False, ) self.register_buffer("stft_window", torch.hann_window(config.stft_win_length), persistent=False) freqs = config.stft_n_fft // 2 + 1 assert sum(config.freqs_per_bands) == freqs, f"Sum of freqs_per_bands must be {freqs}" self.wave_channels = config.num_input_channel def forward( self, raw_audio: torch.Tensor, target: Optional[torch.Tensor] = None, ): device = raw_audio.device with torch.autocast(device_type=device.type, enabled=False): b, c, t = raw_audio.shape raw_audio_packed = rearrange(raw_audio, "b c t -> (b c) t") stft_repr = torch.stft( raw_audio_packed, **self.stft_kwargs, window=self.stft_window, return_complex=True, ) stft_repr = torch.view_as_real(stft_repr) stft_repr = rearrange(stft_repr, "(b c) f t T -> b c f t T", c=c) stft_repr_merged = rearrange(stft_repr, "b c f t T -> b t (f c T)") hidden_states = self.model(stft_repr_merged) mask = torch.stack([fn(hidden_states) for fn in self.mask_estimators], dim=1) mask = rearrange(mask, "b n t (f c T) -> b n c f t T", T=2, c=c) mask = mask.to(dtype=torch.float32) with torch.autocast(device_type=device.type, enabled=False): stft_repr_expanded = rearrange(stft_repr, "b c f t T -> b 1 c f t T") stft_repr_complex = torch.view_as_complex(stft_repr_expanded) mask_complex = torch.view_as_complex(mask) masked_stft = stft_repr_complex * mask_complex masked_stft = rearrange(masked_stft, "b n c f t -> (b n c) f t") recon_audio = torch.istft( masked_stft, **self.stft_kwargs, window=self.stft_window, return_complex=False, length=raw_audio.shape[-1], ) recon_audio = rearrange(recon_audio, "(b n c) t -> b n c t", c=self.wave_channels, n=self.config.num_stems) if target is None: return recon_audio target = target[..., : recon_audio.shape[-1]] loss = F.l1_loss(recon_audio, target) return loss def separate( self, mixed_wave: torch.Tensor, chunk_size: int = 44100 * 8, overlap_size: int = 44100 * 4, batch_size: int = 16, gap_size: int = 44100 * 1, verbose: bool = True, ): """ Separates a full audio waveform into its constituent stems using a sliding window approach. Args: mixed_wave (`torch.Tensor` of shape `(channels, time)`): The raw audio waveform of the mixture. chunk_size (`int`, *optional*, defaults to `352800` (8 seconds at 44.1kHz)): The size of each audio chunk for processing. overlap_size (`int`, *optional*, defaults to `176400` (4 seconds at 44.1kHz)): The size of the overlap between consecutive chunks. batch_size (`int`, *optional*, defaults to `16`): The number of chunks to process in a single batch. gap_size (`int`, *optional*, defaults to `44100` (1 second at 44.1kHz)): The size of the gap for the fade-in/fade-out window. verbose (`bool`, *optional*, defaults to `True`): Whether to print progress information during processing. Returns: torch.Tensor (`torch.Tensor` of shape `(num_stems, channels, time)`): The separated audio waveforms. """ if mixed_wave.dim() != 2: raise ValueError("Input `mixed_wave` must be a 2D tensor of shape (channels, time)") device = mixed_wave.device # Fade-in/fade-out window fade_size = chunk_size // 10 window = torch.ones(chunk_size - 2 * gap_size, device=device) window[:fade_size] = torch.linspace(0, 1, fade_size, device=device) window[-fade_size:] = torch.linspace(1, 0, fade_size, device=device) window = F.pad(window, (gap_size, gap_size), value=0.0) with torch.inference_mode(): wave_length = mixed_wave.shape[-1] if wave_length <= chunk_size: num_chunks = 1 else: num_chunks = math.ceil((wave_length - chunk_size) / overlap_size) + 1 required_length = (num_chunks - 1) * overlap_size + chunk_size padded_wave = F.pad( mixed_wave, (0, required_length - wave_length), mode="constant", ) unfolded_chunks = padded_wave.unfold( dimension=-1, size=chunk_size, step=overlap_size, ) # (C, num_chunks, chunk_size) batch = unfolded_chunks.permute(1, 0, 2) # (num_chunks, C, chunk_size) if verbose: print(f"Input wave shape: {mixed_wave.shape}") print(f"Padded wave shape: {padded_wave.shape}") print(f"Number of chunks: {num_chunks}") output_chunks = [] for i in range(0, num_chunks, batch_size): chunk_batch = batch[i : i + batch_size] output_chunk = self(chunk_batch) # Call forward method output_chunks.append(output_chunk) if verbose: print(f"Processed chunks {i} to {i + chunk_batch.shape[0]}") batch_output = torch.cat(output_chunks, dim=0) # (num_chunks, num_stems, C, chunk_size) _, num_stems, C, _ = batch_output.shape batch_output = batch_output.view(num_chunks, -1, chunk_size).permute(1, 0, 2) # (num_stems * C, num_chunks, chunk_size) batch_output = batch_output * window output_result_buffer = F.fold( batch_output.permute(0, 2, 1), output_size=(1, required_length), kernel_size=(1, chunk_size), stride=(1, overlap_size), ) # (num_stems * C, 1, 1, required_length) window_for_fold = window.expand(1, 1, -1).repeat(1, num_chunks, 1) weighted_sum_counter = F.fold( window_for_fold.permute(0, 2, 1), output_size=(1, required_length), kernel_size=(1, chunk_size), stride=(1, overlap_size), ) # (1, 1, 1, required_length) output_result_buffer = output_result_buffer.view(num_stems, C, -1) # (num_stems, C, required_length) weighted_sum_counter = weighted_sum_counter.view(1, 1, -1) weighted_sum_counter.clamp_min_(1e-8) final_output = (output_result_buffer / weighted_sum_counter)[:, :, :wave_length] return final_output