#!/usr/bin/env -S uv run --extra cpu # /// script # requires-python = ">=3.10" # dependencies = [ # "fireredvad", # "torch>=2.0.0", # "numpy", # "kaldiio", # "huggingface_hub", # ] # /// """ Export FireRedVAD PyTorch checkpoints to the FRVD binary format consumed by the pure-C inference engine (mod_fireredvad / fireredvad-dart). Downloads the official FireRedTeam/FireRedVAD checkpoints from HuggingFace if not already present, then writes: fireredvad.bin "FRVD" + uint32 version + concatenated VAD then AED weights as little-endian float32. Layout matches the reader in fireredvad.c::fireredvad_load_weights. fireredvad.json CMVN normalization stats {"means": [...], "inv_std": [...]} converted from the kaldi-format cmvn.ark. Usage: python export_frvd.py python export_frvd.py --output-dir /tmp/frvd python export_frvd.py --model-root path/to/FireRedVAD """ import argparse import json import os import struct import sys import numpy as np import torch HF_REPO = "FireRedTeam/FireRedVAD" DEFAULT_MODEL_ROOT = "pretrained_models/FireRedVAD" FRVD_MAGIC = b"FRVD" FRVD_VERSION = 1 # Must match constants in fireredvad.h D_IN = 80 D_HIDDEN = 256 D_PROJ = 128 D_FILTER = 20 N_BLOCKS = 8 # 1 input FSMN + 7 DFSMN blocks N_FSMN_BLOCKS = N_BLOCKS - 1 AED_NUM_CLASSES = 3 def download_models(model_root): """Download FireRedVAD checkpoints from HuggingFace if missing.""" needed = ["Stream-VAD", "AED"] if all(os.path.isdir(os.path.join(model_root, sub)) for sub in needed): print(f"Models already present at: {model_root}") return try: from huggingface_hub import snapshot_download except ImportError: print("Error: huggingface_hub not found. Install with: pip install huggingface_hub") sys.exit(1) print(f"Downloading {HF_REPO} to {model_root} ...") snapshot_download(repo_id=HF_REPO, local_dir=model_root) def load_state_dict(model_dir): """Load a FireRedVAD checkpoint and return its state_dict.""" from fireredvad.core.detect_model import DetectModel model = DetectModel.from_pretrained(model_dir) model.eval() return model.state_dict() def linear_w(sd, prefix, expected_in, expected_out): """PyTorch Linear stores [out, in]; FRVD reader expects [in, out] flat.""" w = sd[f"{prefix}.weight"] assert tuple(w.shape) == (expected_out, expected_in), \ f"{prefix}.weight shape {tuple(w.shape)} != ({expected_out}, {expected_in})" return w.t().contiguous().cpu().float().numpy().reshape(-1) def linear_b(sd, prefix, expected_out): b = sd[f"{prefix}.bias"] assert tuple(b.shape) == (expected_out,), \ f"{prefix}.bias shape {tuple(b.shape)} != ({expected_out},)" return b.cpu().float().numpy().reshape(-1) def conv_filter(sd, prefix, expected_p, expected_k): """Depthwise Conv1d weight is [P, 1, K]; FRVD layout is [P, K] flat.""" w = sd[f"{prefix}.weight"] assert tuple(w.shape) == (expected_p, 1, expected_k), \ f"{prefix}.weight shape {tuple(w.shape)} != ({expected_p}, 1, {expected_k})" return w.cpu().float().numpy().reshape(-1) def serialize_vad(sd, out): """Append VAD weights to `out` (a bytearray) in FRVD reader order.""" def write(arr): out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes()) # Input projection write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN)) write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN)) write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ)) write(linear_b(sd, "dfsmn.fc2.0", D_PROJ)) write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER)) # 7 DFSMN blocks (block fc2 has no bias) for i in range(N_FSMN_BLOCKS): write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN)) write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN)) write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ)) write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER)) # Output head: dnns[0] is Linear(P, H); self.out is Linear(H, 1) write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN)) write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN)) write(linear_w(sd, "out", D_HIDDEN, 1)) write(linear_b(sd, "out", 1)) def serialize_aed(sd, out): """Append AED weights to `out`. AED adds lookahead filters and odim=3.""" def write(arr): out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes()) write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN)) write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN)) write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ)) write(linear_b(sd, "dfsmn.fc2.0", D_PROJ)) write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER)) write(conv_filter(sd, "dfsmn.fsmn1.lookahead_filter", D_PROJ, D_FILTER)) for i in range(N_FSMN_BLOCKS): write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN)) write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN)) write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ)) write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER)) write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookahead_filter", D_PROJ, D_FILTER)) write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN)) write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN)) write(linear_w(sd, "out", D_HIDDEN, AED_NUM_CLASSES)) write(linear_b(sd, "out", AED_NUM_CLASSES)) def export_weights(vad_dir, aed_dir, out_path): print(f"Loading VAD checkpoint from: {vad_dir}") vad_sd = load_state_dict(vad_dir) print(f"Loading AED checkpoint from: {aed_dir}") aed_sd = load_state_dict(aed_dir) buf = bytearray() buf.extend(FRVD_MAGIC) buf.extend(struct.pack("= 1, f"Bad frame count in cmvn.ark: {count}" # Compute the same way as fireredvad.core.audio_feat.CMVN so values stay # in float32 precision (the C engine consumes them as float32 anyway). floor = np.float32(1e-20) means, inv_std = [], [] for d in range(dim): mean = stats[0, d] / count var = (stats[1, d] / count) - mean * mean if var < floor: var = floor means.append(float(mean)) inv_std.append(float(np.float32(1.0) / np.sqrt(var))) with open(out_path, "w") as f: json.dump({"means": means, "inv_std": inv_std}, f) print(f"Wrote {out_path} ({dim} bins)") def main(): parser = argparse.ArgumentParser( description="Export FireRedVAD PyTorch checkpoints to FRVD .bin + cmvn .json") parser.add_argument("--model-root", default=DEFAULT_MODEL_ROOT, help=f"Root directory for downloaded models (default: {DEFAULT_MODEL_ROOT})") parser.add_argument("--vad-dir", default=None, help="Path to streaming VAD model directory (default: {model-root}/Stream-VAD)") parser.add_argument("--aed-dir", default=None, help="Path to AED model directory (default: {model-root}/AED)") parser.add_argument("--output-dir", default=".", help="Output directory (default: current directory)") parser.add_argument("--skip-download", action="store_true", help="Skip downloading models (use existing local files)") args = parser.parse_args() if not args.skip_download and args.vad_dir is None and args.aed_dir is None: download_models(args.model_root) vad_dir = args.vad_dir or os.path.join(args.model_root, "Stream-VAD") aed_dir = args.aed_dir or os.path.join(args.model_root, "AED") os.makedirs(args.output_dir, exist_ok=True) bin_path = os.path.join(args.output_dir, "fireredvad.bin") json_path = os.path.join(args.output_dir, "fireredvad.json") export_weights(vad_dir, aed_dir, bin_path) cmvn_ark = os.path.join(vad_dir, "cmvn.ark") if not os.path.isfile(cmvn_ark): cmvn_ark = os.path.join(aed_dir, "cmvn.ark") if not os.path.isfile(cmvn_ark): print(f"Warning: cmvn.ark not found in {vad_dir} or {aed_dir}; skipping JSON") return export_cmvn(cmvn_ark, json_path) print("Done.") if __name__ == "__main__": main()