fireredvad-c / export_frvd.py
eschmidbauer's picture
add export script
221475f
#!/usr/bin/env -S uv run --extra cpu
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "fireredvad",
# "torch>=2.0.0",
# "numpy",
# "kaldiio",
# "huggingface_hub",
# ]
# ///
"""
Export FireRedVAD PyTorch checkpoints to the FRVD binary format consumed by
the pure-C inference engine (mod_fireredvad / fireredvad-dart).
Downloads the official FireRedTeam/FireRedVAD checkpoints from HuggingFace if
not already present, then writes:
fireredvad.bin "FRVD" + uint32 version + concatenated VAD then AED
weights as little-endian float32. Layout matches the
reader in fireredvad.c::fireredvad_load_weights.
fireredvad.json CMVN normalization stats {"means": [...], "inv_std": [...]}
converted from the kaldi-format cmvn.ark.
Usage:
python export_frvd.py
python export_frvd.py --output-dir /tmp/frvd
python export_frvd.py --model-root path/to/FireRedVAD
"""
import argparse
import json
import os
import struct
import sys
import numpy as np
import torch
HF_REPO = "FireRedTeam/FireRedVAD"
DEFAULT_MODEL_ROOT = "pretrained_models/FireRedVAD"
FRVD_MAGIC = b"FRVD"
FRVD_VERSION = 1
# Must match constants in fireredvad.h
D_IN = 80
D_HIDDEN = 256
D_PROJ = 128
D_FILTER = 20
N_BLOCKS = 8 # 1 input FSMN + 7 DFSMN blocks
N_FSMN_BLOCKS = N_BLOCKS - 1
AED_NUM_CLASSES = 3
def download_models(model_root):
"""Download FireRedVAD checkpoints from HuggingFace if missing."""
needed = ["Stream-VAD", "AED"]
if all(os.path.isdir(os.path.join(model_root, sub)) for sub in needed):
print(f"Models already present at: {model_root}")
return
try:
from huggingface_hub import snapshot_download
except ImportError:
print("Error: huggingface_hub not found. Install with: pip install huggingface_hub")
sys.exit(1)
print(f"Downloading {HF_REPO} to {model_root} ...")
snapshot_download(repo_id=HF_REPO, local_dir=model_root)
def load_state_dict(model_dir):
"""Load a FireRedVAD checkpoint and return its state_dict."""
from fireredvad.core.detect_model import DetectModel
model = DetectModel.from_pretrained(model_dir)
model.eval()
return model.state_dict()
def linear_w(sd, prefix, expected_in, expected_out):
"""PyTorch Linear stores [out, in]; FRVD reader expects [in, out] flat."""
w = sd[f"{prefix}.weight"]
assert tuple(w.shape) == (expected_out, expected_in), \
f"{prefix}.weight shape {tuple(w.shape)} != ({expected_out}, {expected_in})"
return w.t().contiguous().cpu().float().numpy().reshape(-1)
def linear_b(sd, prefix, expected_out):
b = sd[f"{prefix}.bias"]
assert tuple(b.shape) == (expected_out,), \
f"{prefix}.bias shape {tuple(b.shape)} != ({expected_out},)"
return b.cpu().float().numpy().reshape(-1)
def conv_filter(sd, prefix, expected_p, expected_k):
"""Depthwise Conv1d weight is [P, 1, K]; FRVD layout is [P, K] flat."""
w = sd[f"{prefix}.weight"]
assert tuple(w.shape) == (expected_p, 1, expected_k), \
f"{prefix}.weight shape {tuple(w.shape)} != ({expected_p}, 1, {expected_k})"
return w.cpu().float().numpy().reshape(-1)
def serialize_vad(sd, out):
"""Append VAD weights to `out` (a bytearray) in FRVD reader order."""
def write(arr):
out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes())
# Input projection
write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN))
write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN))
write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ))
write(linear_b(sd, "dfsmn.fc2.0", D_PROJ))
write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER))
# 7 DFSMN blocks (block fc2 has no bias)
for i in range(N_FSMN_BLOCKS):
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN))
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ))
write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER))
# Output head: dnns[0] is Linear(P, H); self.out is Linear(H, 1)
write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN))
write(linear_w(sd, "out", D_HIDDEN, 1))
write(linear_b(sd, "out", 1))
def serialize_aed(sd, out):
"""Append AED weights to `out`. AED adds lookahead filters and odim=3."""
def write(arr):
out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes())
write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN))
write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN))
write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ))
write(linear_b(sd, "dfsmn.fc2.0", D_PROJ))
write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER))
write(conv_filter(sd, "dfsmn.fsmn1.lookahead_filter", D_PROJ, D_FILTER))
for i in range(N_FSMN_BLOCKS):
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN))
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ))
write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER))
write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookahead_filter", D_PROJ, D_FILTER))
write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN))
write(linear_w(sd, "out", D_HIDDEN, AED_NUM_CLASSES))
write(linear_b(sd, "out", AED_NUM_CLASSES))
def export_weights(vad_dir, aed_dir, out_path):
print(f"Loading VAD checkpoint from: {vad_dir}")
vad_sd = load_state_dict(vad_dir)
print(f"Loading AED checkpoint from: {aed_dir}")
aed_sd = load_state_dict(aed_dir)
buf = bytearray()
buf.extend(FRVD_MAGIC)
buf.extend(struct.pack("<I", FRVD_VERSION))
serialize_vad(vad_sd, buf)
serialize_aed(aed_sd, buf)
with open(out_path, "wb") as f:
f.write(buf)
size_mb = len(buf) / 1024 / 1024
print(f"Wrote {out_path} ({len(buf):,} bytes, {size_mb:.2f} MB)")
def export_cmvn(cmvn_ark, out_path):
"""Convert kaldi cmvn.ark to {"means": [...], "inv_std": [...]} JSON."""
try:
import kaldiio
except ImportError:
print("Error: kaldiio not found. Install with: pip install kaldiio")
sys.exit(1)
stats = kaldiio.load_mat(cmvn_ark)
assert stats.shape[0] == 2, f"Unexpected cmvn shape: {stats.shape}"
dim = stats.shape[1] - 1
count = stats[0, dim]
assert count >= 1, f"Bad frame count in cmvn.ark: {count}"
# Compute the same way as fireredvad.core.audio_feat.CMVN so values stay
# in float32 precision (the C engine consumes them as float32 anyway).
floor = np.float32(1e-20)
means, inv_std = [], []
for d in range(dim):
mean = stats[0, d] / count
var = (stats[1, d] / count) - mean * mean
if var < floor:
var = floor
means.append(float(mean))
inv_std.append(float(np.float32(1.0) / np.sqrt(var)))
with open(out_path, "w") as f:
json.dump({"means": means, "inv_std": inv_std}, f)
print(f"Wrote {out_path} ({dim} bins)")
def main():
parser = argparse.ArgumentParser(
description="Export FireRedVAD PyTorch checkpoints to FRVD .bin + cmvn .json")
parser.add_argument("--model-root", default=DEFAULT_MODEL_ROOT,
help=f"Root directory for downloaded models (default: {DEFAULT_MODEL_ROOT})")
parser.add_argument("--vad-dir", default=None,
help="Path to streaming VAD model directory (default: {model-root}/Stream-VAD)")
parser.add_argument("--aed-dir", default=None,
help="Path to AED model directory (default: {model-root}/AED)")
parser.add_argument("--output-dir", default=".",
help="Output directory (default: current directory)")
parser.add_argument("--skip-download", action="store_true",
help="Skip downloading models (use existing local files)")
args = parser.parse_args()
if not args.skip_download and args.vad_dir is None and args.aed_dir is None:
download_models(args.model_root)
vad_dir = args.vad_dir or os.path.join(args.model_root, "Stream-VAD")
aed_dir = args.aed_dir or os.path.join(args.model_root, "AED")
os.makedirs(args.output_dir, exist_ok=True)
bin_path = os.path.join(args.output_dir, "fireredvad.bin")
json_path = os.path.join(args.output_dir, "fireredvad.json")
export_weights(vad_dir, aed_dir, bin_path)
cmvn_ark = os.path.join(vad_dir, "cmvn.ark")
if not os.path.isfile(cmvn_ark):
cmvn_ark = os.path.join(aed_dir, "cmvn.ark")
if not os.path.isfile(cmvn_ark):
print(f"Warning: cmvn.ark not found in {vad_dir} or {aed_dir}; skipping JSON")
return
export_cmvn(cmvn_ark, json_path)
print("Done.")
if __name__ == "__main__":
main()