HiDolen commited on
Commit
94e120b
·
verified ·
1 Parent(s): 5f78dd6

Upload training process.ipynb

Browse files
Files changed (1) hide show
  1. training process.ipynb +1598 -0
training process.ipynb ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e8306d9f",
6
+ "metadata": {},
7
+ "source": [
8
+ "## 初始化"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "51338b4a",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import time\n",
20
+ "from typing import List, Union, Optional\n",
21
+ "import math\n",
22
+ "from types import SimpleNamespace\n",
23
+ "import random\n",
24
+ "import glob\n",
25
+ "from pathlib import Path\n",
26
+ "import pickle\n",
27
+ "\n",
28
+ "import torch\n",
29
+ "import torch.nn as nn\n",
30
+ "import torch.nn.functional as F\n",
31
+ "import torch.optim as optim\n",
32
+ "from torch.utils.data import DataLoader, IterableDataset, Dataset\n",
33
+ "\n",
34
+ "from transformers.configuration_utils import PretrainedConfig\n",
35
+ "from transformers.modeling_utils import PreTrainedModel\n",
36
+ "from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n",
37
+ "from transformers.activations import ACT2FN\n",
38
+ "\n",
39
+ "from einops import rearrange, pack, unpack\n",
40
+ "import numpy as np\n",
41
+ "from tqdm import tqdm\n",
42
+ "\n",
43
+ "import soundfile\n",
44
+ "import audiomentations\n",
45
+ "\n",
46
+ "import numpy as np\n",
47
+ "from tqdm import tqdm\n",
48
+ "\n",
49
+ "from pl_utils import BaseModule, LearningRateConfig, TrainingConfig"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "e15cad0e",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "from pl_utils import init_before_training\n",
60
+ "\n",
61
+ "\n",
62
+ "init_before_training(\n",
63
+ " matmul_precision=\"medium\",\n",
64
+ " empty_cache=False,\n",
65
+ " seed=42,\n",
66
+ ")\n",
67
+ "\n",
68
+ "num_workers = 28"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "id": "a828912f",
74
+ "metadata": {},
75
+ "source": [
76
+ "## 定义"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "id": "9592af7a",
82
+ "metadata": {},
83
+ "source": [
84
+ "### Utils 定义"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "84dd1eec",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "def loudness_db2linear(db):\n",
95
+ " return 10 ** (db / 20)\n",
96
+ "\n",
97
+ "\n",
98
+ "def loudness_linear2db(linear):\n",
99
+ " return 20 * np.log10(linear)\n",
100
+ "\n",
101
+ "\n",
102
+ "def inference_one_with_model(\n",
103
+ " model,\n",
104
+ " mixed_wave,\n",
105
+ " chunk_size=44100 * 8,\n",
106
+ " overlap_size=44100 * 4,\n",
107
+ " batch_size=16,\n",
108
+ " gap_size=44100 * 1,\n",
109
+ "):\n",
110
+ " \"\"\"\n",
111
+ " 输入一段 (C, wave_length) 音频张量,使用模型推理,输出 (num_stems, C, wave_length) 音频张量。\n",
112
+ " \"\"\"\n",
113
+ " # 淡入淡出 窗口\n",
114
+ " fade_size = chunk_size // 10\n",
115
+ " window = torch.ones(chunk_size - 2 * gap_size)\n",
116
+ " window[:fade_size] = torch.linspace(0, 1, fade_size)\n",
117
+ " window[-fade_size:] = torch.linspace(1, 0, fade_size)\n",
118
+ " window = F.pad(window, (gap_size, gap_size), value=0.0)\n",
119
+ " window = window.to(mixed_wave.device)\n",
120
+ "\n",
121
+ " with torch.inference_mode():\n",
122
+ " wave_length = mixed_wave.shape[-1]\n",
123
+ "\n",
124
+ " if wave_length <= chunk_size:\n",
125
+ " num_chunks = 1\n",
126
+ " else:\n",
127
+ " num_chunks = math.ceil((wave_length - chunk_size) / overlap_size) + 1\n",
128
+ "\n",
129
+ " required_length = (num_chunks - 1) * overlap_size + chunk_size\n",
130
+ " padded_wave = F.pad(\n",
131
+ " mixed_wave,\n",
132
+ " (0, required_length - wave_length),\n",
133
+ " mode=\"constant\",\n",
134
+ " )\n",
135
+ "\n",
136
+ " unfolded_chunks = padded_wave.unfold(\n",
137
+ " dimension=-1,\n",
138
+ " size=chunk_size,\n",
139
+ " step=overlap_size,\n",
140
+ " ) # (C, num_chunks, chunk_size)\n",
141
+ " batch = unfolded_chunks.permute(1, 0, 2) # (num_chunks, C, chunk_size)\n",
142
+ "\n",
143
+ " output_chunks = []\n",
144
+ " for i in range(0, num_chunks, batch_size):\n",
145
+ " chunk_batch = batch[i : i + batch_size]\n",
146
+ " output_chunk = model(chunk_batch)\n",
147
+ " output_chunks.append(output_chunk)\n",
148
+ " batch = torch.cat(output_chunks, dim=0) # (num_chunks, num_stems, C, chunk_size)\n",
149
+ "\n",
150
+ " _, num_stems, C, _ = batch.shape\n",
151
+ " batch = batch.view(num_chunks, -1, chunk_size).permute(1, 0, 2) # (num_stems * C, num_chunks, chunk_size)\n",
152
+ " batch = batch * window\n",
153
+ " output_result_buffer = F.fold(\n",
154
+ " batch.permute(0, 2, 1),\n",
155
+ " output_size=(1, required_length),\n",
156
+ " kernel_size=(1, chunk_size),\n",
157
+ " stride=(1, overlap_size),\n",
158
+ " ) # (num_stems * C, 1, 1, required_length)\n",
159
+ "\n",
160
+ " window_for_fold = window.expand(1, 1, -1).repeat(1, num_chunks, 1)\n",
161
+ " weighted_sum_counter = F.fold(\n",
162
+ " window_for_fold.permute(0, 2, 1),\n",
163
+ " output_size=(1, required_length),\n",
164
+ " kernel_size=(1, chunk_size),\n",
165
+ " stride=(1, overlap_size),\n",
166
+ " ) # (1, 1, 1, required_length)\n",
167
+ "\n",
168
+ " output_result_buffer = output_result_buffer.view(num_stems, C, -1) # (num_stems, C, required_length)\n",
169
+ " weighted_sum_counter = weighted_sum_counter.view(1, 1, -1)\n",
170
+ " weighted_sum_counter.clamp_min_(1e-8)\n",
171
+ "\n",
172
+ " final_output = (output_result_buffer / weighted_sum_counter)[:, :, :wave_length]\n",
173
+ "\n",
174
+ " return final_output"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "markdown",
179
+ "id": "68c460af",
180
+ "metadata": {},
181
+ "source": [
182
+ "### Dataset 定义"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "71aaa349",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "class AugmentDataset(IterableDataset):\n",
193
+ " \"\"\"\n",
194
+ " 用于 MUSDB18HQ 数据的、含有数据增强的 Dataset。返回分块音频。\n",
195
+ "\n",
196
+ " 期望的数据目录结构:\n",
197
+ "\n",
198
+ " dataset/\n",
199
+ " ├── A Classic Education - NightOwl\n",
200
+ " │ ├── bass.wav\n",
201
+ " │ ├── drums.wav\n",
202
+ " │ ├── mixture.wav\n",
203
+ " │ ├── other.wav\n",
204
+ " │ └── vocals.wav\n",
205
+ " ├── Actions - Devil's Words\n",
206
+ " │ ├── bass.wav\n",
207
+ " │ ├── drums.wav\n",
208
+ " │ ├── mixture.wav\n",
209
+ " │ ├── other.wav\n",
210
+ " │ └── vocals.wav\n",
211
+ " ···\n",
212
+ " \"\"\"\n",
213
+ "\n",
214
+ " def __init__(\n",
215
+ " self,\n",
216
+ " data_path,\n",
217
+ " wave_chunk_size=44100 * 8,\n",
218
+ " sample_rate=44100,\n",
219
+ " same_stem_mixup_prob=[0.2, 0.02],\n",
220
+ " same_stem_mixup_loudness_range=[-3, 3],\n",
221
+ " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n",
222
+ " debug=False,\n",
223
+ " ):\n",
224
+ " if type(data_path) is not list:\n",
225
+ " data_path = [data_path]\n",
226
+ " self.data_path = [Path(p) for p in data_path]\n",
227
+ "\n",
228
+ " self.wave_chunk_size = wave_chunk_size\n",
229
+ " self.sample_rate = sample_rate\n",
230
+ "\n",
231
+ " self.same_stem_mixup_prob = same_stem_mixup_prob\n",
232
+ " self.same_stem_mixup_loudness_range = same_stem_mixup_loudness_range\n",
233
+ " self.stem_names = stem_names\n",
234
+ "\n",
235
+ " self.metadata = self._get_metadata()\n",
236
+ "\n",
237
+ " self.augments = audiomentations.Compose(\n",
238
+ " [\n",
239
+ " # 极性反转\n",
240
+ " audiomentations.PolarityInversion(p=0.5),\n",
241
+ " # 音高偏移\n",
242
+ " # audiomentations.PitchShift(\n",
243
+ " # min_semitones=-5,\n",
244
+ " # max_semitones=5,\n",
245
+ " # p=0.5,\n",
246
+ " # ),\n",
247
+ " # 七频段 eq 随机调整\n",
248
+ " audiomentations.SevenBandParametricEQ(\n",
249
+ " min_gain_db=-9,\n",
250
+ " max_gain_db=9,\n",
251
+ " p=1.0,\n",
252
+ " ),\n",
253
+ " # tanh 失真\n",
254
+ " audiomentations.TanhDistortion(\n",
255
+ " min_distortion=0.1,\n",
256
+ " max_distortion=0.6,\n",
257
+ " p=0.5,\n",
258
+ " ),\n",
259
+ " # 低品质失真\n",
260
+ " audiomentations.Mp3Compression(\n",
261
+ " min_bitrate=32,\n",
262
+ " max_bitrate=256,\n",
263
+ " p=0.4,\n",
264
+ " ),\n",
265
+ " # 拉伸\n",
266
+ " # audiomentations.TimeStretch(\n",
267
+ " # min_rate=0.8,\n",
268
+ " # max_rate=1.25,\n",
269
+ " # p=1.0,\n",
270
+ " # ),\n",
271
+ " # 随机音量\n",
272
+ " # audiomentations.GainTransition(\n",
273
+ " # min_gain_db=-3,\n",
274
+ " # max_gain_db=3,\n",
275
+ " # min_duration=0.5,\n",
276
+ " # max_duration=4.0,\n",
277
+ " # p=1.0,\n",
278
+ " # ),\n",
279
+ " ]\n",
280
+ " )\n",
281
+ "\n",
282
+ " self.file_handles = {}\n",
283
+ " self.debug = debug\n",
284
+ "\n",
285
+ " def _get_one_of_metadata(self, data_path):\n",
286
+ " song_paths = [p for p in data_path.iterdir() if p.is_dir()]\n",
287
+ " # 读取缓存\n",
288
+ " cache_path = data_path / \"metadata.pkl\"\n",
289
+ " if cache_path.exists():\n",
290
+ " with open(cache_path, \"rb\") as f:\n",
291
+ " song_metadata = pickle.load(f)\n",
292
+ " cache_paths = [m[0] for m in song_metadata]\n",
293
+ " # 文件没有改动,直接使用缓存\n",
294
+ " if set(cache_paths) == set(song_paths):\n",
295
+ " return song_metadata\n",
296
+ "\n",
297
+ " # 构建缓存\n",
298
+ " song_metadata = []\n",
299
+ " for song_path in tqdm(song_paths, desc=\"Scanning dataset\"):\n",
300
+ " wave_files = [f for f in song_path.iterdir() if f.is_file() and f.stem in self.stem_names]\n",
301
+ "\n",
302
+ " lengths = []\n",
303
+ " for wave_file in wave_files:\n",
304
+ " data, samplerate = soundfile.read(wave_file)\n",
305
+ " assert samplerate == self.sample_rate, f\"Sample rate {samplerate} is not desired {self.sample_rate}\"\n",
306
+ " track_length = len(data)\n",
307
+ " lengths.append(track_length)\n",
308
+ " if len(set(lengths)) > 1:\n",
309
+ " print(f\"Warning: Inconsistent track lengths found in {song_path}. Using min length: {min(lengths)}\")\n",
310
+ "\n",
311
+ " stem_file_dict = {f.stem: f for f in wave_files}\n",
312
+ " song_metadata.append((song_path, min(lengths), stem_file_dict))\n",
313
+ "\n",
314
+ " # 保存缓存\n",
315
+ " with open(cache_path, \"wb\") as f:\n",
316
+ " pickle.dump(song_metadata, f)\n",
317
+ "\n",
318
+ " return song_metadata\n",
319
+ "\n",
320
+ " def _get_metadata(self):\n",
321
+ " all_metadata = []\n",
322
+ " for p in self.data_path:\n",
323
+ " metadata = self._get_one_of_metadata(p)\n",
324
+ " all_metadata.extend(metadata)\n",
325
+ " return all_metadata\n",
326
+ "\n",
327
+ " def _load_random_wave(self, stem_name):\n",
328
+ " \"\"\"\n",
329
+ " 从 self.metadata 选取出指定 stem_name 的音轨。来源歌曲、截取位置都随机。\n",
330
+ "\n",
331
+ " 截取长度由 `self.wave_chunk_size` 决定。\n",
332
+ " \"\"\"\n",
333
+ "\n",
334
+ " # 尝试 10 次,保证音频响度大于 -50dB\n",
335
+ " for _ in range(10):\n",
336
+ " song_path, length, stem_file_dict = random.choice(self.metadata)\n",
337
+ "\n",
338
+ " # random offset within track\n",
339
+ " offset = np.random.randint(length - self.wave_chunk_size + 1)\n",
340
+ " # get or open cached file handle\n",
341
+ " file_path = stem_file_dict[stem_name]\n",
342
+ " if file_path not in self.file_handles:\n",
343
+ " self.file_handles[file_path] = soundfile.SoundFile(str(file_path), mode='r')\n",
344
+ " handle = self.file_handles[file_path]\n",
345
+ " # seek and read chunk\n",
346
+ " handle.seek(offset)\n",
347
+ " wave = handle.read(self.wave_chunk_size, dtype='float32')\n",
348
+ " wave = wave.T # (channel, time)\n",
349
+ " if len(wave.shape) == 1: # 对 mono 音频添加 channel 维度\n",
350
+ " wave = np.expand_dims(wave, axis=0)\n",
351
+ "\n",
352
+ " rms = np.sqrt(np.mean(wave**2))\n",
353
+ " if rms > loudness_db2linear(-50):\n",
354
+ " break\n",
355
+ "\n",
356
+ " if self.debug:\n",
357
+ " print(f\"Warning: sampled very silent audio from {file_path} (rms={rms:.6f})\")\n",
358
+ " # augmentation\n",
359
+ " wave = self._apply_augment(wave, stem_name)\n",
360
+ "\n",
361
+ " return wave\n",
362
+ "\n",
363
+ " def _load_random_stems(self):\n",
364
+ " \"\"\"\n",
365
+ " 加载随机的 self.stem_names 分轨。\n",
366
+ "\n",
367
+ " 包含的数据增强:\n",
368
+ "\n",
369
+ " - 单个 stem 的来源歌曲和截取位置都随机(由 `self._load_random_track()` 实现)\n",
370
+ " - 单个 stem 可能是多个同类型 stem 混合获得,概率由 `self.same_stem_mixup_prob` 决定\n",
371
+ " - 混合 stem 时各个 stem 的响度在 `self.same_stem_mixup_loudness_range` 范围内随机\n",
372
+ " \"\"\"\n",
373
+ " waves = []\n",
374
+ " for stem_name in self.stem_names:\n",
375
+ " wave = self._load_random_wave(stem_name)\n",
376
+ "\n",
377
+ " mixup_waves = [wave]\n",
378
+ " for prob in self.same_stem_mixup_prob:\n",
379
+ " if random.uniform(0, 1) < prob:\n",
380
+ " wave2 = self._load_random_wave(stem_name)\n",
381
+ " mixup_waves.append(wave2)\n",
382
+ "\n",
383
+ " mixup_waves = np.stack(mixup_waves, axis=0)\n",
384
+ "\n",
385
+ " # 在 self.same_stem_mixup_loudness_range 范围内的随机响度\n",
386
+ " loudness = np.random.uniform(\n",
387
+ " low=loudness_db2linear(self.same_stem_mixup_loudness_range[0]),\n",
388
+ " high=loudness_db2linear(self.same_stem_mixup_loudness_range[1]),\n",
389
+ " size=(len(mixup_waves),),\n",
390
+ " )\n",
391
+ " mixup_waves *= loudness[:, None, None]\n",
392
+ " mixup_wave = mixup_waves.mean(axis=0)\n",
393
+ "\n",
394
+ " waves.append(mixup_wave)\n",
395
+ "\n",
396
+ " waves = np.stack(waves, axis=0)\n",
397
+ "\n",
398
+ " return waves\n",
399
+ "\n",
400
+ " def _apply_augment(self, wave, stem_name):\n",
401
+ " # Channel shuffle\n",
402
+ " if random.uniform(0, 1) < 0.5:\n",
403
+ " wave = wave[::-1].copy()\n",
404
+ "\n",
405
+ " # self.stem_augment\n",
406
+ " wave = self.augments(samples=wave, sample_rate=self.sample_rate)\n",
407
+ "\n",
408
+ " return wave\n",
409
+ "\n",
410
+ " def __iter__(self):\n",
411
+ " while True:\n",
412
+ " waves = self._load_random_stems()\n",
413
+ "\n",
414
+ " # 随机分轨音量\n",
415
+ " loudnesses = np.random.uniform(\n",
416
+ " low=loudness_db2linear(-3),\n",
417
+ " high=loudness_db2linear(3),\n",
418
+ " size=(len(waves),),\n",
419
+ " )\n",
420
+ " # 各个 stem 有 10% 概率变为空音频\n",
421
+ " loudnesses *= (np.random.uniform(0, 1, size=(len(waves),)) > 0.1).astype(np.float32)\n",
422
+ " # 施加到 waves 上\n",
423
+ " waves *= loudnesses[:, None, None]\n",
424
+ "\n",
425
+ " # 获得混合音频\n",
426
+ " mixed_wave = waves.sum(0)\n",
427
+ "\n",
428
+ " yield waves, mixed_wave\n",
429
+ "\n",
430
+ " def __del__(self):\n",
431
+ " # Close any open SoundFile handles when dataset is destroyed\n",
432
+ " for handle in self.file_handles.values():\n",
433
+ " try:\n",
434
+ " handle.close()\n",
435
+ " except Exception:\n",
436
+ " pass\n",
437
+ "\n",
438
+ "\n",
439
+ "class ValidationDataset(Dataset):\n",
440
+ " \"\"\"\n",
441
+ " 用于 MUSDB18HQ 数据的、用于验证的 Dataset。返回完整音频。\n",
442
+ "\n",
443
+ " 期望的数据目录结构:\n",
444
+ "\n",
445
+ " dataset/\n",
446
+ " ├── A Classic Education - NightOwl\n",
447
+ " │ ├── bass.wav\n",
448
+ " │ ├── drums.wav\n",
449
+ " │ ├── mixture.wav\n",
450
+ " │ ├── other.wav\n",
451
+ " │ └── vocals.wav\n",
452
+ " ├── Actions - Devil's Words\n",
453
+ " │ ├── bass.wav\n",
454
+ " │ ├── drums.wav\n",
455
+ " │ ├── mixture.wav\n",
456
+ " │ ├── other.wav\n",
457
+ " │ └── vocals.wav\n",
458
+ " ···\n",
459
+ " \"\"\"\n",
460
+ "\n",
461
+ " def __init__(\n",
462
+ " self,\n",
463
+ " data_path,\n",
464
+ " sample_rate=44100,\n",
465
+ " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n",
466
+ " ):\n",
467
+ " self.data_path = Path(data_path)\n",
468
+ " self.sample_rate = sample_rate\n",
469
+ " self.stem_names = stem_names\n",
470
+ "\n",
471
+ " self.metadata = self._get_metadata()\n",
472
+ "\n",
473
+ " def _get_metadata(self):\n",
474
+ " song_paths = [p for p in self.data_path.iterdir() if p.is_dir()]\n",
475
+ " # 读取缓存\n",
476
+ " cache_path = self.data_path / \"metadata.pkl\"\n",
477
+ " if cache_path.exists():\n",
478
+ " with open(cache_path, \"rb\") as f:\n",
479
+ " song_metadata = pickle.load(f)\n",
480
+ " cache_paths = [m[0] for m in song_metadata]\n",
481
+ " # 文件没有改动,直接使用缓存\n",
482
+ " if set(cache_paths) == set(song_paths):\n",
483
+ " return song_metadata\n",
484
+ "\n",
485
+ " # 构建缓存\n",
486
+ " song_metadata = []\n",
487
+ " for song_path in tqdm(song_paths, desc=\"Scanning dataset\"):\n",
488
+ " wave_files = [f for f in song_path.iterdir() if f.is_file() and f.stem in self.stem_names]\n",
489
+ "\n",
490
+ " lengths = []\n",
491
+ " for wave_file in wave_files:\n",
492
+ " data, samplerate = soundfile.read(wave_file)\n",
493
+ " assert samplerate == self.sample_rate, f\"Sample rate {samplerate} is not desired {self.sample_rate}\"\n",
494
+ " track_length = len(data)\n",
495
+ " lengths.append(track_length)\n",
496
+ " if len(set(lengths)) > 1:\n",
497
+ " print(f\"Warning: Inconsistent track lengths found in {song_path}. Using min length: {min(lengths)}\")\n",
498
+ "\n",
499
+ " stem_file_dict = {f.stem: f for f in wave_files}\n",
500
+ " song_metadata.append((song_path, min(lengths), stem_file_dict))\n",
501
+ "\n",
502
+ " # 保存缓存\n",
503
+ " with open(cache_path, \"wb\") as f:\n",
504
+ " pickle.dump(song_metadata, f)\n",
505
+ "\n",
506
+ " return song_metadata\n",
507
+ "\n",
508
+ " def __len__(self):\n",
509
+ " return len(self.metadata)\n",
510
+ "\n",
511
+ " def __getitem__(self, index):\n",
512
+ " song_path, length, stem_file_dict = self.metadata[index]\n",
513
+ "\n",
514
+ " waves = []\n",
515
+ " for stem_name in self.stem_names:\n",
516
+ " stem_file = stem_file_dict[stem_name]\n",
517
+ " wave = soundfile.read(\n",
518
+ " stem_file,\n",
519
+ " dtype=\"float32\",\n",
520
+ " )[0]\n",
521
+ " wave = wave.T\n",
522
+ " if len(wave.shape) == 1: # 对 mono 音频添加 channel 维度\n",
523
+ " wave = np.expand_dims(wave, axis=0)\n",
524
+ " waves.append(wave)\n",
525
+ "\n",
526
+ " waves = np.stack(waves, axis=0) # (stem, channel, time)\n",
527
+ "\n",
528
+ " # 获得混合音频\n",
529
+ " mixed_wave = waves.sum(0)\n",
530
+ "\n",
531
+ " return waves, mixed_wave"
532
+ ]
533
+ },
534
+ {
535
+ "cell_type": "markdown",
536
+ "id": "22caec1a",
537
+ "metadata": {},
538
+ "source": [
539
+ "### ModuleConfig 定义"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": null,
545
+ "id": "591a48cd",
546
+ "metadata": {},
547
+ "outputs": [],
548
+ "source": [
549
+ "DEFAULT_FREQS_PER_BANDS = (\n",
550
+ " 2,\n",
551
+ " 2,\n",
552
+ " 2,\n",
553
+ " 2,\n",
554
+ " 2,\n",
555
+ " 2,\n",
556
+ " 2,\n",
557
+ " 2,\n",
558
+ " 2,\n",
559
+ " 2,\n",
560
+ " 2,\n",
561
+ " 2,\n",
562
+ " 2,\n",
563
+ " 2,\n",
564
+ " 2,\n",
565
+ " 2,\n",
566
+ " 2,\n",
567
+ " 2,\n",
568
+ " 2,\n",
569
+ " 2,\n",
570
+ " 2,\n",
571
+ " 2,\n",
572
+ " 2,\n",
573
+ " 2,\n",
574
+ " 4,\n",
575
+ " 4,\n",
576
+ " 4,\n",
577
+ " 4,\n",
578
+ " 4,\n",
579
+ " 4,\n",
580
+ " 4,\n",
581
+ " 4,\n",
582
+ " 4,\n",
583
+ " 4,\n",
584
+ " 4,\n",
585
+ " 4,\n",
586
+ " 12,\n",
587
+ " 12,\n",
588
+ " 12,\n",
589
+ " 12,\n",
590
+ " 12,\n",
591
+ " 12,\n",
592
+ " 12,\n",
593
+ " 12,\n",
594
+ " 24,\n",
595
+ " 24,\n",
596
+ " 24,\n",
597
+ " 24,\n",
598
+ " 24,\n",
599
+ " 24,\n",
600
+ " 24,\n",
601
+ " 24,\n",
602
+ " 48,\n",
603
+ " 48,\n",
604
+ " 48,\n",
605
+ " 48,\n",
606
+ " 48,\n",
607
+ " 48,\n",
608
+ " 48,\n",
609
+ " 48,\n",
610
+ " 128,\n",
611
+ " 129,\n",
612
+ ")\n",
613
+ "\n",
614
+ "\n",
615
+ "class BSRoformerConfig(PretrainedConfig):\n",
616
+ "\n",
617
+ " model_type = \"bs_roformer\"\n",
618
+ "\n",
619
+ " def __init__(\n",
620
+ " self,\n",
621
+ " hidden_size=384,\n",
622
+ " depth=6,\n",
623
+ " num_input_channel=1,\n",
624
+ " num_stems=1,\n",
625
+ " time_transformer_depth=2,\n",
626
+ " freq_transformer_depth=2,\n",
627
+ " freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,\n",
628
+ " attention_dropout=0.0,\n",
629
+ " num_attention_heads=8,\n",
630
+ " num_key_value_heads=8,\n",
631
+ " intermediate_size=384 * 4,\n",
632
+ " #\n",
633
+ " stft_n_fft=2048,\n",
634
+ " stft_hop_length=512,\n",
635
+ " stft_win_length=2048,\n",
636
+ " mask_estimator_depth=2,\n",
637
+ " multi_stft_loss_weight=1.0, # TODO 权重降低会发生什么\n",
638
+ " multi_stft_loss_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256),\n",
639
+ " multi_stft_loss_hop_size=147,\n",
640
+ " rms_norm_eps=1e-6,\n",
641
+ " rope_theta=10000.0,\n",
642
+ " #\n",
643
+ " initializer_range=0.02,\n",
644
+ " register_token_num=4,\n",
645
+ " **kwargs,\n",
646
+ " ):\n",
647
+ " self.hidden_size = hidden_size\n",
648
+ " self.depth = depth\n",
649
+ " self.num_input_channel = num_input_channel\n",
650
+ " self.num_stems = num_stems\n",
651
+ " self.time_transformer_depth = time_transformer_depth\n",
652
+ " self.freq_transformer_depth = freq_transformer_depth\n",
653
+ " self.freqs_per_bands = freqs_per_bands\n",
654
+ " self.attention_dropout = attention_dropout\n",
655
+ " self.num_attention_heads = num_attention_heads\n",
656
+ " self.num_key_value_heads = num_key_value_heads\n",
657
+ " self.intermediate_size = intermediate_size\n",
658
+ "\n",
659
+ " self.stft_n_fft = stft_n_fft\n",
660
+ " self.stft_hop_length = stft_hop_length\n",
661
+ " self.stft_win_length = stft_win_length\n",
662
+ "\n",
663
+ " self.mask_estimator_depth = mask_estimator_depth\n",
664
+ " self.multi_stft_loss_weight = multi_stft_loss_weight\n",
665
+ " self.multi_stft_loss_window_sizes = multi_stft_loss_window_sizes\n",
666
+ " self.multi_stft_loss_hop_size = multi_stft_loss_hop_size\n",
667
+ " self.rms_norm_eps = rms_norm_eps\n",
668
+ " self.rope_theta = rope_theta\n",
669
+ "\n",
670
+ " self.initializer_range = initializer_range\n",
671
+ " self.register_token_num = register_token_num\n",
672
+ "\n",
673
+ " super().__init__(**kwargs)"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "markdown",
678
+ "id": "ba4ce953",
679
+ "metadata": {},
680
+ "source": [
681
+ "### 模型定义"
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": null,
687
+ "id": "48b33373",
688
+ "metadata": {},
689
+ "outputs": [],
690
+ "source": [
691
+ "# RoPE\n",
692
+ "class BSRoformerRotaryEmbedding(nn.Module):\n",
693
+ " def __init__(self, dim, theta=10000.0):\n",
694
+ " super().__init__()\n",
695
+ " inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))\n",
696
+ " self.register_buffer(\"inv_freq\", inv_freq)\n",
697
+ "\n",
698
+ " def forward(self, x, seq_len: int):\n",
699
+ " t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)\n",
700
+ " freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n",
701
+ " emb = torch.cat((freqs, freqs), dim=-1)\n",
702
+ " return emb.cos(), emb.sin()\n",
703
+ "\n",
704
+ "\n",
705
+ "def rotate_half(x):\n",
706
+ " x1 = x[..., : x.shape[-1] // 2]\n",
707
+ " x2 = x[..., x.shape[-1] // 2 :]\n",
708
+ " return torch.cat((-x2, x1), dim=-1)\n",
709
+ "\n",
710
+ "\n",
711
+ "def apply_rotary_pos_emb(q, k, cos, sin):\n",
712
+ " q_embed = (q * cos) + (rotate_half(q) * sin)\n",
713
+ " k_embed = (k * cos) + (rotate_half(k) * sin)\n",
714
+ " return q_embed, k_embed\n",
715
+ "\n",
716
+ "\n",
717
+ "class RotaryEmbedding(nn.Module):\n",
718
+ " def __init__(self, config: BSRoformerConfig):\n",
719
+ " super().__init__()\n",
720
+ " self.head_dim = config.hidden_size // config.num_attention_heads\n",
721
+ " inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))\n",
722
+ " self.register_buffer(\"inv_freq\", inv_freq)\n",
723
+ "\n",
724
+ " def forward(self, x, position_ids):\n",
725
+ " inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n",
726
+ " position_ids_expanded = position_ids[:, None, :].float()\n",
727
+ "\n",
728
+ " device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n",
729
+ " with torch.autocast(device_type=device_type, enabled=False): # Force float32\n",
730
+ " freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n",
731
+ " emb = torch.cat((freqs, freqs), dim=-1)\n",
732
+ " cos = emb.cos()\n",
733
+ " sin = emb.sin()\n",
734
+ "\n",
735
+ " return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n",
736
+ "\n",
737
+ "\n",
738
+ "# Attention\n",
739
+ "class BSRoformerMLP(nn.Module):\n",
740
+ " def __init__(self, config: BSRoformerConfig):\n",
741
+ " super().__init__()\n",
742
+ " self.config = config\n",
743
+ " self.hidden_size = config.hidden_size\n",
744
+ " self.intermediate_size = config.intermediate_size\n",
745
+ " self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n",
746
+ " self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n",
747
+ " self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n",
748
+ " self.act_fn = ACT2FN[\"gelu\"]\n",
749
+ "\n",
750
+ " def forward(self, x):\n",
751
+ " down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n",
752
+ " return down_proj\n",
753
+ "\n",
754
+ "\n",
755
+ "class BSRoformerAttention(nn.Module):\n",
756
+ " def __init__(self, config: BSRoformerConfig):\n",
757
+ " super().__init__()\n",
758
+ " self.is_causal = False\n",
759
+ " self.config = config\n",
760
+ "\n",
761
+ " self.head_dim = config.hidden_size // config.num_attention_heads\n",
762
+ " self.scaling = self.head_dim**-0.5\n",
763
+ " self.attention_dropout = config.attention_dropout\n",
764
+ "\n",
765
+ " self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n",
766
+ "\n",
767
+ " self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)\n",
768
+ " self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)\n",
769
+ " self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)\n",
770
+ " self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)\n",
771
+ "\n",
772
+ " def forward(\n",
773
+ " self,\n",
774
+ " hidden_states,\n",
775
+ " position_embeddings: tuple[torch.Tensor, torch.Tensor],\n",
776
+ " attention_mask=None,\n",
777
+ " ):\n",
778
+ " input_shape = hidden_states.size()[:-1]\n",
779
+ " hidden_shape = (*input_shape, -1, self.head_dim) # b, n, d -> b, n, -1, d'\n",
780
+ "\n",
781
+ " # proj\n",
782
+ " query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n",
783
+ " key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n",
784
+ " value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n",
785
+ "\n",
786
+ " # positional embeddings\n",
787
+ " cos, sin = position_embeddings\n",
788
+ " query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n",
789
+ "\n",
790
+ " # multi-group attention\n",
791
+ " # key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)\n",
792
+ " # value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)\n",
793
+ "\n",
794
+ " attention_interface = ALL_ATTENTION_FUNCTIONS[\"sdpa\"]\n",
795
+ "\n",
796
+ " attn_output, attn_weights = attention_interface(\n",
797
+ " self,\n",
798
+ " query_states,\n",
799
+ " key_states,\n",
800
+ " value_states,\n",
801
+ " attention_mask,\n",
802
+ " dropout=0.0 if not self.training else self.attention_dropout,\n",
803
+ " scaling=self.scaling,\n",
804
+ " )\n",
805
+ "\n",
806
+ " attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n",
807
+ " attn_output = self.o_proj(attn_output)\n",
808
+ "\n",
809
+ " return attn_output, attn_weights\n",
810
+ "\n",
811
+ "\n",
812
+ "class BSRoformerLayer(nn.Module):\n",
813
+ " def __init__(self, config: BSRoformerConfig):\n",
814
+ " super().__init__()\n",
815
+ " self.self_attn = BSRoformerAttention(config)\n",
816
+ " self.mlp = BSRoformerMLP(config)\n",
817
+ "\n",
818
+ " self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
819
+ " self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
820
+ "\n",
821
+ " def forward(\n",
822
+ " self,\n",
823
+ " hidden_states,\n",
824
+ " position_embeddings,\n",
825
+ " attention_mask,\n",
826
+ " ):\n",
827
+ " # Self Attention\n",
828
+ " residual = hidden_states\n",
829
+ " hidden_states = self.input_layernorm(hidden_states)\n",
830
+ " hidden_states, _ = self.self_attn(\n",
831
+ " hidden_states,\n",
832
+ " position_embeddings,\n",
833
+ " attention_mask,\n",
834
+ " )\n",
835
+ " hidden_states = hidden_states + residual\n",
836
+ "\n",
837
+ " # Fully Connected\n",
838
+ " residual = hidden_states\n",
839
+ " hidden_states = self.post_attention_layernorm(hidden_states)\n",
840
+ " hidden_states = self.mlp(hidden_states)\n",
841
+ " hidden_states = hidden_states + residual\n",
842
+ "\n",
843
+ " return hidden_states\n",
844
+ "\n",
845
+ "\n",
846
+ "class BSRoformerAxialTransformer(nn.Module):\n",
847
+ " def __init__(\n",
848
+ " self,\n",
849
+ " config: BSRoformerConfig,\n",
850
+ " transformer_depth: int,\n",
851
+ " is_time_transformer: bool,\n",
852
+ " ):\n",
853
+ " super().__init__()\n",
854
+ " self.layers = nn.ModuleList([BSRoformerLayer(config) for _ in range(transformer_depth)])\n",
855
+ " self.is_time_transformer = is_time_transformer\n",
856
+ "\n",
857
+ " def forward(\n",
858
+ " self,\n",
859
+ " hidden_states,\n",
860
+ " position_embeddings,\n",
861
+ " attention_mask,\n",
862
+ " ):\n",
863
+ " if self.is_time_transformer:\n",
864
+ " hidden_states = rearrange(hidden_states, 'b t f d -> b f t d')\n",
865
+ "\n",
866
+ " # merge batch\n",
867
+ " b, seq_len_1, seq_len_2, d = hidden_states.shape\n",
868
+ " hidden_states = rearrange(hidden_states, 'b n m d -> (b n) m d')\n",
869
+ "\n",
870
+ " for layer in self.layers:\n",
871
+ " hidden_states = layer(\n",
872
+ " hidden_states,\n",
873
+ " position_embeddings,\n",
874
+ " attention_mask,\n",
875
+ " )\n",
876
+ "\n",
877
+ " # unpack batch\n",
878
+ " hidden_states = rearrange(hidden_states, '(b n) m d -> b n m d', b=b)\n",
879
+ "\n",
880
+ " if self.is_time_transformer:\n",
881
+ " hidden_states = rearrange(hidden_states, 'b f t d -> b t f d')\n",
882
+ "\n",
883
+ " return hidden_states\n",
884
+ "\n",
885
+ "\n",
886
+ "# BandSplit & MaskEstimator\n",
887
+ "class BandSplit(nn.Module):\n",
888
+ " def __init__(self, config: BSRoformerConfig):\n",
889
+ " super().__init__()\n",
890
+ " self.dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands)\n",
891
+ " self.to_features = nn.ModuleList(\n",
892
+ " [\n",
893
+ " nn.Sequential(nn.RMSNorm(dim_in, eps=config.rms_norm_eps), nn.Linear(dim_in, config.hidden_size))\n",
894
+ " for dim_in in self.dim_inputs\n",
895
+ " ]\n",
896
+ " )\n",
897
+ "\n",
898
+ " def forward(self, x):\n",
899
+ " x_split = x.split(self.dim_inputs, dim=-1)\n",
900
+ " outs = [to_feature(split_input) for split_input, to_feature in zip(x_split, self.to_features)]\n",
901
+ " return torch.stack(outs, dim=-2)\n",
902
+ "\n",
903
+ "\n",
904
+ "def MLP(dim_in, dim_out, dim_hidden, depth, activation=nn.Tanh):\n",
905
+ " net = []\n",
906
+ " dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)\n",
907
+ " for i, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):\n",
908
+ " net.append(nn.Linear(layer_dim_in, layer_dim_out))\n",
909
+ " if i < len(dims) - 2:\n",
910
+ " net.append(activation())\n",
911
+ " return nn.Sequential(*net)\n",
912
+ "\n",
913
+ "\n",
914
+ "class MaskEstimator(nn.Module):\n",
915
+ " def __init__(self, config: BSRoformerConfig):\n",
916
+ " super().__init__()\n",
917
+ "\n",
918
+ " class MiniGeGLU(nn.Module):\n",
919
+ "\n",
920
+ " def __init__(self, out_size):\n",
921
+ " super().__init__()\n",
922
+ " self.gate_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
923
+ " self.up_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n",
924
+ " self.down_proj = nn.Linear(config.hidden_size, out_size, bias=False)\n",
925
+ " self.act_fn = nn.GELU()\n",
926
+ "\n",
927
+ " def forward(self, x):\n",
928
+ " down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n",
929
+ " return down_proj\n",
930
+ "\n",
931
+ " dim_inputs = tuple(2 * f * config.num_input_channel for f in config.freqs_per_bands)\n",
932
+ " # self.to_freq_mlps = nn.ModuleList([MiniGeGLU(dim_in) for dim_in in dim_inputs])\n",
933
+ " self.to_freq_mlps = nn.ModuleList([nn.Linear(config.hidden_size, dim_in) for dim_in in dim_inputs])\n",
934
+ "\n",
935
+ " def forward(self, x):\n",
936
+ " x_unbind = x.unbind(dim=-2)\n",
937
+ " outs = [mlp(band_features) for band_features, mlp in zip(x_unbind, self.to_freq_mlps)]\n",
938
+ " return torch.cat(outs, dim=-1)\n",
939
+ "\n",
940
+ "\n",
941
+ "# Main Model\n",
942
+ "class BSRoformerPreTrainedModel(PreTrainedModel):\n",
943
+ " config_class = BSRoformerConfig\n",
944
+ " base_model_prefix = \"model\"\n",
945
+ " supports_gradient_checkpointing = True\n",
946
+ " _no_split_modules = [\"BSRoformerLayer\"]\n",
947
+ "\n",
948
+ "\n",
949
+ "class BSRoformerModel(BSRoformerPreTrainedModel):\n",
950
+ " def __init__(self, config: BSRoformerConfig):\n",
951
+ " super().__init__(config)\n",
952
+ " self.config = config\n",
953
+ " self.band_split = BandSplit(config)\n",
954
+ " self.layers = nn.ModuleList(\n",
955
+ " nn.ModuleList(\n",
956
+ " [\n",
957
+ " BSRoformerAxialTransformer(config, config.time_transformer_depth, is_time_transformer=True),\n",
958
+ " BSRoformerAxialTransformer(config, config.freq_transformer_depth, is_time_transformer=False),\n",
959
+ " ]\n",
960
+ " )\n",
961
+ " for _ in range(config.depth)\n",
962
+ " )\n",
963
+ " self.rotary_emb = RotaryEmbedding(config)\n",
964
+ " self.final_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
965
+ "\n",
966
+ " rn = config.register_token_num\n",
967
+ " self.register_tokens = nn.Parameter(torch.normal(0, 0.02, size=(rn, rn, config.hidden_size)))\n",
968
+ "\n",
969
+ " self.post_init()\n",
970
+ "\n",
971
+ " def forward(\n",
972
+ " self,\n",
973
+ " x,\n",
974
+ " position_ids=None,\n",
975
+ " ):\n",
976
+ " hidden_states = self.band_split(x)\n",
977
+ "\n",
978
+ " b, t, n, h = hidden_states.shape # [batch, t, n, hidden_size]\n",
979
+ "\n",
980
+ " if position_ids is None:\n",
981
+ " position_ids = torch.arange(t, device=hidden_states.device).unsqueeze(0)\n",
982
+ " pos_embeds = self.rotary_emb(hidden_states, position_ids)\n",
983
+ " pos_embeds_for_freq = self.rotary_emb(\n",
984
+ " hidden_states,\n",
985
+ " torch.arange(n, device=hidden_states.device).unsqueeze(0),\n",
986
+ " )\n",
987
+ "\n",
988
+ " # add register tokens\n",
989
+ " rn = self.config.register_token_num\n",
990
+ " hidden_states = F.pad(hidden_states, (0, 0, 0, rn, 0, rn))\n",
991
+ " hidden_states[:, t:, n:, :] = self.register_tokens\n",
992
+ "\n",
993
+ " def pad_rope(cos, sin):\n",
994
+ " cos_padded = F.pad(cos, (0, 0, 0, rn), value=1.0)\n",
995
+ " sin_padded = F.pad(sin, (0, 0, 0, rn), value=0.0)\n",
996
+ " return cos_padded, sin_padded\n",
997
+ "\n",
998
+ " pos_embeds = pad_rope(*pos_embeds)\n",
999
+ " pos_embeds_for_freq = pad_rope(*pos_embeds_for_freq)\n",
1000
+ "\n",
1001
+ " for time_transformer, freq_transformer in self.layers:\n",
1002
+ " hidden_states = time_transformer(\n",
1003
+ " hidden_states,\n",
1004
+ " position_embeddings=pos_embeds,\n",
1005
+ " attention_mask=None,\n",
1006
+ " )\n",
1007
+ " hidden_states = freq_transformer(\n",
1008
+ " hidden_states,\n",
1009
+ " position_embeddings=pos_embeds_for_freq,\n",
1010
+ " attention_mask=None,\n",
1011
+ " )\n",
1012
+ "\n",
1013
+ " hidden_states = hidden_states[:, :t, :n, :]\n",
1014
+ "\n",
1015
+ " return self.final_norm(hidden_states)\n",
1016
+ "\n",
1017
+ "\n",
1018
+ "class BSRoformerForMaskedEstimation(BSRoformerPreTrainedModel):\n",
1019
+ " def __init__(self, config: BSRoformerConfig):\n",
1020
+ " super().__init__(config)\n",
1021
+ " self.config = config\n",
1022
+ " self.model = BSRoformerModel(config)\n",
1023
+ " self.mask_estimators = nn.ModuleList([MaskEstimator(config) for _ in range(config.num_stems)])\n",
1024
+ "\n",
1025
+ " # STFT parameters\n",
1026
+ " self.stft_kwargs = dict(\n",
1027
+ " n_fft=config.stft_n_fft,\n",
1028
+ " hop_length=config.stft_hop_length,\n",
1029
+ " win_length=config.stft_win_length,\n",
1030
+ " normalized=False,\n",
1031
+ " )\n",
1032
+ " self.register_buffer(\"stft_window\", torch.hann_window(config.stft_win_length), persistent=False)\n",
1033
+ "\n",
1034
+ " freqs = config.stft_n_fft // 2 + 1\n",
1035
+ " assert sum(config.freqs_per_bands) == freqs, f\"Sum of freqs_per_bands must be {freqs}\"\n",
1036
+ " self.wave_channels = config.num_input_channel\n",
1037
+ "\n",
1038
+ " def forward(\n",
1039
+ " self,\n",
1040
+ " raw_audio: torch.Tensor,\n",
1041
+ " target: Optional[torch.Tensor] = None,\n",
1042
+ " return_loss_breakdown: bool = False,\n",
1043
+ " ):\n",
1044
+ " \"\"\"\n",
1045
+ " Args:\n",
1046
+ " raw_audio (`torch.Tensor` of shape `(batch, channels, time)`):\n",
1047
+ " The raw audio waveform.\n",
1048
+ " target (`torch.Tensor`, *optional*, shape `(batch, num_stems, channels, time)`):\n",
1049
+ " The target audio waveform for loss calculation.\n",
1050
+ " return_loss_breakdown (`bool`, *optional*, defaults to `False`):\n",
1051
+ " Whether to return the breakdown of the loss components.\n",
1052
+ "\n",
1053
+ " Returns:\n",
1054
+ " torch.Tensor (`torch.Tensor` of shape `(batch, num_stems, channels, time)`):\n",
1055
+ " The reconstructed audio waveform.\n",
1056
+ " \"\"\"\n",
1057
+ " device = raw_audio.device\n",
1058
+ "\n",
1059
+ " # 1. STFT: Convert audio to spectrogram\n",
1060
+ " with torch.autocast(device_type=device.type, enabled=False):\n",
1061
+ " b, c, t = raw_audio.shape # batch, channel, time\n",
1062
+ " raw_audio_packed = rearrange(raw_audio, \"b c t -> (b c) t\")\n",
1063
+ " stft_repr = torch.stft(\n",
1064
+ " raw_audio_packed,\n",
1065
+ " **self.stft_kwargs,\n",
1066
+ " window=self.stft_window,\n",
1067
+ " return_complex=True,\n",
1068
+ " )\n",
1069
+ " stft_repr = torch.view_as_real(stft_repr) # (b, c, t) -> (b, c, f, t, 2)\n",
1070
+ " stft_repr = rearrange(stft_repr, \"(b c) f t T -> b c f t T\", c=c)\n",
1071
+ " # Merge frequency, channel, and complex dimensions for the model\n",
1072
+ " stft_repr_merged = rearrange(stft_repr, \"b c f t T -> b t (f c T)\")\n",
1073
+ "\n",
1074
+ " # 2. Model Processing\n",
1075
+ " hidden_states = self.model(stft_repr_merged)\n",
1076
+ "\n",
1077
+ " # 3. Mask Estimation\n",
1078
+ " # (b, t, d) -> (b, n, t, (f c 2)) where n is num_stems\n",
1079
+ " mask = torch.stack([fn(hidden_states) for fn in self.mask_estimators], dim=1)\n",
1080
+ " mask = rearrange(mask, \"b n t (f c T) -> b n c f t T\", T=2, c=c)\n",
1081
+ " mask = mask.to(dtype=torch.float32)\n",
1082
+ "\n",
1083
+ " # 4. Mask Application\n",
1084
+ " with torch.autocast(device_type=device.type, enabled=False):\n",
1085
+ " stft_repr_expanded = rearrange(stft_repr, \"b c f t T -> b 1 c f t T\")\n",
1086
+ " stft_repr_complex = torch.view_as_complex(stft_repr_expanded)\n",
1087
+ " mask_complex = torch.view_as_complex(mask)\n",
1088
+ " masked_stft = stft_repr_complex * mask_complex\n",
1089
+ "\n",
1090
+ " # 5. iSTFT: Convert masked spectrogram back to audio\n",
1091
+ " # (b, n, c, f, t) -> ((b n c), f, t)\n",
1092
+ " masked_stft = rearrange(masked_stft, \"b n c f t -> (b n c) f t\")\n",
1093
+ " recon_audio = torch.istft(\n",
1094
+ " masked_stft,\n",
1095
+ " **self.stft_kwargs,\n",
1096
+ " window=self.stft_window,\n",
1097
+ " return_complex=False,\n",
1098
+ " length=raw_audio.shape[-1],\n",
1099
+ " )\n",
1100
+ " # ((b n c), t) -> (b, n, c, t)\n",
1101
+ " recon_audio = rearrange(recon_audio, \"(b n c) t -> b n c t\", c=self.wave_channels, n=self.config.num_stems)\n",
1102
+ "\n",
1103
+ " if target is None:\n",
1104
+ " return recon_audio\n",
1105
+ "\n",
1106
+ " # 6. Loss Calculation\n",
1107
+ " # Ensure target has the same length as the reconstructed audio\n",
1108
+ " target = target[..., : recon_audio.shape[-1]]\n",
1109
+ "\n",
1110
+ " loss = F.l1_loss(recon_audio, target)\n",
1111
+ "\n",
1112
+ " return loss\n"
1113
+ ]
1114
+ },
1115
+ {
1116
+ "cell_type": "code",
1117
+ "execution_count": null,
1118
+ "id": "f0f2b263",
1119
+ "metadata": {},
1120
+ "outputs": [],
1121
+ "source": [
1122
+ "# model_config = BSRoformerConfig(\n",
1123
+ "# hidden_size=64,\n",
1124
+ "# depth=1,\n",
1125
+ "# num_input_channel=2,\n",
1126
+ "# num_stems=4,\n",
1127
+ "# intermediate_size=64 * 2,\n",
1128
+ "# time_transformer_depth=1,\n",
1129
+ "# freq_transformer_depth=1,\n",
1130
+ "# num_attention_heads=8,\n",
1131
+ "# num_key_value_heads=2,\n",
1132
+ "# #\n",
1133
+ "# mask_estimator_depth=1,\n",
1134
+ "# )\n",
1135
+ "# model = BSRoformerForMaskedEstimation(model_config)\n",
1136
+ "\n",
1137
+ "# dummy_input = torch.randn(6, 2, 44100 * 6)\n",
1138
+ "# output = model(dummy_input)\n",
1139
+ "\n",
1140
+ "# dummy_targets = torch.randn(6, 4, 2, 44100 * 6)\n",
1141
+ "# loss = model(dummy_input, target=dummy_targets)\n",
1142
+ "\n",
1143
+ "# dummy_song = torch.randn(2, 44100 * 30)\n",
1144
+ "# result = inference_one_with_model(\n",
1145
+ "# model,\n",
1146
+ "# dummy_song,\n",
1147
+ "# chunk_size=44100 * 6,\n",
1148
+ "# overlap_size=44100 * 3,\n",
1149
+ "# gap_size=44100 * 1,\n",
1150
+ "# )\n",
1151
+ "\n",
1152
+ "# del model, model_config, dummy_input, output, dummy_targets, loss"
1153
+ ]
1154
+ },
1155
+ {
1156
+ "cell_type": "markdown",
1157
+ "id": "9d26ff61",
1158
+ "metadata": {},
1159
+ "source": [
1160
+ "## 实例化 Datasets"
1161
+ ]
1162
+ },
1163
+ {
1164
+ "cell_type": "code",
1165
+ "execution_count": null,
1166
+ "id": "f4d791c0",
1167
+ "metadata": {},
1168
+ "outputs": [],
1169
+ "source": [
1170
+ "train_dataset = AugmentDataset(\n",
1171
+ " data_path=[\n",
1172
+ " \"/mnt/sda/data/20250826_MUSDB18HQ/train\",\n",
1173
+ " \"/mnt/sda/data/20250826_MUSDB18HQ/test\",\n",
1174
+ " # \"/mnt/sda/data/20250902_DSD100/datas\",\n",
1175
+ " ],\n",
1176
+ " wave_chunk_size=44100 * 6,\n",
1177
+ " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n",
1178
+ ")\n",
1179
+ "val_dataset = ValidationDataset(\n",
1180
+ " data_path=\"/mnt/sda/data/20250826_MUSDB18HQ/valid\",\n",
1181
+ " stem_names=[\"bass\", \"drums\", \"other\", \"vocals\"],\n",
1182
+ ")\n",
1183
+ "\n",
1184
+ "train_loader = DataLoader(\n",
1185
+ " train_dataset,\n",
1186
+ " batch_size=18,\n",
1187
+ " num_workers=num_workers,\n",
1188
+ " pin_memory=True,\n",
1189
+ " persistent_workers=True if num_workers > 0 else False,\n",
1190
+ " prefetch_factor=4 if num_workers > 0 else None,\n",
1191
+ ")\n",
1192
+ "val_loader = DataLoader(\n",
1193
+ " val_dataset,\n",
1194
+ " batch_size=1,\n",
1195
+ " num_workers=num_workers,\n",
1196
+ " pin_memory=True,\n",
1197
+ " persistent_workers=True if num_workers > 0 else False,\n",
1198
+ " shuffle=False,\n",
1199
+ " prefetch_factor=4 if num_workers > 0 else None,\n",
1200
+ ")"
1201
+ ]
1202
+ },
1203
+ {
1204
+ "cell_type": "markdown",
1205
+ "id": "21701211",
1206
+ "metadata": {},
1207
+ "source": [
1208
+ "## Lightning"
1209
+ ]
1210
+ },
1211
+ {
1212
+ "cell_type": "code",
1213
+ "execution_count": null,
1214
+ "id": "23cda886",
1215
+ "metadata": {},
1216
+ "outputs": [],
1217
+ "source": [
1218
+ "def compute_sdr(target, estimate):\n",
1219
+ " target_np = target.float().cpu().numpy()\n",
1220
+ " estimate_np = estimate.float().cpu().numpy()\n",
1221
+ "\n",
1222
+ " sdr_list = []\n",
1223
+ "\n",
1224
+ " for this_target, this_estimate in zip(target_np, estimate_np):\n",
1225
+ " channel_sdrs = []\n",
1226
+ " for this_channel_target, this_channel_estimate in zip(this_target, this_estimate):\n",
1227
+ " signal_power = np.sum(this_channel_target ** 2)\n",
1228
+ " noise_power = np.sum((this_channel_target - this_channel_estimate) ** 2)\n",
1229
+ "\n",
1230
+ " if noise_power == 0:\n",
1231
+ " sdr = float('inf')\n",
1232
+ " else:\n",
1233
+ " sdr = 10 * np.log10(signal_power / noise_power)\n",
1234
+ "\n",
1235
+ " # sdr_list.append(sdr)\n",
1236
+ " channel_sdrs.append(sdr)\n",
1237
+ "\n",
1238
+ " channel_sdr_mean = np.mean(channel_sdrs)\n",
1239
+ " sdr_list.append(channel_sdr_mean)\n",
1240
+ "\n",
1241
+ " return sdr_list\n"
1242
+ ]
1243
+ },
1244
+ {
1245
+ "cell_type": "code",
1246
+ "execution_count": null,
1247
+ "id": "2e5002b1",
1248
+ "metadata": {},
1249
+ "outputs": [],
1250
+ "source": [
1251
+ "class LightningModel(BaseModule):\n",
1252
+ "\n",
1253
+ " def __init__(\n",
1254
+ " self,\n",
1255
+ " model,\n",
1256
+ " lr_config: LearningRateConfig,\n",
1257
+ " training_config: TrainingConfig,\n",
1258
+ " ):\n",
1259
+ " super().__init__(\n",
1260
+ " model,\n",
1261
+ " lr_config,\n",
1262
+ " training_config,\n",
1263
+ " )\n",
1264
+ "\n",
1265
+ " self.validation_sdr_results = []\n",
1266
+ "\n",
1267
+ " def forward(self, x):\n",
1268
+ " return self.model(x)\n",
1269
+ "\n",
1270
+ " def training_step(self, batch, batch_idx):\n",
1271
+ " target_stems, mixed_audio = batch\n",
1272
+ " # target_stems: (batch, stems, channels, time)\n",
1273
+ " # mixed_audio: (batch, channels, time)\n",
1274
+ "\n",
1275
+ " loss = self.model(mixed_audio, target=target_stems)\n",
1276
+ "\n",
1277
+ " grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=4.0)\n",
1278
+ "\n",
1279
+ " self.log('train/grad_norm', grad_norm.item(), on_step=True, on_epoch=False, sync_dist=True)\n",
1280
+ " self.log('train/loss', loss, on_step=True, on_epoch=False, sync_dist=True)\n",
1281
+ "\n",
1282
+ " return loss\n",
1283
+ "\n",
1284
+ " def validation_step(self, batch, batch_idx):\n",
1285
+ " target_stems, mixed_audio = batch\n",
1286
+ "\n",
1287
+ " batch_size = mixed_audio.shape[0]\n",
1288
+ " batch_sdr_scores = []\n",
1289
+ "\n",
1290
+ " for i in range(batch_size):\n",
1291
+ " single_mixed = mixed_audio[i] # (channels, time)\n",
1292
+ " single_target = target_stems[i] # (stems, channels, time)\n",
1293
+ "\n",
1294
+ " with torch.no_grad():\n",
1295
+ " predicted_stems = inference_one_with_model(\n",
1296
+ " self.model,\n",
1297
+ " single_mixed,\n",
1298
+ " chunk_size=44100 * 6,\n",
1299
+ " overlap_size=44100 * 3,\n",
1300
+ " gap_size=0,\n",
1301
+ " ) # (stems, channels, time)\n",
1302
+ "\n",
1303
+ " sdr = compute_sdr(single_target, predicted_stems)\n",
1304
+ " batch_sdr_scores.append(sdr)\n",
1305
+ "\n",
1306
+ " sdrs = np.array(batch_sdr_scores)\n",
1307
+ " sdrs = sdrs.mean(axis=0)\n",
1308
+ "\n",
1309
+ " self.validation_sdr_results.append(sdrs)\n",
1310
+ "\n",
1311
+ " return {\n",
1312
+ " \"val/sdr\": sdrs,\n",
1313
+ " }\n",
1314
+ "\n",
1315
+ " def on_validation_epoch_end(self):\n",
1316
+ " if len(self.validation_sdr_results) > 0:\n",
1317
+ " avg_sdrs = np.mean(self.validation_sdr_results, axis=0)\n",
1318
+ " self.log('val/sdr', avg_sdrs.mean(), on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)\n",
1319
+ " for i, one in enumerate(avg_sdrs):\n",
1320
+ " self.log(f'val/sdr_stem_{i}', one, on_step=False, on_epoch=True, sync_dist=True)\n",
1321
+ "\n",
1322
+ " self.validation_sdr_results.clear()"
1323
+ ]
1324
+ },
1325
+ {
1326
+ "cell_type": "markdown",
1327
+ "id": "e13d2d53",
1328
+ "metadata": {},
1329
+ "source": [
1330
+ "## 配置与实例化"
1331
+ ]
1332
+ },
1333
+ {
1334
+ "cell_type": "code",
1335
+ "execution_count": null,
1336
+ "id": "c31b1d32",
1337
+ "metadata": {},
1338
+ "outputs": [],
1339
+ "source": [
1340
+ "from pl_utils import LearningRateConfig, TrainingConfig\n",
1341
+ "\n",
1342
+ "\n",
1343
+ "learning_rate_config = LearningRateConfig(\n",
1344
+ " lr_warmup_steps=400,\n",
1345
+ " lr_initial=1e-5,\n",
1346
+ " lr_max=5e-4,\n",
1347
+ " lr_end=5e-4,\n",
1348
+ " max_steps=20000,\n",
1349
+ ")\n",
1350
+ "\n",
1351
+ "training_config = TrainingConfig(\n",
1352
+ " optimizer='adamw',\n",
1353
+ " optimizer_args={\n",
1354
+ " 'betas': (0.9, 0.95),\n",
1355
+ " 'weight_decay': 1e-2,\n",
1356
+ " \"fused\": True,\n",
1357
+ " },\n",
1358
+ " excluded_from_weight_decay=[\"bias\", \"norm\", \"embed\", \"scale\"],\n",
1359
+ ")"
1360
+ ]
1361
+ },
1362
+ {
1363
+ "cell_type": "code",
1364
+ "execution_count": null,
1365
+ "id": "13030935",
1366
+ "metadata": {},
1367
+ "outputs": [],
1368
+ "source": [
1369
+ "model_config = BSRoformerConfig(\n",
1370
+ " hidden_size=256,\n",
1371
+ " depth=3,\n",
1372
+ " num_input_channel=2,\n",
1373
+ " num_stems=4,\n",
1374
+ " intermediate_size=256 * 2,\n",
1375
+ " time_transformer_depth=1,\n",
1376
+ " freq_transformer_depth=1,\n",
1377
+ " num_attention_heads=8,\n",
1378
+ " num_key_value_heads=4,\n",
1379
+ " #\n",
1380
+ " mask_estimator_depth=1,\n",
1381
+ " multi_stft_loss_weight=0.0,\n",
1382
+ ")\n",
1383
+ "model = BSRoformerForMaskedEstimation(model_config)\n",
1384
+ "\n",
1385
+ "pl_model = LightningModel(\n",
1386
+ " model,\n",
1387
+ " lr_config=learning_rate_config,\n",
1388
+ " training_config=training_config,\n",
1389
+ ")"
1390
+ ]
1391
+ },
1392
+ {
1393
+ "cell_type": "markdown",
1394
+ "id": "9a430e4f",
1395
+ "metadata": {},
1396
+ "source": [
1397
+ "## 正式训练"
1398
+ ]
1399
+ },
1400
+ {
1401
+ "cell_type": "code",
1402
+ "execution_count": null,
1403
+ "id": "d52e16e9",
1404
+ "metadata": {},
1405
+ "outputs": [],
1406
+ "source": [
1407
+ "from lightning.pytorch.utilities.model_summary import summarize\n",
1408
+ "\n",
1409
+ "summarize(pl_model, max_depth=2)\n",
1410
+ "\n",
1411
+ "model.model.compile(options={\"shape_padding\": True})"
1412
+ ]
1413
+ },
1414
+ {
1415
+ "cell_type": "code",
1416
+ "execution_count": null,
1417
+ "id": "6ff9af11",
1418
+ "metadata": {},
1419
+ "outputs": [],
1420
+ "source": [
1421
+ "import lightning.pytorch as L\n",
1422
+ "from lightning.pytorch.callbacks import ModelCheckpoint\n",
1423
+ "from lightning.pytorch.loggers import TensorBoardLogger\n",
1424
+ "from pl_utils.lightning import format_next_version_name\n",
1425
+ "from lightning.pytorch.strategies import DDPStrategy\n",
1426
+ "\n",
1427
+ "name = \"准备收尾。3层小模型,batch18\"\n",
1428
+ "logger = TensorBoardLogger(save_dir=\"./\", version=format_next_version_name(name))\n",
1429
+ "\n",
1430
+ "checkpoint_callback = ModelCheckpoint(\n",
1431
+ " auto_insert_metric_name=True,\n",
1432
+ " save_top_k=1,\n",
1433
+ " monitor=\"val/sdr\",\n",
1434
+ " mode=\"max\",\n",
1435
+ " every_n_epochs=1,\n",
1436
+ " save_weights_only=True,\n",
1437
+ " # save_last=\"link\",\n",
1438
+ " save_on_train_epoch_end=False,\n",
1439
+ " save_last=True,\n",
1440
+ ")\n",
1441
+ "\n",
1442
+ "trainer = L.Trainer(\n",
1443
+ " logger=logger,\n",
1444
+ " accelerator='gpu',\n",
1445
+ " # max_epochs=16,\n",
1446
+ " strategy=DDPStrategy(find_unused_parameters=False),\n",
1447
+ " precision='16-mixed',\n",
1448
+ " # accumulate_grad_batches=4,\n",
1449
+ " max_steps=200000,\n",
1450
+ " val_check_interval=500,\n",
1451
+ " log_every_n_steps=4,\n",
1452
+ " default_root_dir=\"./\",\n",
1453
+ " #\n",
1454
+ " callbacks=[checkpoint_callback],\n",
1455
+ " # enable_checkpointing=False,\n",
1456
+ " #\n",
1457
+ " num_sanity_val_steps=0,\n",
1458
+ " # fast_dev_run=True,\n",
1459
+ " # enable_checkpointing=False,\n",
1460
+ " enable_model_summary=True,\n",
1461
+ ")\n",
1462
+ "\n",
1463
+ "trainer.fit(pl_model, train_loader, val_loader)"
1464
+ ]
1465
+ },
1466
+ {
1467
+ "cell_type": "markdown",
1468
+ "id": "f9304c3e",
1469
+ "metadata": {},
1470
+ "source": [
1471
+ "## 提前退出"
1472
+ ]
1473
+ },
1474
+ {
1475
+ "cell_type": "code",
1476
+ "execution_count": null,
1477
+ "id": "5e11d871",
1478
+ "metadata": {},
1479
+ "outputs": [],
1480
+ "source": [
1481
+ "import sys\n",
1482
+ "from IPython import get_ipython\n",
1483
+ "\n",
1484
+ "\n",
1485
+ "# 如果是脚本而不是jupyter notebook,此时就该退出了\n",
1486
+ "try:\n",
1487
+ " shell = get_ipython()\n",
1488
+ " if shell is None:\n",
1489
+ " sys.exit()\n",
1490
+ "except:\n",
1491
+ " sys.exit()"
1492
+ ]
1493
+ },
1494
+ {
1495
+ "cell_type": "markdown",
1496
+ "id": "8bca044c",
1497
+ "metadata": {},
1498
+ "source": [
1499
+ "## 加载与推理"
1500
+ ]
1501
+ },
1502
+ {
1503
+ "cell_type": "code",
1504
+ "execution_count": null,
1505
+ "id": "2f65245d",
1506
+ "metadata": {},
1507
+ "outputs": [],
1508
+ "source": [
1509
+ "pl_model = LightningModel.load_from_checkpoint(\n",
1510
+ " \"lightning_logs/version_029_可学习残差(策略为共享一个参数)/checkpoints/last.ckpt\",\n",
1511
+ " model=model,\n",
1512
+ ")"
1513
+ ]
1514
+ },
1515
+ {
1516
+ "cell_type": "code",
1517
+ "execution_count": null,
1518
+ "id": "f865e1f5",
1519
+ "metadata": {},
1520
+ "outputs": [],
1521
+ "source": [
1522
+ "waves, mixed_wave = val_dataset[0]"
1523
+ ]
1524
+ },
1525
+ {
1526
+ "cell_type": "code",
1527
+ "execution_count": null,
1528
+ "id": "badd73dd",
1529
+ "metadata": {},
1530
+ "outputs": [],
1531
+ "source": [
1532
+ "with torch.inference_mode():\n",
1533
+ " predicted_stems = inference_one_with_model(\n",
1534
+ " pl_model.model,\n",
1535
+ " torch.tensor(mixed_wave).to(\"cuda\"),\n",
1536
+ " chunk_size=44100 * 6,\n",
1537
+ " overlap_size=44100 * 3,\n",
1538
+ " ) # (stems, channels, time)"
1539
+ ]
1540
+ },
1541
+ {
1542
+ "cell_type": "code",
1543
+ "execution_count": null,
1544
+ "id": "a1fa8bbd",
1545
+ "metadata": {},
1546
+ "outputs": [],
1547
+ "source": [
1548
+ "predicted_stems.shape"
1549
+ ]
1550
+ },
1551
+ {
1552
+ "cell_type": "code",
1553
+ "execution_count": null,
1554
+ "id": "751c4974",
1555
+ "metadata": {},
1556
+ "outputs": [],
1557
+ "source": [
1558
+ "os.makedirs(\"./outputs\", exist_ok=True)\n",
1559
+ "\n",
1560
+ "for i in range(predicted_stems.shape[0]):\n",
1561
+ " import soundfile as sf\n",
1562
+ "\n",
1563
+ " sf.write(f\"./outputs/predicted_stem_{i}.wav\", predicted_stems[i].cpu().numpy().T, 44100)"
1564
+ ]
1565
+ },
1566
+ {
1567
+ "cell_type": "code",
1568
+ "execution_count": null,
1569
+ "id": "d934415c",
1570
+ "metadata": {},
1571
+ "outputs": [],
1572
+ "source": [
1573
+ "sf.write(\"./outputs/mixed.wav\", mixed_wave.T, 44100)"
1574
+ ]
1575
+ }
1576
+ ],
1577
+ "metadata": {
1578
+ "kernelspec": {
1579
+ "display_name": "20250820_bs-roformer",
1580
+ "language": "python",
1581
+ "name": "python3"
1582
+ },
1583
+ "language_info": {
1584
+ "codemirror_mode": {
1585
+ "name": "ipython",
1586
+ "version": 3
1587
+ },
1588
+ "file_extension": ".py",
1589
+ "mimetype": "text/x-python",
1590
+ "name": "python",
1591
+ "nbconvert_exporter": "python",
1592
+ "pygments_lexer": "ipython3",
1593
+ "version": "3.13.5"
1594
+ }
1595
+ },
1596
+ "nbformat": 4,
1597
+ "nbformat_minor": 5
1598
+ }