Mini-BS-RoFormer / modeling_bs_roformer.py
HiDolen's picture
Update modeling_bs_roformer.py
3d69caf verified
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from .configuration_bs_roformer import BSRoformerConfig
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RotaryEmbedding(nn.Module):
def __init__(self, config: BSRoformerConfig):
super().__init__()
self.head_dim = config.hidden_size // config.num_attention_heads
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class BSRoformerMLP(nn.Module):
def __init__(self, config: BSRoformerConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN["gelu"]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class BSRoformerAttention(nn.Module):
def __init__(self, config: BSRoformerConfig):
super().__init__()
self.is_causal = False
self.config = config
self.head_dim = config.hidden_size // config.num_attention_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
def forward(
self,
hidden_states,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask=None,
):
input_shape = hidden_states.size()[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class BSRoformerLayer(nn.Module):
def __init__(self, config: BSRoformerConfig):
super().__init__()
self.self_attn = BSRoformerAttention(config)
self.mlp = BSRoformerMLP(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states,
position_embeddings,
attention_mask,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states,
position_embeddings,
attention_mask,
)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class BSRoformerAxialTransformer(nn.Module):
def __init__(
self,
config: BSRoformerConfig,
transformer_depth: int,
is_time_transformer: bool,
):
super().__init__()
self.layers = nn.ModuleList([BSRoformerLayer(config) for _ in range(transformer_depth)])
self.is_time_transformer = is_time_transformer
def forward(
self,
hidden_states,
position_embeddings,
attention_mask,
):
if self.is_time_transformer:
hidden_states = rearrange(hidden_states, 'b t f d -> b f t d')
b, seq_len_1, seq_len_2, d = hidden_states.shape
hidden_states = rearrange(hidden_states, 'b n m d -> (b n) m d')
for layer in self.layers:
hidden_states = layer(
hidden_states,
position_embeddings,
attention_mask,
)
hidden_states = rearrange(hidden_states, '(b n) m d -> b n m d', b=b)
if self.is_time_transformer:
hidden_states = rearrange(hidden_states, 'b f t d -> b t f d')
return hidden_states
class BandSplit(nn.Module):
def __init__(self, config: BSRoformerConfig):
super().__init__()
self.dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands)
self.to_features = nn.ModuleList(
[
nn.Sequential(nn.RMSNorm(dim_in, eps=config.rms_norm_eps), nn.Linear(dim_in, config.hidden_size))
for dim_in in self.dim_inputs
]
)
def forward(self, x):
x_split = x.split(self.dim_inputs, dim=-1)
outs = [to_feature(split_input) for split_input, to_feature in zip(x_split, self.to_features)]
return torch.stack(outs, dim=-2)
class MaskEstimator(nn.Module):
def __init__(self, config: BSRoformerConfig):
super().__init__()
dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands)
self.to_freq_mlps = nn.ModuleList([nn.Linear(config.hidden_size, dim_in) for dim_in in dim_inputs])
def forward(self, x):
x_unbind = x.unbind(dim=-2)
outs = [mlp(band_features) for band_features, mlp in zip(x_unbind, self.to_freq_mlps)]
return torch.cat(outs, dim=-1)
class BSRoformerPreTrainedModel(PreTrainedModel):
config_class = BSRoformerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BSRoformerLayer"]
class BSRoformerModel(BSRoformerPreTrainedModel):
def __init__(self, config: BSRoformerConfig):
super().__init__(config)
self.config = config
self.band_split = BandSplit(config)
self.layers = nn.ModuleList(
nn.ModuleList(
[
BSRoformerAxialTransformer(config, config.time_transformer_depth, is_time_transformer=True),
BSRoformerAxialTransformer(config, config.freq_transformer_depth, is_time_transformer=False),
]
)
for _ in range(config.depth)
)
self.rotary_emb = RotaryEmbedding(config)
self.final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
rn = config.register_token_num
self.register_tokens = nn.Parameter(torch.normal(0, 0.02, size=(rn, rn, config.hidden_size)))
self.post_init()
def forward(
self,
x,
position_ids=None,
):
hidden_states = self.band_split(x)
b, t, n, h = hidden_states.shape
if position_ids is None:
position_ids = torch.arange(t, device=hidden_states.device).unsqueeze(0)
pos_embeds = self.rotary_emb(hidden_states, position_ids)
pos_embeds_for_freq = self.rotary_emb(
hidden_states,
torch.arange(n, device=hidden_states.device).unsqueeze(0),
)
rn = self.config.register_token_num
hidden_states = F.pad(hidden_states, (0, 0, 0, rn, 0, rn))
hidden_states[:, t:, n:, :] = self.register_tokens
def pad_rope(cos, sin):
cos_padded = F.pad(cos, (0, 0, 0, rn), value=1.0)
sin_padded = F.pad(sin, (0, 0, 0, rn), value=0.0)
return cos_padded, sin_padded
pos_embeds = pad_rope(*pos_embeds)
pos_embeds_for_freq = pad_rope(*pos_embeds_for_freq)
for time_transformer, freq_transformer in self.layers:
hidden_states = time_transformer(
hidden_states,
position_embeddings=pos_embeds,
attention_mask=None,
)
hidden_states = freq_transformer(
hidden_states,
position_embeddings=pos_embeds_for_freq,
attention_mask=None,
)
hidden_states = hidden_states[:, :t, :n, :]
return self.final_norm(hidden_states)
class BSRoformerForMaskedEstimation(BSRoformerPreTrainedModel):
def __init__(self, config: BSRoformerConfig):
super().__init__(config)
self.config = config
self.model = BSRoformerModel(config)
self.mask_estimators = nn.ModuleList([MaskEstimator(config) for _ in range(config.num_stems)])
self.stft_kwargs = dict(
n_fft=config.stft_n_fft,
hop_length=config.stft_hop_length,
win_length=config.stft_win_length,
normalized=False,
)
self.register_buffer("stft_window", torch.hann_window(config.stft_win_length), persistent=False)
freqs = config.stft_n_fft // 2 + 1
assert sum(config.freqs_per_bands) == freqs, f"Sum of freqs_per_bands must be {freqs}"
self.wave_channels = config.num_input_channel
def forward(
self,
raw_audio: torch.Tensor,
target: Optional[torch.Tensor] = None,
):
device = raw_audio.device
with torch.autocast(device_type=device.type, enabled=False):
b, c, t = raw_audio.shape
raw_audio_packed = rearrange(raw_audio, "b c t -> (b c) t")
stft_repr = torch.stft(
raw_audio_packed,
**self.stft_kwargs,
window=self.stft_window,
return_complex=True,
)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = rearrange(stft_repr, "(b c) f t T -> b c f t T", c=c)
stft_repr_merged = rearrange(stft_repr, "b c f t T -> b t (f c T)")
hidden_states = self.model(stft_repr_merged)
mask = torch.stack([fn(hidden_states) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, "b n t (f c T) -> b n c f t T", T=2, c=c)
mask = mask.to(dtype=torch.float32)
with torch.autocast(device_type=device.type, enabled=False):
stft_repr_expanded = rearrange(stft_repr, "b c f t T -> b 1 c f t T")
stft_repr_complex = torch.view_as_complex(stft_repr_expanded)
mask_complex = torch.view_as_complex(mask)
masked_stft = stft_repr_complex * mask_complex
masked_stft = rearrange(masked_stft, "b n c f t -> (b n c) f t")
recon_audio = torch.istft(
masked_stft,
**self.stft_kwargs,
window=self.stft_window,
return_complex=False,
length=raw_audio.shape[-1],
)
recon_audio = rearrange(recon_audio, "(b n c) t -> b n c t", c=self.wave_channels, n=self.config.num_stems)
if target is None:
return recon_audio
target = target[..., : recon_audio.shape[-1]]
loss = F.l1_loss(recon_audio, target)
return loss
def separate(
self,
mixed_wave: torch.Tensor,
chunk_size: int = 44100 * 8,
overlap_size: int = 44100 * 4,
batch_size: int = 16,
gap_size: int = 44100 * 1,
verbose: bool = True,
):
"""
Separates a full audio waveform into its constituent stems using a sliding window approach.
Args:
mixed_wave (`torch.Tensor` of shape `(channels, time)`):
The raw audio waveform of the mixture.
chunk_size (`int`, *optional*, defaults to `352800` (8 seconds at 44.1kHz)):
The size of each audio chunk for processing.
overlap_size (`int`, *optional*, defaults to `176400` (4 seconds at 44.1kHz)):
The size of the overlap between consecutive chunks.
batch_size (`int`, *optional*, defaults to `16`):
The number of chunks to process in a single batch.
gap_size (`int`, *optional*, defaults to `44100` (1 second at 44.1kHz)):
The size of the gap for the fade-in/fade-out window.
verbose (`bool`, *optional*, defaults to `True`):
Whether to print progress information during processing.
Returns:
torch.Tensor (`torch.Tensor` of shape `(num_stems, channels, time)`):
The separated audio waveforms.
"""
if mixed_wave.dim() != 2:
raise ValueError("Input `mixed_wave` must be a 2D tensor of shape (channels, time)")
device = mixed_wave.device
# Fade-in/fade-out window
fade_size = chunk_size // 10
window = torch.ones(chunk_size - 2 * gap_size, device=device)
window[:fade_size] = torch.linspace(0, 1, fade_size, device=device)
window[-fade_size:] = torch.linspace(1, 0, fade_size, device=device)
window = F.pad(window, (gap_size, gap_size), value=0.0)
with torch.inference_mode():
wave_length = mixed_wave.shape[-1]
if wave_length <= chunk_size:
num_chunks = 1
else:
num_chunks = math.ceil((wave_length - chunk_size) / overlap_size) + 1
required_length = (num_chunks - 1) * overlap_size + chunk_size
padded_wave = F.pad(
mixed_wave,
(0, required_length - wave_length),
mode="constant",
)
unfolded_chunks = padded_wave.unfold(
dimension=-1,
size=chunk_size,
step=overlap_size,
) # (C, num_chunks, chunk_size)
batch = unfolded_chunks.permute(1, 0, 2) # (num_chunks, C, chunk_size)
if verbose:
print(f"Input wave shape: {mixed_wave.shape}")
print(f"Padded wave shape: {padded_wave.shape}")
print(f"Number of chunks: {num_chunks}")
output_chunks = []
for i in range(0, num_chunks, batch_size):
chunk_batch = batch[i : i + batch_size]
output_chunk = self(chunk_batch) # Call forward method
output_chunks.append(output_chunk)
if verbose:
print(f"Processed chunks {i} to {i + chunk_batch.shape[0]}")
batch_output = torch.cat(output_chunks, dim=0) # (num_chunks, num_stems, C, chunk_size)
_, num_stems, C, _ = batch_output.shape
batch_output = batch_output.view(num_chunks, -1, chunk_size).permute(1, 0, 2) # (num_stems * C, num_chunks, chunk_size)
batch_output = batch_output * window
output_result_buffer = F.fold(
batch_output.permute(0, 2, 1),
output_size=(1, required_length),
kernel_size=(1, chunk_size),
stride=(1, overlap_size),
) # (num_stems * C, 1, 1, required_length)
window_for_fold = window.expand(1, 1, -1).repeat(1, num_chunks, 1)
weighted_sum_counter = F.fold(
window_for_fold.permute(0, 2, 1),
output_size=(1, required_length),
kernel_size=(1, chunk_size),
stride=(1, overlap_size),
) # (1, 1, 1, required_length)
output_result_buffer = output_result_buffer.view(num_stems, C, -1) # (num_stems, C, required_length)
weighted_sum_counter = weighted_sum_counter.view(1, 1, -1)
weighted_sum_counter.clamp_min_(1e-8)
final_output = (output_result_buffer / weighted_sum_counter)[:, :, :wave_length]
return final_output