import os import glob import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torchaudio import yaml import argparse import time from models.bs_roformer.bs_roformer import BSRoformer from models.bs_roformer.mel_band_roformer import MelBandRoformer def set_requires_grad_selective(model): for param in model.parameters(): param.requires_grad = False unfrozen_count = 0 for name, param in model.named_parameters(): if name.endswith('.spline_weight') or name.endswith('.spline_gate'): param.requires_grad = True unfrozen_count += 1 print(f"[*] Training: Unfroze {unfrozen_count} KAN tensors") return model class SimpleAudioDataset(Dataset): def __init__(self, vocab_dir, inst_dir, sample_rate=44100, chunk_seconds=4.0): self.vocab_dir = vocab_dir self.inst_dir = inst_dir self.sample_rate = sample_rate self.chunk_size = int(sample_rate * chunk_seconds) vocab_files = set([os.path.basename(f) for f in glob.glob(os.path.join(vocab_dir, "*.wav"))]) inst_files = set([os.path.basename(f) for f in glob.glob(os.path.join(inst_dir, "*.wav"))]) self.matched_files = list(vocab_files.intersection(inst_files)) if not self.matched_files: print("WARNING: No matching .wav files found!") def __len__(self): return len(self.matched_files) def _read_and_pad(self, path): import soundfile as sf import numpy as np data, sr = sf.read(path, always_2d=True) audio = torch.from_numpy(data.T).float() if sr != self.sample_rate: audio = torchaudio.functional.resample(audio, sr, self.sample_rate) if audio.shape[0] == 1: audio = audio.repeat(2, 1) elif audio.shape[0] > 2: audio = audio[:2, :] if audio.shape[-1] > self.chunk_size: start = torch.randint(0, audio.shape[-1] - self.chunk_size, (1,)).item() audio = audio[:, start:start+self.chunk_size] else: pad = self.chunk_size - audio.shape[-1] audio = torch.nn.functional.pad(audio, (0, pad)) return audio def __getitem__(self, idx): filename = self.matched_files[idx] vocals = self._read_and_pad(os.path.join(self.vocab_dir, filename)) insts = self._read_and_pad(os.path.join(self.inst_dir, filename)) mix = vocals + insts return mix, vocals def train(): parser = argparse.ArgumentParser(description="BS-RoKAN Fine-Tuning") parser.add_argument("--config", required=True, help="Path to rokan.yaml") parser.add_argument("--ckpt", required=True, help="Path to rokan.ckpt") parser.add_argument("--output_dir", default="./", help="Where to save checkpoints") parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--gate_lr", type=float, default=1e-3) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--save_every", type=int, default=5) parser.add_argument("--num_workers", type=int, default=4) args = parser.parse_args() # Load config with open(args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) m_cfg = dict(config['model']) for k in ['freqs_per_bands', 'multi_stft_resolutions_window_sizes']: if k in m_cfg: m_cfg[k] = tuple(m_cfg[k]) model_cls = MelBandRoformer if 'num_bands' in m_cfg else BSRoformer model = model_cls(**m_cfg) if os.path.exists(args.ckpt): model.load_state_dict(torch.load(args.ckpt, map_location='cpu'), strict=False) model = model.to(args.device) if args.device == 'cuda' and hasattr(torch, 'compile'): try: model = torch.compile(model) except: pass model = set_requires_grad_selective(model) model.train() dataset = SimpleAudioDataset('dataset/vocals', 'dataset/instrumentals') if len(dataset) == 0: print("\n[!] Dataset empty. Exit.") return dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, pin_memory=True) gate_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_gate')] spline_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_weight')] optimizer = torch.optim.AdamW([ {'params': gate_params, 'lr': args.gate_lr}, {'params': spline_params, 'lr': args.lr}, ], weight_decay=1e-4) try: from torch.amp import GradScaler; scaler = GradScaler(args.device) except: scaler = None for epoch in range(1, args.epochs + 1): epoch_loss = 0.0 for batch_idx, (mix, vocals) in enumerate(dataloader): mix = mix.to(args.device); vocals = vocals.to(args.device) optimizer.zero_grad() with torch.amp.autocast(device_type=args.device, dtype=torch.float16): loss = model(mix, target=vocals) if scaler: scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update() else: loss.backward(); optimizer.step() epoch_loss += loss.item() # PCへの負荷低減のための休憩 time.sleep(0.2) if (batch_idx+1) % 10 == 0: print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(dataloader)} | Loss: {loss.item():.4f}") print(f"==> Epoch {epoch} Average Loss: {epoch_loss/len(dataloader):.4f}") if epoch % args.save_every == 0: os.makedirs(args.output_dir, exist_ok=True) save_path = os.path.join(args.output_dir, f"checkpoint_ep{epoch}.ckpt") torch.save(model.state_dict(), save_path) gate_vals = [p.item() for n, p in model.named_parameters() if n.endswith('.spline_gate')] avg_gate = sum(abs(v) for v in gate_vals) / len(gate_vals) if gate_vals else 0 print(f"[*] Saved: {save_path} | Avg|gate|: {avg_gate:.4f}") if __name__ == "__main__": train()