First-RoKAN-Model / evaluate_rokan_fidelity.py
tekitoutarou's picture
Upload 12 files
f73ae00 verified
import argparse
import time
from pathlib import Path
import soundfile as sf
import torch
import torchaudio.functional as AF
import yaml
from models.bs_roformer.bs_roformer import BSRoformer
from models.bs_roformer.mel_band_roformer import MelBandRoformer
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_cfg(path: Path):
with path.open("r", encoding="utf-8") as f:
return yaml.load(f, Loader=yaml.FullLoader)
def clean_state_dict(ckpt_path: Path):
sd = torch.load(str(ckpt_path), map_location="cpu")
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
if isinstance(sd, dict) and "model" in sd:
sd = sd["model"]
cleaned = {}
for k, v in sd.items():
cleaned[k[6:] if k.startswith("model.") else k] = v
return cleaned
def build_model_from_yaml(yaml_path: Path):
cfg = load_cfg(yaml_path)
m = cfg["model"]
audio_cfg = cfg["audio"]
kwargs = dict(
dim=m["dim"],
depth=m["depth"],
stereo=m.get("stereo", True),
num_stems=m.get("num_stems", 1),
time_transformer_depth=m.get("time_transformer_depth", 1),
freq_transformer_depth=m.get("freq_transformer_depth", 1),
linear_transformer_depth=m.get("linear_transformer_depth", 0),
dim_head=m.get("dim_head", 64),
heads=m.get("heads", 8),
attn_dropout=m.get("attn_dropout", 0.0),
ff_dropout=m.get("ff_dropout", 0.0),
flash_attn=False,
dim_freqs_in=m.get("dim_freqs_in", 1025),
stft_n_fft=m.get("stft_n_fft", 2048),
stft_hop_length=m.get("stft_hop_length", 512),
stft_win_length=m.get("stft_win_length", 2048),
stft_normalized=m.get("stft_normalized", False),
mask_estimator_depth=m.get("mask_estimator_depth", 2),
multi_stft_resolution_loss_weight=m.get("multi_stft_resolution_loss_weight", 1.0),
multi_stft_resolutions_window_sizes=tuple(m.get("multi_stft_resolutions_window_sizes", (4096, 2048, 1024, 512, 256))),
multi_stft_hop_size=m.get("multi_stft_hop_size", 147),
multi_stft_normalized=m.get("multi_stft_normalized", False),
mlp_expansion_factor=m.get("mlp_expansion_factor", 4),
use_torch_checkpoint=False,
skip_connection=m.get("skip_connection", False),
sage_attention=m.get("sage_attention", False),
use_kan=m.get("use_kan", False),
kan_grid_size=m.get("kan_grid_size", 8),
)
if "freqs_per_bands" in m:
kwargs["freqs_per_bands"] = tuple(m["freqs_per_bands"])
if "num_bands" in m:
kwargs["num_bands"] = m.get("num_bands", 60)
kwargs["sample_rate"] = m.get("sample_rate", audio_cfg.get("sample_rate", 44100))
model = MelBandRoformer(**kwargs)
else:
model = BSRoformer(**kwargs)
return model, audio_cfg["sample_rate"]
def load_audio(path: Path, target_sr: int):
wav_np, sr = sf.read(str(path), always_2d=True)
wav = torch.from_numpy(wav_np.T).float()
if sr != target_sr:
wav = AF.resample(wav, sr, target_sr)
if wav.shape[0] == 1:
wav = wav.repeat(2, 1)
elif wav.shape[0] > 2:
wav = wav[:2, :]
return wav.unsqueeze(0)
def infer_chunked(model, audio, chunk_size=353280, context=132096):
center_size = chunk_size - 2 * context
if center_size <= 0:
raise RuntimeError("chunk_size must be larger than 2*context")
audio_len = audio.shape[-1]
padded = torch.nn.functional.pad(audio, (context, context), mode="replicate")
out = None
pos = 0
while pos < audio_len:
center_end = min(pos + center_size, audio_len)
valid_len = center_end - pos
chunk = padded[:, :, pos : pos + chunk_size]
if chunk.shape[-1] < chunk_size:
pad = chunk_size - chunk.shape[-1]
chunk = torch.nn.functional.pad(chunk, (0, pad), mode="replicate")
with torch.inference_mode():
if audio.is_cuda:
with torch.autocast(device_type="cuda", dtype=torch.float16):
out_chunk = model(chunk)
else:
out_chunk = model(chunk)
# Normalize output shape to [B, C, T]
# Some checkpoints return [B, N, C, T] (multi-stem).
if out_chunk.ndim == 4:
out_chunk = out_chunk[:, 0, :, :]
elif out_chunk.ndim != 3:
raise RuntimeError(f"Unsupported output ndim={out_chunk.ndim}, shape={tuple(out_chunk.shape)}")
if out is None:
out = torch.zeros((out_chunk.shape[0], out_chunk.shape[1], audio_len), device=audio.device)
out[:, :, pos:center_end] = out_chunk[:, :, context : context + valid_len]
pos += center_size
return out
def eval_pair(name, teacher_yaml, teacher_ckpt, rokan_yaml, rokan_ckpt, wav_path):
t_model, t_sr = build_model_from_yaml(teacher_yaml)
r_model, r_sr = build_model_from_yaml(rokan_yaml)
if t_sr != r_sr:
raise RuntimeError(f"{name}: sample rate mismatch {t_sr} vs {r_sr}")
t_model.load_state_dict(clean_state_dict(teacher_ckpt), strict=False)
r_model.load_state_dict(clean_state_dict(rokan_ckpt), strict=False)
t_model = t_model.to(DEVICE).eval()
r_model = r_model.to(DEVICE).eval()
audio = load_audio(wav_path, t_sr).to(DEVICE)
tic = time.time()
t_out = infer_chunked(t_model, audio)
t_sec = time.time() - tic
tic = time.time()
r_out = infer_chunked(r_model, audio)
r_sec = time.time() - tic
diff = (t_out - r_out).float()
mae = diff.abs().mean().item()
rmse = torch.sqrt((diff ** 2).mean()).item()
max_abs = diff.abs().max().item()
return {
"name": name,
"sample_rate": t_sr,
"audio_seconds": float(audio.shape[-1]) / float(t_sr),
"teacher_sec": t_sec,
"rokan_sec": r_sec,
"mae": mae,
"rmse": rmse,
"max_abs": max_abs,
}
def main():
parser = argparse.ArgumentParser(description="Evaluate teacher vs RoKAN fidelity for BS and MelBand models")
parser.add_argument("--input_wav", type=str, default="")
args = parser.parse_args()
root = Path(__file__).resolve().parent
input_dir = root / "input"
wav_path = Path(args.input_wav) if args.input_wav else None
if wav_path is None:
wavs = sorted(input_dir.glob("*.wav"))
if not wavs:
raise RuntimeError("No wav in input/. Set --input_wav explicitly.")
wav_path = wavs[0]
if not wav_path.exists():
raise RuntimeError(f"Input wav not found: {wav_path}")
pairs = [
(
"BS-Rofo-SW-Fixed",
root / "dataset/Models/BS-Rofo-SW-Fixed.yaml",
root / "dataset/Models/BS-Rofo-SW-Fixed.ckpt",
root / "converted_models/BS-Rofo-SW-Fixed_rokan.yaml",
root / "converted_models/BS-Rofo-SW-Fixed_rokan.ckpt",
),
(
"MelBand denoise",
root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.yaml",
root / "dataset/Models/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt",
root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.yaml",
root / "converted_models/denoise_mel_band_roformer_aufr33_sdr_27.9959_rokan.ckpt",
),
]
rows = []
for row in pairs:
name, ty, tc, ry, rc = row
missing = [str(p) for p in (ty, tc, ry, rc) if not p.exists()]
if missing:
rows.append({"name": name, "error": "missing files: " + ", ".join(missing)})
continue
try:
rows.append(eval_pair(name, ty, tc, ry, rc, wav_path))
except Exception as e:
rows.append({"name": name, "error": str(e)})
out_path = root / "converted_models" / "eval_fidelity_report.md"
lines = []
lines.append("# RoKAN Fidelity Report")
lines.append("")
lines.append(f"- input_wav: `{wav_path}`")
lines.append(f"- device: `{DEVICE}`")
lines.append("")
for r in rows:
lines.append(f"## {r['name']}")
if "error" in r:
lines.append(f"- status: FAIL")
lines.append(f"- error: `{r['error']}`")
else:
lines.append("- status: OK")
lines.append(f"- sample_rate: {r['sample_rate']}")
lines.append(f"- audio_seconds: {r['audio_seconds']:.2f}")
lines.append(f"- teacher_infer_sec: {r['teacher_sec']:.2f}")
lines.append(f"- rokan_infer_sec: {r['rokan_sec']:.2f}")
lines.append(f"- mae: {r['mae']:.8f}")
lines.append(f"- rmse: {r['rmse']:.8f}")
lines.append(f"- max_abs: {r['max_abs']:.8f}")
lines.append("")
out_path.write_text("\n".join(lines), encoding="utf-8")
print(f"wrote: {out_path}")
if __name__ == "__main__":
main()