tekitoutarou commited on
Commit
f73ae00
·
verified ·
1 Parent(s): 73e70d7

Upload 12 files

Browse files

here is everything. Hope this leads to Good results.

BS_Base_Model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0363fdc84906eb52c092b842c6dc1b231065d927604b35b6da6cbc1c38c28a6
3
+ size 1102136494
BS_Base_Model.yaml ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio:
2
+ chunk_size: 588800
3
+ dim_f: 1024
4
+ dim_t: 801
5
+ hop_length: 441
6
+ min_mean_abs: 0.0
7
+ n_fft: 2048
8
+ num_channels: 2
9
+ sample_rate: 44100
10
+ augmentations:
11
+ all:
12
+ channel_shuffle: 0.5
13
+ random_inverse: 0.1
14
+ random_polarity: 0.5
15
+ bass:
16
+ pitch_shift: 0.1
17
+ pitch_shift_max_semitones: 2
18
+ pitch_shift_min_semitones: -2
19
+ seven_band_parametric_eq: 0.1
20
+ seven_band_parametric_eq_max_gain_db: 6
21
+ seven_band_parametric_eq_min_gain_db: -3
22
+ tanh_distortion: 0.1
23
+ tanh_distortion_max: 0.5
24
+ tanh_distortion_min: 0.1
25
+ drums:
26
+ pitch_shift: 0.1
27
+ pitch_shift_max_semitones: 5
28
+ pitch_shift_min_semitones: -5
29
+ seven_band_parametric_eq: 0.1
30
+ seven_band_parametric_eq_max_gain_db: 9
31
+ seven_band_parametric_eq_min_gain_db: -9
32
+ tanh_distortion: 0.1
33
+ tanh_distortion_max: 0.6
34
+ tanh_distortion_min: 0.1
35
+ enable: true
36
+ loudness: true
37
+ loudness_max: 1.5
38
+ loudness_min: 0.5
39
+ mixup: true
40
+ mixup_loudness_max: 1.5
41
+ mixup_loudness_min: 0.5
42
+ mixup_probs: !!python/tuple
43
+ - 0.2
44
+ - 0.02
45
+ other:
46
+ gaussian_noise: 0.1
47
+ gaussian_noise_max_amplitude: 0.015
48
+ gaussian_noise_min_amplitude: 0.001
49
+ pitch_shift: 0.1
50
+ pitch_shift_max_semitones: 4
51
+ pitch_shift_min_semitones: -4
52
+ time_stretch: 0.1
53
+ time_stretch_max_rate: 1.25
54
+ time_stretch_min_rate: 0.8
55
+ vocals:
56
+ pitch_shift: 0.1
57
+ pitch_shift_max_semitones: 5
58
+ pitch_shift_min_semitones: -5
59
+ seven_band_parametric_eq: 0.1
60
+ seven_band_parametric_eq_max_gain_db: 9
61
+ seven_band_parametric_eq_min_gain_db: -9
62
+ tanh_distortion: 0.1
63
+ tanh_distortion_max: 0.7
64
+ tanh_distortion_min: 0.1
65
+ inference:
66
+ batch_size: 1
67
+ dim_t: 1101
68
+ normalize: false
69
+ num_overlap: 2
70
+ model:
71
+ attn_dropout: 0.1
72
+ depth: 12
73
+ dim: 256
74
+ dim_freqs_in: 1025
75
+ dim_head: 64
76
+ ff_dropout: 0.1
77
+ flash_attn: false
78
+ freq_transformer_depth: 1
79
+ freqs_per_bands:
80
+ - 2
81
+ - 2
82
+ - 2
83
+ - 2
84
+ - 2
85
+ - 2
86
+ - 2
87
+ - 2
88
+ - 2
89
+ - 2
90
+ - 2
91
+ - 2
92
+ - 2
93
+ - 2
94
+ - 2
95
+ - 2
96
+ - 2
97
+ - 2
98
+ - 2
99
+ - 2
100
+ - 2
101
+ - 2
102
+ - 2
103
+ - 2
104
+ - 4
105
+ - 4
106
+ - 4
107
+ - 4
108
+ - 4
109
+ - 4
110
+ - 4
111
+ - 4
112
+ - 4
113
+ - 4
114
+ - 4
115
+ - 4
116
+ - 12
117
+ - 12
118
+ - 12
119
+ - 12
120
+ - 12
121
+ - 12
122
+ - 12
123
+ - 12
124
+ - 24
125
+ - 24
126
+ - 24
127
+ - 24
128
+ - 24
129
+ - 24
130
+ - 24
131
+ - 24
132
+ - 48
133
+ - 48
134
+ - 48
135
+ - 48
136
+ - 48
137
+ - 48
138
+ - 48
139
+ - 48
140
+ - 128
141
+ - 129
142
+ heads: 8
143
+ kan_grid_size: 8
144
+ linear_transformer_depth: 0
145
+ mask_estimator_depth: 2
146
+ mlp_expansion_factor: 4
147
+ multi_stft_hop_size: 147
148
+ multi_stft_normalized: false
149
+ multi_stft_resolution_loss_weight: 1.0
150
+ multi_stft_resolutions_window_sizes:
151
+ - 4096
152
+ - 2048
153
+ - 1024
154
+ - 512
155
+ - 256
156
+ num_stems: 6
157
+ sage_attention: false
158
+ skip_connection: false
159
+ stereo: true
160
+ stft_hop_length: 512
161
+ stft_n_fft: 2048
162
+ stft_normalized: false
163
+ stft_win_length: 2048
164
+ time_transformer_depth: 1
165
+ use_kan: true
166
+ use_torch_checkpoint: false
167
+ training:
168
+ augmentation: false
169
+ augmentation_loudness: true
170
+ augmentation_loudness_max: 1.5
171
+ augmentation_loudness_min: 0.5
172
+ augmentation_loudness_type: 1
173
+ augmentation_mix: true
174
+ augmentation_type: simple1
175
+ batch_size: 2
176
+ coarse_loss_clip: true
177
+ ema_momentum: 0.999
178
+ grad_clip: 0
179
+ gradient_accumulation_steps: 1
180
+ instruments:
181
+ - bass
182
+ - drums
183
+ - other
184
+ - vocals
185
+ - guitar
186
+ - piano
187
+ lr: 1.0e-05
188
+ num_epochs: 1000
189
+ num_steps: 1000
190
+ optimizer: adam
191
+ other_fix: false
192
+ patience: 3
193
+ q: 0.95
194
+ reduce_factor: 0.95
195
+ target_instrument: null
196
+ use_amp: true
197
+ use_mp3_compress: false
MelBand Base Model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2a9652c40d90519a5708898b8c32b8f90666e1f8ef95890f91cced72dc22ac8
3
+ size 1366088139
MelBand Base Model.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio:
2
+ chunk_size: 352800
3
+ dim_f: 1024
4
+ dim_t: 256
5
+ hop_length: 441
6
+ min_mean_abs: 0
7
+ n_fft: 2048
8
+ num_channels: 2
9
+ sample_rate: 44100
10
+ inference:
11
+ batch_size: 2
12
+ dim_t: 256
13
+ num_overlap: 4
14
+ model:
15
+ attn_dropout: 0
16
+ depth: 6
17
+ dim: 384
18
+ dim_freqs_in: 1025
19
+ dim_head: 64
20
+ ff_dropout: 0
21
+ flash_attn: false
22
+ freq_transformer_depth: 1
23
+ heads: 8
24
+ kan_grid_size: 8
25
+ mask_estimator_depth: 2
26
+ multi_stft_hop_size: 147
27
+ multi_stft_normalized: false
28
+ multi_stft_resolution_loss_weight: 1.0
29
+ multi_stft_resolutions_window_sizes:
30
+ - 4096
31
+ - 2048
32
+ - 1024
33
+ - 512
34
+ - 256
35
+ num_bands: 60
36
+ num_stems: 1
37
+ sage_attention: false
38
+ sample_rate: 44100
39
+ stereo: true
40
+ stft_hop_length: 441
41
+ stft_n_fft: 2048
42
+ stft_normalized: false
43
+ stft_win_length: 2048
44
+ time_transformer_depth: 1
45
+ use_kan: true
46
+ use_torch_checkpoint: false
47
+ training:
48
+ augmentation: false
49
+ augmentation_loudness: false
50
+ augmentation_loudness_max: 0
51
+ augmentation_loudness_min: 0
52
+ augmentation_loudness_type: 1
53
+ augmentation_mix: false
54
+ augmentation_type: null
55
+ batch_size: 2
56
+ coarse_loss_clip: false
57
+ ema_momentum: 0.999
58
+ grad_clip: 0
59
+ gradient_accumulation_steps: 1
60
+ instruments:
61
+ - dry
62
+ - other
63
+ lr: 1.0e-05
64
+ num_epochs: 1000
65
+ num_steps: 4032
66
+ optimizer: adam
67
+ other_fix: false
68
+ patience: 8
69
+ q: 0.95
70
+ reduce_factor: 0.95
71
+ target_instrument: dry
72
+ use_mp3_compress: false
README.md CHANGED
@@ -1,3 +1,67 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Faster-RoKAN Core
2
+
3
+ Faster-RoKAN is a next-generation hybrid architecture that integrates Faster-KAN (Kolmogorov-Arnold Networks) into the BS-Roformer audio source separation model.
4
+
5
+ ## Features
6
+ - **Isomorphic Conversion**: Convert standard BS-Roformer or MelBand-Roformer models to the RoKAN architecture with ZERO fidelity loss (MAE ≈ 0.0).
7
+ - **Faster-KAN (RSWAF)**: Replaces linear MLP layers with Reflectional Switch Wavelet Activation Functions for efficient, expressive, and detailed non-linear learning. High-frequency artifacts are filtered out through smooth geometric spline curves.
8
+ - **Gentle Training**: Optimized for standard consumer hardware with thermal management considerations.
9
+
10
+ ## Includes Base Model
11
+ To get you started immediately, we have included a pre-converted **`Base_Model.ckpt`** and **`Base_Model.yaml`** in this package.
12
+ This base model is already functioning perfectly. You skip the conversion step entirely and jump straight to fine-tuning it on your own dataset!
13
+
14
+ ## Setup
15
+ ```bash
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ## Usage
20
+
21
+ ### 0. (Optional) How to Make Your Own RoKAN Model
22
+ If you want to use a different checkpoint rather than the provided `Base_Model`, you can convert your existing standard `.ckpt` to the RoKAN format automatically with `convert_bs_to_rokan.py`.
23
+ **(Note: You do NOT need to do this if you just want to use the included `Base_Model`.)**
24
+
25
+ ```bash
26
+ python convert_bs_to_rokan.py \
27
+ --src_yaml dataset/Models/your_model.yaml \
28
+ --src_ckpt dataset/Models/your_model.ckpt \
29
+ --out_yaml converted/rokan.yaml \
30
+ --out_ckpt converted/rokan.ckpt
31
+ ```
32
+
33
+ **How it works (For both BS & MelBand):**
34
+ The `convert_bs_to_rokan.py` script automatically analyzes your `.yaml` configuration to determine whether it is a **BS-Roformer** or a **MelBand-Roformer** (by checking for the `num_bands` parameter).
35
+ Depending on the architecture, it seamlessly intercepts the standard linear MLP components located inside the Siamese or Standard Transformer FeedForward blocks, and replaces them with our custom `FasterKANLinear` blocks. All base knowledge is perfectly preserved without any fidelity loss.
36
+
37
+ ### 1. Fine-tuning
38
+ Train only the new KAN spline parameters on your dataset to remove high-frequency artifacts and teach the model geometric patterns. The script will automatically unfreeze *only* the new KAN parameters while keeping the base knowledge perfectly intact.
39
+
40
+ ```bash
41
+ python train_rokan.py --ckpt_path Base_Model.ckpt --yaml_path Base_Model.yaml
42
+ ```
43
+ *(Store your vocal audio in `dataset/vocals/` and instrumental audio in `dataset/instrumentals/` before running).*
44
+
45
+ ### 2. Inference
46
+ Run source separation using the pre-tuned or fine-tuned model:
47
+ ```bash
48
+ python run_infer_rokan.py \
49
+ --model_path Base_Model.ckpt \
50
+ --config_path Base_Model.yaml \
51
+ --input_audio your_song.wav
52
+ ```
53
+
54
+ ---
55
+
56
+ ## Credits, Contact & Disclaimer
57
+
58
+ **All Method Made By Himadayon.**
59
+ **IMPORTANT:** If you release or distribute any models that utilize this architecture or are fine-tuned using this repository, you **must** explicitly explicitly credit `Himadayon` in your release notes or repository.
60
+
61
+ **Contact:**
62
+ If you have any questions or inquiries regarding this project, please send an email to:
63
+ 📧 **Joker200702@gmail.com**
64
+ *(Please make sure to include a clear subject line and detailed contents in your email).*
65
+
66
+ **Disclaimer:**
67
+ For the purpose of experimental verification and architectural testing, existing base models originally developed by **unwa** and **Aname** were utilized during the development of this project.
agent_monitor.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import subprocess
5
+ import json
6
+ import urllib.request
7
+
8
+ # ==========================================================
9
+ # Terminal Agent (Gemini API) for BS-RoKAN 監視
10
+ # VRAM消費: 0GB / CPU負荷: 極小
11
+ # ==========================================================
12
+
13
+ # APIキーをファイルから読み込む
14
+ KEY_FILE = "APIKey From Google AI Studio.txt"
15
+ if os.path.exists(KEY_FILE):
16
+ with open(KEY_FILE, "r") as f:
17
+ API_KEY = f.read().strip()
18
+ else:
19
+ API_KEY = os.environ.get("GEMINI_API_KEY", "")
20
+
21
+ MODEL_NAME = "gemini-3.1-flash-lite"
22
+
23
+ def analyze_logs_with_llm(log_buffer):
24
+ if not API_KEY:
25
+ print("[Agent] API_KEYがないため判定をスキップ(OK)")
26
+ return "OK"
27
+
28
+ system_instruction = "あなたは音声分離モデルBS-RoKANの学習監視エージェントです。以下の学習ログを見て、学習が順調か評価してください。"
29
+ prompt = f"{system_instruction} 出力は OK, LOWER_LR, RESTART のいずれか1語のみにしてください。 \n\nログ:\n" + "\n".join(log_buffer)
30
+
31
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_NAME}:generateContent?key={API_KEY}"
32
+
33
+ # Gemini API (REST) format
34
+ payload = {
35
+ "contents": [{
36
+ "parts": [{"text": prompt}]
37
+ }],
38
+ "generationConfig": {
39
+ "temperature": 0.1,
40
+ "maxOutputTokens": 10
41
+ }
42
+ }
43
+
44
+ try:
45
+ req = urllib.request.Request(url, data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"})
46
+ with urllib.request.urlopen(req, timeout=15) as r:
47
+ response = json.loads(r.read())
48
+ # Extract text from Gemini response structure
49
+ decision = response["candidates"][0]["content"]["parts"][0]["text"].strip().upper()
50
+
51
+ if "LOWER_LR" in decision: return "LOWER_LR"
52
+ if "RESTART" in decision: return "RESTART"
53
+ return "OK"
54
+ except Exception as e:
55
+ print(f"[Agent] Gemini APIエラー: {e}")
56
+ return "OK"
57
+
58
+ def main():
59
+ print(f"[*] Gemini Terminal Agent 起動成功 (Model: {MODEL_NAME})")
60
+ print(f"[*] 学習プロセスを起動中...")
61
+
62
+ # RX 9070 XT想定: WSL2上でバッチサイズ2で開始
63
+ cmd = ["python", "-u", "train_rokan.py", "--batch_size", "2"]
64
+
65
+ while True:
66
+ print(f"\n[Agent] 訓練開始: {' '.join(cmd)}")
67
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
68
+ log_buffer = []
69
+
70
+ try:
71
+ for line in process.stdout:
72
+ line = line.strip()
73
+ if not line: continue
74
+ print(line)
75
+
76
+ if "Loss" in line or "Saved:" in line:
77
+ log_buffer.append(line)
78
+
79
+ # セーブ(Epoch終了)ごとにGeminiで診断を行う
80
+ if "Saved:" in line and len(log_buffer) > 5:
81
+ decision = analyze_logs_with_llm(log_buffer[-30:])
82
+ if decision == "LOWER_LR":
83
+ print(f"[Agent] Geminiの判定: {decision} (学習率を下げて再開します)")
84
+ process.terminate()
85
+ if "--gate_lr" not in cmd:
86
+ cmd.extend(["--gate_lr", "5e-4"]) # 1e-3 -> 5e-4
87
+ break
88
+ elif decision == "RESTART":
89
+ print(f"[Agent] Geminiの判定: {decision} (異常検知につき再起動します)")
90
+ process.terminate()
91
+ time.sleep(5)
92
+ break
93
+ else:
94
+ print(f"[Agent] Geminiの判定: {decision} (順調です)")
95
+ log_buffer = [] # バッファをクリア
96
+
97
+ except KeyboardInterrupt:
98
+ print("\n[Agent] ユーザーによる中断。プロセスを終了します。")
99
+ process.terminate()
100
+ sys.exit(0)
101
+
102
+ process.wait()
103
+ if process.returncode != 0 and process.returncode is not None:
104
+ print(f"[Agent] 訓練プロセスが終了しました (Code: {process.returncode})。10秒後に再起動を試みます。")
105
+ time.sleep(10)
106
+
107
+ if __name__ == "__main__":
108
+ main()
convert_bs_to_rokan.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ convert_bs_to_rokan.py
3
+ =======================
4
+ Universal converter: Any standard BS-Roformer checkpoint → Faster-RoKAN
5
+
6
+ Usage:
7
+ python convert_bs_to_rokan.py \\
8
+ --src_yaml dataset/Models/BS-Rofo-SW-Fixed.yaml \\
9
+ --src_ckpt dataset/Models/BS-Rofo-SW-Fixed.ckpt \\
10
+ --out_yaml bs_rokan_sw.yaml \\
11
+ --out_ckpt bs_rokan_sw.ckpt \\
12
+ --grid_size 8
13
+
14
+ What it does:
15
+ 1. Reads the source YAML and builds a matching BSRoformer with use_kan=True
16
+ 2. Loads the source checkpoint
17
+ 3. Copies ALL compatible weights (Attention, norms, band-split, mask-estimator)
18
+ 4. Remaps FeedForward Linear weights -> FasterKANLinear.base_weight
19
+ net.1.weight -> net.1.base_weight (first projection)
20
+ net.4.weight -> net.3.base_weight (second projection)
21
+ 5. Saves the new Faster-RoKAN checkpoint + YAML
22
+ """
23
+
24
+ import os
25
+ import sys
26
+ import inspect
27
+ import argparse
28
+ import torch
29
+ import yaml
30
+
31
+ sys.path.insert(0, '/home/boss/BS-RoKAN-lab')
32
+ from models.bs_roformer.bs_roformer import BSRoformer
33
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
34
+
35
+
36
+ # ── YAML helpers ──────────────────────────────────────────────────────────────
37
+
38
+ def load_yaml_fullloader(path):
39
+ with open(path, 'r') as f:
40
+ return yaml.load(f, Loader=yaml.FullLoader)
41
+
42
+ def load_yaml_strip_tags(path):
43
+ """Fallback: strip !!python/tuple tags before loading."""
44
+ with open(path, 'r') as f:
45
+ raw = f.read()
46
+ raw = raw.replace('!!python/tuple', '')
47
+ return yaml.safe_load(raw)
48
+
49
+ def load_yaml_any(path):
50
+ try:
51
+ return load_yaml_fullloader(path)
52
+ except Exception:
53
+ return load_yaml_strip_tags(path)
54
+
55
+ def ensure_tuples(cfg):
56
+ """Make sure tuple fields are actual tuples (beartype requirement)."""
57
+ for key in ('freqs_per_bands', 'multi_stft_resolutions_window_sizes'):
58
+ if key in cfg and not isinstance(cfg[key], tuple):
59
+ cfg[key] = tuple(cfg[key])
60
+ return cfg
61
+
62
+
63
+ # ── Checkpoint helpers ────────────────────────────────────────────────────────
64
+
65
+ def load_ckpt_flexible(path):
66
+ sd = torch.load(path, map_location='cpu')
67
+ if isinstance(sd, dict):
68
+ if 'state_dict' in sd:
69
+ sd = sd['state_dict']
70
+ elif 'model' in sd:
71
+ sd = sd['model']
72
+ # Strip model. prefix if present
73
+ return {(k[6:] if k.startswith('model.') else k): v for k, v in sd.items()}
74
+
75
+
76
+ # ── Model builder ─────────────────────────────────────────────────────────────
77
+
78
+ def build_rokan(src_cfg, grid_size):
79
+ """Build Faster-RoKAN with same arch as source config."""
80
+ m = dict(src_cfg) # copy
81
+ m = ensure_tuples(m)
82
+ m['use_kan'] = True
83
+ m['kan_grid_size'] = grid_size
84
+ m['flash_attn'] = False # Disable for stability during conversion
85
+ m.pop('use_torch_checkpoint', None) # Remove if present
86
+ m['use_torch_checkpoint'] = False
87
+ m['sage_attention'] = False
88
+
89
+ model_cls = MelBandRoformer if 'num_bands' in m else BSRoformer
90
+ model_sig = inspect.signature(model_cls.__init__)
91
+ allowed = set(model_sig.parameters.keys()) - {'self'}
92
+ filtered = {k: v for k, v in m.items() if k in allowed}
93
+ return model_cls(**filtered)
94
+
95
+
96
+ # ── Weight mapping ────────────────────────────────────────────────────────────
97
+
98
+ def remap_and_load(src_sd, model):
99
+ """
100
+ Load source weights into Faster-RoKAN model:
101
+ - Direct matches (Attention, norms, etc.) → copied as-is
102
+ - *.net.1.weight (FF first Linear) → *.net.1.base_weight
103
+ - *.net.4.weight (FF second Linear) → *.net.3.base_weight
104
+ - *.net.1.bias / *.net.4.bias → skipped (KAN has no bias term)
105
+ - Everything KAN-specific (spline, gate) → stays at init (to be learned)
106
+ """
107
+ model_dict = model.state_dict()
108
+ matched = {}
109
+ remapped = 0
110
+ skipped = []
111
+
112
+ for k, v in src_sd.items():
113
+ # Direct match
114
+ if k in model_dict and v.shape == model_dict[k].shape:
115
+ matched[k] = v
116
+ continue
117
+
118
+ # Remap FF Linear → base_weight
119
+ remap = None
120
+ if k.endswith('.net.1.weight'):
121
+ remap = k.replace('.net.1.weight', '.net.1.base_weight')
122
+ elif k.endswith('.net.4.weight'):
123
+ remap = k.replace('.net.4.weight', '.net.3.base_weight')
124
+ elif k.endswith('.net.1.bias'):
125
+ remap = k.replace('.net.1.bias', '.net.1.base_bias')
126
+ elif k.endswith('.net.4.bias'):
127
+ remap = k.replace('.net.4.bias', '.net.3.base_bias')
128
+
129
+ if remap and remap in model_dict and v.shape == model_dict[remap].shape:
130
+ matched[remap] = v
131
+ remapped += 1
132
+ else:
133
+ skipped.append(k)
134
+
135
+ model_dict.update(matched)
136
+ model.load_state_dict(model_dict)
137
+
138
+ print(f" Loaded: {len(matched)} tensors")
139
+ print(f" Remapped: {remapped} FF Linear → base_weight")
140
+ print(f" Skipped: {len(skipped)} (biases, incompatible shapes)")
141
+
142
+ # Show what KAN params remain random (to be trained)
143
+ kan_random = [k for k in model_dict if k not in matched]
144
+ kan_types = set(k.split('.')[-1] for k in kan_random)
145
+ print(f" KAN init: {len(kan_random)} tensors types={kan_types}")
146
+ return model
147
+
148
+
149
+ # ── YAML writer ───────────────────────────────────────────────────────────────
150
+
151
+ def write_out_yaml(src_yaml_path, out_yaml_path, grid_size):
152
+ """Write output YAML with use_kan=True and kan_grid_size added."""
153
+ raw = load_yaml_fullloader(src_yaml_path)
154
+ raw['model']['use_kan'] = True
155
+ raw['model']['kan_grid_size'] = grid_size
156
+ raw['model']['flash_attn'] = False
157
+ raw['model']['use_torch_checkpoint'] = False
158
+ raw['model']['sage_attention'] = False
159
+
160
+ # Make sure tuple fields survive round-trip as plain lists (yaml.dump is fine)
161
+ for key in ('freqs_per_bands', 'multi_stft_resolutions_window_sizes'):
162
+ if key in raw['model'] and isinstance(raw['model'][key], tuple):
163
+ raw['model'][key] = list(raw['model'][key])
164
+
165
+ with open(out_yaml_path, 'w') as f:
166
+ yaml.dump(raw, f, default_flow_style=False, allow_unicode=True)
167
+ print(f" Wrote YAML: {out_yaml_path}")
168
+
169
+
170
+ # ── Main ──────────────────────────────────────────────────────────────────────
171
+
172
+ def main():
173
+ parser = argparse.ArgumentParser(description='Convert BS-Roformer → Faster-RoKAN')
174
+ parser.add_argument('--src_yaml', required=True, help='Source model YAML')
175
+ parser.add_argument('--src_ckpt', required=True, help='Source model checkpoint (.ckpt)')
176
+ parser.add_argument('--out_yaml', default='bs_rokan_converted.yaml', help='Output YAML path')
177
+ parser.add_argument('--out_ckpt', default='bs_rokan_converted.ckpt', help='Output checkpoint path')
178
+ parser.add_argument('--grid_size', type=int, default=8, help='Faster-KAN grid size (wavelet count)')
179
+ args = parser.parse_args()
180
+
181
+ print(f"\n[*] BS-Roformer → Faster-RoKAN Converter")
182
+ print(f" src_yaml : {args.src_yaml}")
183
+ print(f" src_ckpt : {args.src_ckpt}")
184
+ print(f" out_yaml : {args.out_yaml}")
185
+ print(f" out_ckpt : {args.out_ckpt}")
186
+ print(f" grid_size: {args.grid_size}\n")
187
+
188
+ # 1. Load source config
189
+ print("[1/4] Loading source YAML...")
190
+ src_raw = load_yaml_any(args.src_yaml)
191
+ src_cfg = src_raw['model']
192
+ src_cfg = ensure_tuples(src_cfg)
193
+ print(f" dim={src_cfg['dim']}, depth={src_cfg['depth']}, stereo={src_cfg.get('stereo')}")
194
+
195
+ # 2. Build Faster-RoKAN model
196
+ print("\n[2/4] Building Faster-RoKAN model...")
197
+ model = build_rokan(src_cfg, args.grid_size)
198
+ total_params = sum(p.numel() for p in model.parameters()) / 1e6
199
+ print(f" Model built. Parameters: {total_params:.1f}M")
200
+
201
+ # 3. Load & remap weights
202
+ print("\n[3/4] Loading source checkpoint and remapping weights...")
203
+ src_sd = load_ckpt_flexible(args.src_ckpt)
204
+ print(f" Source checkpoint has {len(src_sd)} tensors")
205
+ model = remap_and_load(src_sd, model)
206
+
207
+ # 4. Save
208
+ print("\n[4/4] Saving Faster-RoKAN...")
209
+ torch.save(model.state_dict(), args.out_ckpt)
210
+ print(f" Saved checkpoint: {args.out_ckpt}")
211
+
212
+ write_out_yaml(args.src_yaml, args.out_yaml, args.grid_size)
213
+
214
+ print("\n[*] Conversion complete!")
215
+ print(f" Inference: MODEL_YAML={args.out_yaml} MODEL_CKPT={args.out_ckpt} python run_infer_rokan.py")
216
+ print(f" Training : python train_rokan.py (update ckpt_path in script to {args.out_ckpt})")
217
+
218
+ if __name__ == '__main__':
219
+ main()
eval_fidelity_report.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RoKAN Fidelity Report
2
+
3
+ - input_wav: `input/Arctic Tundra.wav`
4
+ - device: `cuda`
5
+
6
+ ## BS-Roformer
7
+ - status: OK
8
+ - sample_rate: 44100
9
+ - audio_seconds: 152.63
10
+ - teacher_infer_sec: 30.59
11
+ - rokan_infer_sec: 102.77
12
+ - mae: 0.00000000
13
+ - rmse: 0.00000000
14
+ - max_abs: 0.00000004
15
+
16
+ ## MelBand-Roformer
17
+ - status: OK
18
+ - sample_rate: 44100
19
+ - audio_seconds: 152.63
20
+ - teacher_infer_sec: 21.30
21
+ - rokan_infer_sec: 79.54
22
+ - mae: 0.00000384
23
+ - rmse: 0.00000723
24
+ - max_abs: 0.00013021
evaluate_rokan_fidelity.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import soundfile as sf
6
+ import torch
7
+ import torchaudio.functional as AF
8
+ import yaml
9
+
10
+ from models.bs_roformer.bs_roformer import BSRoformer
11
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
12
+
13
+
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+
17
+ def load_cfg(path: Path):
18
+ with path.open("r", encoding="utf-8") as f:
19
+ return yaml.load(f, Loader=yaml.FullLoader)
20
+
21
+
22
+ def clean_state_dict(ckpt_path: Path):
23
+ sd = torch.load(str(ckpt_path), map_location="cpu")
24
+ if isinstance(sd, dict) and "state_dict" in sd:
25
+ sd = sd["state_dict"]
26
+ if isinstance(sd, dict) and "model" in sd:
27
+ sd = sd["model"]
28
+ cleaned = {}
29
+ for k, v in sd.items():
30
+ cleaned[k[6:] if k.startswith("model.") else k] = v
31
+ return cleaned
32
+
33
+
34
+ def build_model_from_yaml(yaml_path: Path):
35
+ cfg = load_cfg(yaml_path)
36
+ m = cfg["model"]
37
+ audio_cfg = cfg["audio"]
38
+ kwargs = dict(
39
+ dim=m["dim"],
40
+ depth=m["depth"],
41
+ stereo=m.get("stereo", True),
42
+ num_stems=m.get("num_stems", 1),
43
+ time_transformer_depth=m.get("time_transformer_depth", 1),
44
+ freq_transformer_depth=m.get("freq_transformer_depth", 1),
45
+ linear_transformer_depth=m.get("linear_transformer_depth", 0),
46
+ dim_head=m.get("dim_head", 64),
47
+ heads=m.get("heads", 8),
48
+ attn_dropout=m.get("attn_dropout", 0.0),
49
+ ff_dropout=m.get("ff_dropout", 0.0),
50
+ flash_attn=False,
51
+ dim_freqs_in=m.get("dim_freqs_in", 1025),
52
+ stft_n_fft=m.get("stft_n_fft", 2048),
53
+ stft_hop_length=m.get("stft_hop_length", 512),
54
+ stft_win_length=m.get("stft_win_length", 2048),
55
+ stft_normalized=m.get("stft_normalized", False),
56
+ mask_estimator_depth=m.get("mask_estimator_depth", 2),
57
+ multi_stft_resolution_loss_weight=m.get("multi_stft_resolution_loss_weight", 1.0),
58
+ multi_stft_resolutions_window_sizes=tuple(m.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256))),
59
+ multi_stft_hop_size=m.get("multi_stft_hop_size", 147),
60
+ multi_stft_normalized=m.get("multi_stft_normalized", False),
61
+ mlp_expansion_factor=m.get("mlp_expansion_factor", 4),
62
+ use_torch_checkpoint=False,
63
+ skip_connection=m.get("skip_connection", False),
64
+ sage_attention=m.get("sage_attention", False),
65
+ use_kan=m.get("use_kan", False),
66
+ kan_grid_size=m.get("kan_grid_size", 8),
67
+ )
68
+ if "freqs_per_bands" in m:
69
+ kwargs["freqs_per_bands"] = tuple(m["freqs_per_bands"])
70
+
71
+ if "num_bands" in m:
72
+ kwargs["num_bands"] = m.get("num_bands", 60)
73
+ kwargs["sample_rate"] = m.get("sample_rate", audio_cfg.get("sample_rate", 44100))
74
+ model = MelBandRoformer(**kwargs)
75
+ else:
76
+ model = BSRoformer(**kwargs)
77
+ return model, audio_cfg["sample_rate"]
78
+
79
+
80
+ def load_audio(path: Path, target_sr: int):
81
+ wav_np, sr = sf.read(str(path), always_2d=True)
82
+ wav = torch.from_numpy(wav_np.T).float()
83
+ if sr != target_sr:
84
+ wav = AF.resample(wav, sr, target_sr)
85
+ if wav.shape[0] == 1:
86
+ wav = wav.repeat(2, 1)
87
+ elif wav.shape[0] > 2:
88
+ wav = wav[:2, :]
89
+ return wav.unsqueeze(0)
90
+
91
+
92
+ def infer_chunked(model, audio, chunk_size=353280, context=132096):
93
+ center_size = chunk_size - 2 * context
94
+ if center_size <= 0:
95
+ raise RuntimeError("chunk_size must be larger than 2*context")
96
+ audio_len = audio.shape[-1]
97
+ padded = torch.nn.functional.pad(audio, (context, context), mode="replicate")
98
+ out = None
99
+ pos = 0
100
+ while pos < audio_len:
101
+ center_end = min(pos + center_size, audio_len)
102
+ valid_len = center_end - pos
103
+ chunk = padded[:, :, pos : pos + chunk_size]
104
+ if chunk.shape[-1] < chunk_size:
105
+ pad = chunk_size - chunk.shape[-1]
106
+ chunk = torch.nn.functional.pad(chunk, (0, pad), mode="replicate")
107
+ with torch.inference_mode():
108
+ if audio.is_cuda:
109
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
110
+ out_chunk = model(chunk)
111
+ else:
112
+ out_chunk = model(chunk)
113
+ # Normalize output shape to [B, C, T]
114
+ # Some checkpoints return [B, N, C, T] (multi-stem).
115
+ if out_chunk.ndim == 4:
116
+ out_chunk = out_chunk[:, 0, :, :]
117
+ elif out_chunk.ndim != 3:
118
+ raise RuntimeError(f"Unsupported output ndim={out_chunk.ndim}, shape={tuple(out_chunk.shape)}")
119
+
120
+ if out is None:
121
+ out = torch.zeros((out_chunk.shape[0], out_chunk.shape[1], audio_len), device=audio.device)
122
+
123
+ out[:, :, pos:center_end] = out_chunk[:, :, context : context + valid_len]
124
+ pos += center_size
125
+ return out
126
+
127
+
128
+ def eval_pair(name, teacher_yaml, teacher_ckpt, rokan_yaml, rokan_ckpt, wav_path):
129
+ t_model, t_sr = build_model_from_yaml(teacher_yaml)
130
+ r_model, r_sr = build_model_from_yaml(rokan_yaml)
131
+ if t_sr != r_sr:
132
+ raise RuntimeError(f"{name}: sample rate mismatch {t_sr} vs {r_sr}")
133
+ t_model.load_state_dict(clean_state_dict(teacher_ckpt), strict=False)
134
+ r_model.load_state_dict(clean_state_dict(rokan_ckpt), strict=False)
135
+ t_model = t_model.to(DEVICE).eval()
136
+ r_model = r_model.to(DEVICE).eval()
137
+
138
+ audio = load_audio(wav_path, t_sr).to(DEVICE)
139
+ tic = time.time()
140
+ t_out = infer_chunked(t_model, audio)
141
+ t_sec = time.time() - tic
142
+ tic = time.time()
143
+ r_out = infer_chunked(r_model, audio)
144
+ r_sec = time.time() - tic
145
+
146
+ diff = (t_out - r_out).float()
147
+ mae = diff.abs().mean().item()
148
+ rmse = torch.sqrt((diff ** 2).mean()).item()
149
+ max_abs = diff.abs().max().item()
150
+ return {
151
+ "name": name,
152
+ "sample_rate": t_sr,
153
+ "audio_seconds": float(audio.shape[-1]) / float(t_sr),
154
+ "teacher_sec": t_sec,
155
+ "rokan_sec": r_sec,
156
+ "mae": mae,
157
+ "rmse": rmse,
158
+ "max_abs": max_abs,
159
+ }
160
+
161
+
162
+ def main():
163
+ parser = argparse.ArgumentParser(description="Evaluate teacher vs RoKAN fidelity for BS and MelBand models")
164
+ parser.add_argument("--input_wav", type=str, default="")
165
+ args = parser.parse_args()
166
+
167
+ root = Path(__file__).resolve().parent
168
+ input_dir = root / "input"
169
+ wav_path = Path(args.input_wav) if args.input_wav else None
170
+ if wav_path is None:
171
+ wavs = sorted(input_dir.glob("*.wav"))
172
+ if not wavs:
173
+ raise RuntimeError("No wav in input/. Set --input_wav explicitly.")
174
+ wav_path = wavs[0]
175
+ if not wav_path.exists():
176
+ raise RuntimeError(f"Input wav not found: {wav_path}")
177
+
178
+ pairs = [
179
+ (
180
+ "BS-Rofo-SW-Fixed",
181
+ root / "dataset/Models/BS-Rofo-SW-Fixed.yaml",
182
+ root / "dataset/Models/BS-Rofo-SW-Fixed.ckpt",
183
+ root / "converted_models/BS-Rofo-SW-Fixed_rokan.yaml",
184
+ root / "converted_models/BS-Rofo-SW-Fixed_rokan.ckpt",
185
+ ),
186
+ (
187
+ "MelBand denoise",
188
+ root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.yaml",
189
+ root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt",
190
+ root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.yaml",
191
+ root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.ckpt",
192
+ ),
193
+ ]
194
+
195
+ rows = []
196
+ for row in pairs:
197
+ name, ty, tc, ry, rc = row
198
+ missing = [str(p) for p in (ty, tc, ry, rc) if not p.exists()]
199
+ if missing:
200
+ rows.append({"name": name, "error": "missing files: " + ", ".join(missing)})
201
+ continue
202
+ try:
203
+ rows.append(eval_pair(name, ty, tc, ry, rc, wav_path))
204
+ except Exception as e:
205
+ rows.append({"name": name, "error": str(e)})
206
+
207
+ out_path = root / "converted_models" / "eval_fidelity_report.md"
208
+ lines = []
209
+ lines.append("# RoKAN Fidelity Report")
210
+ lines.append("")
211
+ lines.append(f"- input_wav: `{wav_path}`")
212
+ lines.append(f"- device: `{DEVICE}`")
213
+ lines.append("")
214
+ for r in rows:
215
+ lines.append(f"## {r['name']}")
216
+ if "error" in r:
217
+ lines.append(f"- status: FAIL")
218
+ lines.append(f"- error: `{r['error']}`")
219
+ else:
220
+ lines.append("- status: OK")
221
+ lines.append(f"- sample_rate: {r['sample_rate']}")
222
+ lines.append(f"- audio_seconds: {r['audio_seconds']:.2f}")
223
+ lines.append(f"- teacher_infer_sec: {r['teacher_sec']:.2f}")
224
+ lines.append(f"- rokan_infer_sec: {r['rokan_sec']:.2f}")
225
+ lines.append(f"- mae: {r['mae']:.8f}")
226
+ lines.append(f"- rmse: {r['rmse']:.8f}")
227
+ lines.append(f"- max_abs: {r['max_abs']:.8f}")
228
+ lines.append("")
229
+
230
+ out_path.write_text("\n".join(lines), encoding="utf-8")
231
+ print(f"wrote: {out_path}")
232
+
233
+
234
+ if __name__ == "__main__":
235
+ main()
236
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio
3
+ einops
4
+ rotary-embedding-torch
5
+ librosa
6
+ soundfile
7
+ pyyaml
8
+ beartype
9
+ tqdm
run_infer_rokan.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import soundfile as sf
5
+ import torchaudio.functional as AF
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.bs_roformer import BSRoformer
9
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
10
+
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ MODEL_YAML = os.environ.get("MODEL_YAML", "bs_rokan.yaml")
14
+ MODEL_CKPT = os.environ.get("MODEL_CKPT", "bs_rokan.ckpt")
15
+ INPUT_DIR = os.environ.get("INPUT_DIR", os.path.expanduser("~/BS-RoKAN-lab/input"))
16
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/BS-RoKAN-lab/RoKAN output"))
17
+
18
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
19
+
20
+ def _env_int(name: str, default: int) -> int:
21
+ v = os.environ.get(name, str(default)).strip()
22
+ try:
23
+ return int(v)
24
+ except Exception:
25
+ return default
26
+
27
+ def _env_bool(name: str, default: bool) -> bool:
28
+ v = os.environ.get(name)
29
+ if v is None:
30
+ return default
31
+ return v.strip().lower() in ("1", "true", "yes", "y", "on")
32
+
33
+
34
+ def load_model():
35
+ with open(MODEL_YAML, "r") as f:
36
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
37
+
38
+ model_cfg = cfg["model"]
39
+ audio_cfg = cfg["audio"]
40
+
41
+ kwargs = dict(
42
+ dim=model_cfg["dim"],
43
+ depth=model_cfg["depth"],
44
+ stereo=model_cfg.get("stereo", True),
45
+ num_stems=model_cfg.get("num_stems", 1),
46
+ time_transformer_depth=model_cfg.get("time_transformer_depth", 1),
47
+ freq_transformer_depth=model_cfg.get("freq_transformer_depth", 1),
48
+ linear_transformer_depth=model_cfg.get("linear_transformer_depth", 0),
49
+ freqs_per_bands=tuple(model_cfg.get("freqs_per_bands", (
50
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
51
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
52
+ 2, 2, 2, 2,
53
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
54
+ 12, 12, 12, 12, 12, 12, 12, 12,
55
+ 24, 24, 24, 24, 24, 24, 24, 24,
56
+ 48, 48, 48, 48, 48, 48, 48, 48,
57
+ 128, 129,
58
+ ))),
59
+ dim_head=model_cfg.get("dim_head", 64),
60
+ heads=model_cfg.get("heads", 8),
61
+ attn_dropout=model_cfg.get("attn_dropout", 0.0),
62
+ ff_dropout=model_cfg.get("ff_dropout", 0.0),
63
+ flash_attn=False,
64
+ dim_freqs_in=model_cfg.get("dim_freqs_in", 1025),
65
+ stft_n_fft=model_cfg.get("stft_n_fft", 2048),
66
+ stft_hop_length=model_cfg.get("stft_hop_length", 512),
67
+ stft_win_length=model_cfg.get("stft_win_length", 2048),
68
+ stft_normalized=model_cfg.get("stft_normalized", False),
69
+ mask_estimator_depth=model_cfg.get("mask_estimator_depth", 2),
70
+ multi_stft_resolution_loss_weight=model_cfg.get("multi_stft_resolution_loss_weight", 1.0),
71
+ multi_stft_resolutions_window_sizes=tuple(
72
+ model_cfg.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256))
73
+ ),
74
+ multi_stft_hop_size=model_cfg.get("multi_stft_hop_size", 147),
75
+ multi_stft_normalized=model_cfg.get("multi_stft_normalized", False),
76
+ mlp_expansion_factor=model_cfg.get("mlp_expansion_factor", 4),
77
+ use_torch_checkpoint=model_cfg.get("use_torch_checkpoint", False),
78
+ skip_connection=model_cfg.get("skip_connection", False),
79
+ sage_attention=model_cfg.get("sage_attention", False),
80
+ use_kan=model_cfg.get("use_kan", False),
81
+ kan_grid_size=model_cfg.get("kan_grid_size", 5),
82
+ )
83
+
84
+ print("Building model...")
85
+ model_cls = MelBandRoformer if "num_bands" in model_cfg else BSRoformer
86
+ if model_cls is MelBandRoformer:
87
+ kwargs["num_bands"] = model_cfg.get("num_bands", 60)
88
+ kwargs["sample_rate"] = model_cfg.get("sample_rate", audio_cfg.get("sample_rate", 44100))
89
+ model = model_cls(**kwargs).to(DEVICE)
90
+ model.eval()
91
+
92
+ print("Loading checkpoint...")
93
+ ckpt = torch.load(MODEL_CKPT, map_location="cpu")
94
+
95
+ if "state_dict" in ckpt:
96
+ state = ckpt["state_dict"]
97
+ elif "model" in ckpt:
98
+ state = ckpt["model"]
99
+ else:
100
+ state = ckpt
101
+
102
+ clean_state = {}
103
+ for k, v in state.items():
104
+ if k.startswith("model."):
105
+ clean_state[k[len("model."):]] = v
106
+ else:
107
+ clean_state[k] = v
108
+
109
+ missing, unexpected = model.load_state_dict(clean_state, strict=False)
110
+ print("missing:", len(missing), "unexpected:", len(unexpected))
111
+
112
+ # Optional inference optimizations (safe defaults off unless env says so)
113
+ if DEVICE == "cuda":
114
+ if _env_bool("INFER_TF32", True):
115
+ torch.backends.cuda.matmul.allow_tf32 = True
116
+ torch.backends.cudnn.allow_tf32 = True
117
+ try:
118
+ torch.set_float32_matmul_precision("high")
119
+ except Exception:
120
+ pass
121
+
122
+ if _env_bool("INFER_COMPILE", False) and hasattr(torch, "compile"):
123
+ try:
124
+ model = torch.compile(model)
125
+ print("torch.compile enabled")
126
+ except Exception as e:
127
+ print(f"torch.compile skipped: {e}")
128
+
129
+ return model, audio_cfg["sample_rate"]
130
+
131
+
132
+ def load_audio(path: str, target_sr: int) -> torch.Tensor:
133
+ audio_np, sr = sf.read(path, always_2d=True)
134
+ audio = torch.from_numpy(audio_np.T).float()
135
+
136
+ if sr != target_sr:
137
+ audio = AF.resample(audio, sr, target_sr)
138
+
139
+ if audio.shape[0] == 1:
140
+ audio = audio.repeat(2, 1)
141
+ elif audio.shape[0] > 2:
142
+ audio = audio[:2, :]
143
+
144
+ return audio.unsqueeze(0).to(DEVICE)
145
+
146
+
147
+ def separate_with_context(model: torch.nn.Module, audio: torch.Tensor) -> torch.Tensor:
148
+ # Tunable via env for Colab optimization / VRAM tradeoffs
149
+ chunk_size = _env_int("INFER_CHUNK_SIZE", 353280)
150
+ context = _env_int("INFER_CONTEXT", 132096)
151
+ center_size = chunk_size - 2 * context
152
+
153
+ if center_size <= 0:
154
+ raise RuntimeError("chunk_size must be larger than 2 * context")
155
+
156
+ audio_len = audio.shape[-1]
157
+ padded = F.pad(audio, (context, context), mode="replicate")
158
+ output = torch.zeros((1, audio.shape[1], audio_len), device=DEVICE)
159
+
160
+ pos = 0
161
+ while pos < audio_len:
162
+ center_end = min(pos + center_size, audio_len)
163
+ valid_len = center_end - pos
164
+
165
+ chunk_start = pos
166
+ chunk_end = pos + chunk_size
167
+ chunk = padded[:, :, chunk_start:chunk_end]
168
+
169
+ if chunk.shape[-1] < chunk_size:
170
+ pad = chunk_size - chunk.shape[-1]
171
+ chunk = F.pad(chunk, (0, pad), mode="replicate")
172
+
173
+ with torch.inference_mode():
174
+ if DEVICE == "cuda" and _env_bool("INFER_AMP", True):
175
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
176
+ out_chunk = model(chunk)
177
+ else:
178
+ out_chunk = model(chunk)
179
+
180
+ center = out_chunk[:, :, context:context + valid_len]
181
+ output[:, :, pos:center_end] = center
182
+
183
+ pos += center_size
184
+
185
+ return output
186
+
187
+
188
+ def main():
189
+ model, sample_rate_target = load_model()
190
+
191
+ wav_files = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(".wav")]
192
+ if not wav_files:
193
+ raise RuntimeError(f"No wav files found in input folder: {INPUT_DIR}")
194
+
195
+ for wav_name in wav_files:
196
+ in_path = os.path.join(INPUT_DIR, wav_name)
197
+ out_path = os.path.join(OUTPUT_DIR, wav_name)
198
+
199
+ print(f"Processing: {wav_name}")
200
+
201
+ audio = load_audio(in_path, sample_rate_target)
202
+ out = separate_with_context(model, audio)
203
+
204
+ out_np = out.squeeze(0).detach().cpu().T.numpy()
205
+ sf.write(out_path, out_np, sample_rate_target)
206
+
207
+ del audio, out, out_np
208
+ if DEVICE == "cuda" and _env_bool("INFER_EMPTY_CACHE", False):
209
+ torch.cuda.empty_cache()
210
+
211
+ print(f"Saved: {out_path}")
212
+
213
+ print("All done.")
214
+
215
+
216
+ if __name__ == "__main__":
217
+ main()
train_rokan.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torchaudio
7
+ import yaml
8
+ import argparse
9
+ import time
10
+ from models.bs_roformer.bs_roformer import BSRoformer
11
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
12
+
13
+ def set_requires_grad_selective(model):
14
+ for param in model.parameters():
15
+ param.requires_grad = False
16
+ unfrozen_count = 0
17
+ for name, param in model.named_parameters():
18
+ if name.endswith('.spline_weight') or name.endswith('.spline_gate'):
19
+ param.requires_grad = True
20
+ unfrozen_count += 1
21
+ print(f"[*] Training: Unfroze {unfrozen_count} KAN tensors")
22
+ return model
23
+
24
+ class SimpleAudioDataset(Dataset):
25
+ def __init__(self, vocab_dir, inst_dir, sample_rate=44100, chunk_seconds=4.0):
26
+ self.vocab_dir = vocab_dir
27
+ self.inst_dir = inst_dir
28
+ self.sample_rate = sample_rate
29
+ self.chunk_size = int(sample_rate * chunk_seconds)
30
+ vocab_files = set([os.path.basename(f) for f in glob.glob(os.path.join(vocab_dir, "*.wav"))])
31
+ inst_files = set([os.path.basename(f) for f in glob.glob(os.path.join(inst_dir, "*.wav"))])
32
+ self.matched_files = list(vocab_files.intersection(inst_files))
33
+ if not self.matched_files:
34
+ print("WARNING: No matching .wav files found!")
35
+
36
+ def __len__(self): return len(self.matched_files)
37
+
38
+ def _read_and_pad(self, path):
39
+ import soundfile as sf
40
+ import numpy as np
41
+ data, sr = sf.read(path, always_2d=True)
42
+ audio = torch.from_numpy(data.T).float()
43
+ if sr != self.sample_rate:
44
+ audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
45
+ if audio.shape[0] == 1: audio = audio.repeat(2, 1)
46
+ elif audio.shape[0] > 2: audio = audio[:2, :]
47
+ if audio.shape[-1] > self.chunk_size:
48
+ start = torch.randint(0, audio.shape[-1] - self.chunk_size, (1,)).item()
49
+ audio = audio[:, start:start+self.chunk_size]
50
+ else:
51
+ pad = self.chunk_size - audio.shape[-1]
52
+ audio = torch.nn.functional.pad(audio, (0, pad))
53
+ return audio
54
+
55
+ def __getitem__(self, idx):
56
+ filename = self.matched_files[idx]
57
+ vocals = self._read_and_pad(os.path.join(self.vocab_dir, filename))
58
+ insts = self._read_and_pad(os.path.join(self.inst_dir, filename))
59
+ mix = vocals + insts
60
+ return mix, vocals
61
+
62
+ def train():
63
+ parser = argparse.ArgumentParser(description="BS-RoKAN Fine-Tuning")
64
+ parser.add_argument("--config", required=True, help="Path to rokan.yaml")
65
+ parser.add_argument("--ckpt", required=True, help="Path to rokan.ckpt")
66
+ parser.add_argument("--output_dir", default="./", help="Where to save checkpoints")
67
+ parser.add_argument("--batch_size", type=int, default=1)
68
+ parser.add_argument("--lr", type=float, default=1e-4)
69
+ parser.add_argument("--gate_lr", type=float, default=1e-3)
70
+ parser.add_argument("--epochs", type=int, default=100)
71
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
72
+ parser.add_argument("--save_every", type=int, default=5)
73
+ parser.add_argument("--num_workers", type=int, default=4)
74
+ args = parser.parse_args()
75
+
76
+ # Load config
77
+ with open(args.config, 'r') as f:
78
+ config = yaml.load(f, Loader=yaml.FullLoader)
79
+
80
+ m_cfg = dict(config['model'])
81
+ for k in ['freqs_per_bands', 'multi_stft_resolutions_window_sizes']:
82
+ if k in m_cfg: m_cfg[k] = tuple(m_cfg[k])
83
+
84
+ model_cls = MelBandRoformer if 'num_bands' in m_cfg else BSRoformer
85
+ model = model_cls(**m_cfg)
86
+ if os.path.exists(args.ckpt):
87
+ model.load_state_dict(torch.load(args.ckpt, map_location='cpu'), strict=False)
88
+ model = model.to(args.device)
89
+
90
+ if args.device == 'cuda' and hasattr(torch, 'compile'):
91
+ try: model = torch.compile(model)
92
+ except: pass
93
+
94
+ model = set_requires_grad_selective(model)
95
+ model.train()
96
+
97
+ dataset = SimpleAudioDataset('dataset/vocals', 'dataset/instrumentals')
98
+ if len(dataset) == 0:
99
+ print("\n[!] Dataset empty. Exit.")
100
+ return
101
+
102
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, pin_memory=True)
103
+
104
+ gate_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_gate')]
105
+ spline_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_weight')]
106
+ optimizer = torch.optim.AdamW([
107
+ {'params': gate_params, 'lr': args.gate_lr},
108
+ {'params': spline_params, 'lr': args.lr},
109
+ ], weight_decay=1e-4)
110
+
111
+ try: from torch.amp import GradScaler; scaler = GradScaler(args.device)
112
+ except: scaler = None
113
+
114
+ for epoch in range(1, args.epochs + 1):
115
+ epoch_loss = 0.0
116
+ for batch_idx, (mix, vocals) in enumerate(dataloader):
117
+ mix = mix.to(args.device); vocals = vocals.to(args.device)
118
+ optimizer.zero_grad()
119
+ with torch.amp.autocast(device_type=args.device, dtype=torch.float16):
120
+ loss = model(mix, target=vocals)
121
+ if scaler: scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
122
+ else: loss.backward(); optimizer.step()
123
+ epoch_loss += loss.item()
124
+
125
+ # PCへの負荷低減のための休憩
126
+ time.sleep(0.2)
127
+
128
+ if (batch_idx+1) % 10 == 0:
129
+ print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(dataloader)} | Loss: {loss.item():.4f}")
130
+
131
+ print(f"==> Epoch {epoch} Average Loss: {epoch_loss/len(dataloader):.4f}")
132
+ if epoch % args.save_every == 0:
133
+ os.makedirs(args.output_dir, exist_ok=True)
134
+ save_path = os.path.join(args.output_dir, f"checkpoint_ep{epoch}.ckpt")
135
+ torch.save(model.state_dict(), save_path)
136
+ gate_vals = [p.item() for n, p in model.named_parameters() if n.endswith('.spline_gate')]
137
+ avg_gate = sum(abs(v) for v in gate_vals) / len(gate_vals) if gate_vals else 0
138
+ print(f"[*] Saved: {save_path} | Avg|gate|: {avg_gate:.4f}")
139
+
140
+ if __name__ == "__main__":
141
+ train()