import argparse import time from pathlib import Path import soundfile as sf import torch import torchaudio.functional as AF import yaml from models.bs_roformer.bs_roformer import BSRoformer from models.bs_roformer.mel_band_roformer import MelBandRoformer DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_cfg(path: Path): with path.open("r", encoding="utf-8") as f: return yaml.load(f, Loader=yaml.FullLoader) def clean_state_dict(ckpt_path: Path): sd = torch.load(str(ckpt_path), map_location="cpu") if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] if isinstance(sd, dict) and "model" in sd: sd = sd["model"] cleaned = {} for k, v in sd.items(): cleaned[k[6:] if k.startswith("model.") else k] = v return cleaned def build_model_from_yaml(yaml_path: Path): cfg = load_cfg(yaml_path) m = cfg["model"] audio_cfg = cfg["audio"] kwargs = dict( dim=m["dim"], depth=m["depth"], stereo=m.get("stereo", True), num_stems=m.get("num_stems", 1), time_transformer_depth=m.get("time_transformer_depth", 1), freq_transformer_depth=m.get("freq_transformer_depth", 1), linear_transformer_depth=m.get("linear_transformer_depth", 0), dim_head=m.get("dim_head", 64), heads=m.get("heads", 8), attn_dropout=m.get("attn_dropout", 0.0), ff_dropout=m.get("ff_dropout", 0.0), flash_attn=False, dim_freqs_in=m.get("dim_freqs_in", 1025), stft_n_fft=m.get("stft_n_fft", 2048), stft_hop_length=m.get("stft_hop_length", 512), stft_win_length=m.get("stft_win_length", 2048), stft_normalized=m.get("stft_normalized", False), mask_estimator_depth=m.get("mask_estimator_depth", 2), multi_stft_resolution_loss_weight=m.get("multi_stft_resolution_loss_weight", 1.0), multi_stft_resolutions_window_sizes=tuple(m.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256))), multi_stft_hop_size=m.get("multi_stft_hop_size", 147), multi_stft_normalized=m.get("multi_stft_normalized", False), mlp_expansion_factor=m.get("mlp_expansion_factor", 4), use_torch_checkpoint=False, skip_connection=m.get("skip_connection", False), sage_attention=m.get("sage_attention", False), use_kan=m.get("use_kan", False), kan_grid_size=m.get("kan_grid_size", 8), ) if "freqs_per_bands" in m: kwargs["freqs_per_bands"] = tuple(m["freqs_per_bands"]) if "num_bands" in m: kwargs["num_bands"] = m.get("num_bands", 60) kwargs["sample_rate"] = m.get("sample_rate", audio_cfg.get("sample_rate", 44100)) model = MelBandRoformer(**kwargs) else: model = BSRoformer(**kwargs) return model, audio_cfg["sample_rate"] def load_audio(path: Path, target_sr: int): wav_np, sr = sf.read(str(path), always_2d=True) wav = torch.from_numpy(wav_np.T).float() if sr != target_sr: wav = AF.resample(wav, sr, target_sr) if wav.shape[0] == 1: wav = wav.repeat(2, 1) elif wav.shape[0] > 2: wav = wav[:2, :] return wav.unsqueeze(0) def infer_chunked(model, audio, chunk_size=353280, context=132096): center_size = chunk_size - 2 * context if center_size <= 0: raise RuntimeError("chunk_size must be larger than 2*context") audio_len = audio.shape[-1] padded = torch.nn.functional.pad(audio, (context, context), mode="replicate") out = None pos = 0 while pos < audio_len: center_end = min(pos + center_size, audio_len) valid_len = center_end - pos chunk = padded[:, :, pos : pos + chunk_size] if chunk.shape[-1] < chunk_size: pad = chunk_size - chunk.shape[-1] chunk = torch.nn.functional.pad(chunk, (0, pad), mode="replicate") with torch.inference_mode(): if audio.is_cuda: with torch.autocast(device_type="cuda", dtype=torch.float16): out_chunk = model(chunk) else: out_chunk = model(chunk) # Normalize output shape to [B, C, T] # Some checkpoints return [B, N, C, T] (multi-stem). if out_chunk.ndim == 4: out_chunk = out_chunk[:, 0, :, :] elif out_chunk.ndim != 3: raise RuntimeError(f"Unsupported output ndim={out_chunk.ndim}, shape={tuple(out_chunk.shape)}") if out is None: out = torch.zeros((out_chunk.shape[0], out_chunk.shape[1], audio_len), device=audio.device) out[:, :, pos:center_end] = out_chunk[:, :, context : context + valid_len] pos += center_size return out def eval_pair(name, teacher_yaml, teacher_ckpt, rokan_yaml, rokan_ckpt, wav_path): t_model, t_sr = build_model_from_yaml(teacher_yaml) r_model, r_sr = build_model_from_yaml(rokan_yaml) if t_sr != r_sr: raise RuntimeError(f"{name}: sample rate mismatch {t_sr} vs {r_sr}") t_model.load_state_dict(clean_state_dict(teacher_ckpt), strict=False) r_model.load_state_dict(clean_state_dict(rokan_ckpt), strict=False) t_model = t_model.to(DEVICE).eval() r_model = r_model.to(DEVICE).eval() audio = load_audio(wav_path, t_sr).to(DEVICE) tic = time.time() t_out = infer_chunked(t_model, audio) t_sec = time.time() - tic tic = time.time() r_out = infer_chunked(r_model, audio) r_sec = time.time() - tic diff = (t_out - r_out).float() mae = diff.abs().mean().item() rmse = torch.sqrt((diff ** 2).mean()).item() max_abs = diff.abs().max().item() return { "name": name, "sample_rate": t_sr, "audio_seconds": float(audio.shape[-1]) / float(t_sr), "teacher_sec": t_sec, "rokan_sec": r_sec, "mae": mae, "rmse": rmse, "max_abs": max_abs, } def main(): parser = argparse.ArgumentParser(description="Evaluate teacher vs RoKAN fidelity for BS and MelBand models") parser.add_argument("--input_wav", type=str, default="") args = parser.parse_args() root = Path(__file__).resolve().parent input_dir = root / "input" wav_path = Path(args.input_wav) if args.input_wav else None if wav_path is None: wavs = sorted(input_dir.glob("*.wav")) if not wavs: raise RuntimeError("No wav in input/. Set --input_wav explicitly.") wav_path = wavs[0] if not wav_path.exists(): raise RuntimeError(f"Input wav not found: {wav_path}") pairs = [ ( "BS-Rofo-SW-Fixed", root / "dataset/Models/BS-Rofo-SW-Fixed.yaml", root / "dataset/Models/BS-Rofo-SW-Fixed.ckpt", root / "converted_models/BS-Rofo-SW-Fixed_rokan.yaml", root / "converted_models/BS-Rofo-SW-Fixed_rokan.ckpt", ), ( "MelBand denoise", root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.yaml", root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt", root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.yaml", root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.ckpt", ), ] rows = [] for row in pairs: name, ty, tc, ry, rc = row missing = [str(p) for p in (ty, tc, ry, rc) if not p.exists()] if missing: rows.append({"name": name, "error": "missing files: " + ", ".join(missing)}) continue try: rows.append(eval_pair(name, ty, tc, ry, rc, wav_path)) except Exception as e: rows.append({"name": name, "error": str(e)}) out_path = root / "converted_models" / "eval_fidelity_report.md" lines = [] lines.append("# RoKAN Fidelity Report") lines.append("") lines.append(f"- input_wav: `{wav_path}`") lines.append(f"- device: `{DEVICE}`") lines.append("") for r in rows: lines.append(f"## {r['name']}") if "error" in r: lines.append(f"- status: FAIL") lines.append(f"- error: `{r['error']}`") else: lines.append("- status: OK") lines.append(f"- sample_rate: {r['sample_rate']}") lines.append(f"- audio_seconds: {r['audio_seconds']:.2f}") lines.append(f"- teacher_infer_sec: {r['teacher_sec']:.2f}") lines.append(f"- rokan_infer_sec: {r['rokan_sec']:.2f}") lines.append(f"- mae: {r['mae']:.8f}") lines.append(f"- rmse: {r['rmse']:.8f}") lines.append(f"- max_abs: {r['max_abs']:.8f}") lines.append("") out_path.write_text("\n".join(lines), encoding="utf-8") print(f"wrote: {out_path}") if __name__ == "__main__": main()