mpuig commited on
Commit
d711c8f
·
verified ·
1 Parent(s): 2242193

Upload export_nemo_ctc.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_nemo_ctc.py +113 -0
export_nemo_ctc.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10,<3.13"
3
+ # dependencies = [
4
+ # "nemo_toolkit[asr]",
5
+ # "torch<2.6",
6
+ # "onnx",
7
+ # "onnxruntime",
8
+ # "huggingface_hub",
9
+ # ]
10
+ # ///
11
+ """Export NeMo CTC model to ONNX format for sherpa-onnx.
12
+
13
+ Usage:
14
+ uv run --python 3.11 scripts/export_nemo_ctc.py
15
+
16
+ Output files are written to the gibberish models directory:
17
+ ~/Library/Application Support/gibberish/models/nemo-conformer-ca/
18
+ """
19
+
20
+ import os
21
+ from pathlib import Path
22
+
23
+ import nemo.collections.asr as nemo_asr
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ HF_REPO = "nvidia/stt_ca_conformer_ctc_large"
27
+ HF_FILENAME = "stt_ca_conformer_ctc_large.nemo"
28
+
29
+ # Output directory
30
+ if os.name == "nt":
31
+ models_dir = Path(os.environ.get("LOCALAPPDATA", "")) / "gibberish" / "models"
32
+ else:
33
+ models_dir = Path.home() / "Library" / "Application Support" / "gibberish" / "models"
34
+
35
+ output_dir = models_dir / "nemo-conformer-ca"
36
+ output_dir.mkdir(parents=True, exist_ok=True)
37
+
38
+ # Download .nemo file from HuggingFace
39
+ print(f"Downloading {HF_FILENAME} from {HF_REPO}...")
40
+ nemo_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FILENAME)
41
+ print(f"Downloaded to: {nemo_path}")
42
+
43
+ print("Loading NeMo model...")
44
+ m = nemo_asr.models.EncDecCTCModel.restore_from(nemo_path)
45
+ m.eval()
46
+
47
+ # Export tokens in sherpa-onnx format: TOKEN ID per line, with <blk> at end
48
+ tokens_path = output_dir / "tokens.txt"
49
+
50
+ # BPE models use tokenizer, not labels
51
+ if hasattr(m, 'tokenizer') and m.tokenizer is not None:
52
+ vocab_size = m.tokenizer.vocab_size
53
+ print(f"Writing {vocab_size} BPE tokens + <blk> to {tokens_path}")
54
+ with open(tokens_path, "w", encoding="utf-8") as f:
55
+ for i in range(vocab_size):
56
+ token = m.tokenizer.ids_to_tokens([i])[0]
57
+ # Replace special characters for sherpa compatibility
58
+ if token == " ":
59
+ token = "▁" # SentencePiece space marker
60
+ f.write(f"{token} {i}\n")
61
+ f.write(f"<blk> {vocab_size}\n")
62
+ else:
63
+ # Character-based CTC model
64
+ labels = list(m.cfg.labels)
65
+ print(f"Writing {len(labels)} tokens + <blk> to {tokens_path}")
66
+ with open(tokens_path, "w", encoding="utf-8") as f:
67
+ for i, t in enumerate(labels):
68
+ f.write(f"{t} {i}\n")
69
+ f.write(f"<blk> {len(labels)}\n")
70
+
71
+ # Export ONNX model
72
+ model_path = output_dir / "model.onnx"
73
+ print(f"Exporting ONNX model to {model_path}")
74
+ m.export(str(model_path))
75
+
76
+ # Add required metadata for sherpa-onnx compatibility
77
+ import onnx
78
+
79
+ print("Adding sherpa-onnx metadata to model...")
80
+ model = onnx.load(str(model_path))
81
+
82
+ # Get vocab size (including blank token)
83
+ if hasattr(m, 'tokenizer') and m.tokenizer is not None:
84
+ vocab_size = m.tokenizer.vocab_size + 1 # +1 for blank
85
+ else:
86
+ vocab_size = len(m.cfg.labels) + 1
87
+
88
+ # Extract config from model
89
+ normalize_type = str(m.preprocessor._cfg.get("normalize", ""))
90
+ subsampling_factor = str(m.encoder._cfg.get("subsampling_factor", 4))
91
+
92
+ # Add metadata
93
+ metadata = {
94
+ "vocab_size": str(vocab_size),
95
+ "normalize_type": normalize_type,
96
+ "subsampling_factor": subsampling_factor,
97
+ "model_type": "EncDecCTCModelBPE",
98
+ "version": "1",
99
+ }
100
+
101
+ for key, value in metadata.items():
102
+ meta = model.metadata_props.add()
103
+ meta.key = key
104
+ meta.value = value
105
+
106
+ onnx.save(model, str(model_path))
107
+ print(f"Added metadata: vocab_size={vocab_size}")
108
+
109
+ print(f"\nDone! Files written to: {output_dir}")
110
+ print(" - model.onnx")
111
+ print(" - tokens.txt")
112
+ if (output_dir / "model.onnx_data").exists():
113
+ print(" - model.onnx_data")