| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Export FireRedVAD PyTorch checkpoints to the FRVD binary format consumed by |
| the pure-C inference engine (mod_fireredvad / fireredvad-dart). |
| |
| Downloads the official FireRedTeam/FireRedVAD checkpoints from HuggingFace if |
| not already present, then writes: |
| |
| fireredvad.bin "FRVD" + uint32 version + concatenated VAD then AED |
| weights as little-endian float32. Layout matches the |
| reader in fireredvad.c::fireredvad_load_weights. |
| |
| fireredvad.json CMVN normalization stats {"means": [...], "inv_std": [...]} |
| converted from the kaldi-format cmvn.ark. |
| |
| Usage: |
| python export_frvd.py |
| python export_frvd.py --output-dir /tmp/frvd |
| python export_frvd.py --model-root path/to/FireRedVAD |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import struct |
| import sys |
|
|
| import numpy as np |
| import torch |
|
|
| HF_REPO = "FireRedTeam/FireRedVAD" |
| DEFAULT_MODEL_ROOT = "pretrained_models/FireRedVAD" |
|
|
| FRVD_MAGIC = b"FRVD" |
| FRVD_VERSION = 1 |
|
|
| |
| D_IN = 80 |
| D_HIDDEN = 256 |
| D_PROJ = 128 |
| D_FILTER = 20 |
| N_BLOCKS = 8 |
| N_FSMN_BLOCKS = N_BLOCKS - 1 |
| AED_NUM_CLASSES = 3 |
|
|
|
|
| def download_models(model_root): |
| """Download FireRedVAD checkpoints from HuggingFace if missing.""" |
| needed = ["Stream-VAD", "AED"] |
| if all(os.path.isdir(os.path.join(model_root, sub)) for sub in needed): |
| print(f"Models already present at: {model_root}") |
| return |
|
|
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError: |
| print("Error: huggingface_hub not found. Install with: pip install huggingface_hub") |
| sys.exit(1) |
|
|
| print(f"Downloading {HF_REPO} to {model_root} ...") |
| snapshot_download(repo_id=HF_REPO, local_dir=model_root) |
|
|
|
|
| def load_state_dict(model_dir): |
| """Load a FireRedVAD checkpoint and return its state_dict.""" |
| from fireredvad.core.detect_model import DetectModel |
| model = DetectModel.from_pretrained(model_dir) |
| model.eval() |
| return model.state_dict() |
|
|
|
|
| def linear_w(sd, prefix, expected_in, expected_out): |
| """PyTorch Linear stores [out, in]; FRVD reader expects [in, out] flat.""" |
| w = sd[f"{prefix}.weight"] |
| assert tuple(w.shape) == (expected_out, expected_in), \ |
| f"{prefix}.weight shape {tuple(w.shape)} != ({expected_out}, {expected_in})" |
| return w.t().contiguous().cpu().float().numpy().reshape(-1) |
|
|
|
|
| def linear_b(sd, prefix, expected_out): |
| b = sd[f"{prefix}.bias"] |
| assert tuple(b.shape) == (expected_out,), \ |
| f"{prefix}.bias shape {tuple(b.shape)} != ({expected_out},)" |
| return b.cpu().float().numpy().reshape(-1) |
|
|
|
|
| def conv_filter(sd, prefix, expected_p, expected_k): |
| """Depthwise Conv1d weight is [P, 1, K]; FRVD layout is [P, K] flat.""" |
| w = sd[f"{prefix}.weight"] |
| assert tuple(w.shape) == (expected_p, 1, expected_k), \ |
| f"{prefix}.weight shape {tuple(w.shape)} != ({expected_p}, 1, {expected_k})" |
| return w.cpu().float().numpy().reshape(-1) |
|
|
|
|
| def serialize_vad(sd, out): |
| """Append VAD weights to `out` (a bytearray) in FRVD reader order.""" |
| def write(arr): |
| out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes()) |
|
|
| |
| write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN)) |
| write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN)) |
| write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ)) |
| write(linear_b(sd, "dfsmn.fc2.0", D_PROJ)) |
| write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER)) |
|
|
| |
| for i in range(N_FSMN_BLOCKS): |
| write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN)) |
| write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN)) |
| write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ)) |
| write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER)) |
|
|
| |
| write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN)) |
| write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN)) |
| write(linear_w(sd, "out", D_HIDDEN, 1)) |
| write(linear_b(sd, "out", 1)) |
|
|
|
|
| def serialize_aed(sd, out): |
| """Append AED weights to `out`. AED adds lookahead filters and odim=3.""" |
| def write(arr): |
| out.extend(np.ascontiguousarray(arr, dtype=np.float32).tobytes()) |
|
|
| write(linear_w(sd, "dfsmn.fc1.0", D_IN, D_HIDDEN)) |
| write(linear_b(sd, "dfsmn.fc1.0", D_HIDDEN)) |
| write(linear_w(sd, "dfsmn.fc2.0", D_HIDDEN, D_PROJ)) |
| write(linear_b(sd, "dfsmn.fc2.0", D_PROJ)) |
| write(conv_filter(sd, "dfsmn.fsmn1.lookback_filter", D_PROJ, D_FILTER)) |
| write(conv_filter(sd, "dfsmn.fsmn1.lookahead_filter", D_PROJ, D_FILTER)) |
|
|
| for i in range(N_FSMN_BLOCKS): |
| write(linear_w(sd, f"dfsmn.fsmns.{i}.fc1.0", D_PROJ, D_HIDDEN)) |
| write(linear_b(sd, f"dfsmn.fsmns.{i}.fc1.0", D_HIDDEN)) |
| write(linear_w(sd, f"dfsmn.fsmns.{i}.fc2", D_HIDDEN, D_PROJ)) |
| write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookback_filter", D_PROJ, D_FILTER)) |
| write(conv_filter(sd, f"dfsmn.fsmns.{i}.fsmn.lookahead_filter", D_PROJ, D_FILTER)) |
|
|
| write(linear_w(sd, "dfsmn.dnns.0", D_PROJ, D_HIDDEN)) |
| write(linear_b(sd, "dfsmn.dnns.0", D_HIDDEN)) |
| write(linear_w(sd, "out", D_HIDDEN, AED_NUM_CLASSES)) |
| write(linear_b(sd, "out", AED_NUM_CLASSES)) |
|
|
|
|
| def export_weights(vad_dir, aed_dir, out_path): |
| print(f"Loading VAD checkpoint from: {vad_dir}") |
| vad_sd = load_state_dict(vad_dir) |
| print(f"Loading AED checkpoint from: {aed_dir}") |
| aed_sd = load_state_dict(aed_dir) |
|
|
| buf = bytearray() |
| buf.extend(FRVD_MAGIC) |
| buf.extend(struct.pack("<I", FRVD_VERSION)) |
| serialize_vad(vad_sd, buf) |
| serialize_aed(aed_sd, buf) |
|
|
| with open(out_path, "wb") as f: |
| f.write(buf) |
| size_mb = len(buf) / 1024 / 1024 |
| print(f"Wrote {out_path} ({len(buf):,} bytes, {size_mb:.2f} MB)") |
|
|
|
|
| def export_cmvn(cmvn_ark, out_path): |
| """Convert kaldi cmvn.ark to {"means": [...], "inv_std": [...]} JSON.""" |
| try: |
| import kaldiio |
| except ImportError: |
| print("Error: kaldiio not found. Install with: pip install kaldiio") |
| sys.exit(1) |
|
|
| stats = kaldiio.load_mat(cmvn_ark) |
| assert stats.shape[0] == 2, f"Unexpected cmvn shape: {stats.shape}" |
| dim = stats.shape[1] - 1 |
| count = stats[0, dim] |
| assert count >= 1, f"Bad frame count in cmvn.ark: {count}" |
|
|
| |
| |
| floor = np.float32(1e-20) |
| means, inv_std = [], [] |
| for d in range(dim): |
| mean = stats[0, d] / count |
| var = (stats[1, d] / count) - mean * mean |
| if var < floor: |
| var = floor |
| means.append(float(mean)) |
| inv_std.append(float(np.float32(1.0) / np.sqrt(var))) |
|
|
| with open(out_path, "w") as f: |
| json.dump({"means": means, "inv_std": inv_std}, f) |
| print(f"Wrote {out_path} ({dim} bins)") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Export FireRedVAD PyTorch checkpoints to FRVD .bin + cmvn .json") |
| parser.add_argument("--model-root", default=DEFAULT_MODEL_ROOT, |
| help=f"Root directory for downloaded models (default: {DEFAULT_MODEL_ROOT})") |
| parser.add_argument("--vad-dir", default=None, |
| help="Path to streaming VAD model directory (default: {model-root}/Stream-VAD)") |
| parser.add_argument("--aed-dir", default=None, |
| help="Path to AED model directory (default: {model-root}/AED)") |
| parser.add_argument("--output-dir", default=".", |
| help="Output directory (default: current directory)") |
| parser.add_argument("--skip-download", action="store_true", |
| help="Skip downloading models (use existing local files)") |
| args = parser.parse_args() |
|
|
| if not args.skip_download and args.vad_dir is None and args.aed_dir is None: |
| download_models(args.model_root) |
|
|
| vad_dir = args.vad_dir or os.path.join(args.model_root, "Stream-VAD") |
| aed_dir = args.aed_dir or os.path.join(args.model_root, "AED") |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| bin_path = os.path.join(args.output_dir, "fireredvad.bin") |
| json_path = os.path.join(args.output_dir, "fireredvad.json") |
|
|
| export_weights(vad_dir, aed_dir, bin_path) |
|
|
| cmvn_ark = os.path.join(vad_dir, "cmvn.ark") |
| if not os.path.isfile(cmvn_ark): |
| cmvn_ark = os.path.join(aed_dir, "cmvn.ark") |
| if not os.path.isfile(cmvn_ark): |
| print(f"Warning: cmvn.ark not found in {vad_dir} or {aed_dir}; skipping JSON") |
| return |
| export_cmvn(cmvn_ark, json_path) |
|
|
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|