First-RoKAN-Model / train_rokan.py
tekitoutarou's picture
Upload 12 files
f73ae00 verified
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
import yaml
import argparse
import time
from models.bs_roformer.bs_roformer import BSRoformer
from models.bs_roformer.mel_band_roformer import MelBandRoformer
def set_requires_grad_selective(model):
for param in model.parameters():
param.requires_grad = False
unfrozen_count = 0
for name, param in model.named_parameters():
if name.endswith('.spline_weight') or name.endswith('.spline_gate'):
param.requires_grad = True
unfrozen_count += 1
print(f"[*] Training: Unfroze {unfrozen_count} KAN tensors")
return model
class SimpleAudioDataset(Dataset):
def __init__(self, vocab_dir, inst_dir, sample_rate=44100, chunk_seconds=4.0):
self.vocab_dir = vocab_dir
self.inst_dir = inst_dir
self.sample_rate = sample_rate
self.chunk_size = int(sample_rate * chunk_seconds)
vocab_files = set([os.path.basename(f) for f in glob.glob(os.path.join(vocab_dir, "*.wav"))])
inst_files = set([os.path.basename(f) for f in glob.glob(os.path.join(inst_dir, "*.wav"))])
self.matched_files = list(vocab_files.intersection(inst_files))
if not self.matched_files:
print("WARNING: No matching .wav files found!")
def __len__(self): return len(self.matched_files)
def _read_and_pad(self, path):
import soundfile as sf
import numpy as np
data, sr = sf.read(path, always_2d=True)
audio = torch.from_numpy(data.T).float()
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
if audio.shape[0] == 1: audio = audio.repeat(2, 1)
elif audio.shape[0] > 2: audio = audio[:2, :]
if audio.shape[-1] > self.chunk_size:
start = torch.randint(0, audio.shape[-1] - self.chunk_size, (1,)).item()
audio = audio[:, start:start+self.chunk_size]
else:
pad = self.chunk_size - audio.shape[-1]
audio = torch.nn.functional.pad(audio, (0, pad))
return audio
def __getitem__(self, idx):
filename = self.matched_files[idx]
vocals = self._read_and_pad(os.path.join(self.vocab_dir, filename))
insts = self._read_and_pad(os.path.join(self.inst_dir, filename))
mix = vocals + insts
return mix, vocals
def train():
parser = argparse.ArgumentParser(description="BS-RoKAN Fine-Tuning")
parser.add_argument("--config", required=True, help="Path to rokan.yaml")
parser.add_argument("--ckpt", required=True, help="Path to rokan.ckpt")
parser.add_argument("--output_dir", default="./", help="Where to save checkpoints")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--gate_lr", type=float, default=1e-3)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--save_every", type=int, default=5)
parser.add_argument("--num_workers", type=int, default=4)
args = parser.parse_args()
# Load config
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
m_cfg = dict(config['model'])
for k in ['freqs_per_bands', 'multi_stft_resolutions_window_sizes']:
if k in m_cfg: m_cfg[k] = tuple(m_cfg[k])
model_cls = MelBandRoformer if 'num_bands' in m_cfg else BSRoformer
model = model_cls(**m_cfg)
if os.path.exists(args.ckpt):
model.load_state_dict(torch.load(args.ckpt, map_location='cpu'), strict=False)
model = model.to(args.device)
if args.device == 'cuda' and hasattr(torch, 'compile'):
try: model = torch.compile(model)
except: pass
model = set_requires_grad_selective(model)
model.train()
dataset = SimpleAudioDataset('dataset/vocals', 'dataset/instrumentals')
if len(dataset) == 0:
print("\n[!] Dataset empty. Exit.")
return
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, pin_memory=True)
gate_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_gate')]
spline_params = [p for n, p in model.named_parameters() if p.requires_grad and n.endswith('.spline_weight')]
optimizer = torch.optim.AdamW([
{'params': gate_params, 'lr': args.gate_lr},
{'params': spline_params, 'lr': args.lr},
], weight_decay=1e-4)
try: from torch.amp import GradScaler; scaler = GradScaler(args.device)
except: scaler = None
for epoch in range(1, args.epochs + 1):
epoch_loss = 0.0
for batch_idx, (mix, vocals) in enumerate(dataloader):
mix = mix.to(args.device); vocals = vocals.to(args.device)
optimizer.zero_grad()
with torch.amp.autocast(device_type=args.device, dtype=torch.float16):
loss = model(mix, target=vocals)
if scaler: scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
else: loss.backward(); optimizer.step()
epoch_loss += loss.item()
# PCへの負荷低減のための休憩
time.sleep(0.2)
if (batch_idx+1) % 10 == 0:
print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(dataloader)} | Loss: {loss.item():.4f}")
print(f"==> Epoch {epoch} Average Loss: {epoch_loss/len(dataloader):.4f}")
if epoch % args.save_every == 0:
os.makedirs(args.output_dir, exist_ok=True)
save_path = os.path.join(args.output_dir, f"checkpoint_ep{epoch}.ckpt")
torch.save(model.state_dict(), save_path)
gate_vals = [p.item() for n, p in model.named_parameters() if n.endswith('.spline_gate')]
avg_gate = sum(abs(v) for v in gate_vals) / len(gate_vals) if gate_vals else 0
print(f"[*] Saved: {save_path} | Avg|gate|: {avg_gate:.4f}")
if __name__ == "__main__":
train()