Upload 12 files
Browse fileshere is everything. Hope this leads to Good results.
- BS_Base_Model.ckpt +3 -0
- BS_Base_Model.yaml +197 -0
- MelBand Base Model.ckpt +3 -0
- MelBand Base Model.yaml +72 -0
- README.md +67 -3
- agent_monitor.py +108 -0
- convert_bs_to_rokan.py +219 -0
- eval_fidelity_report.md +24 -0
- evaluate_rokan_fidelity.py +236 -0
- requirements.txt +9 -0
- run_infer_rokan.py +217 -0
- train_rokan.py +141 -0
BS_Base_Model.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0363fdc84906eb52c092b842c6dc1b231065d927604b35b6da6cbc1c38c28a6
|
| 3 |
+
size 1102136494
|
BS_Base_Model.yaml
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
audio:
|
| 2 |
+
chunk_size: 588800
|
| 3 |
+
dim_f: 1024
|
| 4 |
+
dim_t: 801
|
| 5 |
+
hop_length: 441
|
| 6 |
+
min_mean_abs: 0.0
|
| 7 |
+
n_fft: 2048
|
| 8 |
+
num_channels: 2
|
| 9 |
+
sample_rate: 44100
|
| 10 |
+
augmentations:
|
| 11 |
+
all:
|
| 12 |
+
channel_shuffle: 0.5
|
| 13 |
+
random_inverse: 0.1
|
| 14 |
+
random_polarity: 0.5
|
| 15 |
+
bass:
|
| 16 |
+
pitch_shift: 0.1
|
| 17 |
+
pitch_shift_max_semitones: 2
|
| 18 |
+
pitch_shift_min_semitones: -2
|
| 19 |
+
seven_band_parametric_eq: 0.1
|
| 20 |
+
seven_band_parametric_eq_max_gain_db: 6
|
| 21 |
+
seven_band_parametric_eq_min_gain_db: -3
|
| 22 |
+
tanh_distortion: 0.1
|
| 23 |
+
tanh_distortion_max: 0.5
|
| 24 |
+
tanh_distortion_min: 0.1
|
| 25 |
+
drums:
|
| 26 |
+
pitch_shift: 0.1
|
| 27 |
+
pitch_shift_max_semitones: 5
|
| 28 |
+
pitch_shift_min_semitones: -5
|
| 29 |
+
seven_band_parametric_eq: 0.1
|
| 30 |
+
seven_band_parametric_eq_max_gain_db: 9
|
| 31 |
+
seven_band_parametric_eq_min_gain_db: -9
|
| 32 |
+
tanh_distortion: 0.1
|
| 33 |
+
tanh_distortion_max: 0.6
|
| 34 |
+
tanh_distortion_min: 0.1
|
| 35 |
+
enable: true
|
| 36 |
+
loudness: true
|
| 37 |
+
loudness_max: 1.5
|
| 38 |
+
loudness_min: 0.5
|
| 39 |
+
mixup: true
|
| 40 |
+
mixup_loudness_max: 1.5
|
| 41 |
+
mixup_loudness_min: 0.5
|
| 42 |
+
mixup_probs: !!python/tuple
|
| 43 |
+
- 0.2
|
| 44 |
+
- 0.02
|
| 45 |
+
other:
|
| 46 |
+
gaussian_noise: 0.1
|
| 47 |
+
gaussian_noise_max_amplitude: 0.015
|
| 48 |
+
gaussian_noise_min_amplitude: 0.001
|
| 49 |
+
pitch_shift: 0.1
|
| 50 |
+
pitch_shift_max_semitones: 4
|
| 51 |
+
pitch_shift_min_semitones: -4
|
| 52 |
+
time_stretch: 0.1
|
| 53 |
+
time_stretch_max_rate: 1.25
|
| 54 |
+
time_stretch_min_rate: 0.8
|
| 55 |
+
vocals:
|
| 56 |
+
pitch_shift: 0.1
|
| 57 |
+
pitch_shift_max_semitones: 5
|
| 58 |
+
pitch_shift_min_semitones: -5
|
| 59 |
+
seven_band_parametric_eq: 0.1
|
| 60 |
+
seven_band_parametric_eq_max_gain_db: 9
|
| 61 |
+
seven_band_parametric_eq_min_gain_db: -9
|
| 62 |
+
tanh_distortion: 0.1
|
| 63 |
+
tanh_distortion_max: 0.7
|
| 64 |
+
tanh_distortion_min: 0.1
|
| 65 |
+
inference:
|
| 66 |
+
batch_size: 1
|
| 67 |
+
dim_t: 1101
|
| 68 |
+
normalize: false
|
| 69 |
+
num_overlap: 2
|
| 70 |
+
model:
|
| 71 |
+
attn_dropout: 0.1
|
| 72 |
+
depth: 12
|
| 73 |
+
dim: 256
|
| 74 |
+
dim_freqs_in: 1025
|
| 75 |
+
dim_head: 64
|
| 76 |
+
ff_dropout: 0.1
|
| 77 |
+
flash_attn: false
|
| 78 |
+
freq_transformer_depth: 1
|
| 79 |
+
freqs_per_bands:
|
| 80 |
+
- 2
|
| 81 |
+
- 2
|
| 82 |
+
- 2
|
| 83 |
+
- 2
|
| 84 |
+
- 2
|
| 85 |
+
- 2
|
| 86 |
+
- 2
|
| 87 |
+
- 2
|
| 88 |
+
- 2
|
| 89 |
+
- 2
|
| 90 |
+
- 2
|
| 91 |
+
- 2
|
| 92 |
+
- 2
|
| 93 |
+
- 2
|
| 94 |
+
- 2
|
| 95 |
+
- 2
|
| 96 |
+
- 2
|
| 97 |
+
- 2
|
| 98 |
+
- 2
|
| 99 |
+
- 2
|
| 100 |
+
- 2
|
| 101 |
+
- 2
|
| 102 |
+
- 2
|
| 103 |
+
- 2
|
| 104 |
+
- 4
|
| 105 |
+
- 4
|
| 106 |
+
- 4
|
| 107 |
+
- 4
|
| 108 |
+
- 4
|
| 109 |
+
- 4
|
| 110 |
+
- 4
|
| 111 |
+
- 4
|
| 112 |
+
- 4
|
| 113 |
+
- 4
|
| 114 |
+
- 4
|
| 115 |
+
- 4
|
| 116 |
+
- 12
|
| 117 |
+
- 12
|
| 118 |
+
- 12
|
| 119 |
+
- 12
|
| 120 |
+
- 12
|
| 121 |
+
- 12
|
| 122 |
+
- 12
|
| 123 |
+
- 12
|
| 124 |
+
- 24
|
| 125 |
+
- 24
|
| 126 |
+
- 24
|
| 127 |
+
- 24
|
| 128 |
+
- 24
|
| 129 |
+
- 24
|
| 130 |
+
- 24
|
| 131 |
+
- 24
|
| 132 |
+
- 48
|
| 133 |
+
- 48
|
| 134 |
+
- 48
|
| 135 |
+
- 48
|
| 136 |
+
- 48
|
| 137 |
+
- 48
|
| 138 |
+
- 48
|
| 139 |
+
- 48
|
| 140 |
+
- 128
|
| 141 |
+
- 129
|
| 142 |
+
heads: 8
|
| 143 |
+
kan_grid_size: 8
|
| 144 |
+
linear_transformer_depth: 0
|
| 145 |
+
mask_estimator_depth: 2
|
| 146 |
+
mlp_expansion_factor: 4
|
| 147 |
+
multi_stft_hop_size: 147
|
| 148 |
+
multi_stft_normalized: false
|
| 149 |
+
multi_stft_resolution_loss_weight: 1.0
|
| 150 |
+
multi_stft_resolutions_window_sizes:
|
| 151 |
+
- 4096
|
| 152 |
+
- 2048
|
| 153 |
+
- 1024
|
| 154 |
+
- 512
|
| 155 |
+
- 256
|
| 156 |
+
num_stems: 6
|
| 157 |
+
sage_attention: false
|
| 158 |
+
skip_connection: false
|
| 159 |
+
stereo: true
|
| 160 |
+
stft_hop_length: 512
|
| 161 |
+
stft_n_fft: 2048
|
| 162 |
+
stft_normalized: false
|
| 163 |
+
stft_win_length: 2048
|
| 164 |
+
time_transformer_depth: 1
|
| 165 |
+
use_kan: true
|
| 166 |
+
use_torch_checkpoint: false
|
| 167 |
+
training:
|
| 168 |
+
augmentation: false
|
| 169 |
+
augmentation_loudness: true
|
| 170 |
+
augmentation_loudness_max: 1.5
|
| 171 |
+
augmentation_loudness_min: 0.5
|
| 172 |
+
augmentation_loudness_type: 1
|
| 173 |
+
augmentation_mix: true
|
| 174 |
+
augmentation_type: simple1
|
| 175 |
+
batch_size: 2
|
| 176 |
+
coarse_loss_clip: true
|
| 177 |
+
ema_momentum: 0.999
|
| 178 |
+
grad_clip: 0
|
| 179 |
+
gradient_accumulation_steps: 1
|
| 180 |
+
instruments:
|
| 181 |
+
- bass
|
| 182 |
+
- drums
|
| 183 |
+
- other
|
| 184 |
+
- vocals
|
| 185 |
+
- guitar
|
| 186 |
+
- piano
|
| 187 |
+
lr: 1.0e-05
|
| 188 |
+
num_epochs: 1000
|
| 189 |
+
num_steps: 1000
|
| 190 |
+
optimizer: adam
|
| 191 |
+
other_fix: false
|
| 192 |
+
patience: 3
|
| 193 |
+
q: 0.95
|
| 194 |
+
reduce_factor: 0.95
|
| 195 |
+
target_instrument: null
|
| 196 |
+
use_amp: true
|
| 197 |
+
use_mp3_compress: false
|
MelBand Base Model.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2a9652c40d90519a5708898b8c32b8f90666e1f8ef95890f91cced72dc22ac8
|
| 3 |
+
size 1366088139
|
MelBand Base Model.yaml
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
audio:
|
| 2 |
+
chunk_size: 352800
|
| 3 |
+
dim_f: 1024
|
| 4 |
+
dim_t: 256
|
| 5 |
+
hop_length: 441
|
| 6 |
+
min_mean_abs: 0
|
| 7 |
+
n_fft: 2048
|
| 8 |
+
num_channels: 2
|
| 9 |
+
sample_rate: 44100
|
| 10 |
+
inference:
|
| 11 |
+
batch_size: 2
|
| 12 |
+
dim_t: 256
|
| 13 |
+
num_overlap: 4
|
| 14 |
+
model:
|
| 15 |
+
attn_dropout: 0
|
| 16 |
+
depth: 6
|
| 17 |
+
dim: 384
|
| 18 |
+
dim_freqs_in: 1025
|
| 19 |
+
dim_head: 64
|
| 20 |
+
ff_dropout: 0
|
| 21 |
+
flash_attn: false
|
| 22 |
+
freq_transformer_depth: 1
|
| 23 |
+
heads: 8
|
| 24 |
+
kan_grid_size: 8
|
| 25 |
+
mask_estimator_depth: 2
|
| 26 |
+
multi_stft_hop_size: 147
|
| 27 |
+
multi_stft_normalized: false
|
| 28 |
+
multi_stft_resolution_loss_weight: 1.0
|
| 29 |
+
multi_stft_resolutions_window_sizes:
|
| 30 |
+
- 4096
|
| 31 |
+
- 2048
|
| 32 |
+
- 1024
|
| 33 |
+
- 512
|
| 34 |
+
- 256
|
| 35 |
+
num_bands: 60
|
| 36 |
+
num_stems: 1
|
| 37 |
+
sage_attention: false
|
| 38 |
+
sample_rate: 44100
|
| 39 |
+
stereo: true
|
| 40 |
+
stft_hop_length: 441
|
| 41 |
+
stft_n_fft: 2048
|
| 42 |
+
stft_normalized: false
|
| 43 |
+
stft_win_length: 2048
|
| 44 |
+
time_transformer_depth: 1
|
| 45 |
+
use_kan: true
|
| 46 |
+
use_torch_checkpoint: false
|
| 47 |
+
training:
|
| 48 |
+
augmentation: false
|
| 49 |
+
augmentation_loudness: false
|
| 50 |
+
augmentation_loudness_max: 0
|
| 51 |
+
augmentation_loudness_min: 0
|
| 52 |
+
augmentation_loudness_type: 1
|
| 53 |
+
augmentation_mix: false
|
| 54 |
+
augmentation_type: null
|
| 55 |
+
batch_size: 2
|
| 56 |
+
coarse_loss_clip: false
|
| 57 |
+
ema_momentum: 0.999
|
| 58 |
+
grad_clip: 0
|
| 59 |
+
gradient_accumulation_steps: 1
|
| 60 |
+
instruments:
|
| 61 |
+
- dry
|
| 62 |
+
- other
|
| 63 |
+
lr: 1.0e-05
|
| 64 |
+
num_epochs: 1000
|
| 65 |
+
num_steps: 4032
|
| 66 |
+
optimizer: adam
|
| 67 |
+
other_fix: false
|
| 68 |
+
patience: 8
|
| 69 |
+
q: 0.95
|
| 70 |
+
reduce_factor: 0.95
|
| 71 |
+
target_instrument: dry
|
| 72 |
+
use_mp3_compress: false
|
README.md
CHANGED
|
@@ -1,3 +1,67 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Faster-RoKAN Core
|
| 2 |
+
|
| 3 |
+
Faster-RoKAN is a next-generation hybrid architecture that integrates Faster-KAN (Kolmogorov-Arnold Networks) into the BS-Roformer audio source separation model.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
- **Isomorphic Conversion**: Convert standard BS-Roformer or MelBand-Roformer models to the RoKAN architecture with ZERO fidelity loss (MAE ≈ 0.0).
|
| 7 |
+
- **Faster-KAN (RSWAF)**: Replaces linear MLP layers with Reflectional Switch Wavelet Activation Functions for efficient, expressive, and detailed non-linear learning. High-frequency artifacts are filtered out through smooth geometric spline curves.
|
| 8 |
+
- **Gentle Training**: Optimized for standard consumer hardware with thermal management considerations.
|
| 9 |
+
|
| 10 |
+
## Includes Base Model
|
| 11 |
+
To get you started immediately, we have included a pre-converted **`Base_Model.ckpt`** and **`Base_Model.yaml`** in this package.
|
| 12 |
+
This base model is already functioning perfectly. You skip the conversion step entirely and jump straight to fine-tuning it on your own dataset!
|
| 13 |
+
|
| 14 |
+
## Setup
|
| 15 |
+
```bash
|
| 16 |
+
pip install -r requirements.txt
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Usage
|
| 20 |
+
|
| 21 |
+
### 0. (Optional) How to Make Your Own RoKAN Model
|
| 22 |
+
If you want to use a different checkpoint rather than the provided `Base_Model`, you can convert your existing standard `.ckpt` to the RoKAN format automatically with `convert_bs_to_rokan.py`.
|
| 23 |
+
**(Note: You do NOT need to do this if you just want to use the included `Base_Model`.)**
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
python convert_bs_to_rokan.py \
|
| 27 |
+
--src_yaml dataset/Models/your_model.yaml \
|
| 28 |
+
--src_ckpt dataset/Models/your_model.ckpt \
|
| 29 |
+
--out_yaml converted/rokan.yaml \
|
| 30 |
+
--out_ckpt converted/rokan.ckpt
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
**How it works (For both BS & MelBand):**
|
| 34 |
+
The `convert_bs_to_rokan.py` script automatically analyzes your `.yaml` configuration to determine whether it is a **BS-Roformer** or a **MelBand-Roformer** (by checking for the `num_bands` parameter).
|
| 35 |
+
Depending on the architecture, it seamlessly intercepts the standard linear MLP components located inside the Siamese or Standard Transformer FeedForward blocks, and replaces them with our custom `FasterKANLinear` blocks. All base knowledge is perfectly preserved without any fidelity loss.
|
| 36 |
+
|
| 37 |
+
### 1. Fine-tuning
|
| 38 |
+
Train only the new KAN spline parameters on your dataset to remove high-frequency artifacts and teach the model geometric patterns. The script will automatically unfreeze *only* the new KAN parameters while keeping the base knowledge perfectly intact.
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python train_rokan.py --ckpt_path Base_Model.ckpt --yaml_path Base_Model.yaml
|
| 42 |
+
```
|
| 43 |
+
*(Store your vocal audio in `dataset/vocals/` and instrumental audio in `dataset/instrumentals/` before running).*
|
| 44 |
+
|
| 45 |
+
### 2. Inference
|
| 46 |
+
Run source separation using the pre-tuned or fine-tuned model:
|
| 47 |
+
```bash
|
| 48 |
+
python run_infer_rokan.py \
|
| 49 |
+
--model_path Base_Model.ckpt \
|
| 50 |
+
--config_path Base_Model.yaml \
|
| 51 |
+
--input_audio your_song.wav
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## Credits, Contact & Disclaimer
|
| 57 |
+
|
| 58 |
+
**All Method Made By Himadayon.**
|
| 59 |
+
**IMPORTANT:** If you release or distribute any models that utilize this architecture or are fine-tuned using this repository, you **must** explicitly explicitly credit `Himadayon` in your release notes or repository.
|
| 60 |
+
|
| 61 |
+
**Contact:**
|
| 62 |
+
If you have any questions or inquiries regarding this project, please send an email to:
|
| 63 |
+
📧 **Joker200702@gmail.com**
|
| 64 |
+
*(Please make sure to include a clear subject line and detailed contents in your email).*
|
| 65 |
+
|
| 66 |
+
**Disclaimer:**
|
| 67 |
+
For the purpose of experimental verification and architectural testing, existing base models originally developed by **unwa** and **Aname** were utilized during the development of this project.
|
agent_monitor.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import subprocess
|
| 5 |
+
import json
|
| 6 |
+
import urllib.request
|
| 7 |
+
|
| 8 |
+
# ==========================================================
|
| 9 |
+
# Terminal Agent (Gemini API) for BS-RoKAN 監視
|
| 10 |
+
# VRAM消費: 0GB / CPU負荷: 極小
|
| 11 |
+
# ==========================================================
|
| 12 |
+
|
| 13 |
+
# APIキーをファイルから読み込む
|
| 14 |
+
KEY_FILE = "APIKey From Google AI Studio.txt"
|
| 15 |
+
if os.path.exists(KEY_FILE):
|
| 16 |
+
with open(KEY_FILE, "r") as f:
|
| 17 |
+
API_KEY = f.read().strip()
|
| 18 |
+
else:
|
| 19 |
+
API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
| 20 |
+
|
| 21 |
+
MODEL_NAME = "gemini-3.1-flash-lite"
|
| 22 |
+
|
| 23 |
+
def analyze_logs_with_llm(log_buffer):
|
| 24 |
+
if not API_KEY:
|
| 25 |
+
print("[Agent] API_KEYがないため判定をスキップ(OK)")
|
| 26 |
+
return "OK"
|
| 27 |
+
|
| 28 |
+
system_instruction = "あなたは音声分離モデルBS-RoKANの学習監視エージェントです。以下の学習ログを見て、学習が順調か評価してください。"
|
| 29 |
+
prompt = f"{system_instruction} 出力は OK, LOWER_LR, RESTART のいずれか1語のみにしてください。 \n\nログ:\n" + "\n".join(log_buffer)
|
| 30 |
+
|
| 31 |
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_NAME}:generateContent?key={API_KEY}"
|
| 32 |
+
|
| 33 |
+
# Gemini API (REST) format
|
| 34 |
+
payload = {
|
| 35 |
+
"contents": [{
|
| 36 |
+
"parts": [{"text": prompt}]
|
| 37 |
+
}],
|
| 38 |
+
"generationConfig": {
|
| 39 |
+
"temperature": 0.1,
|
| 40 |
+
"maxOutputTokens": 10
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
req = urllib.request.Request(url, data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"})
|
| 46 |
+
with urllib.request.urlopen(req, timeout=15) as r:
|
| 47 |
+
response = json.loads(r.read())
|
| 48 |
+
# Extract text from Gemini response structure
|
| 49 |
+
decision = response["candidates"][0]["content"]["parts"][0]["text"].strip().upper()
|
| 50 |
+
|
| 51 |
+
if "LOWER_LR" in decision: return "LOWER_LR"
|
| 52 |
+
if "RESTART" in decision: return "RESTART"
|
| 53 |
+
return "OK"
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"[Agent] Gemini APIエラー: {e}")
|
| 56 |
+
return "OK"
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
print(f"[*] Gemini Terminal Agent 起動成功 (Model: {MODEL_NAME})")
|
| 60 |
+
print(f"[*] 学習プロセスを起動中...")
|
| 61 |
+
|
| 62 |
+
# RX 9070 XT想定: WSL2上でバッチサイズ2で開始
|
| 63 |
+
cmd = ["python", "-u", "train_rokan.py", "--batch_size", "2"]
|
| 64 |
+
|
| 65 |
+
while True:
|
| 66 |
+
print(f"\n[Agent] 訓練開始: {' '.join(cmd)}")
|
| 67 |
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
| 68 |
+
log_buffer = []
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
for line in process.stdout:
|
| 72 |
+
line = line.strip()
|
| 73 |
+
if not line: continue
|
| 74 |
+
print(line)
|
| 75 |
+
|
| 76 |
+
if "Loss" in line or "Saved:" in line:
|
| 77 |
+
log_buffer.append(line)
|
| 78 |
+
|
| 79 |
+
# セーブ(Epoch終了)ごとにGeminiで診断を行う
|
| 80 |
+
if "Saved:" in line and len(log_buffer) > 5:
|
| 81 |
+
decision = analyze_logs_with_llm(log_buffer[-30:])
|
| 82 |
+
if decision == "LOWER_LR":
|
| 83 |
+
print(f"[Agent] Geminiの判定: {decision} (学習率を下げて再開します)")
|
| 84 |
+
process.terminate()
|
| 85 |
+
if "--gate_lr" not in cmd:
|
| 86 |
+
cmd.extend(["--gate_lr", "5e-4"]) # 1e-3 -> 5e-4
|
| 87 |
+
break
|
| 88 |
+
elif decision == "RESTART":
|
| 89 |
+
print(f"[Agent] Geminiの判定: {decision} (異常検知につき再起動します)")
|
| 90 |
+
process.terminate()
|
| 91 |
+
time.sleep(5)
|
| 92 |
+
break
|
| 93 |
+
else:
|
| 94 |
+
print(f"[Agent] Geminiの判定: {decision} (順調です)")
|
| 95 |
+
log_buffer = [] # バッファをクリア
|
| 96 |
+
|
| 97 |
+
except KeyboardInterrupt:
|
| 98 |
+
print("\n[Agent] ユーザーによる中断。プロセスを終了します。")
|
| 99 |
+
process.terminate()
|
| 100 |
+
sys.exit(0)
|
| 101 |
+
|
| 102 |
+
process.wait()
|
| 103 |
+
if process.returncode != 0 and process.returncode is not None:
|
| 104 |
+
print(f"[Agent] 訓練プロセスが終了しました (Code: {process.returncode})。10秒後に再起動を試みます。")
|
| 105 |
+
time.sleep(10)
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
main()
|
convert_bs_to_rokan.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
convert_bs_to_rokan.py
|
| 3 |
+
=======================
|
| 4 |
+
Universal converter: Any standard BS-Roformer checkpoint → Faster-RoKAN
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python convert_bs_to_rokan.py \\
|
| 8 |
+
--src_yaml dataset/Models/BS-Rofo-SW-Fixed.yaml \\
|
| 9 |
+
--src_ckpt dataset/Models/BS-Rofo-SW-Fixed.ckpt \\
|
| 10 |
+
--out_yaml bs_rokan_sw.yaml \\
|
| 11 |
+
--out_ckpt bs_rokan_sw.ckpt \\
|
| 12 |
+
--grid_size 8
|
| 13 |
+
|
| 14 |
+
What it does:
|
| 15 |
+
1. Reads the source YAML and builds a matching BSRoformer with use_kan=True
|
| 16 |
+
2. Loads the source checkpoint
|
| 17 |
+
3. Copies ALL compatible weights (Attention, norms, band-split, mask-estimator)
|
| 18 |
+
4. Remaps FeedForward Linear weights -> FasterKANLinear.base_weight
|
| 19 |
+
net.1.weight -> net.1.base_weight (first projection)
|
| 20 |
+
net.4.weight -> net.3.base_weight (second projection)
|
| 21 |
+
5. Saves the new Faster-RoKAN checkpoint + YAML
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
import inspect
|
| 27 |
+
import argparse
|
| 28 |
+
import torch
|
| 29 |
+
import yaml
|
| 30 |
+
|
| 31 |
+
sys.path.insert(0, '/home/boss/BS-RoKAN-lab')
|
| 32 |
+
from models.bs_roformer.bs_roformer import BSRoformer
|
| 33 |
+
from models.bs_roformer.mel_band_roformer import MelBandRoformer
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ── YAML helpers ──────────────────────────────────────────────────────────────
|
| 37 |
+
|
| 38 |
+
def load_yaml_fullloader(path):
|
| 39 |
+
with open(path, 'r') as f:
|
| 40 |
+
return yaml.load(f, Loader=yaml.FullLoader)
|
| 41 |
+
|
| 42 |
+
def load_yaml_strip_tags(path):
|
| 43 |
+
"""Fallback: strip !!python/tuple tags before loading."""
|
| 44 |
+
with open(path, 'r') as f:
|
| 45 |
+
raw = f.read()
|
| 46 |
+
raw = raw.replace('!!python/tuple', '')
|
| 47 |
+
return yaml.safe_load(raw)
|
| 48 |
+
|
| 49 |
+
def load_yaml_any(path):
|
| 50 |
+
try:
|
| 51 |
+
return load_yaml_fullloader(path)
|
| 52 |
+
except Exception:
|
| 53 |
+
return load_yaml_strip_tags(path)
|
| 54 |
+
|
| 55 |
+
def ensure_tuples(cfg):
|
| 56 |
+
"""Make sure tuple fields are actual tuples (beartype requirement)."""
|
| 57 |
+
for key in ('freqs_per_bands', 'multi_stft_resolutions_window_sizes'):
|
| 58 |
+
if key in cfg and not isinstance(cfg[key], tuple):
|
| 59 |
+
cfg[key] = tuple(cfg[key])
|
| 60 |
+
return cfg
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ── Checkpoint helpers ────────────────────────────────────────────────────────
|
| 64 |
+
|
| 65 |
+
def load_ckpt_flexible(path):
|
| 66 |
+
sd = torch.load(path, map_location='cpu')
|
| 67 |
+
if isinstance(sd, dict):
|
| 68 |
+
if 'state_dict' in sd:
|
| 69 |
+
sd = sd['state_dict']
|
| 70 |
+
elif 'model' in sd:
|
| 71 |
+
sd = sd['model']
|
| 72 |
+
# Strip model. prefix if present
|
| 73 |
+
return {(k[6:] if k.startswith('model.') else k): v for k, v in sd.items()}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ── Model builder ─────────────────────────────────────────────────────────────
|
| 77 |
+
|
| 78 |
+
def build_rokan(src_cfg, grid_size):
|
| 79 |
+
"""Build Faster-RoKAN with same arch as source config."""
|
| 80 |
+
m = dict(src_cfg) # copy
|
| 81 |
+
m = ensure_tuples(m)
|
| 82 |
+
m['use_kan'] = True
|
| 83 |
+
m['kan_grid_size'] = grid_size
|
| 84 |
+
m['flash_attn'] = False # Disable for stability during conversion
|
| 85 |
+
m.pop('use_torch_checkpoint', None) # Remove if present
|
| 86 |
+
m['use_torch_checkpoint'] = False
|
| 87 |
+
m['sage_attention'] = False
|
| 88 |
+
|
| 89 |
+
model_cls = MelBandRoformer if 'num_bands' in m else BSRoformer
|
| 90 |
+
model_sig = inspect.signature(model_cls.__init__)
|
| 91 |
+
allowed = set(model_sig.parameters.keys()) - {'self'}
|
| 92 |
+
filtered = {k: v for k, v in m.items() if k in allowed}
|
| 93 |
+
return model_cls(**filtered)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ── Weight mapping ────────────────────────────────────────────────────────────
|
| 97 |
+
|
| 98 |
+
def remap_and_load(src_sd, model):
|
| 99 |
+
"""
|
| 100 |
+
Load source weights into Faster-RoKAN model:
|
| 101 |
+
- Direct matches (Attention, norms, etc.) → copied as-is
|
| 102 |
+
- *.net.1.weight (FF first Linear) → *.net.1.base_weight
|
| 103 |
+
- *.net.4.weight (FF second Linear) → *.net.3.base_weight
|
| 104 |
+
- *.net.1.bias / *.net.4.bias → skipped (KAN has no bias term)
|
| 105 |
+
- Everything KAN-specific (spline, gate) → stays at init (to be learned)
|
| 106 |
+
"""
|
| 107 |
+
model_dict = model.state_dict()
|
| 108 |
+
matched = {}
|
| 109 |
+
remapped = 0
|
| 110 |
+
skipped = []
|
| 111 |
+
|
| 112 |
+
for k, v in src_sd.items():
|
| 113 |
+
# Direct match
|
| 114 |
+
if k in model_dict and v.shape == model_dict[k].shape:
|
| 115 |
+
matched[k] = v
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
# Remap FF Linear → base_weight
|
| 119 |
+
remap = None
|
| 120 |
+
if k.endswith('.net.1.weight'):
|
| 121 |
+
remap = k.replace('.net.1.weight', '.net.1.base_weight')
|
| 122 |
+
elif k.endswith('.net.4.weight'):
|
| 123 |
+
remap = k.replace('.net.4.weight', '.net.3.base_weight')
|
| 124 |
+
elif k.endswith('.net.1.bias'):
|
| 125 |
+
remap = k.replace('.net.1.bias', '.net.1.base_bias')
|
| 126 |
+
elif k.endswith('.net.4.bias'):
|
| 127 |
+
remap = k.replace('.net.4.bias', '.net.3.base_bias')
|
| 128 |
+
|
| 129 |
+
if remap and remap in model_dict and v.shape == model_dict[remap].shape:
|
| 130 |
+
matched[remap] = v
|
| 131 |
+
remapped += 1
|
| 132 |
+
else:
|
| 133 |
+
skipped.append(k)
|
| 134 |
+
|
| 135 |
+
model_dict.update(matched)
|
| 136 |
+
model.load_state_dict(model_dict)
|
| 137 |
+
|
| 138 |
+
print(f" Loaded: {len(matched)} tensors")
|
| 139 |
+
print(f" Remapped: {remapped} FF Linear → base_weight")
|
| 140 |
+
print(f" Skipped: {len(skipped)} (biases, incompatible shapes)")
|
| 141 |
+
|
| 142 |
+
# Show what KAN params remain random (to be trained)
|
| 143 |
+
kan_random = [k for k in model_dict if k not in matched]
|
| 144 |
+
kan_types = set(k.split('.')[-1] for k in kan_random)
|
| 145 |
+
print(f" KAN init: {len(kan_random)} tensors types={kan_types}")
|
| 146 |
+
return model
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ── YAML writer ───────────────────────────────────────────────────────────────
|
| 150 |
+
|
| 151 |
+
def write_out_yaml(src_yaml_path, out_yaml_path, grid_size):
|
| 152 |
+
"""Write output YAML with use_kan=True and kan_grid_size added."""
|
| 153 |
+
raw = load_yaml_fullloader(src_yaml_path)
|
| 154 |
+
raw['model']['use_kan'] = True
|
| 155 |
+
raw['model']['kan_grid_size'] = grid_size
|
| 156 |
+
raw['model']['flash_attn'] = False
|
| 157 |
+
raw['model']['use_torch_checkpoint'] = False
|
| 158 |
+
raw['model']['sage_attention'] = False
|
| 159 |
+
|
| 160 |
+
# Make sure tuple fields survive round-trip as plain lists (yaml.dump is fine)
|
| 161 |
+
for key in ('freqs_per_bands', 'multi_stft_resolutions_window_sizes'):
|
| 162 |
+
if key in raw['model'] and isinstance(raw['model'][key], tuple):
|
| 163 |
+
raw['model'][key] = list(raw['model'][key])
|
| 164 |
+
|
| 165 |
+
with open(out_yaml_path, 'w') as f:
|
| 166 |
+
yaml.dump(raw, f, default_flow_style=False, allow_unicode=True)
|
| 167 |
+
print(f" Wrote YAML: {out_yaml_path}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
| 171 |
+
|
| 172 |
+
def main():
|
| 173 |
+
parser = argparse.ArgumentParser(description='Convert BS-Roformer → Faster-RoKAN')
|
| 174 |
+
parser.add_argument('--src_yaml', required=True, help='Source model YAML')
|
| 175 |
+
parser.add_argument('--src_ckpt', required=True, help='Source model checkpoint (.ckpt)')
|
| 176 |
+
parser.add_argument('--out_yaml', default='bs_rokan_converted.yaml', help='Output YAML path')
|
| 177 |
+
parser.add_argument('--out_ckpt', default='bs_rokan_converted.ckpt', help='Output checkpoint path')
|
| 178 |
+
parser.add_argument('--grid_size', type=int, default=8, help='Faster-KAN grid size (wavelet count)')
|
| 179 |
+
args = parser.parse_args()
|
| 180 |
+
|
| 181 |
+
print(f"\n[*] BS-Roformer → Faster-RoKAN Converter")
|
| 182 |
+
print(f" src_yaml : {args.src_yaml}")
|
| 183 |
+
print(f" src_ckpt : {args.src_ckpt}")
|
| 184 |
+
print(f" out_yaml : {args.out_yaml}")
|
| 185 |
+
print(f" out_ckpt : {args.out_ckpt}")
|
| 186 |
+
print(f" grid_size: {args.grid_size}\n")
|
| 187 |
+
|
| 188 |
+
# 1. Load source config
|
| 189 |
+
print("[1/4] Loading source YAML...")
|
| 190 |
+
src_raw = load_yaml_any(args.src_yaml)
|
| 191 |
+
src_cfg = src_raw['model']
|
| 192 |
+
src_cfg = ensure_tuples(src_cfg)
|
| 193 |
+
print(f" dim={src_cfg['dim']}, depth={src_cfg['depth']}, stereo={src_cfg.get('stereo')}")
|
| 194 |
+
|
| 195 |
+
# 2. Build Faster-RoKAN model
|
| 196 |
+
print("\n[2/4] Building Faster-RoKAN model...")
|
| 197 |
+
model = build_rokan(src_cfg, args.grid_size)
|
| 198 |
+
total_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 199 |
+
print(f" Model built. Parameters: {total_params:.1f}M")
|
| 200 |
+
|
| 201 |
+
# 3. Load & remap weights
|
| 202 |
+
print("\n[3/4] Loading source checkpoint and remapping weights...")
|
| 203 |
+
src_sd = load_ckpt_flexible(args.src_ckpt)
|
| 204 |
+
print(f" Source checkpoint has {len(src_sd)} tensors")
|
| 205 |
+
model = remap_and_load(src_sd, model)
|
| 206 |
+
|
| 207 |
+
# 4. Save
|
| 208 |
+
print("\n[4/4] Saving Faster-RoKAN...")
|
| 209 |
+
torch.save(model.state_dict(), args.out_ckpt)
|
| 210 |
+
print(f" Saved checkpoint: {args.out_ckpt}")
|
| 211 |
+
|
| 212 |
+
write_out_yaml(args.src_yaml, args.out_yaml, args.grid_size)
|
| 213 |
+
|
| 214 |
+
print("\n[*] Conversion complete!")
|
| 215 |
+
print(f" Inference: MODEL_YAML={args.out_yaml} MODEL_CKPT={args.out_ckpt} python run_infer_rokan.py")
|
| 216 |
+
print(f" Training : python train_rokan.py (update ckpt_path in script to {args.out_ckpt})")
|
| 217 |
+
|
| 218 |
+
if __name__ == '__main__':
|
| 219 |
+
main()
|
eval_fidelity_report.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RoKAN Fidelity Report
|
| 2 |
+
|
| 3 |
+
- input_wav: `input/Arctic Tundra.wav`
|
| 4 |
+
- device: `cuda`
|
| 5 |
+
|
| 6 |
+
## BS-Roformer
|
| 7 |
+
- status: OK
|
| 8 |
+
- sample_rate: 44100
|
| 9 |
+
- audio_seconds: 152.63
|
| 10 |
+
- teacher_infer_sec: 30.59
|
| 11 |
+
- rokan_infer_sec: 102.77
|
| 12 |
+
- mae: 0.00000000
|
| 13 |
+
- rmse: 0.00000000
|
| 14 |
+
- max_abs: 0.00000004
|
| 15 |
+
|
| 16 |
+
## MelBand-Roformer
|
| 17 |
+
- status: OK
|
| 18 |
+
- sample_rate: 44100
|
| 19 |
+
- audio_seconds: 152.63
|
| 20 |
+
- teacher_infer_sec: 21.30
|
| 21 |
+
- rokan_infer_sec: 79.54
|
| 22 |
+
- mae: 0.00000384
|
| 23 |
+
- rmse: 0.00000723
|
| 24 |
+
- max_abs: 0.00013021
|
evaluate_rokan_fidelity.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio.functional as AF
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
from models.bs_roformer.bs_roformer import BSRoformer
|
| 11 |
+
from models.bs_roformer.mel_band_roformer import MelBandRoformer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_cfg(path: Path):
|
| 18 |
+
with path.open("r", encoding="utf-8") as f:
|
| 19 |
+
return yaml.load(f, Loader=yaml.FullLoader)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def clean_state_dict(ckpt_path: Path):
|
| 23 |
+
sd = torch.load(str(ckpt_path), map_location="cpu")
|
| 24 |
+
if isinstance(sd, dict) and "state_dict" in sd:
|
| 25 |
+
sd = sd["state_dict"]
|
| 26 |
+
if isinstance(sd, dict) and "model" in sd:
|
| 27 |
+
sd = sd["model"]
|
| 28 |
+
cleaned = {}
|
| 29 |
+
for k, v in sd.items():
|
| 30 |
+
cleaned[k[6:] if k.startswith("model.") else k] = v
|
| 31 |
+
return cleaned
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def build_model_from_yaml(yaml_path: Path):
|
| 35 |
+
cfg = load_cfg(yaml_path)
|
| 36 |
+
m = cfg["model"]
|
| 37 |
+
audio_cfg = cfg["audio"]
|
| 38 |
+
kwargs = dict(
|
| 39 |
+
dim=m["dim"],
|
| 40 |
+
depth=m["depth"],
|
| 41 |
+
stereo=m.get("stereo", True),
|
| 42 |
+
num_stems=m.get("num_stems", 1),
|
| 43 |
+
time_transformer_depth=m.get("time_transformer_depth", 1),
|
| 44 |
+
freq_transformer_depth=m.get("freq_transformer_depth", 1),
|
| 45 |
+
linear_transformer_depth=m.get("linear_transformer_depth", 0),
|
| 46 |
+
dim_head=m.get("dim_head", 64),
|
| 47 |
+
heads=m.get("heads", 8),
|
| 48 |
+
attn_dropout=m.get("attn_dropout", 0.0),
|
| 49 |
+
ff_dropout=m.get("ff_dropout", 0.0),
|
| 50 |
+
flash_attn=False,
|
| 51 |
+
dim_freqs_in=m.get("dim_freqs_in", 1025),
|
| 52 |
+
stft_n_fft=m.get("stft_n_fft", 2048),
|
| 53 |
+
stft_hop_length=m.get("stft_hop_length", 512),
|
| 54 |
+
stft_win_length=m.get("stft_win_length", 2048),
|
| 55 |
+
stft_normalized=m.get("stft_normalized", False),
|
| 56 |
+
mask_estimator_depth=m.get("mask_estimator_depth", 2),
|
| 57 |
+
multi_stft_resolution_loss_weight=m.get("multi_stft_resolution_loss_weight", 1.0),
|
| 58 |
+
multi_stft_resolutions_window_sizes=tuple(m.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256))),
|
| 59 |
+
multi_stft_hop_size=m.get("multi_stft_hop_size", 147),
|
| 60 |
+
multi_stft_normalized=m.get("multi_stft_normalized", False),
|
| 61 |
+
mlp_expansion_factor=m.get("mlp_expansion_factor", 4),
|
| 62 |
+
use_torch_checkpoint=False,
|
| 63 |
+
skip_connection=m.get("skip_connection", False),
|
| 64 |
+
sage_attention=m.get("sage_attention", False),
|
| 65 |
+
use_kan=m.get("use_kan", False),
|
| 66 |
+
kan_grid_size=m.get("kan_grid_size", 8),
|
| 67 |
+
)
|
| 68 |
+
if "freqs_per_bands" in m:
|
| 69 |
+
kwargs["freqs_per_bands"] = tuple(m["freqs_per_bands"])
|
| 70 |
+
|
| 71 |
+
if "num_bands" in m:
|
| 72 |
+
kwargs["num_bands"] = m.get("num_bands", 60)
|
| 73 |
+
kwargs["sample_rate"] = m.get("sample_rate", audio_cfg.get("sample_rate", 44100))
|
| 74 |
+
model = MelBandRoformer(**kwargs)
|
| 75 |
+
else:
|
| 76 |
+
model = BSRoformer(**kwargs)
|
| 77 |
+
return model, audio_cfg["sample_rate"]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_audio(path: Path, target_sr: int):
|
| 81 |
+
wav_np, sr = sf.read(str(path), always_2d=True)
|
| 82 |
+
wav = torch.from_numpy(wav_np.T).float()
|
| 83 |
+
if sr != target_sr:
|
| 84 |
+
wav = AF.resample(wav, sr, target_sr)
|
| 85 |
+
if wav.shape[0] == 1:
|
| 86 |
+
wav = wav.repeat(2, 1)
|
| 87 |
+
elif wav.shape[0] > 2:
|
| 88 |
+
wav = wav[:2, :]
|
| 89 |
+
return wav.unsqueeze(0)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def infer_chunked(model, audio, chunk_size=353280, context=132096):
|
| 93 |
+
center_size = chunk_size - 2 * context
|
| 94 |
+
if center_size <= 0:
|
| 95 |
+
raise RuntimeError("chunk_size must be larger than 2*context")
|
| 96 |
+
audio_len = audio.shape[-1]
|
| 97 |
+
padded = torch.nn.functional.pad(audio, (context, context), mode="replicate")
|
| 98 |
+
out = None
|
| 99 |
+
pos = 0
|
| 100 |
+
while pos < audio_len:
|
| 101 |
+
center_end = min(pos + center_size, audio_len)
|
| 102 |
+
valid_len = center_end - pos
|
| 103 |
+
chunk = padded[:, :, pos : pos + chunk_size]
|
| 104 |
+
if chunk.shape[-1] < chunk_size:
|
| 105 |
+
pad = chunk_size - chunk.shape[-1]
|
| 106 |
+
chunk = torch.nn.functional.pad(chunk, (0, pad), mode="replicate")
|
| 107 |
+
with torch.inference_mode():
|
| 108 |
+
if audio.is_cuda:
|
| 109 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 110 |
+
out_chunk = model(chunk)
|
| 111 |
+
else:
|
| 112 |
+
out_chunk = model(chunk)
|
| 113 |
+
# Normalize output shape to [B, C, T]
|
| 114 |
+
# Some checkpoints return [B, N, C, T] (multi-stem).
|
| 115 |
+
if out_chunk.ndim == 4:
|
| 116 |
+
out_chunk = out_chunk[:, 0, :, :]
|
| 117 |
+
elif out_chunk.ndim != 3:
|
| 118 |
+
raise RuntimeError(f"Unsupported output ndim={out_chunk.ndim}, shape={tuple(out_chunk.shape)}")
|
| 119 |
+
|
| 120 |
+
if out is None:
|
| 121 |
+
out = torch.zeros((out_chunk.shape[0], out_chunk.shape[1], audio_len), device=audio.device)
|
| 122 |
+
|
| 123 |
+
out[:, :, pos:center_end] = out_chunk[:, :, context : context + valid_len]
|
| 124 |
+
pos += center_size
|
| 125 |
+
return out
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def eval_pair(name, teacher_yaml, teacher_ckpt, rokan_yaml, rokan_ckpt, wav_path):
|
| 129 |
+
t_model, t_sr = build_model_from_yaml(teacher_yaml)
|
| 130 |
+
r_model, r_sr = build_model_from_yaml(rokan_yaml)
|
| 131 |
+
if t_sr != r_sr:
|
| 132 |
+
raise RuntimeError(f"{name}: sample rate mismatch {t_sr} vs {r_sr}")
|
| 133 |
+
t_model.load_state_dict(clean_state_dict(teacher_ckpt), strict=False)
|
| 134 |
+
r_model.load_state_dict(clean_state_dict(rokan_ckpt), strict=False)
|
| 135 |
+
t_model = t_model.to(DEVICE).eval()
|
| 136 |
+
r_model = r_model.to(DEVICE).eval()
|
| 137 |
+
|
| 138 |
+
audio = load_audio(wav_path, t_sr).to(DEVICE)
|
| 139 |
+
tic = time.time()
|
| 140 |
+
t_out = infer_chunked(t_model, audio)
|
| 141 |
+
t_sec = time.time() - tic
|
| 142 |
+
tic = time.time()
|
| 143 |
+
r_out = infer_chunked(r_model, audio)
|
| 144 |
+
r_sec = time.time() - tic
|
| 145 |
+
|
| 146 |
+
diff = (t_out - r_out).float()
|
| 147 |
+
mae = diff.abs().mean().item()
|
| 148 |
+
rmse = torch.sqrt((diff ** 2).mean()).item()
|
| 149 |
+
max_abs = diff.abs().max().item()
|
| 150 |
+
return {
|
| 151 |
+
"name": name,
|
| 152 |
+
"sample_rate": t_sr,
|
| 153 |
+
"audio_seconds": float(audio.shape[-1]) / float(t_sr),
|
| 154 |
+
"teacher_sec": t_sec,
|
| 155 |
+
"rokan_sec": r_sec,
|
| 156 |
+
"mae": mae,
|
| 157 |
+
"rmse": rmse,
|
| 158 |
+
"max_abs": max_abs,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def main():
|
| 163 |
+
parser = argparse.ArgumentParser(description="Evaluate teacher vs RoKAN fidelity for BS and MelBand models")
|
| 164 |
+
parser.add_argument("--input_wav", type=str, default="")
|
| 165 |
+
args = parser.parse_args()
|
| 166 |
+
|
| 167 |
+
root = Path(__file__).resolve().parent
|
| 168 |
+
input_dir = root / "input"
|
| 169 |
+
wav_path = Path(args.input_wav) if args.input_wav else None
|
| 170 |
+
if wav_path is None:
|
| 171 |
+
wavs = sorted(input_dir.glob("*.wav"))
|
| 172 |
+
if not wavs:
|
| 173 |
+
raise RuntimeError("No wav in input/. Set --input_wav explicitly.")
|
| 174 |
+
wav_path = wavs[0]
|
| 175 |
+
if not wav_path.exists():
|
| 176 |
+
raise RuntimeError(f"Input wav not found: {wav_path}")
|
| 177 |
+
|
| 178 |
+
pairs = [
|
| 179 |
+
(
|
| 180 |
+
"BS-Rofo-SW-Fixed",
|
| 181 |
+
root / "dataset/Models/BS-Rofo-SW-Fixed.yaml",
|
| 182 |
+
root / "dataset/Models/BS-Rofo-SW-Fixed.ckpt",
|
| 183 |
+
root / "converted_models/BS-Rofo-SW-Fixed_rokan.yaml",
|
| 184 |
+
root / "converted_models/BS-Rofo-SW-Fixed_rokan.ckpt",
|
| 185 |
+
),
|
| 186 |
+
(
|
| 187 |
+
"MelBand denoise",
|
| 188 |
+
root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.yaml",
|
| 189 |
+
root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt",
|
| 190 |
+
root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.yaml",
|
| 191 |
+
root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.ckpt",
|
| 192 |
+
),
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
rows = []
|
| 196 |
+
for row in pairs:
|
| 197 |
+
name, ty, tc, ry, rc = row
|
| 198 |
+
missing = [str(p) for p in (ty, tc, ry, rc) if not p.exists()]
|
| 199 |
+
if missing:
|
| 200 |
+
rows.append({"name": name, "error": "missing files: " + ", ".join(missing)})
|
| 201 |
+
continue
|
| 202 |
+
try:
|
| 203 |
+
rows.append(eval_pair(name, ty, tc, ry, rc, wav_path))
|
| 204 |
+
except Exception as e:
|
| 205 |
+
rows.append({"name": name, "error": str(e)})
|
| 206 |
+
|
| 207 |
+
out_path = root / "converted_models" / "eval_fidelity_report.md"
|
| 208 |
+
lines = []
|
| 209 |
+
lines.append("# RoKAN Fidelity Report")
|
| 210 |
+
lines.append("")
|
| 211 |
+
lines.append(f"- input_wav: `{wav_path}`")
|
| 212 |
+
lines.append(f"- device: `{DEVICE}`")
|
| 213 |
+
lines.append("")
|
| 214 |
+
for r in rows:
|
| 215 |
+
lines.append(f"## {r['name']}")
|
| 216 |
+
if "error" in r:
|
| 217 |
+
lines.append(f"- status: FAIL")
|
| 218 |
+
lines.append(f"- error: `{r['error']}`")
|
| 219 |
+
else:
|
| 220 |
+
lines.append("- status: OK")
|
| 221 |
+
lines.append(f"- sample_rate: {r['sample_rate']}")
|
| 222 |
+
lines.append(f"- audio_seconds: {r['audio_seconds']:.2f}")
|
| 223 |
+
lines.append(f"- teacher_infer_sec: {r['teacher_sec']:.2f}")
|
| 224 |
+
lines.append(f"- rokan_infer_sec: {r['rokan_sec']:.2f}")
|
| 225 |
+
lines.append(f"- mae: {r['mae']:.8f}")
|
| 226 |
+
lines.append(f"- rmse: {r['rmse']:.8f}")
|
| 227 |
+
lines.append(f"- max_abs: {r['max_abs']:.8f}")
|
| 228 |
+
lines.append("")
|
| 229 |
+
|
| 230 |
+
out_path.write_text("\n".join(lines), encoding="utf-8")
|
| 231 |
+
print(f"wrote: {out_path}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
main()
|
| 236 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchaudio
|
| 3 |
+
einops
|
| 4 |
+
rotary-embedding-torch
|
| 5 |
+
librosa
|
| 6 |
+
soundfile
|
| 7 |
+
pyyaml
|
| 8 |
+
beartype
|
| 9 |
+
tqdm
|
run_infer_rokan.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import torchaudio.functional as AF
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from models.bs_roformer.bs_roformer import BSRoformer
|
| 9 |
+
from models.bs_roformer.mel_band_roformer import MelBandRoformer
|
| 10 |
+
|
| 11 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
|
| 13 |
+
MODEL_YAML = os.environ.get("MODEL_YAML", "bs_rokan.yaml")
|
| 14 |
+
MODEL_CKPT = os.environ.get("MODEL_CKPT", "bs_rokan.ckpt")
|
| 15 |
+
INPUT_DIR = os.environ.get("INPUT_DIR", os.path.expanduser("~/BS-RoKAN-lab/input"))
|
| 16 |
+
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", os.path.expanduser("~/BS-RoKAN-lab/RoKAN output"))
|
| 17 |
+
|
| 18 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
def _env_int(name: str, default: int) -> int:
|
| 21 |
+
v = os.environ.get(name, str(default)).strip()
|
| 22 |
+
try:
|
| 23 |
+
return int(v)
|
| 24 |
+
except Exception:
|
| 25 |
+
return default
|
| 26 |
+
|
| 27 |
+
def _env_bool(name: str, default: bool) -> bool:
|
| 28 |
+
v = os.environ.get(name)
|
| 29 |
+
if v is None:
|
| 30 |
+
return default
|
| 31 |
+
return v.strip().lower() in ("1", "true", "yes", "y", "on")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_model():
|
| 35 |
+
with open(MODEL_YAML, "r") as f:
|
| 36 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 37 |
+
|
| 38 |
+
model_cfg = cfg["model"]
|
| 39 |
+
audio_cfg = cfg["audio"]
|
| 40 |
+
|
| 41 |
+
kwargs = dict(
|
| 42 |
+
dim=model_cfg["dim"],
|
| 43 |
+
depth=model_cfg["depth"],
|
| 44 |
+
stereo=model_cfg.get("stereo", True),
|
| 45 |
+
num_stems=model_cfg.get("num_stems", 1),
|
| 46 |
+
time_transformer_depth=model_cfg.get("time_transformer_depth", 1),
|
| 47 |
+
freq_transformer_depth=model_cfg.get("freq_transformer_depth", 1),
|
| 48 |
+
linear_transformer_depth=model_cfg.get("linear_transformer_depth", 0),
|
| 49 |
+
freqs_per_bands=tuple(model_cfg.get("freqs_per_bands", (
|
| 50 |
+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
| 51 |
+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
| 52 |
+
2, 2, 2, 2,
|
| 53 |
+
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
|
| 54 |
+
12, 12, 12, 12, 12, 12, 12, 12,
|
| 55 |
+
24, 24, 24, 24, 24, 24, 24, 24,
|
| 56 |
+
48, 48, 48, 48, 48, 48, 48, 48,
|
| 57 |
+
128, 129,
|
| 58 |
+
))),
|
| 59 |
+
dim_head=model_cfg.get("dim_head", 64),
|
| 60 |
+
heads=model_cfg.get("heads", 8),
|
| 61 |
+
attn_dropout=model_cfg.get("attn_dropout", 0.0),
|
| 62 |
+
ff_dropout=model_cfg.get("ff_dropout", 0.0),
|
| 63 |
+
flash_attn=False,
|
| 64 |
+
dim_freqs_in=model_cfg.get("dim_freqs_in", 1025),
|
| 65 |
+
stft_n_fft=model_cfg.get("stft_n_fft", 2048),
|
| 66 |
+
stft_hop_length=model_cfg.get("stft_hop_length", 512),
|
| 67 |
+
stft_win_length=model_cfg.get("stft_win_length", 2048),
|
| 68 |
+
stft_normalized=model_cfg.get("stft_normalized", False),
|
| 69 |
+
mask_estimator_depth=model_cfg.get("mask_estimator_depth", 2),
|
| 70 |
+
multi_stft_resolution_loss_weight=model_cfg.get("multi_stft_resolution_loss_weight", 1.0),
|
| 71 |
+
multi_stft_resolutions_window_sizes=tuple(
|
| 72 |
+
model_cfg.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256))
|
| 73 |
+
),
|
| 74 |
+
multi_stft_hop_size=model_cfg.get("multi_stft_hop_size", 147),
|
| 75 |
+
multi_stft_normalized=model_cfg.get("multi_stft_normalized", False),
|
| 76 |
+
mlp_expansion_factor=model_cfg.get("mlp_expansion_factor", 4),
|
| 77 |
+
use_torch_checkpoint=model_cfg.get("use_torch_checkpoint", False),
|
| 78 |
+
skip_connection=model_cfg.get("skip_connection", False),
|
| 79 |
+
sage_attention=model_cfg.get("sage_attention", False),
|
| 80 |
+
use_kan=model_cfg.get("use_kan", False),
|
| 81 |
+
kan_grid_size=model_cfg.get("kan_grid_size", 5),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
print("Building model...")
|
| 85 |
+
model_cls = MelBandRoformer if "num_bands" in model_cfg else BSRoformer
|
| 86 |
+
if model_cls is MelBandRoformer:
|
| 87 |
+
kwargs["num_bands"] = model_cfg.get("num_bands", 60)
|
| 88 |
+
kwargs["sample_rate"] = model_cfg.get("sample_rate", audio_cfg.get("sample_rate", 44100))
|
| 89 |
+
model = model_cls(**kwargs).to(DEVICE)
|
| 90 |
+
model.eval()
|
| 91 |
+
|
| 92 |
+
print("Loading checkpoint...")
|
| 93 |
+
ckpt = torch.load(MODEL_CKPT, map_location="cpu")
|
| 94 |
+
|
| 95 |
+
if "state_dict" in ckpt:
|
| 96 |
+
state = ckpt["state_dict"]
|
| 97 |
+
elif "model" in ckpt:
|
| 98 |
+
state = ckpt["model"]
|
| 99 |
+
else:
|
| 100 |
+
state = ckpt
|
| 101 |
+
|
| 102 |
+
clean_state = {}
|
| 103 |
+
for k, v in state.items():
|
| 104 |
+
if k.startswith("model."):
|
| 105 |
+
clean_state[k[len("model."):]] = v
|
| 106 |
+
else:
|
| 107 |
+
clean_state[k] = v
|
| 108 |
+
|
| 109 |
+
missing, unexpected = model.load_state_dict(clean_state, strict=False)
|
| 110 |
+
print("missing:", len(missing), "unexpected:", len(unexpected))
|
| 111 |
+
|
| 112 |
+
# Optional inference optimizations (safe defaults off unless env says so)
|
| 113 |
+
if DEVICE == "cuda":
|
| 114 |
+
if _env_bool("INFER_TF32", True):
|
| 115 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 116 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 117 |
+
try:
|
| 118 |
+
torch.set_float32_matmul_precision("high")
|
| 119 |
+
except Exception:
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
if _env_bool("INFER_COMPILE", False) and hasattr(torch, "compile"):
|
| 123 |
+
try:
|
| 124 |
+
model = torch.compile(model)
|
| 125 |
+
print("torch.compile enabled")
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"torch.compile skipped: {e}")
|
| 128 |
+
|
| 129 |
+
return model, audio_cfg["sample_rate"]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_audio(path: str, target_sr: int) -> torch.Tensor:
|
| 133 |
+
audio_np, sr = sf.read(path, always_2d=True)
|
| 134 |
+
audio = torch.from_numpy(audio_np.T).float()
|
| 135 |
+
|
| 136 |
+
if sr != target_sr:
|
| 137 |
+
audio = AF.resample(audio, sr, target_sr)
|
| 138 |
+
|
| 139 |
+
if audio.shape[0] == 1:
|
| 140 |
+
audio = audio.repeat(2, 1)
|
| 141 |
+
elif audio.shape[0] > 2:
|
| 142 |
+
audio = audio[:2, :]
|
| 143 |
+
|
| 144 |
+
return audio.unsqueeze(0).to(DEVICE)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def separate_with_context(model: torch.nn.Module, audio: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
# Tunable via env for Colab optimization / VRAM tradeoffs
|
| 149 |
+
chunk_size = _env_int("INFER_CHUNK_SIZE", 353280)
|
| 150 |
+
context = _env_int("INFER_CONTEXT", 132096)
|
| 151 |
+
center_size = chunk_size - 2 * context
|
| 152 |
+
|
| 153 |
+
if center_size <= 0:
|
| 154 |
+
raise RuntimeError("chunk_size must be larger than 2 * context")
|
| 155 |
+
|
| 156 |
+
audio_len = audio.shape[-1]
|
| 157 |
+
padded = F.pad(audio, (context, context), mode="replicate")
|
| 158 |
+
output = torch.zeros((1, audio.shape[1], audio_len), device=DEVICE)
|
| 159 |
+
|
| 160 |
+
pos = 0
|
| 161 |
+
while pos < audio_len:
|
| 162 |
+
center_end = min(pos + center_size, audio_len)
|
| 163 |
+
valid_len = center_end - pos
|
| 164 |
+
|
| 165 |
+
chunk_start = pos
|
| 166 |
+
chunk_end = pos + chunk_size
|
| 167 |
+
chunk = padded[:, :, chunk_start:chunk_end]
|
| 168 |
+
|
| 169 |
+
if chunk.shape[-1] < chunk_size:
|
| 170 |
+
pad = chunk_size - chunk.shape[-1]
|
| 171 |
+
chunk = F.pad(chunk, (0, pad), mode="replicate")
|
| 172 |
+
|
| 173 |
+
with torch.inference_mode():
|
| 174 |
+
if DEVICE == "cuda" and _env_bool("INFER_AMP", True):
|
| 175 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 176 |
+
out_chunk = model(chunk)
|
| 177 |
+
else:
|
| 178 |
+
out_chunk = model(chunk)
|
| 179 |
+
|
| 180 |
+
center = out_chunk[:, :, context:context + valid_len]
|
| 181 |
+
output[:, :, pos:center_end] = center
|
| 182 |
+
|
| 183 |
+
pos += center_size
|
| 184 |
+
|
| 185 |
+
return output
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def main():
|
| 189 |
+
model, sample_rate_target = load_model()
|
| 190 |
+
|
| 191 |
+
wav_files = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(".wav")]
|
| 192 |
+
if not wav_files:
|
| 193 |
+
raise RuntimeError(f"No wav files found in input folder: {INPUT_DIR}")
|
| 194 |
+
|
| 195 |
+
for wav_name in wav_files:
|
| 196 |
+
in_path = os.path.join(INPUT_DIR, wav_name)
|
| 197 |
+
out_path = os.path.join(OUTPUT_DIR, wav_name)
|
| 198 |
+
|
| 199 |
+
print(f"Processing: {wav_name}")
|
| 200 |
+
|
| 201 |
+
audio = load_audio(in_path, sample_rate_target)
|
| 202 |
+
out = separate_with_context(model, audio)
|
| 203 |
+
|
| 204 |
+
out_np = out.squeeze(0).detach().cpu().T.numpy()
|
| 205 |
+
sf.write(out_path, out_np, sample_rate_target)
|
| 206 |
+
|
| 207 |
+
del audio, out, out_np
|
| 208 |
+
if DEVICE == "cuda" and _env_bool("INFER_EMPTY_CACHE", False):
|
| 209 |
+
torch.cuda.empty_cache()
|
| 210 |
+
|
| 211 |
+
print(f"Saved: {out_path}")
|
| 212 |
+
|
| 213 |
+
print("All done.")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
main()
|
train_rokan.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
import torchaudio
|
| 7 |
+
import yaml
|
| 8 |
+
import argparse
|
| 9 |
+
import time
|
| 10 |
+
from models.bs_roformer.bs_roformer import BSRoformer
|
| 11 |
+
from models.bs_roformer.mel_band_roformer import MelBandRoformer
|
| 12 |
+
|
| 13 |
+
def set_requires_grad_selective(model):
|
| 14 |
+
for param in model.parameters():
|
| 15 |
+
param.requires_grad = False
|
| 16 |
+
unfrozen_count = 0
|
| 17 |
+
for name, param in model.named_parameters():
|
| 18 |
+
if name.endswith('.spline_weight') or name.endswith('.spline_gate'):
|
| 19 |
+
param.requires_grad = True
|
| 20 |
+
unfrozen_count += 1
|
| 21 |
+
print(f"[*] Training: Unfroze {unfrozen_count} KAN tensors")
|
| 22 |
+
return model
|
| 23 |
+
|
| 24 |
+
class SimpleAudioDataset(Dataset):
|
| 25 |
+
def __init__(self, vocab_dir, inst_dir, sample_rate=44100, chunk_seconds=4.0):
|
| 26 |
+
self.vocab_dir = vocab_dir
|
| 27 |
+
self.inst_dir = inst_dir
|
| 28 |
+
self.sample_rate = sample_rate
|
| 29 |
+
self.chunk_size = int(sample_rate * chunk_seconds)
|
| 30 |
+
vocab_files = set([os.path.basename(f) for f in glob.glob(os.path.join(vocab_dir, "*.wav"))])
|
| 31 |
+
inst_files = set([os.path.basename(f) for f in glob.glob(os.path.join(inst_dir, "*.wav"))])
|
| 32 |
+
self.matched_files = list(vocab_files.intersection(inst_files))
|
| 33 |
+
if not self.matched_files:
|
| 34 |
+
print("WARNING: No matching .wav files found!")
|
| 35 |
+
|
| 36 |
+
def __len__(self): return len(self.matched_files)
|
| 37 |
+
|
| 38 |
+
def _read_and_pad(self, path):
|
| 39 |
+
import soundfile as sf
|
| 40 |
+
import numpy as np
|
| 41 |
+
data, sr = sf.read(path, always_2d=True)
|
| 42 |
+
audio = torch.from_numpy(data.T).float()
|
| 43 |
+
if sr != self.sample_rate:
|
| 44 |
+
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
| 45 |
+
if audio.shape[0] == 1: audio = audio.repeat(2, 1)
|
| 46 |
+
elif audio.shape[0] > 2: audio = audio[:2, :]
|
| 47 |
+
if audio.shape[-1] > self.chunk_size:
|
| 48 |
+
start = torch.randint(0, audio.shape[-1] - self.chunk_size, (1,)).item()
|
| 49 |
+
audio = audio[:, start:start+self.chunk_size]
|
| 50 |
+
else:
|
| 51 |
+
pad = self.chunk_size - audio.shape[-1]
|
| 52 |
+
audio = torch.nn.functional.pad(audio, (0, pad))
|
| 53 |
+
return audio
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx):
|
| 56 |
+
filename = self.matched_files[idx]
|
| 57 |
+
vocals = self._read_and_pad(os.path.join(self.vocab_dir, filename))
|
| 58 |
+
insts = self._read_and_pad(os.path.join(self.inst_dir, filename))
|
| 59 |
+
mix = vocals + insts
|
| 60 |
+
return mix, vocals
|
| 61 |
+
|
| 62 |
+
def train():
|
| 63 |
+
parser = argparse.ArgumentParser(description="BS-RoKAN Fine-Tuning")
|
| 64 |
+
parser.add_argument("--config", required=True, help="Path to rokan.yaml")
|
| 65 |
+
parser.add_argument("--ckpt", required=True, help="Path to rokan.ckpt")
|
| 66 |
+
parser.add_argument("--output_dir", default="./", help="Where to save checkpoints")
|
| 67 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 68 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 69 |
+
parser.add_argument("--gate_lr", type=float, default=1e-3)
|
| 70 |
+
parser.add_argument("--epochs", type=int, default=100)
|
| 71 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 72 |
+
parser.add_argument("--save_every", type=int, default=5)
|
| 73 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
# Load config
|
| 77 |
+
with open(args.config, 'r') as f:
|
| 78 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 79 |
+
|
| 80 |
+
m_cfg = dict(config['model'])
|
| 81 |
+
for k in ['freqs_per_bands', 'multi_stft_resolutions_window_sizes']:
|
| 82 |
+
if k in m_cfg: m_cfg[k] = tuple(m_cfg[k])
|
| 83 |
+
|
| 84 |
+
model_cls = MelBandRoformer if 'num_bands' in m_cfg else BSRoformer
|
| 85 |
+
model = model_cls(**m_cfg)
|
| 86 |
+
if os.path.exists(args.ckpt):
|
| 87 |
+
model.load_state_dict(torch.load(args.ckpt, map_location='cpu'), strict=False)
|
| 88 |
+
model = model.to(args.device)
|
| 89 |
+
|
| 90 |
+
if args.device == 'cuda' and hasattr(torch, 'compile'):
|
| 91 |
+
try: model = torch.compile(model)
|
| 92 |
+
except: pass
|
| 93 |
+
|
| 94 |
+
model = set_requires_grad_selective(model)
|
| 95 |
+
model.train()
|
| 96 |
+
|
| 97 |
+
dataset = SimpleAudioDataset('dataset/vocals', 'dataset/instrumentals')
|
| 98 |
+
if len(dataset) == 0:
|
| 99 |
+
print("\n[!] Dataset empty. Exit.")
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, pin_memory=True)
|
| 103 |
+
|
| 104 |
+
gate_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_gate')]
|
| 105 |
+
spline_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_weight')]
|
| 106 |
+
optimizer = torch.optim.AdamW([
|
| 107 |
+
{'params': gate_params, 'lr': args.gate_lr},
|
| 108 |
+
{'params': spline_params, 'lr': args.lr},
|
| 109 |
+
], weight_decay=1e-4)
|
| 110 |
+
|
| 111 |
+
try: from torch.amp import GradScaler; scaler = GradScaler(args.device)
|
| 112 |
+
except: scaler = None
|
| 113 |
+
|
| 114 |
+
for epoch in range(1, args.epochs + 1):
|
| 115 |
+
epoch_loss = 0.0
|
| 116 |
+
for batch_idx, (mix, vocals) in enumerate(dataloader):
|
| 117 |
+
mix = mix.to(args.device); vocals = vocals.to(args.device)
|
| 118 |
+
optimizer.zero_grad()
|
| 119 |
+
with torch.amp.autocast(device_type=args.device, dtype=torch.float16):
|
| 120 |
+
loss = model(mix, target=vocals)
|
| 121 |
+
if scaler: scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
|
| 122 |
+
else: loss.backward(); optimizer.step()
|
| 123 |
+
epoch_loss += loss.item()
|
| 124 |
+
|
| 125 |
+
# PCへの負荷低減のための休憩
|
| 126 |
+
time.sleep(0.2)
|
| 127 |
+
|
| 128 |
+
if (batch_idx+1) % 10 == 0:
|
| 129 |
+
print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(dataloader)} | Loss: {loss.item():.4f}")
|
| 130 |
+
|
| 131 |
+
print(f"==> Epoch {epoch} Average Loss: {epoch_loss/len(dataloader):.4f}")
|
| 132 |
+
if epoch % args.save_every == 0:
|
| 133 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 134 |
+
save_path = os.path.join(args.output_dir, f"checkpoint_ep{epoch}.ckpt")
|
| 135 |
+
torch.save(model.state_dict(), save_path)
|
| 136 |
+
gate_vals = [p.item() for n, p in model.named_parameters() if n.endswith('.spline_gate')]
|
| 137 |
+
avg_gate = sum(abs(v) for v in gate_vals) / len(gate_vals) if gate_vals else 0
|
| 138 |
+
print(f"[*] Saved: {save_path} | Avg|gate|: {avg_gate:.4f}")
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
train()
|