{ "cells": [ { "cell_type": "markdown", "id": "4c6d6855", "metadata": {}, "source": [ "- torch>=2.9.1\n", "- tqdm>=4.67.1\n", "- ipykernel>=6.30.1\n", "- ipywidgets>=8.1.7\n", "- transformers>=4.55.4\n", "- git+https://github.com/HiDolen/pytorch_lightning_quick_start_utils@v0.3\n", "- tensorboard>=2.20.0\n", "- lightning>=2.5.3\n", "- audiomentations[extras]>=0.42.0\n", "- matplotlib>=3.10.5\n", "- torchao>=0.12.0" ] }, { "cell_type": "markdown", "id": "e8306d9f", "metadata": {}, "source": [ "## 初始化" ] }, { "cell_type": "code", "execution_count": null, "id": "51338b4a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "from typing import List, Union, Optional\n", "import math\n", "from types import SimpleNamespace\n", "import random\n", "import glob\n", "from pathlib import Path\n", "import pickle\n", "from contextlib import nullcontext\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, IterableDataset, Dataset\n", "\n", "from transformers.configuration_utils import PretrainedConfig\n", "from transformers.modeling_utils import PreTrainedModel\n", "from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n", "from transformers.activations import ACT2FN\n", "\n", "from einops import rearrange, pack, unpack\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import soundfile\n", "import audiomentations\n", "\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "from pl_utils import BaseModule, TrainingConfig\n", "from pl_utils.misc import is_notebook, get_model_parameters_count" ] }, { "cell_type": "code", "execution_count": null, "id": "e15cad0e", "metadata": {}, "outputs": [], "source": [ "from pl_utils import init_before_training\n", "\n", "\n", "init_before_training(\n", " matmul_precision=\"medium\",\n", " seed=42,\n", ")\n", "\n", "num_workers = 28\n", "batch_size = 8\n", "\n", "wave_chunk_size = 44100 * 8" ] }, { "cell_type": "markdown", "id": "a828912f", "metadata": {}, "source": [ "## 定义" ] }, { "cell_type": "markdown", "id": "9592af7a", "metadata": {}, "source": [ "### Utils 定义" ] }, { "cell_type": "code", "execution_count": null, "id": "84dd1eec", "metadata": {}, "outputs": [], "source": [ "def loudness_db2linear(db):\n", " return 10 ** (db / 20)\n", "\n", "\n", "def loudness_linear2db(linear):\n", " return 20 * np.log10(linear)" ] }, { "cell_type": "markdown", "id": "68c460af", "metadata": {}, "source": [ "### Dataset 定义" ] }, { "cell_type": "code", "execution_count": null, "id": "71aaa349", "metadata": {}, "outputs": [], "source": [ "class AugmentDataset(IterableDataset):\n", " \"\"\"\n", " 用于 MUSDB18HQ 数据的、含有数据增强的 Dataset。返回分块音频。\n", "\n", " 期望的数据目录结构:\n", "\n", " dataset/\n", " ├── A Classic Education - NightOwl\n", " │ ├── bass.wav\n", " │ ├── drums.wav\n", " │ ├── mixture.wav\n", " │ ├── other.wav\n", " │ └── vocals.wav\n", " ├── Actions - Devil's Words\n", " │ ├── bass.wav\n", " │ ├── drums.wav\n", " │ ├── mixture.wav\n", " │ ├── other.wav\n", " │ └── vocals.wav\n", " ···\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " data_path,\n", " wave_chunk_size=44100 * 8,\n", " sample_rate=44100,\n", " same_stem_mixup_prob=[0.2, 0.02],\n", " same_stem_mixup_loudness_range=[-3, 3],\n", " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n", " debug=False,\n", " ):\n", " if type(data_path) is not list:\n", " data_path = [data_path]\n", " self.data_path = [Path(p) for p in data_path]\n", "\n", " self.wave_chunk_size = wave_chunk_size\n", " self.sample_rate = sample_rate\n", "\n", " self.same_stem_mixup_prob = same_stem_mixup_prob\n", " self.same_stem_mixup_loudness_range = same_stem_mixup_loudness_range\n", " self.stem_names = stem_names\n", "\n", " self.metadata = self._get_metadata()\n", "\n", " self.augments = audiomentations.Compose(\n", " [\n", " # 极性反转\n", " audiomentations.PolarityInversion(p=0.5),\n", " # 音高偏移\n", " # audiomentations.PitchShift(\n", " # min_semitones=-5,\n", " # max_semitones=5,\n", " # p=0.5,\n", " # ),\n", " # 七频段 eq 随机调整\n", " audiomentations.SevenBandParametricEQ(\n", " min_gain_db=-9,\n", " max_gain_db=9,\n", " p=1.0,\n", " ),\n", " # tanh 失真\n", " audiomentations.TanhDistortion(\n", " min_distortion=0.1,\n", " max_distortion=0.6,\n", " p=0.5,\n", " ),\n", " # 低品质失真\n", " audiomentations.Mp3Compression(\n", " min_bitrate=32,\n", " max_bitrate=256,\n", " p=0.4,\n", " ),\n", " # 拉伸\n", " # audiomentations.TimeStretch(\n", " # min_rate=0.8,\n", " # max_rate=1.25,\n", " # p=1.0,\n", " # ),\n", " # 随机音量\n", " # audiomentations.GainTransition(\n", " # min_gain_db=-3,\n", " # max_gain_db=3,\n", " # min_duration=0.5,\n", " # max_duration=4.0,\n", " # p=1.0,\n", " # ),\n", " ]\n", " )\n", "\n", " self.file_handles = {}\n", " self.debug = debug\n", "\n", " def _get_one_of_metadata(self, data_path):\n", " song_paths = [p for p in data_path.iterdir() if p.is_dir()]\n", " # 读取缓存\n", " cache_path = data_path / \"metadata.pkl\"\n", " if cache_path.exists():\n", " with open(cache_path, \"rb\") as f:\n", " song_metadata = pickle.load(f)\n", " cache_paths = [m[0] for m in song_metadata]\n", " # 文件没有改动,直接使用缓存\n", " if set(cache_paths) == set(song_paths):\n", " return song_metadata\n", "\n", " # 构建缓存\n", " song_metadata = []\n", " for song_path in tqdm(song_paths, desc=\"Scanning dataset\"):\n", " wave_files = [f for f in song_path.iterdir() if f.is_file() and f.stem in self.stem_names]\n", "\n", " lengths = []\n", " for wave_file in wave_files:\n", " data, samplerate = soundfile.read(wave_file)\n", " assert samplerate == self.sample_rate, f\"Sample rate {samplerate} is not desired {self.sample_rate}\"\n", " track_length = len(data)\n", " lengths.append(track_length)\n", " if len(set(lengths)) > 1:\n", " print(f\"Warning: Inconsistent track lengths found in {song_path}. Using min length: {min(lengths)}\")\n", "\n", " stem_file_dict = {f.stem: f for f in wave_files}\n", " song_metadata.append((song_path, min(lengths), stem_file_dict))\n", "\n", " # 保存缓存\n", " with open(cache_path, \"wb\") as f:\n", " pickle.dump(song_metadata, f)\n", "\n", " return song_metadata\n", "\n", " def _get_metadata(self):\n", " all_metadata = []\n", " for p in self.data_path:\n", " metadata = self._get_one_of_metadata(p)\n", " all_metadata.extend(metadata)\n", " return all_metadata\n", "\n", " def _load_random_wave(self, stem_name):\n", " \"\"\"\n", " 从 self.metadata 选取出指定 stem_name 的音轨。来源歌曲、截取位置都随机。\n", "\n", " 截取长度由 `self.wave_chunk_size` 决定。\n", " \"\"\"\n", "\n", " # 尝试 10 次,保证音频响度大于 -50dB\n", " for _ in range(10):\n", " song_path, length, stem_file_dict = random.choice(self.metadata)\n", "\n", " # random offset within track\n", " offset = np.random.randint(length - self.wave_chunk_size + 1)\n", " # get or open cached file handle\n", " file_path = stem_file_dict[stem_name]\n", " if file_path not in self.file_handles:\n", " self.file_handles[file_path] = soundfile.SoundFile(str(file_path), mode='r')\n", " handle = self.file_handles[file_path]\n", " # seek and read chunk\n", " handle.seek(offset)\n", " wave = handle.read(self.wave_chunk_size, dtype='float32')\n", " wave = wave.T # (channel, time)\n", " if len(wave.shape) == 1: # 对 mono 音频添加 channel 维度\n", " wave = np.expand_dims(wave, axis=0)\n", "\n", " rms = np.sqrt(np.mean(wave**2))\n", " if rms > loudness_db2linear(-50):\n", " break\n", "\n", " if self.debug:\n", " print(f\"Warning: sampled very silent audio from {file_path} (rms={rms:.6f})\")\n", " # augmentation\n", " wave = self._apply_augment(wave, stem_name)\n", "\n", " return wave\n", "\n", " def _load_random_stems(self):\n", " \"\"\"\n", " 加载随机的 self.stem_names 分轨。\n", "\n", " 包含的数据增强:\n", "\n", " - 单个 stem 的来源歌曲和截取位置都随机(由 `self._load_random_track()` 实现)\n", " - 单个 stem 可能是多个同类型 stem 混合获得,概率由 `self.same_stem_mixup_prob` 决定\n", " - 混合 stem 时各个 stem 的响度在 `self.same_stem_mixup_loudness_range` 范围内随机\n", " \"\"\"\n", " waves = []\n", " for stem_name in self.stem_names:\n", " wave = self._load_random_wave(stem_name)\n", "\n", " mixup_waves = [wave]\n", " for prob in self.same_stem_mixup_prob:\n", " if random.uniform(0, 1) < prob:\n", " wave2 = self._load_random_wave(stem_name)\n", " mixup_waves.append(wave2)\n", "\n", " mixup_waves = np.stack(mixup_waves, axis=0)\n", "\n", " # 在 self.same_stem_mixup_loudness_range 范围内的随机响度\n", " loudness = np.random.uniform(\n", " low=loudness_db2linear(self.same_stem_mixup_loudness_range[0]),\n", " high=loudness_db2linear(self.same_stem_mixup_loudness_range[1]),\n", " size=(len(mixup_waves),),\n", " )\n", " mixup_waves *= loudness[:, None, None]\n", " mixup_wave = mixup_waves.mean(axis=0)\n", "\n", " waves.append(mixup_wave)\n", "\n", " waves = np.stack(waves, axis=0)\n", "\n", " return waves\n", "\n", " def _apply_augment(self, wave, stem_name):\n", " # Channel shuffle\n", " if random.uniform(0, 1) < 0.5:\n", " wave = wave[::-1].copy()\n", "\n", " # self.stem_augment\n", " wave = self.augments(samples=wave, sample_rate=self.sample_rate)\n", "\n", " return wave\n", "\n", " def __iter__(self):\n", " while True:\n", " waves = self._load_random_stems()\n", "\n", " # 随机分轨音量\n", " loudnesses = np.random.uniform(\n", " low=loudness_db2linear(-3),\n", " high=loudness_db2linear(3),\n", " size=(len(waves),),\n", " )\n", " # 各个 stem 有 10% 概率变为空音频\n", " loudnesses *= (np.random.uniform(0, 1, size=(len(waves),)) > 0.1).astype(np.float32)\n", " # 施加到 waves 上\n", " waves *= loudnesses[:, None, None]\n", "\n", " # 获得混合音频\n", " mixed_wave = waves.sum(0)\n", "\n", " yield waves, mixed_wave\n", "\n", " def __del__(self):\n", " # Close any open SoundFile handles when dataset is destroyed\n", " for handle in self.file_handles.values():\n", " try:\n", " handle.close()\n", " except Exception:\n", " pass\n", "\n", "\n", "class ValidationDataset(Dataset):\n", " \"\"\"\n", " 用于 MUSDB18HQ 数据的、用于验证的 Dataset。返回完整音频。\n", "\n", " 期望的数据目录结构:\n", "\n", " dataset/\n", " ├── A Classic Education - NightOwl\n", " │ ├── bass.wav\n", " │ ├── drums.wav\n", " │ ├── mixture.wav\n", " │ ├── other.wav\n", " │ └── vocals.wav\n", " ├── Actions - Devil's Words\n", " │ ├── bass.wav\n", " │ ├── drums.wav\n", " │ ├── mixture.wav\n", " │ ├── other.wav\n", " │ └── vocals.wav\n", " ···\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " data_path,\n", " sample_rate=44100,\n", " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n", " ):\n", " self.data_path = Path(data_path)\n", " self.sample_rate = sample_rate\n", " self.stem_names = stem_names\n", "\n", " self.metadata = self._get_metadata()\n", "\n", " def _get_metadata(self):\n", " song_paths = [p for p in self.data_path.iterdir() if p.is_dir()]\n", " # 读取缓存\n", " cache_path = self.data_path / \"metadata.pkl\"\n", " if cache_path.exists():\n", " with open(cache_path, \"rb\") as f:\n", " song_metadata = pickle.load(f)\n", " cache_paths = [m[0] for m in song_metadata]\n", " # 文件没有改动,直接使用缓存\n", " if set(cache_paths) == set(song_paths):\n", " return song_metadata\n", "\n", " # 构建缓存\n", " song_metadata = []\n", " for song_path in tqdm(song_paths, desc=\"Scanning dataset\"):\n", " wave_files = [f for f in song_path.iterdir() if f.is_file() and f.stem in self.stem_names]\n", "\n", " lengths = []\n", " for wave_file in wave_files:\n", " data, samplerate = soundfile.read(wave_file)\n", " assert samplerate == self.sample_rate, f\"Sample rate {samplerate} is not desired {self.sample_rate}\"\n", " track_length = len(data)\n", " lengths.append(track_length)\n", " if len(set(lengths)) > 1:\n", " print(f\"Warning: Inconsistent track lengths found in {song_path}. Using min length: {min(lengths)}\")\n", "\n", " stem_file_dict = {f.stem: f for f in wave_files}\n", " song_metadata.append((song_path, min(lengths), stem_file_dict))\n", "\n", " # 保存缓存\n", " with open(cache_path, \"wb\") as f:\n", " pickle.dump(song_metadata, f)\n", "\n", " return song_metadata\n", "\n", " def __len__(self):\n", " return len(self.metadata)\n", "\n", " def __getitem__(self, index):\n", " song_path, length, stem_file_dict = self.metadata[index]\n", "\n", " waves = []\n", " for stem_name in self.stem_names:\n", " stem_file = stem_file_dict[stem_name]\n", " wave = soundfile.read(\n", " stem_file,\n", " dtype=\"float32\",\n", " )[0]\n", " wave = wave.T\n", " if len(wave.shape) == 1: # 对 mono 音频添加 channel 维度\n", " wave = np.expand_dims(wave, axis=0)\n", " waves.append(wave)\n", "\n", " waves = np.stack(waves, axis=0) # (stem, channel, time)\n", "\n", " # 获得混合音频\n", " mixed_wave = waves.sum(0)\n", "\n", " return waves, mixed_wave" ] }, { "cell_type": "markdown", "id": "22caec1a", "metadata": {}, "source": [ "### ModuleConfig 定义" ] }, { "cell_type": "code", "execution_count": null, "id": "591a48cd", "metadata": {}, "outputs": [], "source": [ "# in n_fft=4086,out n_fft=2048\n", "DEFAULT_FREQS_PER_BANDS = tuple([2 * 2] * 24 + [2 * 4] * 12 + [2 * 12] * 8 + [2 * 24] * 8 + [2 * 48] * 8 + [256, 257])\n", "DEFAULT_FREQS_PER_BANDS_OUT = tuple([2] * 24 + [4] * 12 + [12] * 8 + [24] * 8 + [48] * 8 + [128, 129])\n", "\n", "\n", "class BSRoformerConfig(PretrainedConfig):\n", "\n", " model_type = \"bs_roformer\"\n", "\n", " def __init__(\n", " self,\n", " hidden_size=384,\n", " num_hidden_layers=6,\n", " head_dim=None,\n", " attention_dropout=0.0,\n", " num_attention_heads=8,\n", " num_key_value_heads=8,\n", " intermediate_size=384 * 4,\n", " register_token_num=4,\n", " #\n", " num_input_channels=1,\n", " num_stems=1,\n", " band_proj_size=None,\n", " time_conv_length=None,\n", " time_transformer_depth=1,\n", " freq_transformer_depth=1,\n", " freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,\n", " freqs_per_bands_out: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS_OUT,\n", " #\n", " stft_n_fft=4096,\n", " stft_n_fft_out=2048,\n", " stft_hop_length=512,\n", " wave_sample_rate=44100, # 不会影响模型结构,只是记录训练参数\n", " wave_chunk_size=44100 * 6, # 不会影响模型结构,只是记录训练参数\n", " #\n", " rms_norm_eps=1e-6,\n", " rope_base=10000.0,\n", " #\n", " initializer_range=0.02,\n", " **kwargs,\n", " ):\n", " self.hidden_size = hidden_size\n", " self.num_hidden_layers = num_hidden_layers\n", " self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads\n", " self.attention_dropout = attention_dropout\n", " self.num_attention_heads = num_attention_heads\n", " self.num_key_value_heads = num_key_value_heads\n", " self.intermediate_size = intermediate_size\n", " self.register_token_num = register_token_num\n", "\n", " self.num_input_channels = num_input_channels\n", " self.num_stems = num_stems\n", " self.band_proj_size = band_proj_size if band_proj_size is not None else hidden_size\n", " self.time_conv_length = time_conv_length\n", " self.time_transformer_depth = time_transformer_depth\n", " self.freq_transformer_depth = freq_transformer_depth\n", " self.freqs_per_bands = freqs_per_bands\n", " self.freqs_per_bands_out = freqs_per_bands_out\n", " assert len(self.freqs_per_bands) == len(\n", " self.freqs_per_bands_out,\n", " ), f\"len(freqs_per_bands) ({len(self.freqs_per_bands)}) != len(freqs_per_bands_out) ({len(self.freqs_per_bands_out)})\"\n", "\n", " self.stft_n_fft = stft_n_fft\n", " self.stft_n_fft_out = stft_n_fft_out\n", " self.stft_hop_length = stft_hop_length\n", " self.wave_sample_rate = wave_sample_rate\n", " self.wave_chunk_size = wave_chunk_size\n", "\n", " freq_count = stft_n_fft // 2 + 1\n", " freq_count_out = self.stft_n_fft_out // 2 + 1\n", " assert (\n", " sum(self.freqs_per_bands) == freq_count\n", " ), f\"sum(freqs_per_bands) ({sum(self.freqs_per_bands)}) != freq_count ({freq_count})\"\n", " assert (\n", " sum(self.freqs_per_bands_out) == freq_count_out\n", " ), f\"sum(freqs_per_bands_out) ({sum(self.freqs_per_bands_out)}) != freq_count_out ({freq_count_out})\"\n", "\n", " self.rms_norm_eps = rms_norm_eps\n", " self.rope_base = rope_base\n", "\n", " self.initializer_range = initializer_range\n", "\n", " super().__init__(**kwargs)" ] }, { "cell_type": "markdown", "id": "ba4ce953", "metadata": {}, "source": [ "### 模型定义" ] }, { "cell_type": "markdown", "id": "6513ea69", "metadata": {}, "source": [ "#### RotaryEmbedding" ] }, { "cell_type": "code", "execution_count": null, "id": "48b33373", "metadata": {}, "outputs": [], "source": [ "def rotate_half(x):\n", " x1 = x[..., : x.shape[-1] // 2]\n", " x2 = x[..., x.shape[-1] // 2 :]\n", " return torch.cat((-x2, x1), dim=-1)\n", "\n", "\n", "def apply_rotary_pos_emb(q, k, cos, sin):\n", " q_embed = (q * cos) + (rotate_half(q) * sin)\n", " k_embed = (k * cos) + (rotate_half(k) * sin)\n", " return q_embed, k_embed\n", "\n", "\n", "class RotaryEmbedding(nn.Module):\n", " def __init__(\n", " self,\n", " head_dim,\n", " theta=10000.0,\n", " ):\n", " super().__init__()\n", " self.head_dim = head_dim\n", " inv_freq = 1.0 / (theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))\n", " self.register_buffer(\"inv_freq\", inv_freq)\n", "\n", " def forward(self, x, position_ids):\n", " inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n", " position_ids_expanded = position_ids[:, None, :].float()\n", "\n", " device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n", " with torch.autocast(device_type=device_type, enabled=False): # Force float32\n", " freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n", " emb = torch.cat((freqs, freqs), dim=-1)\n", " cos = emb.cos()\n", " sin = emb.sin()\n", "\n", " return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)" ] }, { "cell_type": "markdown", "id": "66f449e8", "metadata": {}, "source": [ "#### MLP" ] }, { "cell_type": "code", "execution_count": null, "id": "c2ae0987", "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", "\n", " def __init__(\n", " self,\n", " hidden_size: int,\n", " intermediate_size: int,\n", " out_size: int | None = None,\n", " bias: bool = False,\n", " ):\n", " super().__init__()\n", " self.hidden_size = hidden_size\n", " self.intermediate_size = intermediate_size\n", " self.out_size = out_size if out_size is not None else hidden_size\n", " \n", " self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)\n", " self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)\n", " self.down_proj = nn.Linear(self.intermediate_size, self.out_size, bias=bias)\n", " self.act_fn = ACT2FN[\"gelu\"]\n", "\n", " def forward(self, x):\n", " down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n", " return down_proj" ] }, { "cell_type": "markdown", "id": "6034d448", "metadata": {}, "source": [ "#### Attention" ] }, { "cell_type": "code", "execution_count": null, "id": "876edc29", "metadata": {}, "outputs": [], "source": [ "class Attention(nn.Module):\n", " def __init__(\n", " self,\n", " hidden_size: int,\n", " num_attention_heads: int,\n", " num_key_value_heads: int,\n", " attention_dropout: float,\n", " head_dim: int,\n", " ):\n", " super().__init__()\n", " self.is_causal = False\n", "\n", " self.head_dim = head_dim\n", " self.scaling = self.head_dim**-0.5\n", " self.attention_dropout = attention_dropout\n", "\n", " self.num_key_value_groups = num_attention_heads // num_key_value_heads\n", " self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=True)\n", " self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=True)\n", " self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=True)\n", " self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=True)\n", " def forward(\n", " self,\n", " hidden_states,\n", " position_embeddings: tuple[torch.Tensor, torch.Tensor],\n", " attention_mask=None,\n", " ):\n", " input_shape = hidden_states.size()[:-1]\n", " hidden_shape = (*input_shape, -1, self.head_dim) # b, n, d -> b, n, -1, d'\n", "\n", " # proj\n", " query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n", " key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n", " value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n", "\n", " # positional embeddings\n", " cos, sin = position_embeddings\n", " query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n", "\n", " # multi-group attention\n", " # key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)\n", " # value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)\n", "\n", " # mask\n", " if attention_mask is not None and attention_mask.dim() == 2:\n", " attention_mask = attention_mask[:, None, None, :] # [batch, 1, 1, seq_len]\n", "\n", " # attention\n", " attention_interface = ALL_ATTENTION_FUNCTIONS[\"sdpa\"]\n", " attn_output, attn_weights = attention_interface(\n", " self,\n", " query_states,\n", " key_states,\n", " value_states,\n", " attention_mask,\n", " dropout=0.0 if not self.training else self.attention_dropout,\n", " scaling=self.scaling,\n", " )\n", "\n", " attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n", " attn_output = self.o_proj(attn_output)\n", "\n", " return attn_output, attn_weights" ] }, { "cell_type": "markdown", "id": "740ba83a", "metadata": {}, "source": [ "#### BSRoformerLayer" ] }, { "cell_type": "code", "execution_count": null, "id": "1574728b", "metadata": {}, "outputs": [], "source": [ "class BSRoformerLayer(nn.Module):\n", " def __init__(self, config: BSRoformerConfig):\n", " super().__init__()\n", " self.self_attn = Attention(\n", " config.hidden_size,\n", " config.num_attention_heads,\n", " config.num_key_value_heads,\n", " config.attention_dropout,\n", " config.head_dim,\n", " )\n", " self.mlp = MLP(\n", " config.hidden_size,\n", " config.intermediate_size,\n", " )\n", "\n", " self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n", " self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n", "\n", " def forward(\n", " self,\n", " hidden_states,\n", " position_embeddings,\n", " attention_mask,\n", " ):\n", " # Self Attention\n", " residual = hidden_states\n", " hidden_states = self.input_layernorm(hidden_states)\n", " hidden_states, _ = self.self_attn(\n", " hidden_states,\n", " position_embeddings,\n", " attention_mask,\n", " )\n", " hidden_states = hidden_states + residual\n", "\n", " # Fully Connected\n", " residual = hidden_states\n", " hidden_states = self.post_attention_layernorm(hidden_states)\n", " hidden_states = self.mlp(hidden_states)\n", " hidden_states = hidden_states + residual\n", "\n", " return hidden_states" ] }, { "cell_type": "markdown", "id": "9a286c76", "metadata": {}, "source": [ "#### BSRoformerAxialTransformer" ] }, { "cell_type": "code", "execution_count": null, "id": "e2aeeb00", "metadata": {}, "outputs": [], "source": [ "class BSRoformerAxialTransformer(nn.Module):\n", " def __init__(\n", " self,\n", " config: BSRoformerConfig,\n", " transformer_depth: int,\n", " is_time_transformer: bool,\n", " ):\n", " super().__init__()\n", " self.layers = nn.ModuleList([BSRoformerLayer(config) for _ in range(transformer_depth)])\n", " self.is_time_transformer = is_time_transformer\n", "\n", " def forward(\n", " self,\n", " hidden_states,\n", " position_embeddings,\n", " attention_mask,\n", " ):\n", " if self.is_time_transformer:\n", " hidden_states = rearrange(hidden_states, 'b t f d -> b f t d')\n", "\n", " # merge batch\n", " b, seq_len_1, seq_len_2, d = hidden_states.shape\n", " hidden_states = rearrange(hidden_states, 'b n m d -> (b n) m d')\n", "\n", " for layer in self.layers:\n", " hidden_states = layer(\n", " hidden_states,\n", " position_embeddings,\n", " attention_mask,\n", " )\n", "\n", " # unpack batch\n", " hidden_states = rearrange(hidden_states, '(b n) m d -> b n m d', b=b)\n", "\n", " if self.is_time_transformer:\n", " hidden_states = rearrange(hidden_states, 'b f t d -> b t f d')\n", "\n", " return hidden_states" ] }, { "cell_type": "markdown", "id": "51ca681e", "metadata": {}, "source": [ "#### BandSplit" ] }, { "cell_type": "code", "execution_count": null, "id": "62926fb5", "metadata": {}, "outputs": [], "source": [ "class BandSplit(nn.Module):\n", " def __init__(self, config: BSRoformerConfig):\n", " super().__init__()\n", " self.dim_inputs = tuple(2 * f * config.num_input_channels for f in config.freqs_per_bands)\n", " self.to_features = nn.ModuleList(\n", " [\n", " nn.Sequential(nn.RMSNorm(dim_in, eps=config.rms_norm_eps), nn.Linear(dim_in, config.band_proj_size))\n", " for dim_in in self.dim_inputs\n", " ]\n", " )\n", "\n", " def forward(self, x):\n", " x_split = x.split(self.dim_inputs, dim=-1)\n", " outs = [to_feature(split_input) for split_input, to_feature in zip(x_split, self.to_features)]\n", " return torch.stack(outs, dim=-2)" ] }, { "cell_type": "markdown", "id": "71547d98", "metadata": {}, "source": [ "#### MaskEstimator" ] }, { "cell_type": "code", "execution_count": null, "id": "b12c1524", "metadata": {}, "outputs": [], "source": [ "class MaskEstimator(nn.Module):\n", " def __init__(self, config: BSRoformerConfig):\n", " super().__init__()\n", "\n", " dim_inputs = tuple(f * config.num_input_channels * 2 for f in config.freqs_per_bands_out)\n", " self.to_freq_mlps = nn.ModuleList([nn.Linear(config.band_proj_size, dim) for dim in dim_inputs])\n", " self.to_gate_mlps = nn.ModuleList([nn.Linear(config.band_proj_size, dim // 2) for dim in dim_inputs])\n", "\n", " def forward(self, x):\n", " \"\"\"\n", "\n", " Args:\n", " x: (batch, time, bands, band_proj_size)\n", " \n", " Returns:\n", " (batch, time, freq * channel * 2)\n", " \"\"\"\n", " x_unbind = x.unbind(dim=-2)\n", " outs = []\n", " for band_features, freq_mlp, gate_mlp in zip(x_unbind, self.to_freq_mlps, self.to_gate_mlps):\n", " mask = freq_mlp(band_features)\n", " gate = gate_mlp(band_features)\n", " gate = gate.repeat_interleave(2, dim=-1)\n", " outs.append(mask * torch.sigmoid(gate))\n", " return torch.cat(outs, dim=-1)" ] }, { "cell_type": "markdown", "id": "780e0637", "metadata": {}, "source": [ "#### BSRoformerPreTrainedModel" ] }, { "cell_type": "code", "execution_count": null, "id": "7ff31afd", "metadata": {}, "outputs": [], "source": [ "class BSRoformerPreTrainedModel(PreTrainedModel):\n", " config_class = BSRoformerConfig\n", " base_model_prefix = \"freq_domain_model\"\n", " _no_split_modules = [\"BSRoformerLayer\"]" ] }, { "cell_type": "markdown", "id": "5a5ff848", "metadata": {}, "source": [ "#### BSRoformerModel" ] }, { "cell_type": "code", "execution_count": null, "id": "b54ae746", "metadata": {}, "outputs": [], "source": [ "class BSRoformerModel(BSRoformerPreTrainedModel):\n", " \"\"\"BS-RoFormer 模型的核心,在频域上对音频进行建模。\"\"\"\n", " def __init__(self, config: BSRoformerConfig):\n", " super().__init__(config)\n", " self.config = config\n", "\n", " # 主要模块\n", " self.rotary_emb = RotaryEmbedding(config.head_dim, theta=config.rope_base)\n", " self.band_split = BandSplit(config)\n", " self.layers = nn.ModuleList(\n", " nn.ModuleList(\n", " [\n", " BSRoformerAxialTransformer(config, config.time_transformer_depth, is_time_transformer=True),\n", " BSRoformerAxialTransformer(config, config.freq_transformer_depth, is_time_transformer=False),\n", " ]\n", " )\n", " for _ in range(config.num_hidden_layers)\n", " )\n", " self.final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n", " self.mask_estimators = nn.ModuleList([MaskEstimator(config) for _ in range(config.num_stems)])\n", "\n", " # 时域压缩\n", " self.time_conv_length = config.time_conv_length\n", " if self.time_conv_length is not None:\n", " self.time_conv = nn.Sequential(\n", " nn.RMSNorm(config.band_proj_size * self.time_conv_length, eps=config.rms_norm_eps),\n", " MLP(\n", " hidden_size=config.band_proj_size * self.time_conv_length,\n", " intermediate_size=config.hidden_size * self.time_conv_length,\n", " out_size=config.hidden_size,\n", " bias=True,\n", " ),\n", " )\n", " self.time_deconv = nn.Sequential(\n", " MLP(\n", " hidden_size=config.hidden_size,\n", " intermediate_size=config.hidden_size * self.time_conv_length,\n", " out_size=config.band_proj_size * self.time_conv_length,\n", " bias=True,\n", " ),\n", " nn.RMSNorm(config.band_proj_size * self.time_conv_length, eps=config.rms_norm_eps),\n", " )\n", "\n", "\n", " # 寄存器\n", " rn = config.register_token_num\n", " self.register_tokens = nn.Parameter(torch.normal(0, 0.02, size=(rn, rn, config.hidden_size)))\n", "\n", " self.post_init()\n", "\n", " def forward(\n", " self,\n", " x,\n", " position_ids=None,\n", " ):\n", " \"\"\"\n", " Args:\n", " x: (batch, time, freq_bins),其中 freq_bins = freq * channel * 2\n", " position_ids: (batch, time)\n", " Returns:\n", " mask: (batch, num_stems, time, freq_bins)\n", " \"\"\"\n", " origin_dtype = x.dtype\n", " target_dtype = next(self.parameters()).dtype\n", " x = x.to(dtype=target_dtype)\n", " t_origin = x.shape[1]\n", "\n", " # 1. band split\n", " if self.time_conv_length is not None:\n", " pad_t = (self.time_conv_length - (t_origin % self.time_conv_length)) % self.time_conv_length\n", " if pad_t > 0:\n", " x = F.pad(x, (0, 0, 0, pad_t), value=0.0)\n", " hidden_states = self.band_split(x)\n", " if self.time_conv_length is not None:\n", " hidden_states = rearrange(hidden_states, \"b (t t_c) n d -> b t n (d t_c)\", t_c=self.time_conv_length)\n", " hidden_states = self.time_conv(hidden_states)\n", " b, t, n, h = hidden_states.shape # [batch, t, n, hidden_size]\n", "\n", " # 2. RoPE\n", " if position_ids is None:\n", " position_ids = torch.arange(t, device=hidden_states.device).unsqueeze(0)\n", " pos_embeds = self.rotary_emb(hidden_states, position_ids)\n", " pos_embeds_for_freq = self.rotary_emb(\n", " hidden_states,\n", " torch.arange(n, device=hidden_states.device).unsqueeze(0),\n", " )\n", "\n", " # 3. add register tokens\n", " rn = self.config.register_token_num\n", " hidden_states = F.pad(hidden_states, (0, 0, 0, rn, 0, rn))\n", " hidden_states[:, t:, n:, :] = self.register_tokens\n", "\n", " def pad_rope(cos, sin):\n", " cos_padded = F.pad(cos, (0, 0, 0, rn), value=1.0)\n", " sin_padded = F.pad(sin, (0, 0, 0, rn), value=0.0)\n", " return cos_padded, sin_padded\n", "\n", " pos_embeds = pad_rope(*pos_embeds)\n", " pos_embeds_for_freq = pad_rope(*pos_embeds_for_freq)\n", "\n", " # 4. axial transformer layers\n", " for time_transformer, freq_transformer in self.layers:\n", " hidden_states = time_transformer(\n", " hidden_states,\n", " position_embeddings=pos_embeds,\n", " attention_mask=None,\n", " )\n", " hidden_states = freq_transformer(\n", " hidden_states,\n", " position_embeddings=pos_embeds_for_freq,\n", " attention_mask=None,\n", " )\n", "\n", " # 5. remove register tokens, and final norm\n", " hidden_states = hidden_states[:, :t, :n, :]\n", " hidden_states = self.final_norm(hidden_states)\n", "\n", " # 6. mask estimation\n", " if self.time_conv_length is not None:\n", " hidden_states = self.time_deconv(hidden_states)\n", " hidden_states = rearrange(hidden_states, \"b t n (d t_c) -> b (t t_c) n d\", t_c=self.time_conv_length)\n", " hidden_states = hidden_states[:, :t_origin, :, :]\n", " mask = torch.stack([fn(hidden_states) for fn in self.mask_estimators], dim=1)\n", "\n", " return mask.to(dtype=origin_dtype)" ] }, { "cell_type": "markdown", "id": "e0da7e63", "metadata": {}, "source": [ "#### BSRoformerForMaskedEstimation" ] }, { "cell_type": "code", "execution_count": null, "id": "d4e9d81c", "metadata": {}, "outputs": [], "source": [ "class BSRoformerForMaskedEstimation(BSRoformerPreTrainedModel):\n", " \"\"\"包含 STFT/iSTFT 和频域处理的完整模型。\"\"\"\n", "\n", " def __init__(\n", " self,\n", " config: BSRoformerConfig,\n", " ):\n", " super().__init__(config)\n", " self.freq_domain_model = BSRoformerModel(config)\n", " self.config = config\n", "\n", " self.register_buffer(\"stft_window\", torch.hann_window(config.stft_n_fft), persistent=False)\n", " self.register_buffer(\"stft_out_window\", torch.hann_window(config.stft_n_fft_out), persistent=False)\n", "\n", " self.stft_kwargs = dict(\n", " n_fft=config.stft_n_fft,\n", " hop_length=config.stft_hop_length,\n", " win_length=config.stft_n_fft,\n", " normalized=False,\n", " )\n", " self.stft_out_kwargs = dict(\n", " n_fft=config.stft_n_fft_out,\n", " hop_length=config.stft_hop_length,\n", " win_length=config.stft_n_fft_out,\n", " normalized=False,\n", " )\n", "\n", " freqs = config.stft_n_fft // 2 + 1\n", " assert sum(config.freqs_per_bands) == freqs, f\"Sum of freqs_per_bands must be {freqs}\"\n", " self.wave_channels = config.num_input_channels\n", "\n", " def forward(\n", " self,\n", " raw_audio: torch.Tensor,\n", " target: torch.Tensor | None = None,\n", " ):\n", " \"\"\"\n", " Args:\n", " raw_audio (`torch.Tensor` of shape `(batch, channels, time)`):\n", " The raw audio waveform. `time` must be `config.wave_chunk_size`.\n", " target (`torch.Tensor`, *optional*, shape `(batch, num_stems, channels, time)`):\n", " The target audio waveform for loss calculation.\n", "\n", " Returns:\n", " torch.Tensor (`torch.Tensor` of shape `(batch, num_stems, channels, time)`):\n", " The reconstructed audio waveform.\n", " \"\"\"\n", "\n", " device = raw_audio.device\n", " dtype = raw_audio.dtype\n", " b, c, t = raw_audio.shape # batch, channel, time\n", "\n", " # 1. STFT: Convert audio to spectrogram\n", " with torch.autocast(device_type=device.type, enabled=False):\n", " raw_audio = raw_audio.to(dtype=torch.float32)\n", "\n", " raw_audio_packed = rearrange(raw_audio, \"b c t -> (b c) t\")\n", " stft_repr = torch.stft(\n", " raw_audio_packed,\n", " **self.stft_kwargs,\n", " window=self.stft_window,\n", " return_complex=True,\n", " )\n", " stft_repr = torch.view_as_real(stft_repr)\n", " stft_repr = rearrange(stft_repr, \"(b c) f t T -> b c f t T\", c=c)\n", " # Merge frequency, channel, and complex dimensions\n", " stft_repr_merged = rearrange(stft_repr, \"b c f t T -> b t (f c T)\")\n", "\n", " stft_repr_merged = stft_repr_merged.to(dtype=dtype)\n", "\n", " # 2. Mask Estimation\n", " mask = self.freq_domain_model(stft_repr_merged)\n", " mask = rearrange(mask, \"b n t (f c T) -> b n c f t T\", T=2, c=c)\n", " mask = mask.to(dtype=torch.float32)\n", "\n", " # 3. Mask Application\n", " with torch.autocast(device_type=device.type, enabled=False):\n", " stft_repr = torch.stft(\n", " raw_audio_packed,\n", " **self.stft_out_kwargs,\n", " window=self.stft_out_window,\n", " return_complex=True,\n", " )\n", " stft_repr = torch.view_as_real(stft_repr)\n", " stft_repr_expanded = rearrange(stft_repr, \"(b c) f t T -> b 1 c f t T\", c=c)\n", " stft_repr_complex = torch.view_as_complex(stft_repr_expanded)\n", " mask_complex = torch.view_as_complex(mask)\n", " masked_stft = stft_repr_complex * mask_complex\n", "\n", " # 4. iSTFT: Convert masked spectrogram back to audio\n", " masked_stft = rearrange(masked_stft, \"b n c f t -> (b n c) f t\")\n", " recon_audio = torch.istft(\n", " masked_stft,\n", " **self.stft_out_kwargs,\n", " window=self.stft_out_window,\n", " return_complex=False,\n", " length=raw_audio.shape[-1],\n", " )\n", " recon_audio = rearrange(recon_audio, \"(b n c) t -> b n c t\", c=self.wave_channels, n=self.config.num_stems)\n", "\n", " if target is None: # return recon_audio\n", " return recon_audio\n", "\n", " # 5. Loss Calculation\n", " target = target[..., : recon_audio.shape[-1]]\n", " loss = F.l1_loss(recon_audio, target)\n", " return loss\n", "\n", " @torch.inference_mode()\n", " def separate(\n", " model,\n", " mixed_wave,\n", " chunk_size=None,\n", " overlap_size=None,\n", " batch_size=4,\n", " gap_size=44100 * 1,\n", " verbose=True,\n", " ):\n", " \"\"\"\n", " 输入一段 (C, wave_length) 音频张量,使用模型推理,输出 (num_stems, C, wave_length) 音频张量。\n", "\n", " 其中 C 是音频通道数,num_stems 是分轨数量。\n", "\n", " Separates a full audio waveform into its constituent stems.\n", "\n", " Args:\n", " mixed_wave (`torch.Tensor` of shape `(channels, time)`):\n", " The raw audio waveform of the mixture.\n", " chunk_size (`int`, *optional*, defaults to model.config.wave_chunk_size):\n", " The size of each audio chunk for processing.\n", " overlap_size (`int`, *optional*, defaults to `chunk_size // 2`):\n", " The size of the overlap between consecutive chunks.\n", " batch_size (`int`, *optional*, defaults to `4`):\n", " The number of chunks to process in a single batch.\n", " gap_size (`int`, *optional*, defaults to `44100` (1 second at 44.1kHz)):\n", " The size of the gap for the fade-in/fade-out window.\n", " verbose (`bool`, *optional*, defaults to `True`):\n", " Whether to print progress information during processing.\n", " Returns:\n", " torch.Tensor (`torch.Tensor` of shape `(num_stems, channels, time)`):\n", " The separated audio waveforms.\n", " \"\"\"\n", " assert mixed_wave.dim() == 2, \"mixed_wave must be a 2D tensor of shape (channels, time)\"\n", " assert (\n", " mixed_wave.size(0) == model.config.num_input_channels\n", " ), f\"mixed_wave must have {model.config.num_input_channels} channels, but got {mixed_wave.size(0)}\"\n", "\n", " chunk_size = chunk_size or model.config.wave_chunk_size\n", " overlap_size = overlap_size or (chunk_size // 2)\n", "\n", " # 淡入淡出 窗口\n", " fade_size = chunk_size // 10\n", " window = torch.ones(chunk_size - 2 * gap_size)\n", " window[:fade_size] = torch.linspace(0, 1, fade_size)\n", " window[-fade_size:] = torch.linspace(1, 0, fade_size)\n", " window = F.pad(window, (gap_size, gap_size), value=0.0)\n", " window = window.to(mixed_wave.device)\n", "\n", " # 分块准备\n", " wave_length = mixed_wave.shape[-1]\n", " n = math.ceil(max(wave_length - chunk_size, 0) / overlap_size) + 1 # 分块数量\n", " required_length = (n - 1) * overlap_size + chunk_size\n", "\n", " if verbose:\n", " print(f\"Input wave shape: {mixed_wave.shape}\")\n", " print(f\"Padded wave length: {required_length}\")\n", " print(f\"Batch size: {batch_size}\")\n", "\n", " # pad 与分块\n", " padded_wave = F.pad(mixed_wave, (0, required_length - wave_length), mode=\"constant\")\n", " unfolded_chunks = padded_wave.unfold(dimension=-1, size=chunk_size, step=overlap_size) # (C, n, chunk_size)\n", " batch = unfolded_chunks.permute(1, 0, 2) # (n, C, chunk_size)\n", "\n", " # 模型推理\n", " outputs = []\n", " for i, chunk_batch in enumerate(batch.split(batch_size, dim=0)):\n", " if verbose:\n", " print(f\"\\rProcessing: {i * batch_size + chunk_batch.shape[0]} / {n}\")\n", " outputs.append(model(chunk_batch))\n", " batch = torch.cat(outputs, dim=0) # (n, num_stems, C, chunk_size)\n", "\n", " # 加窗\n", " _, num_stems, C, _ = batch.shape\n", " batch = batch * window\n", "\n", " # 还原波形\n", " batch = batch.view(n, -1, chunk_size).permute(1, 2, 0) # (num_stems * C, chunk_size, n)\n", " output_result_buffer = F.fold(\n", " batch,\n", " output_size=(1, required_length),\n", " kernel_size=(1, chunk_size),\n", " stride=(1, overlap_size),\n", " ) # (num_stems * C, 1, 1, required_length)\n", " output_result_buffer = output_result_buffer.view(num_stems, C, -1) # (num_stems, C, required_length)\n", "\n", " # 获得权重和\n", " window_for_fold = window.expand(1, 1, -1).repeat(1, n, 1)\n", " weighted_sum_counter = F.fold(\n", " window_for_fold.permute(0, 2, 1),\n", " output_size=(1, required_length),\n", " kernel_size=(1, chunk_size),\n", " stride=(1, overlap_size),\n", " ) # (1, 1, 1, required_length)\n", " weighted_sum_counter = weighted_sum_counter.view(1, 1, -1) # (1, 1, required_length)\n", " weighted_sum_counter.clamp_min_(1e-8)\n", "\n", " # 归一化\n", " final_output = (output_result_buffer / weighted_sum_counter)[:, :, :wave_length]\n", "\n", " return final_output" ] }, { "cell_type": "markdown", "id": "e4e5c718", "metadata": {}, "source": [ "#### 冒烟测试" ] }, { "cell_type": "code", "execution_count": null, "id": "f0f2b263", "metadata": {}, "outputs": [], "source": [ "if is_notebook():\n", " model_config = BSRoformerConfig(\n", " hidden_size=8,\n", " num_hidden_layers=1,\n", " head_dim=4,\n", " num_attention_heads=4,\n", " num_key_value_heads=2,\n", " intermediate_size=8 * 2,\n", " register_token_num=2,\n", " #\n", " num_input_channels=2,\n", " num_stems=4,\n", " time_transformer_depth=1,\n", " freq_transformer_depth=1,\n", " freqs_per_bands=DEFAULT_FREQS_PER_BANDS,\n", " freqs_per_bands_out=DEFAULT_FREQS_PER_BANDS_OUT,\n", " #\n", " stft_n_fft=4096,\n", " stft_n_fft_out=2048,\n", " stft_hop_length=512,\n", " )\n", " model = BSRoformerForMaskedEstimation(model_config)\n", "\n", " dummy_input = torch.randn(6, 2, model_config.wave_chunk_size)\n", " output = model(dummy_input)\n", "\n", " dummy_targets = torch.randn(6, 4, 2, model_config.wave_chunk_size)\n", " loss = model(dummy_input, target=dummy_targets)\n", "\n", " dummy_song = torch.randn(2, 44100 * 30)\n", " result = model.separate(\n", " dummy_song,\n", " chunk_size=model_config.wave_chunk_size,\n", " overlap_size=model_config.wave_chunk_size // 2,\n", " batch_size=8,\n", " gap_size=44100 * 1,\n", " )\n", "\n", " del model, model_config, dummy_input, output, dummy_targets, loss" ] }, { "cell_type": "markdown", "id": "9d26ff61", "metadata": {}, "source": [ "## 实例化 Datasets" ] }, { "cell_type": "code", "execution_count": null, "id": "f4d791c0", "metadata": {}, "outputs": [], "source": [ "train_dataset = AugmentDataset(\n", " data_path=[\n", " \"/mnt/sda/data/20250826_MUSDB18HQ/train\",\n", " \"/mnt/sda/data/20250826_MUSDB18HQ/test\",\n", " # \"/mnt/sda/data/20250902_DSD100/datas\",\n", " ],\n", " wave_chunk_size=wave_chunk_size,\n", " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n", ")\n", "val_dataset = ValidationDataset(\n", " data_path=\"/mnt/sda/data/20250826_MUSDB18HQ/valid\",\n", " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n", ")\n", "\n", "train_loader = DataLoader(\n", " train_dataset,\n", " batch_size=batch_size,\n", " num_workers=num_workers,\n", " pin_memory=True,\n", " persistent_workers=True if num_workers > 0 else False,\n", " # prefetch_factor=4 if num_workers > 0 else None,\n", ")\n", "val_loader = DataLoader(\n", " val_dataset,\n", " batch_size=1,\n", " num_workers=num_workers,\n", " pin_memory=True,\n", " persistent_workers=True if num_workers > 0 else False,\n", " shuffle=False,\n", " # prefetch_factor=4 if num_workers > 0 else None,\n", ")" ] }, { "cell_type": "markdown", "id": "21701211", "metadata": {}, "source": [ "## Lightning" ] }, { "cell_type": "code", "execution_count": null, "id": "23cda886", "metadata": {}, "outputs": [], "source": [ "def compute_sdr(target, estimate):\n", " target_np = target.float().cpu().numpy()\n", " estimate_np = estimate.float().cpu().numpy()\n", "\n", " sdr_list = []\n", "\n", " for this_target, this_estimate in zip(target_np, estimate_np):\n", " channel_sdrs = []\n", " for this_channel_target, this_channel_estimate in zip(this_target, this_estimate):\n", " signal_power = np.sum(this_channel_target ** 2)\n", " noise_power = np.sum((this_channel_target - this_channel_estimate) ** 2)\n", "\n", " if noise_power == 0:\n", " sdr = float('inf')\n", " else:\n", " sdr = 10 * np.log10(signal_power / noise_power)\n", "\n", " channel_sdrs.append(sdr)\n", "\n", " channel_sdr_mean = np.mean(channel_sdrs)\n", " sdr_list.append(channel_sdr_mean)\n", "\n", " return sdr_list\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2e5002b1", "metadata": {}, "outputs": [], "source": [ "from lightning.pytorch.strategies import ParallelStrategy\n", "\n", "class LightningModel(BaseModule):\n", "\n", " def __init__(\n", " self,\n", " model,\n", " training_config: TrainingConfig,\n", " ):\n", " super().__init__(model, training_config)\n", "\n", " self.validation_sdr_results = []\n", "\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", " def training_step(self, batch, batch_idx):\n", " optimizers = self.optimizers()\n", " optimizers = [optimizers] if not isinstance(optimizers, list) else optimizers\n", "\n", " need_step = (batch_idx + 1) % self.training_config.accumulate_grad_batches == 0\n", " need_step = need_step or self.trainer.is_last_batch\n", "\n", " target_stems, mixed_audio = batch\n", "\n", " context = nullcontext()\n", " if not need_step and isinstance(self.trainer.strategy, ParallelStrategy):\n", " context = self.trainer.strategy.block_backward_sync()\n", " with context:\n", " loss = self.model(mixed_audio, target=target_stems)\n", " self.manual_backward(loss)\n", " if need_step:\n", " self.log('train/loss', loss.item(), sync_dist=True)\n", "\n", " # 梯度和优化器\n", " if need_step:\n", " grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5.0)\n", " self.log('train/grad_norm', grad_norm.item(), sync_dist=True)\n", " for optimizer in optimizers:\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " # 学习率调度器\n", " schedulers = self.lr_schedulers()\n", " schedulers = [schedulers] if not isinstance(schedulers, list) else schedulers\n", " for scheduler in schedulers:\n", " scheduler.step()\n", "\n", " def validation_step(self, batch, batch_idx):\n", " target_stems, mixed_audio = batch\n", "\n", " batch_size = mixed_audio.shape[0]\n", " batch_sdr_scores = []\n", "\n", " for i in range(batch_size):\n", " single_mixed = mixed_audio[i] # (channels, time)\n", " single_target = target_stems[i] # (stems, channels, time)\n", "\n", " with torch.no_grad():\n", " predicted_stems = self.model.separate(\n", " single_mixed,\n", " batch_size=16,\n", " gap_size=0,\n", " verbose=False,\n", " ) # (stems, channels, time)\n", "\n", " sdr = compute_sdr(single_target, predicted_stems)\n", " batch_sdr_scores.append(sdr)\n", "\n", " sdrs = np.array(batch_sdr_scores)\n", " sdrs = sdrs.mean(axis=0)\n", "\n", " self.validation_sdr_results.append(sdrs)\n", "\n", " return {\n", " \"val/sdr\": sdrs,\n", " }\n", "\n", " def on_validation_epoch_end(self):\n", " if len(self.validation_sdr_results) > 0:\n", " avg_sdrs = np.mean(self.validation_sdr_results, axis=0)\n", " self.log('val/sdr', avg_sdrs.mean(), on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)\n", " for i, one in enumerate(avg_sdrs):\n", " self.log(f'val/sdr_stem_{i}', one, on_step=False, on_epoch=True, sync_dist=True)\n", "\n", " self.validation_sdr_results.clear()\n", "\n", " def on_save_checkpoint(self, checkpoint: dict) -> None:\n", " checkpoint_keys = list(checkpoint.keys())\n", " for key in checkpoint_keys:\n", " if key != 'state_dict':\n", " del checkpoint[key]" ] }, { "cell_type": "markdown", "id": "7e15c0cb", "metadata": {}, "source": [ "## 特殊 mel" ] }, { "cell_type": "code", "execution_count": null, "id": "9c26c48d", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import librosa\n", "from typing import Optional\n", "\n", "\n", "def compute_mel_freqs_per_bands(\n", " n_fft: int,\n", " sample_rate: int,\n", " num_bands: int,\n", " fmin: float = 0.0,\n", " fmax: Optional[float] = None,\n", " alpha: float = 1.0,\n", ") -> tuple[int, ...]:\n", "\n", " fmax = sample_rate / 2 if fmax is None else fmax\n", " freq_bins = n_fft // 2 + 1\n", "\n", " # 构造“变形 mel 轴”\n", " mel_min = librosa.hz_to_mel(fmin)\n", " mel_max = librosa.hz_to_mel(fmax)\n", " mel_lin = np.linspace(mel_min, mel_max, num_bands + 1)\n", " mel_warp = mel_min + (mel_lin - mel_min) ** alpha / (mel_max - mel_min) ** (alpha - 1)\n", " warped_freqs = librosa.mel_to_hz(mel_warp)\n", "\n", " # 查找 FFT bin 边界\n", " fft_freqs = librosa.fft_frequencies(sr=sample_rate, n_fft=n_fft)\n", " bin_boundaries = np.searchsorted(fft_freqs, warped_freqs, side=\"left\")\n", "\n", " # 强制单调递增,每个区间至少占 1 个 bin\n", " bin_boundaries[0] = 0\n", " for i in range(1, len(bin_boundaries)):\n", " bin_boundaries[i] = max(bin_boundaries[i], bin_boundaries[i - 1] + 1)\n", " bin_boundaries[-1] = freq_bins # 并确保覆盖到末尾\n", "\n", " return tuple(np.sort(np.diff(bin_boundaries)).tolist())\n", "\n", "\n", "freqs_per_bands_out = compute_mel_freqs_per_bands(\n", " n_fft=2048,\n", " sample_rate=44100,\n", " num_bands=80,\n", " alpha=1.5,\n", ")\n", "freqs_per_bands = list(one * 2 for one in freqs_per_bands_out)\n", "freqs_per_bands[-1] -= 1\n", "\n", "print(freqs_per_bands_out)\n", "print(sum(freqs_per_bands_out))" ] }, { "cell_type": "code", "execution_count": null, "id": "ae11c004", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "cum_adapt = np.cumsum(freqs_per_bands)\n", "cum_default = np.cumsum(DEFAULT_FREQS_PER_BANDS)\n", "\n", "plt.figure(figsize=(8, 4))\n", "plt.plot(cum_adapt, marker='o', label='freq_bands')\n", "plt.plot(cum_default, marker='x', label='DEFAULT_FREQS_PER_BANDS')\n", "plt.grid(alpha=0.3)\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# 可选:每个频带宽度直接比较\n", "plt.figure(figsize=(8, 4))\n", "plt.plot(freqs_per_bands, marker='o', label='freq_bands')\n", "plt.plot(DEFAULT_FREQS_PER_BANDS, marker='x', label='DEFAULT_FREQS_PER_BANDS')\n", "plt.grid(alpha=0.3)\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "e13d2d53", "metadata": {}, "source": [ "## 配置与实例化" ] }, { "cell_type": "markdown", "id": "c4304f6c", "metadata": {}, "source": [ "### 基本参数" ] }, { "cell_type": "code", "execution_count": null, "id": "c31b1d32", "metadata": {}, "outputs": [], "source": [ "from pl_utils import OneOptimizerConfig, TrainingConfig\n", "from pl_utils import LinearWarmupStepDecayLR\n", "\n", "scheduler = LinearWarmupStepDecayLR(\n", " max_steps=20000,\n", " lr_initial=1e-2,\n", " lr_max=1,\n", " lr_warmup_steps=400,\n", " decay_factor=0.99,\n", " decay_steps=40000,\n", ")\n", "\n", "muon_config = OneOptimizerConfig(\n", " torch.optim.Muon,\n", " {\n", " \"lr\": 5e-4,\n", " \"weight_decay\": 1e-2,\n", " \"adjust_lr_fn\": \"match_rms_adamw\",\n", " },\n", " scheduler=scheduler,\n", " keywords=[\"mlp.\", \"attn.\", \"conv\"],\n", " excluded_from_weight_decay=[\"bias\", \"norm\", \"embed\", \"scale\"],\n", ")\n", "\n", "adamw_config = OneOptimizerConfig(\n", " torch.optim.AdamW,\n", " {\n", " \"lr\": 5e-4,\n", " \"weight_decay\": 1e-2,\n", " },\n", " scheduler=scheduler,\n", " excluded_from_weight_decay=[\"bias\", \"norm\", \"embed\", \"scale\"],\n", ")\n", "\n", "training_config = TrainingConfig(\n", " optimizers=[muon_config, adamw_config],\n", " accumulate_grad_batches=1,\n", ")" ] }, { "cell_type": "markdown", "id": "2cdabb95", "metadata": {}, "source": [ "### 实例化" ] }, { "cell_type": "code", "execution_count": null, "id": "13030935", "metadata": {}, "outputs": [], "source": [ "model_config = BSRoformerConfig(\n", " hidden_size=384,\n", " num_hidden_layers=9,\n", " # head_dim=48,\n", " num_attention_heads=8,\n", " num_key_value_heads=4,\n", " intermediate_size=384 * 3,\n", " register_token_num=4,\n", " #\n", " num_input_channels=2,\n", " num_stems=4,\n", " band_proj_size=256,\n", " time_conv_length=4,\n", " time_transformer_depth=1,\n", " freq_transformer_depth=1,\n", " freqs_per_bands=freqs_per_bands,\n", " freqs_per_bands_out=freqs_per_bands_out,\n", " #\n", " stft_n_fft=4096,\n", " stft_n_fft_out=2048,\n", " stft_hop_length=512,\n", " wave_chunk_size=wave_chunk_size,\n", ")\n", "model = BSRoformerForMaskedEstimation(model_config)\n", "\n", "pl_model = LightningModel(model, training_config=training_config)" ] }, { "cell_type": "markdown", "id": "9a430e4f", "metadata": {}, "source": [ "## 正式训练" ] }, { "cell_type": "markdown", "id": "5af097b5", "metadata": {}, "source": [ "### 准备" ] }, { "cell_type": "code", "execution_count": null, "id": "90058d7a", "metadata": {}, "outputs": [], "source": [ "from lightning.pytorch.utilities.model_summary import summarize\n", "\n", "summarize(pl_model, max_depth=3)" ] }, { "cell_type": "code", "execution_count": null, "id": "9f80c543", "metadata": {}, "outputs": [], "source": [ "from lightning.fabric.utilities.throughput import measure_flops\n", "\n", "dummy_song = torch.randn(2, 44100 * 30).to(\"cuda\")\n", "model = model.to(\"cuda\")\n", "\n", "with torch.inference_mode():\n", " model_fwd = lambda: model.separate(\n", " dummy_song,\n", " gap_size=0,\n", " batch_size=16,\n", " )\n", " fwd_flops = measure_flops(model, model_fwd)\n", "print(f\"Forward FLOPs (inference 30s audio): {fwd_flops / 1e9:.2f} GFLOPs\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d52e16e9", "metadata": {}, "outputs": [], "source": [ "from torchao.float8 import convert_to_float8_training, Float8LinearConfig\n", "\n", "def module_filter_fn(mod: torch.nn.Module, fqn: str):\n", " if \"attn.\" not in fqn and \"mlp.\" not in fqn:\n", " return False\n", " if isinstance(mod, torch.nn.Linear):\n", " if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:\n", " return False\n", " return True\n", "\n", "# convert_to_float8_training(model.freq_domain_model, module_filter_fn=module_filter_fn)\n", "\n", "model.freq_domain_model.compile(mode=\"reduce-overhead\")" ] }, { "cell_type": "markdown", "id": "2ec2b83b", "metadata": {}, "source": [ "### 开始" ] }, { "cell_type": "code", "execution_count": null, "id": "6ff9af11", "metadata": {}, "outputs": [], "source": [ "import lightning.pytorch as L\n", "from lightning.pytorch.callbacks import ModelCheckpoint\n", "from lightning.pytorch.loggers import TensorBoardLogger\n", "from pl_utils.lightning import format_next_version_name\n", "from lightning.pytorch.strategies import DDPStrategy\n", "\n", "name = \"baseline4,{},9层大模型,准备收尾\".format(get_model_parameters_count(pl_model)[\"total_readable\"])\n", "logger = TensorBoardLogger(save_dir=\"./\", version=format_next_version_name(name))\n", "\n", "checkpoint_callback = ModelCheckpoint(\n", " auto_insert_metric_name=True,\n", " save_top_k=1,\n", " monitor=\"val/sdr\",\n", " mode=\"max\",\n", " every_n_epochs=1,\n", " save_weights_only=True,\n", " # save_last=\"link\",\n", " save_on_train_epoch_end=False,\n", " save_last=True,\n", ")\n", "\n", "trainer = L.Trainer(\n", " logger=logger,\n", " accelerator='gpu',\n", " # max_epochs=16,\n", " strategy=DDPStrategy(find_unused_parameters=False),\n", " precision='bf16-true',\n", " # accumulate_grad_batches=4,\n", " max_steps=200000 * 4,\n", " val_check_interval=2000,\n", " log_every_n_steps=200,\n", " default_root_dir=\"./\",\n", " #\n", " callbacks=[checkpoint_callback],\n", " # enable_checkpointing=False,\n", " #\n", " num_sanity_val_steps=0,\n", " # fast_dev_run=True,\n", " # enable_checkpointing=False,\n", " enable_model_summary=True,\n", ")\n", "\n", "trainer.fit(pl_model, train_loader, val_loader)" ] }, { "cell_type": "markdown", "id": "f9304c3e", "metadata": {}, "source": [ "## 提前退出" ] }, { "cell_type": "code", "execution_count": null, "id": "5e11d871", "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if not is_notebook():\n", " sys.exit()" ] }, { "cell_type": "markdown", "id": "8bca044c", "metadata": {}, "source": [ "## 加载与推理" ] }, { "cell_type": "code", "execution_count": null, "id": "2f65245d", "metadata": {}, "outputs": [], "source": [ "ckpt = \"lightning_logs/version_229_baseline4,46.8M,9层大模型,准备收尾/checkpoints/epoch=0-step=620000.ckpt\"\n", "state_dict = torch.load(ckpt, weights_only=True)[\"state_dict\"]\n", "pl_model.load_state_dict(state_dict)\n", "pl_model = pl_model.to(\"cuda\")" ] }, { "cell_type": "code", "execution_count": null, "id": "be884045", "metadata": {}, "outputs": [], "source": [ "file = \"はるまきごはん,初音ミク - 宇宙分解.mp3\"\n", "waveform, sr = librosa.load(file, sr=44100, mono=False)\n", "waveform = torch.tensor(waveform).float()\n", "mixed_wave = waveform.to(\"cuda\")" ] }, { "cell_type": "code", "execution_count": null, "id": "badd73dd", "metadata": {}, "outputs": [], "source": [ "with torch.inference_mode():\n", " predicted_stems = pl_model.model.separate(\n", " torch.tensor(mixed_wave).to(\"cuda\"),\n", " chunk_size=wave_chunk_size,\n", " overlap_size=wave_chunk_size // 2,\n", " gap_size=0,\n", " batch_size=8,\n", " ) # (stems, channels, time)" ] }, { "cell_type": "code", "execution_count": null, "id": "a1fa8bbd", "metadata": {}, "outputs": [], "source": [ "predicted_stems.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "751c4974", "metadata": {}, "outputs": [], "source": [ "os.makedirs(\"./outputs\", exist_ok=True)\n", "\n", "for i in range(predicted_stems.shape[0]):\n", " import soundfile as sf\n", "\n", " sf.write(f\"./outputs/predicted_stem_{i}.wav\", predicted_stems[i].cpu().numpy().T, 44100)" ] } ], "metadata": { "kernelspec": { "display_name": "20250820_bs-roformer", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }