| import argparse |
| import time |
| from pathlib import Path |
|
|
| import soundfile as sf |
| import torch |
| import torchaudio.functional as AF |
| import yaml |
|
|
| from models.bs_roformer.bs_roformer import BSRoformer |
| from models.bs_roformer.mel_band_roformer import MelBandRoformer |
|
|
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def load_cfg(path: Path): |
| with path.open("r", encoding="utf-8") as f: |
| return yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
|
| def clean_state_dict(ckpt_path: Path): |
| sd = torch.load(str(ckpt_path), map_location="cpu") |
| if isinstance(sd, dict) and "state_dict" in sd: |
| sd = sd["state_dict"] |
| if isinstance(sd, dict) and "model" in sd: |
| sd = sd["model"] |
| cleaned = {} |
| for k, v in sd.items(): |
| cleaned[k[6:] if k.startswith("model.") else k] = v |
| return cleaned |
|
|
|
|
| def build_model_from_yaml(yaml_path: Path): |
| cfg = load_cfg(yaml_path) |
| m = cfg["model"] |
| audio_cfg = cfg["audio"] |
| kwargs = dict( |
| dim=m["dim"], |
| depth=m["depth"], |
| stereo=m.get("stereo", True), |
| num_stems=m.get("num_stems", 1), |
| time_transformer_depth=m.get("time_transformer_depth", 1), |
| freq_transformer_depth=m.get("freq_transformer_depth", 1), |
| linear_transformer_depth=m.get("linear_transformer_depth", 0), |
| dim_head=m.get("dim_head", 64), |
| heads=m.get("heads", 8), |
| attn_dropout=m.get("attn_dropout", 0.0), |
| ff_dropout=m.get("ff_dropout", 0.0), |
| flash_attn=False, |
| dim_freqs_in=m.get("dim_freqs_in", 1025), |
| stft_n_fft=m.get("stft_n_fft", 2048), |
| stft_hop_length=m.get("stft_hop_length", 512), |
| stft_win_length=m.get("stft_win_length", 2048), |
| stft_normalized=m.get("stft_normalized", False), |
| mask_estimator_depth=m.get("mask_estimator_depth", 2), |
| multi_stft_resolution_loss_weight=m.get("multi_stft_resolution_loss_weight", 1.0), |
| multi_stft_resolutions_window_sizes=tuple(m.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256))), |
| multi_stft_hop_size=m.get("multi_stft_hop_size", 147), |
| multi_stft_normalized=m.get("multi_stft_normalized", False), |
| mlp_expansion_factor=m.get("mlp_expansion_factor", 4), |
| use_torch_checkpoint=False, |
| skip_connection=m.get("skip_connection", False), |
| sage_attention=m.get("sage_attention", False), |
| use_kan=m.get("use_kan", False), |
| kan_grid_size=m.get("kan_grid_size", 8), |
| ) |
| if "freqs_per_bands" in m: |
| kwargs["freqs_per_bands"] = tuple(m["freqs_per_bands"]) |
|
|
| if "num_bands" in m: |
| kwargs["num_bands"] = m.get("num_bands", 60) |
| kwargs["sample_rate"] = m.get("sample_rate", audio_cfg.get("sample_rate", 44100)) |
| model = MelBandRoformer(**kwargs) |
| else: |
| model = BSRoformer(**kwargs) |
| return model, audio_cfg["sample_rate"] |
|
|
|
|
| def load_audio(path: Path, target_sr: int): |
| wav_np, sr = sf.read(str(path), always_2d=True) |
| wav = torch.from_numpy(wav_np.T).float() |
| if sr != target_sr: |
| wav = AF.resample(wav, sr, target_sr) |
| if wav.shape[0] == 1: |
| wav = wav.repeat(2, 1) |
| elif wav.shape[0] > 2: |
| wav = wav[:2, :] |
| return wav.unsqueeze(0) |
|
|
|
|
| def infer_chunked(model, audio, chunk_size=353280, context=132096): |
| center_size = chunk_size - 2 * context |
| if center_size <= 0: |
| raise RuntimeError("chunk_size must be larger than 2*context") |
| audio_len = audio.shape[-1] |
| padded = torch.nn.functional.pad(audio, (context, context), mode="replicate") |
| out = None |
| pos = 0 |
| while pos < audio_len: |
| center_end = min(pos + center_size, audio_len) |
| valid_len = center_end - pos |
| chunk = padded[:, :, pos : pos + chunk_size] |
| if chunk.shape[-1] < chunk_size: |
| pad = chunk_size - chunk.shape[-1] |
| chunk = torch.nn.functional.pad(chunk, (0, pad), mode="replicate") |
| with torch.inference_mode(): |
| if audio.is_cuda: |
| with torch.autocast(device_type="cuda", dtype=torch.float16): |
| out_chunk = model(chunk) |
| else: |
| out_chunk = model(chunk) |
| |
| |
| if out_chunk.ndim == 4: |
| out_chunk = out_chunk[:, 0, :, :] |
| elif out_chunk.ndim != 3: |
| raise RuntimeError(f"Unsupported output ndim={out_chunk.ndim}, shape={tuple(out_chunk.shape)}") |
|
|
| if out is None: |
| out = torch.zeros((out_chunk.shape[0], out_chunk.shape[1], audio_len), device=audio.device) |
|
|
| out[:, :, pos:center_end] = out_chunk[:, :, context : context + valid_len] |
| pos += center_size |
| return out |
|
|
|
|
| def eval_pair(name, teacher_yaml, teacher_ckpt, rokan_yaml, rokan_ckpt, wav_path): |
| t_model, t_sr = build_model_from_yaml(teacher_yaml) |
| r_model, r_sr = build_model_from_yaml(rokan_yaml) |
| if t_sr != r_sr: |
| raise RuntimeError(f"{name}: sample rate mismatch {t_sr} vs {r_sr}") |
| t_model.load_state_dict(clean_state_dict(teacher_ckpt), strict=False) |
| r_model.load_state_dict(clean_state_dict(rokan_ckpt), strict=False) |
| t_model = t_model.to(DEVICE).eval() |
| r_model = r_model.to(DEVICE).eval() |
|
|
| audio = load_audio(wav_path, t_sr).to(DEVICE) |
| tic = time.time() |
| t_out = infer_chunked(t_model, audio) |
| t_sec = time.time() - tic |
| tic = time.time() |
| r_out = infer_chunked(r_model, audio) |
| r_sec = time.time() - tic |
|
|
| diff = (t_out - r_out).float() |
| mae = diff.abs().mean().item() |
| rmse = torch.sqrt((diff ** 2).mean()).item() |
| max_abs = diff.abs().max().item() |
| return { |
| "name": name, |
| "sample_rate": t_sr, |
| "audio_seconds": float(audio.shape[-1]) / float(t_sr), |
| "teacher_sec": t_sec, |
| "rokan_sec": r_sec, |
| "mae": mae, |
| "rmse": rmse, |
| "max_abs": max_abs, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate teacher vs RoKAN fidelity for BS and MelBand models") |
| parser.add_argument("--input_wav", type=str, default="") |
| args = parser.parse_args() |
|
|
| root = Path(__file__).resolve().parent |
| input_dir = root / "input" |
| wav_path = Path(args.input_wav) if args.input_wav else None |
| if wav_path is None: |
| wavs = sorted(input_dir.glob("*.wav")) |
| if not wavs: |
| raise RuntimeError("No wav in input/. Set --input_wav explicitly.") |
| wav_path = wavs[0] |
| if not wav_path.exists(): |
| raise RuntimeError(f"Input wav not found: {wav_path}") |
|
|
| pairs = [ |
| ( |
| "BS-Rofo-SW-Fixed", |
| root / "dataset/Models/BS-Rofo-SW-Fixed.yaml", |
| root / "dataset/Models/BS-Rofo-SW-Fixed.ckpt", |
| root / "converted_models/BS-Rofo-SW-Fixed_rokan.yaml", |
| root / "converted_models/BS-Rofo-SW-Fixed_rokan.ckpt", |
| ), |
| ( |
| "MelBand denoise", |
| root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.yaml", |
| root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt", |
| root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.yaml", |
| root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.ckpt", |
| ), |
| ] |
|
|
| rows = [] |
| for row in pairs: |
| name, ty, tc, ry, rc = row |
| missing = [str(p) for p in (ty, tc, ry, rc) if not p.exists()] |
| if missing: |
| rows.append({"name": name, "error": "missing files: " + ", ".join(missing)}) |
| continue |
| try: |
| rows.append(eval_pair(name, ty, tc, ry, rc, wav_path)) |
| except Exception as e: |
| rows.append({"name": name, "error": str(e)}) |
|
|
| out_path = root / "converted_models" / "eval_fidelity_report.md" |
| lines = [] |
| lines.append("# RoKAN Fidelity Report") |
| lines.append("") |
| lines.append(f"- input_wav: `{wav_path}`") |
| lines.append(f"- device: `{DEVICE}`") |
| lines.append("") |
| for r in rows: |
| lines.append(f"## {r['name']}") |
| if "error" in r: |
| lines.append(f"- status: FAIL") |
| lines.append(f"- error: `{r['error']}`") |
| else: |
| lines.append("- status: OK") |
| lines.append(f"- sample_rate: {r['sample_rate']}") |
| lines.append(f"- audio_seconds: {r['audio_seconds']:.2f}") |
| lines.append(f"- teacher_infer_sec: {r['teacher_sec']:.2f}") |
| lines.append(f"- rokan_infer_sec: {r['rokan_sec']:.2f}") |
| lines.append(f"- mae: {r['mae']:.8f}") |
| lines.append(f"- rmse: {r['rmse']:.8f}") |
| lines.append(f"- max_abs: {r['max_abs']:.8f}") |
| lines.append("") |
|
|
| out_path.write_text("\n".join(lines), encoding="utf-8") |
| print(f"wrote: {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|