FireRed Team commited on
Commit
0ddb4a4
·
verified ·
1 Parent(s): fc63095
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +511 -1
  2. app.py +111 -0
  3. fireredasr2s/__init__.py +40 -0
  4. fireredasr2s/fireredasr2/__init__.py +27 -0
  5. fireredasr2s/fireredasr2/asr.py +241 -0
  6. fireredasr2s/fireredasr2/data/asr_feat.py +124 -0
  7. fireredasr2s/fireredasr2/data/token_dict.py +63 -0
  8. fireredasr2s/fireredasr2/models/fireredasr_aed.py +74 -0
  9. fireredasr2s/fireredasr2/models/fireredasr_llm.py +297 -0
  10. fireredasr2s/fireredasr2/models/lstm_lm.py +65 -0
  11. fireredasr2s/fireredasr2/models/module/adapter.py +32 -0
  12. fireredasr2s/fireredasr2/models/module/conformer_encoder.py +324 -0
  13. fireredasr2s/fireredasr2/models/module/ctc.py +119 -0
  14. fireredasr2s/fireredasr2/models/module/transformer_decoder.py +329 -0
  15. fireredasr2s/fireredasr2/models/param.py +17 -0
  16. fireredasr2s/fireredasr2/speech2text.py +107 -0
  17. fireredasr2s/fireredasr2/tokenizer/aed_tokenizer.py +93 -0
  18. fireredasr2s/fireredasr2/tokenizer/llm_tokenizer.py +107 -0
  19. fireredasr2s/fireredasr2/utils/io.py +55 -0
  20. fireredasr2s/fireredasr2/utils/wer.py +326 -0
  21. fireredasr2s/fireredasr2s-cli +273 -0
  22. fireredasr2s/fireredasr2s_cli.py +273 -0
  23. fireredasr2s/fireredasr2system.py +200 -0
  24. fireredasr2s/fireredlid/README.md +125 -0
  25. fireredasr2s/fireredlid/__init__.py +23 -0
  26. fireredasr2s/fireredlid/data/feat.py +124 -0
  27. fireredasr2s/fireredlid/data/token_dict.py +63 -0
  28. fireredasr2s/fireredlid/lid.py +110 -0
  29. fireredasr2s/fireredlid/models/fireredlid_aed.py +37 -0
  30. fireredasr2s/fireredlid/models/module/conformer_encoder.py +324 -0
  31. fireredasr2s/fireredlid/models/module/transformer_decoder.py +317 -0
  32. fireredasr2s/fireredlid/models/param.py +17 -0
  33. fireredasr2s/fireredlid/speech2lang.py +73 -0
  34. fireredasr2s/fireredlid/tokenizer/lid_tokenizer.py +17 -0
  35. fireredasr2s/fireredlid/utils/io.py +38 -0
  36. fireredasr2s/fireredpunc/__init__.py +27 -0
  37. fireredasr2s/fireredpunc/add_punc.py +96 -0
  38. fireredasr2s/fireredpunc/data/__init__.py +0 -0
  39. fireredasr2s/fireredpunc/data/hf_bert_tokenizer.py +180 -0
  40. fireredasr2s/fireredpunc/data/token_dict.py +63 -0
  41. fireredasr2s/fireredpunc/models/__init__.py +0 -0
  42. fireredasr2s/fireredpunc/models/fireredpunc_bert.py +69 -0
  43. fireredasr2s/fireredpunc/models/param.py +17 -0
  44. fireredasr2s/fireredpunc/punc.py +391 -0
  45. fireredasr2s/fireredvad/__init__.py +57 -0
  46. fireredasr2s/fireredvad/aed.py +109 -0
  47. fireredasr2s/fireredvad/bin/__init__.py +0 -0
  48. fireredasr2s/fireredvad/bin/aed.py +92 -0
  49. fireredasr2s/fireredvad/bin/fireredvad_cli.py +41 -0
  50. fireredasr2s/fireredvad/bin/stream_vad.py +172 -0
README.md CHANGED
@@ -12,4 +12,514 @@ license: apache-2.0
12
  short_description: A SOTA Industrial-Grade All-in-One ASR system
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  short_description: A SOTA Industrial-Grade All-in-One ASR system
13
  ---
14
 
15
+ <div align="center">
16
+ <h1>
17
+ FireRedASR2S
18
+ <br>
19
+ A SOTA Industrial-Grade All-in-One ASR System
20
+ </h1>
21
+
22
+ </div>
23
+
24
+ [[Paper]](https://arxiv.org/pdf/2603.10420)
25
+ [[Model🤗]](https://huggingface.co/collections/FireRedTeam/fireredasr2s)
26
+ [[Model🤖]](https://www.modelscope.cn/collections/xukaituo/FireRedASR2S)
27
+ [[Demo]](https://huggingface.co/spaces/FireRedTeam/FireRedASR)
28
+
29
+
30
+ FireRedASR2S is a state-of-the-art (SOTA), industrial-grade, all-in-one ASR system with ASR, VAD, LID, and Punc modules. All modules achieve SOTA performance:
31
+ - **FireRedASR2**: Automatic Speech Recognition (ASR) supporting peech and singing transcription for Chinese (Mandarin, 20+ dialects/accents), English, code-switching. 2.89% average CER on 4 public Mandarin benchmarks, 11.55% on 19 Chinese dialects and accents benchmarks, **outperforming Doubao-ASR, Qwen3-ASR-1.7B, Fun-ASR, and Fun-ASR-Nano-2512**. FireRedASR2-AED also supports word-level timestamps and confidence scores.
32
+ - **FireRedVAD**: Voice Activity Detection (VAD) supporting speech/singing/music in 100+ languages. 97.57% F1, **outperforming Silero-VAD, TEN-VAD, FunASR-VAD and WebRTC-VAD**. Supports non-streaming/streaming VAD and Multi-label VAD (mVAD).
33
+ - **FireRedLID**: Spoken Language Identification (LID) supporting 100+ languages and 20+ Chinese dialects/accents. 97.18% accuracy, **outperforming Whisper and SpeechBrain**.
34
+ - **FireRedPunc**: Punctuation Prediction (Punc) for Chinese and English. 78.90% average F1, outperforming FunASR-Punc (62.77%).
35
+
36
+ *`2S`: `2`nd-generation FireRedASR, now expanded to an all-in-one ASR `S`ystem*
37
+
38
+
39
+ ## 🔥 News
40
+ - [2026.03.12] 🔥 We release FireRedASR2S technical report. See [arXiv](https://arxiv.org/abs/2603.10420).
41
+ - [2026.03.05] 🚀 [vLLM](https://github.com/vllm-project/vllm/pull/35727) supports FireRedASR2-LLM. See [vLLM Usage](https://github.com/FireRedTeam/FireRedASR2S?tab=readme-ov-file#vllm-usage) part.
42
+ - [2026.02.25] 🔥 We release **FireRedASR2-LLM model weights**. [🤗](https://huggingface.co/FireRedTeam/FireRedASR2-LLM) [🤖](https://www.modelscope.cn/models/xukaituo/FireRedASR2-LLM/)
43
+ - [2026.02.13] 🚀 Support TensorRT-LLM inference acceleration for FireRedASR2-AED (contributed by NVIDIA). Benchmark on AISHELL-1 test set shows **12.7x speedup** over PyTorch baseline (single H20).
44
+ - [2026.02.12] 🔥 We release FireRedASR2S (FireRedASR2-AED, FireRedVAD, FireRedLID, and FireRedPunc) with **model weights and inference code**. Download links below. Technical report and finetuning code coming soon.
45
+
46
+
47
+
48
+ ## Available Models and Languages
49
+
50
+ |Model|Supported Languages & Dialects|Download|
51
+ |:-------------:|:---------------------------------:|:----------:|
52
+ |FireRedASR2-LLM| Chinese (Mandarin and 20+ dialects/accents<sup>*</sup>), English, Code-Switching | [🤗](https://huggingface.co/FireRedTeam/FireRedASR2-LLM) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedASR2-LLM/)|
53
+ |FireRedASR2-AED| Chinese (Mandarin and 20+ dialects/accents<sup>*</sup>), English, Code-Switching | [🤗](https://huggingface.co/FireRedTeam/FireRedASR2-AED) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedASR2-AED/)|
54
+ |FireRedVAD | 100+ languages, 20+ Chinese dialects/accents<sup>*</sup> | [🤗](https://huggingface.co/FireRedTeam/FireRedVAD) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedVAD/)|
55
+ |FireRedLID | 100+ languages, 20+ Chinese dialects/accents<sup>*</sup> | [🤗](https://huggingface.co/FireRedTeam/FireRedLID) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedLID/)|
56
+ |FireRedPunc| Chinese, English | [🤗](https://huggingface.co/FireRedTeam/FireRedPunc) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedPunc/)|
57
+
58
+ <sup>*</sup>Supported Chinese dialects/accents: Cantonese (Hong Kong & Guangdong), Sichuan, Shanghai, Wu, Minnan, Anhui, Fujian, Gansu, Guizhou, Hebei, Henan, Hubei, Hunan, Jiangxi, Liaoning, Ningxia, Shaanxi, Shanxi, Shandong, Tianjin, Yunnan, etc.
59
+
60
+
61
+
62
+ ## Method
63
+ ### FireRedASR2S: System Overview
64
+ ![Model](./assets/FireRedASR2S.png)
65
+
66
+ ### FireRedASR2
67
+ FireRedASR2 builds upon [FireRedASR](https://github.com/FireRedTeam/FireRedASR) with improved accuracy, designed to meet diverse requirements in superior performance and optimal efficiency across various applications. It comprises two variants:
68
+ - **FireRedASR2-LLM**: Designed to achieve state-of-the-art performance and to enable seamless end-to-end speech interaction. It adopts an Encoder-Adapter-LLM framework leveraging large language model (LLM) capabilities.
69
+ - **FireRedASR2-AED**: Designed to balance high performance and computational efficiency and to serve as an effective speech representation module in LLM-based speech models. It utilizes an Attention-based Encoder-Decoder (AED) architecture.
70
+
71
+ ![Model](./assets/FireRedASR2.png)
72
+
73
+ ### Other Modules
74
+ - **FireRedVAD**: DFSMN-based non-streaming/streaming Voice Activity Detection and Multi-label VAD (mVAD). mVAD can be viewed as a lightweight Audio Event Detection (AED) system specialized for a small set of sound categories (speech/singing/music).
75
+ - **FireRedLID**: Encoder-Decoder-based Spoken Language Identification. See [FireRedLID README](./fireredasr2s/fireredlid/README.md) for language details.
76
+ - **FireRedPunc**: BERT-based Punctuation Prediction.
77
+
78
+
79
+ ## Quick Start
80
+ ### Setup
81
+ 1. Create a clean Python environment:
82
+ ```bash
83
+ $ conda create --name fireredasr2s python=3.10
84
+ $ conda activate fireredasr2s
85
+ $ git clone https://github.com/FireRedTeam/FireRedASR2S.git
86
+ $ cd FireRedASR2S # or fireredasr2s
87
+ ```
88
+
89
+ 2. Install dependencies and set up PATH and PYTHONPATH:
90
+ ```bash
91
+ $ pip install -r requirements.txt
92
+ $ export PATH=$PWD/fireredasr2s/:$PATH
93
+ $ export PYTHONPATH=$PWD/:$PYTHONPATH
94
+ ```
95
+
96
+ 3. Download models:
97
+ ```bash
98
+ # Download via ModelScope (recommended for users in China)
99
+ pip install -U modelscope
100
+ modelscope download --model xukaituo/FireRedASR2-AED --local_dir ./pretrained_models/FireRedASR2-AED
101
+ modelscope download --model xukaituo/FireRedVAD --local_dir ./pretrained_models/FireRedVAD
102
+ modelscope download --model xukaituo/FireRedLID --local_dir ./pretrained_models/FireRedLID
103
+ modelscope download --model xukaituo/FireRedPunc --local_dir ./pretrained_models/FireRedPunc
104
+ modelscope download --model xukaituo/FireRedASR2-LLM --local_dir ./pretrained_models/FireRedASR2-LLM
105
+
106
+ # Download via Hugging Face
107
+ pip install -U "huggingface_hub[cli]"
108
+ huggingface-cli download FireRedTeam/FireRedASR2-AED --local-dir ./pretrained_models/FireRedASR2-AED
109
+ huggingface-cli download FireRedTeam/FireRedVAD --local-dir ./pretrained_models/FireRedVAD
110
+ huggingface-cli download FireRedTeam/FireRedLID --local-dir ./pretrained_models/FireRedLID
111
+ huggingface-cli download FireRedTeam/FireRedPunc --local-dir ./pretrained_models/FireRedPunc
112
+ huggingface-cli download FireRedTeam/FireRedASR2-LLM --local-dir ./pretrained_models/FireRedASR2-LLM
113
+ ```
114
+
115
+ 4. Convert your audio to **16kHz 16-bit mono PCM** format if needed:
116
+ ```bash
117
+ $ ffmpeg -i <input_audio_path> -ar 16000 -ac 1 -acodec pcm_s16le -f wav <output_wav_path>
118
+ ```
119
+
120
+ ### Script Usage
121
+ ```bash
122
+ $ cd examples_infer/asr_system
123
+ $ bash inference_asr_system.sh
124
+ ```
125
+
126
+ ### Command-line Usage
127
+ ```bash
128
+ $ fireredasr2s-cli --help
129
+ $ fireredasr2s-cli --wav_paths "assets/hello_zh.wav" "assets/hello_en.wav" --outdir output
130
+ $ cat output/result.jsonl
131
+ # {"uttid": "hello_zh", "text": "你好世界。", "sentences": [{"start_ms": 310, "end_ms": 1840, "text": "你好世界。", "asr_confidence": 0.875, "lang": "zh mandarin", "lang_confidence": 0.999}], "vad_segments_ms": [[310, 1840]], "dur_s": 2.32, "words": [{"start_ms": 490, "end_ms": 690, "text": "你"}, {"start_ms": 690, "end_ms": 1090, "text": "好"}, {"start_ms": 1090, "end_ms": 1330, "text": "世"}, {"start_ms": 1330, "end_ms": 1795, "text": "界"}], "wav_path": "assets/hello_zh.wav"}
132
+ # {"uttid": "hello_en", "text": "Hello speech.", "sentences": [{"start_ms": 120, "end_ms": 1840, "text": "Hello speech.", "asr_confidence": 0.833, "lang": "en", "lang_confidence": 0.998}], "vad_segments_ms": [[120, 1840]], "dur_s": 2.24, "words": [{"start_ms": 340, "end_ms": 1020, "text": "hello"}, {"start_ms": 1020, "end_ms": 1666, "text": "speech"}], "wav_path": "assets/hello_en.wav"}
133
+ ```
134
+
135
+ ### Python API Usage
136
+ ```python
137
+ from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
138
+
139
+ asr_system_config = FireRedAsr2SystemConfig() # Use default config
140
+ asr_system = FireRedAsr2System(asr_system_config)
141
+
142
+ result = asr_system.process("assets/hello_zh.wav")
143
+ print(result)
144
+ # {'uttid': 'tmpid', 'text': '你好世界。', 'sentences': [{'start_ms': 440, 'end_ms': 1820, 'text': '你好世界。', 'asr_confidence': 0.868, 'lang': 'zh mandarin', 'lang_confidence': 0.999}], 'vad_segments_ms': [(440, 1820)], 'dur_s': 2.32, 'words': [], 'wav_path': 'assets/hello_zh.wav'}
145
+
146
+ result = asr_system.process("assets/hello_en.wav")
147
+ print(result)
148
+ # {'uttid': 'tmpid', 'text': 'Hello speech.', 'sentences': [{'start_ms': 260, 'end_ms': 1820, 'text': 'Hello speech.', 'asr_confidence': 0.933, 'lang': 'en', 'lang_confidence': 0.993}], 'vad_segments_ms': [(260, 1820)], 'dur_s': 2.24, 'words': [], 'wav_path': 'assets/hello_en.wav'}
149
+ ```
150
+
151
+
152
+
153
+ ## Usage of Each Module
154
+ The four components under `fireredasr2s`, i.e. `fireredasr2`, `fireredvad`, `fireredlid`, and `fireredpunc` are self-contained and designed to work as a standalone modules. You can use any of them independently without depending on the others. `FireRedVAD` and `FireRedLID` will also be open-sourced as standalone libraries in separate repositories.
155
+
156
+ ### Script Usage
157
+ ```bash
158
+ # ASR
159
+ $ cd examples_infer/asr
160
+ $ bash inference_asr_aed.sh
161
+ $ bash inference_asr_llm.sh
162
+
163
+ # VAD & mVAD (mVAD=Audio Event Detection, AED)
164
+ $ cd examples_infer/vad
165
+ $ bash inference_vad.sh
166
+ $ bash inference_streamvad.sh
167
+ $ bash inference_aed.sh
168
+
169
+ # LID
170
+ $ cd examples_infer/lid
171
+ $ bash inference_lid.sh
172
+
173
+ # Punc
174
+ $ cd examples_infer/punc
175
+ $ bash inference_punc.sh
176
+ ```
177
+
178
+ ### vLLM Usage
179
+ ```shell
180
+ # Serving FireRedASR2-LLM with latest vLLM for the highest performance.
181
+ # For more details, see https://github.com/vllm-project/vllm/pull/35727.
182
+ $ vllm serve allendou/FireRedASR2-LLM-vllm -tp=2 --dtype=float32
183
+ $ python3 examples/online_serving/openai_transcription_client.py --repetition_penalty=1.0 --audio_path=/root/hello_zh.wav
184
+ ```
185
+
186
+ ### Python API Usage
187
+ Set up `PYTHONPATH` first: `export PYTHONPATH=$PWD/:$PYTHONPATH`
188
+
189
+ #### ASR
190
+ ```python
191
+ from fireredasr2s.fireredasr2 import FireRedAsr2, FireRedAsr2Config
192
+
193
+ batch_uttid = ["hello_zh", "hello_en"]
194
+ batch_wav_path = ["assets/hello_zh.wav", "assets/hello_en.wav"]
195
+
196
+ # FireRedASR2-AED
197
+ asr_config = FireRedAsr2Config(
198
+ use_gpu=True,
199
+ use_half=False,
200
+ beam_size=3,
201
+ nbest=1,
202
+ decode_max_len=0,
203
+ softmax_smoothing=1.25,
204
+ aed_length_penalty=0.6,
205
+ eos_penalty=1.0,
206
+ return_timestamp=True
207
+ )
208
+ model = FireRedAsr2.from_pretrained("aed", "pretrained_models/FireRedASR2-AED", asr_config)
209
+ results = model.transcribe(batch_uttid, batch_wav_path)
210
+ print(results)
211
+ # [{'uttid': 'hello_zh', 'text': '你好世界', 'confidence': 0.971, 'dur_s': 2.32, 'rtf': '0.0870', 'wav': 'assets/hello_zh.wav', 'timestamp': [('你', 0.42, 0.66), ('好', 0.66, 1.1), ('世', 1.1, 1.34), ('界', 1.34, 2.039)]}, {'uttid': 'hello_en', 'text': 'hello speech', 'confidence': 0.943, 'dur_s': 2.24, 'rtf': '0.0870', 'wav': 'assets/hello_en.wav', 'timestamp': [('hello', 0.34, 0.98), ('speech', 0.98, 1.766)]}]
212
+
213
+ # FireRedASR2-LLM
214
+ asr_config = FireRedAsr2Config(
215
+ use_gpu=True,
216
+ decode_min_len=0,
217
+ repetition_penalty=1.0,
218
+ llm_length_penalty=0.0,
219
+ temperature=1.0
220
+ )
221
+ model = FireRedAsr2.from_pretrained("llm", "pretrained_models/FireRedASR2-LLM", asr_config)
222
+ results = model.transcribe(batch_uttid, batch_wav_path)
223
+ print(results)
224
+ # [{'uttid': 'hello_zh', 'text': '你好世界', 'rtf': '0.0681', 'wav': 'assets/hello_zh.wav'}, {'uttid': 'hello_en', 'text': 'hello speech', 'rtf': '0.0681', 'wav': 'assets/hello_en.wav'}]
225
+ ```
226
+
227
+
228
+ #### VAD
229
+ ```python
230
+ from fireredasr2s.fireredvad import FireRedVad, FireRedVadConfig
231
+
232
+ vad_config = FireRedVadConfig(
233
+ use_gpu=False,
234
+ smooth_window_size=5,
235
+ speech_threshold=0.4,
236
+ min_speech_frame=20,
237
+ max_speech_frame=2000,
238
+ min_silence_frame=20,
239
+ merge_silence_frame=0,
240
+ extend_speech_frame=0,
241
+ chunk_max_frame=30000)
242
+ vad = FireRedVad.from_pretrained("pretrained_models/FireRedVAD/VAD", vad_config)
243
+
244
+ result, probs = vad.detect("assets/hello_zh.wav")
245
+
246
+ print(result)
247
+ # {'dur': 2.32, 'timestamps': [(0.44, 1.82)], 'wav_path': 'assets/hello_zh.wav'}
248
+ ```
249
+
250
+
251
+ #### Stream VAD
252
+ <details>
253
+ <summary>Click to expand</summary>
254
+
255
+ ```python
256
+ from fireredasr2s.fireredvad import FireRedStreamVad, FireRedStreamVadConfig
257
+
258
+ vad_config=FireRedStreamVadConfig(
259
+ use_gpu=False,
260
+ smooth_window_size=5,
261
+ speech_threshold=0.4,
262
+ pad_start_frame=5,
263
+ min_speech_frame=8,
264
+ max_speech_frame=2000,
265
+ min_silence_frame=20,
266
+ chunk_max_frame=30000)
267
+ stream_vad = FireRedStreamVad.from_pretrained("pretrained_models/FireRedVAD/Stream-VAD", vad_config)
268
+
269
+ frame_results, result = stream_vad.detect_full("assets/hello_zh.wav")
270
+
271
+ print(result)
272
+ # {'dur': 2.32, 'timestamps': [(0.46, 1.84)], 'wav_path': 'assets/hello_zh.wav'}
273
+ ```
274
+ </details>
275
+
276
+
277
+ #### mVAD (Audio Event Detection, AED)
278
+ <details>
279
+ <summary>Click to expand</summary>
280
+
281
+ ```python
282
+ from fireredasr2s.fireredvad import FireRedAed, FireRedAedConfig
283
+
284
+ aed_config=FireRedAedConfig(
285
+ use_gpu=False,
286
+ smooth_window_size=5,
287
+ speech_threshold=0.4,
288
+ singing_threshold=0.5,
289
+ music_threshold=0.5,
290
+ min_event_frame=20,
291
+ max_event_frame=2000,
292
+ min_silence_frame=20,
293
+ merge_silence_frame=0,
294
+ extend_speech_frame=0,
295
+ chunk_max_frame=30000)
296
+ aed = FireRedAed.from_pretrained("pretrained_models/FireRedVAD/AED", aed_config)
297
+
298
+ result, probs = aed.detect("assets/event.wav")
299
+
300
+ print(result)
301
+ # {'dur': 22.016, 'event2timestamps': {'speech': [(0.4, 3.56), (3.66, 9.08), (9.27, 9.77), (10.78, 21.76)], 'singing': [(1.79, 19.96), (19.97, 22.016)], 'music': [(0.09, 12.32), (12.33, 22.016)]}, 'event2ratio': {'speech': 0.848, 'singing': 0.905, 'music': 0.991}, 'wav_path': 'assets/event.wav'}
302
+ ```
303
+ </details>
304
+
305
+
306
+ #### LID
307
+ <details>
308
+ <summary>Click to expand</summary>
309
+
310
+ ```python
311
+ from fireredasr2s.fireredlid import FireRedLid, FireRedLidConfig
312
+
313
+ batch_uttid = ["hello_zh", "hello_en"]
314
+ batch_wav_path = ["assets/hello_zh.wav", "assets/hello_en.wav"]
315
+
316
+ config = FireRedLidConfig(use_gpu=True, use_half=False)
317
+ model = FireRedLid.from_pretrained("pretrained_models/FireRedLID", config)
318
+
319
+ results = model.process(batch_uttid, batch_wav_path)
320
+ print(results)
321
+ # [{'uttid': 'hello_zh', 'lang': 'zh mandarin', 'confidence': 0.996, 'dur_s': 2.32, 'rtf': '0.0741', 'wav': 'assets/hello_zh.wav'}, {'uttid': 'hello_en', 'lang': 'en', 'confidence': 0.996, 'dur_s': 2.24, 'rtf': '0.0741', 'wav': 'assets/hello_en.wav'}]
322
+ ```
323
+ </details>
324
+
325
+
326
+ #### Punc
327
+ <details>
328
+ <summary>Click to expand</summary>
329
+
330
+ ```python
331
+ from fireredasr2s.fireredpunc.punc import FireRedPunc, FireRedPuncConfig
332
+
333
+ config = FireRedPuncConfig(use_gpu=True)
334
+ model = FireRedPunc.from_pretrained("pretrained_models/FireRedPunc", config)
335
+
336
+ batch_text = ["你好世界", "Hello world"]
337
+ results = model.process(batch_text)
338
+
339
+ print(results)
340
+ # [{'punc_text': '你好世界。', 'origin_text': '你好世界'}, {'punc_text': 'Hello world!', 'origin_text': 'Hello world'}]
341
+ ```
342
+ </details>
343
+
344
+
345
+ #### ASR System
346
+ ```python
347
+ from fireredasr2s.fireredasr2 import FireRedAsr2Config
348
+ from fireredasr2s.fireredlid import FireRedLidConfig
349
+ from fireredasr2s.fireredpunc import FireRedPuncConfig
350
+ from fireredasr2s.fireredvad import FireRedVadConfig
351
+ from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
352
+
353
+ vad_config = FireRedVadConfig(
354
+ use_gpu=False,
355
+ smooth_window_size=5,
356
+ speech_threshold=0.4,
357
+ min_speech_frame=20,
358
+ max_speech_frame=2000,
359
+ min_silence_frame=20,
360
+ merge_silence_frame=0,
361
+ extend_speech_frame=0,
362
+ chunk_max_frame=30000
363
+ )
364
+ lid_config = FireRedLidConfig(use_gpu=True, use_half=False)
365
+ asr_config = FireRedAsr2Config(
366
+ use_gpu=True,
367
+ use_half=False,
368
+ beam_size=3,
369
+ nbest=1,
370
+ decode_max_len=0,
371
+ softmax_smoothing=1.25,
372
+ aed_length_penalty=0.6,
373
+ eos_penalty=1.0,
374
+ return_timestamp=True
375
+ )
376
+ punc_config = FireRedPuncConfig(use_gpu=True)
377
+
378
+ asr_system_config = FireRedAsr2SystemConfig(
379
+ "pretrained_models/FireRedVAD/VAD",
380
+ "pretrained_models/FireRedLID",
381
+ "aed", "pretrained_models/FireRedASR2-AED",
382
+ "pretrained_models/FireRedPunc",
383
+ vad_config, lid_config, asr_config, punc_config,
384
+ enable_vad=1, enable_lid=1, enable_punc=1
385
+ )
386
+ asr_system = FireRedAsr2System(asr_system_config)
387
+
388
+ batch_uttid = ["hello_zh", "hello_en"]
389
+ batch_wav_path = ["assets/hello_zh.wav", "assets/hello_en.wav"]
390
+ for wav_path, uttid in zip(batch_wav_path, batch_uttid):
391
+ result = asr_system.process(wav_path, uttid)
392
+ print(result)
393
+ # {'uttid': 'hello_zh', 'text': '你好世界。', 'sentences': [{'start_ms': 440, 'end_ms': 1820, 'text': '你好世界。', 'asr_confidence': 0.868, 'lang': 'zh mandarin', 'lang_confidence': 0.999}], 'vad_segments_ms': [(440, 1820)], 'dur_s': 2.32, 'words': [{'start_ms': 540, 'end_ms': 700, 'text': '你'}, {'start_ms': 700, 'end_ms': 1100, 'text': '好'}, {'start_ms': 1100, 'end_ms': 1300, 'text': '世'}, {'start_ms': 1300, 'end_ms': 1765, 'text': '界'}], 'wav_path': 'assets/hello_zh.wav'}
394
+ # {'uttid': 'hello_en', 'text': 'Hello speech.', 'sentences': [{'start_ms': 260, 'end_ms': 1820, 'text': 'Hello speech.', 'asr_confidence': 0.933, 'lang': 'en', 'lang_confidence': 0.993}], 'vad_segments_ms': [(260, 1820)], 'dur_s': 2.24, 'words': [{'start_ms': 400, 'end_ms': 960, 'text': 'hello'}, {'start_ms': 960, 'end_ms': 1666, 'text': 'speech'}], 'wav_path': 'assets/hello_en.wav'}
395
+ ```
396
+
397
+ **Note:** `FireRedASR2S` code has only been tested on Linux Ubuntu 22.04. Behavior on other Linux distributions or Windows has not been tested.
398
+
399
+
400
+ ## FAQ
401
+ **Q: What audio format is supported?**
402
+
403
+ 16kHz 16-bit mono PCM wav. Use ffmpeg to convert other formats: `ffmpeg -i <input_audio_path> -ar 16000 -ac 1 -acodec pcm_s16le -f wav <output_wav_path>`
404
+
405
+ **Q: What are the input length limitations of ASR models?**
406
+
407
+ - **FireRedASR2-AED** supports audio input **up to 60s**. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors.
408
+ - **FireRedASR2-LLM** supports audio input **up to 40s**. The behavior for longer input is untested.
409
+ - **FireRedASR2-LLM Batch Beam Search**: When performing batch beam search with FireRedASR2-LLM, even though attention masks are applied, it is recommended to ensure that the input lengths of the utterances are similar. If there are significant differences in utterance lengths, shorter utterances may experience **repetition issues**. You can either sort your dataset by length or set `batch_size` to 1 to avoid the repetition issue.
410
+
411
+
412
+
413
+ ## Evaluation
414
+ ### FireRedASR2
415
+ Metrics: Character Error Rate (CER%) for Chinese and Word Error Rate (WER%) for English. Lower is better.
416
+
417
+ We evaluate FireRedASR2 on 24 public test sets covering Mandarin, 20+ Chinese dialects/accents, and singing.
418
+
419
+ - **Mandarin (4 test sets)**: 2.89% (LLM) / 3.05% (AED) average CER, outperforming Doubao-ASR (3.69%), Qwen3-ASR-1.7B (3.76%), Fun-ASR (4.16%) and Fun-ASR-Nano-2512 (4.55%).
420
+ - **Dialects (19 test sets)**: 11.55% (LLM) / 11.67% (AED) average CER, outperforming Doubao-ASR (15.39%), Qwen3-ASR-1.7B (11.85%), Fun-ASR (12.76%) and Fun-ASR-Nano-2512 (15.07%).
421
+
422
+ *Note: FRASR2=FireRedASR2, ws=WenetSpeech, md=MagicData, conv=Conversational, daily=Daily-use.*
423
+
424
+ |ID|Testset\CER\Model|FRASR2-LLM|FRASR2-AED|Doubao-ASR|Qwen3-ASR|Fun-ASR|
425
+ |:--:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|
426
+ |Avg|**All(1-24)** |**9.67** |**9.80** |12.98 |10.12 |10.92 |
427
+ |Avg|**Mandarin(1-4)** |**2.89** |**3.05** |3.69 |3.76 |4.16 |
428
+ |Avg|**Dialect(5-23)** |**11.55**|**11.67**|15.39|11.85|12.76|
429
+ |1 |aishell1 |0.64 |0.57 |1.52 |1.48 |1.64 |
430
+ |2 |aishell2 |2.15 |2.51 |2.77 |2.71 |2.38 |
431
+ |3 |ws-net |4.44 |4.57 |5.73 |4.97 |6.85 |
432
+ |4 |ws-meeting |4.32 |4.53 |4.74 |5.88 |5.78 |
433
+ |5 |kespeech |3.08 |3.60 |5.38 |5.10 |5.36 |
434
+ |6 |ws-yue-short |5.14 |5.15 |10.51|5.82 |7.34 |
435
+ |7 |ws-yue-long |8.71 |8.54 |11.39|8.85 |10.14|
436
+ |8 |ws-chuan-easy |10.90|10.60|11.33|11.99|12.46|
437
+ |9 |ws-chuan-hard |20.71|21.35|20.77|21.63|22.49|
438
+ |10|md-heavy |7.42 |7.43 |7.69 |8.02 |9.13 |
439
+ |11|md-yue-conv |12.23|11.66|26.25|9.76 |33.71|
440
+ |12|md-yue-daily |3.61 |3.35 |12.82|3.66 |2.69 |
441
+ |13|md-yue-vehicle |4.50 |4.83 |8.66 |4.28 |6.00 |
442
+ |14|md-chuan-conv |13.18|13.07|11.77|14.35|14.01|
443
+ |15|md-chuan-daily |4.90 |5.17 |3.90 |4.93 |3.98 |
444
+ |16|md-shanghai-conv |28.70|27.02|45.15|29.77|25.49|
445
+ |17|md-shanghai-daily |24.94|24.18|44.06|23.93|12.55|
446
+ |18|md-wu |7.15 |7.14 |7.70 |7.57 |10.63|
447
+ |19|md-zhengzhou-conv |10.20|10.65|9.83 |9.55 |10.85|
448
+ |20|md-zhengzhou-daily|5.80 |6.26 |5.77 |5.88 |6.29 |
449
+ |21|md-wuhan |9.60 |10.81|9.94 |10.22|4.34 |
450
+ |22|md-tianjin |15.45|15.30|15.79|16.16|19.27|
451
+ |23|md-changsha |23.18|25.64|23.76|23.70|25.66|
452
+ |24|opencpop |1.12 |1.17 |4.36 |2.57 |3.05 |
453
+
454
+
455
+ ### FireRedVAD
456
+ <details>
457
+ <summary>Click to expand</summary>
458
+ We evaluate FireRedVAD on FLEURS-VAD-102, a multilingual VAD benchmark covering 102 languages.
459
+
460
+ FireRedVAD achieves SOTA performance, outperforming Silero-VAD, TEN-VAD, FunASR-VAD, and WebRTC-VAD.
461
+
462
+ |Metric\Model|FireRedVAD|Silero-VAD|TEN-VAD|FunASR-VAD|WebRTC-VAD|
463
+ |:-------:|:-----:|:------:|:------:|:------:|:------:|
464
+ |AUC-ROC↑ |**99.60**|97.99|97.81|- |- |
465
+ |F1 score↑ |**97.57**|95.95|95.19|90.91|52.30|
466
+ |False Alarm Rate↓ |**2.69** |9.41 |15.47|44.03|2.83 |
467
+ |Miss Rate↓|3.62 |3.95 |2.95 |0.42 |64.15|
468
+
469
+ FLEURS-VAD-102: We randomly selected ~100 audio files per language from [FLEURS test set](https://huggingface.co/datasets/google/fleurs), resulting in 9,443 audio files with manually annotated binary VAD labels (speech=1, silence=0). This VAD testset will be open sourced (coming soon).
470
+
471
+ Note: FunASR-VAD achieves low Miss Rate but at the cost of high False Alarm Rate (44.03%), indicating over-prediction of speech segments.
472
+ </details>
473
+
474
+
475
+ ### FireRedLID
476
+ <details>
477
+ <summary>Click to expand</summary>
478
+ Metric: Utterance-level LID Accuracy (%). Higher is better.
479
+
480
+ We evaluate FireRedLID on multilingual and Chinese dialect benchmarks.
481
+
482
+ FireRedLID achieves SOTA performance, outperforming Whisper, SpeechBrain-LID, and Dolphin.
483
+
484
+ |Testset\Model|Languages|FireRedLID|Whisper|SpeechBrain|Dolphin|
485
+ |:-----------------:|:---------:|:---------:|:-----:|:---------:|:-----:|
486
+ |FLEURS test |82 languages |**97.18** |79.41 |92.91 |-|
487
+ |CommonVoice test |74 languages |**92.07** |80.81 |78.75 |-|
488
+ |KeSpeech + MagicData|20+ Chinese dialects/accents |**88.47** |-|-|69.01|
489
+ </details>
490
+
491
+
492
+ ### FireRedPunc
493
+ <details>
494
+ <summary>Click to expand</summary>
495
+ Metric: Precision/Recall/F1 Score (%). Higher is better.
496
+
497
+ We evaluate FireRedPunc on multi-domain Chinese and English benchmarks.
498
+
499
+ FireRedPunc achieves SOTA performance, outperforming FunASR-Punc (CT-Transformer).
500
+
501
+ |Testset\Model|#Sentences|FireRedPunc|FunASR-Punc|
502
+ |:------------------:|:---------:|:--------------:|:-----------------:|
503
+ |Multi-domain Chinese| 88,644 |**82.84 / 83.08 / 82.96** | 77.27 / 74.03 / 75.62 |
504
+ |Multi-domain English| 28,641 |**78.40 / 71.57 / 74.83** | 55.79 / 45.15 / 49.91 |
505
+ |Average F1 Score | - |**78.90** | 62.77 |
506
+
507
+ </details>
508
+
509
+
510
+ ## Acknowledgements
511
+ Thanks to the following open-source works:
512
+ - [Qwen](https://huggingface.co/Qwen)
513
+ - [WenetSpeech-Yue](https://github.com/ASLP-lab/WenetSpeech-Yue)
514
+ - [WenetSpeech-Chuan](https://github.com/ASLP-lab/WenetSpeech-Chuan)
515
+
516
+
517
+ ## Citation
518
+ ```bibtex
519
+ @article{xu2026fireredasr2s,
520
+ title={FireRedASR2S: A State-of-the-Art Industrial-Grade All-in-One Automatic Speech Recognition System},
521
+ author={Xu, Kaituo and Jia, Yan and Huang, Kai and Chen, Junjie and Li, Wenpeng and Liu, Kun and Xie, Feng-Long and Tang, Xu and Hu, Yao},
522
+ journal={arXiv preprint arXiv:2603.10420},
523
+ year={2026}
524
+ }
525
+ ```
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import gradio as gr
4
+ import spaces
5
+ from huggingface_hub import snapshot_download
6
+
7
+ sys.path.append("./fireredasr2s")
8
+ from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
9
+ from fireredasr2s import FireRedAsr2, FireRedAsr2Config
10
+
11
+
12
+ asr_model_aed = None
13
+ asr_model_llm = None
14
+
15
+
16
+ def init_model(model_dir_aed, model_dir_llm):
17
+ global asr_model_aed
18
+ global asr_model_llm
19
+ if asr_model_aed is None:
20
+ asr_config_aed = FireRedAsr2Config(
21
+ use_gpu=True,
22
+ use_half=False,
23
+ beam_size=3,
24
+ nbest=1,
25
+ decode_max_len=0,
26
+ softmax_smoothing=1.25,
27
+ aed_length_penalty=0.6,
28
+ eos_penalty=1.0,
29
+ return_timestamp=True
30
+ )
31
+ asr_model_aed = FireRedAsr2.from_pretrained("aed", model_dir_aed, asr_config_aed)
32
+ if asr_model_llm is None:
33
+ asr_config_llm = FireRedAsr2Config(
34
+ use_gpu=True,
35
+ decode_min_len=0,
36
+ repetition_penalty=3.0,
37
+ llm_length_penalty=1.0,
38
+ temperature=1.0
39
+ )
40
+ asr_model_llm = FireRedAsr2.from_pretrained("llm", model_dir_llm, asr_config_llm)
41
+
42
+
43
+ @spaces.GPU(duration=20)
44
+ def asr_inference(audio_file):
45
+ if not audio_file:
46
+ return "Please upload a wav file"
47
+ batch_uttid = ["demo"]
48
+ batch_wav_path = [audio_file]
49
+ results = asr_model_aed.transcribe(
50
+ batch_uttid,
51
+ batch_wav_path
52
+ )
53
+ text_output = results[0]["text"]
54
+ return text_output
55
+
56
+
57
+ @spaces.GPU(duration=30)
58
+ def asr_inference_llm(audio_file):
59
+ if not audio_file:
60
+ return "Please upload a wav file"
61
+ batch_uttid = ["demo"]
62
+ batch_wav_path = [audio_file]
63
+ results = asr_model_llm.transcribe(
64
+ batch_uttid,
65
+ batch_wav_path,
66
+ )
67
+ text_output = results[0]["text"]
68
+ return text_output
69
+
70
+
71
+ with gr.Blocks(title="FireRedASR") as demo:
72
+ gr.HTML(
73
+ "<h1 style='text-align: center'>FireRedASR2 Demo</h1>"
74
+ )
75
+ gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
76
+
77
+ with gr.Row():
78
+ with gr.Column():
79
+ #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
80
+ audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
81
+
82
+ with gr.Column():
83
+ asr_button = gr.Button("Start Recognition (FireRedASR2-AED-L)", variant="primary")
84
+ text_output = gr.Textbox(label="Model Result (FireRedASR2-AED-L)", interactive=False, lines=3, max_lines=12)
85
+ asr_button_llm = gr.Button("Start Recognition (FireRedASR2-LLM-L)", variant="primary")
86
+ text_output_llm = gr.Textbox(label="Model Result (FireRedASR2-LLM-L)", interactive=False, lines=3, max_lines=12)
87
+
88
+ asr_button.click(
89
+ fn=asr_inference,
90
+ inputs=[audio_file],
91
+ outputs=[text_output]
92
+ )
93
+
94
+ asr_button_llm.click(
95
+ fn=asr_inference_llm,
96
+ inputs=[audio_file],
97
+ outputs=[text_output_llm]
98
+ )
99
+
100
+
101
+ if __name__ == "__main__":
102
+ # Download model
103
+ local_dir='pretrained_models/FireRedASR2-AED-L'
104
+ snapshot_download(repo_id='FireRedTeam/FireRedASR2-AED-L', local_dir=local_dir)
105
+ local_dir_llm='pretrained_models/FireRedASR2-LLM-L'
106
+ snapshot_download(repo_id='FireRedTeam/FireRedASR2-LLM-L', local_dir=local_dir_llm)
107
+ # Init model
108
+ init_model(local_dir, local_dir_llm)
109
+ # UI
110
+ demo.queue()
111
+ demo.launch()
fireredasr2s/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
2
+
3
+ import os
4
+ import sys
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=ALL, 1=INFO, 2=WARNING, 3=ERROR
8
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
9
+
10
+ __version__ = "0.0.1"
11
+
12
+ _PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__))
13
+ _PROJECT_ROOT = os.path.dirname(_PACKAGE_DIR)
14
+ if _PROJECT_ROOT not in sys.path:
15
+ sys.path.insert(0, _PROJECT_ROOT)
16
+
17
+ from fireredasr2s.fireredasr2system import (
18
+ FireRedAsr2System,
19
+ FireRedAsr2SystemConfig
20
+ )
21
+
22
+
23
+ # API
24
+ __all__ = [
25
+ "__version__",
26
+ "FireRedAsr2System",
27
+ "FireRedAsr2SystemConfig",
28
+ "FireRedAsr2",
29
+ "FireRedAsr2Config",
30
+ "FireRedVad",
31
+ "FireRedVadConfig",
32
+ "FireRedStreamVad",
33
+ "FireRedStreamVadConfig",
34
+ "FireRedAed",
35
+ "FireRedAedConfig",
36
+ "FireRedLid",
37
+ "FireRedLidConfig",
38
+ "FireRedPunc",
39
+ "FireRedPuncConfig",
40
+ ]
fireredasr2s/fireredasr2/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import os
4
+ import sys
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=ALL, 1=INFO, 2=WARNING, 3=ERROR
8
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
9
+
10
+ __version__ = "0.0.1"
11
+
12
+ _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
13
+
14
+ try:
15
+ from fireredasr2s.fireredasr2.asr import FireRedAsr2, FireRedAsr2Config
16
+ except ImportError:
17
+ if _CURRENT_DIR not in sys.path:
18
+ sys.path.insert(0, _CURRENT_DIR)
19
+ from .asr import FireRedAsr2, FireRedAsr2Config
20
+
21
+
22
+ # API
23
+ __all__ = [
24
+ "__version__",
25
+ "FireRedAsr2",
26
+ "FireRedAsr2Config",
27
+ ]
fireredasr2s/fireredasr2/asr.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import logging
4
+ import os
5
+ import re
6
+ import time
7
+ import traceback
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+
12
+ from .data.asr_feat import ASRFeatExtractor
13
+ from .models.fireredasr_aed import FireRedAsrAed
14
+ from .models.fireredasr_llm import FireRedAsrLlm
15
+ from .models.lstm_lm import LstmLm
16
+ from .models.param import count_model_parameters
17
+ from .tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
18
+ from .tokenizer.llm_tokenizer import LlmTokenizerWrapper
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class FireRedAsr2Config:
25
+ use_gpu: bool = True
26
+ use_half: bool = False
27
+ beam_size: int = 3
28
+ nbest: int = 1
29
+ decode_max_len: int = 0
30
+ softmax_smoothing: float = 1.25
31
+ aed_length_penalty: float = 0.6
32
+ eos_penalty: float = 1.0
33
+ return_timestamp: bool = False
34
+ decode_min_len: bool = 0
35
+ repetition_penalty: float = 1.0
36
+ llm_length_penalty: float = 0.0
37
+ temperature: float = 1.0
38
+ elm_dir: str = ""
39
+ elm_weight: float = 0.0
40
+ def __post_init__(self):
41
+ pass
42
+
43
+
44
+ class FireRedAsr2:
45
+ @classmethod
46
+ def from_pretrained(cls, asr_type, model_dir, config=FireRedAsr2Config()):
47
+ assert asr_type in ["aed", "llm"]
48
+
49
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
50
+ feat_extractor = ASRFeatExtractor(cmvn_path)
51
+
52
+ if asr_type == "aed":
53
+ model_path = os.path.join(model_dir, "model.pth.tar")
54
+ dict_path =os.path.join(model_dir, "dict.txt")
55
+ spm_model = os.path.join(model_dir, "train_bpe1000.model")
56
+ model = load_fireredasr_aed_model(model_path)
57
+ tokenizer = ChineseCharEnglishSpmTokenizer(dict_path, spm_model)
58
+ elif asr_type == "llm":
59
+ model_path = os.path.join(model_dir, "model.pth.tar")
60
+ encoder_path = os.path.join(model_dir, "asr_encoder.pth.tar")
61
+ llm_dir = os.path.join(model_dir, "Qwen2-7B-Instruct")
62
+ model, tokenizer = load_firered_llm_model_and_tokenizer(
63
+ model_path, encoder_path, llm_dir)
64
+ elm = None
65
+ if config.elm_dir:
66
+ assert os.path.exists(config.elm_dir), f"{config.elm_dir}"
67
+ model_path = os.path.join(config.elm_dir, "model.pth.tar")
68
+ elm = load_lstm_lm(model_path)
69
+ elm.eval()
70
+ logger.info(elm)
71
+ count_model_parameters(model)
72
+ model.eval()
73
+ return cls(asr_type, feat_extractor, model, tokenizer, elm, config)
74
+
75
+ def __init__(self, asr_type, feat_extractor, model, tokenizer, elm, config):
76
+ self.asr_type = asr_type
77
+ self.feat_extractor = feat_extractor
78
+ self.model = model
79
+ self.tokenizer = tokenizer
80
+ self.elm = elm
81
+ self.config = config
82
+ logger.info(self.config)
83
+ if self.config.use_gpu:
84
+ if self.config.use_half:
85
+ self.model.half()
86
+ self.model.cuda()
87
+ if self.elm:
88
+ self.elm.cuda()
89
+ else:
90
+ self.model.cpu()
91
+
92
+ @torch.no_grad()
93
+ def transcribe(self, batch_uttid, batch_wav_path):
94
+ batch_uttid_origin = batch_uttid
95
+ try:
96
+ feats, lengths, durs, batch_wav_path, batch_uttid = \
97
+ self.feat_extractor(batch_wav_path, batch_uttid)
98
+ if feats is None:
99
+ return [{"uttid": uttid, "text":""} for uttid in batch_uttid_origin]
100
+ except:
101
+ traceback.print_exc()
102
+ return [{"uttid": uttid, "text":""} for uttid in batch_uttid_origin]
103
+ total_dur = sum(durs)
104
+ if self.config.use_gpu:
105
+ feats, lengths = feats.cuda(), lengths.cuda()
106
+ if self.config.use_half:
107
+ feats = feats.half()
108
+
109
+ if self.asr_type == "aed":
110
+ start_time = time.time()
111
+
112
+ try:
113
+ hyps = self.model.transcribe(
114
+ feats, lengths,
115
+ self.config.beam_size,
116
+ self.config.nbest,
117
+ self.config.decode_max_len,
118
+ self.config.softmax_smoothing,
119
+ self.config.aed_length_penalty,
120
+ self.config.eos_penalty,
121
+ self.config.return_timestamp,
122
+ self.elm,
123
+ self.config.elm_weight
124
+ )
125
+ except Exception as e:
126
+ traceback.print_exc()
127
+ hyps = []
128
+
129
+ elapsed = time.time() - start_time
130
+ rtf= elapsed / total_dur if total_dur > 0 else 0
131
+
132
+ results = []
133
+ for uttid, wav, hyp, dur in zip(batch_uttid, batch_wav_path, hyps, durs):
134
+ hyp = hyp[0] # only return 1-best
135
+ hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
136
+ text = self.tokenizer.detokenize(hyp_ids)
137
+ text = re.sub(r"(<blank>)|(<sil>)", "", text)
138
+ results.append({"uttid": uttid, "text": text.lower(),
139
+ "confidence": round(hyp["confidence"].cpu().item(), 3),
140
+ "dur_s": round(dur, 3), "rtf": f"{rtf:.4f}"})
141
+ if type(wav) == str:
142
+ results[-1]["wav"] = wav
143
+ if self.config.return_timestamp:
144
+ results[-1]["timestamp"] = self._get_and_fix_timestamp(hyp, hyp_ids, dur)
145
+ return results
146
+
147
+ elif self.asr_type == "llm":
148
+ input_ids, attention_mask, _, _ = \
149
+ LlmTokenizerWrapper.preprocess_texts(
150
+ origin_texts=[""]*feats.size(0), tokenizer=self.tokenizer,
151
+ max_len=128, decode=True)
152
+ if self.config.use_gpu:
153
+ input_ids = input_ids.cuda()
154
+ attention_mask = attention_mask.cuda()
155
+ start_time = time.time()
156
+
157
+ try:
158
+ generated_ids = self.model.transcribe(
159
+ feats, lengths, input_ids, attention_mask,
160
+ self.config.beam_size,
161
+ self.config.decode_max_len,
162
+ self.config.decode_min_len,
163
+ self.config.repetition_penalty,
164
+ self.config.llm_length_penalty,
165
+ self.config.temperature
166
+ )
167
+ texts = self.tokenizer.batch_decode(generated_ids,
168
+ skip_special_tokens=True)
169
+ except Exception as e:
170
+ texts = []
171
+
172
+ elapsed = time.time() - start_time
173
+ rtf= elapsed / total_dur if total_dur > 0 else 0
174
+ results = []
175
+ for uttid, wav, text in zip(batch_uttid, batch_wav_path, texts):
176
+ results.append({"uttid": uttid, "text": text.lower(),
177
+ "rtf": f"{rtf:.4f}"})
178
+ if type(wav) == str:
179
+ results[-1]["wav"] = wav
180
+ return results
181
+
182
+ def _get_and_fix_timestamp(self, hyp, hyp_ids, dur):
183
+ r3 = lambda x: round(x, 3)
184
+ if "timestamp" not in hyp or hyp["timestamp"] is None:
185
+ timestamp = []
186
+ avg_dur = dur / len(hyp_ids) if len(hyp_ids) > 0 else 0
187
+ last_end = dur
188
+ for i, hyp_id in enumerate(hyp_ids):
189
+ token = self.tokenizer.detokenize([hyp_id], "", False)
190
+ start = min(max(0, i*avg_dur), last_end)
191
+ end = min((i+1)*avg_dur, dur)
192
+ last_end = end
193
+ timestamp.append([token.lower(), r3(start), r3(end)])
194
+ else:
195
+ starts, ends = hyp["timestamp"]
196
+ timestamp = []
197
+ last_end = dur
198
+ SHIFT = 0.06 # shift 40ms
199
+ for hyp_id, start, end in zip(hyp_ids, starts, ends):
200
+ token = self.tokenizer.detokenize([hyp_id], "", False)
201
+ start = min(max(0, start - SHIFT), last_end)
202
+ end = min(max(0, end - SHIFT), dur)
203
+ last_end = end
204
+ timestamp.append([token.lower(), r3(start), r3(end)])
205
+ # Fix case: start == dur and end == dur
206
+ for i in range(len(timestamp)):
207
+ idx = -(i+1)
208
+ _, start, end = timestamp[idx]
209
+ if abs(dur - start) < 0.001:
210
+ logger.info(f"start before {timestamp[idx]}")
211
+ timestamp[idx][1] = dur - (i+1)*0.001
212
+ logger.info(f"start after {timestamp[idx]}")
213
+ if i != 0 and abs(dur - end) < 0.001:
214
+ logger.info(f"end before {timestamp[idx]}")
215
+ timestamp[idx][2] = dur - i*0.001
216
+ logger.info(f"end after {timestamp[idx]}")
217
+ timestamp = self.tokenizer.merge_spm_timestamp(timestamp)
218
+ return timestamp
219
+
220
+
221
+ def load_fireredasr_aed_model(model_path):
222
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
223
+ model = FireRedAsrAed.from_args(package["args"])
224
+ model.load_state_dict(package["model_state_dict"], strict=False)
225
+ return model
226
+
227
+
228
+ def load_firered_llm_model_and_tokenizer(model_path, encoder_path, llm_dir):
229
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
230
+ package["args"].encoder_path = encoder_path
231
+ package["args"].llm_dir = llm_dir
232
+ model = FireRedAsrLlm.from_args(package["args"])
233
+ model.load_state_dict(package["model_state_dict"], strict=False)
234
+ tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(llm_dir)
235
+ return model, tokenizer
236
+
237
+ def load_lstm_lm(model_path):
238
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
239
+ model = LstmLm.from_args(package["args"])
240
+ model.load_state_dict(package["model_state_dict"], strict=False)
241
+ return model
fireredasr2s/fireredasr2/data/asr_feat.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import math
4
+ import os
5
+
6
+ import kaldiio
7
+ import kaldi_native_fbank as knf
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class ASRFeatExtractor:
13
+ def __init__(self, kaldi_cmvn_file):
14
+ self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
15
+ self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
16
+ frame_shift=10, dither=0.0)
17
+
18
+ def __call__(self, wav_paths, wav_uttids):
19
+ feats = []
20
+ durs = []
21
+ return_wav_paths = []
22
+ return_wav_uttids = []
23
+
24
+ wav_datas = []
25
+ if isinstance(wav_paths[0], str):
26
+ for wav_path in wav_paths:
27
+ sample_rate, wav_np = kaldiio.load_mat(wav_path)
28
+ wav_datas.append([sample_rate, wav_np])
29
+ else:
30
+ wav_datas = wav_paths
31
+
32
+ for (sample_rate, wav_np), path, uttid in zip(wav_datas, wav_paths, wav_uttids):
33
+ dur = wav_np.shape[0] / sample_rate
34
+ fbank = self.fbank((sample_rate, wav_np))
35
+ if fbank.shape[0] < 1:
36
+ continue
37
+ if self.cmvn is not None:
38
+ fbank = self.cmvn(fbank)
39
+ fbank = torch.from_numpy(fbank).float()
40
+ feats.append(fbank)
41
+ durs.append(dur)
42
+ return_wav_paths.append(path)
43
+ return_wav_uttids.append(uttid)
44
+ if len(feats) > 0:
45
+ lengths = torch.tensor([feat.size(0) for feat in feats]).long()
46
+ feats_pad = self.pad_feat(feats, 0.0)
47
+ else:
48
+ lengths, feats_pad = None, None
49
+ return feats_pad, lengths, durs, return_wav_paths, return_wav_uttids
50
+
51
+ def pad_feat(self, xs, pad_value):
52
+ # type: (List[Tensor], int) -> Tensor
53
+ n_batch = len(xs)
54
+ max_len = max([xs[i].size(0) for i in range(n_batch)])
55
+ pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value)
56
+ for i in range(n_batch):
57
+ pad[i, :xs[i].size(0)] = xs[i]
58
+ return pad
59
+
60
+
61
+ class CMVN:
62
+ def __init__(self, kaldi_cmvn_file):
63
+ self.dim, self.means, self.inverse_std_variences = \
64
+ self.read_kaldi_cmvn(kaldi_cmvn_file)
65
+
66
+ def __call__(self, x, is_train=False):
67
+ assert x.shape[-1] == self.dim, "CMVN dim mismatch"
68
+ out = x - self.means
69
+ out = out * self.inverse_std_variences
70
+ return out
71
+
72
+ def read_kaldi_cmvn(self, kaldi_cmvn_file):
73
+ assert os.path.exists(kaldi_cmvn_file)
74
+ stats = kaldiio.load_mat(kaldi_cmvn_file)
75
+ assert stats.shape[0] == 2
76
+ dim = stats.shape[-1] - 1
77
+ count = stats[0, dim]
78
+ assert count >= 1
79
+ floor = 1e-20
80
+ means = []
81
+ inverse_std_variences = []
82
+ for d in range(dim):
83
+ mean = stats[0, d] / count
84
+ means.append(mean.item())
85
+ varience = (stats[1, d] / count) - mean*mean
86
+ if varience < floor:
87
+ varience = floor
88
+ istd = 1.0 / math.sqrt(varience)
89
+ inverse_std_variences.append(istd)
90
+ return dim, np.array(means), np.array(inverse_std_variences)
91
+
92
+
93
+
94
+ class KaldifeatFbank:
95
+ def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
96
+ dither=1.0):
97
+ self.dither = dither
98
+ opts = knf.FbankOptions()
99
+ opts.frame_opts.dither = dither
100
+ opts.mel_opts.num_bins = num_mel_bins
101
+ opts.frame_opts.snip_edges = True
102
+ opts.mel_opts.debug_mel = False
103
+ self.opts = opts
104
+
105
+ def __call__(self, wav, is_train=False):
106
+ if type(wav) is str:
107
+ sample_rate, wav_np = kaldiio.load_mat(wav)
108
+ elif type(wav) in [tuple, list] and len(wav) == 2:
109
+ sample_rate, wav_np = wav
110
+ assert len(wav_np.shape) == 1
111
+
112
+ dither = self.dither if is_train else 0.0
113
+ self.opts.frame_opts.dither = dither
114
+ fbank = knf.OnlineFbank(self.opts)
115
+
116
+ fbank.accept_waveform(sample_rate, wav_np.tolist())
117
+ feat = []
118
+ for i in range(fbank.num_frames_ready):
119
+ feat.append(fbank.get_frame(i))
120
+ if len(feat) == 0:
121
+ print("Check data, len(feat) == 0", wav, flush=True)
122
+ return np.zeros((0, self.opts.mel_opts.num_bins))
123
+ feat = np.vstack(feat)
124
+ return feat
fireredasr2s/fireredasr2/data/token_dict.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class TokenDict:
9
+ def __init__(self, dict_path, unk=""):
10
+ assert dict_path != ""
11
+ self.id2word, self.word2id = self.read_dict(dict_path)
12
+ self.unk = unk
13
+ assert unk == "" or unk in self.word2id
14
+ self.unkid = self.word2id[unk] if unk else -1
15
+
16
+ def get(self, key, default):
17
+ if type(default) == str:
18
+ default = self.word2id[default]
19
+ return self.word2id.get(key, default)
20
+
21
+ def __getitem__(self, key):
22
+ if type(key) == str:
23
+ if self.unk:
24
+ return self.word2id.get(key, self.word2id[self.unk])
25
+ else:
26
+ return self.word2id[key]
27
+ elif type(key) == int:
28
+ return self.id2word[key]
29
+ else:
30
+ raise TypeError("Key should be str or int")
31
+
32
+ def __len__(self):
33
+ return len(self.id2word)
34
+
35
+ def __contains__(self, query):
36
+ if type(query) == str:
37
+ return query in self.word2id
38
+ elif type(query) == int:
39
+ return query in self.id2word
40
+ else:
41
+ raise TypeError("query should be str or int")
42
+
43
+ def read_dict(self, dict_path):
44
+ id2word, word2id = [], {}
45
+ with open(dict_path, encoding='utf8') as f:
46
+ for i, line in enumerate(f):
47
+ tokens = line.strip().split()
48
+ if len(tokens) >= 2:
49
+ word, index = tokens[0], int(tokens[1])
50
+ elif len(tokens) == 1:
51
+ word, index = tokens[0], i
52
+ else: # empty line or space
53
+ logger.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
54
+ word, index = " ", i
55
+ assert len(id2word) == index
56
+ assert len(word2id) == index
57
+ if word == "<space>":
58
+ logger.info(f"NOTE: Find <space> in {dict_path}:L{i} and convert it to ' '")
59
+ word = " "
60
+ word2id[word] = index
61
+ id2word.append(word)
62
+ assert len(id2word) == len(word2id)
63
+ return id2word, word2id
fireredasr2s/fireredasr2/models/fireredasr_aed.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import traceback
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from .module.conformer_encoder import ConformerEncoder
9
+ from .module.ctc import CTC
10
+ from .module.transformer_decoder import TransformerDecoder
11
+
12
+
13
+ class FireRedAsrAed(torch.nn.Module):
14
+ @classmethod
15
+ def from_args(cls, args):
16
+ return cls(args)
17
+
18
+ def __init__(self, args):
19
+ super().__init__()
20
+ self.sos_id = args.sos_id
21
+ self.eos_id = args.eos_id
22
+
23
+ self.encoder = ConformerEncoder(
24
+ args.idim, args.n_layers_enc, args.n_head, args.d_model,
25
+ args.residual_dropout, args.dropout_rate,
26
+ args.kernel_size, args.pe_maxlen)
27
+
28
+ self.decoder = TransformerDecoder(
29
+ args.sos_id, args.eos_id, args.pad_id, args.odim,
30
+ args.n_layers_dec, args.n_head, args.d_model,
31
+ args.residual_dropout, args.pe_maxlen)
32
+
33
+ self.ctc = CTC(args.odim, args.d_model)
34
+
35
+ def transcribe(self, padded_input, input_lengths,
36
+ beam_size=1, nbest=1, decode_max_len=0,
37
+ softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0,
38
+ return_timestamp=False, elm=None, elm_weight=0.0):
39
+ enc_outputs, enc_lengths, enc_mask = self.encoder(padded_input, input_lengths)
40
+ nbest_hyps = self.decoder.batch_beam_search(
41
+ enc_outputs, enc_mask,
42
+ beam_size, nbest, decode_max_len,
43
+ softmax_smoothing, length_penalty, eos_penalty,
44
+ elm, elm_weight)
45
+ if return_timestamp:
46
+ nbest_hyps = self.get_token_timestamp_torchaudio(enc_outputs, enc_lengths, nbest_hyps)
47
+ return nbest_hyps
48
+
49
+ def get_token_timestamp_torchaudio(self, enc_outputs, enc_lengths, nbest_hyps):
50
+ ctc_logits = self.ctc(enc_outputs)
51
+ enc_lengths = enc_lengths
52
+ for n in range(enc_outputs.size(0)):
53
+ try:
54
+ n_ctc_logits = ctc_logits[n, :enc_lengths[n]]
55
+ y = nbest_hyps[n][0]["yseq"]
56
+ y = y[y!=0] # 0 is blank
57
+ if y.numel() == 0 or n_ctc_logits.size()[0] == 0:
58
+ logger.debug("skip null output")
59
+ nbest_hyps[n][0]["timestamp"] = None
60
+ continue
61
+ elif y.numel() > n_ctc_logits.size()[0]:
62
+ nbest_hyps[n][0]["timestamp"] = None
63
+ continue
64
+
65
+ alignment, _ = torchaudio.functional.forced_align(
66
+ n_ctc_logits.unsqueeze(0), y.unsqueeze(0), blank=0)
67
+ alignment = alignment[0].cpu().tolist()
68
+ start_times, end_times = self.ctc.ctc_alignment_to_timestamp(alignment,
69
+ self.encoder.input_preprocessor.subsampling, blank_id=0)
70
+ nbest_hyps[n][0]["timestamp"] = (start_times, end_times)
71
+ except:
72
+ traceback.print_exc()
73
+ nbest_hyps[n][0]["timestamp"] = None
74
+ return nbest_hyps
fireredasr2s/fireredasr2/models/fireredasr_llm.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from transformers import AutoModelForCausalLM
11
+
12
+ from ..models.fireredasr_aed import FireRedAsrAed
13
+ from ..models.module.adapter import Adapter
14
+ from ..models.param import count_model_parameters
15
+ from ..tokenizer.llm_tokenizer import DEFAULT_SPEECH_TOKEN, IGNORE_TOKEN_ID
16
+ from ..tokenizer.llm_tokenizer import LlmTokenizerWrapper
17
+
18
+
19
+ class FireRedAsrLlm(nn.Module):
20
+ @classmethod
21
+ def load_encoder(cls, model_path):
22
+ assert os.path.exists(model_path)
23
+ package = torch.load(model_path, map_location=lambda storage, loc: storage)
24
+ model = FireRedAsrAed.from_args(package["args"])
25
+ if "model_state_dict" in package:
26
+ model.load_state_dict(package["model_state_dict"], strict=False)
27
+ encoder = model.encoder
28
+ encoder_dim = encoder.odim
29
+ return encoder, encoder_dim
30
+
31
+ @classmethod
32
+ def from_args(cls, args):
33
+ logging.info(args)
34
+ logging.info("Build FireRedAsrLlm")
35
+ # Build Speech Encoder
36
+ encoder, encoder_dim = cls.load_encoder(args.encoder_path)
37
+ count_model_parameters(encoder)
38
+ if args.freeze_encoder:
39
+ logging.info(f"Frezee encoder")
40
+ for name, param in encoder.named_parameters():
41
+ param.requires_grad = False
42
+ encoder.eval()
43
+
44
+ # Training use torch.bfloat16
45
+ if args.use_flash_attn:
46
+ attn_implementation = "flash_attention_2"
47
+ if args.use_fp16:
48
+ #torch_dtype = torch.float16
49
+ torch_dtype = torch.bfloat16
50
+ else:
51
+ torch_dtype = torch.float32
52
+ else:
53
+ attn_implementation = "eager"
54
+ if args.use_fp16:
55
+ #torch_dtype = torch.float16
56
+ torch_dtype = torch.bfloat16
57
+ else:
58
+ torch_dtype = torch.float32
59
+
60
+ # Build LLM
61
+ llm = AutoModelForCausalLM.from_pretrained(
62
+ args.llm_dir,
63
+ attn_implementation=attn_implementation,
64
+ torch_dtype=torch_dtype,
65
+ )
66
+ count_model_parameters(llm)
67
+
68
+ # LLM Freeze or LoRA
69
+ llm_dim = llm.config.hidden_size
70
+ if args.freeze_llm:
71
+ logging.info(f"Frezee LLM")
72
+ for name, param in llm.named_parameters():
73
+ param.requires_grad = False
74
+ llm.eval()
75
+ else:
76
+ if args.use_lora:
77
+ from peft import LoraConfig, get_peft_model
78
+ lora_config = LoraConfig(
79
+ r=64,
80
+ lora_alpha=16,
81
+ target_modules=[
82
+ "q_proj",
83
+ "k_proj",
84
+ "v_proj",
85
+ "o_proj",
86
+ "up_proj",
87
+ "gate_proj",
88
+ "down_proj",
89
+ ],
90
+ lora_dropout=0.05,
91
+ task_type="CAUSAL_LM",
92
+ )
93
+ llm = get_peft_model(llm, lora_config)
94
+ llm.print_trainable_parameters()
95
+
96
+ tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(args.llm_dir)
97
+ assert tokenizer.pad_token_id == tokenizer.convert_tokens_to_ids("<|endoftext|>")
98
+ llm.config.pad_token_id = tokenizer.pad_token_id
99
+ llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
100
+ llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
101
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
102
+ DEFAULT_SPEECH_TOKEN
103
+ )
104
+
105
+ # Build projector
106
+ encoder_projector = Adapter(
107
+ encoder_dim, llm_dim, args.encoder_downsample_rate)
108
+ count_model_parameters(encoder_projector)
109
+
110
+ return cls(encoder, llm, encoder_projector,
111
+ args.freeze_encoder, args.freeze_llm)
112
+
113
+ def __init__(self, encoder, llm, encoder_projector,
114
+ freeze_encoder, freeze_llm):
115
+ super().__init__()
116
+ self.encoder = encoder
117
+ self.llm = llm
118
+ self.encoder_projector = encoder_projector
119
+ # args
120
+ self.freeze_encoder = freeze_encoder
121
+ self.freeze_llm = freeze_llm
122
+ self.llm_config = llm.config
123
+
124
+ def transcribe(self, padded_feat, feat_lengths, padded_input_ids, attention_mask,
125
+ beam_size=1, decode_max_len=0, decode_min_len=0,
126
+ repetition_penalty=1.0, llm_length_penalty=1.0, temperature=1.0):
127
+ encoder_outs, enc_lengths, enc_mask = self.encoder(padded_feat, feat_lengths)
128
+ speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
129
+ inputs_embeds = self.llm.get_input_embeddings()(padded_input_ids)
130
+
131
+ inputs_embeds, attention_mask, _ = \
132
+ self._merge_input_ids_with_speech_features(
133
+ speech_features.to(inputs_embeds.dtype), inputs_embeds, padded_input_ids, attention_mask,
134
+ speech_lens=speech_lens
135
+ )
136
+
137
+ max_new_tokens = speech_features.size(1) if decode_max_len < 1 else decode_max_len
138
+ max_new_tokens = max(1, max_new_tokens)
139
+
140
+ generated_ids = self.llm.generate(
141
+ inputs_embeds=inputs_embeds,
142
+ attention_mask=attention_mask,
143
+ max_new_tokens=max_new_tokens,
144
+ num_beams=beam_size,
145
+ do_sample=False,
146
+ min_length=decode_min_len,
147
+ repetition_penalty=repetition_penalty,
148
+ length_penalty=llm_length_penalty,
149
+ temperature=temperature,
150
+ bos_token_id=self.llm.config.bos_token_id,
151
+ eos_token_id=self.llm.config.eos_token_id,
152
+ pad_token_id=self.llm.config.pad_token_id,
153
+ )
154
+
155
+ return generated_ids
156
+
157
+
158
+ def _merge_input_ids_with_speech_features(
159
+ self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None,
160
+ speech_lens=None
161
+ ):
162
+ """
163
+ Modified from: https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
164
+ """
165
+ speech_lens = None
166
+ num_speechs, speech_len, embed_dim = speech_features.shape
167
+ batch_size, sequence_length = input_ids.shape
168
+ left_padding = not torch.sum(
169
+ input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
170
+ )
171
+ # 1. Create a mask to know where special speech tokens are
172
+ special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
173
+ num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
174
+ # Compute the maximum embed dimension
175
+ max_embed_dim = (
176
+ num_special_speech_tokens.max() * (speech_len - 1)
177
+ ) + sequence_length
178
+ batch_indices, non_speech_indices = torch.where(
179
+ input_ids != self.llm.config.default_speech_token_id
180
+ )
181
+
182
+ # 2. Compute the positions where text should be written
183
+ # Calculate new positions for text tokens in merged speech-text sequence.
184
+ # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
185
+ # `torch.cumsum` computes how each speech token shifts subsequent text token positions.
186
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
187
+ new_token_positions = (
188
+ torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
189
+ ) # (N,U)
190
+ nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
191
+ if left_padding:
192
+ new_token_positions += nb_speech_pad[:, None] # offset for left padding
193
+ text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
194
+
195
+ # 3. Create the full embedding, already padded to the maximum position
196
+ final_embedding = torch.zeros(
197
+ batch_size,
198
+ max_embed_dim,
199
+ embed_dim,
200
+ dtype=inputs_embeds.dtype,
201
+ device=inputs_embeds.device,
202
+ )
203
+ final_attention_mask = torch.zeros(
204
+ batch_size,
205
+ max_embed_dim,
206
+ dtype=attention_mask.dtype,
207
+ device=inputs_embeds.device,
208
+ )
209
+ if labels is not None:
210
+ final_labels = torch.full(
211
+ (batch_size, max_embed_dim),
212
+ IGNORE_TOKEN_ID,
213
+ dtype=input_ids.dtype,
214
+ device=input_ids.device,
215
+ )
216
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
217
+ # set the corresponding tensors into their correct target device.
218
+ target_device = inputs_embeds.device
219
+ batch_indices, non_speech_indices, text_to_overwrite = (
220
+ batch_indices.to(target_device),
221
+ non_speech_indices.to(target_device),
222
+ text_to_overwrite.to(target_device),
223
+ )
224
+ attention_mask = attention_mask.to(target_device)
225
+
226
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
227
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
228
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
229
+ batch_indices, non_speech_indices
230
+ ]
231
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
232
+ batch_indices, non_speech_indices
233
+ ]
234
+ if labels is not None:
235
+ final_labels[batch_indices, text_to_overwrite] = labels[
236
+ batch_indices, non_speech_indices
237
+ ]
238
+
239
+ # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
240
+ speech_to_overwrite = torch.full(
241
+ (batch_size, max_embed_dim),
242
+ True,
243
+ dtype=torch.bool,
244
+ device=inputs_embeds.device,
245
+ )
246
+ speech_to_overwrite[batch_indices, text_to_overwrite] = False
247
+ if speech_lens is not None:
248
+ speech_pad_position = speech_to_overwrite.cumsum(-1) <= speech_lens[:, None]
249
+ speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
250
+ :, None
251
+ ].to(target_device)
252
+ if speech_lens is not None:
253
+ if torch.any(speech_lens > speech_len):
254
+ raise ValueError(
255
+ f"speech_lens contains values ({speech_lens.max()}) larger than "
256
+ f"speech_len ({speech_len})"
257
+ )
258
+
259
+ speech_cumsum = speech_to_overwrite.long().cumsum(-1)
260
+ speech_position_counter = torch.where(speech_to_overwrite, speech_cumsum - 1, 0)
261
+ valid_speech_positions = speech_position_counter < speech_lens[:, None].to(target_device)
262
+
263
+ speech_to_overwrite &= valid_speech_positions
264
+ if speech_to_overwrite.sum().item() != int(speech_lens.sum().item()):
265
+ raise ValueError(
266
+ f"speech_lens and speech token distribution mismatch: "
267
+ f"expected total speech frames {speech_lens.sum().item()}, "
268
+ f"but got {speech_to_overwrite.sum().item()} positions."
269
+ )
270
+ batch_idx, seq_idx = torch.where(speech_to_overwrite)
271
+ speech_feature_idx = speech_position_counter[speech_to_overwrite]
272
+ final_embedding[batch_idx, seq_idx] = speech_features[batch_idx, speech_feature_idx].to(target_device)
273
+ else:
274
+ if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
275
+ raise ValueError(
276
+ f"The input provided to the model are wrong. The number of speech tokens is {speech_to_overwrite.sum()} while"
277
+ f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
278
+ )
279
+ final_embedding[speech_to_overwrite] = (
280
+ speech_features.contiguous().reshape(-1, embed_dim)[:speech_to_overwrite.sum()].to(target_device)
281
+ )
282
+
283
+ final_attention_mask[speech_to_overwrite] = 1
284
+
285
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
286
+ batch_indices_pad, pad_indices = torch.where(
287
+ input_ids == self.llm.config.pad_token_id
288
+ )
289
+ if len(batch_indices_pad) > 0:
290
+ indices_to_mask = new_token_positions[batch_indices_pad, pad_indices]
291
+ final_embedding[batch_indices_pad, indices_to_mask] = 0
292
+ final_attention_mask[batch_indices_pad, indices_to_mask] = 0
293
+
294
+ if labels is None:
295
+ final_labels = None
296
+
297
+ return final_embedding, final_attention_mask, final_labels
fireredasr2s/fireredasr2/models/lstm_lm.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7
+
8
+
9
+ class LstmLm(nn.Module):
10
+ @classmethod
11
+ def from_args(cls, args):
12
+ args.padding_idx = 2
13
+ args.sos_id = 3
14
+ args.eos_id = 4
15
+ return cls(args)
16
+
17
+ def __init__(self, args):
18
+ super().__init__()
19
+ self.embedding = nn.Embedding(args.idim, args.embedding_dim,
20
+ padding_idx=args.padding_idx)
21
+ self.lstm = nn.LSTM(args.embedding_dim, args.hidden_size, args.num_layers,
22
+ batch_first=True, dropout=args.dropout)
23
+ self.fc_in_dim = args.embedding_dim
24
+ self.fc = nn.Linear(args.embedding_dim, args.odim)
25
+
26
+ self._tie_weights(args)
27
+ self.sos_id = args.sos_id
28
+ self.eos_id = args.eos_id
29
+ self.ignore_index = args.padding_idx
30
+
31
+ @torch.jit.ignore
32
+ def _tie_weights(self, args):
33
+ if args.tie_weights:
34
+ if self.fc_in_dim != args.embedding_dim or args.idim != args.odim:
35
+ raise ValueError('When using the tied flag, fc_in_dim must be equal to embedding_dim')
36
+ self.fc.weight = self.embedding.weight
37
+
38
+ @torch.jit.export
39
+ def init_hidden(self, tensor, batch_size):
40
+ # type: (Tensor, int) -> Tuple[Tensor, Tensor]
41
+ return (tensor.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).float(),
42
+ tensor.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).float())
43
+
44
+ @torch.jit.export
45
+ def forward_model(self, padded_inputs, lengths=None, hidden=None):
46
+ # type: (Tensor, Optional[Tensor], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
47
+ # Embedding Layer
48
+ padded_inputs = self.embedding(padded_inputs) # N, T, D
49
+ # LSTM Layers
50
+ if lengths is None:
51
+ output, new_hidden = self.lstm(padded_inputs, hidden)
52
+ else:
53
+ lengths = lengths.cpu().int()
54
+ total_length = padded_inputs.size(1) # get the max sequence length
55
+ packed_input = pack_padded_sequence(padded_inputs, lengths,
56
+ batch_first=True,
57
+ enforce_sorted=False)
58
+ #self.lstm.flatten_parameters()
59
+ packed_output, new_hidden = self.lstm(packed_input, hidden)
60
+ output, _ = pad_packed_sequence(packed_output,
61
+ batch_first=True,
62
+ total_length=total_length)
63
+ # Output Layer
64
+ score = self.fc(output) # (N, T, V)
65
+ return score, new_hidden
fireredasr2s/fireredasr2/models/module/adapter.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Adapter(nn.Module):
8
+ def __init__(self, encoder_dim, llm_dim, downsample_rate=2):
9
+ super().__init__()
10
+ self.ds = downsample_rate
11
+ self.linear1 = nn.Linear(encoder_dim * downsample_rate, llm_dim)
12
+ self.relu = nn.ReLU()
13
+ self.linear2 = nn.Linear(llm_dim, llm_dim)
14
+
15
+ def forward(self, x, x_lens):
16
+ batch_size, seq_len, feat_dim = x.size()
17
+ num_frames_to_discard = seq_len % self.ds
18
+ if num_frames_to_discard > 0:
19
+ x = x[:, :-num_frames_to_discard, :]
20
+ seq_len = x.size(1)
21
+
22
+ x = x.contiguous()
23
+ x = x.view(
24
+ batch_size, seq_len // self.ds, feat_dim * self.ds
25
+ )
26
+
27
+ x = self.linear1(x)
28
+ x = self.relu(x)
29
+ x = self.linear2(x)
30
+
31
+ new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds
32
+ return x, new_x_lens
fireredasr2s/fireredasr2/models/module/conformer_encoder.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ConformerEncoder(nn.Module):
9
+ def __init__(self, idim, n_layers, n_head, d_model,
10
+ residual_dropout=0.1, dropout_rate=0.1, kernel_size=33,
11
+ pe_maxlen=5000):
12
+ super().__init__()
13
+ self.odim = d_model
14
+
15
+ self.input_preprocessor = Conv2dSubsampling(idim, d_model)
16
+ self.positional_encoding = RelPositionalEncoding(d_model)
17
+ self.dropout = nn.Dropout(residual_dropout)
18
+
19
+ self.layer_stack = nn.ModuleList()
20
+ for l in range(n_layers):
21
+ block = RelPosEmbConformerBlock(d_model, n_head,
22
+ residual_dropout,
23
+ dropout_rate, kernel_size)
24
+ self.layer_stack.append(block)
25
+
26
+ def forward(self, padded_input, input_lengths, pad=True):
27
+ if pad:
28
+ padded_input = F.pad(padded_input,
29
+ (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
30
+ src_mask = self.padding_position_is_0(padded_input, input_lengths)
31
+
32
+ embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask)
33
+ enc_output = self.dropout(embed_output)
34
+
35
+ pos_emb = self.dropout(self.positional_encoding(embed_output))
36
+
37
+ enc_outputs = []
38
+ for enc_layer in self.layer_stack:
39
+ enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
40
+ pad_mask=src_mask)
41
+ enc_outputs.append(enc_output)
42
+
43
+ return enc_output, input_lengths, src_mask
44
+
45
+ def padding_position_is_0(self, padded_input, input_lengths):
46
+ N, T = padded_input.size()[:2]
47
+ mask = torch.ones((N, T)).to(padded_input.device)
48
+ for i in range(N):
49
+ mask[i, input_lengths[i]:] = 0
50
+ mask = mask.unsqueeze(dim=1)
51
+ return mask.to(torch.uint8)
52
+
53
+
54
+ class RelPosEmbConformerBlock(nn.Module):
55
+ def __init__(self, d_model, n_head,
56
+ residual_dropout=0.1,
57
+ dropout_rate=0.1, kernel_size=33):
58
+ super().__init__()
59
+ self.ffn1 = ConformerFeedForward(d_model, dropout_rate)
60
+ self.mhsa = RelPosMultiHeadAttention(n_head, d_model,
61
+ residual_dropout)
62
+ self.conv = ConformerConvolution(d_model, kernel_size,
63
+ dropout_rate)
64
+ self.ffn2 = ConformerFeedForward(d_model, dropout_rate)
65
+ self.layer_norm = nn.LayerNorm(d_model)
66
+
67
+ def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None):
68
+ out = 0.5 * x + 0.5 * self.ffn1(x)
69
+ out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
70
+ out = self.conv(out, pad_mask)
71
+ out = 0.5 * out + 0.5 * self.ffn2(out)
72
+ out = self.layer_norm(out)
73
+ return out
74
+
75
+
76
+ class Swish(nn.Module):
77
+ def forward(self, x):
78
+ return x * torch.sigmoid(x)
79
+
80
+
81
+ class Conv2dSubsampling(nn.Module):
82
+ def __init__(self, idim, d_model, out_channels=32):
83
+ super().__init__()
84
+ self.conv = nn.Sequential(
85
+ nn.Conv2d(1, out_channels, 3, 2),
86
+ nn.ReLU(),
87
+ nn.Conv2d(out_channels, out_channels, 3, 2),
88
+ nn.ReLU(),
89
+ )
90
+ subsample_idim = ((idim - 1) // 2 - 1) // 2
91
+ self.out = nn.Linear(out_channels * subsample_idim, d_model)
92
+
93
+ self.subsampling = 4
94
+ left_context = right_context = 3 # both exclude currect frame
95
+ self.context = left_context + 1 + right_context # 7
96
+
97
+ def forward(self, x, x_mask):
98
+ x = x.unsqueeze(1)
99
+ x = self.conv(x)
100
+ N, C, T, D = x.size()
101
+ x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
102
+ mask = x_mask[:, :, :-2:2][:, :, :-2:2]
103
+ input_lengths = mask[:, -1, :].sum(dim=-1)
104
+ return x, input_lengths, mask
105
+
106
+
107
+ class RelPositionalEncoding(torch.nn.Module):
108
+ def __init__(self, d_model, max_len=5000):
109
+ super().__init__()
110
+ pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
111
+ pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
112
+ position = torch.arange(0, max_len).unsqueeze(1).float()
113
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
114
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
115
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
116
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
117
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
118
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
119
+
120
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
121
+ pe_negative = pe_negative[1:].unsqueeze(0)
122
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
123
+ self.register_buffer('pe', pe)
124
+
125
+ def forward(self, x):
126
+ # Tmax = 2 * max_len - 1
127
+ Tmax, T = self.pe.size(1), x.size(1)
128
+ pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
129
+ return pos_emb
130
+
131
+
132
+ class ConformerFeedForward(nn.Module):
133
+ def __init__(self, d_model, dropout_rate=0.1):
134
+ super().__init__()
135
+ pre_layer_norm = nn.LayerNorm(d_model)
136
+ linear_expand = nn.Linear(d_model, d_model*4)
137
+ nonlinear = Swish()
138
+ dropout_pre = nn.Dropout(dropout_rate)
139
+ linear_project = nn.Linear(d_model*4, d_model)
140
+ dropout_post = nn.Dropout(dropout_rate)
141
+ self.net = nn.Sequential(pre_layer_norm,
142
+ linear_expand,
143
+ nonlinear,
144
+ dropout_pre,
145
+ linear_project,
146
+ dropout_post)
147
+
148
+ def forward(self, x):
149
+ residual = x
150
+ output = self.net(x)
151
+ output = output + residual
152
+ return output
153
+
154
+
155
+ class ConformerConvolution(nn.Module):
156
+ def __init__(self, d_model, kernel_size=33, dropout_rate=0.1):
157
+ super().__init__()
158
+ assert kernel_size % 2 == 1
159
+ self.pre_layer_norm = nn.LayerNorm(d_model)
160
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False)
161
+ self.glu = F.glu
162
+ self.padding = (kernel_size - 1) // 2
163
+ self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2,
164
+ kernel_size, stride=1,
165
+ padding=self.padding,
166
+ groups=d_model*2, bias=False)
167
+ self.batch_norm = nn.LayerNorm(d_model*2)
168
+ self.swish = Swish()
169
+ self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False)
170
+ self.dropout = nn.Dropout(dropout_rate)
171
+
172
+ def forward(self, x, mask=None):
173
+ residual = x
174
+ out = self.pre_layer_norm(x)
175
+ out = out.transpose(1, 2)
176
+ if mask is not None:
177
+ out.masked_fill_(mask.ne(1), 0.0)
178
+ out = self.pointwise_conv1(out)
179
+ out = F.glu(out, dim=1)
180
+ out = self.depthwise_conv(out)
181
+
182
+ out = out.transpose(1, 2)
183
+ out = self.swish(self.batch_norm(out))
184
+ out = out.transpose(1, 2)
185
+
186
+ out = self.dropout(self.pointwise_conv2(out))
187
+ if mask is not None:
188
+ out.masked_fill_(mask.ne(1), 0.0)
189
+ out = out.transpose(1, 2)
190
+ return out + residual
191
+
192
+
193
+ class EncoderMultiHeadAttention(nn.Module):
194
+ def __init__(self, n_head, d_model,
195
+ residual_dropout=0.1):
196
+ super().__init__()
197
+ assert d_model % n_head == 0
198
+ self.n_head = n_head
199
+ self.d_k = d_model // n_head
200
+ self.d_v = self.d_k
201
+
202
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False)
203
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
204
+ self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False)
205
+
206
+ self.layer_norm_q = nn.LayerNorm(d_model)
207
+ self.layer_norm_k = nn.LayerNorm(d_model)
208
+ self.layer_norm_v = nn.LayerNorm(d_model)
209
+
210
+ self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5)
211
+ self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False)
212
+ self.dropout = nn.Dropout(residual_dropout)
213
+
214
+ def forward(self, q, k, v, mask=None):
215
+ sz_b, len_q = q.size(0), q.size(1)
216
+
217
+ residual = q
218
+ q, k, v = self.forward_qkv(q, k, v)
219
+
220
+ output, attn = self.attention(q, k, v, mask=mask)
221
+
222
+ output = self.forward_output(output, residual, sz_b, len_q)
223
+ return output, attn
224
+
225
+ def forward_qkv(self, q, k, v):
226
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
227
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
228
+
229
+ q = self.layer_norm_q(q)
230
+ k = self.layer_norm_k(k)
231
+ v = self.layer_norm_v(v)
232
+
233
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
234
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
235
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
236
+ q = q.transpose(1, 2)
237
+ k = k.transpose(1, 2)
238
+ v = v.transpose(1, 2)
239
+ return q, k, v
240
+
241
+ def forward_output(self, output, residual, sz_b, len_q):
242
+ output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
243
+ fc_out = self.fc(output)
244
+ output = self.dropout(fc_out)
245
+ output = output + residual
246
+ return output
247
+
248
+
249
+ class ScaledDotProductAttention(nn.Module):
250
+ def __init__(self, temperature):
251
+ super().__init__()
252
+ self.temperature = temperature
253
+ self.dropout = nn.Dropout(0.0)
254
+ self.INF = float('inf')
255
+
256
+ def forward(self, q, k, v, mask=None):
257
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
258
+ output, attn = self.forward_attention(attn, v, mask)
259
+ return output, attn
260
+
261
+ def forward_attention(self, attn, v, mask=None):
262
+ if mask is not None:
263
+ mask = mask.unsqueeze(1)
264
+ mask = mask.eq(0)
265
+ attn = attn.masked_fill(mask, -self.INF)
266
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
267
+ else:
268
+ attn = torch.softmax(attn, dim=-1)
269
+
270
+ d_attn = self.dropout(attn)
271
+ output = torch.matmul(d_attn, v)
272
+
273
+ return output, attn
274
+
275
+
276
+ class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
277
+ def __init__(self, n_head, d_model,
278
+ residual_dropout=0.1):
279
+ super().__init__(n_head, d_model,
280
+ residual_dropout)
281
+ d_k = d_model // n_head
282
+ self.scale = 1.0 / (d_k ** 0.5)
283
+ self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False)
284
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k))
285
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k))
286
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
287
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
288
+
289
+ def _rel_shift(self, x):
290
+ N, H, T1, T2 = x.size()
291
+ zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
292
+ x_padded = torch.cat([zero_pad, x], dim=-1)
293
+
294
+ x_padded = x_padded.view(N, H, T2 + 1, T1)
295
+ x = x_padded[:, :, 1:].view_as(x)
296
+ x = x[:, :, :, : x.size(-1) // 2 + 1]
297
+ return x
298
+
299
+ def forward(self, q, k, v, pos_emb, mask=None):
300
+ sz_b, len_q = q.size(0), q.size(1)
301
+
302
+ residual = q
303
+ q, k, v = self.forward_qkv(q, k, v)
304
+
305
+ q = q.transpose(1, 2)
306
+ n_batch_pos = pos_emb.size(0)
307
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k)
308
+ p = p.transpose(1, 2)
309
+
310
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
311
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
312
+
313
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
314
+
315
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
316
+ matrix_bd = self._rel_shift(matrix_bd)
317
+
318
+ attn_scores = matrix_ac + matrix_bd
319
+ attn_scores.mul_(self.scale)
320
+
321
+ output, attn = self.attention.forward_attention(attn_scores, v, mask=mask)
322
+
323
+ output = self.forward_output(output, residual, sz_b, len_q)
324
+ return output, attn
fireredasr2s/fireredasr2/models/module/ctc.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class CTC(torch.nn.Module):
9
+ def __init__(self, odim, encoder_output_size):
10
+ super().__init__()
11
+ self.ctc_lo = torch.nn.Linear(encoder_output_size, odim)
12
+
13
+ def forward(self, encoder_output_pad):
14
+ """encoder_output_pad: (N, T, H)"""
15
+ return F.log_softmax(self.ctc_lo(encoder_output_pad), dim=2)
16
+
17
+ @classmethod
18
+ def ctc_align(cls, ctc_probs, y, blank_id=0):
19
+ """ctc forced alignment.
20
+
21
+ Args:
22
+ torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
23
+ torch.Tensor y: id sequence tensor 1d tensor (L)
24
+ int blank_id: blank symbol index
25
+ Returns:
26
+ torch.Tensor: alignment result
27
+ """
28
+ y_insert_blank = insert_blank(y, blank_id)
29
+
30
+ log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
31
+ log_alpha = log_alpha - float('inf') # log of zero
32
+ state_path = (torch.zeros(
33
+ (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
34
+ ) # state path
35
+
36
+ # init start state
37
+ log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
38
+ log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]
39
+
40
+ for t in range(1, ctc_probs.size(0)):
41
+ for s in range(len(y_insert_blank)):
42
+ if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
43
+ s] == y_insert_blank[s - 2]:
44
+ candidates = torch.tensor(
45
+ [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
46
+ prev_state = [s, s - 1]
47
+ else:
48
+ candidates = torch.tensor([
49
+ log_alpha[t - 1, s],
50
+ log_alpha[t - 1, s - 1],
51
+ log_alpha[t - 1, s - 2],
52
+ ])
53
+ prev_state = [s, s - 1, s - 2]
54
+ log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
55
+ state_path[t, s] = prev_state[torch.argmax(candidates)]
56
+
57
+ state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
58
+
59
+ candidates = torch.tensor([
60
+ log_alpha[-1, len(y_insert_blank) - 1],
61
+ log_alpha[-1, len(y_insert_blank) - 2]
62
+ ])
63
+ prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
64
+ state_seq[-1] = prev_state[torch.argmax(candidates)]
65
+ for t in range(ctc_probs.size(0) - 2, -1, -1):
66
+ state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
67
+
68
+ output_alignment = []
69
+ for t in range(0, ctc_probs.size(0)):
70
+ output_alignment.append(y_insert_blank[state_seq[t, 0]])
71
+
72
+ return output_alignment
73
+
74
+ @classmethod
75
+ def ctc_alignment_to_timestamp(cls, ys_with_blank, subsampling, blank_id=0):
76
+ start_times: List[float] = []
77
+ end_times: List[float] = []
78
+ frame_shift = 10 # ms, hard code
79
+ T = len(ys_with_blank)
80
+ t = 0
81
+ ctc_durs = []
82
+ while t < T:
83
+ token = ys_with_blank[t]
84
+ t += 1
85
+ if token != blank_id:
86
+ start_t = t
87
+ timestamp = frame_shift * subsampling * t / 1000.0 # s
88
+ start_times.append(timestamp)
89
+ if len(start_times) == len(end_times) + 2:
90
+ end_times.append(start_times[-1])
91
+ # skip repeat token
92
+ while t < T and token == ys_with_blank[t]:
93
+ t += 1
94
+ assert t-start_t >= 0
95
+ ctc_durs.append((t-start_t+1) * frame_shift * subsampling / 1000.0)
96
+ end_times.append((frame_shift * subsampling * T + 25)/ 1000.0)
97
+ if len(start_times) == 0:
98
+ start_times.append(0.0)
99
+
100
+ # Refine end_times
101
+ assert len(ctc_durs) == len(end_times) and len(start_times) == len(end_times)
102
+ avg_dur = sum(e-s for s, e in zip(start_times, end_times)) / len(end_times)
103
+ new_end_times = []
104
+ for s, e, ctc_dur in zip(start_times, end_times, ctc_durs):
105
+ if e - s > 2 * avg_dur:
106
+ e = s + max(1.5*avg_dur, ctc_dur)
107
+ new_end_times.append(round(e, 3))
108
+ end_times = new_end_times
109
+ return start_times, end_times
110
+
111
+
112
+ def insert_blank(label, blank_id=0):
113
+ """Insert blank token between every two label token."""
114
+ label = np.expand_dims(label, 1)
115
+ blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
116
+ label = np.concatenate([blanks, label], axis=1)
117
+ label = label.reshape(-1)
118
+ label = np.append(label, label[0])
119
+ return label
fireredasr2s/fireredasr2/models/module/transformer_decoder.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ from typing import List, Optional, Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+
11
+ class TransformerDecoder(nn.Module):
12
+ def __init__(
13
+ self, sos_id, eos_id, pad_id, odim,
14
+ n_layers, n_head, d_model,
15
+ residual_dropout=0.1, pe_maxlen=5000):
16
+ super().__init__()
17
+ self.INF = 1e10
18
+ # parameters
19
+ self.pad_id = pad_id
20
+ self.sos_id = sos_id
21
+ self.eos_id = eos_id
22
+ self.n_layers = n_layers
23
+
24
+ # Components
25
+ self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id)
26
+ self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
27
+ self.dropout = nn.Dropout(residual_dropout)
28
+
29
+ self.layer_stack = nn.ModuleList()
30
+ for l in range(n_layers):
31
+ block = DecoderLayer(d_model, n_head, residual_dropout)
32
+ self.layer_stack.append(block)
33
+
34
+ self.tgt_word_prj = nn.Linear(d_model, odim, bias=False)
35
+ self.layer_norm_out = nn.LayerNorm(d_model)
36
+
37
+ self.tgt_word_prj.weight = self.tgt_word_emb.weight
38
+ self.scale = (d_model ** 0.5)
39
+
40
+ def batch_beam_search(self, encoder_outputs, src_masks,
41
+ beam_size=1, nbest=1, decode_max_len=0,
42
+ softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0,
43
+ elm=None, elm_weight=0.0):
44
+ B = beam_size
45
+ N, Ti, H = encoder_outputs.size()
46
+ device = encoder_outputs.device
47
+ maxlen = decode_max_len if decode_max_len > 0 else Ti
48
+ assert eos_penalty > 0.0
49
+
50
+ # Init
51
+ encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H)
52
+ src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti)
53
+ ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device)
54
+ t_ys = ys.clone()
55
+ confidences = torch.zeros(N*B, 1).float().to(device)
56
+ caches: List[Optional[Tensor]] = []
57
+ for _ in range(self.n_layers):
58
+ caches.append(None)
59
+ scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device)
60
+ scores = scores.repeat(N).view(N*B, 1)
61
+ is_finished = torch.zeros_like(scores)
62
+ if elm is not None:
63
+ elm_cache = elm.init_hidden(encoder_outputs, N*B)
64
+
65
+ # Autoregressive Prediction
66
+ for t in range(maxlen):
67
+ tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id)
68
+
69
+ dec_output = self.dropout(
70
+ self.tgt_word_emb(ys) * self.scale +
71
+ self.positional_encoding(ys))
72
+ # if t > 0:
73
+ # dec_output = dec_output[:, -1:, :]
74
+ i = 0
75
+ for dec_layer in self.layer_stack:
76
+ dec_output = dec_layer.forward(
77
+ dec_output, encoder_outputs,
78
+ tgt_mask, src_mask,
79
+ cache=caches[i])
80
+ caches[i] = dec_output
81
+ i += 1
82
+
83
+ dec_output = self.layer_norm_out(dec_output)
84
+
85
+ t_logit = self.tgt_word_prj(dec_output[:, -1])
86
+ t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1)
87
+ t_origin_scores = t_scores
88
+
89
+ if elm is not None and elm_weight > 0.0:
90
+ elm_logit, elm_cache = elm.forward_model(t_ys, hidden=elm_cache)
91
+ #elm_logit, _ = elm.forward_model(ys)
92
+ t_lm_scores = torch.log_softmax(elm_logit[:, -1], dim=-1) * (1 - is_finished.float()) # mask, (N*B, V)
93
+ t_lm_scores[:, elm.eos_id] *= 3
94
+ t_scores = t_scores + elm_weight * t_lm_scores
95
+
96
+ if eos_penalty != 1.0:
97
+ t_scores[:, self.eos_id] *= eos_penalty
98
+
99
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1)
100
+ t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished)
101
+ t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished)
102
+
103
+ # Accumulated
104
+ scores = scores + t_topB_scores
105
+
106
+ # Pruning
107
+ scores = scores.view(N, B*B)
108
+ scores, topB_score_ids = torch.topk(scores, k=B, dim=1)
109
+ scores = scores.view(-1, 1)
110
+
111
+ topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B)
112
+ stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device)
113
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
114
+
115
+ # Update ys
116
+ ys = ys[topB_row_number_in_ys]
117
+ t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
118
+ ys = torch.cat((ys, t_ys), dim=1)
119
+
120
+ # Update confidences
121
+ confidences = confidences[topB_row_number_in_ys]
122
+ t_confidences = torch.gather(t_topB_scores.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
123
+ t_confidences = torch.exp(t_confidences)
124
+ assert torch.all(t_confidences <= 1.0)
125
+ assert torch.all(t_confidences >= 0.0)
126
+ confidences = torch.cat((confidences, t_confidences), dim=1)
127
+
128
+ # Update caches
129
+ new_caches: List[Optional[Tensor]] = []
130
+ for cache in caches:
131
+ if cache is not None:
132
+ new_caches.append(cache[topB_row_number_in_ys])
133
+ caches = new_caches
134
+ if elm and elm_weight > 0.0:
135
+ elm_cache = (elm_cache[0][:, topB_row_number_in_ys], elm_cache[1][:, topB_row_number_in_ys])
136
+
137
+ # Update finished state
138
+ is_finished = t_ys.eq(self.eos_id)
139
+ if is_finished.sum().item() == N*B:
140
+ break
141
+
142
+ # Length penalty (follow GNMT)
143
+ scores = scores.view(N, B)
144
+ ys = ys.view(N, B, -1)
145
+ ys_lengths = self.get_ys_lengths(ys)
146
+ if length_penalty > 0.0:
147
+ penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty)
148
+ scores /= penalty
149
+ nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1)
150
+ nbest_scores = -1.0 * nbest_scores
151
+ index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long()
152
+ nbest_ys = ys.view(N*B, -1)[index.view(-1)]
153
+ nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1)
154
+ nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1)
155
+ nbest_confidences = confidences.view(N*B, -1)[index.view(-1)].view(N, nbest_ids.size(1), -1)
156
+
157
+ # result
158
+ nbest_hyps: List[List[Dict[str, Tensor]]] = []
159
+ for n in range(N):
160
+ n_nbest_hyps: List[Dict[str, Tensor]] = []
161
+ for i, score in enumerate(nbest_scores[n]):
162
+ confidence = nbest_confidences[n, i, 1:nbest_ys_lengths[n, i]]
163
+ confidence = confidence.mean()
164
+ new_hyp = {
165
+ "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]],
166
+ "confidence": confidence
167
+ }
168
+ n_nbest_hyps.append(new_hyp)
169
+ nbest_hyps.append(n_nbest_hyps)
170
+ return nbest_hyps
171
+
172
+ def ignored_target_position_is_0(self, padded_targets, ignore_id):
173
+ mask = torch.ne(padded_targets, ignore_id)
174
+ mask = mask.unsqueeze(dim=1)
175
+ T = padded_targets.size(-1)
176
+ upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype)
177
+ upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device)
178
+ return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8)
179
+
180
+ def upper_triangular_is_0(self, size):
181
+ ones = torch.ones(size, size)
182
+ tri_left_ones = torch.tril(ones)
183
+ return tri_left_ones.to(torch.uint8)
184
+
185
+ def set_finished_beam_score_to_zero(self, scores, is_finished):
186
+ NB, B = scores.size()
187
+ is_finished = is_finished.float()
188
+ mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device)
189
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
190
+ return scores * (1 - is_finished) + mask_score * is_finished
191
+
192
+ def set_finished_beam_y_to_eos(self, ys, is_finished):
193
+ is_finished = is_finished.long()
194
+ return ys * (1 - is_finished) + self.eos_id * is_finished
195
+
196
+ def get_ys_lengths(self, ys):
197
+ N, B, Tmax = ys.size()
198
+ ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1)
199
+ return ys_lengths.int()
200
+
201
+
202
+
203
+ class DecoderLayer(nn.Module):
204
+ def __init__(self, d_model, n_head, dropout):
205
+ super().__init__()
206
+ self.self_attn_norm = nn.LayerNorm(d_model)
207
+ self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
208
+
209
+ self.cross_attn_norm = nn.LayerNorm(d_model)
210
+ self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
211
+
212
+ self.mlp_norm = nn.LayerNorm(d_model)
213
+ self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)
214
+
215
+ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
216
+ cache=None):
217
+ x = dec_input
218
+ residual = x
219
+ x = self.self_attn_norm(x)
220
+ if cache is not None:
221
+ xq = x[:, -1:, :]
222
+ residual = residual[:, -1:, :]
223
+ self_attn_mask = self_attn_mask[:, -1:, :]
224
+ else:
225
+ xq = x
226
+ x = self.self_attn(xq, x, x, mask=self_attn_mask)
227
+ x = residual + x
228
+
229
+ residual = x
230
+ x = self.cross_attn_norm(x)
231
+ x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
232
+ x = residual + x
233
+
234
+ residual = x
235
+ x = self.mlp_norm(x)
236
+ x = residual + self.mlp(x)
237
+
238
+ if cache is not None:
239
+ x = torch.cat([cache, x], dim=1)
240
+
241
+ return x
242
+
243
+
244
+ class DecoderMultiHeadAttention(nn.Module):
245
+ def __init__(self, d_model, n_head, dropout=0.1):
246
+ super().__init__()
247
+ self.d_model = d_model
248
+ self.n_head = n_head
249
+ self.d_k = d_model // n_head
250
+
251
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k)
252
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
253
+ self.w_vs = nn.Linear(d_model, n_head * self.d_k)
254
+
255
+ self.attention = DecoderScaledDotProductAttention(
256
+ temperature=self.d_k ** 0.5)
257
+ self.fc = nn.Linear(n_head * self.d_k, d_model)
258
+ self.dropout = nn.Dropout(dropout)
259
+
260
+ def forward(self, q, k, v, mask=None):
261
+ bs = q.size(0)
262
+
263
+ q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
264
+ k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k)
265
+ v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k)
266
+ q = q.transpose(1, 2)
267
+ k = k.transpose(1, 2)
268
+ v = v.transpose(1, 2)
269
+
270
+ if mask is not None:
271
+ mask = mask.unsqueeze(1)
272
+
273
+ output = self.attention(q, k, v, mask=mask)
274
+
275
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
276
+ output = self.fc(output)
277
+ output = self.dropout(output)
278
+
279
+ return output
280
+
281
+
282
+ class DecoderScaledDotProductAttention(nn.Module):
283
+ def __init__(self, temperature):
284
+ super().__init__()
285
+ self.temperature = temperature
286
+ self.INF = float("inf")
287
+
288
+ def forward(self, q, k, v, mask=None):
289
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
290
+ if mask is not None:
291
+ mask = mask.eq(0)
292
+ attn = attn.masked_fill(mask, -self.INF)
293
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
294
+ else:
295
+ attn = torch.softmax(attn, dim=-1)
296
+ output = torch.matmul(attn, v)
297
+ return output
298
+
299
+
300
+ class PositionwiseFeedForward(nn.Module):
301
+ def __init__(self, d_model, d_ff, dropout=0.1):
302
+ super().__init__()
303
+ self.w_1 = nn.Linear(d_model, d_ff)
304
+ self.act = nn.GELU()
305
+ self.w_2 = nn.Linear(d_ff, d_model)
306
+ self.dropout = nn.Dropout(dropout)
307
+
308
+ def forward(self, x):
309
+ output = self.w_2(self.act(self.w_1(x)))
310
+ output = self.dropout(output)
311
+ return output
312
+
313
+
314
+ class PositionalEncoding(nn.Module):
315
+ def __init__(self, d_model, max_len=5000):
316
+ super().__init__()
317
+ assert d_model % 2 == 0
318
+ pe = torch.zeros(max_len, d_model, requires_grad=False)
319
+ position = torch.arange(0, max_len).unsqueeze(1).float()
320
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
321
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
322
+ pe[:, 0::2] = torch.sin(position * div_term)
323
+ pe[:, 1::2] = torch.cos(position * div_term)
324
+ pe = pe.unsqueeze(0)
325
+ self.register_buffer('pe', pe)
326
+
327
+ def forward(self, x):
328
+ length = x.size(1)
329
+ return self.pe[:, :length].clone().detach()
fireredasr2s/fireredasr2/models/param.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import logging
4
+
5
+ import torch
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def count_model_parameters(model):
11
+ if not isinstance(model, torch.nn.Module):
12
+ return 0, 0
13
+ name = f"{model.__class__.__name__} {model.__class__}"
14
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
15
+ size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
16
+ logger.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
17
+ return num, size
fireredasr2s/fireredasr2/speech2text.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
4
+
5
+ import argparse
6
+ import json
7
+ import logging
8
+ import os
9
+
10
+ from fireredasr2.asr import FireRedAsr2, FireRedAsr2Config
11
+ from fireredasr2.utils.io import get_wav_info, write_textgrid
12
+
13
+ logging.basicConfig(level=logging.INFO,
14
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
15
+ logger = logging.getLogger("fireredasr2.bin.speech2text")
16
+
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--asr_type', type=str, required=True, choices=["aed", "llm"])
20
+ parser.add_argument('--model_dir', type=str, required=True)
21
+
22
+ # Input / Output
23
+ parser.add_argument("--wav_path", type=str)
24
+ parser.add_argument("--wav_paths", type=str, nargs="*")
25
+ parser.add_argument("--wav_dir", type=str)
26
+ parser.add_argument("--wav_scp", type=str)
27
+ parser.add_argument("--sort_wav_by_dur", type=int, default=0)
28
+ parser.add_argument("--output", type=str)
29
+
30
+ # Decode Options
31
+ parser.add_argument('--use_gpu', type=int, default=1)
32
+ parser.add_argument('--use_half', type=int, default=0)
33
+ parser.add_argument("--batch_size", type=int, default=1)
34
+ parser.add_argument("--beam_size", type=int, default=1)
35
+ parser.add_argument("--decode_max_len", type=int, default=0)
36
+ # FireRedASR-AED
37
+ parser.add_argument("--nbest", type=int, default=1)
38
+ parser.add_argument("--softmax_smoothing", type=float, default=1.0)
39
+ parser.add_argument("--aed_length_penalty", type=float, default=0.0)
40
+ parser.add_argument("--eos_penalty", type=float, default=1.0)
41
+ parser.add_argument("--return_timestamp", type=int, default=0)
42
+ parser.add_argument("--write_textgrid", type=int, default=0)
43
+ # AED External LM
44
+ parser.add_argument("--elm_dir", type=str, default="")
45
+ parser.add_argument("--elm_weight", type=float, default=0.0)
46
+ # FireRedASR-LLM
47
+ parser.add_argument("--decode_min_len", type=int, default=0)
48
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
49
+ parser.add_argument("--llm_length_penalty", type=float, default=0.0)
50
+ parser.add_argument("--temperature", type=float, default=1.0)
51
+
52
+
53
+ def main(args):
54
+ wavs = get_wav_info(args)
55
+ fout = open(args.output, "w") if args.output else None
56
+ foutl = open(args.output + ".jsonl", "w") if args.output else None
57
+
58
+ asr_config = FireRedAsr2Config(
59
+ args.use_gpu,
60
+ args.use_half,
61
+ args.beam_size,
62
+ args.nbest,
63
+ args.decode_max_len,
64
+ args.softmax_smoothing,
65
+ args.aed_length_penalty,
66
+ args.eos_penalty,
67
+ args.return_timestamp,
68
+ args.decode_min_len,
69
+ args.repetition_penalty,
70
+ args.llm_length_penalty,
71
+ args.temperature,
72
+ args.elm_dir,
73
+ args.elm_weight
74
+ )
75
+ model = FireRedAsr2.from_pretrained(args.asr_type, args.model_dir, asr_config)
76
+
77
+ batch_uttid = []
78
+ batch_wav_path = []
79
+ for i, wav in enumerate(wavs):
80
+ uttid, wav_path = wav
81
+ batch_uttid.append(uttid)
82
+ batch_wav_path.append(wav_path)
83
+ if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1:
84
+ continue
85
+
86
+ results = model.transcribe(batch_uttid, batch_wav_path)
87
+
88
+ for result in results:
89
+ logger.info(result)
90
+ if fout is not None:
91
+ foutl.write(f"{json.dumps(result, ensure_ascii=False)}\n")
92
+ fout.write(f"{result['uttid']}\t{result['text']}\n")
93
+ if args.write_textgrid and "timestamp" in result:
94
+ write_textgrid(result["wav"], result["dur_s"], result["timestamp"])
95
+
96
+ if fout: fout.flush()
97
+ if foutl: foutl.flush()
98
+ batch_uttid = []
99
+ batch_wav_path = []
100
+ if fout: fout.close()
101
+ if foutl: foutl.close()
102
+
103
+
104
+ if __name__ == "__main__":
105
+ args = parser.parse_args()
106
+ logger.info(args)
107
+ main(args)
fireredasr2s/fireredasr2/tokenizer/aed_tokenizer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import logging
4
+ import re
5
+
6
+ import sentencepiece as spm
7
+
8
+ from ..data.token_dict import TokenDict
9
+
10
+
11
+ class ChineseCharEnglishSpmTokenizer:
12
+ """
13
+ - One Chinese char is a token.
14
+ - Split English word into SPM and one piece is a token.
15
+ - Ignore ' ' between Chinese char
16
+ - Replace ' ' between English word with "▁" by spm_model
17
+ - Need to put SPM piece into dict file
18
+ - If not set spm_model, will use English char and <space>
19
+ """
20
+ SPM_SPACE = "▁"
21
+
22
+ def __init__(self, dict_path, spm_model, unk="<unk>", space="<space>"):
23
+ self.dict = TokenDict(dict_path, unk=unk)
24
+ self.space = space
25
+ if spm_model:
26
+ self.sp = spm.SentencePieceProcessor()
27
+ self.sp.Load(spm_model)
28
+ else:
29
+ self.sp = None
30
+ print("[WRAN] Not set spm_model, will use English char")
31
+ print("[WARN] Please check how to deal with ' '(space)")
32
+ if self.space not in self.dict:
33
+ print("Please add <space> to your dict, or it will be <unk>")
34
+
35
+ def tokenize(self, text, replace_punc=True):
36
+ #if text == "":
37
+ # logging.info(f"empty text")
38
+ text = text.upper()
39
+ tokens = []
40
+ if replace_punc:
41
+ text = re.sub("[,。?!,\.?!]", " ", text)
42
+ pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])')
43
+ parts = pattern.split(text.strip())
44
+ parts = [p for p in parts if len(p.strip()) > 0]
45
+ for part in parts:
46
+ if pattern.fullmatch(part) is not None:
47
+ tokens.append(part)
48
+ else:
49
+ if self.sp:
50
+ for piece in self.sp.EncodeAsPieces(part.strip()):
51
+ tokens.append(piece)
52
+ else:
53
+ for char in part.strip():
54
+ tokens.append(char if char != " " else self.space)
55
+ tokens_id = []
56
+ for token in tokens:
57
+ tokens_id.append(self.dict.get(token, self.dict.unk))
58
+ return tokens, tokens_id
59
+
60
+ def detokenize(self, inputs, join_symbol="", replace_spm_space=True):
61
+ """inputs is ids or tokens, do not need self.sp"""
62
+ if len(inputs) > 0 and type(inputs[0]) == int:
63
+ tokens = [self.dict[id] for id in inputs]
64
+ else:
65
+ tokens = inputs
66
+ s = f"{join_symbol}".join(tokens)
67
+ if replace_spm_space:
68
+ s = s.replace(self.SPM_SPACE, ' ').strip()
69
+ return s
70
+
71
+ def merge_spm_timestamp(self, timestamp):
72
+ merged_timestamp = []
73
+ i = 0
74
+ while i < len(timestamp):
75
+ token, start, end = timestamp[i]
76
+ if token.startswith(self.SPM_SPACE):
77
+ token = token.replace(self.SPM_SPACE, "")
78
+ current_end = end
79
+ next_i = i + 1
80
+ while next_i < len(timestamp):
81
+ next_token, next_start, next_end = timestamp[next_i]
82
+ if re.match("^[a-zA-Z']+$", next_token):
83
+ token += next_token
84
+ current_end = next_end
85
+ next_i += 1
86
+ else:
87
+ break
88
+ end = current_end
89
+ i = next_i
90
+ else:
91
+ i += 1
92
+ merged_timestamp.append((token, start, end))
93
+ return merged_timestamp
fireredasr2s/fireredasr2/tokenizer/llm_tokenizer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import re
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from transformers.trainer_pt_utils import LabelSmoother
8
+
9
+ DEFAULT_SPEECH_TOKEN = "<speech>"
10
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
11
+
12
+
13
+ class LlmTokenizerWrapper:
14
+ @classmethod
15
+ def build_llm_tokenizer(cls, llm_path, use_flash_attn=False):
16
+ tokenizer = AutoTokenizer.from_pretrained(llm_path)
17
+ if use_flash_attn:
18
+ tokenizer.padding_side = "left"
19
+ else:
20
+ tokenizer.padding_side = "right"
21
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
22
+ tokenizer.add_special_tokens(special_tokens_dict)
23
+ return tokenizer
24
+
25
+ @classmethod
26
+ def clean_text(cls, origin_text):
27
+ """remove punc, remove space between Chinese and keep space between English"""
28
+ # remove punc
29
+ text = re.sub("[,。?!,\.!?《》()\·“”、\\/]", "", origin_text)
30
+ # merge space
31
+ text = re.sub("\s+", " ", text)
32
+
33
+ # remove space between Chinese and keep space between English
34
+ pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])') # Chinese
35
+ parts = pattern.split(text.strip())
36
+ parts = [p for p in parts if len(p.strip()) > 0]
37
+ text = "".join(parts)
38
+ text = text.strip()
39
+
40
+ text = text.lower()
41
+ return text
42
+
43
+ @classmethod
44
+ def preprocess_texts(cls, origin_texts, tokenizer, max_len, decode=False):
45
+ messages = []
46
+ clean_texts = []
47
+ for i, origin_text in enumerate(origin_texts):
48
+ text = cls.clean_text(origin_text)
49
+ clean_texts.append(text)
50
+ text = text if not decode else ""
51
+ message = [
52
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
53
+ {"role": "assistant", "content": text},
54
+ ]
55
+ messages.append(message)
56
+
57
+ texts = []
58
+ if not decode:
59
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
60
+ else:
61
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
62
+ for i, msg in enumerate(messages):
63
+ texts.append(
64
+ tokenizer.apply_chat_template(
65
+ msg,
66
+ tokenize=True,
67
+ chat_template=TEMPLATE,
68
+ add_generation_prompt=False,
69
+ padding="longest",
70
+ max_length=max_len,
71
+ truncation=True,
72
+ )
73
+ )
74
+
75
+ # Padding texts
76
+ max_len_texts = max([len(text) for text in texts])
77
+ if tokenizer.padding_side == "right":
78
+ texts = [
79
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
80
+ for text in texts
81
+ ]
82
+ else:
83
+ texts = [
84
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
85
+ for text in texts
86
+ ]
87
+ input_ids = torch.tensor(texts, dtype=torch.int)
88
+
89
+ target_ids = input_ids.clone()
90
+ target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
91
+
92
+ # first get the indices of the tokens
93
+ mask_prompt = True
94
+ if mask_prompt:
95
+ mask_indices = torch.where(
96
+ input_ids == tokenizer.convert_tokens_to_ids("assistant")
97
+ )
98
+ for i in range(mask_indices[0].size(0)):
99
+ row = mask_indices[0][i]
100
+ col = mask_indices[1][i]
101
+ target_ids[row, : col + 2] = IGNORE_TOKEN_ID
102
+
103
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
104
+
105
+ target_ids = target_ids.type(torch.LongTensor)
106
+ input_ids = input_ids.type(torch.LongTensor)
107
+ return input_ids, attention_mask, target_ids, clean_texts
fireredasr2s/fireredasr2/utils/io.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
2
+
3
+ import glob
4
+ import os
5
+ import logging
6
+ logger = logging.getLogger(__name__)
7
+
8
+ from textgrid import TextGrid, IntervalTier
9
+
10
+
11
+ def get_wav_info(args):
12
+ """
13
+ Returns:
14
+ wavs: list of (uttid, wav_path)
15
+ """
16
+ base = lambda p: os.path.basename(p).replace(".wav", "")
17
+ if args.wav_path:
18
+ wavs = [(base(args.wav_path), args.wav_path)]
19
+ elif args.wav_paths and len(args.wav_paths) >= 1:
20
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
21
+ elif args.wav_scp:
22
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
23
+ if args.sort_wav_by_dur:
24
+ logger.info("Sort wav by duration...")
25
+ utt2dur = os.path.join(os.path.dirname(args.wav_scp), "utt2dur")
26
+ if os.path.exists(utt2dur):
27
+ utt2dur = [l.strip().split() for l in open(utt2dur)]
28
+ utt2dur = {l[0]: float(l[1]) for l in utt2dur if len(l) == 2}
29
+ wavs = sorted(wavs, key=lambda x: -utt2dur[x[0]])
30
+ logger.info("Sort Done")
31
+ else:
32
+ logger.info(f"Not find {utt2dur}, un-sort")
33
+ elif args.wav_dir:
34
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
35
+ wavs = [(base(p), p) for p in sorted(wavs)]
36
+ else:
37
+ raise ValueError("Please provide valid wav info")
38
+ logger.info(f"#wavs={len(wavs)}")
39
+ return wavs
40
+
41
+
42
+ def write_textgrid(wav_path, wav_dur, event):
43
+ textgrid_file = wav_path.replace(".wav", ".TextGrid")
44
+ logger.info(f"Write {textgrid_file}")
45
+ textgrid = TextGrid(maxTime=wav_dur)
46
+ tier = IntervalTier(name="token", maxTime=wav_dur)
47
+ for token, start_s, end_s in event:
48
+ if start_s == end_s:
49
+ logger.info(f"Write TG, skip start=end {start_s}")
50
+ continue
51
+ start_s = max(start_s, 0)
52
+ end_s = min(end_s, wav_dur)
53
+ tier.add(minTime=start_s, maxTime=end_s, mark=token)
54
+ textgrid.append(tier)
55
+ textgrid.write(textgrid_file)
fireredasr2s/fireredasr2/utils/wer.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
4
+
5
+ import argparse
6
+ import re
7
+ from collections import OrderedDict
8
+
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--ref", type=str, required=True)
12
+ parser.add_argument("--hyp", type=str, required=True)
13
+ parser.add_argument("--print_sentence_wer", type=int, default=0)
14
+ parser.add_argument("--do_tn", type=int, default=0, help="simple tn by cn2an")
15
+ parser.add_argument("--rm_special", type=int, default=0, help="remove <\|.*?\|>")
16
+
17
+
18
+ def main(args):
19
+ uttid2refs = read_uttid2tokens(args.ref, args.do_tn, args.rm_special)
20
+ uttid2hyps = read_uttid2tokens(args.hyp, args.do_tn, args.rm_special)
21
+ uttid2wer_info, wer_stat, en_dig_stat = compute_uttid2wer_info(
22
+ uttid2refs, uttid2hyps, args.print_sentence_wer)
23
+ wer_stat.print()
24
+ en_dig_stat.print()
25
+
26
+
27
+ def read_uttid2tokens(filename, do_tn=False, rm_special=False):
28
+ print(f">>> Read uttid to tokens: {filename}", flush=True)
29
+ uttid2tokens = OrderedDict()
30
+ uttid2text = read_uttid2text(filename, do_tn, rm_special)
31
+ for uttid, text in uttid2text.items():
32
+ tokens = text2tokens(text)
33
+ uttid2tokens[uttid] = tokens
34
+ return uttid2tokens
35
+
36
+
37
+ def read_uttid2text(filename, do_tn=False, rm_special=False):
38
+ uttid2text = OrderedDict()
39
+ with open(filename, "r", encoding="utf8") as fin:
40
+ for i, line in enumerate(fin):
41
+ cols = line.split()
42
+ if len(cols) == 0:
43
+ print("[WARN] empty line, continue", i, flush=True)
44
+ continue
45
+ assert cols[0] not in uttid2text, f"repeated uttid: {line}"
46
+ if len(cols) == 1:
47
+ uttid2text[cols[0]] = ""
48
+ continue
49
+ txt = " ".join(cols[1:])
50
+ if rm_special:
51
+ txt = " ".join([t for t in re.split("<\|.*?\|>", txt) if t.strip() != ""])
52
+ if do_tn:
53
+ import cn2an
54
+ txt = cn2an.transform(txt, "an2cn")
55
+ uttid2text[cols[0]] = txt
56
+ return uttid2text
57
+
58
+
59
+ def text2tokens(text):
60
+ PUNCTUATIONS = ",。?!,\.?!"#$%&'()*+-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·。\":" + "()\[\]{}/;`|=+"
61
+ if text == "":
62
+ return []
63
+ tokens = []
64
+
65
+ text = re.sub("<unk>", "", text)
66
+ text = re.sub(r"[%s]+" % PUNCTUATIONS, " ", text)
67
+ text = re.sub("<.*>", "", text)
68
+ text = fix_abbr_simple(text)
69
+ #pattern = re.compile(r'([\u4e00-\u9fff])')
70
+ pattern = re.compile(r'([\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u31f0-\u31ff])')
71
+ parts = pattern.split(text.strip().upper())
72
+ parts = [p for p in parts if len(p.strip()) > 0]
73
+ for part in parts:
74
+ if pattern.fullmatch(part) is not None:
75
+ tokens.append(part)
76
+ else:
77
+ for word in part.strip().split():
78
+ tokens.append(word)
79
+ return tokens
80
+
81
+
82
+ def fix_abbr_simple(text):
83
+ ori_text = text
84
+ # 扔掉超长的
85
+ if re.search(r"(?<![a-zA-Z'])([a-zA-Z] ){5,}[a-zA-Z](?![a-zA-Z'])", text):
86
+ return text
87
+
88
+ # 循环处理,直到没有 "X Y" 模式
89
+ prev = None
90
+ while prev != text:
91
+ prev = text
92
+ text = re.sub(r"(?<![a-zA-Z'])([a-zA-Z]) (?=[a-zA-Z](?![a-zA-Z']))", r'\1', text)
93
+
94
+ if ori_text != text:
95
+ print(f"WER FIX: '{ori_text}' --> '{text}'")
96
+ return text
97
+
98
+
99
+
100
+
101
+ def compute_uttid2wer_info(refs, hyps, print_sentence_wer=False):
102
+ print(f">>> Compute uttid to wer info", flush=True)
103
+
104
+ uttid2wer_info = OrderedDict()
105
+ wer_stat = WerStats()
106
+ en_dig_stat = EnDigStats()
107
+
108
+ for uttid, ref in refs.items():
109
+ if uttid not in hyps:
110
+ print(f"[WARN] No hyp for {uttid}", flush=True)
111
+ continue
112
+ hyp = hyps[uttid]
113
+
114
+ if len(hyp) - len(ref) >= 8:
115
+ print(f"[BidLengthDiff]: {uttid} {len(ref)} {len(hyp)}#{' '.join(ref)}#{' '.join(hyp)}")
116
+ #continue
117
+
118
+ wer_info = compute_one_wer_info(ref, hyp)
119
+ uttid2wer_info[uttid] = wer_info
120
+ ns = count_english_ditgit(ref, hyp, wer_info)
121
+ wer_stat.add(wer_info)
122
+ en_dig_stat.add(*ns)
123
+ if print_sentence_wer:
124
+ print(f"{uttid} {wer_info}")
125
+
126
+ return uttid2wer_info, wer_stat, en_dig_stat
127
+
128
+
129
+ COST_SUB = 3
130
+ COST_DEL = 3
131
+ COST_INS = 3
132
+
133
+ ALIGN_CRT = 0
134
+ ALIGN_SUB = 1
135
+ ALIGN_DEL = 2
136
+ ALIGN_INS = 3
137
+ ALIGN_END = 4
138
+
139
+
140
+ def compute_one_wer_info(ref, hyp):
141
+ """Impl minimum edit distance and backtrace.
142
+ Args:
143
+ ref, hyp: List[str]
144
+ Returns:
145
+ WerInfo
146
+ """
147
+ ref_len = len(ref)
148
+ hyp_len = len(hyp)
149
+
150
+ class _DpPoint:
151
+ def __init__(self, cost, align):
152
+ self.cost = cost
153
+ self.align = align
154
+
155
+ dp = []
156
+ for i in range(0, ref_len + 1):
157
+ dp.append([])
158
+ for j in range(0, hyp_len + 1):
159
+ dp[-1].append(_DpPoint(i * j, ALIGN_CRT))
160
+
161
+ # Initialize
162
+ for i in range(1, hyp_len + 1):
163
+ dp[0][i].cost = dp[0][i - 1].cost + COST_INS;
164
+ dp[0][i].align = ALIGN_INS
165
+ for i in range(1, ref_len + 1):
166
+ dp[i][0].cost = dp[i - 1][0].cost + COST_DEL
167
+ dp[i][0].align = ALIGN_DEL
168
+
169
+ # DP
170
+ for i in range(1, ref_len + 1):
171
+ for j in range(1, hyp_len + 1):
172
+ min_cost = 0
173
+ min_align = ALIGN_CRT
174
+ if hyp[j - 1] == ref[i - 1]:
175
+ min_cost = dp[i - 1][j - 1].cost
176
+ min_align = ALIGN_CRT
177
+ else:
178
+ min_cost = dp[i - 1][j - 1].cost + COST_SUB
179
+ min_align = ALIGN_SUB
180
+
181
+ del_cost = dp[i - 1][j].cost + COST_DEL
182
+ if del_cost < min_cost:
183
+ min_cost = del_cost
184
+ min_align = ALIGN_DEL
185
+
186
+ ins_cost = dp[i][j - 1].cost + COST_INS
187
+ if ins_cost < min_cost:
188
+ min_cost = ins_cost
189
+ min_align = ALIGN_INS
190
+
191
+ dp[i][j].cost = min_cost
192
+ dp[i][j].align = min_align
193
+
194
+ # Backtrace
195
+ crt = sub = ins = det = 0
196
+ i = ref_len
197
+ j = hyp_len
198
+ align = []
199
+ while i > 0 or j > 0:
200
+ if dp[i][j].align == ALIGN_CRT:
201
+ align.append((i, j, ALIGN_CRT))
202
+ i -= 1
203
+ j -= 1
204
+ crt += 1
205
+ elif dp[i][j].align == ALIGN_SUB:
206
+ align.append((i, j, ALIGN_SUB))
207
+ i -= 1
208
+ j -= 1
209
+ sub += 1
210
+ elif dp[i][j].align == ALIGN_DEL:
211
+ align.append((i, j, ALIGN_DEL))
212
+ i -= 1
213
+ det += 1
214
+ elif dp[i][j].align == ALIGN_INS:
215
+ align.append((i, j, ALIGN_INS))
216
+ j -= 1
217
+ ins += 1
218
+
219
+ err = sub + det + ins
220
+ align.reverse()
221
+ wer_info = WerInfo(ref_len, err, crt, sub, det, ins, align)
222
+ return wer_info
223
+
224
+
225
+
226
+ class WerInfo:
227
+ def __init__(self, ref, err, crt, sub, dele, ins, ali):
228
+ self.r = ref
229
+ self.e = err
230
+ self.c = crt
231
+ self.s = sub
232
+ self.d = dele
233
+ self.i = ins
234
+ self.ali = ali
235
+ r = max(self.r, 1)
236
+ self.wer = 100.0 * (self.s + self.d + self.i) / r
237
+
238
+ def __repr__(self):
239
+ s = f"wer {self.wer:.2f} ref {self.r:2d} sub {self.s:2d} del {self.d:2d} ins {self.i:2d}"
240
+ return s
241
+
242
+
243
+ class WerStats:
244
+ def __init__(self):
245
+ self.infos = []
246
+
247
+ def add(self, wer_info):
248
+ self.infos.append(wer_info)
249
+
250
+ def print(self):
251
+ r = sum(info.r for info in self.infos)
252
+ if r <= 0:
253
+ print(f"REF len is {r}, check")
254
+ r = 1
255
+ s = sum(info.s for info in self.infos)
256
+ d = sum(info.d for info in self.infos)
257
+ i = sum(info.i for info in self.infos)
258
+ se = 100.0 * s / r
259
+ de = 100.0 * d / r
260
+ ie = 100.0 * i / r
261
+ wer = 100.0 * (s + d + i) / r
262
+ sen = max(len(self.infos), 1)
263
+ errsen = sum(info.e > 0 for info in self.infos)
264
+ ser = 100.0 * errsen / sen
265
+ print("-"*80)
266
+ print(f"ref{r:6d} sub{s:6d} del{d:6d} ins{i:6d}")
267
+ print(f"WER{wer:6.2f} sub{se:6.2f} del{de:6.2f} ins{ie:6.2f}")
268
+ print(f"SER{ser:6.2f} = {errsen} / {sen}")
269
+ print("-"*80)
270
+
271
+
272
+ class EnDigStats:
273
+ def __init__(self):
274
+ self.n_en_word = 0
275
+ self.n_en_correct = 0
276
+ self.n_dig_word = 0
277
+ self.n_dig_correct = 0
278
+
279
+ def add(self, n_en_word, n_en_correct, n_dig_word, n_dig_correct):
280
+ self.n_en_word += n_en_word
281
+ self.n_en_correct += n_en_correct
282
+ self.n_dig_word += n_dig_word
283
+ self.n_dig_correct += n_dig_correct
284
+
285
+ def print(self):
286
+ print(f"English #word={self.n_en_word}, #correct={self.n_en_correct}\n"
287
+ f"Digit #word={self.n_dig_word}, #correct={self.n_dig_correct}")
288
+ print("-"*80)
289
+
290
+
291
+
292
+ def count_english_ditgit(ref, hyp, wer_info):
293
+ patt_en = "[a-zA-Z\.\-\']+"
294
+ patt_dig = "[0-9]+"
295
+ patt_cjk = re.compile(r'([\u4e00-\u9fff])')
296
+ n_en_word = 0
297
+ n_en_correct = 0
298
+ n_dig_word = 0
299
+ n_dig_correct = 0
300
+ ali = wer_info.ali
301
+ for i, token in enumerate(ref):
302
+ if re.match(patt_en, token):
303
+ n_en_word += 1
304
+ for y in ali:
305
+ if y[0] == i+1 and y[2] == ALIGN_CRT:
306
+ j = y[1] - 1
307
+ n_en_correct += 1
308
+ break
309
+ if re.match(patt_dig, token):
310
+ n_dig_word += 1
311
+ for y in ali:
312
+ if y[0] == i+1 and y[2] == ALIGN_CRT:
313
+ j = y[1] - 1
314
+ n_dig_correct += 1
315
+ break
316
+ if not re.match(patt_cjk, token) and not re.match(patt_en, token) \
317
+ and not re.match(patt_dig, token):
318
+ print("[WiredChar]:", token)
319
+ return n_en_word, n_en_correct, n_dig_word, n_dig_correct
320
+
321
+
322
+
323
+ if __name__ == "__main__":
324
+ args = parser.parse_args()
325
+ print(args, flush=True)
326
+ main(args)
fireredasr2s/fireredasr2s-cli ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
4
+
5
+ import argparse
6
+ import glob
7
+ import json
8
+ import logging
9
+ import os
10
+
11
+ import soundfile as sf
12
+ from textgrid import IntervalTier, TextGrid
13
+
14
+ from fireredasr2s.fireredasr2 import FireRedAsr2Config
15
+ from fireredasr2s.fireredasr2system import (FireRedAsr2System,
16
+ FireRedAsr2SystemConfig)
17
+ from fireredasr2s.fireredlid import FireRedLidConfig
18
+ from fireredasr2s.fireredpunc import FireRedPuncConfig
19
+ from fireredasr2s.fireredvad import FireRedVadConfig
20
+
21
+ logging.basicConfig(level=logging.INFO,
22
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
23
+ logger = logging.getLogger("fireredasr2s.asr_system")
24
+
25
+
26
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27
+ input_g = parser.add_argument_group("Input Options")
28
+ input_g.add_argument("--wav_path", type=str)
29
+ input_g.add_argument("--wav_paths", type=str, nargs="*")
30
+ input_g.add_argument("--wav_dir", type=str)
31
+ input_g.add_argument("--wav_scp", type=str)
32
+ input_g.add_argument("--sort_wav_by_dur", type=int, default=0)
33
+
34
+ output_g = parser.add_argument_group("Output Options")
35
+ output_g.add_argument("--outdir", type=str, default="output")
36
+ output_g.add_argument("--write_textgrid", type=int, default=1)
37
+ output_g.add_argument("--write_srt", type=int, default=1)
38
+ output_g.add_argument("--save_segment", type=int, default=0)
39
+
40
+ module_g = parser.add_argument_group("Module Switches")
41
+ module_g.add_argument('--enable_vad', type=int, default=1, choices=[0, 1])
42
+ module_g.add_argument('--enable_lid', type=int, default=1, choices=[0, 1])
43
+ module_g.add_argument('--enable_punc', type=int, default=1, choices=[0, 1])
44
+
45
+ asr_g = parser.add_argument_group("ASR Options")
46
+ asr_g.add_argument('--asr_type', type=str, default="aed", choices=["aed", "llm"])
47
+ asr_g.add_argument('--asr_model_dir', type=str, default="pretrained_models/FireRedASR2-AED")
48
+ asr_g.add_argument('--asr_use_gpu', type=int, default=1)
49
+ asr_g.add_argument('--asr_use_half', type=int, default=0)
50
+ asr_g.add_argument("--asr_batch_size", type=int, default=1)
51
+ # FireRedASR-AED
52
+ asr_g.add_argument("--beam_size", type=int, default=3)
53
+ asr_g.add_argument("--decode_max_len", type=int, default=0)
54
+ asr_g.add_argument("--nbest", type=int, default=1)
55
+ asr_g.add_argument("--softmax_smoothing", type=float, default=1.25)
56
+ asr_g.add_argument("--aed_length_penalty", type=float, default=0.6)
57
+ asr_g.add_argument("--eos_penalty", type=float, default=1.0)
58
+ asr_g.add_argument("--return_timestamp", type=int, default=1)
59
+ # FireRedASR-AED External LM
60
+ asr_g.add_argument("--elm_dir", type=str, default="")
61
+ asr_g.add_argument("--elm_weight", type=float, default=0.0)
62
+
63
+ vad_g = parser.add_argument_group("VAD Options")
64
+ vad_g.add_argument('--vad_model_dir', type=str, default="pretrained_models/FireRedVAD/VAD")
65
+ vad_g.add_argument('--vad_use_gpu', type=int, default=1)
66
+ # Non-streaming VAD
67
+ vad_g.add_argument("--vad_chunk_max_frame", type=int, default=30000)
68
+ vad_g.add_argument("--smooth_window_size", type=int, default=5)
69
+ vad_g.add_argument("--speech_threshold", type=float, default=0.2)
70
+ vad_g.add_argument("--min_speech_frame", type=int, default=20)
71
+ vad_g.add_argument("--max_speech_frame", type=int, default=1000)
72
+ vad_g.add_argument("--min_silence_frame", type=int, default=10)
73
+ vad_g.add_argument("--merge_silence_frame", type=int, default=50)
74
+ vad_g.add_argument("--extend_speech_frame", type=int, default=10)
75
+
76
+ lid_g = parser.add_argument_group("LID Options")
77
+ lid_g.add_argument('--lid_model_dir', type=str, default="pretrained_models/FireRedLID")
78
+ lid_g.add_argument('--lid_use_gpu', type=int, default=1)
79
+
80
+ punc_g = parser.add_argument_group("Punc Options")
81
+ punc_g.add_argument('--punc_model_dir', type=str, default="pretrained_models/FireRedPunc")
82
+ punc_g.add_argument('--punc_use_gpu', type=int, default=1)
83
+ punc_g.add_argument("--punc_batch_size", type=int, default=1)
84
+ punc_g.add_argument('--punc_with_timestamp', type=int, default=1)
85
+ punc_g.add_argument('--punc_sentence_max_length', type=int, default=-1)
86
+
87
+
88
+ def main(args):
89
+ wavs = get_wav_info(args)
90
+ if args.outdir:
91
+ os.makedirs(args.outdir, exist_ok=True)
92
+ fout = open(args.outdir + "/result.jsonl", "w") if args.outdir else None
93
+
94
+ # Build Models
95
+ # VAD
96
+ vad_config = FireRedVadConfig(
97
+ args.vad_use_gpu,
98
+ args.smooth_window_size,
99
+ args.speech_threshold,
100
+ args.min_speech_frame,
101
+ args.max_speech_frame,
102
+ args.min_silence_frame,
103
+ args.merge_silence_frame,
104
+ args.extend_speech_frame,
105
+ args.vad_chunk_max_frame
106
+ )
107
+ # LID
108
+ lid_config = FireRedLidConfig(args.lid_use_gpu)
109
+ # ASR
110
+ asr_config = FireRedAsr2Config(
111
+ args.asr_use_gpu,
112
+ args.asr_use_half,
113
+ args.beam_size,
114
+ args.nbest,
115
+ args.decode_max_len,
116
+ args.softmax_smoothing,
117
+ args.aed_length_penalty,
118
+ args.eos_penalty,
119
+ args.return_timestamp,
120
+ 0, 1.0, 0.0, 1.0,
121
+ args.elm_dir,
122
+ args.elm_weight
123
+ )
124
+ # Punc
125
+ punc_config = FireRedPuncConfig(
126
+ args.punc_use_gpu,
127
+ args.punc_sentence_max_length
128
+ )
129
+
130
+ asr_system_config = FireRedAsr2SystemConfig(
131
+ args.vad_model_dir, args.lid_model_dir,
132
+ args.asr_type, args.asr_model_dir, args.punc_model_dir,
133
+ vad_config, lid_config, asr_config, punc_config,
134
+ args.asr_batch_size, args.punc_batch_size,
135
+ args.enable_vad, args.enable_lid, args.enable_punc
136
+ )
137
+ asr_system = FireRedAsr2System(asr_system_config)
138
+
139
+ for i, (uttid, wav_path) in enumerate(wavs):
140
+ logger.info("")
141
+
142
+ result = asr_system.process(wav_path, uttid)
143
+
144
+ logger.info(f"FINAL: {result}")
145
+
146
+ if fout:
147
+ fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
148
+ fout.flush()
149
+ name = os.path.basename(wav_path).replace(".wav", "")
150
+ if args.write_textgrid:
151
+ tg_dir = os.path.join(args.outdir, "asr_tg")
152
+ write_textgrid(tg_dir, name, result["dur_s"], result["sentences"], result["words"])
153
+ if args.write_srt:
154
+ srt_dir = os.path.join(args.outdir, "asr_srt")
155
+ write_srt(srt_dir, name, result["sentences"])
156
+ if args.save_segment:
157
+ save_segment_dir = os.path.join(args.outdir, "vad_segment")
158
+ split_and_save_segment(wav_path, result["vad_segments_ms"], save_segment_dir)
159
+
160
+ if fout:
161
+ fout.close()
162
+ logger.info("All Done")
163
+
164
+
165
+ def get_wav_info(args):
166
+ """
167
+ Returns:
168
+ wavs: list of (uttid, wav_path)
169
+ """
170
+ def base(p): return os.path.basename(p).replace(".wav", "")
171
+ if args.wav_path:
172
+ wavs = [(base(args.wav_path), args.wav_path)]
173
+ elif args.wav_paths and len(args.wav_paths) >= 1:
174
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
175
+ elif args.wav_scp:
176
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
177
+ elif args.wav_dir:
178
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
179
+ wavs = [(base(p), p) for p in sorted(wavs)]
180
+ else:
181
+ raise ValueError("Please provide valid wav info")
182
+ logger.info(f"#wavs={len(wavs)}")
183
+ return wavs
184
+
185
+
186
+ def write_textgrid(tg_dir, name, wav_dur, sentences, words=None):
187
+ os.makedirs(tg_dir, exist_ok=True)
188
+ textgrid_file = os.path.join(tg_dir, name + ".TextGrid")
189
+ logger.info(f"Write {textgrid_file}")
190
+ textgrid = TextGrid(maxTime=wav_dur)
191
+
192
+ tier = IntervalTier(name="sentence", maxTime=wav_dur)
193
+ for sentence in sentences:
194
+ start_s = sentence["start_ms"] / 1000.0
195
+ end_s = sentence["end_ms"] / 1000.0
196
+ text = sentence["text"]
197
+ confi = sentence["asr_confidence"]
198
+ if start_s == end_s:
199
+ logger.info(f"(sent) Write TG, skip start=end {start_s} {text}")
200
+ continue
201
+ start_s = max(start_s, 0)
202
+ end_s = min(end_s, wav_dur)
203
+ tier.add(minTime=start_s, maxTime=end_s, mark=f"{text}\n{confi}")
204
+ textgrid.append(tier)
205
+
206
+ if words:
207
+ tier = IntervalTier(name="token", maxTime=wav_dur)
208
+ for word in words:
209
+ start_s = word["start_ms"] / 1000.0
210
+ end_s = word["end_ms"] / 1000.0
211
+ text = word["text"]
212
+ if start_s == end_s:
213
+ logger.info(f"(word) Write TG, skip start=end {start_s} {text}")
214
+ continue
215
+ start_s = max(start_s, 0)
216
+ end_s = min(end_s, wav_dur)
217
+ tier.add(minTime=start_s, maxTime=end_s, mark=text)
218
+ textgrid.append(tier)
219
+ textgrid.write(textgrid_file)
220
+
221
+
222
+ def write_srt(srt_dir, name, sentences):
223
+ def _ms2srt_time(ms):
224
+ h = ms // 1000 // 3600
225
+ m = (ms // 1000 % 3600) // 60
226
+ s = (ms // 1000 % 3600) % 60
227
+ ms = (ms % 1000)
228
+ r = f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
229
+ return r
230
+ os.makedirs(srt_dir, exist_ok=True)
231
+ srt_file = os.path.join(srt_dir, name + ".srt")
232
+ logger.info(f"Write {srt_file}")
233
+
234
+ i = 0
235
+ with open(srt_file, "w") as fout:
236
+ for sentence in sentences:
237
+ start_ms = sentence["start_ms"]
238
+ end_ms = sentence["end_ms"]
239
+ text = sentence["text"]
240
+ if text.strip() == "":
241
+ continue
242
+
243
+ i += 1
244
+ fout.write(f"{i}\n")
245
+ s = _ms2srt_time(start_ms)
246
+ e = _ms2srt_time(end_ms)
247
+ fout.write(f"{s} --> {e}\n")
248
+ fout.write(f"{text}\n")
249
+ if i != len(sentences):
250
+ fout.write("\n")
251
+
252
+
253
+ def split_and_save_segment(wav_path, timestamps_ms, save_segment_dir):
254
+ logger.info("Split & save segment")
255
+ os.makedirs(save_segment_dir, exist_ok=True)
256
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
257
+ for i, (start_ms, end_ms) in enumerate(timestamps_ms):
258
+ uttid = wav_path.split("/")[-1].replace(".wav", "")
259
+ seg_id = f"{uttid}_{i}_{start_ms}_{end_ms}"
260
+ seg_path = f"{save_segment_dir}/{seg_id}.wav"
261
+ start = int(start_ms / 1000 * sample_rate)
262
+ end = int(end_ms / 1000 * sample_rate)
263
+ sf.write(seg_path, wav_np[start:end], samplerate=sample_rate)
264
+
265
+
266
+ def cli_main():
267
+ args = parser.parse_args()
268
+ logger.info(args)
269
+ main(args)
270
+
271
+
272
+ if __name__ == "__main__":
273
+ cli_main()
fireredasr2s/fireredasr2s_cli.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
4
+
5
+ import argparse
6
+ import glob
7
+ import json
8
+ import logging
9
+ import os
10
+
11
+ import soundfile as sf
12
+ from textgrid import IntervalTier, TextGrid
13
+
14
+ from fireredasr2s.fireredasr2 import FireRedAsr2Config
15
+ from fireredasr2s.fireredasr2system import (FireRedAsr2System,
16
+ FireRedAsr2SystemConfig)
17
+ from fireredasr2s.fireredlid import FireRedLidConfig
18
+ from fireredasr2s.fireredpunc import FireRedPuncConfig
19
+ from fireredasr2s.fireredvad import FireRedVadConfig
20
+
21
+ logging.basicConfig(level=logging.INFO,
22
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
23
+ logger = logging.getLogger("fireredasr2s.asr_system")
24
+
25
+
26
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27
+ input_g = parser.add_argument_group("Input Options")
28
+ input_g.add_argument("--wav_path", type=str)
29
+ input_g.add_argument("--wav_paths", type=str, nargs="*")
30
+ input_g.add_argument("--wav_dir", type=str)
31
+ input_g.add_argument("--wav_scp", type=str)
32
+ input_g.add_argument("--sort_wav_by_dur", type=int, default=0)
33
+
34
+ output_g = parser.add_argument_group("Output Options")
35
+ output_g.add_argument("--outdir", type=str, default="output")
36
+ output_g.add_argument("--write_textgrid", type=int, default=1)
37
+ output_g.add_argument("--write_srt", type=int, default=1)
38
+ output_g.add_argument("--save_segment", type=int, default=0)
39
+
40
+ module_g = parser.add_argument_group("Module Switches")
41
+ module_g.add_argument('--enable_vad', type=int, default=1, choices=[0, 1])
42
+ module_g.add_argument('--enable_lid', type=int, default=1, choices=[0, 1])
43
+ module_g.add_argument('--enable_punc', type=int, default=1, choices=[0, 1])
44
+
45
+ asr_g = parser.add_argument_group("ASR Options")
46
+ asr_g.add_argument('--asr_type', type=str, default="aed", choices=["aed", "llm"])
47
+ asr_g.add_argument('--asr_model_dir', type=str, default="pretrained_models/FireRedASR2-AED")
48
+ asr_g.add_argument('--asr_use_gpu', type=int, default=1)
49
+ asr_g.add_argument('--asr_use_half', type=int, default=0)
50
+ asr_g.add_argument("--asr_batch_size", type=int, default=1)
51
+ # FireRedASR-AED
52
+ asr_g.add_argument("--beam_size", type=int, default=3)
53
+ asr_g.add_argument("--decode_max_len", type=int, default=0)
54
+ asr_g.add_argument("--nbest", type=int, default=1)
55
+ asr_g.add_argument("--softmax_smoothing", type=float, default=1.25)
56
+ asr_g.add_argument("--aed_length_penalty", type=float, default=0.6)
57
+ asr_g.add_argument("--eos_penalty", type=float, default=1.0)
58
+ asr_g.add_argument("--return_timestamp", type=int, default=1)
59
+ # FireRedASR-AED External LM
60
+ asr_g.add_argument("--elm_dir", type=str, default="")
61
+ asr_g.add_argument("--elm_weight", type=float, default=0.0)
62
+
63
+ vad_g = parser.add_argument_group("VAD Options")
64
+ vad_g.add_argument('--vad_model_dir', type=str, default="pretrained_models/FireRedVAD/VAD")
65
+ vad_g.add_argument('--vad_use_gpu', type=int, default=1)
66
+ # Non-streaming VAD
67
+ vad_g.add_argument("--vad_chunk_max_frame", type=int, default=30000)
68
+ vad_g.add_argument("--smooth_window_size", type=int, default=5)
69
+ vad_g.add_argument("--speech_threshold", type=float, default=0.2)
70
+ vad_g.add_argument("--min_speech_frame", type=int, default=20)
71
+ vad_g.add_argument("--max_speech_frame", type=int, default=1000)
72
+ vad_g.add_argument("--min_silence_frame", type=int, default=10)
73
+ vad_g.add_argument("--merge_silence_frame", type=int, default=50)
74
+ vad_g.add_argument("--extend_speech_frame", type=int, default=10)
75
+
76
+ lid_g = parser.add_argument_group("LID Options")
77
+ lid_g.add_argument('--lid_model_dir', type=str, default="pretrained_models/FireRedLID")
78
+ lid_g.add_argument('--lid_use_gpu', type=int, default=1)
79
+
80
+ punc_g = parser.add_argument_group("Punc Options")
81
+ punc_g.add_argument('--punc_model_dir', type=str, default="pretrained_models/FireRedPunc")
82
+ punc_g.add_argument('--punc_use_gpu', type=int, default=1)
83
+ punc_g.add_argument("--punc_batch_size", type=int, default=1)
84
+ punc_g.add_argument('--punc_with_timestamp', type=int, default=1)
85
+ punc_g.add_argument('--punc_sentence_max_length', type=int, default=-1)
86
+
87
+
88
+ def main(args):
89
+ wavs = get_wav_info(args)
90
+ if args.outdir:
91
+ os.makedirs(args.outdir, exist_ok=True)
92
+ fout = open(args.outdir + "/result.jsonl", "w") if args.outdir else None
93
+
94
+ # Build Models
95
+ # VAD
96
+ vad_config = FireRedVadConfig(
97
+ args.vad_use_gpu,
98
+ args.smooth_window_size,
99
+ args.speech_threshold,
100
+ args.min_speech_frame,
101
+ args.max_speech_frame,
102
+ args.min_silence_frame,
103
+ args.merge_silence_frame,
104
+ args.extend_speech_frame,
105
+ args.vad_chunk_max_frame
106
+ )
107
+ # LID
108
+ lid_config = FireRedLidConfig(args.lid_use_gpu)
109
+ # ASR
110
+ asr_config = FireRedAsr2Config(
111
+ args.asr_use_gpu,
112
+ args.asr_use_half,
113
+ args.beam_size,
114
+ args.nbest,
115
+ args.decode_max_len,
116
+ args.softmax_smoothing,
117
+ args.aed_length_penalty,
118
+ args.eos_penalty,
119
+ args.return_timestamp,
120
+ 0, 1.0, 0.0, 1.0,
121
+ args.elm_dir,
122
+ args.elm_weight
123
+ )
124
+ # Punc
125
+ punc_config = FireRedPuncConfig(
126
+ args.punc_use_gpu,
127
+ args.punc_sentence_max_length
128
+ )
129
+
130
+ asr_system_config = FireRedAsr2SystemConfig(
131
+ args.vad_model_dir, args.lid_model_dir,
132
+ args.asr_type, args.asr_model_dir, args.punc_model_dir,
133
+ vad_config, lid_config, asr_config, punc_config,
134
+ args.asr_batch_size, args.punc_batch_size,
135
+ args.enable_vad, args.enable_lid, args.enable_punc
136
+ )
137
+ asr_system = FireRedAsr2System(asr_system_config)
138
+
139
+ for i, (uttid, wav_path) in enumerate(wavs):
140
+ logger.info("")
141
+
142
+ result = asr_system.process(wav_path, uttid)
143
+
144
+ logger.info(f"FINAL: {result}")
145
+
146
+ if fout:
147
+ fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
148
+ fout.flush()
149
+ name = os.path.basename(wav_path).replace(".wav", "")
150
+ if args.write_textgrid:
151
+ tg_dir = os.path.join(args.outdir, "asr_tg")
152
+ write_textgrid(tg_dir, name, result["dur_s"], result["sentences"], result["words"])
153
+ if args.write_srt:
154
+ srt_dir = os.path.join(args.outdir, "asr_srt")
155
+ write_srt(srt_dir, name, result["sentences"])
156
+ if args.save_segment:
157
+ save_segment_dir = os.path.join(args.outdir, "vad_segment")
158
+ split_and_save_segment(wav_path, result["vad_segments_ms"], save_segment_dir)
159
+
160
+ if fout:
161
+ fout.close()
162
+ logger.info("All Done")
163
+
164
+
165
+ def get_wav_info(args):
166
+ """
167
+ Returns:
168
+ wavs: list of (uttid, wav_path)
169
+ """
170
+ def base(p): return os.path.basename(p).replace(".wav", "")
171
+ if args.wav_path:
172
+ wavs = [(base(args.wav_path), args.wav_path)]
173
+ elif args.wav_paths and len(args.wav_paths) >= 1:
174
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
175
+ elif args.wav_scp:
176
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
177
+ elif args.wav_dir:
178
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
179
+ wavs = [(base(p), p) for p in sorted(wavs)]
180
+ else:
181
+ raise ValueError("Please provide valid wav info")
182
+ logger.info(f"#wavs={len(wavs)}")
183
+ return wavs
184
+
185
+
186
+ def write_textgrid(tg_dir, name, wav_dur, sentences, words=None):
187
+ os.makedirs(tg_dir, exist_ok=True)
188
+ textgrid_file = os.path.join(tg_dir, name + ".TextGrid")
189
+ logger.info(f"Write {textgrid_file}")
190
+ textgrid = TextGrid(maxTime=wav_dur)
191
+
192
+ tier = IntervalTier(name="sentence", maxTime=wav_dur)
193
+ for sentence in sentences:
194
+ start_s = sentence["start_ms"] / 1000.0
195
+ end_s = sentence["end_ms"] / 1000.0
196
+ text = sentence["text"]
197
+ confi = sentence["asr_confidence"]
198
+ if start_s == end_s:
199
+ logger.info(f"(sent) Write TG, skip start=end {start_s} {text}")
200
+ continue
201
+ start_s = max(start_s, 0)
202
+ end_s = min(end_s, wav_dur)
203
+ tier.add(minTime=start_s, maxTime=end_s, mark=f"{text}\n{confi}")
204
+ textgrid.append(tier)
205
+
206
+ if words:
207
+ tier = IntervalTier(name="token", maxTime=wav_dur)
208
+ for word in words:
209
+ start_s = word["start_ms"] / 1000.0
210
+ end_s = word["end_ms"] / 1000.0
211
+ text = word["text"]
212
+ if start_s == end_s:
213
+ logger.info(f"(word) Write TG, skip start=end {start_s} {text}")
214
+ continue
215
+ start_s = max(start_s, 0)
216
+ end_s = min(end_s, wav_dur)
217
+ tier.add(minTime=start_s, maxTime=end_s, mark=text)
218
+ textgrid.append(tier)
219
+ textgrid.write(textgrid_file)
220
+
221
+
222
+ def write_srt(srt_dir, name, sentences):
223
+ def _ms2srt_time(ms):
224
+ h = ms // 1000 // 3600
225
+ m = (ms // 1000 % 3600) // 60
226
+ s = (ms // 1000 % 3600) % 60
227
+ ms = (ms % 1000)
228
+ r = f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
229
+ return r
230
+ os.makedirs(srt_dir, exist_ok=True)
231
+ srt_file = os.path.join(srt_dir, name + ".srt")
232
+ logger.info(f"Write {srt_file}")
233
+
234
+ i = 0
235
+ with open(srt_file, "w") as fout:
236
+ for sentence in sentences:
237
+ start_ms = sentence["start_ms"]
238
+ end_ms = sentence["end_ms"]
239
+ text = sentence["text"]
240
+ if text.strip() == "":
241
+ continue
242
+
243
+ i += 1
244
+ fout.write(f"{i}\n")
245
+ s = _ms2srt_time(start_ms)
246
+ e = _ms2srt_time(end_ms)
247
+ fout.write(f"{s} --> {e}\n")
248
+ fout.write(f"{text}\n")
249
+ if i != len(sentences):
250
+ fout.write("\n")
251
+
252
+
253
+ def split_and_save_segment(wav_path, timestamps_ms, save_segment_dir):
254
+ logger.info("Split & save segment")
255
+ os.makedirs(save_segment_dir, exist_ok=True)
256
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
257
+ for i, (start_ms, end_ms) in enumerate(timestamps_ms):
258
+ uttid = wav_path.split("/")[-1].replace(".wav", "")
259
+ seg_id = f"{uttid}_{i}_{start_ms}_{end_ms}"
260
+ seg_path = f"{save_segment_dir}/{seg_id}.wav"
261
+ start = int(start_ms / 1000 * sample_rate)
262
+ end = int(end_ms / 1000 * sample_rate)
263
+ sf.write(seg_path, wav_np[start:end], samplerate=sample_rate)
264
+
265
+
266
+ def cli_main():
267
+ args = parser.parse_args()
268
+ logger.info(args)
269
+ main(args)
270
+
271
+
272
+ if __name__ == "__main__":
273
+ cli_main()
fireredasr2s/fireredasr2system.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
2
+
3
+ import logging
4
+ import re
5
+ from dataclasses import dataclass, field
6
+
7
+ import soundfile as sf
8
+
9
+ from fireredasr2s.fireredasr2 import FireRedAsr2, FireRedAsr2Config
10
+ from fireredasr2s.fireredlid import FireRedLid, FireRedLidConfig
11
+ from fireredasr2s.fireredpunc import FireRedPunc, FireRedPuncConfig
12
+ from fireredasr2s.fireredvad import FireRedVad, FireRedVadConfig
13
+
14
+ logging.basicConfig(level=logging.INFO,
15
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
16
+ logger = logging.getLogger("fireredasr2s.asr_system")
17
+
18
+
19
+ @dataclass
20
+ class FireRedAsr2SystemConfig:
21
+ vad_model_dir: str = "pretrained_models/FireRedVAD/VAD"
22
+ lid_model_dir: str = "pretrained_models/FireRedLID"
23
+ asr_type: str = "aed"
24
+ asr_model_dir: str = "pretrained_models/FireRedASR2-AED"
25
+ punc_model_dir: str = "pretrained_models/FireRedPunc"
26
+ vad_config: FireRedVadConfig = field(default_factory=FireRedVadConfig)
27
+ lid_config: FireRedLidConfig = field(default_factory=FireRedLidConfig)
28
+ asr_config: FireRedAsr2Config = field(default_factory=FireRedAsr2Config)
29
+ punc_config: FireRedPuncConfig = field(default_factory=FireRedPuncConfig)
30
+ asr_batch_size: int = 1
31
+ punc_batch_size: int = 1
32
+ enable_vad: bool = True
33
+ enable_lid: bool = True
34
+ enable_punc: bool = True
35
+
36
+
37
+ class FireRedAsr2System:
38
+ def __init__(self, config):
39
+ c = config
40
+ self.vad = FireRedVad.from_pretrained(c.vad_model_dir, c.vad_config) if c.enable_vad else None
41
+ self.lid = FireRedLid.from_pretrained(c.lid_model_dir, c.lid_config) if c.enable_lid else None
42
+ self.asr = FireRedAsr2.from_pretrained(c.asr_type, c.asr_model_dir, c.asr_config)
43
+ self.punc = FireRedPunc.from_pretrained(c.punc_model_dir, c.punc_config) if c.enable_punc else None
44
+ self.config = config
45
+
46
+ def process(self, wav_path, uttid="tmpid"):
47
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
48
+ dur = wav_np.shape[0]/sample_rate
49
+
50
+ # 1. VAD
51
+ if self.config.enable_vad:
52
+ vad_result, prob = self.vad.detect(wav_path)
53
+ vad_segments = vad_result["timestamps"]
54
+ logger.info(f"VAD: {vad_result}")
55
+ else:
56
+ vad_segments = [(0, dur)]
57
+ vad_result = {"timestamps" : vad_segments}
58
+
59
+ # 2. VAD output to ASR input
60
+ asr_results = []
61
+ lid_results = []
62
+ assert sample_rate == 16000
63
+ batch_asr_uttid = []
64
+ batch_asr_wav = []
65
+ for j, (start_s, end_s) in enumerate(vad_segments):
66
+ wav_segment = wav_np[int(start_s*sample_rate):int(end_s*sample_rate)]
67
+ vad_uttid = f"{uttid}_s{int(start_s*1000)}_e{int(end_s*1000)}"
68
+ batch_asr_uttid.append(vad_uttid)
69
+ batch_asr_wav.append((sample_rate, wav_segment))
70
+ if len(batch_asr_uttid) < self.config.asr_batch_size and j != len(vad_segments) - 1:
71
+ continue
72
+
73
+ # 3. ASR
74
+ batch_asr_results = self.asr.transcribe(batch_asr_uttid, batch_asr_wav)
75
+ logger.info(f"ASR: {batch_asr_results}")
76
+
77
+ if self.config.enable_lid:
78
+ batch_lid_results = self.lid.process(batch_asr_uttid, batch_asr_wav)
79
+ logger.info(f"LID: {batch_lid_results}")
80
+ else:
81
+ # Note: The original batch size is used here to ensure alignment with the initial number of ASR results
82
+ batch_lid_results = [None] * len(batch_asr_results)
83
+
84
+ # Synchronously traverse and filter to ensure that asr_results and lid_results always maintain a one-to-one correspondence
85
+ for a_res, l_res in zip(batch_asr_results, batch_lid_results):
86
+ text = a_res.get("text", "").strip()
87
+ # Filter out <blank>, <sil> and completely empty strings ""
88
+ if not text or re.search(r"(<blank>)|(<sil>)", text):
89
+ continue
90
+ asr_results.append(a_res)
91
+ lid_results.append(l_res)
92
+
93
+ batch_asr_uttid = []
94
+ batch_asr_wav = []
95
+
96
+ # 4. ASR output to Postprocess input
97
+ if self.config.enable_punc:
98
+ punc_results = []
99
+ batch_asr_text = []
100
+ batch_asr_uttid = []
101
+ batch_asr_timestamp = []
102
+ for j, asr_result in enumerate(asr_results):
103
+ batch_asr_text.append(asr_result["text"])
104
+ batch_asr_uttid.append(asr_result["uttid"])
105
+ if self.config.asr_config.return_timestamp:
106
+ batch_asr_timestamp.append(asr_result.get("timestamp", []))
107
+ elif "timestamp" in asr_result:
108
+ batch_asr_timestamp.append(asr_result["timestamp"])
109
+ if len(batch_asr_text) < self.config.punc_batch_size and j != len(asr_results) - 1:
110
+ continue
111
+
112
+ # 5. Punc
113
+ if self.config.asr_config.return_timestamp:
114
+ batch_punc_results = self.punc.process_with_timestamp(batch_asr_timestamp, batch_asr_uttid)
115
+ else:
116
+ batch_punc_results = self.punc.process(batch_asr_text, batch_asr_uttid)
117
+ logger.info(f"Punc: {batch_punc_results}")
118
+
119
+ punc_results.extend(batch_punc_results)
120
+ batch_asr_text = []
121
+ batch_asr_uttid = []
122
+ batch_asr_timestamp = []
123
+ else:
124
+ punc_results = asr_results
125
+
126
+ # 6. Put all together & Format
127
+ sentences = []
128
+ words = []
129
+ for asr_result, punc_result, lid_result in zip(asr_results, punc_results, lid_results):
130
+ assert asr_result["uttid"] == punc_result["uttid"], f"fix code: {asr_result} | {punc_result}"
131
+ start_ms, end_ms = asr_result["uttid"].split("_")[-2:]
132
+ assert start_ms.startswith("s") and end_ms.startswith("e")
133
+ start_ms, end_ms = int(start_ms[1:]), int(end_ms[1:])
134
+ if self.config.asr_config.return_timestamp:
135
+ sub_sentences = []
136
+ if self.config.enable_punc:
137
+ for i, punc_sent in enumerate(punc_result["punc_sentences"]):
138
+ start = start_ms + int(punc_sent["start_s"]*1000)
139
+ end = start_ms + int(punc_sent["end_s"]*1000)
140
+ if i == 0:
141
+ start = start_ms
142
+ if i == len(punc_result["punc_sentences"]) - 1:
143
+ end = end_ms
144
+ sub_sentence = {
145
+ "start_ms": start,
146
+ "end_ms": end,
147
+ "text": punc_sent["punc_text"],
148
+ "asr_confidence": asr_result["confidence"],
149
+ "lang": None,
150
+ "lang_confidence": 0
151
+ }
152
+ if lid_result:
153
+ sub_sentence["lang"] = lid_result["lang"]
154
+ sub_sentence["lang_confidence"] = lid_result["confidence"]
155
+ sub_sentences.append(sub_sentence)
156
+ else:
157
+ sub_sentences = [{
158
+ "start_ms": start_ms,
159
+ "end_ms": end_ms,
160
+ "text": asr_result["text"],
161
+ "asr_confidence": asr_result["confidence"],
162
+ "lang": None,
163
+ "lang_confidence": 0
164
+ }]
165
+ sentences.extend(sub_sentences)
166
+ else:
167
+ text = punc_result["punc_text"] if self.config.enable_punc else asr_result["text"]
168
+ sentence = {
169
+ "start_ms": start_ms,
170
+ "end_ms": end_ms,
171
+ "text": text,
172
+ "asr_confidence": asr_result["confidence"],
173
+ "lang": None,
174
+ "lang_confidence": 0
175
+ }
176
+ if lid_result:
177
+ sentence["lang"] = lid_result["lang"]
178
+ sentence["lang_confidence"] = lid_result["confidence"]
179
+ sentences.append(sentence)
180
+
181
+ if "timestamp" in asr_result:
182
+ for w, s, e in asr_result["timestamp"]:
183
+ word = {"start_ms": int(s*1000+start_ms), "end_ms":int(e*1000+start_ms), "text": w}
184
+ words.append(word)
185
+
186
+ vad_segments_ms = [(int(s*1000), int(e*1000)) for s, e in vad_result["timestamps"]]
187
+ text = "".join(s["text"] for s in sentences)
188
+ # Add space after English punctuation when followed by a letter
189
+ text = re.sub(r'([.,!?])\s*([a-zA-Z])', r'\1 \2', text)
190
+
191
+ result = {
192
+ "uttid": uttid,
193
+ "text": text,
194
+ "sentences": sentences,
195
+ "vad_segments_ms": vad_segments_ms,
196
+ "dur_s": dur,
197
+ "words": words,
198
+ "wav_path": wav_path
199
+ }
200
+ return result
fireredasr2s/fireredlid/README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Language Code
2
+
3
+ | Language Code | English Name | Chinese Name |
4
+ |---|---|---|
5
+ | zh | Chinese | 中文 |
6
+ | en | English | 英语 |
7
+ | es | Spanish | 西班牙语 |
8
+ | fr | French | 法语 |
9
+ | ja | Japanese | 日语 |
10
+ | ko | Korean | 韩语 |
11
+ | ru | Russian | 俄语 |
12
+ | de | German | 德语 |
13
+ | pt | Portuguese | 葡萄牙语 |
14
+ | ab | Abkhazian | 阿布哈兹语 |
15
+ | af | Afrikaans | 南非荷兰语 |
16
+ | am | Amharic | 阿姆哈拉语 |
17
+ | ar | Arabic | 阿拉伯语 |
18
+ | as | Assamese | 阿萨姆语 |
19
+ | az | Azerbaijani | 阿塞拜疆语 |
20
+ | ba | Bashkir | 巴什基尔语 |
21
+ | be | Belarusian | 白俄罗斯语 |
22
+ | bg | Bulgarian | 保加利亚语 |
23
+ | bn | Bengali | 孟加拉语 |
24
+ | br | Breton | 布列塔尼语 |
25
+ | bs | Bosnian | 波斯尼亚语 |
26
+ | ca | Catalan | 加泰罗尼亚语 |
27
+ | ceb | Cebuano | 宿务语 |
28
+ | cs | Czech | 捷克语 |
29
+ | cy | Welsh | 威尔士语 |
30
+ | da | Danish | 丹麦语 |
31
+ | el | Greek | 希腊语 |
32
+ | eo | Esperanto | 世界语 |
33
+ | et | Estonian | 爱沙尼亚语 |
34
+ | eu | Basque | 巴斯克语 |
35
+ | fa | Persian | 波斯语 |
36
+ | fi | Finnish | 芬兰语 |
37
+ | fo | Faroese | 法罗语 |
38
+ | gl | Galician | 加利西亚语 |
39
+ | gn | Guarani | 瓜拉尼语 |
40
+ | gu | Gujarati | 古吉拉特语 |
41
+ | gv | Manx | 马恩语 |
42
+ | ha | Hausa | 豪萨语 |
43
+ | haw | Hawaiian | 夏威夷语 |
44
+ | hi | Hindi | 印地语 |
45
+ | hr | Croatian | 克罗地亚语 |
46
+ | ht | Haitian Creole | 海地克里奥尔语 |
47
+ | hu | Hungarian | 匈牙利语 |
48
+ | hy | Armenian | 亚美尼亚语 |
49
+ | ia | Interlingua | 国际语 |
50
+ | id | Indonesian | 印度尼西亚语 |
51
+ | is | Icelandic | 冰岛语 |
52
+ | it | Italian | 意大利语 |
53
+ | iw | Hebrew | 希伯来语 |
54
+ | jw | Javanese | 爪哇语 |
55
+ | ka | Georgian | 格鲁吉亚语 |
56
+ | kk | Kazakh | 哈萨克语 |
57
+ | km | Khmer | 高棉语 |
58
+ | kn | Kannada | 卡纳达语 |
59
+ | la | Latin | 拉丁语 |
60
+ | lb | Luxembourgish | 卢森堡语 |
61
+ | ln | Lingala | 林加拉语 |
62
+ | lo | Lao | 老挝语 |
63
+ | lt | Lithuanian | 立陶宛语 |
64
+ | lv | Latvian | 拉脱维亚语 |
65
+ | mg | Malagasy | 马尔加什语 |
66
+ | mi | Māori | 毛利语 |
67
+ | mk | Macedonian | 马其顿语 |
68
+ | ml | Malayalam | 马拉雅拉姆语 |
69
+ | mn | Mongolian | 蒙古语 |
70
+ | mr | Marathi | 马拉地语 |
71
+ | ms | Malay | 马来语 |
72
+ | mt | Maltese | 马耳他语 |
73
+ | my | Burmese | 缅甸语 |
74
+ | ne | Nepali | 尼泊尔语 |
75
+ | nl | Dutch | 荷兰语 |
76
+ | nn | Norwegian Nynorsk | 挪威语 |
77
+ | no | Norwegian | 挪威语 |
78
+ | oc | Occitan | 奥克语 |
79
+ | pa | Punjabi | 旁遮普语 |
80
+ | pl | Polish | 波兰语 |
81
+ | ps | Pashto | 普什图语 |
82
+ | ro | Romanian | 罗马尼亚语 |
83
+ | sa | Sanskrit | 梵语 |
84
+ | sco | Scots | 苏格兰语 |
85
+ | sd | Sindhi | 信德语 |
86
+ | si | Sinhala | 僧伽罗语 |
87
+ | sk | Slovak | 斯洛伐克语 |
88
+ | sl | Slovenian | 斯洛文尼亚语 |
89
+ | sn | Shona | 绍纳语 |
90
+ | so | Somali | 索马里语 |
91
+ | sq | Albanian | 阿尔巴尼亚语 |
92
+ | sr | Serbian | 塞尔维亚语 |
93
+ | su | Sundanese | 巽他语 |
94
+ | sv | Swedish | 瑞典语 |
95
+ | sw | Swahili | 斯瓦希里语 |
96
+ | ta | Tamil | 泰米尔语 |
97
+ | te | Telugu | 泰卢固语 |
98
+ | tg | Tajik | 塔吉克语 |
99
+ | th | Thai | 泰语 |
100
+ | tk | Turkmen | 土库曼语 |
101
+ | tl | Tagalog | 塔加洛语 |
102
+ | tr | Turkish | 土耳其语 |
103
+ | tt | Tatar | 鞑靼语 |
104
+ | uk | Ukrainian | 乌克兰语 |
105
+ | ur | Urdu | 乌尔都语 |
106
+ | uz | Uzbek | 乌兹别克语 |
107
+ | vi | Vietnamese | 越南语 |
108
+ | war | Waray | 瓦赖语 |
109
+ | yi | Yiddish | 意第绪语 |
110
+ | yo | Yoruba | 约鲁巴语 |
111
+
112
+
113
+
114
+ ## Language Region Code(Chinese Dialects)
115
+
116
+ | Language Region Code | English Name | Chinese Name |
117
+ |---|---|---|
118
+ | zh-mandarin | Chinese (Mandarin) | 中文(普通话) |
119
+ | zh-yue | Chinese (Yue, Guangdong) | 中文(粤语-广东) |
120
+ | zh-wu | Chinese (Wu, Shanghai) | 中文(吴语-上海) |
121
+ | zh-min | Chinese (Min, Fujian) | 中文(闽语-福建) |
122
+ | zh-north | Chinese (Mandarin, North) — Shandong / Gansu / Ningxia / Hebei / Shanxi / Liaoning / Shaanxi | 中文(官话-北方:山东/甘肃/宁夏/河北/山西/辽宁/陕西) |
123
+ | zh-xinan | Chinese (Mandarin, Southwest) — Sichuan / Yunnan / Guizhou / Hubei / Chongqing | 中文(官话-西南:四川/云南/贵州/湖北/重庆) |
124
+ | zh-xiang | Chinese (Xiang, Hunan) | 中文(湘语-湖南) |
125
+ |bo|Chinese (Tibetan) | 中文(藏语)|
fireredasr2s/fireredlid/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import os
4
+ import sys
5
+
6
+ __version__ = "0.0.1"
7
+
8
+ _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ try:
11
+ from fireredasr2s.fireredlid.lid import FireRedLid, FireRedLidConfig
12
+ except ImportError:
13
+ if _CURRENT_DIR not in sys.path:
14
+ sys.path.insert(0, _CURRENT_DIR)
15
+ from .lid import FireRedLid, FireRedLidConfig
16
+
17
+
18
+ # API
19
+ __all__ = [
20
+ "__version__",
21
+ "FireRedLid",
22
+ "FireRedLidConfig",
23
+ ]
fireredasr2s/fireredlid/data/feat.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import math
4
+ import os
5
+
6
+ import kaldiio
7
+ import kaldi_native_fbank as knf
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class FeatExtractor:
13
+ def __init__(self, kaldi_cmvn_file):
14
+ self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
15
+ self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
16
+ frame_shift=10, dither=0.0)
17
+
18
+ def __call__(self, wav_paths, wav_uttids):
19
+ feats = []
20
+ durs = []
21
+ return_wav_paths = []
22
+ return_wav_uttids = []
23
+
24
+ wav_datas = []
25
+ if isinstance(wav_paths[0], str):
26
+ for wav_path in wav_paths:
27
+ sample_rate, wav_np = kaldiio.load_mat(wav_path)
28
+ wav_datas.append([sample_rate, wav_np])
29
+ else:
30
+ wav_datas = wav_paths
31
+
32
+ for (sample_rate, wav_np), path, uttid in zip(wav_datas, wav_paths, wav_uttids):
33
+ dur = wav_np.shape[0] / sample_rate
34
+ fbank = self.fbank((sample_rate, wav_np))
35
+ if fbank.shape[0] < 1:
36
+ continue
37
+ if self.cmvn is not None:
38
+ fbank = self.cmvn(fbank)
39
+ fbank = torch.from_numpy(fbank).float()
40
+ feats.append(fbank)
41
+ durs.append(dur)
42
+ return_wav_paths.append(path)
43
+ return_wav_uttids.append(uttid)
44
+ if len(feats) > 0:
45
+ lengths = torch.tensor([feat.size(0) for feat in feats]).long()
46
+ feats_pad = self.pad_feat(feats, 0.0)
47
+ else:
48
+ lengths, feats_pad = None, None
49
+ return feats_pad, lengths, durs, return_wav_paths, return_wav_uttids
50
+
51
+ def pad_feat(self, xs, pad_value):
52
+ # type: (List[Tensor], int) -> Tensor
53
+ n_batch = len(xs)
54
+ max_len = max([xs[i].size(0) for i in range(n_batch)])
55
+ pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value)
56
+ for i in range(n_batch):
57
+ pad[i, :xs[i].size(0)] = xs[i]
58
+ return pad
59
+
60
+
61
+ class CMVN:
62
+ def __init__(self, kaldi_cmvn_file):
63
+ self.dim, self.means, self.inverse_std_variences = \
64
+ self.read_kaldi_cmvn(kaldi_cmvn_file)
65
+
66
+ def __call__(self, x, is_train=False):
67
+ assert x.shape[-1] == self.dim, "CMVN dim mismatch"
68
+ out = x - self.means
69
+ out = out * self.inverse_std_variences
70
+ return out
71
+
72
+ def read_kaldi_cmvn(self, kaldi_cmvn_file):
73
+ assert os.path.exists(kaldi_cmvn_file)
74
+ stats = kaldiio.load_mat(kaldi_cmvn_file)
75
+ assert stats.shape[0] == 2
76
+ dim = stats.shape[-1] - 1
77
+ count = stats[0, dim]
78
+ assert count >= 1
79
+ floor = 1e-20
80
+ means = []
81
+ inverse_std_variences = []
82
+ for d in range(dim):
83
+ mean = stats[0, d] / count
84
+ means.append(mean.item())
85
+ varience = (stats[1, d] / count) - mean*mean
86
+ if varience < floor:
87
+ varience = floor
88
+ istd = 1.0 / math.sqrt(varience)
89
+ inverse_std_variences.append(istd)
90
+ return dim, np.array(means), np.array(inverse_std_variences)
91
+
92
+
93
+
94
+ class KaldifeatFbank:
95
+ def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
96
+ dither=1.0):
97
+ self.dither = dither
98
+ opts = knf.FbankOptions()
99
+ opts.frame_opts.dither = dither
100
+ opts.mel_opts.num_bins = num_mel_bins
101
+ opts.frame_opts.snip_edges = True
102
+ opts.mel_opts.debug_mel = False
103
+ self.opts = opts
104
+
105
+ def __call__(self, wav, is_train=False):
106
+ if type(wav) is str:
107
+ sample_rate, wav_np = kaldiio.load_mat(wav)
108
+ elif type(wav) in [tuple, list] and len(wav) == 2:
109
+ sample_rate, wav_np = wav
110
+ assert len(wav_np.shape) == 1
111
+
112
+ dither = self.dither if is_train else 0.0
113
+ self.opts.frame_opts.dither = dither
114
+ fbank = knf.OnlineFbank(self.opts)
115
+
116
+ fbank.accept_waveform(sample_rate, wav_np.tolist())
117
+ feat = []
118
+ for i in range(fbank.num_frames_ready):
119
+ feat.append(fbank.get_frame(i))
120
+ if len(feat) == 0:
121
+ print("Check data, len(feat) == 0", wav, flush=True)
122
+ return np.zeros((0, self.opts.mel_opts.num_bins))
123
+ feat = np.vstack(feat)
124
+ return feat
fireredasr2s/fireredlid/data/token_dict.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class TokenDict:
9
+ def __init__(self, dict_path, unk=""):
10
+ assert dict_path != ""
11
+ self.id2word, self.word2id = self.read_dict(dict_path)
12
+ self.unk = unk
13
+ assert unk == "" or unk in self.word2id
14
+ self.unkid = self.word2id[unk] if unk else -1
15
+
16
+ def get(self, key, default):
17
+ if type(default) == str:
18
+ default = self.word2id[default]
19
+ return self.word2id.get(key, default)
20
+
21
+ def __getitem__(self, key):
22
+ if type(key) == str:
23
+ if self.unk:
24
+ return self.word2id.get(key, self.word2id[self.unk])
25
+ else:
26
+ return self.word2id[key]
27
+ elif type(key) == int:
28
+ return self.id2word[key]
29
+ else:
30
+ raise TypeError("Key should be str or int")
31
+
32
+ def __len__(self):
33
+ return len(self.id2word)
34
+
35
+ def __contains__(self, query):
36
+ if type(query) == str:
37
+ return query in self.word2id
38
+ elif type(query) == int:
39
+ return query in self.id2word
40
+ else:
41
+ raise TypeError("query should be str or int")
42
+
43
+ def read_dict(self, dict_path):
44
+ id2word, word2id = [], {}
45
+ with open(dict_path, encoding='utf8') as f:
46
+ for i, line in enumerate(f):
47
+ tokens = line.strip().split()
48
+ if len(tokens) >= 2:
49
+ word, index = tokens[0], int(tokens[1])
50
+ elif len(tokens) == 1:
51
+ word, index = tokens[0], i
52
+ else: # empty line or space
53
+ logger.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
54
+ word, index = " ", i
55
+ assert len(id2word) == index
56
+ assert len(word2id) == index
57
+ if word == "<space>":
58
+ logger.info(f"NOTE: Find <space> in {dict_path}:L{i} and convert it to ' '")
59
+ word = " "
60
+ word2id[word] = index
61
+ id2word.append(word)
62
+ assert len(id2word) == len(word2id)
63
+ return id2word, word2id
fireredasr2s/fireredlid/lid.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import os
4
+ import re
5
+ import time
6
+ import traceback
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+
11
+ from .data.feat import FeatExtractor
12
+ from .models.fireredlid_aed import FireRedLidAed
13
+ from .models.param import count_model_parameters
14
+ from .tokenizer.lid_tokenizer import LidTokenizer
15
+
16
+
17
+ @dataclass
18
+ class FireRedLidConfig:
19
+ use_gpu: bool = True
20
+ use_half: bool = False
21
+
22
+
23
+ class FireRedLid:
24
+ @classmethod
25
+ def from_pretrained(cls, model_dir, config=FireRedLidConfig()):
26
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
27
+ feat_extractor = FeatExtractor(cmvn_path)
28
+
29
+ model_path = os.path.join(model_dir, "model.pth.tar")
30
+ dict_path =os.path.join(model_dir, "dict.txt")
31
+ model = load_fireredlid_model(model_path)
32
+ tokenizer = LidTokenizer(dict_path)
33
+
34
+ count_model_parameters(model)
35
+ model.eval()
36
+ return cls(feat_extractor, model, tokenizer, config)
37
+
38
+ def __init__(self, feat_extractor, model, tokenizer, config):
39
+ self.feat_extractor = feat_extractor
40
+ self.model = model
41
+ self.tokenizer = tokenizer
42
+ self.config = config
43
+ self.config.beam_size = 3
44
+ self.config.nbest = 1
45
+ self.config.decode_max_len = 2
46
+ self.config.softmax_smoothing = 1.25
47
+ self.config.aed_length_penalty = 0.6
48
+ self.config.eos_penalty = 1.0
49
+ if self.config.use_gpu:
50
+ if self.config.use_half:
51
+ self.model.half()
52
+ self.model.cuda()
53
+ else:
54
+ self.model.cpu()
55
+
56
+ @torch.no_grad()
57
+ def process(self, batch_uttid, batch_wav_path):
58
+ batch_uttid_origin = batch_uttid
59
+ try:
60
+ feats, lengths, durs, batch_wav_path, batch_uttid = \
61
+ self.feat_extractor(batch_wav_path, batch_uttid)
62
+ if feats is None:
63
+ return [{"uttid": uttid, "lang":""} for uttid in batch_uttid_origin]
64
+ except:
65
+ traceback.print_exc()
66
+ return [{"uttid": uttid, "lang":""} for uttid in batch_uttid_origin]
67
+ total_dur = sum(durs)
68
+ if self.config.use_gpu:
69
+ feats, lengths = feats.cuda(), lengths.cuda()
70
+ if self.config.use_half:
71
+ feats = feats.half()
72
+
73
+ start_time = time.time()
74
+
75
+ try:
76
+ hyps = self.model.process(
77
+ feats, lengths,
78
+ self.config.beam_size,
79
+ self.config.nbest,
80
+ self.config.decode_max_len,
81
+ self.config.softmax_smoothing,
82
+ self.config.aed_length_penalty,
83
+ self.config.eos_penalty
84
+ )
85
+ except Exception as e:
86
+ traceback.print_exc()
87
+ hyps = []
88
+
89
+ elapsed = time.time() - start_time
90
+ rtf= elapsed / total_dur if total_dur > 0 else 0
91
+
92
+ results = []
93
+ for uttid, wav, hyp, dur in zip(batch_uttid, batch_wav_path, hyps, durs):
94
+ hyp = hyp[0] # only return 1-best
95
+ hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
96
+ text = self.tokenizer.detokenize(hyp_ids)
97
+ results.append({"uttid": uttid, "lang": text,
98
+ "confidence": round(hyp["confidence"].cpu().item(), 3),
99
+ "dur_s": round(dur, 3), "rtf": f"{rtf:.4f}"})
100
+ if type(wav) == str:
101
+ results[-1]["wav"] = wav
102
+ return results
103
+
104
+
105
+ def load_fireredlid_model(model_path):
106
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
107
+ #print(package["args"])
108
+ model = FireRedLidAed.from_args(package["args"])
109
+ model.load_state_dict(package["model_state_dict"], strict=False)
110
+ return model
fireredasr2s/fireredlid/models/fireredlid_aed.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import torch
4
+
5
+ from .module.conformer_encoder import ConformerEncoder
6
+ from .module.transformer_decoder import TransformerDecoder
7
+
8
+
9
+ class FireRedLidAed(torch.nn.Module):
10
+ @classmethod
11
+ def from_args(cls, args):
12
+ return cls(args)
13
+
14
+ def __init__(self, args):
15
+ super().__init__()
16
+ self.sos_id = args.sos_id
17
+ self.eos_id = args.eos_id
18
+
19
+ self.encoder = ConformerEncoder(
20
+ args.idim, args.n_layers_enc, args.n_head, args.d_model,
21
+ args.residual_dropout, args.dropout_rate,
22
+ args.kernel_size, args.pe_maxlen)
23
+
24
+ self.lid_decoder = TransformerDecoder(
25
+ args.sos_id, args.eos_id, args.pad_id, args.lid_odim,
26
+ args.n_layers_lid_dec, args.n_head, args.d_model,
27
+ args.residual_dropout, args.pe_maxlen)
28
+
29
+ def process(self, padded_input, input_lengths,
30
+ beam_size=3, nbest=1, decode_max_len=2,
31
+ softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0):
32
+ enc_outputs, enc_lengths, enc_mask = self.encoder(padded_input, input_lengths)
33
+ nbest_hyps = self.lid_decoder.batch_beam_search(
34
+ enc_outputs, enc_mask,
35
+ beam_size, nbest, decode_max_len,
36
+ softmax_smoothing, length_penalty, eos_penalty)
37
+ return nbest_hyps
fireredasr2s/fireredlid/models/module/conformer_encoder.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ConformerEncoder(nn.Module):
9
+ def __init__(self, idim, n_layers, n_head, d_model,
10
+ residual_dropout=0.1, dropout_rate=0.1, kernel_size=33,
11
+ pe_maxlen=5000):
12
+ super().__init__()
13
+ self.odim = d_model
14
+
15
+ self.input_preprocessor = Conv2dSubsampling(idim, d_model)
16
+ self.positional_encoding = RelPositionalEncoding(d_model)
17
+ self.dropout = nn.Dropout(residual_dropout)
18
+
19
+ self.layer_stack = nn.ModuleList()
20
+ for l in range(n_layers):
21
+ block = RelPosEmbConformerBlock(d_model, n_head,
22
+ residual_dropout,
23
+ dropout_rate, kernel_size)
24
+ self.layer_stack.append(block)
25
+
26
+ def forward(self, padded_input, input_lengths, pad=True):
27
+ if pad:
28
+ padded_input = F.pad(padded_input,
29
+ (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
30
+ src_mask = self.padding_position_is_0(padded_input, input_lengths)
31
+
32
+ embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask)
33
+ enc_output = self.dropout(embed_output)
34
+
35
+ pos_emb = self.dropout(self.positional_encoding(embed_output))
36
+
37
+ enc_outputs = []
38
+ for enc_layer in self.layer_stack:
39
+ enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
40
+ pad_mask=src_mask)
41
+ enc_outputs.append(enc_output)
42
+
43
+ return enc_output, input_lengths, src_mask
44
+
45
+ def padding_position_is_0(self, padded_input, input_lengths):
46
+ N, T = padded_input.size()[:2]
47
+ mask = torch.ones((N, T)).to(padded_input.device)
48
+ for i in range(N):
49
+ mask[i, input_lengths[i]:] = 0
50
+ mask = mask.unsqueeze(dim=1)
51
+ return mask.to(torch.uint8)
52
+
53
+
54
+ class RelPosEmbConformerBlock(nn.Module):
55
+ def __init__(self, d_model, n_head,
56
+ residual_dropout=0.1,
57
+ dropout_rate=0.1, kernel_size=33):
58
+ super().__init__()
59
+ self.ffn1 = ConformerFeedForward(d_model, dropout_rate)
60
+ self.mhsa = RelPosMultiHeadAttention(n_head, d_model,
61
+ residual_dropout)
62
+ self.conv = ConformerConvolution(d_model, kernel_size,
63
+ dropout_rate)
64
+ self.ffn2 = ConformerFeedForward(d_model, dropout_rate)
65
+ self.layer_norm = nn.LayerNorm(d_model)
66
+
67
+ def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None):
68
+ out = 0.5 * x + 0.5 * self.ffn1(x)
69
+ out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
70
+ out = self.conv(out, pad_mask)
71
+ out = 0.5 * out + 0.5 * self.ffn2(out)
72
+ out = self.layer_norm(out)
73
+ return out
74
+
75
+
76
+ class Swish(nn.Module):
77
+ def forward(self, x):
78
+ return x * torch.sigmoid(x)
79
+
80
+
81
+ class Conv2dSubsampling(nn.Module):
82
+ def __init__(self, idim, d_model, out_channels=32):
83
+ super().__init__()
84
+ self.conv = nn.Sequential(
85
+ nn.Conv2d(1, out_channels, 3, 2),
86
+ nn.ReLU(),
87
+ nn.Conv2d(out_channels, out_channels, 3, 2),
88
+ nn.ReLU(),
89
+ )
90
+ subsample_idim = ((idim - 1) // 2 - 1) // 2
91
+ self.out = nn.Linear(out_channels * subsample_idim, d_model)
92
+
93
+ self.subsampling = 4
94
+ left_context = right_context = 3 # both exclude currect frame
95
+ self.context = left_context + 1 + right_context # 7
96
+
97
+ def forward(self, x, x_mask):
98
+ x = x.unsqueeze(1)
99
+ x = self.conv(x)
100
+ N, C, T, D = x.size()
101
+ x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
102
+ mask = x_mask[:, :, :-2:2][:, :, :-2:2]
103
+ input_lengths = mask[:, -1, :].sum(dim=-1)
104
+ return x, input_lengths, mask
105
+
106
+
107
+ class RelPositionalEncoding(torch.nn.Module):
108
+ def __init__(self, d_model, max_len=5000):
109
+ super().__init__()
110
+ pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
111
+ pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
112
+ position = torch.arange(0, max_len).unsqueeze(1).float()
113
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
114
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
115
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
116
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
117
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
118
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
119
+
120
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
121
+ pe_negative = pe_negative[1:].unsqueeze(0)
122
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
123
+ self.register_buffer('pe', pe)
124
+
125
+ def forward(self, x):
126
+ # Tmax = 2 * max_len - 1
127
+ Tmax, T = self.pe.size(1), x.size(1)
128
+ pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
129
+ return pos_emb
130
+
131
+
132
+ class ConformerFeedForward(nn.Module):
133
+ def __init__(self, d_model, dropout_rate=0.1):
134
+ super().__init__()
135
+ pre_layer_norm = nn.LayerNorm(d_model)
136
+ linear_expand = nn.Linear(d_model, d_model*4)
137
+ nonlinear = Swish()
138
+ dropout_pre = nn.Dropout(dropout_rate)
139
+ linear_project = nn.Linear(d_model*4, d_model)
140
+ dropout_post = nn.Dropout(dropout_rate)
141
+ self.net = nn.Sequential(pre_layer_norm,
142
+ linear_expand,
143
+ nonlinear,
144
+ dropout_pre,
145
+ linear_project,
146
+ dropout_post)
147
+
148
+ def forward(self, x):
149
+ residual = x
150
+ output = self.net(x)
151
+ output = output + residual
152
+ return output
153
+
154
+
155
+ class ConformerConvolution(nn.Module):
156
+ def __init__(self, d_model, kernel_size=33, dropout_rate=0.1):
157
+ super().__init__()
158
+ assert kernel_size % 2 == 1
159
+ self.pre_layer_norm = nn.LayerNorm(d_model)
160
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False)
161
+ self.glu = F.glu
162
+ self.padding = (kernel_size - 1) // 2
163
+ self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2,
164
+ kernel_size, stride=1,
165
+ padding=self.padding,
166
+ groups=d_model*2, bias=False)
167
+ self.batch_norm = nn.LayerNorm(d_model*2)
168
+ self.swish = Swish()
169
+ self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False)
170
+ self.dropout = nn.Dropout(dropout_rate)
171
+
172
+ def forward(self, x, mask=None):
173
+ residual = x
174
+ out = self.pre_layer_norm(x)
175
+ out = out.transpose(1, 2)
176
+ if mask is not None:
177
+ out.masked_fill_(mask.ne(1), 0.0)
178
+ out = self.pointwise_conv1(out)
179
+ out = F.glu(out, dim=1)
180
+ out = self.depthwise_conv(out)
181
+
182
+ out = out.transpose(1, 2)
183
+ out = self.swish(self.batch_norm(out))
184
+ out = out.transpose(1, 2)
185
+
186
+ out = self.dropout(self.pointwise_conv2(out))
187
+ if mask is not None:
188
+ out.masked_fill_(mask.ne(1), 0.0)
189
+ out = out.transpose(1, 2)
190
+ return out + residual
191
+
192
+
193
+ class EncoderMultiHeadAttention(nn.Module):
194
+ def __init__(self, n_head, d_model,
195
+ residual_dropout=0.1):
196
+ super().__init__()
197
+ assert d_model % n_head == 0
198
+ self.n_head = n_head
199
+ self.d_k = d_model // n_head
200
+ self.d_v = self.d_k
201
+
202
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False)
203
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
204
+ self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False)
205
+
206
+ self.layer_norm_q = nn.LayerNorm(d_model)
207
+ self.layer_norm_k = nn.LayerNorm(d_model)
208
+ self.layer_norm_v = nn.LayerNorm(d_model)
209
+
210
+ self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5)
211
+ self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False)
212
+ self.dropout = nn.Dropout(residual_dropout)
213
+
214
+ def forward(self, q, k, v, mask=None):
215
+ sz_b, len_q = q.size(0), q.size(1)
216
+
217
+ residual = q
218
+ q, k, v = self.forward_qkv(q, k, v)
219
+
220
+ output, attn = self.attention(q, k, v, mask=mask)
221
+
222
+ output = self.forward_output(output, residual, sz_b, len_q)
223
+ return output, attn
224
+
225
+ def forward_qkv(self, q, k, v):
226
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
227
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
228
+
229
+ q = self.layer_norm_q(q)
230
+ k = self.layer_norm_k(k)
231
+ v = self.layer_norm_v(v)
232
+
233
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
234
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
235
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
236
+ q = q.transpose(1, 2)
237
+ k = k.transpose(1, 2)
238
+ v = v.transpose(1, 2)
239
+ return q, k, v
240
+
241
+ def forward_output(self, output, residual, sz_b, len_q):
242
+ output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
243
+ fc_out = self.fc(output)
244
+ output = self.dropout(fc_out)
245
+ output = output + residual
246
+ return output
247
+
248
+
249
+ class ScaledDotProductAttention(nn.Module):
250
+ def __init__(self, temperature):
251
+ super().__init__()
252
+ self.temperature = temperature
253
+ self.dropout = nn.Dropout(0.0)
254
+ self.INF = float('inf')
255
+
256
+ def forward(self, q, k, v, mask=None):
257
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
258
+ output, attn = self.forward_attention(attn, v, mask)
259
+ return output, attn
260
+
261
+ def forward_attention(self, attn, v, mask=None):
262
+ if mask is not None:
263
+ mask = mask.unsqueeze(1)
264
+ mask = mask.eq(0)
265
+ attn = attn.masked_fill(mask, -self.INF)
266
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
267
+ else:
268
+ attn = torch.softmax(attn, dim=-1)
269
+
270
+ d_attn = self.dropout(attn)
271
+ output = torch.matmul(d_attn, v)
272
+
273
+ return output, attn
274
+
275
+
276
+ class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
277
+ def __init__(self, n_head, d_model,
278
+ residual_dropout=0.1):
279
+ super().__init__(n_head, d_model,
280
+ residual_dropout)
281
+ d_k = d_model // n_head
282
+ self.scale = 1.0 / (d_k ** 0.5)
283
+ self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False)
284
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k))
285
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k))
286
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
287
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
288
+
289
+ def _rel_shift(self, x):
290
+ N, H, T1, T2 = x.size()
291
+ zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
292
+ x_padded = torch.cat([zero_pad, x], dim=-1)
293
+
294
+ x_padded = x_padded.view(N, H, T2 + 1, T1)
295
+ x = x_padded[:, :, 1:].view_as(x)
296
+ x = x[:, :, :, : x.size(-1) // 2 + 1]
297
+ return x
298
+
299
+ def forward(self, q, k, v, pos_emb, mask=None):
300
+ sz_b, len_q = q.size(0), q.size(1)
301
+
302
+ residual = q
303
+ q, k, v = self.forward_qkv(q, k, v)
304
+
305
+ q = q.transpose(1, 2)
306
+ n_batch_pos = pos_emb.size(0)
307
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k)
308
+ p = p.transpose(1, 2)
309
+
310
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
311
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
312
+
313
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
314
+
315
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
316
+ matrix_bd = self._rel_shift(matrix_bd)
317
+
318
+ attn_scores = matrix_ac + matrix_bd
319
+ attn_scores.mul_(self.scale)
320
+
321
+ output, attn = self.attention.forward_attention(attn_scores, v, mask=mask)
322
+
323
+ output = self.forward_output(output, residual, sz_b, len_q)
324
+ return output, attn
fireredasr2s/fireredlid/models/module/transformer_decoder.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ from typing import List, Optional, Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+
11
+ class TransformerDecoder(nn.Module):
12
+ def __init__(
13
+ self, sos_id, eos_id, pad_id, odim,
14
+ n_layers, n_head, d_model,
15
+ residual_dropout=0.1, pe_maxlen=5000):
16
+ super().__init__()
17
+ self.INF = 1e10
18
+ # parameters
19
+ self.pad_id = pad_id
20
+ self.sos_id = sos_id
21
+ self.eos_id = eos_id
22
+ self.n_layers = n_layers
23
+
24
+ # Components
25
+ self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id)
26
+ self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
27
+ self.dropout = nn.Dropout(residual_dropout)
28
+
29
+ self.layer_stack = nn.ModuleList()
30
+ for l in range(n_layers):
31
+ block = DecoderLayer(d_model, n_head, residual_dropout)
32
+ self.layer_stack.append(block)
33
+
34
+ self.tgt_word_prj = nn.Linear(d_model, odim, bias=False)
35
+ self.layer_norm_out = nn.LayerNorm(d_model)
36
+
37
+ self.tgt_word_prj.weight = self.tgt_word_emb.weight
38
+ self.scale = (d_model ** 0.5)
39
+
40
+ def batch_beam_search(self, encoder_outputs, src_masks,
41
+ beam_size=1, nbest=1, decode_max_len=0,
42
+ softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
43
+ B = beam_size
44
+ N, Ti, H = encoder_outputs.size()
45
+ device = encoder_outputs.device
46
+ maxlen = decode_max_len if decode_max_len > 0 else Ti
47
+ assert eos_penalty > 0.0
48
+
49
+ # Init
50
+ encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H)
51
+ src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti)
52
+ ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device)
53
+ t_ys = ys.clone()
54
+ confidences = torch.zeros(N*B, 1).float().to(device)
55
+ caches: List[Optional[Tensor]] = []
56
+ for _ in range(self.n_layers):
57
+ caches.append(None)
58
+ scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device)
59
+ scores = scores.repeat(N).view(N*B, 1)
60
+ is_finished = torch.zeros_like(scores)
61
+
62
+ # Autoregressive Prediction
63
+ for t in range(maxlen):
64
+ tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id)
65
+
66
+ dec_output = self.dropout(
67
+ self.tgt_word_emb(ys) * self.scale +
68
+ self.positional_encoding(ys))
69
+ # if t > 0:
70
+ # dec_output = dec_output[:, -1:, :]
71
+ i = 0
72
+ for dec_layer in self.layer_stack:
73
+ dec_output = dec_layer.forward(
74
+ dec_output, encoder_outputs,
75
+ tgt_mask, src_mask,
76
+ cache=caches[i])
77
+ caches[i] = dec_output
78
+ i += 1
79
+
80
+ dec_output = self.layer_norm_out(dec_output)
81
+
82
+ t_logit = self.tgt_word_prj(dec_output[:, -1])
83
+ t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1)
84
+ t_origin_scores = t_scores
85
+
86
+ if eos_penalty != 1.0:
87
+ t_scores[:, self.eos_id] *= eos_penalty
88
+
89
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1)
90
+ t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished)
91
+ t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished)
92
+
93
+ # Accumulated
94
+ scores = scores + t_topB_scores
95
+
96
+ # Pruning
97
+ scores = scores.view(N, B*B)
98
+ scores, topB_score_ids = torch.topk(scores, k=B, dim=1)
99
+ scores = scores.view(-1, 1)
100
+
101
+ topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B)
102
+ stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device)
103
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
104
+
105
+ # Update ys
106
+ ys = ys[topB_row_number_in_ys]
107
+ t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
108
+ ys = torch.cat((ys, t_ys), dim=1)
109
+
110
+ # Update confidences
111
+ confidences = confidences[topB_row_number_in_ys]
112
+ t_confidences = torch.gather(t_topB_scores.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
113
+ t_confidences = torch.exp(t_confidences)
114
+ assert torch.all(t_confidences <= 1.0)
115
+ assert torch.all(t_confidences >= 0.0)
116
+ confidences = torch.cat((confidences, t_confidences), dim=1)
117
+
118
+ # Update caches
119
+ new_caches: List[Optional[Tensor]] = []
120
+ for cache in caches:
121
+ if cache is not None:
122
+ new_caches.append(cache[topB_row_number_in_ys])
123
+ caches = new_caches
124
+
125
+ # Update finished state
126
+ is_finished = t_ys.eq(self.eos_id)
127
+ if is_finished.sum().item() == N*B:
128
+ break
129
+
130
+ # Length penalty (follow GNMT)
131
+ scores = scores.view(N, B)
132
+ ys = ys.view(N, B, -1)
133
+ ys_lengths = self.get_ys_lengths(ys)
134
+ if length_penalty > 0.0:
135
+ penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty)
136
+ scores /= penalty
137
+ nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1)
138
+ nbest_scores = -1.0 * nbest_scores
139
+ index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long()
140
+ nbest_ys = ys.view(N*B, -1)[index.view(-1)]
141
+ nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1)
142
+ nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1)
143
+ nbest_confidences = confidences.view(N*B, -1)[index.view(-1)].view(N, nbest_ids.size(1), -1)
144
+
145
+ # result
146
+ nbest_hyps: List[List[Dict[str, Tensor]]] = []
147
+ for n in range(N):
148
+ n_nbest_hyps: List[Dict[str, Tensor]] = []
149
+ for i, score in enumerate(nbest_scores[n]):
150
+ confidence = nbest_confidences[n, i, 1:nbest_ys_lengths[n, i]]
151
+ confidence = confidence.mean()
152
+ new_hyp = {
153
+ "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]],
154
+ "confidence": confidence
155
+ }
156
+ n_nbest_hyps.append(new_hyp)
157
+ nbest_hyps.append(n_nbest_hyps)
158
+ return nbest_hyps
159
+
160
+ def ignored_target_position_is_0(self, padded_targets, ignore_id):
161
+ mask = torch.ne(padded_targets, ignore_id)
162
+ mask = mask.unsqueeze(dim=1)
163
+ T = padded_targets.size(-1)
164
+ upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype)
165
+ upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device)
166
+ return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8)
167
+
168
+ def upper_triangular_is_0(self, size):
169
+ ones = torch.ones(size, size)
170
+ tri_left_ones = torch.tril(ones)
171
+ return tri_left_ones.to(torch.uint8)
172
+
173
+ def set_finished_beam_score_to_zero(self, scores, is_finished):
174
+ NB, B = scores.size()
175
+ is_finished = is_finished.float()
176
+ mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device)
177
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
178
+ return scores * (1 - is_finished) + mask_score * is_finished
179
+
180
+ def set_finished_beam_y_to_eos(self, ys, is_finished):
181
+ is_finished = is_finished.long()
182
+ return ys * (1 - is_finished) + self.eos_id * is_finished
183
+
184
+ def get_ys_lengths(self, ys):
185
+ N, B, Tmax = ys.size()
186
+ ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1)
187
+ return ys_lengths.int()
188
+
189
+
190
+
191
+ class DecoderLayer(nn.Module):
192
+ def __init__(self, d_model, n_head, dropout):
193
+ super().__init__()
194
+ self.self_attn_norm = nn.LayerNorm(d_model)
195
+ self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
196
+
197
+ self.cross_attn_norm = nn.LayerNorm(d_model)
198
+ self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
199
+
200
+ self.mlp_norm = nn.LayerNorm(d_model)
201
+ self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)
202
+
203
+ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
204
+ cache=None):
205
+ x = dec_input
206
+ residual = x
207
+ x = self.self_attn_norm(x)
208
+ if cache is not None:
209
+ xq = x[:, -1:, :]
210
+ residual = residual[:, -1:, :]
211
+ self_attn_mask = self_attn_mask[:, -1:, :]
212
+ else:
213
+ xq = x
214
+ x = self.self_attn(xq, x, x, mask=self_attn_mask)
215
+ x = residual + x
216
+
217
+ residual = x
218
+ x = self.cross_attn_norm(x)
219
+ x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
220
+ x = residual + x
221
+
222
+ residual = x
223
+ x = self.mlp_norm(x)
224
+ x = residual + self.mlp(x)
225
+
226
+ if cache is not None:
227
+ x = torch.cat([cache, x], dim=1)
228
+
229
+ return x
230
+
231
+
232
+ class DecoderMultiHeadAttention(nn.Module):
233
+ def __init__(self, d_model, n_head, dropout=0.1):
234
+ super().__init__()
235
+ self.d_model = d_model
236
+ self.n_head = n_head
237
+ self.d_k = d_model // n_head
238
+
239
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k)
240
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
241
+ self.w_vs = nn.Linear(d_model, n_head * self.d_k)
242
+
243
+ self.attention = DecoderScaledDotProductAttention(
244
+ temperature=self.d_k ** 0.5)
245
+ self.fc = nn.Linear(n_head * self.d_k, d_model)
246
+ self.dropout = nn.Dropout(dropout)
247
+
248
+ def forward(self, q, k, v, mask=None):
249
+ bs = q.size(0)
250
+
251
+ q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
252
+ k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k)
253
+ v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k)
254
+ q = q.transpose(1, 2)
255
+ k = k.transpose(1, 2)
256
+ v = v.transpose(1, 2)
257
+
258
+ if mask is not None:
259
+ mask = mask.unsqueeze(1)
260
+
261
+ output = self.attention(q, k, v, mask=mask)
262
+
263
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
264
+ output = self.fc(output)
265
+ output = self.dropout(output)
266
+
267
+ return output
268
+
269
+
270
+ class DecoderScaledDotProductAttention(nn.Module):
271
+ def __init__(self, temperature):
272
+ super().__init__()
273
+ self.temperature = temperature
274
+ self.INF = float("inf")
275
+
276
+ def forward(self, q, k, v, mask=None):
277
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
278
+ if mask is not None:
279
+ mask = mask.eq(0)
280
+ attn = attn.masked_fill(mask, -self.INF)
281
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
282
+ else:
283
+ attn = torch.softmax(attn, dim=-1)
284
+ output = torch.matmul(attn, v)
285
+ return output
286
+
287
+
288
+ class PositionwiseFeedForward(nn.Module):
289
+ def __init__(self, d_model, d_ff, dropout=0.1):
290
+ super().__init__()
291
+ self.w_1 = nn.Linear(d_model, d_ff)
292
+ self.act = nn.GELU()
293
+ self.w_2 = nn.Linear(d_ff, d_model)
294
+ self.dropout = nn.Dropout(dropout)
295
+
296
+ def forward(self, x):
297
+ output = self.w_2(self.act(self.w_1(x)))
298
+ output = self.dropout(output)
299
+ return output
300
+
301
+
302
+ class PositionalEncoding(nn.Module):
303
+ def __init__(self, d_model, max_len=5000):
304
+ super().__init__()
305
+ assert d_model % 2 == 0
306
+ pe = torch.zeros(max_len, d_model, requires_grad=False)
307
+ position = torch.arange(0, max_len).unsqueeze(1).float()
308
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
309
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
310
+ pe[:, 0::2] = torch.sin(position * div_term)
311
+ pe[:, 1::2] = torch.cos(position * div_term)
312
+ pe = pe.unsqueeze(0)
313
+ self.register_buffer('pe', pe)
314
+
315
+ def forward(self, x):
316
+ length = x.size(1)
317
+ return self.pe[:, :length].clone().detach()
fireredasr2s/fireredlid/models/param.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import logging
4
+
5
+ import torch
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def count_model_parameters(model):
11
+ if not isinstance(model, torch.nn.Module):
12
+ return 0, 0
13
+ name = f"{model.__class__.__name__} {model.__class__}"
14
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
15
+ size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
16
+ logger.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
17
+ return num, size
fireredasr2s/fireredlid/speech2lang.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
4
+
5
+ import argparse
6
+ import json
7
+ import logging
8
+ import os
9
+
10
+ from fireredlid.lid import FireRedLid, FireRedLidConfig
11
+ from fireredlid.utils.io import get_wav_info
12
+
13
+ logging.basicConfig(level=logging.INFO,
14
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
15
+ logger = logging.getLogger("fireredlid.bin.speech2lang")
16
+
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--model_dir', type=str, required=True)
20
+
21
+ # Input / Output
22
+ parser.add_argument("--wav_path", type=str)
23
+ parser.add_argument("--wav_paths", type=str, nargs="*")
24
+ parser.add_argument("--wav_dir", type=str)
25
+ parser.add_argument("--wav_scp", type=str)
26
+ parser.add_argument("--sort_wav_by_dur", type=int, default=0)
27
+ parser.add_argument("--output", type=str)
28
+ # Decode Options
29
+ parser.add_argument('--use_gpu', type=int, default=1)
30
+ parser.add_argument('--use_half', type=int, default=0)
31
+ parser.add_argument("--batch_size", type=int, default=1)
32
+
33
+
34
+ def main(args):
35
+ wavs = get_wav_info(args)
36
+ fout = open(args.output, "w") if args.output else None
37
+ foutl = open(args.output + ".jsonl", "w") if args.output else None
38
+
39
+ lid_config = FireRedLidConfig(
40
+ args.use_gpu,
41
+ args.use_half
42
+ )
43
+ model = FireRedLid.from_pretrained(args.model_dir, lid_config)
44
+
45
+ batch_uttid = []
46
+ batch_wav_path = []
47
+ for i, wav in enumerate(wavs):
48
+ uttid, wav_path = wav
49
+ batch_uttid.append(uttid)
50
+ batch_wav_path.append(wav_path)
51
+ if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1:
52
+ continue
53
+
54
+ results = model.process(batch_uttid, batch_wav_path)
55
+
56
+ for result in results:
57
+ logger.info(result)
58
+ if fout is not None:
59
+ foutl.write(f"{json.dumps(result, ensure_ascii=False)}\n")
60
+ fout.write(f"{result['uttid']}\t{result['lang']}\n")
61
+
62
+ if fout: fout.flush()
63
+ if foutl: foutl.flush()
64
+ batch_uttid = []
65
+ batch_wav_path = []
66
+ if fout: fout.close()
67
+ if foutl: foutl.close()
68
+
69
+
70
+ if __name__ == "__main__":
71
+ args = parser.parse_args()
72
+ logger.info(args)
73
+ main(args)
fireredasr2s/fireredlid/tokenizer/lid_tokenizer.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ from ..data.token_dict import TokenDict
4
+
5
+
6
+ class LidTokenizer:
7
+
8
+ def __init__(self, dict_path, unk="<unk>"):
9
+ self.dict = TokenDict(dict_path, unk=unk)
10
+
11
+ def detokenize(self, inputs, join_symbol=" "):
12
+ if len(inputs) > 0 and type(inputs[0]) == int:
13
+ tokens = [self.dict[id] for id in inputs]
14
+ else:
15
+ tokens = inputs
16
+ s = f"{join_symbol}".join(tokens)
17
+ return s
fireredasr2s/fireredlid/utils/io.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
2
+
3
+ import glob
4
+ import logging
5
+ import os
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def get_wav_info(args):
11
+ """
12
+ Returns:
13
+ wavs: list of (uttid, wav_path)
14
+ """
15
+ base = lambda p: os.path.basename(p).replace(".wav", "")
16
+ if args.wav_path:
17
+ wavs = [(base(args.wav_path), args.wav_path)]
18
+ elif args.wav_paths and len(args.wav_paths) >= 1:
19
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
20
+ elif args.wav_scp:
21
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
22
+ if args.sort_wav_by_dur:
23
+ logger.info("Sort wav by duration...")
24
+ utt2dur = os.path.join(os.path.dirname(args.wav_scp), "utt2dur")
25
+ if os.path.exists(utt2dur):
26
+ utt2dur = [l.strip().split() for l in open(utt2dur)]
27
+ utt2dur = {l[0]: float(l[1]) for l in utt2dur if len(l) == 2}
28
+ wavs = sorted(wavs, key=lambda x: -utt2dur[x[0]])
29
+ logger.info("Sort Done")
30
+ else:
31
+ logger.info(f"Not find {utt2dur}, un-sort")
32
+ elif args.wav_dir:
33
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
34
+ wavs = [(base(p), p) for p in sorted(wavs)]
35
+ else:
36
+ raise ValueError("Please provide valid wav info")
37
+ logger.info(f"#wavs={len(wavs)}")
38
+ return wavs
fireredasr2s/fireredpunc/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
2
+
3
+ import os
4
+ import sys
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=ALL, 1=INFO, 2=WARNING, 3=ERROR
8
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
9
+
10
+ __version__ = "0.0.1"
11
+
12
+ _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
13
+
14
+ try:
15
+ from fireredasr2s.fireredpunc.punc import FireRedPunc, FireRedPuncConfig
16
+ except ImportError:
17
+ if _CURRENT_DIR not in sys.path:
18
+ sys.path.insert(0, _CURRENT_DIR)
19
+ from .punc import FireRedPunc, FireRedPuncConfig
20
+
21
+
22
+ # API
23
+ __all__ = [
24
+ "__version__",
25
+ "FireRedPunc",
26
+ "FireRedPuncConfig",
27
+ ]
fireredasr2s/fireredpunc/add_punc.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
4
+
5
+ import argparse
6
+ import logging
7
+ import re
8
+
9
+ from fireredpunc.punc import FireRedPunc, FireRedPuncConfig
10
+
11
+ logging.basicConfig(level=logging.INFO,
12
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
13
+ logger = logging.getLogger("fireredvad.bin.vad")
14
+
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--model_dir', type=str, required=True)
18
+ # Input / Output
19
+ parser.add_argument("--input_txt", type=str, default="")
20
+ parser.add_argument("--input_file", type=str, default="")
21
+ parser.add_argument("--output", type=str)
22
+ parser.add_argument("--input_contain_uttid", type=int, default=0)
23
+ # Punc Options
24
+ parser.add_argument('--use_gpu', type=int, default=1)
25
+ parser.add_argument('--batch_size', type=int, default=1)
26
+ parser.add_argument('--sentence_max_length', type=int, default=-1)
27
+
28
+
29
+ def main(args):
30
+ in_texts = get_input(args)
31
+ fout = open(args.output, "w") if args.output else None
32
+
33
+ punc_config = FireRedPuncConfig(
34
+ args.use_gpu,
35
+ args.sentence_max_length
36
+ )
37
+ model = FireRedPunc.from_pretrained(args.model_dir, punc_config)
38
+
39
+ batch_text = []
40
+ batch_uttid = []
41
+ for i, (uttid, text) in enumerate(in_texts):
42
+ batch_text.append(text)
43
+ batch_uttid.append(uttid)
44
+ if len(batch_text) < args.batch_size and i != len(in_texts) - 1:
45
+ continue
46
+
47
+ results = model.process(batch_text)
48
+
49
+ for uttid, result in zip(batch_uttid, results):
50
+ logger.info(result)
51
+ if fout:
52
+ if args.input_contain_uttid:
53
+ fout.write(f"{uttid}\t{result['punc_text']}\n")
54
+ else:
55
+ fout.write(f"{result['punc_text']}\n")
56
+
57
+ batch_text = []
58
+ batch_uttid = []
59
+ if fout: fout.flush()
60
+
61
+
62
+ def get_input(args):
63
+ in_texts = []
64
+ if args.input_file:
65
+ with open(args.input_file, "r") as fin:
66
+ for i, l in enumerate(fin):
67
+ uttid = i
68
+ text = l.strip()
69
+ if args.input_contain_uttid:
70
+ uttid, text = text.split(maxsplit=1)
71
+ text = _remove_punc_and_fix_space(text)
72
+ in_texts.append((uttid, text))
73
+ logger.info(f"#text={len(in_texts)}")
74
+ elif args.input_txt:
75
+ logger.info(f"Input txt: {args.input_txt}")
76
+ text = _remove_punc_and_fix_space(args.input_txt)
77
+ in_texts.append((0, text))
78
+ return in_texts
79
+
80
+
81
+ def _remove_punc_and_fix_space(text):
82
+ origin = text
83
+ text = re.sub("[,。?!,\.?!]", " ", text)
84
+ pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u31f0-\u31ff\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\U0002ceb0-\U0002ebef\U00030000-\U0003134f])')
85
+ parts = pattern.split(text.strip())
86
+ parts = [p for p in parts if len(p.strip()) > 0]
87
+ text = "".join(parts)
88
+ if origin != text:
89
+ logger.debug(f"Change text: '{origin}' --> '{text}'")
90
+ return text
91
+
92
+
93
+ if __name__ == "__main__":
94
+ args = parser.parse_args()
95
+ logger.info(args)
96
+ main(args)
fireredasr2s/fireredpunc/data/__init__.py ADDED
File without changes
fireredasr2s/fireredpunc/data/hf_bert_tokenizer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
2
+
3
+ import logging
4
+ import re
5
+ import traceback
6
+
7
+ from transformers import BertTokenizer
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ # HuggingFace BERT Tokenizer Wrapper
13
+ class HfBertTokenizer:
14
+ def __init__(self, huggingface_tokenizer_dir):
15
+ self.tokenizer = BertTokenizer.from_pretrained(huggingface_tokenizer_dir)
16
+
17
+ def tokenize(self, text, recover_unk=False):
18
+ tokens = self.tokenizer.tokenize(text)
19
+ tokens_id = self.tokenizer.convert_tokens_to_ids(tokens)
20
+ if recover_unk:
21
+ try:
22
+ tokens = self._recover_unk(text.lower(), tokens)
23
+ except Exception as e:
24
+ traceback.print_exc()
25
+ return tokens, tokens_id
26
+
27
+ def _recover_unk(self, text, tokens):
28
+ if "[UNK]" not in tokens:
29
+ return tokens
30
+
31
+ new_tokens = []
32
+ text_no_space = text.replace(" ", "")
33
+
34
+ # Fast recover:
35
+ if re.match(r"^[^a-zA-Z0-9']+$", text):
36
+ tmp_text = text_no_space
37
+ if len(tmp_text) == len(tokens):
38
+ success = True
39
+ for t, tok in zip(tmp_text, tokens):
40
+ if tok != "[UNK]" and t != tok:
41
+ success = False
42
+ break
43
+ new_tokens.append(t)
44
+ if success:
45
+ return new_tokens
46
+ new_tokens = []
47
+
48
+ text_pos = 0
49
+ i = 0
50
+ while i < len(tokens):
51
+ token = tokens[i]
52
+ if token == "[UNK]":
53
+ unk_count = 0
54
+ j = i
55
+ while j < len(tokens) and tokens[j] == "[UNK]":
56
+ unk_count += 1
57
+ j += 1
58
+
59
+ post_token = ""
60
+ if j < len(tokens):
61
+ post_token = tokens[j].replace("##", "")
62
+
63
+ if post_token:
64
+ remaining = text_no_space[text_pos:]
65
+ anchor_pos = remaining.find(post_token)
66
+ if anchor_pos != -1:
67
+ unk_chars = remaining[:anchor_pos]
68
+ else:
69
+ unk_chars = remaining[:unk_count]
70
+ else:
71
+ unk_chars = text_no_space[text_pos:text_pos + unk_count]
72
+
73
+ for k in range(unk_count):
74
+ if k < len(unk_chars):
75
+ new_tokens.append(unk_chars[k])
76
+ else:
77
+ new_tokens.append("")
78
+ text_pos += len(unk_chars)
79
+ i = j
80
+ else:
81
+ new_tokens.append(token)
82
+ token_clean = token.replace("##", "")
83
+ text_pos += len(token_clean)
84
+ i += 1
85
+
86
+ new_tokens = [t for t in new_tokens if t and t != "[UNK]"]
87
+ return new_tokens
88
+
89
+ def detokenize(self, inputs, join_symbol="", replace_spm_space=True):
90
+ raise NotImplementedError
91
+
92
+
93
+
94
+ if __name__ == "__main__":
95
+ import os
96
+ model_dir = "../../../pretrained_models/FireRedPunc"
97
+ tokenizer = HfBertTokenizer(os.path.join(model_dir, "chinese-lert-base"))
98
+
99
+ txts = [
100
+ # 基础测试
101
+ "你好吗",
102
+ "你好 吗",
103
+ "hello how are you",
104
+
105
+ # 连续生僻字(连续 [UNK])
106
+ "寄蜉蝣于天地渺沧海之一粟",
107
+ "魑魅魍魉", # 4个连续生僻字
108
+ "饕餮耄耋", # 另一组4个连续生僻字
109
+
110
+ # 中英混合 + 生僻字
111
+ "寄蜉蝣于天地渺沧海之一粟how are you魑魅魍魉你蝣蜉啊蝣",
112
+ "hello魑魅world魍魉test", # 英文夹生僻字
113
+
114
+ # 开头/结尾的 [UNK]
115
+ "蜉蝣你好", # 开头连续生僻字
116
+ "你好蜉蝣", # 结尾连续生僻字
117
+ "蜉你蝣好", # 交替出现
118
+
119
+ # 特殊符号(可能产生 [UNK])
120
+ "你好!@#¥%",
121
+ "【测试】《标题》",
122
+ "价格:¥99.9元",
123
+
124
+ # 复杂混合
125
+ "【魑魅】说:你好蜉蝣",
126
+ "饕餮之徒hello耄耋老人",
127
+
128
+ # 边界情况
129
+ "", # 空字符串
130
+ "蜉", # 单个生僻字
131
+ "魑魅魍魉饕餮", # 6个连续生僻字
132
+
133
+ # ------------------------------------------
134
+ # 测试:一个 [UNK] 可能对应多个字符的场景
135
+ # ------------------------------------------
136
+
137
+ # 生僻英文单词(可能不在词表中)
138
+ "价格是xyz123元", # xyz123 可能被标记为 [UNK]
139
+ "使用qwerty键盘", # qwerty 可能被标记为 [UNK]
140
+
141
+ # 特殊符号组合
142
+ "商标™注册®版权©", # TM R C 等符号
143
+ "温度是25℃左右", # 摄氏度符号
144
+ "面积100㎡价格", # 平方米符号
145
+
146
+ # 日文/韩文字符(可能不在中文词表中)
147
+ "你好こんにちは世界", # 日文平假名
148
+ "欢迎안녕하세요光临", # 韩文
149
+
150
+ # 罗马数字
151
+ "第Ⅷ章内容", # 罗马数字8
152
+ "共Ⅻ个部分", # 罗马数字12
153
+
154
+ # 数学符号
155
+ "结果是≈100左右", # 约等于符号
156
+ "价格≤1000元", # 小于等于符号
157
+
158
+ # 带圈数字
159
+ "第①步操作", # 带圈数字1
160
+ "共⑩个选项", # 带圈数字10
161
+ ]
162
+
163
+ print("=" * 60)
164
+ print("UNK 恢复测试")
165
+ print("=" * 60)
166
+ for txt in txts:
167
+ if not txt:
168
+ print(f"(空字符串) --> []")
169
+ continue
170
+ tokens_raw = tokenizer.tokenizer.tokenize(txt)
171
+ tokens_recovered, _ = tokenizer.tokenize(txt, recover_unk=True)
172
+ has_unk = "[UNK]" in tokens_raw
173
+ status = "✓" if "[UNK]" not in tokens_recovered else "✗"
174
+ if has_unk:
175
+ print(f"{status} {txt}")
176
+ print(f" 原始: {tokens_raw}")
177
+ print(f" 恢复: {tokens_recovered}")
178
+ else:
179
+ print(f" {txt} --> {tokens_recovered}")
180
+ print("=" * 60)
fireredasr2s/fireredpunc/data/token_dict.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
2
+
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class TokenDict:
9
+ def __init__(self, dict_path, unk=""):
10
+ assert dict_path != ""
11
+ self.id2word, self.word2id = self.read_dict(dict_path)
12
+ self.unk = unk
13
+ assert unk == "" or unk in self.word2id
14
+ self.unkid = self.word2id[unk] if unk else -1
15
+
16
+ def get(self, key, default):
17
+ if type(default) == str:
18
+ default = self.word2id[default]
19
+ return self.word2id.get(key, default)
20
+
21
+ def __getitem__(self, key):
22
+ if type(key) == str:
23
+ if self.unk:
24
+ return self.word2id.get(key, self.word2id[self.unk])
25
+ else:
26
+ return self.word2id[key]
27
+ elif type(key) == int:
28
+ return self.id2word[key]
29
+ else:
30
+ raise TypeError("Key should be str or int")
31
+
32
+ def __len__(self):
33
+ return len(self.id2word)
34
+
35
+ def __contains__(self, query):
36
+ if type(query) == str:
37
+ return query in self.word2id
38
+ elif type(query) == int:
39
+ return query in self.id2word
40
+ else:
41
+ raise TypeError("query should be str or int")
42
+
43
+ def read_dict(self, dict_path):
44
+ id2word, word2id = [], {}
45
+ with open(dict_path, encoding='utf8') as f:
46
+ for i, line in enumerate(f):
47
+ tokens = line.strip().split()
48
+ if len(tokens) >= 2:
49
+ word, index = tokens[0], int(tokens[1])
50
+ elif len(tokens) == 1:
51
+ word, index = tokens[0], i
52
+ else: # empty line or space
53
+ logger.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
54
+ word, index = " ", i
55
+ assert len(id2word) == index
56
+ assert len(word2id) == index
57
+ if word == "<space>":
58
+ logger.info(f"NOTE: Find <space> in {dict_path}:L{i} and convert it to ' '")
59
+ word = " "
60
+ word2id[word] = index
61
+ id2word.append(word)
62
+ assert len(id2word) == len(word2id)
63
+ return id2word, word2id
fireredasr2s/fireredpunc/models/__init__.py ADDED
File without changes
fireredasr2s/fireredpunc/models/fireredpunc_bert.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
2
+
3
+ import logging
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import transformers
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class FireRedPuncBert(nn.Module):
13
+ @classmethod
14
+ def from_args(cls, args):
15
+ assert args.pretrained_bert, "just support pretrained bert"
16
+ args.bert = transformers.BertModel.from_pretrained(f"{args.pretrained_bert}")
17
+ args.bert.pooler = None
18
+ args.hidden_size = args.bert.config.hidden_size
19
+ return cls(args)
20
+
21
+ def __init__(self, args):
22
+ super().__init__()
23
+ self.bert = args.bert if args.pretrained_bert else None # init in build()
24
+ self.dropout = nn.Dropout(float(args.classifier_dropout))
25
+ self.classifier = nn.Linear(args.hidden_size, args.odim)
26
+ self.max_input_len = self.bert.embeddings.position_embeddings.num_embeddings - 1
27
+ self.cls_id = args.cls_id # set in punc_data.py:PuncData.build()
28
+ self.ignore_index = args.ignore_index # used by loss
29
+
30
+ @torch.jit.export
31
+ def forward_model(self, padded_inputs, lengths):
32
+ if padded_inputs.size(1) <= self.max_input_len:
33
+ score = self._forward(padded_inputs, lengths)
34
+ else:
35
+ logger.info("padded_inputs is too long, split it into chunks") #, flush=True)
36
+ chunk_score = []
37
+ chunks = padded_inputs.split(self.max_input_len, dim=1)
38
+ left_lengths = lengths
39
+ for i, chunk in enumerate(chunks, 1):
40
+ chunk_lengths = torch.clamp(left_lengths, min=0, max=self.max_input_len)
41
+ left_lengths = left_lengths - chunk_lengths
42
+ chunk_score.append(self._forward(chunk, chunk_lengths))
43
+ score = torch.cat(chunk_score, dim=1)
44
+ return score
45
+
46
+ def _forward(self, padded_inputs, lengths):
47
+ padded_inputs, lengths = self.add_cls(padded_inputs, lengths)
48
+ attention_mask = create_huggingface_bert_attention_mask(lengths)
49
+ outputs = self.bert(padded_inputs, attention_mask)
50
+ sequence_output = outputs[0][:, 1:] # 1 means remove [CLS]'s output
51
+ sequence_output = self.dropout(sequence_output)
52
+ score = self.classifier(sequence_output)
53
+ return score
54
+
55
+ def add_cls(self, padded_inputs, lengths):
56
+ N = padded_inputs.size(0)
57
+ cls = padded_inputs.new_ones(N, 1).fill_(self.cls_id)
58
+ padded_inputs = torch.cat((cls, padded_inputs), dim=1)
59
+ lengths = lengths + 1
60
+ return padded_inputs, lengths
61
+
62
+
63
+ def create_huggingface_bert_attention_mask(lengths):
64
+ N = int(lengths.size(0))
65
+ T = int(lengths.max())
66
+ mask = lengths.new_ones((N, T))
67
+ for i in range(N):
68
+ mask[i, lengths[i]:] = 0
69
+ return mask.float()
fireredasr2s/fireredpunc/models/param.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
2
+
3
+ import logging
4
+
5
+ import torch
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def count_model_parameters(model):
11
+ if not isinstance(model, torch.nn.Module):
12
+ return 0, 0
13
+ name = f"{model.__class__.__name__} {model.__class__}"
14
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
15
+ size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
16
+ logger.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
17
+ return num, size
fireredasr2s/fireredpunc/punc.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
2
+
3
+ import logging
4
+ import os
5
+ import re
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+
10
+ from .data.hf_bert_tokenizer import HfBertTokenizer
11
+ from .models.fireredpunc_bert import FireRedPuncBert
12
+ from .models.param import count_model_parameters
13
+ from .data.token_dict import TokenDict
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class FireRedPuncConfig:
20
+ use_gpu: bool = True
21
+ sentence_max_length: int = -1
22
+
23
+
24
+ class FireRedPunc:
25
+ @classmethod
26
+ def from_pretrained(cls, model_dir, config):
27
+ model = load_punc_bert_model(model_dir)
28
+ model_io = ModelIO(model_dir)
29
+ assert isinstance(config, FireRedPuncConfig)
30
+ count_model_parameters(model)
31
+ model.eval()
32
+ return cls(model_io, model, config)
33
+
34
+ def __init__(self, model_io, model, config):
35
+ self.model_io = model_io
36
+ self.model = model
37
+ self.config = config
38
+ if self.config.use_gpu:
39
+ self.model.cuda()
40
+ else:
41
+ self.model.cpu()
42
+
43
+ @torch.no_grad()
44
+ def process(self, batch_text, batch_uttid=None):
45
+ # Intercept empty input to prevent max([]) from throwing an error
46
+ if not batch_text:
47
+ return []
48
+
49
+ # 1. Prepare inputs
50
+ padded_inputs, lengths, txt_tokens = self.model_io.text2tensor(batch_text)
51
+ if self.config.use_gpu:
52
+ padded_inputs, lengths = padded_inputs.cuda(), lengths.cuda()
53
+
54
+ # 2. Model inference
55
+ logits = self.model.forward_model(padded_inputs, lengths) # (N,T,C)
56
+ preds = self.get_punc_pred(logits, lengths)
57
+
58
+ # 3. Add Punc to txt
59
+ punc_txts = self.model_io.add_punc_to_txt(txt_tokens, preds)
60
+ punc_txts = [RuleBaedTxtFix.fix(txt) for txt in punc_txts]
61
+
62
+ # 4. Format output
63
+ results = []
64
+ for i in range(len(batch_text)):
65
+ result = {
66
+ "punc_text": punc_txts[i],
67
+ "origin_text": batch_text[i],
68
+ }
69
+ if batch_uttid is not None:
70
+ result["uttid"] = batch_uttid[i]
71
+ results.append(result)
72
+ return results
73
+
74
+ @torch.no_grad()
75
+ def process_with_timestamp(self, batch_timestamp, batch_uttid=None):
76
+ # Intercept empty input to prevent max([]) from throwing an error
77
+ if not batch_timestamp:
78
+ return []
79
+
80
+ # 1. Prepare inputs
81
+ padded_inputs, lengths, batch_txt_tokens, batch_tokens_split_num = \
82
+ self.model_io.timestamp2tensor(batch_timestamp)
83
+ if self.config.use_gpu:
84
+ padded_inputs, lengths = padded_inputs.cuda(), lengths.cuda()
85
+
86
+ # 2. Model inference
87
+ logits = self.model.forward_model(padded_inputs, lengths) # (N,T,C)
88
+ preds = self.get_punc_pred(logits, lengths, batch_txt_tokens)
89
+
90
+ # 3. Add Punc to txt
91
+ punc_txts = self.model_io.add_punc_to_txt_with_timestamp(
92
+ batch_txt_tokens, preds, batch_timestamp, batch_tokens_split_num)
93
+
94
+ new_punc_txts = []
95
+ for txts in punc_txts:
96
+ new_txts = []
97
+ for idx, txt in enumerate(txts):
98
+ # Only capitalize first letter after sentence-ending punctuation (.!?), not after comma
99
+ if idx == 0:
100
+ cap = True
101
+ else:
102
+ prev_text = new_txts[idx - 1][0]
103
+ cap = bool(prev_text) and prev_text[-1] in '.!?。?!'
104
+ new_txts.append((RuleBaedTxtFix.fix(txt[0], capitalize_first=cap), txt[1], txt[2]))
105
+ new_punc_txts.append(new_txts)
106
+ punc_txts = new_punc_txts
107
+
108
+ # 4. Format output
109
+ results = []
110
+ for i in range(len(batch_timestamp)):
111
+ result = {
112
+ "punc_sentences": [
113
+ {"punc_text": t[0], "start_s": t[1], "end_s": t[2]} for t in punc_txts[i]
114
+ ],
115
+ }
116
+ if batch_uttid is not None:
117
+ result["uttid"] = batch_uttid[i]
118
+ results.append(result)
119
+ return results
120
+
121
+ def get_punc_pred(self, punc_logits, lengths, batch_txt_tokens=None):
122
+ max_len = torch.max(lengths).cpu().item()
123
+ if max_len <= self.config.sentence_max_length or self.config.sentence_max_length <= 0 or batch_txt_tokens is None:
124
+ _, preds = torch.max(punc_logits, dim=-1)
125
+ preds = preds.cpu().tolist()
126
+ preds = [pred[:lengths[i]] for i, pred in enumerate(preds)]
127
+ else:
128
+ preds = self.get_punc_pred_limit_max_len(punc_logits, lengths,
129
+ batch_txt_tokens)
130
+ return preds
131
+
132
+ def get_punc_pred_limit_max_len(self, punc_logits, lengths, batch_txt_tokens):
133
+ sentence_max_length = self.config.sentence_max_length
134
+ preds = []
135
+ batch_probs = punc_logits.softmax(dim=-1).cpu()
136
+ lengths = lengths.cpu()
137
+ for n in range(len(batch_probs)):
138
+ # Process each sentence
139
+ single_sentence_seg_token_ids = []
140
+ probs = batch_probs[n]
141
+ L = lengths[n]
142
+ tokens = batch_txt_tokens[n]
143
+ l = 0
144
+ while l < L:
145
+ r = l
146
+ total_num = 0.0
147
+ max_seg_prob = -1.0
148
+ max_index = -1
149
+ while r < L:
150
+ token_num = 0.0
151
+ s = re.sub("^##", "", tokens[r])
152
+ for j in range(len(s)):
153
+ if re.match("[a-zA-Z0-9']", s[j]):
154
+ token_num += 0.5
155
+ else:
156
+ token_num += 1
157
+
158
+ if total_num + token_num > sentence_max_length and max_seg_prob >= 0:
159
+ break
160
+
161
+ space_prob = probs[r][0]
162
+ seg_prob = 1.0 - space_prob
163
+ if seg_prob >= max_seg_prob:
164
+ max_seg_prob = seg_prob
165
+ max_index = r
166
+ total_num += token_num
167
+ r += 1
168
+ if seg_prob >= space_prob:
169
+ break
170
+ if r >= L:
171
+ # r is == sentence_length, r-- to avoid out-of-range-access
172
+ r -= 1
173
+ else:
174
+ # if total_num + token_num > sentence_max_length,
175
+ # we find l to max score's index as a sentence
176
+ # (max index is betweent [l, r])
177
+ r = max_index
178
+ if token_num > sentence_max_length:
179
+ logger.info(f"Too long token...{n}, {l}, {r}, {total_num}, {token_num}, {tokens[l]}, {tokens[r]}")
180
+ # range [l, r] is a sentence
181
+ for idx in range(l, r):
182
+ single_sentence_seg_token_ids.append(0) # 0 should be space
183
+ # argmax BEGIN (find an elegant way?)
184
+ pred_id = 1;
185
+ max_pred_prob = 0.0;
186
+ for k in range(1, len(probs[r])):
187
+ if probs[r][k] > max_pred_prob:
188
+ pred_id = k;
189
+ max_pred_prob = probs[r][k];
190
+ # argmax END
191
+ single_sentence_seg_token_ids.append(pred_id);
192
+ l = r + 1
193
+ preds.append(single_sentence_seg_token_ids)
194
+ return preds
195
+
196
+
197
+ def load_punc_bert_model(model_dir):
198
+ model_path = os.path.join(model_dir, "model.pth.tar")
199
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
200
+ package["args"].bert = None
201
+ package["args"].pretrained_bert = os.path.join(model_dir, "chinese-lert-base")
202
+ model = FireRedPuncBert.from_args(package["args"])
203
+ model.load_state_dict(package["model_state_dict"], strict=False)
204
+ return model
205
+
206
+
207
+ class ModelIO:
208
+ def __init__(self, model_dir):
209
+ self.tokenizer = HfBertTokenizer(os.path.join(model_dir, "chinese-lert-base"))
210
+ self.in_dict = TokenDict(os.path.join(model_dir, "chinese-bert-wwm-ext_vocab.txt"), unk="[UNK]")
211
+ self.out_dict = TokenDict(os.path.join(model_dir, "out_dict"))
212
+ self.INPUT_IGNORE_ID = self.in_dict["[PAD]"]
213
+ self.DEFAULT_OUT = " "
214
+
215
+ def text2tensor(self, batch_text):
216
+ batch_txt_tokens = []
217
+ batch_input_seqs = []
218
+ for text in batch_text:
219
+ tokens, _ = self.tokenizer.tokenize(text, recover_unk=True)
220
+ input_seq = []
221
+ for token in tokens:
222
+ input_seq.append(self.in_dict.get(token, self.in_dict.unk))
223
+ batch_txt_tokens.append(tokens)
224
+ batch_input_seqs.append(input_seq)
225
+ padded_inputs, lengths = self.pad_list(batch_input_seqs, self.INPUT_IGNORE_ID)
226
+ return padded_inputs, lengths, batch_txt_tokens
227
+
228
+ def timestamp2tensor(self, batch_timestamp):
229
+ batch_txt_tokens = []
230
+ batch_input_seqs = []
231
+ batch_tokens_split_num = []
232
+ for timestamps in batch_timestamp:
233
+ txt_token = []
234
+ input_seq = []
235
+ tokens_split_num = []
236
+ for token, start, end in timestamps:
237
+ sub_tokens, _ = self.tokenizer.tokenize(token, recover_unk=True)
238
+ tokens_split_num.append(len(sub_tokens))
239
+ txt_token.extend(sub_tokens)
240
+ for sub_token in sub_tokens:
241
+ input_seq.append(self.in_dict.get(sub_token, self.in_dict.unk))
242
+ batch_txt_tokens.append(txt_token)
243
+ batch_input_seqs.append(input_seq)
244
+ batch_tokens_split_num.append(tokens_split_num)
245
+ padded_inputs, lengths = self.pad_list(batch_input_seqs, self.INPUT_IGNORE_ID)
246
+ return padded_inputs, lengths, batch_txt_tokens, batch_tokens_split_num
247
+
248
+ @classmethod
249
+ def pad_list(cls, input_seqs, pad_value):
250
+ lengths = [len(seq) for seq in input_seqs]
251
+ padded_inputs = torch.zeros(len(input_seqs), max(lengths)).fill_(pad_value).long()
252
+ for i, input_seq in enumerate(input_seqs):
253
+ end = lengths[i]
254
+ padded_inputs[i, :end] = torch.LongTensor(input_seq[:end])
255
+ lengths = torch.IntTensor(lengths)
256
+ return padded_inputs, lengths
257
+
258
+ def add_punc_to_txt(self, token_seqs, pred_seqs):
259
+ punc_txts = []
260
+ for token_seq, pred_seq in zip(token_seqs, pred_seqs):
261
+ assert len(token_seq) == len(pred_seq)
262
+ txt = ""
263
+ for i, token in enumerate(token_seq):
264
+ tag = self.out_dict[pred_seq[i]]
265
+
266
+ # tokenizer_type == "huggingface_bert":
267
+ if token.startswith("##"):
268
+ token = token.replace("##", "")
269
+ elif re.search("[a-zA-Z0-9#]+", token) and \
270
+ i > 0 and re.search("[a-zA-Z0-9#]+", token_seq[i-1]):
271
+ if self.out_dict[pred_seq[i-1]] == self.DEFAULT_OUT:
272
+ token = " " + token
273
+
274
+ if tag == self.DEFAULT_OUT:
275
+ txt += token
276
+ else:
277
+ txt += token + tag
278
+ txt = txt.replace(" ", " ")
279
+ punc_txts.append(txt)
280
+ return punc_txts
281
+
282
+ def add_punc_to_txt_with_timestamp(self, token_seqs, pred_seqs,
283
+ batch_timestamp, batch_tokens_split_num):
284
+ punc_txts = []
285
+ for token_seq, pred_seq, timestamps, tokens_split_num in \
286
+ zip(token_seqs, pred_seqs, batch_timestamp, batch_tokens_split_num):
287
+ assert len(token_seq) == len(pred_seq)
288
+ sentences = []
289
+ txt, start, end = "", -1, -1
290
+
291
+ i = 0
292
+ j = 0
293
+ last_token = ""
294
+ last_tag = ""
295
+ while i < len(token_seq) and j < len(tokens_split_num):
296
+ split_num = tokens_split_num[j]
297
+ timestamp = timestamps[j]
298
+ assert len(timestamp) == 3
299
+ if start == -1:
300
+ start = timestamp[1]
301
+ end = timestamp[2]
302
+
303
+ # Initialize the variables 'token' and 'tag' before each iteration to prevent contamination from the previous word's variables
304
+ token = ""
305
+ tag = self.DEFAULT_OUT
306
+
307
+ for k in range(split_num):
308
+ sub_token = token_seq[i]
309
+ tag = self.out_dict[pred_seq[i]]
310
+ sub_token = re.sub("^##", "", sub_token)
311
+ if k == 0:
312
+ token = sub_token
313
+ else: # k > 0
314
+ token += sub_token
315
+ i += 1
316
+
317
+ # If the word segmenter fails to produce any tokens (for example, the input is an empty string "")
318
+ # Forcefully assign the original string to the token to ensure that the assertion passes and the subsequent logic retains all necessary information
319
+ if split_num == 0:
320
+ token = timestamp[0]
321
+
322
+ assert token == timestamp[0], f"{token}/{timestamp}"
323
+ j += 1
324
+ # Add " " before English & Digit
325
+ if re.search("[a-zA-Z0-9#]+", token) and \
326
+ j > 0 and re.search("[a-zA-Z0-9#]+", last_token):
327
+ if last_tag == self.DEFAULT_OUT:
328
+ token = " " + token
329
+
330
+ if tag == self.DEFAULT_OUT:
331
+ txt += token
332
+ else:
333
+ txt += token + tag
334
+ # Get New sentence
335
+ txt = txt.replace(" ", " ")
336
+ assert start != -1
337
+ sentences.append((txt, start, end))
338
+ txt, start, end = "", -1, -1
339
+ last_token = token
340
+ last_tag = tag
341
+ if txt != "":
342
+ assert start != -1 and end != -1
343
+ sentences.append((txt, start, end))
344
+
345
+ punc_txts.append(sentences)
346
+ return punc_txts
347
+
348
+
349
+ class RuleBaedTxtFix:
350
+ @classmethod
351
+ def fix(cls, txt_ori, capitalize_first=True):
352
+ txt = txt_ori.lower()
353
+ # English Punc
354
+ txt = re.sub(r"([a-z]),([a-z])", r"\1, \2", txt)
355
+ txt = re.sub(r"([a-z])。([a-z])", r"\1. \2", txt)
356
+ txt = re.sub(r"([a-z])?([a-z])", r"\1? \2", txt)
357
+ txt = re.sub(r"([a-z])!([a-z])", r"\1! \2", txt)
358
+ txt = re.sub(r"^([a-z]+),", r"\1,", txt)
359
+ txt = re.sub(r"^([a-z]+)。", r"\1.", txt)
360
+ txt = re.sub(r"^([a-z]+)?", r"\1?", txt)
361
+ txt = re.sub(r"^([a-z]+)!", r"\1!", txt)
362
+ txt = re.sub(r"( [a-zA-Z']+),$", r"\1,", txt)
363
+ txt = re.sub(r"( [a-zA-Z']+)。$", r"\1.", txt)
364
+ txt = re.sub(r"( [a-zA-Z']+)?$", r"\1?", txt)
365
+ txt = re.sub(r"( [a-zA-Z']+)!$", r"\1!", txt)
366
+ # I
367
+ txt = re.sub("^i ", "I ", txt)
368
+ txt = re.sub("^i'm ", "I'm ", txt)
369
+ txt = re.sub("^i'd ", "I'd ", txt)
370
+ txt = re.sub("^i've ", "I've ", txt)
371
+ txt = re.sub("^i'll ", "I'll ", txt)
372
+ txt = re.sub(" i ", " I ", txt)
373
+ txt = re.sub(" i'm ", " I'm ", txt)
374
+ txt = re.sub(" i'd ", " I'd ", txt)
375
+ txt = re.sub(" i've ", " I've ", txt)
376
+ txt = re.sub(" i'll ", " I'll ", txt)
377
+ # First English upper
378
+ if capitalize_first and len(txt) > 0 and re.match("[a-z]", txt[0]):
379
+ txt = txt[0].upper() + txt[1:]
380
+ txt = re.sub(r'([.!?。?!])\s+([a-z])', lambda m: f"{m.group(1)} {m.group(2).upper()}", txt)
381
+
382
+ return txt
383
+
384
+
385
+ if __name__ == "__main__":
386
+ txts = [
387
+ "i'm ok. how are you? i'm fine.",
388
+ "Tim,"
389
+ ]
390
+ for txt in txts:
391
+ txt2 = RuleBaedTxtFix.fix(txt)
fireredasr2s/fireredvad/__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
2
+
3
+ import os
4
+ import sys
5
+
6
+ __version__ = "0.0.1"
7
+
8
+ _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ try:
11
+ from fireredasr2s.fireredvad.aed import FireRedAed, FireRedAedConfig
12
+ from fireredasr2s.fireredvad.stream_vad import FireRedStreamVad, FireRedStreamVadConfig
13
+ from fireredasr2s.fireredvad.vad import FireRedVad, FireRedVadConfig
14
+ except ImportError:
15
+ if _CURRENT_DIR not in sys.path:
16
+ sys.path.insert(0, _CURRENT_DIR)
17
+ from .aed import FireRedAed, FireRedAedConfig
18
+ from .stream_vad import FireRedStreamVad, FireRedStreamVadConfig
19
+ from .vad import FireRedVad, FireRedVadConfig
20
+
21
+
22
+ def non_stream_vad(wav_path, model_dir="pretrained_models/FireRedVAD/VAD", **kwargs):
23
+ """Quick VAD inference"""
24
+ config = FireRedVadConfig(**kwargs)
25
+ vad = FireRedVad.from_pretrained(model_dir, config)
26
+ result, probs = vad.detect(wav_path)
27
+ return result
28
+
29
+
30
+ def stream_vad_full(wav_path, model_dir="pretrained_models/FireRedVAD/Stream-VAD", **kwargs):
31
+ """Quick Stream VAD inference"""
32
+ config = FireRedStreamVadConfig(**kwargs)
33
+ svad = FireRedStreamVad.from_pretrained(model_dir, config)
34
+ frame_results, result = svad.detect_full(wav_path)
35
+ return frame_results, result
36
+
37
+
38
+ def non_stream_aed(wav_path, model_dir="pretrained_models/FireRedVAD/AED", **kwargs):
39
+ """Quick AED inference"""
40
+ config = FireRedAedConfig(**kwargs)
41
+ aed = FireRedAed.from_pretrained(model_dir, config)
42
+ result, probs = aed.detect(wav_path)
43
+ return result
44
+
45
+
46
+ __all__ = [
47
+ '__version__',
48
+ 'FireRedVad',
49
+ 'FireRedVadConfig',
50
+ 'FireRedAed',
51
+ 'FireRedAedConfig',
52
+ 'FireRedStreamVad',
53
+ 'FireRedStreamVadConfig',
54
+ 'non_stream_vad',
55
+ 'stream_vad_full',
56
+ 'non_stream_aed'
57
+ ]
fireredasr2s/fireredvad/aed.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
2
+
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+
9
+ from .core.audio_feat import AudioFeat
10
+ from .core.detect_model import DetectModel
11
+ from .core.vad_postprocessor import VadPostprocessor
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class FireRedAedConfig:
18
+ use_gpu: bool = False
19
+ smooth_window_size: int = 5
20
+ speech_threshold: float = 0.4
21
+ singing_threshold: float = 0.5
22
+ music_threshold: float = 0.5
23
+ min_event_frame: int = 20
24
+ max_event_frame: int = 2000 # 20s
25
+ min_silence_frame: int = 20
26
+ merge_silence_frame: int = 0
27
+ extend_speech_frame: int = 0
28
+ chunk_max_frame: int = 30000 # 300s
29
+
30
+
31
+ class FireRedAed:
32
+ IDX2EVENT = {0: "speech", 1: "singing", 2: "music"}
33
+
34
+ @classmethod
35
+ def from_pretrained(cls, model_dir, config=FireRedAedConfig()):
36
+ # Build Feat Extractor
37
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
38
+ audio_feat = AudioFeat(cmvn_path)
39
+
40
+ # Build Model
41
+ model = DetectModel.from_pretrained(model_dir)
42
+ if config.use_gpu:
43
+ model.cuda()
44
+ else:
45
+ model.cpu()
46
+
47
+ # Build Postprocessor
48
+ event2postprocessor = {}
49
+ for event in cls.IDX2EVENT.values():
50
+ threshold = getattr(config, f"{event}_threshold")
51
+ event2postprocessor[event] = VadPostprocessor(
52
+ config.smooth_window_size,
53
+ threshold,
54
+ config.min_event_frame,
55
+ config.max_event_frame,
56
+ config.min_silence_frame,
57
+ config.merge_silence_frame,
58
+ config.extend_speech_frame)
59
+ return cls(audio_feat, model, event2postprocessor, config)
60
+
61
+ def __init__(self, audio_feat, model, event2postprocessor, config):
62
+ self.audio_feat = audio_feat
63
+ self.model = model
64
+ self.event2postprocessor = event2postprocessor
65
+ self.config = config
66
+
67
+ def detect(self, audio):
68
+ # Extract feat
69
+ feat, dur = self.audio_feat.extract(audio)
70
+ if self.config.use_gpu:
71
+ feat = feat.cuda()
72
+
73
+ # Model inference
74
+ if feat.size(0) <= self.config.chunk_max_frame:
75
+ probs, _ = self.model.forward(feat.unsqueeze(0))
76
+ assert probs.size(-1) == len(self.IDX2EVENT)
77
+ probs = probs.cpu().squeeze(0) # (T,3)
78
+ else:
79
+ logger.warning(f"Too long input, split every {self.config.chunk_max_frame} frames")
80
+ chunk_probs = []
81
+ chunks = feat.split(self.config.chunk_max_frame, dim=0)
82
+ for chunk in chunks:
83
+ chunk_prob, _ = self.model.forward(chunk.unsqueeze(0))
84
+ assert chunk_prob.size(-1) == len(self.IDX2EVENT)
85
+ chunk_probs.append(chunk_prob.cpu())
86
+ probs = torch.cat(chunk_probs, dim=1)
87
+ probs = probs.squeeze(0) # (T,3)
88
+
89
+ # Prob Postprocess
90
+ event2starts_ends_s = {}
91
+ event2raw_ratio = {}
92
+ for idx, event in self.IDX2EVENT.items():
93
+ threshold = getattr(self.config, f"{event}_threshold")
94
+ postprocessor = self.event2postprocessor[event]
95
+ event_probs = probs[:, idx].tolist()
96
+ decision = postprocessor.process(event_probs)
97
+ starts_ends_s = postprocessor.decision_to_segment(decision, dur)
98
+ event2starts_ends_s[event] = starts_ends_s
99
+
100
+ raw_ratio = sum(int(p>= threshold) for p in event_probs) / len(event_probs) if len(event_probs) else 0
101
+ event2raw_ratio[event] = round(raw_ratio, 3)
102
+
103
+ # Format result
104
+ result = {"dur": round(dur, 3),
105
+ "event2timestamps": event2starts_ends_s,
106
+ "event2ratio": event2raw_ratio}
107
+ if isinstance(audio, str):
108
+ result["wav_path"] = audio
109
+ return result, probs
fireredasr2s/fireredvad/bin/__init__.py ADDED
File without changes
fireredasr2s/fireredvad/bin/aed.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
4
+
5
+ import argparse
6
+ import json
7
+ import logging
8
+ import time
9
+
10
+ from fireredvad.aed import FireRedAedConfig, FireRedAed
11
+ from fireredvad.utils.io import get_wav_info, write_event_textgrid, split_and_save_event_segment
12
+
13
+ logging.basicConfig(level=logging.INFO,
14
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
15
+ logger = logging.getLogger("fireredvad.bin.aed")
16
+
17
+
18
+ parser = argparse.ArgumentParser()
19
+ # Input
20
+ parser.add_argument("--wav_path", type=str)
21
+ parser.add_argument("--wav_paths", type=str, nargs="*")
22
+ parser.add_argument("--wav_scp", type=str)
23
+ parser.add_argument("--wav_dir", type=str)
24
+ # Output
25
+ parser.add_argument("--output", type=str, default="aed_output")
26
+ parser.add_argument("--write_textgrid", type=int, default=0)
27
+ parser.add_argument("--save_segment_dir", type=str, default="")
28
+ # AED Options
29
+ parser.add_argument('--model_dir', type=str,
30
+ default="pretrained_models/FireRedVAD-AED-251104")
31
+ parser.add_argument('--use_gpu', type=int, default=0)
32
+ parser.add_argument("--smooth_window_size", type=int, default=5)
33
+ parser.add_argument("--speech_threshold", type=float, default=0.4)
34
+ parser.add_argument("--singing_threshold", type=float, default=0.5)
35
+ parser.add_argument("--music_threshold", type=float, default=0.5)
36
+ parser.add_argument("--min_event_frame", type=int, default=20)
37
+ parser.add_argument("--max_event_frame", type=int, default=2000)
38
+ parser.add_argument("--min_silence_frame", type=int, default=20)
39
+ parser.add_argument("--merge_silence_frame", type=int, default=0)
40
+ parser.add_argument("--extend_speech_frame", type=int, default=0)
41
+ parser.add_argument("--chunk_max_frame", type=int, default=30000)
42
+
43
+
44
+ def main(args):
45
+ logger.info("Start AED...\n")
46
+ wavs = get_wav_info(args)
47
+ fout = open(args.output, "w") if args.output else None
48
+
49
+ aed_config = FireRedAedConfig(
50
+ use_gpu = args.use_gpu,
51
+ smooth_window_size = args.smooth_window_size,
52
+ speech_threshold = args.speech_threshold,
53
+ singing_threshold = args.singing_threshold,
54
+ music_threshold = args.music_threshold,
55
+ min_event_frame = args.min_event_frame,
56
+ max_event_frame = args.max_event_frame,
57
+ min_silence_frame = args.min_silence_frame,
58
+ merge_silence_frame = args.merge_silence_frame,
59
+ extend_speech_frame = args.extend_speech_frame,
60
+ chunk_max_frame = args.chunk_max_frame)
61
+ logger.info(f"{aed_config}")
62
+ aed = FireRedAed.from_pretrained(args.model_dir, aed_config)
63
+
64
+ for i, (uttid, wav_path) in enumerate(wavs):
65
+ logger.info("")
66
+ logger.info(f">>> {i} Processing {wav_path}")
67
+ start_time = time.time()
68
+
69
+ result, probs = aed.detect(wav_path)
70
+
71
+ elapsed = time.time() - start_time
72
+ dur = result["dur"]
73
+ rtf = elapsed / dur if dur > 0 else 0
74
+ logger.info(f"Result: {result}")
75
+ logger.info(f"Dur={dur} elapsed(ms)={round(elapsed*1000, 2)} RTF={round(rtf, 5)}")
76
+
77
+ if fout:
78
+ fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
79
+ if args.write_textgrid:
80
+ write_event_textgrid(result["wav_path"], result["dur"], result["event2timestamps"])
81
+ if args.save_segment_dir:
82
+ split_and_save_event_segment(wav_path, result["event2timestamps"], args.save_segment_dir)
83
+ if fout: fout.close()
84
+
85
+ logger.info("All AED Done")
86
+
87
+
88
+
89
+ if __name__ == "__main__":
90
+ args = parser.parse_args()
91
+ logger.info(f"{args}")
92
+ main(args)
fireredasr2s/fireredvad/bin/fireredvad_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
4
+
5
+ import argparse
6
+ import logging
7
+
8
+ logging.basicConfig(level=logging.INFO,
9
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
10
+ logger = logging.getLogger("fireredvad")
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(description="FireRedVAD: VAD & AED")
15
+ parser.add_argument("--task", type=str, required=True,
16
+ choices=["vad", "stream_vad", "aed"],
17
+ help="Task type: vad, stream_vad, or aed")
18
+ parser.add_argument("--wav_path", type=str, required=True)
19
+ parser.add_argument("--model_dir", type=str, default=None)
20
+ parser.add_argument("--use_gpu", type=int, default=0)
21
+
22
+ args, unknown = parser.parse_known_args()
23
+
24
+ if args.task == "vad":
25
+ from fireredvad import non_stream_vad
26
+ model_dir = args.model_dir or "pretrained_models/FireRedVAD/VAD"
27
+ result = non_stream_vad(args.wav_path, model_dir=model_dir, use_gpu=args.use_gpu)
28
+ elif args.task == "stream_vad":
29
+ from fireredvad import stream_vad_full
30
+ model_dir = args.model_dir or "pretrained_models/FireRedVAD/Stream-VAD"
31
+ result = stream_vad_full(args.wav_path, model_dir=model_dir, use_gpu=args.use_gpu)
32
+ elif args.task == "aed":
33
+ from fireredvad import non_stream_aed
34
+ model_dir = args.model_dir or "pretrained_models/FireRedVAD/AED"
35
+ result = non_stream_aed(args.wav_path, model_dir=model_dir, use_gpu=args.use_gpu)
36
+
37
+ logger.info(f"Result: {result}")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
fireredasr2s/fireredvad/bin/stream_vad.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
4
+
5
+ import argparse
6
+ import json
7
+ import logging
8
+
9
+ import soundfile as sf
10
+
11
+ from fireredvad.core.constants import SAMPLE_RATE, FRAME_LENGTH_SAMPLE, FRAME_SHIFT_SAMPLE
12
+ from fireredvad.stream_vad import FireRedStreamVadConfig, FireRedStreamVad
13
+ from fireredvad.utils.io import get_wav_info, write_textgrid, split_and_save_segment, timeit
14
+
15
+ logging.basicConfig(level=logging.INFO,
16
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
17
+ logger = logging.getLogger("fireredvad.bin.stream_vad")
18
+
19
+
20
+ parser = argparse.ArgumentParser()
21
+ # Input
22
+ parser.add_argument("--wav_path", type=str)
23
+ parser.add_argument("--wav_paths", type=str, nargs="*")
24
+ parser.add_argument("--wav_scp", type=str)
25
+ parser.add_argument("--wav_dir", type=str)
26
+ # Output
27
+ parser.add_argument("--output", type=str, default="vad_output")
28
+ parser.add_argument("--write_textgrid", type=int, default=0)
29
+ parser.add_argument("--save_segment_dir", type=str, default="")
30
+ # VAD Options
31
+ parser.add_argument('--model_dir', type=str,
32
+ default="pretrained_models/FireRedVAD-VAD-stream-251104")
33
+ parser.add_argument('--stream_vad_mode', type=str, default="all",
34
+ choices=["framewise", "chunkwise", "full", "all"])
35
+ parser.add_argument('--stream_chunk_frame', type=int, default=10)
36
+ # Vad Config
37
+ parser.add_argument('--use_gpu', type=int, default=0)
38
+ parser.add_argument("--smooth_window_size", type=int, default=5)
39
+ parser.add_argument("--speech_threshold", type=float, default=0.3)
40
+ parser.add_argument("--pad_start_frame", type=int, default=5)
41
+ parser.add_argument("--min_speech_frame", type=int, default=8)
42
+ parser.add_argument("--max_speech_frame", type=int, default=2000)
43
+ parser.add_argument("--min_silence_frame", type=int, default=20)
44
+ parser.add_argument("--chunk_max_frame", type=int, default=30000)
45
+
46
+
47
+ def main(args):
48
+ logger.info("Start Stream VAD...\n")
49
+ wavs = get_wav_info(args)
50
+ fout = open(args.output, "w") if args.output else None
51
+
52
+ vad_config = FireRedStreamVadConfig(
53
+ use_gpu = args.use_gpu,
54
+ smooth_window_size = args.smooth_window_size,
55
+ speech_threshold = args.speech_threshold,
56
+ pad_start_frame = args.pad_start_frame,
57
+ min_speech_frame = args.min_speech_frame,
58
+ max_speech_frame = args.max_speech_frame,
59
+ min_silence_frame = args.min_silence_frame,
60
+ chunk_max_frame = args.chunk_max_frame)
61
+ logger.info(f"{vad_config}")
62
+ stream_vad = FireRedStreamVad.from_pretrained(args.model_dir, vad_config)
63
+
64
+ for i, (uttid, wav_path) in enumerate(wavs):
65
+ logger.info("")
66
+ logger.info(f">>> {i} Processing {wav_path}")
67
+
68
+ if args.stream_vad_mode in ["all", "full"]:
69
+ results, timestamps, dur = vad_full(wav_path, stream_vad, args)
70
+
71
+ if args.stream_vad_mode in ["all", "chunkwise"]:
72
+ results, timestamps, dur = vad_chunkwise(wav_path, stream_vad, args)
73
+
74
+ if args.stream_vad_mode in ["all", "framewise"]:
75
+ results, timestamps, dur = vad_framewise(wav_path, stream_vad, args)
76
+
77
+ if fout:
78
+ d = {"uttid": uttid, "wav_path": wav_path, "dur": dur, "timestamps": timestamps}
79
+ fout.write(f"{json.dumps(d, ensure_ascii=False)}\n")
80
+ if args.write_textgrid:
81
+ write_textgrid(wav_path, dur, timestamps)
82
+ if args.save_segment_dir:
83
+ split_and_save_segment(wav_path, timestamps, args.save_segment_dir)
84
+ if fout: fout.close()
85
+
86
+ logger.info("All Stream VAD Done")
87
+
88
+
89
+ @timeit
90
+ def vad_framewise(wav_path, stream_vad, args):
91
+ logger.info("Stream VAD Mode: framewise")
92
+
93
+ wav_np, sr = sf.read(wav_path, dtype="int16")
94
+ assert sr == SAMPLE_RATE
95
+ n_frame = 0
96
+ frame_results = []
97
+ stream_vad.reset()
98
+ for j in range(0, len(wav_np) - FRAME_LENGTH_SAMPLE + 1, FRAME_SHIFT_SAMPLE):
99
+ audio_frame = wav_np[j:j+FRAME_LENGTH_SAMPLE]
100
+ result = stream_vad.detect_frame(audio_frame)
101
+ n_frame += 1
102
+ logger.debug(f"{n_frame:4d} {result}")
103
+ if result.is_speech_start:
104
+ logger.info(f"Speech start {result.speech_start_frame}")
105
+ elif result.is_speech_end:
106
+ logger.info(f"Speech end {result.speech_end_frame}")
107
+ frame_results.append(result)
108
+
109
+ logger.info(f"#frame={len(frame_results)}")
110
+ timestamps = stream_vad.results_to_timestamps(frame_results)
111
+ logger.info(f"timestamps(seconds): {timestamps}")
112
+ dur = len(wav_np) / sr
113
+ return frame_results, timestamps, dur
114
+
115
+
116
+ @timeit
117
+ def vad_chunkwise(wav_path, stream_vad, args):
118
+ logger.info(f"Stream VAD Mode: chunkwise {args.stream_chunk_frame}")
119
+ N = args.stream_chunk_frame
120
+ assert N > 0
121
+ chunk_length = FRAME_LENGTH_SAMPLE + (N-1)*FRAME_SHIFT_SAMPLE
122
+ chunk_shift = N * FRAME_SHIFT_SAMPLE
123
+
124
+ wav_np, sr = sf.read(wav_path, dtype="int16")
125
+ assert sr == SAMPLE_RATE
126
+ n_frame = 0
127
+ chunk_results = []
128
+ stream_vad.reset()
129
+ for j in range(0, len(wav_np), chunk_shift):
130
+ audio_chunk = wav_np[j:j+chunk_length]
131
+ results = stream_vad.detect_chunk(audio_chunk)
132
+ for result in results:
133
+ n_frame += 1
134
+ logger.debug(f"{n_frame:4d} {result}")
135
+ if result.is_speech_start:
136
+ logger.info(f"Speech start {result.speech_start_frame}")
137
+ elif result.is_speech_end:
138
+ logger.info(f"Speech end {result.speech_end_frame}")
139
+ chunk_results.append(result)
140
+
141
+ logger.info(f"#frame={len(chunk_results)}")
142
+ timestamps = stream_vad.results_to_timestamps(chunk_results)
143
+ logger.info(f"timestamps(seconds): {timestamps}")
144
+ dur = len(wav_np) / sr
145
+ return chunk_results, timestamps, dur
146
+
147
+
148
+ @timeit
149
+ def vad_full(wav_path, stream_vad, args):
150
+ logger.info("Stream VAD Mode: full")
151
+ frame_results, result = stream_vad.detect_full(wav_path)
152
+ logger.info(f"Result: {result}")
153
+ timestamps = result["timestamps"]
154
+ dur = result["dur"]
155
+
156
+ n_frame = 0
157
+ for frame_result in frame_results:
158
+ n_frame += 1
159
+ logger.debug(f"{n_frame:4d} {result}")
160
+ if frame_result.is_speech_start:
161
+ logger.info(f"Speech start {frame_result.speech_start_frame}")
162
+ elif frame_result.is_speech_end:
163
+ logger.info(f"Speech end {frame_result.speech_end_frame}")
164
+ logger.info(f"#frame={len(frame_results)}")
165
+
166
+ return frame_results, timestamps, dur
167
+
168
+
169
+ if __name__ == "__main__":
170
+ args = parser.parse_args()
171
+ logger.info(f"{args}")
172
+ main(args)