Upload export_nemo_ctc.py with huggingface_hub
Browse files- export_nemo_ctc.py +113 -0
export_nemo_ctc.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.10,<3.13"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "nemo_toolkit[asr]",
|
| 5 |
+
# "torch<2.6",
|
| 6 |
+
# "onnx",
|
| 7 |
+
# "onnxruntime",
|
| 8 |
+
# "huggingface_hub",
|
| 9 |
+
# ]
|
| 10 |
+
# ///
|
| 11 |
+
"""Export NeMo CTC model to ONNX format for sherpa-onnx.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
uv run --python 3.11 scripts/export_nemo_ctc.py
|
| 15 |
+
|
| 16 |
+
Output files are written to the gibberish models directory:
|
| 17 |
+
~/Library/Application Support/gibberish/models/nemo-conformer-ca/
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import nemo.collections.asr as nemo_asr
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
|
| 26 |
+
HF_REPO = "nvidia/stt_ca_conformer_ctc_large"
|
| 27 |
+
HF_FILENAME = "stt_ca_conformer_ctc_large.nemo"
|
| 28 |
+
|
| 29 |
+
# Output directory
|
| 30 |
+
if os.name == "nt":
|
| 31 |
+
models_dir = Path(os.environ.get("LOCALAPPDATA", "")) / "gibberish" / "models"
|
| 32 |
+
else:
|
| 33 |
+
models_dir = Path.home() / "Library" / "Application Support" / "gibberish" / "models"
|
| 34 |
+
|
| 35 |
+
output_dir = models_dir / "nemo-conformer-ca"
|
| 36 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# Download .nemo file from HuggingFace
|
| 39 |
+
print(f"Downloading {HF_FILENAME} from {HF_REPO}...")
|
| 40 |
+
nemo_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FILENAME)
|
| 41 |
+
print(f"Downloaded to: {nemo_path}")
|
| 42 |
+
|
| 43 |
+
print("Loading NeMo model...")
|
| 44 |
+
m = nemo_asr.models.EncDecCTCModel.restore_from(nemo_path)
|
| 45 |
+
m.eval()
|
| 46 |
+
|
| 47 |
+
# Export tokens in sherpa-onnx format: TOKEN ID per line, with <blk> at end
|
| 48 |
+
tokens_path = output_dir / "tokens.txt"
|
| 49 |
+
|
| 50 |
+
# BPE models use tokenizer, not labels
|
| 51 |
+
if hasattr(m, 'tokenizer') and m.tokenizer is not None:
|
| 52 |
+
vocab_size = m.tokenizer.vocab_size
|
| 53 |
+
print(f"Writing {vocab_size} BPE tokens + <blk> to {tokens_path}")
|
| 54 |
+
with open(tokens_path, "w", encoding="utf-8") as f:
|
| 55 |
+
for i in range(vocab_size):
|
| 56 |
+
token = m.tokenizer.ids_to_tokens([i])[0]
|
| 57 |
+
# Replace special characters for sherpa compatibility
|
| 58 |
+
if token == " ":
|
| 59 |
+
token = "▁" # SentencePiece space marker
|
| 60 |
+
f.write(f"{token} {i}\n")
|
| 61 |
+
f.write(f"<blk> {vocab_size}\n")
|
| 62 |
+
else:
|
| 63 |
+
# Character-based CTC model
|
| 64 |
+
labels = list(m.cfg.labels)
|
| 65 |
+
print(f"Writing {len(labels)} tokens + <blk> to {tokens_path}")
|
| 66 |
+
with open(tokens_path, "w", encoding="utf-8") as f:
|
| 67 |
+
for i, t in enumerate(labels):
|
| 68 |
+
f.write(f"{t} {i}\n")
|
| 69 |
+
f.write(f"<blk> {len(labels)}\n")
|
| 70 |
+
|
| 71 |
+
# Export ONNX model
|
| 72 |
+
model_path = output_dir / "model.onnx"
|
| 73 |
+
print(f"Exporting ONNX model to {model_path}")
|
| 74 |
+
m.export(str(model_path))
|
| 75 |
+
|
| 76 |
+
# Add required metadata for sherpa-onnx compatibility
|
| 77 |
+
import onnx
|
| 78 |
+
|
| 79 |
+
print("Adding sherpa-onnx metadata to model...")
|
| 80 |
+
model = onnx.load(str(model_path))
|
| 81 |
+
|
| 82 |
+
# Get vocab size (including blank token)
|
| 83 |
+
if hasattr(m, 'tokenizer') and m.tokenizer is not None:
|
| 84 |
+
vocab_size = m.tokenizer.vocab_size + 1 # +1 for blank
|
| 85 |
+
else:
|
| 86 |
+
vocab_size = len(m.cfg.labels) + 1
|
| 87 |
+
|
| 88 |
+
# Extract config from model
|
| 89 |
+
normalize_type = str(m.preprocessor._cfg.get("normalize", ""))
|
| 90 |
+
subsampling_factor = str(m.encoder._cfg.get("subsampling_factor", 4))
|
| 91 |
+
|
| 92 |
+
# Add metadata
|
| 93 |
+
metadata = {
|
| 94 |
+
"vocab_size": str(vocab_size),
|
| 95 |
+
"normalize_type": normalize_type,
|
| 96 |
+
"subsampling_factor": subsampling_factor,
|
| 97 |
+
"model_type": "EncDecCTCModelBPE",
|
| 98 |
+
"version": "1",
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
for key, value in metadata.items():
|
| 102 |
+
meta = model.metadata_props.add()
|
| 103 |
+
meta.key = key
|
| 104 |
+
meta.value = value
|
| 105 |
+
|
| 106 |
+
onnx.save(model, str(model_path))
|
| 107 |
+
print(f"Added metadata: vocab_size={vocab_size}")
|
| 108 |
+
|
| 109 |
+
print(f"\nDone! Files written to: {output_dir}")
|
| 110 |
+
print(" - model.onnx")
|
| 111 |
+
print(" - tokens.txt")
|
| 112 |
+
if (output_dir / "model.onnx_data").exists():
|
| 113 |
+
print(" - model.onnx_data")
|