| import os |
| import glob |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| import torchaudio |
| import yaml |
| import argparse |
| import time |
| from models.bs_roformer.bs_roformer import BSRoformer |
| from models.bs_roformer.mel_band_roformer import MelBandRoformer |
|
|
| def set_requires_grad_selective(model): |
| for param in model.parameters(): |
| param.requires_grad = False |
| unfrozen_count = 0 |
| for name, param in model.named_parameters(): |
| if name.endswith('.spline_weight') or name.endswith('.spline_gate'): |
| param.requires_grad = True |
| unfrozen_count += 1 |
| print(f"[*] Training: Unfroze {unfrozen_count} KAN tensors") |
| return model |
|
|
| class SimpleAudioDataset(Dataset): |
| def __init__(self, vocab_dir, inst_dir, sample_rate=44100, chunk_seconds=4.0): |
| self.vocab_dir = vocab_dir |
| self.inst_dir = inst_dir |
| self.sample_rate = sample_rate |
| self.chunk_size = int(sample_rate * chunk_seconds) |
| vocab_files = set([os.path.basename(f) for f in glob.glob(os.path.join(vocab_dir, "*.wav"))]) |
| inst_files = set([os.path.basename(f) for f in glob.glob(os.path.join(inst_dir, "*.wav"))]) |
| self.matched_files = list(vocab_files.intersection(inst_files)) |
| if not self.matched_files: |
| print("WARNING: No matching .wav files found!") |
|
|
| def __len__(self): return len(self.matched_files) |
| |
| def _read_and_pad(self, path): |
| import soundfile as sf |
| import numpy as np |
| data, sr = sf.read(path, always_2d=True) |
| audio = torch.from_numpy(data.T).float() |
| if sr != self.sample_rate: |
| audio = torchaudio.functional.resample(audio, sr, self.sample_rate) |
| if audio.shape[0] == 1: audio = audio.repeat(2, 1) |
| elif audio.shape[0] > 2: audio = audio[:2, :] |
| if audio.shape[-1] > self.chunk_size: |
| start = torch.randint(0, audio.shape[-1] - self.chunk_size, (1,)).item() |
| audio = audio[:, start:start+self.chunk_size] |
| else: |
| pad = self.chunk_size - audio.shape[-1] |
| audio = torch.nn.functional.pad(audio, (0, pad)) |
| return audio |
|
|
| def __getitem__(self, idx): |
| filename = self.matched_files[idx] |
| vocals = self._read_and_pad(os.path.join(self.vocab_dir, filename)) |
| insts = self._read_and_pad(os.path.join(self.inst_dir, filename)) |
| mix = vocals + insts |
| return mix, vocals |
|
|
| def train(): |
| parser = argparse.ArgumentParser(description="BS-RoKAN Fine-Tuning") |
| parser.add_argument("--config", required=True, help="Path to rokan.yaml") |
| parser.add_argument("--ckpt", required=True, help="Path to rokan.ckpt") |
| parser.add_argument("--output_dir", default="./", help="Where to save checkpoints") |
| parser.add_argument("--batch_size", type=int, default=1) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--gate_lr", type=float, default=1e-3) |
| parser.add_argument("--epochs", type=int, default=100) |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--save_every", type=int, default=5) |
| parser.add_argument("--num_workers", type=int, default=4) |
| args = parser.parse_args() |
|
|
| |
| with open(args.config, 'r') as f: |
| config = yaml.load(f, Loader=yaml.FullLoader) |
| |
| m_cfg = dict(config['model']) |
| for k in ['freqs_per_bands', 'multi_stft_resolutions_window_sizes']: |
| if k in m_cfg: m_cfg[k] = tuple(m_cfg[k]) |
|
|
| model_cls = MelBandRoformer if 'num_bands' in m_cfg else BSRoformer |
| model = model_cls(**m_cfg) |
| if os.path.exists(args.ckpt): |
| model.load_state_dict(torch.load(args.ckpt, map_location='cpu'), strict=False) |
| model = model.to(args.device) |
|
|
| if args.device == 'cuda' and hasattr(torch, 'compile'): |
| try: model = torch.compile(model) |
| except: pass |
|
|
| model = set_requires_grad_selective(model) |
| model.train() |
|
|
| dataset = SimpleAudioDataset('dataset/vocals', 'dataset/instrumentals') |
| if len(dataset) == 0: |
| print("\n[!] Dataset empty. Exit.") |
| return |
| |
| dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, pin_memory=True) |
| |
| gate_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_gate')] |
| spline_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_weight')] |
| optimizer = torch.optim.AdamW([ |
| {'params': gate_params, 'lr': args.gate_lr}, |
| {'params': spline_params, 'lr': args.lr}, |
| ], weight_decay=1e-4) |
| |
| try: from torch.amp import GradScaler; scaler = GradScaler(args.device) |
| except: scaler = None |
| |
| for epoch in range(1, args.epochs + 1): |
| epoch_loss = 0.0 |
| for batch_idx, (mix, vocals) in enumerate(dataloader): |
| mix = mix.to(args.device); vocals = vocals.to(args.device) |
| optimizer.zero_grad() |
| with torch.amp.autocast(device_type=args.device, dtype=torch.float16): |
| loss = model(mix, target=vocals) |
| if scaler: scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update() |
| else: loss.backward(); optimizer.step() |
| epoch_loss += loss.item() |
|
|
| |
| time.sleep(0.2) |
| |
| if (batch_idx+1) % 10 == 0: |
| print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(dataloader)} | Loss: {loss.item():.4f}") |
| |
| print(f"==> Epoch {epoch} Average Loss: {epoch_loss/len(dataloader):.4f}") |
| if epoch % args.save_every == 0: |
| os.makedirs(args.output_dir, exist_ok=True) |
| save_path = os.path.join(args.output_dir, f"checkpoint_ep{epoch}.ckpt") |
| torch.save(model.state_dict(), save_path) |
| gate_vals = [p.item() for n, p in model.named_parameters() if n.endswith('.spline_gate')] |
| avg_gate = sum(abs(v) for v in gate_vals) / len(gate_vals) if gate_vals else 0 |
| print(f"[*] Saved: {save_path} | Avg|gate|: {avg_gate:.4f}") |
|
|
| if __name__ == "__main__": |
| train() |
|
|