File size: 8,987 Bytes
221475f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | #!/usr/bin/env -S uv run --extra cpu
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "fireredvad",
# "torch>=2.0.0",
# "numpy",
# "kaldiio",
# "huggingface_hub",
# ]
# ///
"""
Export FireRedVAD PyTorch checkpoints to the FRVD binary format consumed by
the pure-C inference engine (mod_fireredvad / fireredvad-dart).
Downloads the official FireRedTeam/FireRedVAD checkpoints from HuggingFace if
not already present, then writes:
fireredvad.bin "FRVD" + uint32 version + concatenated VAD then AED
weights as little-endian float32. Layout matches the
reader in fireredvad.c::fireredvad_load_weights.
fireredvad.json CMVN normalization stats {"means": [...], "inv_std": [...]}
converted from the kaldi-format cmvn.ark.
Usage:
python export_frvd.py
python export_frvd.py --output-dir /tmp/frvd
python export_frvd.py --model-root path/to/FireRedVAD
"""
import argparse
import json
import os
import struct
import sys
import numpy as np
import torch
HF_REPO = "FireRedTeam/FireRedVAD"
DEFAULT_MODEL_ROOT = "pretrained_models/FireRedVAD"
FRVD_MAGIC = b"FRVD"
FRVD_VERSION = 1
# Must match constants in fireredvad.h
D_IN = 80
D_HIDDEN = 256
D_PROJ = 128
D_FILTER = 20
N_BLOCKS = 8 # 1 input FSMN + 7 DFSMN blocks
N_FSMN_BLOCKS = N_BLOCKS - 1
AED_NUM_CLASSES = 3
def download_models(model_root):
"""Download FireRedVAD checkpoints from HuggingFace if missing."""
needed = ["Stream-VAD", "AED"]
if all(os.path.isdir(os.path.join(model_root, sub)) for sub in needed):
print(f"Models already present at: {model_root}")
return
try:
from huggingface_hub import snapshot_download
except ImportError:
print("Error: huggingface_hub not found. Install with: pip install huggingface_hub")
sys.exit(1)
print(f"Downloading {HF_REPO} to {model_root} ...")
snapshot_download(repo_id=HF_REPO, local_dir=model_root)
def load_state_dict(model_dir):
"""Load a FireRedVAD checkpoint and return its state_dict."""
from fireredvad.core.detect_model import DetectModel
model = DetectModel.from_pretrained(model_dir)
model.eval()
return model.state_dict()
def linear_w(sd, prefix, expected_in, expected_out):
"""PyTorch Linear stores [out, in]; FRVD reader expects [in, out] flat."""
w = sd[f"{prefix}.weight"]
assert tuple(w.shape) == (expected_out, expected_in), \
f"{prefix}.weight shape {tuple(w.shape)} != ({expected_out}, {expected_in})"
return w.t().contiguous().cpu().float().numpy().reshape(-1)
def linear_b(sd, prefix, expected_out):
b = sd[f"{prefix}.bias"]
assert tuple(b.shape) == (expected_out,), \
f"{prefix}.bias shape {tuple(b.shape)} != ({expected_out},)"
return b.cpu().float().numpy().reshape(-1)
def conv_filter(sd, prefix, expected_p, expected_k):
"""Depthwise Conv1d weight is [P, 1, K]; FRVD layout is [P, K] flat."""
w = sd[f"{prefix}.weight"]
assert tuple(w.shape) == (expected_p, 1, expected_k), \
f"{prefix}.weight shape {tuple(w.shape)} != ({expected_p}, 1, {expected_k})"
return w.cpu().float().numpy().reshape(-1)
def serialize_vad(sd, out):
"""Append VAD weights to `out` (a bytearray) in FRVD reader order."""
def write(arr):
out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes())
# Input projection
write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN))
write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN))
write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ))
write(linear_b(sd, "dfsmn.fc2.0", D_PROJ))
write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER))
# 7 DFSMN blocks (block fc2 has no bias)
for i in range(N_FSMN_BLOCKS):
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN))
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ))
write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER))
# Output head: dnns[0] is Linear(P, H); self.out is Linear(H, 1)
write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN))
write(linear_w(sd, "out", D_HIDDEN, 1))
write(linear_b(sd, "out", 1))
def serialize_aed(sd, out):
"""Append AED weights to `out`. AED adds lookahead filters and odim=3."""
def write(arr):
out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes())
write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN))
write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN))
write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ))
write(linear_b(sd, "dfsmn.fc2.0", D_PROJ))
write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER))
write(conv_filter(sd, "dfsmn.fsmn1.lookahead_filter", D_PROJ, D_FILTER))
for i in range(N_FSMN_BLOCKS):
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN))
write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ))
write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER))
write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookahead_filter", D_PROJ, D_FILTER))
write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN))
write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN))
write(linear_w(sd, "out", D_HIDDEN, AED_NUM_CLASSES))
write(linear_b(sd, "out", AED_NUM_CLASSES))
def export_weights(vad_dir, aed_dir, out_path):
print(f"Loading VAD checkpoint from: {vad_dir}")
vad_sd = load_state_dict(vad_dir)
print(f"Loading AED checkpoint from: {aed_dir}")
aed_sd = load_state_dict(aed_dir)
buf = bytearray()
buf.extend(FRVD_MAGIC)
buf.extend(struct.pack("<I", FRVD_VERSION))
serialize_vad(vad_sd, buf)
serialize_aed(aed_sd, buf)
with open(out_path, "wb") as f:
f.write(buf)
size_mb = len(buf) / 1024 / 1024
print(f"Wrote {out_path} ({len(buf):,} bytes, {size_mb:.2f} MB)")
def export_cmvn(cmvn_ark, out_path):
"""Convert kaldi cmvn.ark to {"means": [...], "inv_std": [...]} JSON."""
try:
import kaldiio
except ImportError:
print("Error: kaldiio not found. Install with: pip install kaldiio")
sys.exit(1)
stats = kaldiio.load_mat(cmvn_ark)
assert stats.shape[0] == 2, f"Unexpected cmvn shape: {stats.shape}"
dim = stats.shape[1] - 1
count = stats[0, dim]
assert count >= 1, f"Bad frame count in cmvn.ark: {count}"
# Compute the same way as fireredvad.core.audio_feat.CMVN so values stay
# in float32 precision (the C engine consumes them as float32 anyway).
floor = np.float32(1e-20)
means, inv_std = [], []
for d in range(dim):
mean = stats[0, d] / count
var = (stats[1, d] / count) - mean * mean
if var < floor:
var = floor
means.append(float(mean))
inv_std.append(float(np.float32(1.0) / np.sqrt(var)))
with open(out_path, "w") as f:
json.dump({"means": means, "inv_std": inv_std}, f)
print(f"Wrote {out_path} ({dim} bins)")
def main():
parser = argparse.ArgumentParser(
description="Export FireRedVAD PyTorch checkpoints to FRVD .bin + cmvn .json")
parser.add_argument("--model-root", default=DEFAULT_MODEL_ROOT,
help=f"Root directory for downloaded models (default: {DEFAULT_MODEL_ROOT})")
parser.add_argument("--vad-dir", default=None,
help="Path to streaming VAD model directory (default: {model-root}/Stream-VAD)")
parser.add_argument("--aed-dir", default=None,
help="Path to AED model directory (default: {model-root}/AED)")
parser.add_argument("--output-dir", default=".",
help="Output directory (default: current directory)")
parser.add_argument("--skip-download", action="store_true",
help="Skip downloading models (use existing local files)")
args = parser.parse_args()
if not args.skip_download and args.vad_dir is None and args.aed_dir is None:
download_models(args.model_root)
vad_dir = args.vad_dir or os.path.join(args.model_root, "Stream-VAD")
aed_dir = args.aed_dir or os.path.join(args.model_root, "AED")
os.makedirs(args.output_dir, exist_ok=True)
bin_path = os.path.join(args.output_dir, "fireredvad.bin")
json_path = os.path.join(args.output_dir, "fireredvad.json")
export_weights(vad_dir, aed_dir, bin_path)
cmvn_ark = os.path.join(vad_dir, "cmvn.ark")
if not os.path.isfile(cmvn_ark):
cmvn_ark = os.path.join(aed_dir, "cmvn.ark")
if not os.path.isfile(cmvn_ark):
print(f"Warning: cmvn.ark not found in {vad_dir} or {aed_dir}; skipping JSON")
return
export_cmvn(cmvn_ark, json_path)
print("Done.")
if __name__ == "__main__":
main()
|