diff --git a/README.md b/README.md
index 08014db4c83d13a0b564ef48b9ff46e8f740653b..bcf5cabb0f6d9b4c3f17303a6cd9e128bdd71657 100644
--- a/README.md
+++ b/README.md
@@ -12,4 +12,514 @@ license: apache-2.0
short_description: A SOTA Industrial-Grade All-in-One ASR system
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+FireRedASR2S
+
+A SOTA Industrial-Grade All-in-One ASR System
+
+
+
+
+[[Paper]](https://arxiv.org/pdf/2603.10420)
+[[Model🤗]](https://huggingface.co/collections/FireRedTeam/fireredasr2s)
+[[Model🤖]](https://www.modelscope.cn/collections/xukaituo/FireRedASR2S)
+[[Demo]](https://huggingface.co/spaces/FireRedTeam/FireRedASR)
+
+
+FireRedASR2S is a state-of-the-art (SOTA), industrial-grade, all-in-one ASR system with ASR, VAD, LID, and Punc modules. All modules achieve SOTA performance:
+- **FireRedASR2**: Automatic Speech Recognition (ASR) supporting peech and singing transcription for Chinese (Mandarin, 20+ dialects/accents), English, code-switching. 2.89% average CER on 4 public Mandarin benchmarks, 11.55% on 19 Chinese dialects and accents benchmarks, **outperforming Doubao-ASR, Qwen3-ASR-1.7B, Fun-ASR, and Fun-ASR-Nano-2512**. FireRedASR2-AED also supports word-level timestamps and confidence scores.
+- **FireRedVAD**: Voice Activity Detection (VAD) supporting speech/singing/music in 100+ languages. 97.57% F1, **outperforming Silero-VAD, TEN-VAD, FunASR-VAD and WebRTC-VAD**. Supports non-streaming/streaming VAD and Multi-label VAD (mVAD).
+- **FireRedLID**: Spoken Language Identification (LID) supporting 100+ languages and 20+ Chinese dialects/accents. 97.18% accuracy, **outperforming Whisper and SpeechBrain**.
+- **FireRedPunc**: Punctuation Prediction (Punc) for Chinese and English. 78.90% average F1, outperforming FunASR-Punc (62.77%).
+
+*`2S`: `2`nd-generation FireRedASR, now expanded to an all-in-one ASR `S`ystem*
+
+
+## 🔥 News
+- [2026.03.12] 🔥 We release FireRedASR2S technical report. See [arXiv](https://arxiv.org/abs/2603.10420).
+- [2026.03.05] 🚀 [vLLM](https://github.com/vllm-project/vllm/pull/35727) supports FireRedASR2-LLM. See [vLLM Usage](https://github.com/FireRedTeam/FireRedASR2S?tab=readme-ov-file#vllm-usage) part.
+- [2026.02.25] 🔥 We release **FireRedASR2-LLM model weights**. [🤗](https://huggingface.co/FireRedTeam/FireRedASR2-LLM) [🤖](https://www.modelscope.cn/models/xukaituo/FireRedASR2-LLM/)
+- [2026.02.13] 🚀 Support TensorRT-LLM inference acceleration for FireRedASR2-AED (contributed by NVIDIA). Benchmark on AISHELL-1 test set shows **12.7x speedup** over PyTorch baseline (single H20).
+- [2026.02.12] 🔥 We release FireRedASR2S (FireRedASR2-AED, FireRedVAD, FireRedLID, and FireRedPunc) with **model weights and inference code**. Download links below. Technical report and finetuning code coming soon.
+
+
+
+## Available Models and Languages
+
+|Model|Supported Languages & Dialects|Download|
+|:-------------:|:---------------------------------:|:----------:|
+|FireRedASR2-LLM| Chinese (Mandarin and 20+ dialects/accents*), English, Code-Switching | [🤗](https://huggingface.co/FireRedTeam/FireRedASR2-LLM) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedASR2-LLM/)|
+|FireRedASR2-AED| Chinese (Mandarin and 20+ dialects/accents*), English, Code-Switching | [🤗](https://huggingface.co/FireRedTeam/FireRedASR2-AED) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedASR2-AED/)|
+|FireRedVAD | 100+ languages, 20+ Chinese dialects/accents* | [🤗](https://huggingface.co/FireRedTeam/FireRedVAD) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedVAD/)|
+|FireRedLID | 100+ languages, 20+ Chinese dialects/accents* | [🤗](https://huggingface.co/FireRedTeam/FireRedLID) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedLID/)|
+|FireRedPunc| Chinese, English | [🤗](https://huggingface.co/FireRedTeam/FireRedPunc) \| [🤖](https://www.modelscope.cn/models/xukaituo/FireRedPunc/)|
+
+*Supported Chinese dialects/accents: Cantonese (Hong Kong & Guangdong), Sichuan, Shanghai, Wu, Minnan, Anhui, Fujian, Gansu, Guizhou, Hebei, Henan, Hubei, Hunan, Jiangxi, Liaoning, Ningxia, Shaanxi, Shanxi, Shandong, Tianjin, Yunnan, etc.
+
+
+
+## Method
+### FireRedASR2S: System Overview
+
+
+### FireRedASR2
+FireRedASR2 builds upon [FireRedASR](https://github.com/FireRedTeam/FireRedASR) with improved accuracy, designed to meet diverse requirements in superior performance and optimal efficiency across various applications. It comprises two variants:
+- **FireRedASR2-LLM**: Designed to achieve state-of-the-art performance and to enable seamless end-to-end speech interaction. It adopts an Encoder-Adapter-LLM framework leveraging large language model (LLM) capabilities.
+- **FireRedASR2-AED**: Designed to balance high performance and computational efficiency and to serve as an effective speech representation module in LLM-based speech models. It utilizes an Attention-based Encoder-Decoder (AED) architecture.
+
+
+
+### Other Modules
+- **FireRedVAD**: DFSMN-based non-streaming/streaming Voice Activity Detection and Multi-label VAD (mVAD). mVAD can be viewed as a lightweight Audio Event Detection (AED) system specialized for a small set of sound categories (speech/singing/music).
+- **FireRedLID**: Encoder-Decoder-based Spoken Language Identification. See [FireRedLID README](./fireredasr2s/fireredlid/README.md) for language details.
+- **FireRedPunc**: BERT-based Punctuation Prediction.
+
+
+## Quick Start
+### Setup
+1. Create a clean Python environment:
+```bash
+$ conda create --name fireredasr2s python=3.10
+$ conda activate fireredasr2s
+$ git clone https://github.com/FireRedTeam/FireRedASR2S.git
+$ cd FireRedASR2S # or fireredasr2s
+```
+
+2. Install dependencies and set up PATH and PYTHONPATH:
+```bash
+$ pip install -r requirements.txt
+$ export PATH=$PWD/fireredasr2s/:$PATH
+$ export PYTHONPATH=$PWD/:$PYTHONPATH
+```
+
+3. Download models:
+```bash
+# Download via ModelScope (recommended for users in China)
+pip install -U modelscope
+modelscope download --model xukaituo/FireRedASR2-AED --local_dir ./pretrained_models/FireRedASR2-AED
+modelscope download --model xukaituo/FireRedVAD --local_dir ./pretrained_models/FireRedVAD
+modelscope download --model xukaituo/FireRedLID --local_dir ./pretrained_models/FireRedLID
+modelscope download --model xukaituo/FireRedPunc --local_dir ./pretrained_models/FireRedPunc
+modelscope download --model xukaituo/FireRedASR2-LLM --local_dir ./pretrained_models/FireRedASR2-LLM
+
+# Download via Hugging Face
+pip install -U "huggingface_hub[cli]"
+huggingface-cli download FireRedTeam/FireRedASR2-AED --local-dir ./pretrained_models/FireRedASR2-AED
+huggingface-cli download FireRedTeam/FireRedVAD --local-dir ./pretrained_models/FireRedVAD
+huggingface-cli download FireRedTeam/FireRedLID --local-dir ./pretrained_models/FireRedLID
+huggingface-cli download FireRedTeam/FireRedPunc --local-dir ./pretrained_models/FireRedPunc
+huggingface-cli download FireRedTeam/FireRedASR2-LLM --local-dir ./pretrained_models/FireRedASR2-LLM
+```
+
+4. Convert your audio to **16kHz 16-bit mono PCM** format if needed:
+```bash
+$ ffmpeg -i -ar 16000 -ac 1 -acodec pcm_s16le -f wav
+```
+
+### Script Usage
+```bash
+$ cd examples_infer/asr_system
+$ bash inference_asr_system.sh
+```
+
+### Command-line Usage
+```bash
+$ fireredasr2s-cli --help
+$ fireredasr2s-cli --wav_paths "assets/hello_zh.wav" "assets/hello_en.wav" --outdir output
+$ cat output/result.jsonl
+# {"uttid": "hello_zh", "text": "你好世界。", "sentences": [{"start_ms": 310, "end_ms": 1840, "text": "你好世界。", "asr_confidence": 0.875, "lang": "zh mandarin", "lang_confidence": 0.999}], "vad_segments_ms": [[310, 1840]], "dur_s": 2.32, "words": [{"start_ms": 490, "end_ms": 690, "text": "你"}, {"start_ms": 690, "end_ms": 1090, "text": "好"}, {"start_ms": 1090, "end_ms": 1330, "text": "世"}, {"start_ms": 1330, "end_ms": 1795, "text": "界"}], "wav_path": "assets/hello_zh.wav"}
+# {"uttid": "hello_en", "text": "Hello speech.", "sentences": [{"start_ms": 120, "end_ms": 1840, "text": "Hello speech.", "asr_confidence": 0.833, "lang": "en", "lang_confidence": 0.998}], "vad_segments_ms": [[120, 1840]], "dur_s": 2.24, "words": [{"start_ms": 340, "end_ms": 1020, "text": "hello"}, {"start_ms": 1020, "end_ms": 1666, "text": "speech"}], "wav_path": "assets/hello_en.wav"}
+```
+
+### Python API Usage
+```python
+from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
+
+asr_system_config = FireRedAsr2SystemConfig() # Use default config
+asr_system = FireRedAsr2System(asr_system_config)
+
+result = asr_system.process("assets/hello_zh.wav")
+print(result)
+# {'uttid': 'tmpid', 'text': '你好世界。', 'sentences': [{'start_ms': 440, 'end_ms': 1820, 'text': '你好世界。', 'asr_confidence': 0.868, 'lang': 'zh mandarin', 'lang_confidence': 0.999}], 'vad_segments_ms': [(440, 1820)], 'dur_s': 2.32, 'words': [], 'wav_path': 'assets/hello_zh.wav'}
+
+result = asr_system.process("assets/hello_en.wav")
+print(result)
+# {'uttid': 'tmpid', 'text': 'Hello speech.', 'sentences': [{'start_ms': 260, 'end_ms': 1820, 'text': 'Hello speech.', 'asr_confidence': 0.933, 'lang': 'en', 'lang_confidence': 0.993}], 'vad_segments_ms': [(260, 1820)], 'dur_s': 2.24, 'words': [], 'wav_path': 'assets/hello_en.wav'}
+```
+
+
+
+## Usage of Each Module
+The four components under `fireredasr2s`, i.e. `fireredasr2`, `fireredvad`, `fireredlid`, and `fireredpunc` are self-contained and designed to work as a standalone modules. You can use any of them independently without depending on the others. `FireRedVAD` and `FireRedLID` will also be open-sourced as standalone libraries in separate repositories.
+
+### Script Usage
+```bash
+# ASR
+$ cd examples_infer/asr
+$ bash inference_asr_aed.sh
+$ bash inference_asr_llm.sh
+
+# VAD & mVAD (mVAD=Audio Event Detection, AED)
+$ cd examples_infer/vad
+$ bash inference_vad.sh
+$ bash inference_streamvad.sh
+$ bash inference_aed.sh
+
+# LID
+$ cd examples_infer/lid
+$ bash inference_lid.sh
+
+# Punc
+$ cd examples_infer/punc
+$ bash inference_punc.sh
+```
+
+### vLLM Usage
+```shell
+# Serving FireRedASR2-LLM with latest vLLM for the highest performance.
+# For more details, see https://github.com/vllm-project/vllm/pull/35727.
+$ vllm serve allendou/FireRedASR2-LLM-vllm -tp=2 --dtype=float32
+$ python3 examples/online_serving/openai_transcription_client.py --repetition_penalty=1.0 --audio_path=/root/hello_zh.wav
+```
+
+### Python API Usage
+Set up `PYTHONPATH` first: `export PYTHONPATH=$PWD/:$PYTHONPATH`
+
+#### ASR
+```python
+from fireredasr2s.fireredasr2 import FireRedAsr2, FireRedAsr2Config
+
+batch_uttid = ["hello_zh", "hello_en"]
+batch_wav_path = ["assets/hello_zh.wav", "assets/hello_en.wav"]
+
+# FireRedASR2-AED
+asr_config = FireRedAsr2Config(
+ use_gpu=True,
+ use_half=False,
+ beam_size=3,
+ nbest=1,
+ decode_max_len=0,
+ softmax_smoothing=1.25,
+ aed_length_penalty=0.6,
+ eos_penalty=1.0,
+ return_timestamp=True
+)
+model = FireRedAsr2.from_pretrained("aed", "pretrained_models/FireRedASR2-AED", asr_config)
+results = model.transcribe(batch_uttid, batch_wav_path)
+print(results)
+# [{'uttid': 'hello_zh', 'text': '你好世界', 'confidence': 0.971, 'dur_s': 2.32, 'rtf': '0.0870', 'wav': 'assets/hello_zh.wav', 'timestamp': [('你', 0.42, 0.66), ('好', 0.66, 1.1), ('世', 1.1, 1.34), ('界', 1.34, 2.039)]}, {'uttid': 'hello_en', 'text': 'hello speech', 'confidence': 0.943, 'dur_s': 2.24, 'rtf': '0.0870', 'wav': 'assets/hello_en.wav', 'timestamp': [('hello', 0.34, 0.98), ('speech', 0.98, 1.766)]}]
+
+# FireRedASR2-LLM
+asr_config = FireRedAsr2Config(
+ use_gpu=True,
+ decode_min_len=0,
+ repetition_penalty=1.0,
+ llm_length_penalty=0.0,
+ temperature=1.0
+)
+model = FireRedAsr2.from_pretrained("llm", "pretrained_models/FireRedASR2-LLM", asr_config)
+results = model.transcribe(batch_uttid, batch_wav_path)
+print(results)
+# [{'uttid': 'hello_zh', 'text': '你好世界', 'rtf': '0.0681', 'wav': 'assets/hello_zh.wav'}, {'uttid': 'hello_en', 'text': 'hello speech', 'rtf': '0.0681', 'wav': 'assets/hello_en.wav'}]
+```
+
+
+#### VAD
+```python
+from fireredasr2s.fireredvad import FireRedVad, FireRedVadConfig
+
+vad_config = FireRedVadConfig(
+ use_gpu=False,
+ smooth_window_size=5,
+ speech_threshold=0.4,
+ min_speech_frame=20,
+ max_speech_frame=2000,
+ min_silence_frame=20,
+ merge_silence_frame=0,
+ extend_speech_frame=0,
+ chunk_max_frame=30000)
+vad = FireRedVad.from_pretrained("pretrained_models/FireRedVAD/VAD", vad_config)
+
+result, probs = vad.detect("assets/hello_zh.wav")
+
+print(result)
+# {'dur': 2.32, 'timestamps': [(0.44, 1.82)], 'wav_path': 'assets/hello_zh.wav'}
+```
+
+
+#### Stream VAD
+
+Click to expand
+
+```python
+from fireredasr2s.fireredvad import FireRedStreamVad, FireRedStreamVadConfig
+
+vad_config=FireRedStreamVadConfig(
+ use_gpu=False,
+ smooth_window_size=5,
+ speech_threshold=0.4,
+ pad_start_frame=5,
+ min_speech_frame=8,
+ max_speech_frame=2000,
+ min_silence_frame=20,
+ chunk_max_frame=30000)
+stream_vad = FireRedStreamVad.from_pretrained("pretrained_models/FireRedVAD/Stream-VAD", vad_config)
+
+frame_results, result = stream_vad.detect_full("assets/hello_zh.wav")
+
+print(result)
+# {'dur': 2.32, 'timestamps': [(0.46, 1.84)], 'wav_path': 'assets/hello_zh.wav'}
+```
+
+
+
+#### mVAD (Audio Event Detection, AED)
+
+Click to expand
+
+```python
+from fireredasr2s.fireredvad import FireRedAed, FireRedAedConfig
+
+aed_config=FireRedAedConfig(
+ use_gpu=False,
+ smooth_window_size=5,
+ speech_threshold=0.4,
+ singing_threshold=0.5,
+ music_threshold=0.5,
+ min_event_frame=20,
+ max_event_frame=2000,
+ min_silence_frame=20,
+ merge_silence_frame=0,
+ extend_speech_frame=0,
+ chunk_max_frame=30000)
+aed = FireRedAed.from_pretrained("pretrained_models/FireRedVAD/AED", aed_config)
+
+result, probs = aed.detect("assets/event.wav")
+
+print(result)
+# {'dur': 22.016, 'event2timestamps': {'speech': [(0.4, 3.56), (3.66, 9.08), (9.27, 9.77), (10.78, 21.76)], 'singing': [(1.79, 19.96), (19.97, 22.016)], 'music': [(0.09, 12.32), (12.33, 22.016)]}, 'event2ratio': {'speech': 0.848, 'singing': 0.905, 'music': 0.991}, 'wav_path': 'assets/event.wav'}
+```
+
+
+
+#### LID
+
+Click to expand
+
+```python
+from fireredasr2s.fireredlid import FireRedLid, FireRedLidConfig
+
+batch_uttid = ["hello_zh", "hello_en"]
+batch_wav_path = ["assets/hello_zh.wav", "assets/hello_en.wav"]
+
+config = FireRedLidConfig(use_gpu=True, use_half=False)
+model = FireRedLid.from_pretrained("pretrained_models/FireRedLID", config)
+
+results = model.process(batch_uttid, batch_wav_path)
+print(results)
+# [{'uttid': 'hello_zh', 'lang': 'zh mandarin', 'confidence': 0.996, 'dur_s': 2.32, 'rtf': '0.0741', 'wav': 'assets/hello_zh.wav'}, {'uttid': 'hello_en', 'lang': 'en', 'confidence': 0.996, 'dur_s': 2.24, 'rtf': '0.0741', 'wav': 'assets/hello_en.wav'}]
+```
+
+
+
+#### Punc
+
+Click to expand
+
+```python
+from fireredasr2s.fireredpunc.punc import FireRedPunc, FireRedPuncConfig
+
+config = FireRedPuncConfig(use_gpu=True)
+model = FireRedPunc.from_pretrained("pretrained_models/FireRedPunc", config)
+
+batch_text = ["你好世界", "Hello world"]
+results = model.process(batch_text)
+
+print(results)
+# [{'punc_text': '你好世界。', 'origin_text': '你好世界'}, {'punc_text': 'Hello world!', 'origin_text': 'Hello world'}]
+```
+
+
+
+#### ASR System
+```python
+from fireredasr2s.fireredasr2 import FireRedAsr2Config
+from fireredasr2s.fireredlid import FireRedLidConfig
+from fireredasr2s.fireredpunc import FireRedPuncConfig
+from fireredasr2s.fireredvad import FireRedVadConfig
+from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
+
+vad_config = FireRedVadConfig(
+ use_gpu=False,
+ smooth_window_size=5,
+ speech_threshold=0.4,
+ min_speech_frame=20,
+ max_speech_frame=2000,
+ min_silence_frame=20,
+ merge_silence_frame=0,
+ extend_speech_frame=0,
+ chunk_max_frame=30000
+)
+lid_config = FireRedLidConfig(use_gpu=True, use_half=False)
+asr_config = FireRedAsr2Config(
+ use_gpu=True,
+ use_half=False,
+ beam_size=3,
+ nbest=1,
+ decode_max_len=0,
+ softmax_smoothing=1.25,
+ aed_length_penalty=0.6,
+ eos_penalty=1.0,
+ return_timestamp=True
+)
+punc_config = FireRedPuncConfig(use_gpu=True)
+
+asr_system_config = FireRedAsr2SystemConfig(
+ "pretrained_models/FireRedVAD/VAD",
+ "pretrained_models/FireRedLID",
+ "aed", "pretrained_models/FireRedASR2-AED",
+ "pretrained_models/FireRedPunc",
+ vad_config, lid_config, asr_config, punc_config,
+ enable_vad=1, enable_lid=1, enable_punc=1
+)
+asr_system = FireRedAsr2System(asr_system_config)
+
+batch_uttid = ["hello_zh", "hello_en"]
+batch_wav_path = ["assets/hello_zh.wav", "assets/hello_en.wav"]
+for wav_path, uttid in zip(batch_wav_path, batch_uttid):
+ result = asr_system.process(wav_path, uttid)
+ print(result)
+# {'uttid': 'hello_zh', 'text': '你好世界。', 'sentences': [{'start_ms': 440, 'end_ms': 1820, 'text': '你好世界。', 'asr_confidence': 0.868, 'lang': 'zh mandarin', 'lang_confidence': 0.999}], 'vad_segments_ms': [(440, 1820)], 'dur_s': 2.32, 'words': [{'start_ms': 540, 'end_ms': 700, 'text': '你'}, {'start_ms': 700, 'end_ms': 1100, 'text': '好'}, {'start_ms': 1100, 'end_ms': 1300, 'text': '世'}, {'start_ms': 1300, 'end_ms': 1765, 'text': '界'}], 'wav_path': 'assets/hello_zh.wav'}
+# {'uttid': 'hello_en', 'text': 'Hello speech.', 'sentences': [{'start_ms': 260, 'end_ms': 1820, 'text': 'Hello speech.', 'asr_confidence': 0.933, 'lang': 'en', 'lang_confidence': 0.993}], 'vad_segments_ms': [(260, 1820)], 'dur_s': 2.24, 'words': [{'start_ms': 400, 'end_ms': 960, 'text': 'hello'}, {'start_ms': 960, 'end_ms': 1666, 'text': 'speech'}], 'wav_path': 'assets/hello_en.wav'}
+```
+
+**Note:** `FireRedASR2S` code has only been tested on Linux Ubuntu 22.04. Behavior on other Linux distributions or Windows has not been tested.
+
+
+## FAQ
+**Q: What audio format is supported?**
+
+16kHz 16-bit mono PCM wav. Use ffmpeg to convert other formats: `ffmpeg -i -ar 16000 -ac 1 -acodec pcm_s16le -f wav `
+
+**Q: What are the input length limitations of ASR models?**
+
+- **FireRedASR2-AED** supports audio input **up to 60s**. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors.
+- **FireRedASR2-LLM** supports audio input **up to 40s**. The behavior for longer input is untested.
+- **FireRedASR2-LLM Batch Beam Search**: When performing batch beam search with FireRedASR2-LLM, even though attention masks are applied, it is recommended to ensure that the input lengths of the utterances are similar. If there are significant differences in utterance lengths, shorter utterances may experience **repetition issues**. You can either sort your dataset by length or set `batch_size` to 1 to avoid the repetition issue.
+
+
+
+## Evaluation
+### FireRedASR2
+Metrics: Character Error Rate (CER%) for Chinese and Word Error Rate (WER%) for English. Lower is better.
+
+We evaluate FireRedASR2 on 24 public test sets covering Mandarin, 20+ Chinese dialects/accents, and singing.
+
+- **Mandarin (4 test sets)**: 2.89% (LLM) / 3.05% (AED) average CER, outperforming Doubao-ASR (3.69%), Qwen3-ASR-1.7B (3.76%), Fun-ASR (4.16%) and Fun-ASR-Nano-2512 (4.55%).
+- **Dialects (19 test sets)**: 11.55% (LLM) / 11.67% (AED) average CER, outperforming Doubao-ASR (15.39%), Qwen3-ASR-1.7B (11.85%), Fun-ASR (12.76%) and Fun-ASR-Nano-2512 (15.07%).
+
+*Note: FRASR2=FireRedASR2, ws=WenetSpeech, md=MagicData, conv=Conversational, daily=Daily-use.*
+
+|ID|Testset\CER\Model|FRASR2-LLM|FRASR2-AED|Doubao-ASR|Qwen3-ASR|Fun-ASR|
+|:--:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|
+|Avg|**All(1-24)** |**9.67** |**9.80** |12.98 |10.12 |10.92 |
+|Avg|**Mandarin(1-4)** |**2.89** |**3.05** |3.69 |3.76 |4.16 |
+|Avg|**Dialect(5-23)** |**11.55**|**11.67**|15.39|11.85|12.76|
+|1 |aishell1 |0.64 |0.57 |1.52 |1.48 |1.64 |
+|2 |aishell2 |2.15 |2.51 |2.77 |2.71 |2.38 |
+|3 |ws-net |4.44 |4.57 |5.73 |4.97 |6.85 |
+|4 |ws-meeting |4.32 |4.53 |4.74 |5.88 |5.78 |
+|5 |kespeech |3.08 |3.60 |5.38 |5.10 |5.36 |
+|6 |ws-yue-short |5.14 |5.15 |10.51|5.82 |7.34 |
+|7 |ws-yue-long |8.71 |8.54 |11.39|8.85 |10.14|
+|8 |ws-chuan-easy |10.90|10.60|11.33|11.99|12.46|
+|9 |ws-chuan-hard |20.71|21.35|20.77|21.63|22.49|
+|10|md-heavy |7.42 |7.43 |7.69 |8.02 |9.13 |
+|11|md-yue-conv |12.23|11.66|26.25|9.76 |33.71|
+|12|md-yue-daily |3.61 |3.35 |12.82|3.66 |2.69 |
+|13|md-yue-vehicle |4.50 |4.83 |8.66 |4.28 |6.00 |
+|14|md-chuan-conv |13.18|13.07|11.77|14.35|14.01|
+|15|md-chuan-daily |4.90 |5.17 |3.90 |4.93 |3.98 |
+|16|md-shanghai-conv |28.70|27.02|45.15|29.77|25.49|
+|17|md-shanghai-daily |24.94|24.18|44.06|23.93|12.55|
+|18|md-wu |7.15 |7.14 |7.70 |7.57 |10.63|
+|19|md-zhengzhou-conv |10.20|10.65|9.83 |9.55 |10.85|
+|20|md-zhengzhou-daily|5.80 |6.26 |5.77 |5.88 |6.29 |
+|21|md-wuhan |9.60 |10.81|9.94 |10.22|4.34 |
+|22|md-tianjin |15.45|15.30|15.79|16.16|19.27|
+|23|md-changsha |23.18|25.64|23.76|23.70|25.66|
+|24|opencpop |1.12 |1.17 |4.36 |2.57 |3.05 |
+
+
+### FireRedVAD
+
+Click to expand
+We evaluate FireRedVAD on FLEURS-VAD-102, a multilingual VAD benchmark covering 102 languages.
+
+FireRedVAD achieves SOTA performance, outperforming Silero-VAD, TEN-VAD, FunASR-VAD, and WebRTC-VAD.
+
+|Metric\Model|FireRedVAD|Silero-VAD|TEN-VAD|FunASR-VAD|WebRTC-VAD|
+|:-------:|:-----:|:------:|:------:|:------:|:------:|
+|AUC-ROC↑ |**99.60**|97.99|97.81|- |- |
+|F1 score↑ |**97.57**|95.95|95.19|90.91|52.30|
+|False Alarm Rate↓ |**2.69** |9.41 |15.47|44.03|2.83 |
+|Miss Rate↓|3.62 |3.95 |2.95 |0.42 |64.15|
+
+FLEURS-VAD-102: We randomly selected ~100 audio files per language from [FLEURS test set](https://huggingface.co/datasets/google/fleurs), resulting in 9,443 audio files with manually annotated binary VAD labels (speech=1, silence=0). This VAD testset will be open sourced (coming soon).
+
+Note: FunASR-VAD achieves low Miss Rate but at the cost of high False Alarm Rate (44.03%), indicating over-prediction of speech segments.
+
+
+
+### FireRedLID
+
+Click to expand
+Metric: Utterance-level LID Accuracy (%). Higher is better.
+
+We evaluate FireRedLID on multilingual and Chinese dialect benchmarks.
+
+FireRedLID achieves SOTA performance, outperforming Whisper, SpeechBrain-LID, and Dolphin.
+
+|Testset\Model|Languages|FireRedLID|Whisper|SpeechBrain|Dolphin|
+|:-----------------:|:---------:|:---------:|:-----:|:---------:|:-----:|
+|FLEURS test |82 languages |**97.18** |79.41 |92.91 |-|
+|CommonVoice test |74 languages |**92.07** |80.81 |78.75 |-|
+|KeSpeech + MagicData|20+ Chinese dialects/accents |**88.47** |-|-|69.01|
+
+
+
+### FireRedPunc
+
+Click to expand
+Metric: Precision/Recall/F1 Score (%). Higher is better.
+
+We evaluate FireRedPunc on multi-domain Chinese and English benchmarks.
+
+FireRedPunc achieves SOTA performance, outperforming FunASR-Punc (CT-Transformer).
+
+|Testset\Model|#Sentences|FireRedPunc|FunASR-Punc|
+|:------------------:|:---------:|:--------------:|:-----------------:|
+|Multi-domain Chinese| 88,644 |**82.84 / 83.08 / 82.96** | 77.27 / 74.03 / 75.62 |
+|Multi-domain English| 28,641 |**78.40 / 71.57 / 74.83** | 55.79 / 45.15 / 49.91 |
+|Average F1 Score | - |**78.90** | 62.77 |
+
+
+
+
+## Acknowledgements
+Thanks to the following open-source works:
+- [Qwen](https://huggingface.co/Qwen)
+- [WenetSpeech-Yue](https://github.com/ASLP-lab/WenetSpeech-Yue)
+- [WenetSpeech-Chuan](https://github.com/ASLP-lab/WenetSpeech-Chuan)
+
+
+## Citation
+```bibtex
+@article{xu2026fireredasr2s,
+ title={FireRedASR2S: A State-of-the-Art Industrial-Grade All-in-One Automatic Speech Recognition System},
+ author={Xu, Kaituo and Jia, Yan and Huang, Kai and Chen, Junjie and Li, Wenpeng and Liu, Kun and Xie, Feng-Long and Tang, Xu and Hu, Yao},
+ journal={arXiv preprint arXiv:2603.10420},
+ year={2026}
+}
+```
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be95b802c23906eb2ea7de7da8ee7dbeb45e88a
--- /dev/null
+++ b/app.py
@@ -0,0 +1,111 @@
+import sys
+
+import gradio as gr
+import spaces
+from huggingface_hub import snapshot_download
+
+sys.path.append("./fireredasr2s")
+from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
+from fireredasr2s import FireRedAsr2, FireRedAsr2Config
+
+
+asr_model_aed = None
+asr_model_llm = None
+
+
+def init_model(model_dir_aed, model_dir_llm):
+ global asr_model_aed
+ global asr_model_llm
+ if asr_model_aed is None:
+ asr_config_aed = FireRedAsr2Config(
+ use_gpu=True,
+ use_half=False,
+ beam_size=3,
+ nbest=1,
+ decode_max_len=0,
+ softmax_smoothing=1.25,
+ aed_length_penalty=0.6,
+ eos_penalty=1.0,
+ return_timestamp=True
+ )
+ asr_model_aed = FireRedAsr2.from_pretrained("aed", model_dir_aed, asr_config_aed)
+ if asr_model_llm is None:
+ asr_config_llm = FireRedAsr2Config(
+ use_gpu=True,
+ decode_min_len=0,
+ repetition_penalty=3.0,
+ llm_length_penalty=1.0,
+ temperature=1.0
+ )
+ asr_model_llm = FireRedAsr2.from_pretrained("llm", model_dir_llm, asr_config_llm)
+
+
+@spaces.GPU(duration=20)
+def asr_inference(audio_file):
+ if not audio_file:
+ return "Please upload a wav file"
+ batch_uttid = ["demo"]
+ batch_wav_path = [audio_file]
+ results = asr_model_aed.transcribe(
+ batch_uttid,
+ batch_wav_path
+ )
+ text_output = results[0]["text"]
+ return text_output
+
+
+@spaces.GPU(duration=30)
+def asr_inference_llm(audio_file):
+ if not audio_file:
+ return "Please upload a wav file"
+ batch_uttid = ["demo"]
+ batch_wav_path = [audio_file]
+ results = asr_model_llm.transcribe(
+ batch_uttid,
+ batch_wav_path,
+ )
+ text_output = results[0]["text"]
+ return text_output
+
+
+with gr.Blocks(title="FireRedASR") as demo:
+ gr.HTML(
+ "FireRedASR2 Demo
"
+ )
+ gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
+
+ with gr.Row():
+ with gr.Column():
+ #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
+ audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
+
+ with gr.Column():
+ asr_button = gr.Button("Start Recognition (FireRedASR2-AED-L)", variant="primary")
+ text_output = gr.Textbox(label="Model Result (FireRedASR2-AED-L)", interactive=False, lines=3, max_lines=12)
+ asr_button_llm = gr.Button("Start Recognition (FireRedASR2-LLM-L)", variant="primary")
+ text_output_llm = gr.Textbox(label="Model Result (FireRedASR2-LLM-L)", interactive=False, lines=3, max_lines=12)
+
+ asr_button.click(
+ fn=asr_inference,
+ inputs=[audio_file],
+ outputs=[text_output]
+ )
+
+ asr_button_llm.click(
+ fn=asr_inference_llm,
+ inputs=[audio_file],
+ outputs=[text_output_llm]
+ )
+
+
+if __name__ == "__main__":
+ # Download model
+ local_dir='pretrained_models/FireRedASR2-AED-L'
+ snapshot_download(repo_id='FireRedTeam/FireRedASR2-AED-L', local_dir=local_dir)
+ local_dir_llm='pretrained_models/FireRedASR2-LLM-L'
+ snapshot_download(repo_id='FireRedTeam/FireRedASR2-LLM-L', local_dir=local_dir_llm)
+ # Init model
+ init_model(local_dir, local_dir_llm)
+ # UI
+ demo.queue()
+ demo.launch()
diff --git a/fireredasr2s/__init__.py b/fireredasr2s/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a684a0614e23b51c6467d26b9cede5415666880
--- /dev/null
+++ b/fireredasr2s/__init__.py
@@ -0,0 +1,40 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
+
+import os
+import sys
+import warnings
+warnings.filterwarnings('ignore')
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=ALL, 1=INFO, 2=WARNING, 3=ERROR
+os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
+
+__version__ = "0.0.1"
+
+_PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__))
+_PROJECT_ROOT = os.path.dirname(_PACKAGE_DIR)
+if _PROJECT_ROOT not in sys.path:
+ sys.path.insert(0, _PROJECT_ROOT)
+
+from fireredasr2s.fireredasr2system import (
+ FireRedAsr2System,
+ FireRedAsr2SystemConfig
+)
+
+
+# API
+__all__ = [
+ "__version__",
+ "FireRedAsr2System",
+ "FireRedAsr2SystemConfig",
+ "FireRedAsr2",
+ "FireRedAsr2Config",
+ "FireRedVad",
+ "FireRedVadConfig",
+ "FireRedStreamVad",
+ "FireRedStreamVadConfig",
+ "FireRedAed",
+ "FireRedAedConfig",
+ "FireRedLid",
+ "FireRedLidConfig",
+ "FireRedPunc",
+ "FireRedPuncConfig",
+]
diff --git a/fireredasr2s/fireredasr2/__init__.py b/fireredasr2s/fireredasr2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..958874e94b6f1800511fdacfd8be8b51a8a60ed7
--- /dev/null
+++ b/fireredasr2s/fireredasr2/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import os
+import sys
+import warnings
+warnings.filterwarnings('ignore')
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=ALL, 1=INFO, 2=WARNING, 3=ERROR
+os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
+
+__version__ = "0.0.1"
+
+_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+try:
+ from fireredasr2s.fireredasr2.asr import FireRedAsr2, FireRedAsr2Config
+except ImportError:
+ if _CURRENT_DIR not in sys.path:
+ sys.path.insert(0, _CURRENT_DIR)
+ from .asr import FireRedAsr2, FireRedAsr2Config
+
+
+# API
+__all__ = [
+ "__version__",
+ "FireRedAsr2",
+ "FireRedAsr2Config",
+]
diff --git a/fireredasr2s/fireredasr2/asr.py b/fireredasr2s/fireredasr2/asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb90a344475e2c32d1eb075a2011d25fb8464429
--- /dev/null
+++ b/fireredasr2s/fireredasr2/asr.py
@@ -0,0 +1,241 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import logging
+import os
+import re
+import time
+import traceback
+from dataclasses import dataclass
+
+import torch
+
+from .data.asr_feat import ASRFeatExtractor
+from .models.fireredasr_aed import FireRedAsrAed
+from .models.fireredasr_llm import FireRedAsrLlm
+from .models.lstm_lm import LstmLm
+from .models.param import count_model_parameters
+from .tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
+from .tokenizer.llm_tokenizer import LlmTokenizerWrapper
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FireRedAsr2Config:
+ use_gpu: bool = True
+ use_half: bool = False
+ beam_size: int = 3
+ nbest: int = 1
+ decode_max_len: int = 0
+ softmax_smoothing: float = 1.25
+ aed_length_penalty: float = 0.6
+ eos_penalty: float = 1.0
+ return_timestamp: bool = False
+ decode_min_len: bool = 0
+ repetition_penalty: float = 1.0
+ llm_length_penalty: float = 0.0
+ temperature: float = 1.0
+ elm_dir: str = ""
+ elm_weight: float = 0.0
+ def __post_init__(self):
+ pass
+
+
+class FireRedAsr2:
+ @classmethod
+ def from_pretrained(cls, asr_type, model_dir, config=FireRedAsr2Config()):
+ assert asr_type in ["aed", "llm"]
+
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
+ feat_extractor = ASRFeatExtractor(cmvn_path)
+
+ if asr_type == "aed":
+ model_path = os.path.join(model_dir, "model.pth.tar")
+ dict_path =os.path.join(model_dir, "dict.txt")
+ spm_model = os.path.join(model_dir, "train_bpe1000.model")
+ model = load_fireredasr_aed_model(model_path)
+ tokenizer = ChineseCharEnglishSpmTokenizer(dict_path, spm_model)
+ elif asr_type == "llm":
+ model_path = os.path.join(model_dir, "model.pth.tar")
+ encoder_path = os.path.join(model_dir, "asr_encoder.pth.tar")
+ llm_dir = os.path.join(model_dir, "Qwen2-7B-Instruct")
+ model, tokenizer = load_firered_llm_model_and_tokenizer(
+ model_path, encoder_path, llm_dir)
+ elm = None
+ if config.elm_dir:
+ assert os.path.exists(config.elm_dir), f"{config.elm_dir}"
+ model_path = os.path.join(config.elm_dir, "model.pth.tar")
+ elm = load_lstm_lm(model_path)
+ elm.eval()
+ logger.info(elm)
+ count_model_parameters(model)
+ model.eval()
+ return cls(asr_type, feat_extractor, model, tokenizer, elm, config)
+
+ def __init__(self, asr_type, feat_extractor, model, tokenizer, elm, config):
+ self.asr_type = asr_type
+ self.feat_extractor = feat_extractor
+ self.model = model
+ self.tokenizer = tokenizer
+ self.elm = elm
+ self.config = config
+ logger.info(self.config)
+ if self.config.use_gpu:
+ if self.config.use_half:
+ self.model.half()
+ self.model.cuda()
+ if self.elm:
+ self.elm.cuda()
+ else:
+ self.model.cpu()
+
+ @torch.no_grad()
+ def transcribe(self, batch_uttid, batch_wav_path):
+ batch_uttid_origin = batch_uttid
+ try:
+ feats, lengths, durs, batch_wav_path, batch_uttid = \
+ self.feat_extractor(batch_wav_path, batch_uttid)
+ if feats is None:
+ return [{"uttid": uttid, "text":""} for uttid in batch_uttid_origin]
+ except:
+ traceback.print_exc()
+ return [{"uttid": uttid, "text":""} for uttid in batch_uttid_origin]
+ total_dur = sum(durs)
+ if self.config.use_gpu:
+ feats, lengths = feats.cuda(), lengths.cuda()
+ if self.config.use_half:
+ feats = feats.half()
+
+ if self.asr_type == "aed":
+ start_time = time.time()
+
+ try:
+ hyps = self.model.transcribe(
+ feats, lengths,
+ self.config.beam_size,
+ self.config.nbest,
+ self.config.decode_max_len,
+ self.config.softmax_smoothing,
+ self.config.aed_length_penalty,
+ self.config.eos_penalty,
+ self.config.return_timestamp,
+ self.elm,
+ self.config.elm_weight
+ )
+ except Exception as e:
+ traceback.print_exc()
+ hyps = []
+
+ elapsed = time.time() - start_time
+ rtf= elapsed / total_dur if total_dur > 0 else 0
+
+ results = []
+ for uttid, wav, hyp, dur in zip(batch_uttid, batch_wav_path, hyps, durs):
+ hyp = hyp[0] # only return 1-best
+ hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
+ text = self.tokenizer.detokenize(hyp_ids)
+ text = re.sub(r"()|()", "", text)
+ results.append({"uttid": uttid, "text": text.lower(),
+ "confidence": round(hyp["confidence"].cpu().item(), 3),
+ "dur_s": round(dur, 3), "rtf": f"{rtf:.4f}"})
+ if type(wav) == str:
+ results[-1]["wav"] = wav
+ if self.config.return_timestamp:
+ results[-1]["timestamp"] = self._get_and_fix_timestamp(hyp, hyp_ids, dur)
+ return results
+
+ elif self.asr_type == "llm":
+ input_ids, attention_mask, _, _ = \
+ LlmTokenizerWrapper.preprocess_texts(
+ origin_texts=[""]*feats.size(0), tokenizer=self.tokenizer,
+ max_len=128, decode=True)
+ if self.config.use_gpu:
+ input_ids = input_ids.cuda()
+ attention_mask = attention_mask.cuda()
+ start_time = time.time()
+
+ try:
+ generated_ids = self.model.transcribe(
+ feats, lengths, input_ids, attention_mask,
+ self.config.beam_size,
+ self.config.decode_max_len,
+ self.config.decode_min_len,
+ self.config.repetition_penalty,
+ self.config.llm_length_penalty,
+ self.config.temperature
+ )
+ texts = self.tokenizer.batch_decode(generated_ids,
+ skip_special_tokens=True)
+ except Exception as e:
+ texts = []
+
+ elapsed = time.time() - start_time
+ rtf= elapsed / total_dur if total_dur > 0 else 0
+ results = []
+ for uttid, wav, text in zip(batch_uttid, batch_wav_path, texts):
+ results.append({"uttid": uttid, "text": text.lower(),
+ "rtf": f"{rtf:.4f}"})
+ if type(wav) == str:
+ results[-1]["wav"] = wav
+ return results
+
+ def _get_and_fix_timestamp(self, hyp, hyp_ids, dur):
+ r3 = lambda x: round(x, 3)
+ if "timestamp" not in hyp or hyp["timestamp"] is None:
+ timestamp = []
+ avg_dur = dur / len(hyp_ids) if len(hyp_ids) > 0 else 0
+ last_end = dur
+ for i, hyp_id in enumerate(hyp_ids):
+ token = self.tokenizer.detokenize([hyp_id], "", False)
+ start = min(max(0, i*avg_dur), last_end)
+ end = min((i+1)*avg_dur, dur)
+ last_end = end
+ timestamp.append([token.lower(), r3(start), r3(end)])
+ else:
+ starts, ends = hyp["timestamp"]
+ timestamp = []
+ last_end = dur
+ SHIFT = 0.06 # shift 40ms
+ for hyp_id, start, end in zip(hyp_ids, starts, ends):
+ token = self.tokenizer.detokenize([hyp_id], "", False)
+ start = min(max(0, start - SHIFT), last_end)
+ end = min(max(0, end - SHIFT), dur)
+ last_end = end
+ timestamp.append([token.lower(), r3(start), r3(end)])
+ # Fix case: start == dur and end == dur
+ for i in range(len(timestamp)):
+ idx = -(i+1)
+ _, start, end = timestamp[idx]
+ if abs(dur - start) < 0.001:
+ logger.info(f"start before {timestamp[idx]}")
+ timestamp[idx][1] = dur - (i+1)*0.001
+ logger.info(f"start after {timestamp[idx]}")
+ if i != 0 and abs(dur - end) < 0.001:
+ logger.info(f"end before {timestamp[idx]}")
+ timestamp[idx][2] = dur - i*0.001
+ logger.info(f"end after {timestamp[idx]}")
+ timestamp = self.tokenizer.merge_spm_timestamp(timestamp)
+ return timestamp
+
+
+def load_fireredasr_aed_model(model_path):
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
+ model = FireRedAsrAed.from_args(package["args"])
+ model.load_state_dict(package["model_state_dict"], strict=False)
+ return model
+
+
+def load_firered_llm_model_and_tokenizer(model_path, encoder_path, llm_dir):
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
+ package["args"].encoder_path = encoder_path
+ package["args"].llm_dir = llm_dir
+ model = FireRedAsrLlm.from_args(package["args"])
+ model.load_state_dict(package["model_state_dict"], strict=False)
+ tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(llm_dir)
+ return model, tokenizer
+
+def load_lstm_lm(model_path):
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
+ model = LstmLm.from_args(package["args"])
+ model.load_state_dict(package["model_state_dict"], strict=False)
+ return model
diff --git a/fireredasr2s/fireredasr2/data/asr_feat.py b/fireredasr2s/fireredasr2/data/asr_feat.py
new file mode 100644
index 0000000000000000000000000000000000000000..506ef182483e74da586ce6a69ee870c14fffe21b
--- /dev/null
+++ b/fireredasr2s/fireredasr2/data/asr_feat.py
@@ -0,0 +1,124 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import math
+import os
+
+import kaldiio
+import kaldi_native_fbank as knf
+import numpy as np
+import torch
+
+
+class ASRFeatExtractor:
+ def __init__(self, kaldi_cmvn_file):
+ self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
+ self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
+ frame_shift=10, dither=0.0)
+
+ def __call__(self, wav_paths, wav_uttids):
+ feats = []
+ durs = []
+ return_wav_paths = []
+ return_wav_uttids = []
+
+ wav_datas = []
+ if isinstance(wav_paths[0], str):
+ for wav_path in wav_paths:
+ sample_rate, wav_np = kaldiio.load_mat(wav_path)
+ wav_datas.append([sample_rate, wav_np])
+ else:
+ wav_datas = wav_paths
+
+ for (sample_rate, wav_np), path, uttid in zip(wav_datas, wav_paths, wav_uttids):
+ dur = wav_np.shape[0] / sample_rate
+ fbank = self.fbank((sample_rate, wav_np))
+ if fbank.shape[0] < 1:
+ continue
+ if self.cmvn is not None:
+ fbank = self.cmvn(fbank)
+ fbank = torch.from_numpy(fbank).float()
+ feats.append(fbank)
+ durs.append(dur)
+ return_wav_paths.append(path)
+ return_wav_uttids.append(uttid)
+ if len(feats) > 0:
+ lengths = torch.tensor([feat.size(0) for feat in feats]).long()
+ feats_pad = self.pad_feat(feats, 0.0)
+ else:
+ lengths, feats_pad = None, None
+ return feats_pad, lengths, durs, return_wav_paths, return_wav_uttids
+
+ def pad_feat(self, xs, pad_value):
+ # type: (List[Tensor], int) -> Tensor
+ n_batch = len(xs)
+ max_len = max([xs[i].size(0) for i in range(n_batch)])
+ pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value)
+ for i in range(n_batch):
+ pad[i, :xs[i].size(0)] = xs[i]
+ return pad
+
+
+class CMVN:
+ def __init__(self, kaldi_cmvn_file):
+ self.dim, self.means, self.inverse_std_variences = \
+ self.read_kaldi_cmvn(kaldi_cmvn_file)
+
+ def __call__(self, x, is_train=False):
+ assert x.shape[-1] == self.dim, "CMVN dim mismatch"
+ out = x - self.means
+ out = out * self.inverse_std_variences
+ return out
+
+ def read_kaldi_cmvn(self, kaldi_cmvn_file):
+ assert os.path.exists(kaldi_cmvn_file)
+ stats = kaldiio.load_mat(kaldi_cmvn_file)
+ assert stats.shape[0] == 2
+ dim = stats.shape[-1] - 1
+ count = stats[0, dim]
+ assert count >= 1
+ floor = 1e-20
+ means = []
+ inverse_std_variences = []
+ for d in range(dim):
+ mean = stats[0, d] / count
+ means.append(mean.item())
+ varience = (stats[1, d] / count) - mean*mean
+ if varience < floor:
+ varience = floor
+ istd = 1.0 / math.sqrt(varience)
+ inverse_std_variences.append(istd)
+ return dim, np.array(means), np.array(inverse_std_variences)
+
+
+
+class KaldifeatFbank:
+ def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
+ dither=1.0):
+ self.dither = dither
+ opts = knf.FbankOptions()
+ opts.frame_opts.dither = dither
+ opts.mel_opts.num_bins = num_mel_bins
+ opts.frame_opts.snip_edges = True
+ opts.mel_opts.debug_mel = False
+ self.opts = opts
+
+ def __call__(self, wav, is_train=False):
+ if type(wav) is str:
+ sample_rate, wav_np = kaldiio.load_mat(wav)
+ elif type(wav) in [tuple, list] and len(wav) == 2:
+ sample_rate, wav_np = wav
+ assert len(wav_np.shape) == 1
+
+ dither = self.dither if is_train else 0.0
+ self.opts.frame_opts.dither = dither
+ fbank = knf.OnlineFbank(self.opts)
+
+ fbank.accept_waveform(sample_rate, wav_np.tolist())
+ feat = []
+ for i in range(fbank.num_frames_ready):
+ feat.append(fbank.get_frame(i))
+ if len(feat) == 0:
+ print("Check data, len(feat) == 0", wav, flush=True)
+ return np.zeros((0, self.opts.mel_opts.num_bins))
+ feat = np.vstack(feat)
+ return feat
diff --git a/fireredasr2s/fireredasr2/data/token_dict.py b/fireredasr2s/fireredasr2/data/token_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..31745d383e31a4f643c423d1b1b50368033f9297
--- /dev/null
+++ b/fireredasr2s/fireredasr2/data/token_dict.py
@@ -0,0 +1,63 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TokenDict:
+ def __init__(self, dict_path, unk=""):
+ assert dict_path != ""
+ self.id2word, self.word2id = self.read_dict(dict_path)
+ self.unk = unk
+ assert unk == "" or unk in self.word2id
+ self.unkid = self.word2id[unk] if unk else -1
+
+ def get(self, key, default):
+ if type(default) == str:
+ default = self.word2id[default]
+ return self.word2id.get(key, default)
+
+ def __getitem__(self, key):
+ if type(key) == str:
+ if self.unk:
+ return self.word2id.get(key, self.word2id[self.unk])
+ else:
+ return self.word2id[key]
+ elif type(key) == int:
+ return self.id2word[key]
+ else:
+ raise TypeError("Key should be str or int")
+
+ def __len__(self):
+ return len(self.id2word)
+
+ def __contains__(self, query):
+ if type(query) == str:
+ return query in self.word2id
+ elif type(query) == int:
+ return query in self.id2word
+ else:
+ raise TypeError("query should be str or int")
+
+ def read_dict(self, dict_path):
+ id2word, word2id = [], {}
+ with open(dict_path, encoding='utf8') as f:
+ for i, line in enumerate(f):
+ tokens = line.strip().split()
+ if len(tokens) >= 2:
+ word, index = tokens[0], int(tokens[1])
+ elif len(tokens) == 1:
+ word, index = tokens[0], i
+ else: # empty line or space
+ logger.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
+ word, index = " ", i
+ assert len(id2word) == index
+ assert len(word2id) == index
+ if word == "":
+ logger.info(f"NOTE: Find in {dict_path}:L{i} and convert it to ' '")
+ word = " "
+ word2id[word] = index
+ id2word.append(word)
+ assert len(id2word) == len(word2id)
+ return id2word, word2id
diff --git a/fireredasr2s/fireredasr2/models/fireredasr_aed.py b/fireredasr2s/fireredasr2/models/fireredasr_aed.py
new file mode 100644
index 0000000000000000000000000000000000000000..490d8a32dfca6c8f09c4488fe7cee171def5f72f
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/fireredasr_aed.py
@@ -0,0 +1,74 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import traceback
+
+import torch
+import torchaudio
+
+from .module.conformer_encoder import ConformerEncoder
+from .module.ctc import CTC
+from .module.transformer_decoder import TransformerDecoder
+
+
+class FireRedAsrAed(torch.nn.Module):
+ @classmethod
+ def from_args(cls, args):
+ return cls(args)
+
+ def __init__(self, args):
+ super().__init__()
+ self.sos_id = args.sos_id
+ self.eos_id = args.eos_id
+
+ self.encoder = ConformerEncoder(
+ args.idim, args.n_layers_enc, args.n_head, args.d_model,
+ args.residual_dropout, args.dropout_rate,
+ args.kernel_size, args.pe_maxlen)
+
+ self.decoder = TransformerDecoder(
+ args.sos_id, args.eos_id, args.pad_id, args.odim,
+ args.n_layers_dec, args.n_head, args.d_model,
+ args.residual_dropout, args.pe_maxlen)
+
+ self.ctc = CTC(args.odim, args.d_model)
+
+ def transcribe(self, padded_input, input_lengths,
+ beam_size=1, nbest=1, decode_max_len=0,
+ softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0,
+ return_timestamp=False, elm=None, elm_weight=0.0):
+ enc_outputs, enc_lengths, enc_mask = self.encoder(padded_input, input_lengths)
+ nbest_hyps = self.decoder.batch_beam_search(
+ enc_outputs, enc_mask,
+ beam_size, nbest, decode_max_len,
+ softmax_smoothing, length_penalty, eos_penalty,
+ elm, elm_weight)
+ if return_timestamp:
+ nbest_hyps = self.get_token_timestamp_torchaudio(enc_outputs, enc_lengths, nbest_hyps)
+ return nbest_hyps
+
+ def get_token_timestamp_torchaudio(self, enc_outputs, enc_lengths, nbest_hyps):
+ ctc_logits = self.ctc(enc_outputs)
+ enc_lengths = enc_lengths
+ for n in range(enc_outputs.size(0)):
+ try:
+ n_ctc_logits = ctc_logits[n, :enc_lengths[n]]
+ y = nbest_hyps[n][0]["yseq"]
+ y = y[y!=0] # 0 is blank
+ if y.numel() == 0 or n_ctc_logits.size()[0] == 0:
+ logger.debug("skip null output")
+ nbest_hyps[n][0]["timestamp"] = None
+ continue
+ elif y.numel() > n_ctc_logits.size()[0]:
+ nbest_hyps[n][0]["timestamp"] = None
+ continue
+
+ alignment, _ = torchaudio.functional.forced_align(
+ n_ctc_logits.unsqueeze(0), y.unsqueeze(0), blank=0)
+ alignment = alignment[0].cpu().tolist()
+ start_times, end_times = self.ctc.ctc_alignment_to_timestamp(alignment,
+ self.encoder.input_preprocessor.subsampling, blank_id=0)
+ nbest_hyps[n][0]["timestamp"] = (start_times, end_times)
+ except:
+ traceback.print_exc()
+ nbest_hyps[n][0]["timestamp"] = None
+ return nbest_hyps
diff --git a/fireredasr2s/fireredasr2/models/fireredasr_llm.py b/fireredasr2s/fireredasr2/models/fireredasr_llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..76f270c8c7212da3c49e03662e9a3eee2cabd35d
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/fireredasr_llm.py
@@ -0,0 +1,297 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import logging
+import os
+import random
+import re
+
+import torch
+import torch.nn as nn
+from transformers import AutoModelForCausalLM
+
+from ..models.fireredasr_aed import FireRedAsrAed
+from ..models.module.adapter import Adapter
+from ..models.param import count_model_parameters
+from ..tokenizer.llm_tokenizer import DEFAULT_SPEECH_TOKEN, IGNORE_TOKEN_ID
+from ..tokenizer.llm_tokenizer import LlmTokenizerWrapper
+
+
+class FireRedAsrLlm(nn.Module):
+ @classmethod
+ def load_encoder(cls, model_path):
+ assert os.path.exists(model_path)
+ package = torch.load(model_path, map_location=lambda storage, loc: storage)
+ model = FireRedAsrAed.from_args(package["args"])
+ if "model_state_dict" in package:
+ model.load_state_dict(package["model_state_dict"], strict=False)
+ encoder = model.encoder
+ encoder_dim = encoder.odim
+ return encoder, encoder_dim
+
+ @classmethod
+ def from_args(cls, args):
+ logging.info(args)
+ logging.info("Build FireRedAsrLlm")
+ # Build Speech Encoder
+ encoder, encoder_dim = cls.load_encoder(args.encoder_path)
+ count_model_parameters(encoder)
+ if args.freeze_encoder:
+ logging.info(f"Frezee encoder")
+ for name, param in encoder.named_parameters():
+ param.requires_grad = False
+ encoder.eval()
+
+ # Training use torch.bfloat16
+ if args.use_flash_attn:
+ attn_implementation = "flash_attention_2"
+ if args.use_fp16:
+ #torch_dtype = torch.float16
+ torch_dtype = torch.bfloat16
+ else:
+ torch_dtype = torch.float32
+ else:
+ attn_implementation = "eager"
+ if args.use_fp16:
+ #torch_dtype = torch.float16
+ torch_dtype = torch.bfloat16
+ else:
+ torch_dtype = torch.float32
+
+ # Build LLM
+ llm = AutoModelForCausalLM.from_pretrained(
+ args.llm_dir,
+ attn_implementation=attn_implementation,
+ torch_dtype=torch_dtype,
+ )
+ count_model_parameters(llm)
+
+ # LLM Freeze or LoRA
+ llm_dim = llm.config.hidden_size
+ if args.freeze_llm:
+ logging.info(f"Frezee LLM")
+ for name, param in llm.named_parameters():
+ param.requires_grad = False
+ llm.eval()
+ else:
+ if args.use_lora:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=64,
+ lora_alpha=16,
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "up_proj",
+ "gate_proj",
+ "down_proj",
+ ],
+ lora_dropout=0.05,
+ task_type="CAUSAL_LM",
+ )
+ llm = get_peft_model(llm, lora_config)
+ llm.print_trainable_parameters()
+
+ tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(args.llm_dir)
+ assert tokenizer.pad_token_id == tokenizer.convert_tokens_to_ids("<|endoftext|>")
+ llm.config.pad_token_id = tokenizer.pad_token_id
+ llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
+ llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
+ DEFAULT_SPEECH_TOKEN
+ )
+
+ # Build projector
+ encoder_projector = Adapter(
+ encoder_dim, llm_dim, args.encoder_downsample_rate)
+ count_model_parameters(encoder_projector)
+
+ return cls(encoder, llm, encoder_projector,
+ args.freeze_encoder, args.freeze_llm)
+
+ def __init__(self, encoder, llm, encoder_projector,
+ freeze_encoder, freeze_llm):
+ super().__init__()
+ self.encoder = encoder
+ self.llm = llm
+ self.encoder_projector = encoder_projector
+ # args
+ self.freeze_encoder = freeze_encoder
+ self.freeze_llm = freeze_llm
+ self.llm_config = llm.config
+
+ def transcribe(self, padded_feat, feat_lengths, padded_input_ids, attention_mask,
+ beam_size=1, decode_max_len=0, decode_min_len=0,
+ repetition_penalty=1.0, llm_length_penalty=1.0, temperature=1.0):
+ encoder_outs, enc_lengths, enc_mask = self.encoder(padded_feat, feat_lengths)
+ speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
+ inputs_embeds = self.llm.get_input_embeddings()(padded_input_ids)
+
+ inputs_embeds, attention_mask, _ = \
+ self._merge_input_ids_with_speech_features(
+ speech_features.to(inputs_embeds.dtype), inputs_embeds, padded_input_ids, attention_mask,
+ speech_lens=speech_lens
+ )
+
+ max_new_tokens = speech_features.size(1) if decode_max_len < 1 else decode_max_len
+ max_new_tokens = max(1, max_new_tokens)
+
+ generated_ids = self.llm.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ max_new_tokens=max_new_tokens,
+ num_beams=beam_size,
+ do_sample=False,
+ min_length=decode_min_len,
+ repetition_penalty=repetition_penalty,
+ length_penalty=llm_length_penalty,
+ temperature=temperature,
+ bos_token_id=self.llm.config.bos_token_id,
+ eos_token_id=self.llm.config.eos_token_id,
+ pad_token_id=self.llm.config.pad_token_id,
+ )
+
+ return generated_ids
+
+
+ def _merge_input_ids_with_speech_features(
+ self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None,
+ speech_lens=None
+ ):
+ """
+ Modified from: https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
+ """
+ speech_lens = None
+ num_speechs, speech_len, embed_dim = speech_features.shape
+ batch_size, sequence_length = input_ids.shape
+ left_padding = not torch.sum(
+ input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
+ )
+ # 1. Create a mask to know where special speech tokens are
+ special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
+ num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
+ # Compute the maximum embed dimension
+ max_embed_dim = (
+ num_special_speech_tokens.max() * (speech_len - 1)
+ ) + sequence_length
+ batch_indices, non_speech_indices = torch.where(
+ input_ids != self.llm.config.default_speech_token_id
+ )
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged speech-text sequence.
+ # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
+ # `torch.cumsum` computes how each speech token shifts subsequent text token positions.
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
+ new_token_positions = (
+ torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
+ ) # (N,U)
+ nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_speech_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size,
+ max_embed_dim,
+ embed_dim,
+ dtype=inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ )
+ final_attention_mask = torch.zeros(
+ batch_size,
+ max_embed_dim,
+ dtype=attention_mask.dtype,
+ device=inputs_embeds.device,
+ )
+ if labels is not None:
+ final_labels = torch.full(
+ (batch_size, max_embed_dim),
+ IGNORE_TOKEN_ID,
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = inputs_embeds.device
+ batch_indices, non_speech_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_speech_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+ attention_mask = attention_mask.to(target_device)
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
+ batch_indices, non_speech_indices
+ ]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
+ batch_indices, non_speech_indices
+ ]
+ if labels is not None:
+ final_labels[batch_indices, text_to_overwrite] = labels[
+ batch_indices, non_speech_indices
+ ]
+
+ # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
+ speech_to_overwrite = torch.full(
+ (batch_size, max_embed_dim),
+ True,
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ speech_to_overwrite[batch_indices, text_to_overwrite] = False
+ if speech_lens is not None:
+ speech_pad_position = speech_to_overwrite.cumsum(-1) <= speech_lens[:, None]
+ speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
+ :, None
+ ].to(target_device)
+ if speech_lens is not None:
+ if torch.any(speech_lens > speech_len):
+ raise ValueError(
+ f"speech_lens contains values ({speech_lens.max()}) larger than "
+ f"speech_len ({speech_len})"
+ )
+
+ speech_cumsum = speech_to_overwrite.long().cumsum(-1)
+ speech_position_counter = torch.where(speech_to_overwrite, speech_cumsum - 1, 0)
+ valid_speech_positions = speech_position_counter < speech_lens[:, None].to(target_device)
+
+ speech_to_overwrite &= valid_speech_positions
+ if speech_to_overwrite.sum().item() != int(speech_lens.sum().item()):
+ raise ValueError(
+ f"speech_lens and speech token distribution mismatch: "
+ f"expected total speech frames {speech_lens.sum().item()}, "
+ f"but got {speech_to_overwrite.sum().item()} positions."
+ )
+ batch_idx, seq_idx = torch.where(speech_to_overwrite)
+ speech_feature_idx = speech_position_counter[speech_to_overwrite]
+ final_embedding[batch_idx, seq_idx] = speech_features[batch_idx, speech_feature_idx].to(target_device)
+ else:
+ if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
+ raise ValueError(
+ f"The input provided to the model are wrong. The number of speech tokens is {speech_to_overwrite.sum()} while"
+ f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
+ )
+ final_embedding[speech_to_overwrite] = (
+ speech_features.contiguous().reshape(-1, embed_dim)[:speech_to_overwrite.sum()].to(target_device)
+ )
+
+ final_attention_mask[speech_to_overwrite] = 1
+
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
+ batch_indices_pad, pad_indices = torch.where(
+ input_ids == self.llm.config.pad_token_id
+ )
+ if len(batch_indices_pad) > 0:
+ indices_to_mask = new_token_positions[batch_indices_pad, pad_indices]
+ final_embedding[batch_indices_pad, indices_to_mask] = 0
+ final_attention_mask[batch_indices_pad, indices_to_mask] = 0
+
+ if labels is None:
+ final_labels = None
+
+ return final_embedding, final_attention_mask, final_labels
diff --git a/fireredasr2s/fireredasr2/models/lstm_lm.py b/fireredasr2s/fireredasr2/models/lstm_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6623019c119ab1804df149a70169004b8cee520d
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/lstm_lm.py
@@ -0,0 +1,65 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+
+
+class LstmLm(nn.Module):
+ @classmethod
+ def from_args(cls, args):
+ args.padding_idx = 2
+ args.sos_id = 3
+ args.eos_id = 4
+ return cls(args)
+
+ def __init__(self, args):
+ super().__init__()
+ self.embedding = nn.Embedding(args.idim, args.embedding_dim,
+ padding_idx=args.padding_idx)
+ self.lstm = nn.LSTM(args.embedding_dim, args.hidden_size, args.num_layers,
+ batch_first=True, dropout=args.dropout)
+ self.fc_in_dim = args.embedding_dim
+ self.fc = nn.Linear(args.embedding_dim, args.odim)
+
+ self._tie_weights(args)
+ self.sos_id = args.sos_id
+ self.eos_id = args.eos_id
+ self.ignore_index = args.padding_idx
+
+ @torch.jit.ignore
+ def _tie_weights(self, args):
+ if args.tie_weights:
+ if self.fc_in_dim != args.embedding_dim or args.idim != args.odim:
+ raise ValueError('When using the tied flag, fc_in_dim must be equal to embedding_dim')
+ self.fc.weight = self.embedding.weight
+
+ @torch.jit.export
+ def init_hidden(self, tensor, batch_size):
+ # type: (Tensor, int) -> Tuple[Tensor, Tensor]
+ return (tensor.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).float(),
+ tensor.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).float())
+
+ @torch.jit.export
+ def forward_model(self, padded_inputs, lengths=None, hidden=None):
+ # type: (Tensor, Optional[Tensor], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+ # Embedding Layer
+ padded_inputs = self.embedding(padded_inputs) # N, T, D
+ # LSTM Layers
+ if lengths is None:
+ output, new_hidden = self.lstm(padded_inputs, hidden)
+ else:
+ lengths = lengths.cpu().int()
+ total_length = padded_inputs.size(1) # get the max sequence length
+ packed_input = pack_padded_sequence(padded_inputs, lengths,
+ batch_first=True,
+ enforce_sorted=False)
+ #self.lstm.flatten_parameters()
+ packed_output, new_hidden = self.lstm(packed_input, hidden)
+ output, _ = pad_packed_sequence(packed_output,
+ batch_first=True,
+ total_length=total_length)
+ # Output Layer
+ score = self.fc(output) # (N, T, V)
+ return score, new_hidden
diff --git a/fireredasr2s/fireredasr2/models/module/adapter.py b/fireredasr2s/fireredasr2/models/module/adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b9d082b3c77ec263c754d67e9bde1c13ae124f
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/module/adapter.py
@@ -0,0 +1,32 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import torch
+import torch.nn as nn
+
+
+class Adapter(nn.Module):
+ def __init__(self, encoder_dim, llm_dim, downsample_rate=2):
+ super().__init__()
+ self.ds = downsample_rate
+ self.linear1 = nn.Linear(encoder_dim * downsample_rate, llm_dim)
+ self.relu = nn.ReLU()
+ self.linear2 = nn.Linear(llm_dim, llm_dim)
+
+ def forward(self, x, x_lens):
+ batch_size, seq_len, feat_dim = x.size()
+ num_frames_to_discard = seq_len % self.ds
+ if num_frames_to_discard > 0:
+ x = x[:, :-num_frames_to_discard, :]
+ seq_len = x.size(1)
+
+ x = x.contiguous()
+ x = x.view(
+ batch_size, seq_len // self.ds, feat_dim * self.ds
+ )
+
+ x = self.linear1(x)
+ x = self.relu(x)
+ x = self.linear2(x)
+
+ new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds
+ return x, new_x_lens
diff --git a/fireredasr2s/fireredasr2/models/module/conformer_encoder.py b/fireredasr2s/fireredasr2/models/module/conformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ebe1bd409786b577079c9d58b0a936af127e6a7
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/module/conformer_encoder.py
@@ -0,0 +1,324 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ConformerEncoder(nn.Module):
+ def __init__(self, idim, n_layers, n_head, d_model,
+ residual_dropout=0.1, dropout_rate=0.1, kernel_size=33,
+ pe_maxlen=5000):
+ super().__init__()
+ self.odim = d_model
+
+ self.input_preprocessor = Conv2dSubsampling(idim, d_model)
+ self.positional_encoding = RelPositionalEncoding(d_model)
+ self.dropout = nn.Dropout(residual_dropout)
+
+ self.layer_stack = nn.ModuleList()
+ for l in range(n_layers):
+ block = RelPosEmbConformerBlock(d_model, n_head,
+ residual_dropout,
+ dropout_rate, kernel_size)
+ self.layer_stack.append(block)
+
+ def forward(self, padded_input, input_lengths, pad=True):
+ if pad:
+ padded_input = F.pad(padded_input,
+ (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
+ src_mask = self.padding_position_is_0(padded_input, input_lengths)
+
+ embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask)
+ enc_output = self.dropout(embed_output)
+
+ pos_emb = self.dropout(self.positional_encoding(embed_output))
+
+ enc_outputs = []
+ for enc_layer in self.layer_stack:
+ enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
+ pad_mask=src_mask)
+ enc_outputs.append(enc_output)
+
+ return enc_output, input_lengths, src_mask
+
+ def padding_position_is_0(self, padded_input, input_lengths):
+ N, T = padded_input.size()[:2]
+ mask = torch.ones((N, T)).to(padded_input.device)
+ for i in range(N):
+ mask[i, input_lengths[i]:] = 0
+ mask = mask.unsqueeze(dim=1)
+ return mask.to(torch.uint8)
+
+
+class RelPosEmbConformerBlock(nn.Module):
+ def __init__(self, d_model, n_head,
+ residual_dropout=0.1,
+ dropout_rate=0.1, kernel_size=33):
+ super().__init__()
+ self.ffn1 = ConformerFeedForward(d_model, dropout_rate)
+ self.mhsa = RelPosMultiHeadAttention(n_head, d_model,
+ residual_dropout)
+ self.conv = ConformerConvolution(d_model, kernel_size,
+ dropout_rate)
+ self.ffn2 = ConformerFeedForward(d_model, dropout_rate)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None):
+ out = 0.5 * x + 0.5 * self.ffn1(x)
+ out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
+ out = self.conv(out, pad_mask)
+ out = 0.5 * out + 0.5 * self.ffn2(out)
+ out = self.layer_norm(out)
+ return out
+
+
+class Swish(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class Conv2dSubsampling(nn.Module):
+ def __init__(self, idim, d_model, out_channels=32):
+ super().__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(1, out_channels, 3, 2),
+ nn.ReLU(),
+ nn.Conv2d(out_channels, out_channels, 3, 2),
+ nn.ReLU(),
+ )
+ subsample_idim = ((idim - 1) // 2 - 1) // 2
+ self.out = nn.Linear(out_channels * subsample_idim, d_model)
+
+ self.subsampling = 4
+ left_context = right_context = 3 # both exclude currect frame
+ self.context = left_context + 1 + right_context # 7
+
+ def forward(self, x, x_mask):
+ x = x.unsqueeze(1)
+ x = self.conv(x)
+ N, C, T, D = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
+ mask = x_mask[:, :, :-2:2][:, :, :-2:2]
+ input_lengths = mask[:, -1, :].sum(dim=-1)
+ return x, input_lengths, mask
+
+
+class RelPositionalEncoding(torch.nn.Module):
+ def __init__(self, d_model, max_len=5000):
+ super().__init__()
+ pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
+ pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ # Tmax = 2 * max_len - 1
+ Tmax, T = self.pe.size(1), x.size(1)
+ pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
+ return pos_emb
+
+
+class ConformerFeedForward(nn.Module):
+ def __init__(self, d_model, dropout_rate=0.1):
+ super().__init__()
+ pre_layer_norm = nn.LayerNorm(d_model)
+ linear_expand = nn.Linear(d_model, d_model*4)
+ nonlinear = Swish()
+ dropout_pre = nn.Dropout(dropout_rate)
+ linear_project = nn.Linear(d_model*4, d_model)
+ dropout_post = nn.Dropout(dropout_rate)
+ self.net = nn.Sequential(pre_layer_norm,
+ linear_expand,
+ nonlinear,
+ dropout_pre,
+ linear_project,
+ dropout_post)
+
+ def forward(self, x):
+ residual = x
+ output = self.net(x)
+ output = output + residual
+ return output
+
+
+class ConformerConvolution(nn.Module):
+ def __init__(self, d_model, kernel_size=33, dropout_rate=0.1):
+ super().__init__()
+ assert kernel_size % 2 == 1
+ self.pre_layer_norm = nn.LayerNorm(d_model)
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False)
+ self.glu = F.glu
+ self.padding = (kernel_size - 1) // 2
+ self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2,
+ kernel_size, stride=1,
+ padding=self.padding,
+ groups=d_model*2, bias=False)
+ self.batch_norm = nn.LayerNorm(d_model*2)
+ self.swish = Swish()
+ self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, x, mask=None):
+ residual = x
+ out = self.pre_layer_norm(x)
+ out = out.transpose(1, 2)
+ if mask is not None:
+ out.masked_fill_(mask.ne(1), 0.0)
+ out = self.pointwise_conv1(out)
+ out = F.glu(out, dim=1)
+ out = self.depthwise_conv(out)
+
+ out = out.transpose(1, 2)
+ out = self.swish(self.batch_norm(out))
+ out = out.transpose(1, 2)
+
+ out = self.dropout(self.pointwise_conv2(out))
+ if mask is not None:
+ out.masked_fill_(mask.ne(1), 0.0)
+ out = out.transpose(1, 2)
+ return out + residual
+
+
+class EncoderMultiHeadAttention(nn.Module):
+ def __init__(self, n_head, d_model,
+ residual_dropout=0.1):
+ super().__init__()
+ assert d_model % n_head == 0
+ self.n_head = n_head
+ self.d_k = d_model // n_head
+ self.d_v = self.d_k
+
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False)
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
+ self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False)
+
+ self.layer_norm_q = nn.LayerNorm(d_model)
+ self.layer_norm_k = nn.LayerNorm(d_model)
+ self.layer_norm_v = nn.LayerNorm(d_model)
+
+ self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5)
+ self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False)
+ self.dropout = nn.Dropout(residual_dropout)
+
+ def forward(self, q, k, v, mask=None):
+ sz_b, len_q = q.size(0), q.size(1)
+
+ residual = q
+ q, k, v = self.forward_qkv(q, k, v)
+
+ output, attn = self.attention(q, k, v, mask=mask)
+
+ output = self.forward_output(output, residual, sz_b, len_q)
+ return output, attn
+
+ def forward_qkv(self, q, k, v):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
+
+ q = self.layer_norm_q(q)
+ k = self.layer_norm_k(k)
+ v = self.layer_norm_v(v)
+
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ return q, k, v
+
+ def forward_output(self, output, residual, sz_b, len_q):
+ output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
+ fc_out = self.fc(output)
+ output = self.dropout(fc_out)
+ output = output + residual
+ return output
+
+
+class ScaledDotProductAttention(nn.Module):
+ def __init__(self, temperature):
+ super().__init__()
+ self.temperature = temperature
+ self.dropout = nn.Dropout(0.0)
+ self.INF = float('inf')
+
+ def forward(self, q, k, v, mask=None):
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
+ output, attn = self.forward_attention(attn, v, mask)
+ return output, attn
+
+ def forward_attention(self, attn, v, mask=None):
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ mask = mask.eq(0)
+ attn = attn.masked_fill(mask, -self.INF)
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
+ else:
+ attn = torch.softmax(attn, dim=-1)
+
+ d_attn = self.dropout(attn)
+ output = torch.matmul(d_attn, v)
+
+ return output, attn
+
+
+class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
+ def __init__(self, n_head, d_model,
+ residual_dropout=0.1):
+ super().__init__(n_head, d_model,
+ residual_dropout)
+ d_k = d_model // n_head
+ self.scale = 1.0 / (d_k ** 0.5)
+ self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False)
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k))
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def _rel_shift(self, x):
+ N, H, T1, T2 = x.size()
+ zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(N, H, T2 + 1, T1)
+ x = x_padded[:, :, 1:].view_as(x)
+ x = x[:, :, :, : x.size(-1) // 2 + 1]
+ return x
+
+ def forward(self, q, k, v, pos_emb, mask=None):
+ sz_b, len_q = q.size(0), q.size(1)
+
+ residual = q
+ q, k, v = self.forward_qkv(q, k, v)
+
+ q = q.transpose(1, 2)
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k)
+ p = p.transpose(1, 2)
+
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self._rel_shift(matrix_bd)
+
+ attn_scores = matrix_ac + matrix_bd
+ attn_scores.mul_(self.scale)
+
+ output, attn = self.attention.forward_attention(attn_scores, v, mask=mask)
+
+ output = self.forward_output(output, residual, sz_b, len_q)
+ return output, attn
diff --git a/fireredasr2s/fireredasr2/models/module/ctc.py b/fireredasr2s/fireredasr2/models/module/ctc.py
new file mode 100644
index 0000000000000000000000000000000000000000..81784c73c206b11a531867830aed33b7308310b3
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/module/ctc.py
@@ -0,0 +1,119 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+class CTC(torch.nn.Module):
+ def __init__(self, odim, encoder_output_size):
+ super().__init__()
+ self.ctc_lo = torch.nn.Linear(encoder_output_size, odim)
+
+ def forward(self, encoder_output_pad):
+ """encoder_output_pad: (N, T, H)"""
+ return F.log_softmax(self.ctc_lo(encoder_output_pad), dim=2)
+
+ @classmethod
+ def ctc_align(cls, ctc_probs, y, blank_id=0):
+ """ctc forced alignment.
+
+ Args:
+ torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
+ torch.Tensor y: id sequence tensor 1d tensor (L)
+ int blank_id: blank symbol index
+ Returns:
+ torch.Tensor: alignment result
+ """
+ y_insert_blank = insert_blank(y, blank_id)
+
+ log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
+ log_alpha = log_alpha - float('inf') # log of zero
+ state_path = (torch.zeros(
+ (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
+ ) # state path
+
+ # init start state
+ log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
+ log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]
+
+ for t in range(1, ctc_probs.size(0)):
+ for s in range(len(y_insert_blank)):
+ if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
+ s] == y_insert_blank[s - 2]:
+ candidates = torch.tensor(
+ [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
+ prev_state = [s, s - 1]
+ else:
+ candidates = torch.tensor([
+ log_alpha[t - 1, s],
+ log_alpha[t - 1, s - 1],
+ log_alpha[t - 1, s - 2],
+ ])
+ prev_state = [s, s - 1, s - 2]
+ log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
+ state_path[t, s] = prev_state[torch.argmax(candidates)]
+
+ state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
+
+ candidates = torch.tensor([
+ log_alpha[-1, len(y_insert_blank) - 1],
+ log_alpha[-1, len(y_insert_blank) - 2]
+ ])
+ prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
+ state_seq[-1] = prev_state[torch.argmax(candidates)]
+ for t in range(ctc_probs.size(0) - 2, -1, -1):
+ state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
+
+ output_alignment = []
+ for t in range(0, ctc_probs.size(0)):
+ output_alignment.append(y_insert_blank[state_seq[t, 0]])
+
+ return output_alignment
+
+ @classmethod
+ def ctc_alignment_to_timestamp(cls, ys_with_blank, subsampling, blank_id=0):
+ start_times: List[float] = []
+ end_times: List[float] = []
+ frame_shift = 10 # ms, hard code
+ T = len(ys_with_blank)
+ t = 0
+ ctc_durs = []
+ while t < T:
+ token = ys_with_blank[t]
+ t += 1
+ if token != blank_id:
+ start_t = t
+ timestamp = frame_shift * subsampling * t / 1000.0 # s
+ start_times.append(timestamp)
+ if len(start_times) == len(end_times) + 2:
+ end_times.append(start_times[-1])
+ # skip repeat token
+ while t < T and token == ys_with_blank[t]:
+ t += 1
+ assert t-start_t >= 0
+ ctc_durs.append((t-start_t+1) * frame_shift * subsampling / 1000.0)
+ end_times.append((frame_shift * subsampling * T + 25)/ 1000.0)
+ if len(start_times) == 0:
+ start_times.append(0.0)
+
+ # Refine end_times
+ assert len(ctc_durs) == len(end_times) and len(start_times) == len(end_times)
+ avg_dur = sum(e-s for s, e in zip(start_times, end_times)) / len(end_times)
+ new_end_times = []
+ for s, e, ctc_dur in zip(start_times, end_times, ctc_durs):
+ if e - s > 2 * avg_dur:
+ e = s + max(1.5*avg_dur, ctc_dur)
+ new_end_times.append(round(e, 3))
+ end_times = new_end_times
+ return start_times, end_times
+
+
+def insert_blank(label, blank_id=0):
+ """Insert blank token between every two label token."""
+ label = np.expand_dims(label, 1)
+ blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
+ label = np.concatenate([blanks, label], axis=1)
+ label = label.reshape(-1)
+ label = np.append(label, label[0])
+ return label
diff --git a/fireredasr2s/fireredasr2/models/module/transformer_decoder.py b/fireredasr2s/fireredasr2/models/module/transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bc84f1ab08bdda5b2f8ad3d7083bfcb973aaa9c
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/module/transformer_decoder.py
@@ -0,0 +1,329 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+from typing import List, Optional, Dict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self, sos_id, eos_id, pad_id, odim,
+ n_layers, n_head, d_model,
+ residual_dropout=0.1, pe_maxlen=5000):
+ super().__init__()
+ self.INF = 1e10
+ # parameters
+ self.pad_id = pad_id
+ self.sos_id = sos_id
+ self.eos_id = eos_id
+ self.n_layers = n_layers
+
+ # Components
+ self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id)
+ self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
+ self.dropout = nn.Dropout(residual_dropout)
+
+ self.layer_stack = nn.ModuleList()
+ for l in range(n_layers):
+ block = DecoderLayer(d_model, n_head, residual_dropout)
+ self.layer_stack.append(block)
+
+ self.tgt_word_prj = nn.Linear(d_model, odim, bias=False)
+ self.layer_norm_out = nn.LayerNorm(d_model)
+
+ self.tgt_word_prj.weight = self.tgt_word_emb.weight
+ self.scale = (d_model ** 0.5)
+
+ def batch_beam_search(self, encoder_outputs, src_masks,
+ beam_size=1, nbest=1, decode_max_len=0,
+ softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0,
+ elm=None, elm_weight=0.0):
+ B = beam_size
+ N, Ti, H = encoder_outputs.size()
+ device = encoder_outputs.device
+ maxlen = decode_max_len if decode_max_len > 0 else Ti
+ assert eos_penalty > 0.0
+
+ # Init
+ encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H)
+ src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti)
+ ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device)
+ t_ys = ys.clone()
+ confidences = torch.zeros(N*B, 1).float().to(device)
+ caches: List[Optional[Tensor]] = []
+ for _ in range(self.n_layers):
+ caches.append(None)
+ scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device)
+ scores = scores.repeat(N).view(N*B, 1)
+ is_finished = torch.zeros_like(scores)
+ if elm is not None:
+ elm_cache = elm.init_hidden(encoder_outputs, N*B)
+
+ # Autoregressive Prediction
+ for t in range(maxlen):
+ tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id)
+
+ dec_output = self.dropout(
+ self.tgt_word_emb(ys) * self.scale +
+ self.positional_encoding(ys))
+# if t > 0:
+# dec_output = dec_output[:, -1:, :]
+ i = 0
+ for dec_layer in self.layer_stack:
+ dec_output = dec_layer.forward(
+ dec_output, encoder_outputs,
+ tgt_mask, src_mask,
+ cache=caches[i])
+ caches[i] = dec_output
+ i += 1
+
+ dec_output = self.layer_norm_out(dec_output)
+
+ t_logit = self.tgt_word_prj(dec_output[:, -1])
+ t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1)
+ t_origin_scores = t_scores
+
+ if elm is not None and elm_weight > 0.0:
+ elm_logit, elm_cache = elm.forward_model(t_ys, hidden=elm_cache)
+ #elm_logit, _ = elm.forward_model(ys)
+ t_lm_scores = torch.log_softmax(elm_logit[:, -1], dim=-1) * (1 - is_finished.float()) # mask, (N*B, V)
+ t_lm_scores[:, elm.eos_id] *= 3
+ t_scores = t_scores + elm_weight * t_lm_scores
+
+ if eos_penalty != 1.0:
+ t_scores[:, self.eos_id] *= eos_penalty
+
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1)
+ t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished)
+ t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished)
+
+ # Accumulated
+ scores = scores + t_topB_scores
+
+ # Pruning
+ scores = scores.view(N, B*B)
+ scores, topB_score_ids = torch.topk(scores, k=B, dim=1)
+ scores = scores.view(-1, 1)
+
+ topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B)
+ stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device)
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
+
+ # Update ys
+ ys = ys[topB_row_number_in_ys]
+ t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
+ ys = torch.cat((ys, t_ys), dim=1)
+
+ # Update confidences
+ confidences = confidences[topB_row_number_in_ys]
+ t_confidences = torch.gather(t_topB_scores.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
+ t_confidences = torch.exp(t_confidences)
+ assert torch.all(t_confidences <= 1.0)
+ assert torch.all(t_confidences >= 0.0)
+ confidences = torch.cat((confidences, t_confidences), dim=1)
+
+ # Update caches
+ new_caches: List[Optional[Tensor]] = []
+ for cache in caches:
+ if cache is not None:
+ new_caches.append(cache[topB_row_number_in_ys])
+ caches = new_caches
+ if elm and elm_weight > 0.0:
+ elm_cache = (elm_cache[0][:, topB_row_number_in_ys], elm_cache[1][:, topB_row_number_in_ys])
+
+ # Update finished state
+ is_finished = t_ys.eq(self.eos_id)
+ if is_finished.sum().item() == N*B:
+ break
+
+ # Length penalty (follow GNMT)
+ scores = scores.view(N, B)
+ ys = ys.view(N, B, -1)
+ ys_lengths = self.get_ys_lengths(ys)
+ if length_penalty > 0.0:
+ penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty)
+ scores /= penalty
+ nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1)
+ nbest_scores = -1.0 * nbest_scores
+ index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long()
+ nbest_ys = ys.view(N*B, -1)[index.view(-1)]
+ nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1)
+ nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1)
+ nbest_confidences = confidences.view(N*B, -1)[index.view(-1)].view(N, nbest_ids.size(1), -1)
+
+ # result
+ nbest_hyps: List[List[Dict[str, Tensor]]] = []
+ for n in range(N):
+ n_nbest_hyps: List[Dict[str, Tensor]] = []
+ for i, score in enumerate(nbest_scores[n]):
+ confidence = nbest_confidences[n, i, 1:nbest_ys_lengths[n, i]]
+ confidence = confidence.mean()
+ new_hyp = {
+ "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]],
+ "confidence": confidence
+ }
+ n_nbest_hyps.append(new_hyp)
+ nbest_hyps.append(n_nbest_hyps)
+ return nbest_hyps
+
+ def ignored_target_position_is_0(self, padded_targets, ignore_id):
+ mask = torch.ne(padded_targets, ignore_id)
+ mask = mask.unsqueeze(dim=1)
+ T = padded_targets.size(-1)
+ upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype)
+ upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device)
+ return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8)
+
+ def upper_triangular_is_0(self, size):
+ ones = torch.ones(size, size)
+ tri_left_ones = torch.tril(ones)
+ return tri_left_ones.to(torch.uint8)
+
+ def set_finished_beam_score_to_zero(self, scores, is_finished):
+ NB, B = scores.size()
+ is_finished = is_finished.float()
+ mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device)
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
+ return scores * (1 - is_finished) + mask_score * is_finished
+
+ def set_finished_beam_y_to_eos(self, ys, is_finished):
+ is_finished = is_finished.long()
+ return ys * (1 - is_finished) + self.eos_id * is_finished
+
+ def get_ys_lengths(self, ys):
+ N, B, Tmax = ys.size()
+ ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1)
+ return ys_lengths.int()
+
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, d_model, n_head, dropout):
+ super().__init__()
+ self.self_attn_norm = nn.LayerNorm(d_model)
+ self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
+
+ self.cross_attn_norm = nn.LayerNorm(d_model)
+ self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
+
+ self.mlp_norm = nn.LayerNorm(d_model)
+ self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)
+
+ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
+ cache=None):
+ x = dec_input
+ residual = x
+ x = self.self_attn_norm(x)
+ if cache is not None:
+ xq = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ self_attn_mask = self_attn_mask[:, -1:, :]
+ else:
+ xq = x
+ x = self.self_attn(xq, x, x, mask=self_attn_mask)
+ x = residual + x
+
+ residual = x
+ x = self.cross_attn_norm(x)
+ x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
+ x = residual + x
+
+ residual = x
+ x = self.mlp_norm(x)
+ x = residual + self.mlp(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x
+
+
+class DecoderMultiHeadAttention(nn.Module):
+ def __init__(self, d_model, n_head, dropout=0.1):
+ super().__init__()
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_k = d_model // n_head
+
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k)
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
+ self.w_vs = nn.Linear(d_model, n_head * self.d_k)
+
+ self.attention = DecoderScaledDotProductAttention(
+ temperature=self.d_k ** 0.5)
+ self.fc = nn.Linear(n_head * self.d_k, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v, mask=None):
+ bs = q.size(0)
+
+ q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
+ k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k)
+ v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+
+ output = self.attention(q, k, v, mask=mask)
+
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
+ output = self.fc(output)
+ output = self.dropout(output)
+
+ return output
+
+
+class DecoderScaledDotProductAttention(nn.Module):
+ def __init__(self, temperature):
+ super().__init__()
+ self.temperature = temperature
+ self.INF = float("inf")
+
+ def forward(self, q, k, v, mask=None):
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
+ if mask is not None:
+ mask = mask.eq(0)
+ attn = attn.masked_fill(mask, -self.INF)
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
+ else:
+ attn = torch.softmax(attn, dim=-1)
+ output = torch.matmul(attn, v)
+ return output
+
+
+class PositionwiseFeedForward(nn.Module):
+ def __init__(self, d_model, d_ff, dropout=0.1):
+ super().__init__()
+ self.w_1 = nn.Linear(d_model, d_ff)
+ self.act = nn.GELU()
+ self.w_2 = nn.Linear(d_ff, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ output = self.w_2(self.act(self.w_1(x)))
+ output = self.dropout(output)
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, max_len=5000):
+ super().__init__()
+ assert d_model % 2 == 0
+ pe = torch.zeros(max_len, d_model, requires_grad=False)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ length = x.size(1)
+ return self.pe[:, :length].clone().detach()
diff --git a/fireredasr2s/fireredasr2/models/param.py b/fireredasr2s/fireredasr2/models/param.py
new file mode 100644
index 0000000000000000000000000000000000000000..4da2c72034cac5357514337968e265960bf1d804
--- /dev/null
+++ b/fireredasr2s/fireredasr2/models/param.py
@@ -0,0 +1,17 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import logging
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def count_model_parameters(model):
+ if not isinstance(model, torch.nn.Module):
+ return 0, 0
+ name = f"{model.__class__.__name__} {model.__class__}"
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
+ logger.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
+ return num, size
diff --git a/fireredasr2s/fireredasr2/speech2text.py b/fireredasr2s/fireredasr2/speech2text.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc01c8c7ea24bc4d11930e332aedf67d3d6c9083
--- /dev/null
+++ b/fireredasr2s/fireredasr2/speech2text.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import argparse
+import json
+import logging
+import os
+
+from fireredasr2.asr import FireRedAsr2, FireRedAsr2Config
+from fireredasr2.utils.io import get_wav_info, write_textgrid
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredasr2.bin.speech2text")
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--asr_type', type=str, required=True, choices=["aed", "llm"])
+parser.add_argument('--model_dir', type=str, required=True)
+
+# Input / Output
+parser.add_argument("--wav_path", type=str)
+parser.add_argument("--wav_paths", type=str, nargs="*")
+parser.add_argument("--wav_dir", type=str)
+parser.add_argument("--wav_scp", type=str)
+parser.add_argument("--sort_wav_by_dur", type=int, default=0)
+parser.add_argument("--output", type=str)
+
+# Decode Options
+parser.add_argument('--use_gpu', type=int, default=1)
+parser.add_argument('--use_half', type=int, default=0)
+parser.add_argument("--batch_size", type=int, default=1)
+parser.add_argument("--beam_size", type=int, default=1)
+parser.add_argument("--decode_max_len", type=int, default=0)
+# FireRedASR-AED
+parser.add_argument("--nbest", type=int, default=1)
+parser.add_argument("--softmax_smoothing", type=float, default=1.0)
+parser.add_argument("--aed_length_penalty", type=float, default=0.0)
+parser.add_argument("--eos_penalty", type=float, default=1.0)
+parser.add_argument("--return_timestamp", type=int, default=0)
+parser.add_argument("--write_textgrid", type=int, default=0)
+# AED External LM
+parser.add_argument("--elm_dir", type=str, default="")
+parser.add_argument("--elm_weight", type=float, default=0.0)
+# FireRedASR-LLM
+parser.add_argument("--decode_min_len", type=int, default=0)
+parser.add_argument("--repetition_penalty", type=float, default=1.0)
+parser.add_argument("--llm_length_penalty", type=float, default=0.0)
+parser.add_argument("--temperature", type=float, default=1.0)
+
+
+def main(args):
+ wavs = get_wav_info(args)
+ fout = open(args.output, "w") if args.output else None
+ foutl = open(args.output + ".jsonl", "w") if args.output else None
+
+ asr_config = FireRedAsr2Config(
+ args.use_gpu,
+ args.use_half,
+ args.beam_size,
+ args.nbest,
+ args.decode_max_len,
+ args.softmax_smoothing,
+ args.aed_length_penalty,
+ args.eos_penalty,
+ args.return_timestamp,
+ args.decode_min_len,
+ args.repetition_penalty,
+ args.llm_length_penalty,
+ args.temperature,
+ args.elm_dir,
+ args.elm_weight
+ )
+ model = FireRedAsr2.from_pretrained(args.asr_type, args.model_dir, asr_config)
+
+ batch_uttid = []
+ batch_wav_path = []
+ for i, wav in enumerate(wavs):
+ uttid, wav_path = wav
+ batch_uttid.append(uttid)
+ batch_wav_path.append(wav_path)
+ if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1:
+ continue
+
+ results = model.transcribe(batch_uttid, batch_wav_path)
+
+ for result in results:
+ logger.info(result)
+ if fout is not None:
+ foutl.write(f"{json.dumps(result, ensure_ascii=False)}\n")
+ fout.write(f"{result['uttid']}\t{result['text']}\n")
+ if args.write_textgrid and "timestamp" in result:
+ write_textgrid(result["wav"], result["dur_s"], result["timestamp"])
+
+ if fout: fout.flush()
+ if foutl: foutl.flush()
+ batch_uttid = []
+ batch_wav_path = []
+ if fout: fout.close()
+ if foutl: foutl.close()
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ logger.info(args)
+ main(args)
diff --git a/fireredasr2s/fireredasr2/tokenizer/aed_tokenizer.py b/fireredasr2s/fireredasr2/tokenizer/aed_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7125fc7cbd1ec34253efc0188fd10ee1abbb65d2
--- /dev/null
+++ b/fireredasr2s/fireredasr2/tokenizer/aed_tokenizer.py
@@ -0,0 +1,93 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import logging
+import re
+
+import sentencepiece as spm
+
+from ..data.token_dict import TokenDict
+
+
+class ChineseCharEnglishSpmTokenizer:
+ """
+ - One Chinese char is a token.
+ - Split English word into SPM and one piece is a token.
+ - Ignore ' ' between Chinese char
+ - Replace ' ' between English word with "▁" by spm_model
+ - Need to put SPM piece into dict file
+ - If not set spm_model, will use English char and
+ """
+ SPM_SPACE = "▁"
+
+ def __init__(self, dict_path, spm_model, unk="", space=""):
+ self.dict = TokenDict(dict_path, unk=unk)
+ self.space = space
+ if spm_model:
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.Load(spm_model)
+ else:
+ self.sp = None
+ print("[WRAN] Not set spm_model, will use English char")
+ print("[WARN] Please check how to deal with ' '(space)")
+ if self.space not in self.dict:
+ print("Please add to your dict, or it will be ")
+
+ def tokenize(self, text, replace_punc=True):
+ #if text == "":
+ # logging.info(f"empty text")
+ text = text.upper()
+ tokens = []
+ if replace_punc:
+ text = re.sub("[,。?!,\.?!]", " ", text)
+ pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])')
+ parts = pattern.split(text.strip())
+ parts = [p for p in parts if len(p.strip()) > 0]
+ for part in parts:
+ if pattern.fullmatch(part) is not None:
+ tokens.append(part)
+ else:
+ if self.sp:
+ for piece in self.sp.EncodeAsPieces(part.strip()):
+ tokens.append(piece)
+ else:
+ for char in part.strip():
+ tokens.append(char if char != " " else self.space)
+ tokens_id = []
+ for token in tokens:
+ tokens_id.append(self.dict.get(token, self.dict.unk))
+ return tokens, tokens_id
+
+ def detokenize(self, inputs, join_symbol="", replace_spm_space=True):
+ """inputs is ids or tokens, do not need self.sp"""
+ if len(inputs) > 0 and type(inputs[0]) == int:
+ tokens = [self.dict[id] for id in inputs]
+ else:
+ tokens = inputs
+ s = f"{join_symbol}".join(tokens)
+ if replace_spm_space:
+ s = s.replace(self.SPM_SPACE, ' ').strip()
+ return s
+
+ def merge_spm_timestamp(self, timestamp):
+ merged_timestamp = []
+ i = 0
+ while i < len(timestamp):
+ token, start, end = timestamp[i]
+ if token.startswith(self.SPM_SPACE):
+ token = token.replace(self.SPM_SPACE, "")
+ current_end = end
+ next_i = i + 1
+ while next_i < len(timestamp):
+ next_token, next_start, next_end = timestamp[next_i]
+ if re.match("^[a-zA-Z']+$", next_token):
+ token += next_token
+ current_end = next_end
+ next_i += 1
+ else:
+ break
+ end = current_end
+ i = next_i
+ else:
+ i += 1
+ merged_timestamp.append((token, start, end))
+ return merged_timestamp
diff --git a/fireredasr2s/fireredasr2/tokenizer/llm_tokenizer.py b/fireredasr2s/fireredasr2/tokenizer/llm_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d18a9de68a0d86e20b15104422792cd633d9ebfe
--- /dev/null
+++ b/fireredasr2s/fireredasr2/tokenizer/llm_tokenizer.py
@@ -0,0 +1,107 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import re
+
+import torch
+from transformers import AutoTokenizer
+from transformers.trainer_pt_utils import LabelSmoother
+
+DEFAULT_SPEECH_TOKEN = ""
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+
+
+class LlmTokenizerWrapper:
+ @classmethod
+ def build_llm_tokenizer(cls, llm_path, use_flash_attn=False):
+ tokenizer = AutoTokenizer.from_pretrained(llm_path)
+ if use_flash_attn:
+ tokenizer.padding_side = "left"
+ else:
+ tokenizer.padding_side = "right"
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
+ tokenizer.add_special_tokens(special_tokens_dict)
+ return tokenizer
+
+ @classmethod
+ def clean_text(cls, origin_text):
+ """remove punc, remove space between Chinese and keep space between English"""
+ # remove punc
+ text = re.sub("[,。?!,\.!?《》()\·“”、\\/]", "", origin_text)
+ # merge space
+ text = re.sub("\s+", " ", text)
+
+ # remove space between Chinese and keep space between English
+ pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])') # Chinese
+ parts = pattern.split(text.strip())
+ parts = [p for p in parts if len(p.strip()) > 0]
+ text = "".join(parts)
+ text = text.strip()
+
+ text = text.lower()
+ return text
+
+ @classmethod
+ def preprocess_texts(cls, origin_texts, tokenizer, max_len, decode=False):
+ messages = []
+ clean_texts = []
+ for i, origin_text in enumerate(origin_texts):
+ text = cls.clean_text(origin_text)
+ clean_texts.append(text)
+ text = text if not decode else ""
+ message = [
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
+ {"role": "assistant", "content": text},
+ ]
+ messages.append(message)
+
+ texts = []
+ if not decode:
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ else:
+ TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
+ for i, msg in enumerate(messages):
+ texts.append(
+ tokenizer.apply_chat_template(
+ msg,
+ tokenize=True,
+ chat_template=TEMPLATE,
+ add_generation_prompt=False,
+ padding="longest",
+ max_length=max_len,
+ truncation=True,
+ )
+ )
+
+ # Padding texts
+ max_len_texts = max([len(text) for text in texts])
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
+ else:
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
+ input_ids = torch.tensor(texts, dtype=torch.int)
+
+ target_ids = input_ids.clone()
+ target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
+
+ # first get the indices of the tokens
+ mask_prompt = True
+ if mask_prompt:
+ mask_indices = torch.where(
+ input_ids == tokenizer.convert_tokens_to_ids("assistant")
+ )
+ for i in range(mask_indices[0].size(0)):
+ row = mask_indices[0][i]
+ col = mask_indices[1][i]
+ target_ids[row, : col + 2] = IGNORE_TOKEN_ID
+
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
+
+ target_ids = target_ids.type(torch.LongTensor)
+ input_ids = input_ids.type(torch.LongTensor)
+ return input_ids, attention_mask, target_ids, clean_texts
diff --git a/fireredasr2s/fireredasr2/utils/io.py b/fireredasr2s/fireredasr2/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..ade8055ab2d19900b427387711df870375beeef5
--- /dev/null
+++ b/fireredasr2s/fireredasr2/utils/io.py
@@ -0,0 +1,55 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import glob
+import os
+import logging
+logger = logging.getLogger(__name__)
+
+from textgrid import TextGrid, IntervalTier
+
+
+def get_wav_info(args):
+ """
+ Returns:
+ wavs: list of (uttid, wav_path)
+ """
+ base = lambda p: os.path.basename(p).replace(".wav", "")
+ if args.wav_path:
+ wavs = [(base(args.wav_path), args.wav_path)]
+ elif args.wav_paths and len(args.wav_paths) >= 1:
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
+ elif args.wav_scp:
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
+ if args.sort_wav_by_dur:
+ logger.info("Sort wav by duration...")
+ utt2dur = os.path.join(os.path.dirname(args.wav_scp), "utt2dur")
+ if os.path.exists(utt2dur):
+ utt2dur = [l.strip().split() for l in open(utt2dur)]
+ utt2dur = {l[0]: float(l[1]) for l in utt2dur if len(l) == 2}
+ wavs = sorted(wavs, key=lambda x: -utt2dur[x[0]])
+ logger.info("Sort Done")
+ else:
+ logger.info(f"Not find {utt2dur}, un-sort")
+ elif args.wav_dir:
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
+ wavs = [(base(p), p) for p in sorted(wavs)]
+ else:
+ raise ValueError("Please provide valid wav info")
+ logger.info(f"#wavs={len(wavs)}")
+ return wavs
+
+
+def write_textgrid(wav_path, wav_dur, event):
+ textgrid_file = wav_path.replace(".wav", ".TextGrid")
+ logger.info(f"Write {textgrid_file}")
+ textgrid = TextGrid(maxTime=wav_dur)
+ tier = IntervalTier(name="token", maxTime=wav_dur)
+ for token, start_s, end_s in event:
+ if start_s == end_s:
+ logger.info(f"Write TG, skip start=end {start_s}")
+ continue
+ start_s = max(start_s, 0)
+ end_s = min(end_s, wav_dur)
+ tier.add(minTime=start_s, maxTime=end_s, mark=token)
+ textgrid.append(tier)
+ textgrid.write(textgrid_file)
diff --git a/fireredasr2s/fireredasr2/utils/wer.py b/fireredasr2s/fireredasr2/utils/wer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cbd2d524dc84f049323064962dc7055fa27a3a7
--- /dev/null
+++ b/fireredasr2s/fireredasr2/utils/wer.py
@@ -0,0 +1,326 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu)
+
+import argparse
+import re
+from collections import OrderedDict
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--ref", type=str, required=True)
+parser.add_argument("--hyp", type=str, required=True)
+parser.add_argument("--print_sentence_wer", type=int, default=0)
+parser.add_argument("--do_tn", type=int, default=0, help="simple tn by cn2an")
+parser.add_argument("--rm_special", type=int, default=0, help="remove <\|.*?\|>")
+
+
+def main(args):
+ uttid2refs = read_uttid2tokens(args.ref, args.do_tn, args.rm_special)
+ uttid2hyps = read_uttid2tokens(args.hyp, args.do_tn, args.rm_special)
+ uttid2wer_info, wer_stat, en_dig_stat = compute_uttid2wer_info(
+ uttid2refs, uttid2hyps, args.print_sentence_wer)
+ wer_stat.print()
+ en_dig_stat.print()
+
+
+def read_uttid2tokens(filename, do_tn=False, rm_special=False):
+ print(f">>> Read uttid to tokens: {filename}", flush=True)
+ uttid2tokens = OrderedDict()
+ uttid2text = read_uttid2text(filename, do_tn, rm_special)
+ for uttid, text in uttid2text.items():
+ tokens = text2tokens(text)
+ uttid2tokens[uttid] = tokens
+ return uttid2tokens
+
+
+def read_uttid2text(filename, do_tn=False, rm_special=False):
+ uttid2text = OrderedDict()
+ with open(filename, "r", encoding="utf8") as fin:
+ for i, line in enumerate(fin):
+ cols = line.split()
+ if len(cols) == 0:
+ print("[WARN] empty line, continue", i, flush=True)
+ continue
+ assert cols[0] not in uttid2text, f"repeated uttid: {line}"
+ if len(cols) == 1:
+ uttid2text[cols[0]] = ""
+ continue
+ txt = " ".join(cols[1:])
+ if rm_special:
+ txt = " ".join([t for t in re.split("<\|.*?\|>", txt) if t.strip() != ""])
+ if do_tn:
+ import cn2an
+ txt = cn2an.transform(txt, "an2cn")
+ uttid2text[cols[0]] = txt
+ return uttid2text
+
+
+def text2tokens(text):
+ PUNCTUATIONS = ",。?!,\.?!"#$%&'()*+-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·。\":" + "()\[\]{}/;`|=+"
+ if text == "":
+ return []
+ tokens = []
+
+ text = re.sub("", "", text)
+ text = re.sub(r"[%s]+" % PUNCTUATIONS, " ", text)
+ text = re.sub("<.*>", "", text)
+ text = fix_abbr_simple(text)
+ #pattern = re.compile(r'([\u4e00-\u9fff])')
+ pattern = re.compile(r'([\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u31f0-\u31ff])')
+ parts = pattern.split(text.strip().upper())
+ parts = [p for p in parts if len(p.strip()) > 0]
+ for part in parts:
+ if pattern.fullmatch(part) is not None:
+ tokens.append(part)
+ else:
+ for word in part.strip().split():
+ tokens.append(word)
+ return tokens
+
+
+def fix_abbr_simple(text):
+ ori_text = text
+ # 扔掉超长的
+ if re.search(r"(? '{text}'")
+ return text
+
+
+
+
+def compute_uttid2wer_info(refs, hyps, print_sentence_wer=False):
+ print(f">>> Compute uttid to wer info", flush=True)
+
+ uttid2wer_info = OrderedDict()
+ wer_stat = WerStats()
+ en_dig_stat = EnDigStats()
+
+ for uttid, ref in refs.items():
+ if uttid not in hyps:
+ print(f"[WARN] No hyp for {uttid}", flush=True)
+ continue
+ hyp = hyps[uttid]
+
+ if len(hyp) - len(ref) >= 8:
+ print(f"[BidLengthDiff]: {uttid} {len(ref)} {len(hyp)}#{' '.join(ref)}#{' '.join(hyp)}")
+ #continue
+
+ wer_info = compute_one_wer_info(ref, hyp)
+ uttid2wer_info[uttid] = wer_info
+ ns = count_english_ditgit(ref, hyp, wer_info)
+ wer_stat.add(wer_info)
+ en_dig_stat.add(*ns)
+ if print_sentence_wer:
+ print(f"{uttid} {wer_info}")
+
+ return uttid2wer_info, wer_stat, en_dig_stat
+
+
+COST_SUB = 3
+COST_DEL = 3
+COST_INS = 3
+
+ALIGN_CRT = 0
+ALIGN_SUB = 1
+ALIGN_DEL = 2
+ALIGN_INS = 3
+ALIGN_END = 4
+
+
+def compute_one_wer_info(ref, hyp):
+ """Impl minimum edit distance and backtrace.
+ Args:
+ ref, hyp: List[str]
+ Returns:
+ WerInfo
+ """
+ ref_len = len(ref)
+ hyp_len = len(hyp)
+
+ class _DpPoint:
+ def __init__(self, cost, align):
+ self.cost = cost
+ self.align = align
+
+ dp = []
+ for i in range(0, ref_len + 1):
+ dp.append([])
+ for j in range(0, hyp_len + 1):
+ dp[-1].append(_DpPoint(i * j, ALIGN_CRT))
+
+ # Initialize
+ for i in range(1, hyp_len + 1):
+ dp[0][i].cost = dp[0][i - 1].cost + COST_INS;
+ dp[0][i].align = ALIGN_INS
+ for i in range(1, ref_len + 1):
+ dp[i][0].cost = dp[i - 1][0].cost + COST_DEL
+ dp[i][0].align = ALIGN_DEL
+
+ # DP
+ for i in range(1, ref_len + 1):
+ for j in range(1, hyp_len + 1):
+ min_cost = 0
+ min_align = ALIGN_CRT
+ if hyp[j - 1] == ref[i - 1]:
+ min_cost = dp[i - 1][j - 1].cost
+ min_align = ALIGN_CRT
+ else:
+ min_cost = dp[i - 1][j - 1].cost + COST_SUB
+ min_align = ALIGN_SUB
+
+ del_cost = dp[i - 1][j].cost + COST_DEL
+ if del_cost < min_cost:
+ min_cost = del_cost
+ min_align = ALIGN_DEL
+
+ ins_cost = dp[i][j - 1].cost + COST_INS
+ if ins_cost < min_cost:
+ min_cost = ins_cost
+ min_align = ALIGN_INS
+
+ dp[i][j].cost = min_cost
+ dp[i][j].align = min_align
+
+ # Backtrace
+ crt = sub = ins = det = 0
+ i = ref_len
+ j = hyp_len
+ align = []
+ while i > 0 or j > 0:
+ if dp[i][j].align == ALIGN_CRT:
+ align.append((i, j, ALIGN_CRT))
+ i -= 1
+ j -= 1
+ crt += 1
+ elif dp[i][j].align == ALIGN_SUB:
+ align.append((i, j, ALIGN_SUB))
+ i -= 1
+ j -= 1
+ sub += 1
+ elif dp[i][j].align == ALIGN_DEL:
+ align.append((i, j, ALIGN_DEL))
+ i -= 1
+ det += 1
+ elif dp[i][j].align == ALIGN_INS:
+ align.append((i, j, ALIGN_INS))
+ j -= 1
+ ins += 1
+
+ err = sub + det + ins
+ align.reverse()
+ wer_info = WerInfo(ref_len, err, crt, sub, det, ins, align)
+ return wer_info
+
+
+
+class WerInfo:
+ def __init__(self, ref, err, crt, sub, dele, ins, ali):
+ self.r = ref
+ self.e = err
+ self.c = crt
+ self.s = sub
+ self.d = dele
+ self.i = ins
+ self.ali = ali
+ r = max(self.r, 1)
+ self.wer = 100.0 * (self.s + self.d + self.i) / r
+
+ def __repr__(self):
+ s = f"wer {self.wer:.2f} ref {self.r:2d} sub {self.s:2d} del {self.d:2d} ins {self.i:2d}"
+ return s
+
+
+class WerStats:
+ def __init__(self):
+ self.infos = []
+
+ def add(self, wer_info):
+ self.infos.append(wer_info)
+
+ def print(self):
+ r = sum(info.r for info in self.infos)
+ if r <= 0:
+ print(f"REF len is {r}, check")
+ r = 1
+ s = sum(info.s for info in self.infos)
+ d = sum(info.d for info in self.infos)
+ i = sum(info.i for info in self.infos)
+ se = 100.0 * s / r
+ de = 100.0 * d / r
+ ie = 100.0 * i / r
+ wer = 100.0 * (s + d + i) / r
+ sen = max(len(self.infos), 1)
+ errsen = sum(info.e > 0 for info in self.infos)
+ ser = 100.0 * errsen / sen
+ print("-"*80)
+ print(f"ref{r:6d} sub{s:6d} del{d:6d} ins{i:6d}")
+ print(f"WER{wer:6.2f} sub{se:6.2f} del{de:6.2f} ins{ie:6.2f}")
+ print(f"SER{ser:6.2f} = {errsen} / {sen}")
+ print("-"*80)
+
+
+class EnDigStats:
+ def __init__(self):
+ self.n_en_word = 0
+ self.n_en_correct = 0
+ self.n_dig_word = 0
+ self.n_dig_correct = 0
+
+ def add(self, n_en_word, n_en_correct, n_dig_word, n_dig_correct):
+ self.n_en_word += n_en_word
+ self.n_en_correct += n_en_correct
+ self.n_dig_word += n_dig_word
+ self.n_dig_correct += n_dig_correct
+
+ def print(self):
+ print(f"English #word={self.n_en_word}, #correct={self.n_en_correct}\n"
+ f"Digit #word={self.n_dig_word}, #correct={self.n_dig_correct}")
+ print("-"*80)
+
+
+
+def count_english_ditgit(ref, hyp, wer_info):
+ patt_en = "[a-zA-Z\.\-\']+"
+ patt_dig = "[0-9]+"
+ patt_cjk = re.compile(r'([\u4e00-\u9fff])')
+ n_en_word = 0
+ n_en_correct = 0
+ n_dig_word = 0
+ n_dig_correct = 0
+ ali = wer_info.ali
+ for i, token in enumerate(ref):
+ if re.match(patt_en, token):
+ n_en_word += 1
+ for y in ali:
+ if y[0] == i+1 and y[2] == ALIGN_CRT:
+ j = y[1] - 1
+ n_en_correct += 1
+ break
+ if re.match(patt_dig, token):
+ n_dig_word += 1
+ for y in ali:
+ if y[0] == i+1 and y[2] == ALIGN_CRT:
+ j = y[1] - 1
+ n_dig_correct += 1
+ break
+ if not re.match(patt_cjk, token) and not re.match(patt_en, token) \
+ and not re.match(patt_dig, token):
+ print("[WiredChar]:", token)
+ return n_en_word, n_en_correct, n_dig_word, n_dig_correct
+
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ print(args, flush=True)
+ main(args)
diff --git a/fireredasr2s/fireredasr2s-cli b/fireredasr2s/fireredasr2s-cli
new file mode 100644
index 0000000000000000000000000000000000000000..1d97813f8170c022c1e02078e17a4ac1127e892d
--- /dev/null
+++ b/fireredasr2s/fireredasr2s-cli
@@ -0,0 +1,273 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
+
+import argparse
+import glob
+import json
+import logging
+import os
+
+import soundfile as sf
+from textgrid import IntervalTier, TextGrid
+
+from fireredasr2s.fireredasr2 import FireRedAsr2Config
+from fireredasr2s.fireredasr2system import (FireRedAsr2System,
+ FireRedAsr2SystemConfig)
+from fireredasr2s.fireredlid import FireRedLidConfig
+from fireredasr2s.fireredpunc import FireRedPuncConfig
+from fireredasr2s.fireredvad import FireRedVadConfig
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredasr2s.asr_system")
+
+
+parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+input_g = parser.add_argument_group("Input Options")
+input_g.add_argument("--wav_path", type=str)
+input_g.add_argument("--wav_paths", type=str, nargs="*")
+input_g.add_argument("--wav_dir", type=str)
+input_g.add_argument("--wav_scp", type=str)
+input_g.add_argument("--sort_wav_by_dur", type=int, default=0)
+
+output_g = parser.add_argument_group("Output Options")
+output_g.add_argument("--outdir", type=str, default="output")
+output_g.add_argument("--write_textgrid", type=int, default=1)
+output_g.add_argument("--write_srt", type=int, default=1)
+output_g.add_argument("--save_segment", type=int, default=0)
+
+module_g = parser.add_argument_group("Module Switches")
+module_g.add_argument('--enable_vad', type=int, default=1, choices=[0, 1])
+module_g.add_argument('--enable_lid', type=int, default=1, choices=[0, 1])
+module_g.add_argument('--enable_punc', type=int, default=1, choices=[0, 1])
+
+asr_g = parser.add_argument_group("ASR Options")
+asr_g.add_argument('--asr_type', type=str, default="aed", choices=["aed", "llm"])
+asr_g.add_argument('--asr_model_dir', type=str, default="pretrained_models/FireRedASR2-AED")
+asr_g.add_argument('--asr_use_gpu', type=int, default=1)
+asr_g.add_argument('--asr_use_half', type=int, default=0)
+asr_g.add_argument("--asr_batch_size", type=int, default=1)
+# FireRedASR-AED
+asr_g.add_argument("--beam_size", type=int, default=3)
+asr_g.add_argument("--decode_max_len", type=int, default=0)
+asr_g.add_argument("--nbest", type=int, default=1)
+asr_g.add_argument("--softmax_smoothing", type=float, default=1.25)
+asr_g.add_argument("--aed_length_penalty", type=float, default=0.6)
+asr_g.add_argument("--eos_penalty", type=float, default=1.0)
+asr_g.add_argument("--return_timestamp", type=int, default=1)
+# FireRedASR-AED External LM
+asr_g.add_argument("--elm_dir", type=str, default="")
+asr_g.add_argument("--elm_weight", type=float, default=0.0)
+
+vad_g = parser.add_argument_group("VAD Options")
+vad_g.add_argument('--vad_model_dir', type=str, default="pretrained_models/FireRedVAD/VAD")
+vad_g.add_argument('--vad_use_gpu', type=int, default=1)
+# Non-streaming VAD
+vad_g.add_argument("--vad_chunk_max_frame", type=int, default=30000)
+vad_g.add_argument("--smooth_window_size", type=int, default=5)
+vad_g.add_argument("--speech_threshold", type=float, default=0.2)
+vad_g.add_argument("--min_speech_frame", type=int, default=20)
+vad_g.add_argument("--max_speech_frame", type=int, default=1000)
+vad_g.add_argument("--min_silence_frame", type=int, default=10)
+vad_g.add_argument("--merge_silence_frame", type=int, default=50)
+vad_g.add_argument("--extend_speech_frame", type=int, default=10)
+
+lid_g = parser.add_argument_group("LID Options")
+lid_g.add_argument('--lid_model_dir', type=str, default="pretrained_models/FireRedLID")
+lid_g.add_argument('--lid_use_gpu', type=int, default=1)
+
+punc_g = parser.add_argument_group("Punc Options")
+punc_g.add_argument('--punc_model_dir', type=str, default="pretrained_models/FireRedPunc")
+punc_g.add_argument('--punc_use_gpu', type=int, default=1)
+punc_g.add_argument("--punc_batch_size", type=int, default=1)
+punc_g.add_argument('--punc_with_timestamp', type=int, default=1)
+punc_g.add_argument('--punc_sentence_max_length', type=int, default=-1)
+
+
+def main(args):
+ wavs = get_wav_info(args)
+ if args.outdir:
+ os.makedirs(args.outdir, exist_ok=True)
+ fout = open(args.outdir + "/result.jsonl", "w") if args.outdir else None
+
+ # Build Models
+ # VAD
+ vad_config = FireRedVadConfig(
+ args.vad_use_gpu,
+ args.smooth_window_size,
+ args.speech_threshold,
+ args.min_speech_frame,
+ args.max_speech_frame,
+ args.min_silence_frame,
+ args.merge_silence_frame,
+ args.extend_speech_frame,
+ args.vad_chunk_max_frame
+ )
+ # LID
+ lid_config = FireRedLidConfig(args.lid_use_gpu)
+ # ASR
+ asr_config = FireRedAsr2Config(
+ args.asr_use_gpu,
+ args.asr_use_half,
+ args.beam_size,
+ args.nbest,
+ args.decode_max_len,
+ args.softmax_smoothing,
+ args.aed_length_penalty,
+ args.eos_penalty,
+ args.return_timestamp,
+ 0, 1.0, 0.0, 1.0,
+ args.elm_dir,
+ args.elm_weight
+ )
+ # Punc
+ punc_config = FireRedPuncConfig(
+ args.punc_use_gpu,
+ args.punc_sentence_max_length
+ )
+
+ asr_system_config = FireRedAsr2SystemConfig(
+ args.vad_model_dir, args.lid_model_dir,
+ args.asr_type, args.asr_model_dir, args.punc_model_dir,
+ vad_config, lid_config, asr_config, punc_config,
+ args.asr_batch_size, args.punc_batch_size,
+ args.enable_vad, args.enable_lid, args.enable_punc
+ )
+ asr_system = FireRedAsr2System(asr_system_config)
+
+ for i, (uttid, wav_path) in enumerate(wavs):
+ logger.info("")
+
+ result = asr_system.process(wav_path, uttid)
+
+ logger.info(f"FINAL: {result}")
+
+ if fout:
+ fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
+ fout.flush()
+ name = os.path.basename(wav_path).replace(".wav", "")
+ if args.write_textgrid:
+ tg_dir = os.path.join(args.outdir, "asr_tg")
+ write_textgrid(tg_dir, name, result["dur_s"], result["sentences"], result["words"])
+ if args.write_srt:
+ srt_dir = os.path.join(args.outdir, "asr_srt")
+ write_srt(srt_dir, name, result["sentences"])
+ if args.save_segment:
+ save_segment_dir = os.path.join(args.outdir, "vad_segment")
+ split_and_save_segment(wav_path, result["vad_segments_ms"], save_segment_dir)
+
+ if fout:
+ fout.close()
+ logger.info("All Done")
+
+
+def get_wav_info(args):
+ """
+ Returns:
+ wavs: list of (uttid, wav_path)
+ """
+ def base(p): return os.path.basename(p).replace(".wav", "")
+ if args.wav_path:
+ wavs = [(base(args.wav_path), args.wav_path)]
+ elif args.wav_paths and len(args.wav_paths) >= 1:
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
+ elif args.wav_scp:
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
+ elif args.wav_dir:
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
+ wavs = [(base(p), p) for p in sorted(wavs)]
+ else:
+ raise ValueError("Please provide valid wav info")
+ logger.info(f"#wavs={len(wavs)}")
+ return wavs
+
+
+def write_textgrid(tg_dir, name, wav_dur, sentences, words=None):
+ os.makedirs(tg_dir, exist_ok=True)
+ textgrid_file = os.path.join(tg_dir, name + ".TextGrid")
+ logger.info(f"Write {textgrid_file}")
+ textgrid = TextGrid(maxTime=wav_dur)
+
+ tier = IntervalTier(name="sentence", maxTime=wav_dur)
+ for sentence in sentences:
+ start_s = sentence["start_ms"] / 1000.0
+ end_s = sentence["end_ms"] / 1000.0
+ text = sentence["text"]
+ confi = sentence["asr_confidence"]
+ if start_s == end_s:
+ logger.info(f"(sent) Write TG, skip start=end {start_s} {text}")
+ continue
+ start_s = max(start_s, 0)
+ end_s = min(end_s, wav_dur)
+ tier.add(minTime=start_s, maxTime=end_s, mark=f"{text}\n{confi}")
+ textgrid.append(tier)
+
+ if words:
+ tier = IntervalTier(name="token", maxTime=wav_dur)
+ for word in words:
+ start_s = word["start_ms"] / 1000.0
+ end_s = word["end_ms"] / 1000.0
+ text = word["text"]
+ if start_s == end_s:
+ logger.info(f"(word) Write TG, skip start=end {start_s} {text}")
+ continue
+ start_s = max(start_s, 0)
+ end_s = min(end_s, wav_dur)
+ tier.add(minTime=start_s, maxTime=end_s, mark=text)
+ textgrid.append(tier)
+ textgrid.write(textgrid_file)
+
+
+def write_srt(srt_dir, name, sentences):
+ def _ms2srt_time(ms):
+ h = ms // 1000 // 3600
+ m = (ms // 1000 % 3600) // 60
+ s = (ms // 1000 % 3600) % 60
+ ms = (ms % 1000)
+ r = f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
+ return r
+ os.makedirs(srt_dir, exist_ok=True)
+ srt_file = os.path.join(srt_dir, name + ".srt")
+ logger.info(f"Write {srt_file}")
+
+ i = 0
+ with open(srt_file, "w") as fout:
+ for sentence in sentences:
+ start_ms = sentence["start_ms"]
+ end_ms = sentence["end_ms"]
+ text = sentence["text"]
+ if text.strip() == "":
+ continue
+
+ i += 1
+ fout.write(f"{i}\n")
+ s = _ms2srt_time(start_ms)
+ e = _ms2srt_time(end_ms)
+ fout.write(f"{s} --> {e}\n")
+ fout.write(f"{text}\n")
+ if i != len(sentences):
+ fout.write("\n")
+
+
+def split_and_save_segment(wav_path, timestamps_ms, save_segment_dir):
+ logger.info("Split & save segment")
+ os.makedirs(save_segment_dir, exist_ok=True)
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
+ for i, (start_ms, end_ms) in enumerate(timestamps_ms):
+ uttid = wav_path.split("/")[-1].replace(".wav", "")
+ seg_id = f"{uttid}_{i}_{start_ms}_{end_ms}"
+ seg_path = f"{save_segment_dir}/{seg_id}.wav"
+ start = int(start_ms / 1000 * sample_rate)
+ end = int(end_ms / 1000 * sample_rate)
+ sf.write(seg_path, wav_np[start:end], samplerate=sample_rate)
+
+
+def cli_main():
+ args = parser.parse_args()
+ logger.info(args)
+ main(args)
+
+
+if __name__ == "__main__":
+ cli_main()
diff --git a/fireredasr2s/fireredasr2s_cli.py b/fireredasr2s/fireredasr2s_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d97813f8170c022c1e02078e17a4ac1127e892d
--- /dev/null
+++ b/fireredasr2s/fireredasr2s_cli.py
@@ -0,0 +1,273 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
+
+import argparse
+import glob
+import json
+import logging
+import os
+
+import soundfile as sf
+from textgrid import IntervalTier, TextGrid
+
+from fireredasr2s.fireredasr2 import FireRedAsr2Config
+from fireredasr2s.fireredasr2system import (FireRedAsr2System,
+ FireRedAsr2SystemConfig)
+from fireredasr2s.fireredlid import FireRedLidConfig
+from fireredasr2s.fireredpunc import FireRedPuncConfig
+from fireredasr2s.fireredvad import FireRedVadConfig
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredasr2s.asr_system")
+
+
+parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+input_g = parser.add_argument_group("Input Options")
+input_g.add_argument("--wav_path", type=str)
+input_g.add_argument("--wav_paths", type=str, nargs="*")
+input_g.add_argument("--wav_dir", type=str)
+input_g.add_argument("--wav_scp", type=str)
+input_g.add_argument("--sort_wav_by_dur", type=int, default=0)
+
+output_g = parser.add_argument_group("Output Options")
+output_g.add_argument("--outdir", type=str, default="output")
+output_g.add_argument("--write_textgrid", type=int, default=1)
+output_g.add_argument("--write_srt", type=int, default=1)
+output_g.add_argument("--save_segment", type=int, default=0)
+
+module_g = parser.add_argument_group("Module Switches")
+module_g.add_argument('--enable_vad', type=int, default=1, choices=[0, 1])
+module_g.add_argument('--enable_lid', type=int, default=1, choices=[0, 1])
+module_g.add_argument('--enable_punc', type=int, default=1, choices=[0, 1])
+
+asr_g = parser.add_argument_group("ASR Options")
+asr_g.add_argument('--asr_type', type=str, default="aed", choices=["aed", "llm"])
+asr_g.add_argument('--asr_model_dir', type=str, default="pretrained_models/FireRedASR2-AED")
+asr_g.add_argument('--asr_use_gpu', type=int, default=1)
+asr_g.add_argument('--asr_use_half', type=int, default=0)
+asr_g.add_argument("--asr_batch_size", type=int, default=1)
+# FireRedASR-AED
+asr_g.add_argument("--beam_size", type=int, default=3)
+asr_g.add_argument("--decode_max_len", type=int, default=0)
+asr_g.add_argument("--nbest", type=int, default=1)
+asr_g.add_argument("--softmax_smoothing", type=float, default=1.25)
+asr_g.add_argument("--aed_length_penalty", type=float, default=0.6)
+asr_g.add_argument("--eos_penalty", type=float, default=1.0)
+asr_g.add_argument("--return_timestamp", type=int, default=1)
+# FireRedASR-AED External LM
+asr_g.add_argument("--elm_dir", type=str, default="")
+asr_g.add_argument("--elm_weight", type=float, default=0.0)
+
+vad_g = parser.add_argument_group("VAD Options")
+vad_g.add_argument('--vad_model_dir', type=str, default="pretrained_models/FireRedVAD/VAD")
+vad_g.add_argument('--vad_use_gpu', type=int, default=1)
+# Non-streaming VAD
+vad_g.add_argument("--vad_chunk_max_frame", type=int, default=30000)
+vad_g.add_argument("--smooth_window_size", type=int, default=5)
+vad_g.add_argument("--speech_threshold", type=float, default=0.2)
+vad_g.add_argument("--min_speech_frame", type=int, default=20)
+vad_g.add_argument("--max_speech_frame", type=int, default=1000)
+vad_g.add_argument("--min_silence_frame", type=int, default=10)
+vad_g.add_argument("--merge_silence_frame", type=int, default=50)
+vad_g.add_argument("--extend_speech_frame", type=int, default=10)
+
+lid_g = parser.add_argument_group("LID Options")
+lid_g.add_argument('--lid_model_dir', type=str, default="pretrained_models/FireRedLID")
+lid_g.add_argument('--lid_use_gpu', type=int, default=1)
+
+punc_g = parser.add_argument_group("Punc Options")
+punc_g.add_argument('--punc_model_dir', type=str, default="pretrained_models/FireRedPunc")
+punc_g.add_argument('--punc_use_gpu', type=int, default=1)
+punc_g.add_argument("--punc_batch_size", type=int, default=1)
+punc_g.add_argument('--punc_with_timestamp', type=int, default=1)
+punc_g.add_argument('--punc_sentence_max_length', type=int, default=-1)
+
+
+def main(args):
+ wavs = get_wav_info(args)
+ if args.outdir:
+ os.makedirs(args.outdir, exist_ok=True)
+ fout = open(args.outdir + "/result.jsonl", "w") if args.outdir else None
+
+ # Build Models
+ # VAD
+ vad_config = FireRedVadConfig(
+ args.vad_use_gpu,
+ args.smooth_window_size,
+ args.speech_threshold,
+ args.min_speech_frame,
+ args.max_speech_frame,
+ args.min_silence_frame,
+ args.merge_silence_frame,
+ args.extend_speech_frame,
+ args.vad_chunk_max_frame
+ )
+ # LID
+ lid_config = FireRedLidConfig(args.lid_use_gpu)
+ # ASR
+ asr_config = FireRedAsr2Config(
+ args.asr_use_gpu,
+ args.asr_use_half,
+ args.beam_size,
+ args.nbest,
+ args.decode_max_len,
+ args.softmax_smoothing,
+ args.aed_length_penalty,
+ args.eos_penalty,
+ args.return_timestamp,
+ 0, 1.0, 0.0, 1.0,
+ args.elm_dir,
+ args.elm_weight
+ )
+ # Punc
+ punc_config = FireRedPuncConfig(
+ args.punc_use_gpu,
+ args.punc_sentence_max_length
+ )
+
+ asr_system_config = FireRedAsr2SystemConfig(
+ args.vad_model_dir, args.lid_model_dir,
+ args.asr_type, args.asr_model_dir, args.punc_model_dir,
+ vad_config, lid_config, asr_config, punc_config,
+ args.asr_batch_size, args.punc_batch_size,
+ args.enable_vad, args.enable_lid, args.enable_punc
+ )
+ asr_system = FireRedAsr2System(asr_system_config)
+
+ for i, (uttid, wav_path) in enumerate(wavs):
+ logger.info("")
+
+ result = asr_system.process(wav_path, uttid)
+
+ logger.info(f"FINAL: {result}")
+
+ if fout:
+ fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
+ fout.flush()
+ name = os.path.basename(wav_path).replace(".wav", "")
+ if args.write_textgrid:
+ tg_dir = os.path.join(args.outdir, "asr_tg")
+ write_textgrid(tg_dir, name, result["dur_s"], result["sentences"], result["words"])
+ if args.write_srt:
+ srt_dir = os.path.join(args.outdir, "asr_srt")
+ write_srt(srt_dir, name, result["sentences"])
+ if args.save_segment:
+ save_segment_dir = os.path.join(args.outdir, "vad_segment")
+ split_and_save_segment(wav_path, result["vad_segments_ms"], save_segment_dir)
+
+ if fout:
+ fout.close()
+ logger.info("All Done")
+
+
+def get_wav_info(args):
+ """
+ Returns:
+ wavs: list of (uttid, wav_path)
+ """
+ def base(p): return os.path.basename(p).replace(".wav", "")
+ if args.wav_path:
+ wavs = [(base(args.wav_path), args.wav_path)]
+ elif args.wav_paths and len(args.wav_paths) >= 1:
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
+ elif args.wav_scp:
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
+ elif args.wav_dir:
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
+ wavs = [(base(p), p) for p in sorted(wavs)]
+ else:
+ raise ValueError("Please provide valid wav info")
+ logger.info(f"#wavs={len(wavs)}")
+ return wavs
+
+
+def write_textgrid(tg_dir, name, wav_dur, sentences, words=None):
+ os.makedirs(tg_dir, exist_ok=True)
+ textgrid_file = os.path.join(tg_dir, name + ".TextGrid")
+ logger.info(f"Write {textgrid_file}")
+ textgrid = TextGrid(maxTime=wav_dur)
+
+ tier = IntervalTier(name="sentence", maxTime=wav_dur)
+ for sentence in sentences:
+ start_s = sentence["start_ms"] / 1000.0
+ end_s = sentence["end_ms"] / 1000.0
+ text = sentence["text"]
+ confi = sentence["asr_confidence"]
+ if start_s == end_s:
+ logger.info(f"(sent) Write TG, skip start=end {start_s} {text}")
+ continue
+ start_s = max(start_s, 0)
+ end_s = min(end_s, wav_dur)
+ tier.add(minTime=start_s, maxTime=end_s, mark=f"{text}\n{confi}")
+ textgrid.append(tier)
+
+ if words:
+ tier = IntervalTier(name="token", maxTime=wav_dur)
+ for word in words:
+ start_s = word["start_ms"] / 1000.0
+ end_s = word["end_ms"] / 1000.0
+ text = word["text"]
+ if start_s == end_s:
+ logger.info(f"(word) Write TG, skip start=end {start_s} {text}")
+ continue
+ start_s = max(start_s, 0)
+ end_s = min(end_s, wav_dur)
+ tier.add(minTime=start_s, maxTime=end_s, mark=text)
+ textgrid.append(tier)
+ textgrid.write(textgrid_file)
+
+
+def write_srt(srt_dir, name, sentences):
+ def _ms2srt_time(ms):
+ h = ms // 1000 // 3600
+ m = (ms // 1000 % 3600) // 60
+ s = (ms // 1000 % 3600) % 60
+ ms = (ms % 1000)
+ r = f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
+ return r
+ os.makedirs(srt_dir, exist_ok=True)
+ srt_file = os.path.join(srt_dir, name + ".srt")
+ logger.info(f"Write {srt_file}")
+
+ i = 0
+ with open(srt_file, "w") as fout:
+ for sentence in sentences:
+ start_ms = sentence["start_ms"]
+ end_ms = sentence["end_ms"]
+ text = sentence["text"]
+ if text.strip() == "":
+ continue
+
+ i += 1
+ fout.write(f"{i}\n")
+ s = _ms2srt_time(start_ms)
+ e = _ms2srt_time(end_ms)
+ fout.write(f"{s} --> {e}\n")
+ fout.write(f"{text}\n")
+ if i != len(sentences):
+ fout.write("\n")
+
+
+def split_and_save_segment(wav_path, timestamps_ms, save_segment_dir):
+ logger.info("Split & save segment")
+ os.makedirs(save_segment_dir, exist_ok=True)
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
+ for i, (start_ms, end_ms) in enumerate(timestamps_ms):
+ uttid = wav_path.split("/")[-1].replace(".wav", "")
+ seg_id = f"{uttid}_{i}_{start_ms}_{end_ms}"
+ seg_path = f"{save_segment_dir}/{seg_id}.wav"
+ start = int(start_ms / 1000 * sample_rate)
+ end = int(end_ms / 1000 * sample_rate)
+ sf.write(seg_path, wav_np[start:end], samplerate=sample_rate)
+
+
+def cli_main():
+ args = parser.parse_args()
+ logger.info(args)
+ main(args)
+
+
+if __name__ == "__main__":
+ cli_main()
diff --git a/fireredasr2s/fireredasr2system.py b/fireredasr2s/fireredasr2system.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa09dda1c1b0d3b145280a46371ffda222e5f8c
--- /dev/null
+++ b/fireredasr2s/fireredasr2system.py
@@ -0,0 +1,200 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
+
+import logging
+import re
+from dataclasses import dataclass, field
+
+import soundfile as sf
+
+from fireredasr2s.fireredasr2 import FireRedAsr2, FireRedAsr2Config
+from fireredasr2s.fireredlid import FireRedLid, FireRedLidConfig
+from fireredasr2s.fireredpunc import FireRedPunc, FireRedPuncConfig
+from fireredasr2s.fireredvad import FireRedVad, FireRedVadConfig
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredasr2s.asr_system")
+
+
+@dataclass
+class FireRedAsr2SystemConfig:
+ vad_model_dir: str = "pretrained_models/FireRedVAD/VAD"
+ lid_model_dir: str = "pretrained_models/FireRedLID"
+ asr_type: str = "aed"
+ asr_model_dir: str = "pretrained_models/FireRedASR2-AED"
+ punc_model_dir: str = "pretrained_models/FireRedPunc"
+ vad_config: FireRedVadConfig = field(default_factory=FireRedVadConfig)
+ lid_config: FireRedLidConfig = field(default_factory=FireRedLidConfig)
+ asr_config: FireRedAsr2Config = field(default_factory=FireRedAsr2Config)
+ punc_config: FireRedPuncConfig = field(default_factory=FireRedPuncConfig)
+ asr_batch_size: int = 1
+ punc_batch_size: int = 1
+ enable_vad: bool = True
+ enable_lid: bool = True
+ enable_punc: bool = True
+
+
+class FireRedAsr2System:
+ def __init__(self, config):
+ c = config
+ self.vad = FireRedVad.from_pretrained(c.vad_model_dir, c.vad_config) if c.enable_vad else None
+ self.lid = FireRedLid.from_pretrained(c.lid_model_dir, c.lid_config) if c.enable_lid else None
+ self.asr = FireRedAsr2.from_pretrained(c.asr_type, c.asr_model_dir, c.asr_config)
+ self.punc = FireRedPunc.from_pretrained(c.punc_model_dir, c.punc_config) if c.enable_punc else None
+ self.config = config
+
+ def process(self, wav_path, uttid="tmpid"):
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
+ dur = wav_np.shape[0]/sample_rate
+
+ # 1. VAD
+ if self.config.enable_vad:
+ vad_result, prob = self.vad.detect(wav_path)
+ vad_segments = vad_result["timestamps"]
+ logger.info(f"VAD: {vad_result}")
+ else:
+ vad_segments = [(0, dur)]
+ vad_result = {"timestamps" : vad_segments}
+
+ # 2. VAD output to ASR input
+ asr_results = []
+ lid_results = []
+ assert sample_rate == 16000
+ batch_asr_uttid = []
+ batch_asr_wav = []
+ for j, (start_s, end_s) in enumerate(vad_segments):
+ wav_segment = wav_np[int(start_s*sample_rate):int(end_s*sample_rate)]
+ vad_uttid = f"{uttid}_s{int(start_s*1000)}_e{int(end_s*1000)}"
+ batch_asr_uttid.append(vad_uttid)
+ batch_asr_wav.append((sample_rate, wav_segment))
+ if len(batch_asr_uttid) < self.config.asr_batch_size and j != len(vad_segments) - 1:
+ continue
+
+ # 3. ASR
+ batch_asr_results = self.asr.transcribe(batch_asr_uttid, batch_asr_wav)
+ logger.info(f"ASR: {batch_asr_results}")
+
+ if self.config.enable_lid:
+ batch_lid_results = self.lid.process(batch_asr_uttid, batch_asr_wav)
+ logger.info(f"LID: {batch_lid_results}")
+ else:
+ # Note: The original batch size is used here to ensure alignment with the initial number of ASR results
+ batch_lid_results = [None] * len(batch_asr_results)
+
+ # Synchronously traverse and filter to ensure that asr_results and lid_results always maintain a one-to-one correspondence
+ for a_res, l_res in zip(batch_asr_results, batch_lid_results):
+ text = a_res.get("text", "").strip()
+ # Filter out , and completely empty strings ""
+ if not text or re.search(r"()|()", text):
+ continue
+ asr_results.append(a_res)
+ lid_results.append(l_res)
+
+ batch_asr_uttid = []
+ batch_asr_wav = []
+
+ # 4. ASR output to Postprocess input
+ if self.config.enable_punc:
+ punc_results = []
+ batch_asr_text = []
+ batch_asr_uttid = []
+ batch_asr_timestamp = []
+ for j, asr_result in enumerate(asr_results):
+ batch_asr_text.append(asr_result["text"])
+ batch_asr_uttid.append(asr_result["uttid"])
+ if self.config.asr_config.return_timestamp:
+ batch_asr_timestamp.append(asr_result.get("timestamp", []))
+ elif "timestamp" in asr_result:
+ batch_asr_timestamp.append(asr_result["timestamp"])
+ if len(batch_asr_text) < self.config.punc_batch_size and j != len(asr_results) - 1:
+ continue
+
+ # 5. Punc
+ if self.config.asr_config.return_timestamp:
+ batch_punc_results = self.punc.process_with_timestamp(batch_asr_timestamp, batch_asr_uttid)
+ else:
+ batch_punc_results = self.punc.process(batch_asr_text, batch_asr_uttid)
+ logger.info(f"Punc: {batch_punc_results}")
+
+ punc_results.extend(batch_punc_results)
+ batch_asr_text = []
+ batch_asr_uttid = []
+ batch_asr_timestamp = []
+ else:
+ punc_results = asr_results
+
+ # 6. Put all together & Format
+ sentences = []
+ words = []
+ for asr_result, punc_result, lid_result in zip(asr_results, punc_results, lid_results):
+ assert asr_result["uttid"] == punc_result["uttid"], f"fix code: {asr_result} | {punc_result}"
+ start_ms, end_ms = asr_result["uttid"].split("_")[-2:]
+ assert start_ms.startswith("s") and end_ms.startswith("e")
+ start_ms, end_ms = int(start_ms[1:]), int(end_ms[1:])
+ if self.config.asr_config.return_timestamp:
+ sub_sentences = []
+ if self.config.enable_punc:
+ for i, punc_sent in enumerate(punc_result["punc_sentences"]):
+ start = start_ms + int(punc_sent["start_s"]*1000)
+ end = start_ms + int(punc_sent["end_s"]*1000)
+ if i == 0:
+ start = start_ms
+ if i == len(punc_result["punc_sentences"]) - 1:
+ end = end_ms
+ sub_sentence = {
+ "start_ms": start,
+ "end_ms": end,
+ "text": punc_sent["punc_text"],
+ "asr_confidence": asr_result["confidence"],
+ "lang": None,
+ "lang_confidence": 0
+ }
+ if lid_result:
+ sub_sentence["lang"] = lid_result["lang"]
+ sub_sentence["lang_confidence"] = lid_result["confidence"]
+ sub_sentences.append(sub_sentence)
+ else:
+ sub_sentences = [{
+ "start_ms": start_ms,
+ "end_ms": end_ms,
+ "text": asr_result["text"],
+ "asr_confidence": asr_result["confidence"],
+ "lang": None,
+ "lang_confidence": 0
+ }]
+ sentences.extend(sub_sentences)
+ else:
+ text = punc_result["punc_text"] if self.config.enable_punc else asr_result["text"]
+ sentence = {
+ "start_ms": start_ms,
+ "end_ms": end_ms,
+ "text": text,
+ "asr_confidence": asr_result["confidence"],
+ "lang": None,
+ "lang_confidence": 0
+ }
+ if lid_result:
+ sentence["lang"] = lid_result["lang"]
+ sentence["lang_confidence"] = lid_result["confidence"]
+ sentences.append(sentence)
+
+ if "timestamp" in asr_result:
+ for w, s, e in asr_result["timestamp"]:
+ word = {"start_ms": int(s*1000+start_ms), "end_ms":int(e*1000+start_ms), "text": w}
+ words.append(word)
+
+ vad_segments_ms = [(int(s*1000), int(e*1000)) for s, e in vad_result["timestamps"]]
+ text = "".join(s["text"] for s in sentences)
+ # Add space after English punctuation when followed by a letter
+ text = re.sub(r'([.,!?])\s*([a-zA-Z])', r'\1 \2', text)
+
+ result = {
+ "uttid": uttid,
+ "text": text,
+ "sentences": sentences,
+ "vad_segments_ms": vad_segments_ms,
+ "dur_s": dur,
+ "words": words,
+ "wav_path": wav_path
+ }
+ return result
diff --git a/fireredasr2s/fireredlid/README.md b/fireredasr2s/fireredlid/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9f3cd7ed4dce8be257d08c8db859a048ff4c3a62
--- /dev/null
+++ b/fireredasr2s/fireredlid/README.md
@@ -0,0 +1,125 @@
+## Language Code
+
+| Language Code | English Name | Chinese Name |
+|---|---|---|
+| zh | Chinese | 中文 |
+| en | English | 英语 |
+| es | Spanish | 西班牙语 |
+| fr | French | 法语 |
+| ja | Japanese | 日语 |
+| ko | Korean | 韩语 |
+| ru | Russian | 俄语 |
+| de | German | 德语 |
+| pt | Portuguese | 葡萄牙语 |
+| ab | Abkhazian | 阿布哈兹语 |
+| af | Afrikaans | 南非荷兰语 |
+| am | Amharic | 阿姆哈拉语 |
+| ar | Arabic | 阿拉伯语 |
+| as | Assamese | 阿萨姆语 |
+| az | Azerbaijani | 阿塞拜疆语 |
+| ba | Bashkir | 巴什基尔语 |
+| be | Belarusian | 白俄罗斯语 |
+| bg | Bulgarian | 保加利亚语 |
+| bn | Bengali | 孟加拉语 |
+| br | Breton | 布列塔尼语 |
+| bs | Bosnian | 波斯尼亚语 |
+| ca | Catalan | 加泰罗尼亚语 |
+| ceb | Cebuano | 宿务语 |
+| cs | Czech | 捷克语 |
+| cy | Welsh | 威尔士语 |
+| da | Danish | 丹麦语 |
+| el | Greek | 希腊语 |
+| eo | Esperanto | 世界语 |
+| et | Estonian | 爱沙尼亚语 |
+| eu | Basque | 巴斯克语 |
+| fa | Persian | 波斯语 |
+| fi | Finnish | 芬兰语 |
+| fo | Faroese | 法罗语 |
+| gl | Galician | 加利西亚语 |
+| gn | Guarani | 瓜拉尼语 |
+| gu | Gujarati | 古吉拉特语 |
+| gv | Manx | 马恩语 |
+| ha | Hausa | 豪萨语 |
+| haw | Hawaiian | 夏威夷语 |
+| hi | Hindi | 印地语 |
+| hr | Croatian | 克罗地亚语 |
+| ht | Haitian Creole | 海地克里奥尔语 |
+| hu | Hungarian | 匈牙利语 |
+| hy | Armenian | 亚美尼亚语 |
+| ia | Interlingua | 国际语 |
+| id | Indonesian | 印度尼西亚语 |
+| is | Icelandic | 冰岛语 |
+| it | Italian | 意大利语 |
+| iw | Hebrew | 希伯来语 |
+| jw | Javanese | 爪哇语 |
+| ka | Georgian | 格鲁吉亚语 |
+| kk | Kazakh | 哈萨克语 |
+| km | Khmer | 高棉语 |
+| kn | Kannada | 卡纳达语 |
+| la | Latin | 拉丁语 |
+| lb | Luxembourgish | 卢森堡语 |
+| ln | Lingala | 林加拉语 |
+| lo | Lao | 老挝语 |
+| lt | Lithuanian | 立陶宛语 |
+| lv | Latvian | 拉脱维亚语 |
+| mg | Malagasy | 马尔加什语 |
+| mi | Māori | 毛利语 |
+| mk | Macedonian | 马其顿语 |
+| ml | Malayalam | 马拉雅拉姆语 |
+| mn | Mongolian | 蒙古语 |
+| mr | Marathi | 马拉地语 |
+| ms | Malay | 马来语 |
+| mt | Maltese | 马耳他语 |
+| my | Burmese | 缅甸语 |
+| ne | Nepali | 尼泊尔语 |
+| nl | Dutch | 荷兰语 |
+| nn | Norwegian Nynorsk | 挪威语 |
+| no | Norwegian | 挪威语 |
+| oc | Occitan | 奥克语 |
+| pa | Punjabi | 旁遮普语 |
+| pl | Polish | 波兰语 |
+| ps | Pashto | 普什图语 |
+| ro | Romanian | 罗马尼亚语 |
+| sa | Sanskrit | 梵语 |
+| sco | Scots | 苏格兰语 |
+| sd | Sindhi | 信德语 |
+| si | Sinhala | 僧伽罗语 |
+| sk | Slovak | 斯洛伐克语 |
+| sl | Slovenian | 斯洛文尼亚语 |
+| sn | Shona | 绍纳语 |
+| so | Somali | 索马里语 |
+| sq | Albanian | 阿尔巴尼亚语 |
+| sr | Serbian | 塞尔维亚语 |
+| su | Sundanese | 巽他语 |
+| sv | Swedish | 瑞典语 |
+| sw | Swahili | 斯瓦希里语 |
+| ta | Tamil | 泰米尔语 |
+| te | Telugu | 泰卢固语 |
+| tg | Tajik | 塔吉克语 |
+| th | Thai | 泰语 |
+| tk | Turkmen | 土库曼语 |
+| tl | Tagalog | 塔加洛语 |
+| tr | Turkish | 土耳其语 |
+| tt | Tatar | 鞑靼语 |
+| uk | Ukrainian | 乌克兰语 |
+| ur | Urdu | 乌尔都语 |
+| uz | Uzbek | 乌兹别克语 |
+| vi | Vietnamese | 越南语 |
+| war | Waray | 瓦赖语 |
+| yi | Yiddish | 意第绪语 |
+| yo | Yoruba | 约鲁巴语 |
+
+
+
+## Language Region Code(Chinese Dialects)
+
+| Language Region Code | English Name | Chinese Name |
+|---|---|---|
+| zh-mandarin | Chinese (Mandarin) | 中文(普通话) |
+| zh-yue | Chinese (Yue, Guangdong) | 中文(粤语-广东) |
+| zh-wu | Chinese (Wu, Shanghai) | 中文(吴语-上海) |
+| zh-min | Chinese (Min, Fujian) | 中文(闽语-福建) |
+| zh-north | Chinese (Mandarin, North) — Shandong / Gansu / Ningxia / Hebei / Shanxi / Liaoning / Shaanxi | 中文(官话-北方:山东/甘肃/宁夏/河北/山西/辽宁/陕西) |
+| zh-xinan | Chinese (Mandarin, Southwest) — Sichuan / Yunnan / Guizhou / Hubei / Chongqing | 中文(官话-西南:四川/云南/贵州/湖北/重庆) |
+| zh-xiang | Chinese (Xiang, Hunan) | 中文(湘语-湖南) |
+|bo|Chinese (Tibetan) | 中文(藏语)|
diff --git a/fireredasr2s/fireredlid/__init__.py b/fireredasr2s/fireredlid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..12400de8ff282e85400fe3268e6b4517eb9a5a30
--- /dev/null
+++ b/fireredasr2s/fireredlid/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import os
+import sys
+
+__version__ = "0.0.1"
+
+_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+try:
+ from fireredasr2s.fireredlid.lid import FireRedLid, FireRedLidConfig
+except ImportError:
+ if _CURRENT_DIR not in sys.path:
+ sys.path.insert(0, _CURRENT_DIR)
+ from .lid import FireRedLid, FireRedLidConfig
+
+
+# API
+__all__ = [
+ "__version__",
+ "FireRedLid",
+ "FireRedLidConfig",
+]
diff --git a/fireredasr2s/fireredlid/data/feat.py b/fireredasr2s/fireredlid/data/feat.py
new file mode 100644
index 0000000000000000000000000000000000000000..f54910e0d526d18333a25e5d23bcc1a5226b3324
--- /dev/null
+++ b/fireredasr2s/fireredlid/data/feat.py
@@ -0,0 +1,124 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import math
+import os
+
+import kaldiio
+import kaldi_native_fbank as knf
+import numpy as np
+import torch
+
+
+class FeatExtractor:
+ def __init__(self, kaldi_cmvn_file):
+ self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
+ self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
+ frame_shift=10, dither=0.0)
+
+ def __call__(self, wav_paths, wav_uttids):
+ feats = []
+ durs = []
+ return_wav_paths = []
+ return_wav_uttids = []
+
+ wav_datas = []
+ if isinstance(wav_paths[0], str):
+ for wav_path in wav_paths:
+ sample_rate, wav_np = kaldiio.load_mat(wav_path)
+ wav_datas.append([sample_rate, wav_np])
+ else:
+ wav_datas = wav_paths
+
+ for (sample_rate, wav_np), path, uttid in zip(wav_datas, wav_paths, wav_uttids):
+ dur = wav_np.shape[0] / sample_rate
+ fbank = self.fbank((sample_rate, wav_np))
+ if fbank.shape[0] < 1:
+ continue
+ if self.cmvn is not None:
+ fbank = self.cmvn(fbank)
+ fbank = torch.from_numpy(fbank).float()
+ feats.append(fbank)
+ durs.append(dur)
+ return_wav_paths.append(path)
+ return_wav_uttids.append(uttid)
+ if len(feats) > 0:
+ lengths = torch.tensor([feat.size(0) for feat in feats]).long()
+ feats_pad = self.pad_feat(feats, 0.0)
+ else:
+ lengths, feats_pad = None, None
+ return feats_pad, lengths, durs, return_wav_paths, return_wav_uttids
+
+ def pad_feat(self, xs, pad_value):
+ # type: (List[Tensor], int) -> Tensor
+ n_batch = len(xs)
+ max_len = max([xs[i].size(0) for i in range(n_batch)])
+ pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value)
+ for i in range(n_batch):
+ pad[i, :xs[i].size(0)] = xs[i]
+ return pad
+
+
+class CMVN:
+ def __init__(self, kaldi_cmvn_file):
+ self.dim, self.means, self.inverse_std_variences = \
+ self.read_kaldi_cmvn(kaldi_cmvn_file)
+
+ def __call__(self, x, is_train=False):
+ assert x.shape[-1] == self.dim, "CMVN dim mismatch"
+ out = x - self.means
+ out = out * self.inverse_std_variences
+ return out
+
+ def read_kaldi_cmvn(self, kaldi_cmvn_file):
+ assert os.path.exists(kaldi_cmvn_file)
+ stats = kaldiio.load_mat(kaldi_cmvn_file)
+ assert stats.shape[0] == 2
+ dim = stats.shape[-1] - 1
+ count = stats[0, dim]
+ assert count >= 1
+ floor = 1e-20
+ means = []
+ inverse_std_variences = []
+ for d in range(dim):
+ mean = stats[0, d] / count
+ means.append(mean.item())
+ varience = (stats[1, d] / count) - mean*mean
+ if varience < floor:
+ varience = floor
+ istd = 1.0 / math.sqrt(varience)
+ inverse_std_variences.append(istd)
+ return dim, np.array(means), np.array(inverse_std_variences)
+
+
+
+class KaldifeatFbank:
+ def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
+ dither=1.0):
+ self.dither = dither
+ opts = knf.FbankOptions()
+ opts.frame_opts.dither = dither
+ opts.mel_opts.num_bins = num_mel_bins
+ opts.frame_opts.snip_edges = True
+ opts.mel_opts.debug_mel = False
+ self.opts = opts
+
+ def __call__(self, wav, is_train=False):
+ if type(wav) is str:
+ sample_rate, wav_np = kaldiio.load_mat(wav)
+ elif type(wav) in [tuple, list] and len(wav) == 2:
+ sample_rate, wav_np = wav
+ assert len(wav_np.shape) == 1
+
+ dither = self.dither if is_train else 0.0
+ self.opts.frame_opts.dither = dither
+ fbank = knf.OnlineFbank(self.opts)
+
+ fbank.accept_waveform(sample_rate, wav_np.tolist())
+ feat = []
+ for i in range(fbank.num_frames_ready):
+ feat.append(fbank.get_frame(i))
+ if len(feat) == 0:
+ print("Check data, len(feat) == 0", wav, flush=True)
+ return np.zeros((0, self.opts.mel_opts.num_bins))
+ feat = np.vstack(feat)
+ return feat
diff --git a/fireredasr2s/fireredlid/data/token_dict.py b/fireredasr2s/fireredlid/data/token_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fbd2fd55ff946ea51d17058132793cea2a21f8f
--- /dev/null
+++ b/fireredasr2s/fireredlid/data/token_dict.py
@@ -0,0 +1,63 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TokenDict:
+ def __init__(self, dict_path, unk=""):
+ assert dict_path != ""
+ self.id2word, self.word2id = self.read_dict(dict_path)
+ self.unk = unk
+ assert unk == "" or unk in self.word2id
+ self.unkid = self.word2id[unk] if unk else -1
+
+ def get(self, key, default):
+ if type(default) == str:
+ default = self.word2id[default]
+ return self.word2id.get(key, default)
+
+ def __getitem__(self, key):
+ if type(key) == str:
+ if self.unk:
+ return self.word2id.get(key, self.word2id[self.unk])
+ else:
+ return self.word2id[key]
+ elif type(key) == int:
+ return self.id2word[key]
+ else:
+ raise TypeError("Key should be str or int")
+
+ def __len__(self):
+ return len(self.id2word)
+
+ def __contains__(self, query):
+ if type(query) == str:
+ return query in self.word2id
+ elif type(query) == int:
+ return query in self.id2word
+ else:
+ raise TypeError("query should be str or int")
+
+ def read_dict(self, dict_path):
+ id2word, word2id = [], {}
+ with open(dict_path, encoding='utf8') as f:
+ for i, line in enumerate(f):
+ tokens = line.strip().split()
+ if len(tokens) >= 2:
+ word, index = tokens[0], int(tokens[1])
+ elif len(tokens) == 1:
+ word, index = tokens[0], i
+ else: # empty line or space
+ logger.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
+ word, index = " ", i
+ assert len(id2word) == index
+ assert len(word2id) == index
+ if word == "":
+ logger.info(f"NOTE: Find in {dict_path}:L{i} and convert it to ' '")
+ word = " "
+ word2id[word] = index
+ id2word.append(word)
+ assert len(id2word) == len(word2id)
+ return id2word, word2id
diff --git a/fireredasr2s/fireredlid/lid.py b/fireredasr2s/fireredlid/lid.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b663de700483e29d2c4d10be1993632ca018e0b
--- /dev/null
+++ b/fireredasr2s/fireredlid/lid.py
@@ -0,0 +1,110 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import os
+import re
+import time
+import traceback
+from dataclasses import dataclass
+
+import torch
+
+from .data.feat import FeatExtractor
+from .models.fireredlid_aed import FireRedLidAed
+from .models.param import count_model_parameters
+from .tokenizer.lid_tokenizer import LidTokenizer
+
+
+@dataclass
+class FireRedLidConfig:
+ use_gpu: bool = True
+ use_half: bool = False
+
+
+class FireRedLid:
+ @classmethod
+ def from_pretrained(cls, model_dir, config=FireRedLidConfig()):
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
+ feat_extractor = FeatExtractor(cmvn_path)
+
+ model_path = os.path.join(model_dir, "model.pth.tar")
+ dict_path =os.path.join(model_dir, "dict.txt")
+ model = load_fireredlid_model(model_path)
+ tokenizer = LidTokenizer(dict_path)
+
+ count_model_parameters(model)
+ model.eval()
+ return cls(feat_extractor, model, tokenizer, config)
+
+ def __init__(self, feat_extractor, model, tokenizer, config):
+ self.feat_extractor = feat_extractor
+ self.model = model
+ self.tokenizer = tokenizer
+ self.config = config
+ self.config.beam_size = 3
+ self.config.nbest = 1
+ self.config.decode_max_len = 2
+ self.config.softmax_smoothing = 1.25
+ self.config.aed_length_penalty = 0.6
+ self.config.eos_penalty = 1.0
+ if self.config.use_gpu:
+ if self.config.use_half:
+ self.model.half()
+ self.model.cuda()
+ else:
+ self.model.cpu()
+
+ @torch.no_grad()
+ def process(self, batch_uttid, batch_wav_path):
+ batch_uttid_origin = batch_uttid
+ try:
+ feats, lengths, durs, batch_wav_path, batch_uttid = \
+ self.feat_extractor(batch_wav_path, batch_uttid)
+ if feats is None:
+ return [{"uttid": uttid, "lang":""} for uttid in batch_uttid_origin]
+ except:
+ traceback.print_exc()
+ return [{"uttid": uttid, "lang":""} for uttid in batch_uttid_origin]
+ total_dur = sum(durs)
+ if self.config.use_gpu:
+ feats, lengths = feats.cuda(), lengths.cuda()
+ if self.config.use_half:
+ feats = feats.half()
+
+ start_time = time.time()
+
+ try:
+ hyps = self.model.process(
+ feats, lengths,
+ self.config.beam_size,
+ self.config.nbest,
+ self.config.decode_max_len,
+ self.config.softmax_smoothing,
+ self.config.aed_length_penalty,
+ self.config.eos_penalty
+ )
+ except Exception as e:
+ traceback.print_exc()
+ hyps = []
+
+ elapsed = time.time() - start_time
+ rtf= elapsed / total_dur if total_dur > 0 else 0
+
+ results = []
+ for uttid, wav, hyp, dur in zip(batch_uttid, batch_wav_path, hyps, durs):
+ hyp = hyp[0] # only return 1-best
+ hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
+ text = self.tokenizer.detokenize(hyp_ids)
+ results.append({"uttid": uttid, "lang": text,
+ "confidence": round(hyp["confidence"].cpu().item(), 3),
+ "dur_s": round(dur, 3), "rtf": f"{rtf:.4f}"})
+ if type(wav) == str:
+ results[-1]["wav"] = wav
+ return results
+
+
+def load_fireredlid_model(model_path):
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
+ #print(package["args"])
+ model = FireRedLidAed.from_args(package["args"])
+ model.load_state_dict(package["model_state_dict"], strict=False)
+ return model
diff --git a/fireredasr2s/fireredlid/models/fireredlid_aed.py b/fireredasr2s/fireredlid/models/fireredlid_aed.py
new file mode 100644
index 0000000000000000000000000000000000000000..663795e5564d8ed748450053eeb0f8740f7edb6f
--- /dev/null
+++ b/fireredasr2s/fireredlid/models/fireredlid_aed.py
@@ -0,0 +1,37 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import torch
+
+from .module.conformer_encoder import ConformerEncoder
+from .module.transformer_decoder import TransformerDecoder
+
+
+class FireRedLidAed(torch.nn.Module):
+ @classmethod
+ def from_args(cls, args):
+ return cls(args)
+
+ def __init__(self, args):
+ super().__init__()
+ self.sos_id = args.sos_id
+ self.eos_id = args.eos_id
+
+ self.encoder = ConformerEncoder(
+ args.idim, args.n_layers_enc, args.n_head, args.d_model,
+ args.residual_dropout, args.dropout_rate,
+ args.kernel_size, args.pe_maxlen)
+
+ self.lid_decoder = TransformerDecoder(
+ args.sos_id, args.eos_id, args.pad_id, args.lid_odim,
+ args.n_layers_lid_dec, args.n_head, args.d_model,
+ args.residual_dropout, args.pe_maxlen)
+
+ def process(self, padded_input, input_lengths,
+ beam_size=3, nbest=1, decode_max_len=2,
+ softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0):
+ enc_outputs, enc_lengths, enc_mask = self.encoder(padded_input, input_lengths)
+ nbest_hyps = self.lid_decoder.batch_beam_search(
+ enc_outputs, enc_mask,
+ beam_size, nbest, decode_max_len,
+ softmax_smoothing, length_penalty, eos_penalty)
+ return nbest_hyps
diff --git a/fireredasr2s/fireredlid/models/module/conformer_encoder.py b/fireredasr2s/fireredlid/models/module/conformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c95b8ef359781908df248511d57d1bed6cc99f1
--- /dev/null
+++ b/fireredasr2s/fireredlid/models/module/conformer_encoder.py
@@ -0,0 +1,324 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ConformerEncoder(nn.Module):
+ def __init__(self, idim, n_layers, n_head, d_model,
+ residual_dropout=0.1, dropout_rate=0.1, kernel_size=33,
+ pe_maxlen=5000):
+ super().__init__()
+ self.odim = d_model
+
+ self.input_preprocessor = Conv2dSubsampling(idim, d_model)
+ self.positional_encoding = RelPositionalEncoding(d_model)
+ self.dropout = nn.Dropout(residual_dropout)
+
+ self.layer_stack = nn.ModuleList()
+ for l in range(n_layers):
+ block = RelPosEmbConformerBlock(d_model, n_head,
+ residual_dropout,
+ dropout_rate, kernel_size)
+ self.layer_stack.append(block)
+
+ def forward(self, padded_input, input_lengths, pad=True):
+ if pad:
+ padded_input = F.pad(padded_input,
+ (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
+ src_mask = self.padding_position_is_0(padded_input, input_lengths)
+
+ embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask)
+ enc_output = self.dropout(embed_output)
+
+ pos_emb = self.dropout(self.positional_encoding(embed_output))
+
+ enc_outputs = []
+ for enc_layer in self.layer_stack:
+ enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
+ pad_mask=src_mask)
+ enc_outputs.append(enc_output)
+
+ return enc_output, input_lengths, src_mask
+
+ def padding_position_is_0(self, padded_input, input_lengths):
+ N, T = padded_input.size()[:2]
+ mask = torch.ones((N, T)).to(padded_input.device)
+ for i in range(N):
+ mask[i, input_lengths[i]:] = 0
+ mask = mask.unsqueeze(dim=1)
+ return mask.to(torch.uint8)
+
+
+class RelPosEmbConformerBlock(nn.Module):
+ def __init__(self, d_model, n_head,
+ residual_dropout=0.1,
+ dropout_rate=0.1, kernel_size=33):
+ super().__init__()
+ self.ffn1 = ConformerFeedForward(d_model, dropout_rate)
+ self.mhsa = RelPosMultiHeadAttention(n_head, d_model,
+ residual_dropout)
+ self.conv = ConformerConvolution(d_model, kernel_size,
+ dropout_rate)
+ self.ffn2 = ConformerFeedForward(d_model, dropout_rate)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None):
+ out = 0.5 * x + 0.5 * self.ffn1(x)
+ out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
+ out = self.conv(out, pad_mask)
+ out = 0.5 * out + 0.5 * self.ffn2(out)
+ out = self.layer_norm(out)
+ return out
+
+
+class Swish(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class Conv2dSubsampling(nn.Module):
+ def __init__(self, idim, d_model, out_channels=32):
+ super().__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(1, out_channels, 3, 2),
+ nn.ReLU(),
+ nn.Conv2d(out_channels, out_channels, 3, 2),
+ nn.ReLU(),
+ )
+ subsample_idim = ((idim - 1) // 2 - 1) // 2
+ self.out = nn.Linear(out_channels * subsample_idim, d_model)
+
+ self.subsampling = 4
+ left_context = right_context = 3 # both exclude currect frame
+ self.context = left_context + 1 + right_context # 7
+
+ def forward(self, x, x_mask):
+ x = x.unsqueeze(1)
+ x = self.conv(x)
+ N, C, T, D = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
+ mask = x_mask[:, :, :-2:2][:, :, :-2:2]
+ input_lengths = mask[:, -1, :].sum(dim=-1)
+ return x, input_lengths, mask
+
+
+class RelPositionalEncoding(torch.nn.Module):
+ def __init__(self, d_model, max_len=5000):
+ super().__init__()
+ pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
+ pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ # Tmax = 2 * max_len - 1
+ Tmax, T = self.pe.size(1), x.size(1)
+ pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
+ return pos_emb
+
+
+class ConformerFeedForward(nn.Module):
+ def __init__(self, d_model, dropout_rate=0.1):
+ super().__init__()
+ pre_layer_norm = nn.LayerNorm(d_model)
+ linear_expand = nn.Linear(d_model, d_model*4)
+ nonlinear = Swish()
+ dropout_pre = nn.Dropout(dropout_rate)
+ linear_project = nn.Linear(d_model*4, d_model)
+ dropout_post = nn.Dropout(dropout_rate)
+ self.net = nn.Sequential(pre_layer_norm,
+ linear_expand,
+ nonlinear,
+ dropout_pre,
+ linear_project,
+ dropout_post)
+
+ def forward(self, x):
+ residual = x
+ output = self.net(x)
+ output = output + residual
+ return output
+
+
+class ConformerConvolution(nn.Module):
+ def __init__(self, d_model, kernel_size=33, dropout_rate=0.1):
+ super().__init__()
+ assert kernel_size % 2 == 1
+ self.pre_layer_norm = nn.LayerNorm(d_model)
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False)
+ self.glu = F.glu
+ self.padding = (kernel_size - 1) // 2
+ self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2,
+ kernel_size, stride=1,
+ padding=self.padding,
+ groups=d_model*2, bias=False)
+ self.batch_norm = nn.LayerNorm(d_model*2)
+ self.swish = Swish()
+ self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, x, mask=None):
+ residual = x
+ out = self.pre_layer_norm(x)
+ out = out.transpose(1, 2)
+ if mask is not None:
+ out.masked_fill_(mask.ne(1), 0.0)
+ out = self.pointwise_conv1(out)
+ out = F.glu(out, dim=1)
+ out = self.depthwise_conv(out)
+
+ out = out.transpose(1, 2)
+ out = self.swish(self.batch_norm(out))
+ out = out.transpose(1, 2)
+
+ out = self.dropout(self.pointwise_conv2(out))
+ if mask is not None:
+ out.masked_fill_(mask.ne(1), 0.0)
+ out = out.transpose(1, 2)
+ return out + residual
+
+
+class EncoderMultiHeadAttention(nn.Module):
+ def __init__(self, n_head, d_model,
+ residual_dropout=0.1):
+ super().__init__()
+ assert d_model % n_head == 0
+ self.n_head = n_head
+ self.d_k = d_model // n_head
+ self.d_v = self.d_k
+
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False)
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
+ self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False)
+
+ self.layer_norm_q = nn.LayerNorm(d_model)
+ self.layer_norm_k = nn.LayerNorm(d_model)
+ self.layer_norm_v = nn.LayerNorm(d_model)
+
+ self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5)
+ self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False)
+ self.dropout = nn.Dropout(residual_dropout)
+
+ def forward(self, q, k, v, mask=None):
+ sz_b, len_q = q.size(0), q.size(1)
+
+ residual = q
+ q, k, v = self.forward_qkv(q, k, v)
+
+ output, attn = self.attention(q, k, v, mask=mask)
+
+ output = self.forward_output(output, residual, sz_b, len_q)
+ return output, attn
+
+ def forward_qkv(self, q, k, v):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
+
+ q = self.layer_norm_q(q)
+ k = self.layer_norm_k(k)
+ v = self.layer_norm_v(v)
+
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ return q, k, v
+
+ def forward_output(self, output, residual, sz_b, len_q):
+ output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
+ fc_out = self.fc(output)
+ output = self.dropout(fc_out)
+ output = output + residual
+ return output
+
+
+class ScaledDotProductAttention(nn.Module):
+ def __init__(self, temperature):
+ super().__init__()
+ self.temperature = temperature
+ self.dropout = nn.Dropout(0.0)
+ self.INF = float('inf')
+
+ def forward(self, q, k, v, mask=None):
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
+ output, attn = self.forward_attention(attn, v, mask)
+ return output, attn
+
+ def forward_attention(self, attn, v, mask=None):
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ mask = mask.eq(0)
+ attn = attn.masked_fill(mask, -self.INF)
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
+ else:
+ attn = torch.softmax(attn, dim=-1)
+
+ d_attn = self.dropout(attn)
+ output = torch.matmul(d_attn, v)
+
+ return output, attn
+
+
+class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
+ def __init__(self, n_head, d_model,
+ residual_dropout=0.1):
+ super().__init__(n_head, d_model,
+ residual_dropout)
+ d_k = d_model // n_head
+ self.scale = 1.0 / (d_k ** 0.5)
+ self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False)
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k))
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def _rel_shift(self, x):
+ N, H, T1, T2 = x.size()
+ zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(N, H, T2 + 1, T1)
+ x = x_padded[:, :, 1:].view_as(x)
+ x = x[:, :, :, : x.size(-1) // 2 + 1]
+ return x
+
+ def forward(self, q, k, v, pos_emb, mask=None):
+ sz_b, len_q = q.size(0), q.size(1)
+
+ residual = q
+ q, k, v = self.forward_qkv(q, k, v)
+
+ q = q.transpose(1, 2)
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k)
+ p = p.transpose(1, 2)
+
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self._rel_shift(matrix_bd)
+
+ attn_scores = matrix_ac + matrix_bd
+ attn_scores.mul_(self.scale)
+
+ output, attn = self.attention.forward_attention(attn_scores, v, mask=mask)
+
+ output = self.forward_output(output, residual, sz_b, len_q)
+ return output, attn
diff --git a/fireredasr2s/fireredlid/models/module/transformer_decoder.py b/fireredasr2s/fireredlid/models/module/transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a891aed9f79a35ff05701f992bdc91b44ee6cf3
--- /dev/null
+++ b/fireredasr2s/fireredlid/models/module/transformer_decoder.py
@@ -0,0 +1,317 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+from typing import List, Optional, Dict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self, sos_id, eos_id, pad_id, odim,
+ n_layers, n_head, d_model,
+ residual_dropout=0.1, pe_maxlen=5000):
+ super().__init__()
+ self.INF = 1e10
+ # parameters
+ self.pad_id = pad_id
+ self.sos_id = sos_id
+ self.eos_id = eos_id
+ self.n_layers = n_layers
+
+ # Components
+ self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id)
+ self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
+ self.dropout = nn.Dropout(residual_dropout)
+
+ self.layer_stack = nn.ModuleList()
+ for l in range(n_layers):
+ block = DecoderLayer(d_model, n_head, residual_dropout)
+ self.layer_stack.append(block)
+
+ self.tgt_word_prj = nn.Linear(d_model, odim, bias=False)
+ self.layer_norm_out = nn.LayerNorm(d_model)
+
+ self.tgt_word_prj.weight = self.tgt_word_emb.weight
+ self.scale = (d_model ** 0.5)
+
+ def batch_beam_search(self, encoder_outputs, src_masks,
+ beam_size=1, nbest=1, decode_max_len=0,
+ softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
+ B = beam_size
+ N, Ti, H = encoder_outputs.size()
+ device = encoder_outputs.device
+ maxlen = decode_max_len if decode_max_len > 0 else Ti
+ assert eos_penalty > 0.0
+
+ # Init
+ encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H)
+ src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti)
+ ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device)
+ t_ys = ys.clone()
+ confidences = torch.zeros(N*B, 1).float().to(device)
+ caches: List[Optional[Tensor]] = []
+ for _ in range(self.n_layers):
+ caches.append(None)
+ scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device)
+ scores = scores.repeat(N).view(N*B, 1)
+ is_finished = torch.zeros_like(scores)
+
+ # Autoregressive Prediction
+ for t in range(maxlen):
+ tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id)
+
+ dec_output = self.dropout(
+ self.tgt_word_emb(ys) * self.scale +
+ self.positional_encoding(ys))
+# if t > 0:
+# dec_output = dec_output[:, -1:, :]
+ i = 0
+ for dec_layer in self.layer_stack:
+ dec_output = dec_layer.forward(
+ dec_output, encoder_outputs,
+ tgt_mask, src_mask,
+ cache=caches[i])
+ caches[i] = dec_output
+ i += 1
+
+ dec_output = self.layer_norm_out(dec_output)
+
+ t_logit = self.tgt_word_prj(dec_output[:, -1])
+ t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1)
+ t_origin_scores = t_scores
+
+ if eos_penalty != 1.0:
+ t_scores[:, self.eos_id] *= eos_penalty
+
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1)
+ t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished)
+ t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished)
+
+ # Accumulated
+ scores = scores + t_topB_scores
+
+ # Pruning
+ scores = scores.view(N, B*B)
+ scores, topB_score_ids = torch.topk(scores, k=B, dim=1)
+ scores = scores.view(-1, 1)
+
+ topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B)
+ stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device)
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
+
+ # Update ys
+ ys = ys[topB_row_number_in_ys]
+ t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
+ ys = torch.cat((ys, t_ys), dim=1)
+
+ # Update confidences
+ confidences = confidences[topB_row_number_in_ys]
+ t_confidences = torch.gather(t_topB_scores.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
+ t_confidences = torch.exp(t_confidences)
+ assert torch.all(t_confidences <= 1.0)
+ assert torch.all(t_confidences >= 0.0)
+ confidences = torch.cat((confidences, t_confidences), dim=1)
+
+ # Update caches
+ new_caches: List[Optional[Tensor]] = []
+ for cache in caches:
+ if cache is not None:
+ new_caches.append(cache[topB_row_number_in_ys])
+ caches = new_caches
+
+ # Update finished state
+ is_finished = t_ys.eq(self.eos_id)
+ if is_finished.sum().item() == N*B:
+ break
+
+ # Length penalty (follow GNMT)
+ scores = scores.view(N, B)
+ ys = ys.view(N, B, -1)
+ ys_lengths = self.get_ys_lengths(ys)
+ if length_penalty > 0.0:
+ penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty)
+ scores /= penalty
+ nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1)
+ nbest_scores = -1.0 * nbest_scores
+ index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long()
+ nbest_ys = ys.view(N*B, -1)[index.view(-1)]
+ nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1)
+ nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1)
+ nbest_confidences = confidences.view(N*B, -1)[index.view(-1)].view(N, nbest_ids.size(1), -1)
+
+ # result
+ nbest_hyps: List[List[Dict[str, Tensor]]] = []
+ for n in range(N):
+ n_nbest_hyps: List[Dict[str, Tensor]] = []
+ for i, score in enumerate(nbest_scores[n]):
+ confidence = nbest_confidences[n, i, 1:nbest_ys_lengths[n, i]]
+ confidence = confidence.mean()
+ new_hyp = {
+ "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]],
+ "confidence": confidence
+ }
+ n_nbest_hyps.append(new_hyp)
+ nbest_hyps.append(n_nbest_hyps)
+ return nbest_hyps
+
+ def ignored_target_position_is_0(self, padded_targets, ignore_id):
+ mask = torch.ne(padded_targets, ignore_id)
+ mask = mask.unsqueeze(dim=1)
+ T = padded_targets.size(-1)
+ upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype)
+ upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device)
+ return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8)
+
+ def upper_triangular_is_0(self, size):
+ ones = torch.ones(size, size)
+ tri_left_ones = torch.tril(ones)
+ return tri_left_ones.to(torch.uint8)
+
+ def set_finished_beam_score_to_zero(self, scores, is_finished):
+ NB, B = scores.size()
+ is_finished = is_finished.float()
+ mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device)
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
+ return scores * (1 - is_finished) + mask_score * is_finished
+
+ def set_finished_beam_y_to_eos(self, ys, is_finished):
+ is_finished = is_finished.long()
+ return ys * (1 - is_finished) + self.eos_id * is_finished
+
+ def get_ys_lengths(self, ys):
+ N, B, Tmax = ys.size()
+ ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1)
+ return ys_lengths.int()
+
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, d_model, n_head, dropout):
+ super().__init__()
+ self.self_attn_norm = nn.LayerNorm(d_model)
+ self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
+
+ self.cross_attn_norm = nn.LayerNorm(d_model)
+ self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
+
+ self.mlp_norm = nn.LayerNorm(d_model)
+ self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)
+
+ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
+ cache=None):
+ x = dec_input
+ residual = x
+ x = self.self_attn_norm(x)
+ if cache is not None:
+ xq = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ self_attn_mask = self_attn_mask[:, -1:, :]
+ else:
+ xq = x
+ x = self.self_attn(xq, x, x, mask=self_attn_mask)
+ x = residual + x
+
+ residual = x
+ x = self.cross_attn_norm(x)
+ x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
+ x = residual + x
+
+ residual = x
+ x = self.mlp_norm(x)
+ x = residual + self.mlp(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x
+
+
+class DecoderMultiHeadAttention(nn.Module):
+ def __init__(self, d_model, n_head, dropout=0.1):
+ super().__init__()
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_k = d_model // n_head
+
+ self.w_qs = nn.Linear(d_model, n_head * self.d_k)
+ self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
+ self.w_vs = nn.Linear(d_model, n_head * self.d_k)
+
+ self.attention = DecoderScaledDotProductAttention(
+ temperature=self.d_k ** 0.5)
+ self.fc = nn.Linear(n_head * self.d_k, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v, mask=None):
+ bs = q.size(0)
+
+ q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
+ k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k)
+ v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+
+ output = self.attention(q, k, v, mask=mask)
+
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
+ output = self.fc(output)
+ output = self.dropout(output)
+
+ return output
+
+
+class DecoderScaledDotProductAttention(nn.Module):
+ def __init__(self, temperature):
+ super().__init__()
+ self.temperature = temperature
+ self.INF = float("inf")
+
+ def forward(self, q, k, v, mask=None):
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
+ if mask is not None:
+ mask = mask.eq(0)
+ attn = attn.masked_fill(mask, -self.INF)
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
+ else:
+ attn = torch.softmax(attn, dim=-1)
+ output = torch.matmul(attn, v)
+ return output
+
+
+class PositionwiseFeedForward(nn.Module):
+ def __init__(self, d_model, d_ff, dropout=0.1):
+ super().__init__()
+ self.w_1 = nn.Linear(d_model, d_ff)
+ self.act = nn.GELU()
+ self.w_2 = nn.Linear(d_ff, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ output = self.w_2(self.act(self.w_1(x)))
+ output = self.dropout(output)
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, max_len=5000):
+ super().__init__()
+ assert d_model % 2 == 0
+ pe = torch.zeros(max_len, d_model, requires_grad=False)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
+ -(torch.log(torch.tensor(10000.0)).item()/d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ length = x.size(1)
+ return self.pe[:, :length].clone().detach()
diff --git a/fireredasr2s/fireredlid/models/param.py b/fireredasr2s/fireredlid/models/param.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4933fc561cca9d44a42ad89db5580e493468bab
--- /dev/null
+++ b/fireredasr2s/fireredlid/models/param.py
@@ -0,0 +1,17 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import logging
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def count_model_parameters(model):
+ if not isinstance(model, torch.nn.Module):
+ return 0, 0
+ name = f"{model.__class__.__name__} {model.__class__}"
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
+ logger.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
+ return num, size
diff --git a/fireredasr2s/fireredlid/speech2lang.py b/fireredasr2s/fireredlid/speech2lang.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0687af346614612a9f2b2bc1db7f5d24eb359b4
--- /dev/null
+++ b/fireredasr2s/fireredlid/speech2lang.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import argparse
+import json
+import logging
+import os
+
+from fireredlid.lid import FireRedLid, FireRedLidConfig
+from fireredlid.utils.io import get_wav_info
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredlid.bin.speech2lang")
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model_dir', type=str, required=True)
+
+# Input / Output
+parser.add_argument("--wav_path", type=str)
+parser.add_argument("--wav_paths", type=str, nargs="*")
+parser.add_argument("--wav_dir", type=str)
+parser.add_argument("--wav_scp", type=str)
+parser.add_argument("--sort_wav_by_dur", type=int, default=0)
+parser.add_argument("--output", type=str)
+# Decode Options
+parser.add_argument('--use_gpu', type=int, default=1)
+parser.add_argument('--use_half', type=int, default=0)
+parser.add_argument("--batch_size", type=int, default=1)
+
+
+def main(args):
+ wavs = get_wav_info(args)
+ fout = open(args.output, "w") if args.output else None
+ foutl = open(args.output + ".jsonl", "w") if args.output else None
+
+ lid_config = FireRedLidConfig(
+ args.use_gpu,
+ args.use_half
+ )
+ model = FireRedLid.from_pretrained(args.model_dir, lid_config)
+
+ batch_uttid = []
+ batch_wav_path = []
+ for i, wav in enumerate(wavs):
+ uttid, wav_path = wav
+ batch_uttid.append(uttid)
+ batch_wav_path.append(wav_path)
+ if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1:
+ continue
+
+ results = model.process(batch_uttid, batch_wav_path)
+
+ for result in results:
+ logger.info(result)
+ if fout is not None:
+ foutl.write(f"{json.dumps(result, ensure_ascii=False)}\n")
+ fout.write(f"{result['uttid']}\t{result['lang']}\n")
+
+ if fout: fout.flush()
+ if foutl: foutl.flush()
+ batch_uttid = []
+ batch_wav_path = []
+ if fout: fout.close()
+ if foutl: foutl.close()
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ logger.info(args)
+ main(args)
diff --git a/fireredasr2s/fireredlid/tokenizer/lid_tokenizer.py b/fireredasr2s/fireredlid/tokenizer/lid_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba6ec965e23f2fc5f93e1018ad6f9e9023a2f04b
--- /dev/null
+++ b/fireredasr2s/fireredlid/tokenizer/lid_tokenizer.py
@@ -0,0 +1,17 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+from ..data.token_dict import TokenDict
+
+
+class LidTokenizer:
+
+ def __init__(self, dict_path, unk=""):
+ self.dict = TokenDict(dict_path, unk=unk)
+
+ def detokenize(self, inputs, join_symbol=" "):
+ if len(inputs) > 0 and type(inputs[0]) == int:
+ tokens = [self.dict[id] for id in inputs]
+ else:
+ tokens = inputs
+ s = f"{join_symbol}".join(tokens)
+ return s
diff --git a/fireredasr2s/fireredlid/utils/io.py b/fireredasr2s/fireredlid/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..9317b542b5928518e3206d2efde22c2baf72c943
--- /dev/null
+++ b/fireredasr2s/fireredlid/utils/io.py
@@ -0,0 +1,38 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Yan Jia)
+
+import glob
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+def get_wav_info(args):
+ """
+ Returns:
+ wavs: list of (uttid, wav_path)
+ """
+ base = lambda p: os.path.basename(p).replace(".wav", "")
+ if args.wav_path:
+ wavs = [(base(args.wav_path), args.wav_path)]
+ elif args.wav_paths and len(args.wav_paths) >= 1:
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
+ elif args.wav_scp:
+ wavs = [line.strip().split() for line in open(args.wav_scp)]
+ if args.sort_wav_by_dur:
+ logger.info("Sort wav by duration...")
+ utt2dur = os.path.join(os.path.dirname(args.wav_scp), "utt2dur")
+ if os.path.exists(utt2dur):
+ utt2dur = [l.strip().split() for l in open(utt2dur)]
+ utt2dur = {l[0]: float(l[1]) for l in utt2dur if len(l) == 2}
+ wavs = sorted(wavs, key=lambda x: -utt2dur[x[0]])
+ logger.info("Sort Done")
+ else:
+ logger.info(f"Not find {utt2dur}, un-sort")
+ elif args.wav_dir:
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
+ wavs = [(base(p), p) for p in sorted(wavs)]
+ else:
+ raise ValueError("Please provide valid wav info")
+ logger.info(f"#wavs={len(wavs)}")
+ return wavs
diff --git a/fireredasr2s/fireredpunc/__init__.py b/fireredasr2s/fireredpunc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c853d1f79302d6c9e0b6ddc205c104a10c74f7b4
--- /dev/null
+++ b/fireredasr2s/fireredpunc/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
+
+import os
+import sys
+import warnings
+warnings.filterwarnings('ignore')
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=ALL, 1=INFO, 2=WARNING, 3=ERROR
+os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
+
+__version__ = "0.0.1"
+
+_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+try:
+ from fireredasr2s.fireredpunc.punc import FireRedPunc, FireRedPuncConfig
+except ImportError:
+ if _CURRENT_DIR not in sys.path:
+ sys.path.insert(0, _CURRENT_DIR)
+ from .punc import FireRedPunc, FireRedPuncConfig
+
+
+# API
+__all__ = [
+ "__version__",
+ "FireRedPunc",
+ "FireRedPuncConfig",
+]
diff --git a/fireredasr2s/fireredpunc/add_punc.py b/fireredasr2s/fireredpunc/add_punc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a05ca974e978f58ce40680379d80dd8f3529e935
--- /dev/null
+++ b/fireredasr2s/fireredpunc/add_punc.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
+
+import argparse
+import logging
+import re
+
+from fireredpunc.punc import FireRedPunc, FireRedPuncConfig
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredvad.bin.vad")
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model_dir', type=str, required=True)
+# Input / Output
+parser.add_argument("--input_txt", type=str, default="")
+parser.add_argument("--input_file", type=str, default="")
+parser.add_argument("--output", type=str)
+parser.add_argument("--input_contain_uttid", type=int, default=0)
+# Punc Options
+parser.add_argument('--use_gpu', type=int, default=1)
+parser.add_argument('--batch_size', type=int, default=1)
+parser.add_argument('--sentence_max_length', type=int, default=-1)
+
+
+def main(args):
+ in_texts = get_input(args)
+ fout = open(args.output, "w") if args.output else None
+
+ punc_config = FireRedPuncConfig(
+ args.use_gpu,
+ args.sentence_max_length
+ )
+ model = FireRedPunc.from_pretrained(args.model_dir, punc_config)
+
+ batch_text = []
+ batch_uttid = []
+ for i, (uttid, text) in enumerate(in_texts):
+ batch_text.append(text)
+ batch_uttid.append(uttid)
+ if len(batch_text) < args.batch_size and i != len(in_texts) - 1:
+ continue
+
+ results = model.process(batch_text)
+
+ for uttid, result in zip(batch_uttid, results):
+ logger.info(result)
+ if fout:
+ if args.input_contain_uttid:
+ fout.write(f"{uttid}\t{result['punc_text']}\n")
+ else:
+ fout.write(f"{result['punc_text']}\n")
+
+ batch_text = []
+ batch_uttid = []
+ if fout: fout.flush()
+
+
+def get_input(args):
+ in_texts = []
+ if args.input_file:
+ with open(args.input_file, "r") as fin:
+ for i, l in enumerate(fin):
+ uttid = i
+ text = l.strip()
+ if args.input_contain_uttid:
+ uttid, text = text.split(maxsplit=1)
+ text = _remove_punc_and_fix_space(text)
+ in_texts.append((uttid, text))
+ logger.info(f"#text={len(in_texts)}")
+ elif args.input_txt:
+ logger.info(f"Input txt: {args.input_txt}")
+ text = _remove_punc_and_fix_space(args.input_txt)
+ in_texts.append((0, text))
+ return in_texts
+
+
+def _remove_punc_and_fix_space(text):
+ origin = text
+ text = re.sub("[,。?!,\.?!]", " ", text)
+ pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u31f0-\u31ff\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\U0002ceb0-\U0002ebef\U00030000-\U0003134f])')
+ parts = pattern.split(text.strip())
+ parts = [p for p in parts if len(p.strip()) > 0]
+ text = "".join(parts)
+ if origin != text:
+ logger.debug(f"Change text: '{origin}' --> '{text}'")
+ return text
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ logger.info(args)
+ main(args)
diff --git a/fireredasr2s/fireredpunc/data/__init__.py b/fireredasr2s/fireredpunc/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fireredasr2s/fireredpunc/data/hf_bert_tokenizer.py b/fireredasr2s/fireredpunc/data/hf_bert_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d07846702f7e9619febaac5ee3c401ed544d8a6f
--- /dev/null
+++ b/fireredasr2s/fireredpunc/data/hf_bert_tokenizer.py
@@ -0,0 +1,180 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
+
+import logging
+import re
+import traceback
+
+from transformers import BertTokenizer
+
+logger = logging.getLogger(__name__)
+
+
+# HuggingFace BERT Tokenizer Wrapper
+class HfBertTokenizer:
+ def __init__(self, huggingface_tokenizer_dir):
+ self.tokenizer = BertTokenizer.from_pretrained(huggingface_tokenizer_dir)
+
+ def tokenize(self, text, recover_unk=False):
+ tokens = self.tokenizer.tokenize(text)
+ tokens_id = self.tokenizer.convert_tokens_to_ids(tokens)
+ if recover_unk:
+ try:
+ tokens = self._recover_unk(text.lower(), tokens)
+ except Exception as e:
+ traceback.print_exc()
+ return tokens, tokens_id
+
+ def _recover_unk(self, text, tokens):
+ if "[UNK]" not in tokens:
+ return tokens
+
+ new_tokens = []
+ text_no_space = text.replace(" ", "")
+
+ # Fast recover:
+ if re.match(r"^[^a-zA-Z0-9']+$", text):
+ tmp_text = text_no_space
+ if len(tmp_text) == len(tokens):
+ success = True
+ for t, tok in zip(tmp_text, tokens):
+ if tok != "[UNK]" and t != tok:
+ success = False
+ break
+ new_tokens.append(t)
+ if success:
+ return new_tokens
+ new_tokens = []
+
+ text_pos = 0
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+ if token == "[UNK]":
+ unk_count = 0
+ j = i
+ while j < len(tokens) and tokens[j] == "[UNK]":
+ unk_count += 1
+ j += 1
+
+ post_token = ""
+ if j < len(tokens):
+ post_token = tokens[j].replace("##", "")
+
+ if post_token:
+ remaining = text_no_space[text_pos:]
+ anchor_pos = remaining.find(post_token)
+ if anchor_pos != -1:
+ unk_chars = remaining[:anchor_pos]
+ else:
+ unk_chars = remaining[:unk_count]
+ else:
+ unk_chars = text_no_space[text_pos:text_pos + unk_count]
+
+ for k in range(unk_count):
+ if k < len(unk_chars):
+ new_tokens.append(unk_chars[k])
+ else:
+ new_tokens.append("")
+ text_pos += len(unk_chars)
+ i = j
+ else:
+ new_tokens.append(token)
+ token_clean = token.replace("##", "")
+ text_pos += len(token_clean)
+ i += 1
+
+ new_tokens = [t for t in new_tokens if t and t != "[UNK]"]
+ return new_tokens
+
+ def detokenize(self, inputs, join_symbol="", replace_spm_space=True):
+ raise NotImplementedError
+
+
+
+if __name__ == "__main__":
+ import os
+ model_dir = "../../../pretrained_models/FireRedPunc"
+ tokenizer = HfBertTokenizer(os.path.join(model_dir, "chinese-lert-base"))
+
+ txts = [
+ # 基础测试
+ "你好吗",
+ "你好 吗",
+ "hello how are you",
+
+ # 连续生僻字(连续 [UNK])
+ "寄蜉蝣于天地渺沧海之一粟",
+ "魑魅魍魉", # 4个连续生僻字
+ "饕餮耄耋", # 另一组4个连续生僻字
+
+ # 中英混合 + 生僻字
+ "寄蜉蝣于天地渺沧海之一粟how are you魑魅魍魉你蝣蜉啊蝣",
+ "hello魑魅world魍魉test", # 英文夹生僻字
+
+ # 开头/结尾的 [UNK]
+ "蜉蝣你好", # 开头连续生僻字
+ "你好蜉蝣", # 结尾连续生僻字
+ "蜉你蝣好", # 交替出现
+
+ # 特殊符号(可能产生 [UNK])
+ "你好!@#¥%",
+ "【测试】《标题》",
+ "价格:¥99.9元",
+
+ # 复杂混合
+ "【魑魅】说:你好蜉蝣",
+ "饕餮之徒hello耄耋老人",
+
+ # 边界情况
+ "", # 空字符串
+ "蜉", # 单个生僻字
+ "魑魅魍魉饕餮", # 6个连续生僻字
+
+ # ------------------------------------------
+ # 测试:一个 [UNK] 可能对应多个字符的场景
+ # ------------------------------------------
+
+ # 生僻英文单词(可能不在词表中)
+ "价格是xyz123元", # xyz123 可能被标记为 [UNK]
+ "使用qwerty键盘", # qwerty 可能被标记为 [UNK]
+
+ # 特殊符号组合
+ "商标™注册®版权©", # TM R C 等符号
+ "温度是25℃左右", # 摄氏度符号
+ "面积100㎡价格", # 平方米符号
+
+ # 日文/韩文字符(可能不在中文词表中)
+ "你好こんにちは世界", # 日文平假名
+ "欢迎안녕하세요光临", # 韩文
+
+ # 罗马数字
+ "第Ⅷ章内容", # 罗马数字8
+ "共Ⅻ个部分", # 罗马数字12
+
+ # 数学符号
+ "结果是≈100左右", # 约等于符号
+ "价格≤1000元", # 小于等于符号
+
+ # 带圈数字
+ "第①步操作", # 带圈数字1
+ "共⑩个选项", # 带圈数字10
+ ]
+
+ print("=" * 60)
+ print("UNK 恢复测试")
+ print("=" * 60)
+ for txt in txts:
+ if not txt:
+ print(f"(空字符串) --> []")
+ continue
+ tokens_raw = tokenizer.tokenizer.tokenize(txt)
+ tokens_recovered, _ = tokenizer.tokenize(txt, recover_unk=True)
+ has_unk = "[UNK]" in tokens_raw
+ status = "✓" if "[UNK]" not in tokens_recovered else "✗"
+ if has_unk:
+ print(f"{status} {txt}")
+ print(f" 原始: {tokens_raw}")
+ print(f" 恢复: {tokens_recovered}")
+ else:
+ print(f" {txt} --> {tokens_recovered}")
+ print("=" * 60)
diff --git a/fireredasr2s/fireredpunc/data/token_dict.py b/fireredasr2s/fireredpunc/data/token_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..7908fc2bad5b326bcdb5fc14bd0e731a781346e9
--- /dev/null
+++ b/fireredasr2s/fireredpunc/data/token_dict.py
@@ -0,0 +1,63 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TokenDict:
+ def __init__(self, dict_path, unk=""):
+ assert dict_path != ""
+ self.id2word, self.word2id = self.read_dict(dict_path)
+ self.unk = unk
+ assert unk == "" or unk in self.word2id
+ self.unkid = self.word2id[unk] if unk else -1
+
+ def get(self, key, default):
+ if type(default) == str:
+ default = self.word2id[default]
+ return self.word2id.get(key, default)
+
+ def __getitem__(self, key):
+ if type(key) == str:
+ if self.unk:
+ return self.word2id.get(key, self.word2id[self.unk])
+ else:
+ return self.word2id[key]
+ elif type(key) == int:
+ return self.id2word[key]
+ else:
+ raise TypeError("Key should be str or int")
+
+ def __len__(self):
+ return len(self.id2word)
+
+ def __contains__(self, query):
+ if type(query) == str:
+ return query in self.word2id
+ elif type(query) == int:
+ return query in self.id2word
+ else:
+ raise TypeError("query should be str or int")
+
+ def read_dict(self, dict_path):
+ id2word, word2id = [], {}
+ with open(dict_path, encoding='utf8') as f:
+ for i, line in enumerate(f):
+ tokens = line.strip().split()
+ if len(tokens) >= 2:
+ word, index = tokens[0], int(tokens[1])
+ elif len(tokens) == 1:
+ word, index = tokens[0], i
+ else: # empty line or space
+ logger.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
+ word, index = " ", i
+ assert len(id2word) == index
+ assert len(word2id) == index
+ if word == "":
+ logger.info(f"NOTE: Find in {dict_path}:L{i} and convert it to ' '")
+ word = " "
+ word2id[word] = index
+ id2word.append(word)
+ assert len(id2word) == len(word2id)
+ return id2word, word2id
diff --git a/fireredasr2s/fireredpunc/models/__init__.py b/fireredasr2s/fireredpunc/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fireredasr2s/fireredpunc/models/fireredpunc_bert.py b/fireredasr2s/fireredpunc/models/fireredpunc_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..068b66136c4318ff15ddf21a4a8197a335890433
--- /dev/null
+++ b/fireredasr2s/fireredpunc/models/fireredpunc_bert.py
@@ -0,0 +1,69 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
+
+import logging
+
+import torch
+import torch.nn as nn
+import transformers
+
+logger = logging.getLogger(__name__)
+
+
+class FireRedPuncBert(nn.Module):
+ @classmethod
+ def from_args(cls, args):
+ assert args.pretrained_bert, "just support pretrained bert"
+ args.bert = transformers.BertModel.from_pretrained(f"{args.pretrained_bert}")
+ args.bert.pooler = None
+ args.hidden_size = args.bert.config.hidden_size
+ return cls(args)
+
+ def __init__(self, args):
+ super().__init__()
+ self.bert = args.bert if args.pretrained_bert else None # init in build()
+ self.dropout = nn.Dropout(float(args.classifier_dropout))
+ self.classifier = nn.Linear(args.hidden_size, args.odim)
+ self.max_input_len = self.bert.embeddings.position_embeddings.num_embeddings - 1
+ self.cls_id = args.cls_id # set in punc_data.py:PuncData.build()
+ self.ignore_index = args.ignore_index # used by loss
+
+ @torch.jit.export
+ def forward_model(self, padded_inputs, lengths):
+ if padded_inputs.size(1) <= self.max_input_len:
+ score = self._forward(padded_inputs, lengths)
+ else:
+ logger.info("padded_inputs is too long, split it into chunks") #, flush=True)
+ chunk_score = []
+ chunks = padded_inputs.split(self.max_input_len, dim=1)
+ left_lengths = lengths
+ for i, chunk in enumerate(chunks, 1):
+ chunk_lengths = torch.clamp(left_lengths, min=0, max=self.max_input_len)
+ left_lengths = left_lengths - chunk_lengths
+ chunk_score.append(self._forward(chunk, chunk_lengths))
+ score = torch.cat(chunk_score, dim=1)
+ return score
+
+ def _forward(self, padded_inputs, lengths):
+ padded_inputs, lengths = self.add_cls(padded_inputs, lengths)
+ attention_mask = create_huggingface_bert_attention_mask(lengths)
+ outputs = self.bert(padded_inputs, attention_mask)
+ sequence_output = outputs[0][:, 1:] # 1 means remove [CLS]'s output
+ sequence_output = self.dropout(sequence_output)
+ score = self.classifier(sequence_output)
+ return score
+
+ def add_cls(self, padded_inputs, lengths):
+ N = padded_inputs.size(0)
+ cls = padded_inputs.new_ones(N, 1).fill_(self.cls_id)
+ padded_inputs = torch.cat((cls, padded_inputs), dim=1)
+ lengths = lengths + 1
+ return padded_inputs, lengths
+
+
+def create_huggingface_bert_attention_mask(lengths):
+ N = int(lengths.size(0))
+ T = int(lengths.max())
+ mask = lengths.new_ones((N, T))
+ for i in range(N):
+ mask[i, lengths[i]:] = 0
+ return mask.float()
diff --git a/fireredasr2s/fireredpunc/models/param.py b/fireredasr2s/fireredpunc/models/param.py
new file mode 100644
index 0000000000000000000000000000000000000000..976d1167bb43b71338d4c33404cc669936d4a90e
--- /dev/null
+++ b/fireredasr2s/fireredpunc/models/param.py
@@ -0,0 +1,17 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
+
+import logging
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def count_model_parameters(model):
+ if not isinstance(model, torch.nn.Module):
+ return 0, 0
+ name = f"{model.__class__.__name__} {model.__class__}"
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
+ logger.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
+ return num, size
diff --git a/fireredasr2s/fireredpunc/punc.py b/fireredasr2s/fireredpunc/punc.py
new file mode 100644
index 0000000000000000000000000000000000000000..48bd59bf6631b80b0fad9628ed9209d8e3fc6294
--- /dev/null
+++ b/fireredasr2s/fireredpunc/punc.py
@@ -0,0 +1,391 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Junjie Chen)
+
+import logging
+import os
+import re
+from dataclasses import dataclass
+
+import torch
+
+from .data.hf_bert_tokenizer import HfBertTokenizer
+from .models.fireredpunc_bert import FireRedPuncBert
+from .models.param import count_model_parameters
+from .data.token_dict import TokenDict
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FireRedPuncConfig:
+ use_gpu: bool = True
+ sentence_max_length: int = -1
+
+
+class FireRedPunc:
+ @classmethod
+ def from_pretrained(cls, model_dir, config):
+ model = load_punc_bert_model(model_dir)
+ model_io = ModelIO(model_dir)
+ assert isinstance(config, FireRedPuncConfig)
+ count_model_parameters(model)
+ model.eval()
+ return cls(model_io, model, config)
+
+ def __init__(self, model_io, model, config):
+ self.model_io = model_io
+ self.model = model
+ self.config = config
+ if self.config.use_gpu:
+ self.model.cuda()
+ else:
+ self.model.cpu()
+
+ @torch.no_grad()
+ def process(self, batch_text, batch_uttid=None):
+ # Intercept empty input to prevent max([]) from throwing an error
+ if not batch_text:
+ return []
+
+ # 1. Prepare inputs
+ padded_inputs, lengths, txt_tokens = self.model_io.text2tensor(batch_text)
+ if self.config.use_gpu:
+ padded_inputs, lengths = padded_inputs.cuda(), lengths.cuda()
+
+ # 2. Model inference
+ logits = self.model.forward_model(padded_inputs, lengths) # (N,T,C)
+ preds = self.get_punc_pred(logits, lengths)
+
+ # 3. Add Punc to txt
+ punc_txts = self.model_io.add_punc_to_txt(txt_tokens, preds)
+ punc_txts = [RuleBaedTxtFix.fix(txt) for txt in punc_txts]
+
+ # 4. Format output
+ results = []
+ for i in range(len(batch_text)):
+ result = {
+ "punc_text": punc_txts[i],
+ "origin_text": batch_text[i],
+ }
+ if batch_uttid is not None:
+ result["uttid"] = batch_uttid[i]
+ results.append(result)
+ return results
+
+ @torch.no_grad()
+ def process_with_timestamp(self, batch_timestamp, batch_uttid=None):
+ # Intercept empty input to prevent max([]) from throwing an error
+ if not batch_timestamp:
+ return []
+
+ # 1. Prepare inputs
+ padded_inputs, lengths, batch_txt_tokens, batch_tokens_split_num = \
+ self.model_io.timestamp2tensor(batch_timestamp)
+ if self.config.use_gpu:
+ padded_inputs, lengths = padded_inputs.cuda(), lengths.cuda()
+
+ # 2. Model inference
+ logits = self.model.forward_model(padded_inputs, lengths) # (N,T,C)
+ preds = self.get_punc_pred(logits, lengths, batch_txt_tokens)
+
+ # 3. Add Punc to txt
+ punc_txts = self.model_io.add_punc_to_txt_with_timestamp(
+ batch_txt_tokens, preds, batch_timestamp, batch_tokens_split_num)
+
+ new_punc_txts = []
+ for txts in punc_txts:
+ new_txts = []
+ for idx, txt in enumerate(txts):
+ # Only capitalize first letter after sentence-ending punctuation (.!?), not after comma
+ if idx == 0:
+ cap = True
+ else:
+ prev_text = new_txts[idx - 1][0]
+ cap = bool(prev_text) and prev_text[-1] in '.!?。?!'
+ new_txts.append((RuleBaedTxtFix.fix(txt[0], capitalize_first=cap), txt[1], txt[2]))
+ new_punc_txts.append(new_txts)
+ punc_txts = new_punc_txts
+
+ # 4. Format output
+ results = []
+ for i in range(len(batch_timestamp)):
+ result = {
+ "punc_sentences": [
+ {"punc_text": t[0], "start_s": t[1], "end_s": t[2]} for t in punc_txts[i]
+ ],
+ }
+ if batch_uttid is not None:
+ result["uttid"] = batch_uttid[i]
+ results.append(result)
+ return results
+
+ def get_punc_pred(self, punc_logits, lengths, batch_txt_tokens=None):
+ max_len = torch.max(lengths).cpu().item()
+ if max_len <= self.config.sentence_max_length or self.config.sentence_max_length <= 0 or batch_txt_tokens is None:
+ _, preds = torch.max(punc_logits, dim=-1)
+ preds = preds.cpu().tolist()
+ preds = [pred[:lengths[i]] for i, pred in enumerate(preds)]
+ else:
+ preds = self.get_punc_pred_limit_max_len(punc_logits, lengths,
+ batch_txt_tokens)
+ return preds
+
+ def get_punc_pred_limit_max_len(self, punc_logits, lengths, batch_txt_tokens):
+ sentence_max_length = self.config.sentence_max_length
+ preds = []
+ batch_probs = punc_logits.softmax(dim=-1).cpu()
+ lengths = lengths.cpu()
+ for n in range(len(batch_probs)):
+ # Process each sentence
+ single_sentence_seg_token_ids = []
+ probs = batch_probs[n]
+ L = lengths[n]
+ tokens = batch_txt_tokens[n]
+ l = 0
+ while l < L:
+ r = l
+ total_num = 0.0
+ max_seg_prob = -1.0
+ max_index = -1
+ while r < L:
+ token_num = 0.0
+ s = re.sub("^##", "", tokens[r])
+ for j in range(len(s)):
+ if re.match("[a-zA-Z0-9']", s[j]):
+ token_num += 0.5
+ else:
+ token_num += 1
+
+ if total_num + token_num > sentence_max_length and max_seg_prob >= 0:
+ break
+
+ space_prob = probs[r][0]
+ seg_prob = 1.0 - space_prob
+ if seg_prob >= max_seg_prob:
+ max_seg_prob = seg_prob
+ max_index = r
+ total_num += token_num
+ r += 1
+ if seg_prob >= space_prob:
+ break
+ if r >= L:
+ # r is == sentence_length, r-- to avoid out-of-range-access
+ r -= 1
+ else:
+ # if total_num + token_num > sentence_max_length,
+ # we find l to max score's index as a sentence
+ # (max index is betweent [l, r])
+ r = max_index
+ if token_num > sentence_max_length:
+ logger.info(f"Too long token...{n}, {l}, {r}, {total_num}, {token_num}, {tokens[l]}, {tokens[r]}")
+ # range [l, r] is a sentence
+ for idx in range(l, r):
+ single_sentence_seg_token_ids.append(0) # 0 should be space
+ # argmax BEGIN (find an elegant way?)
+ pred_id = 1;
+ max_pred_prob = 0.0;
+ for k in range(1, len(probs[r])):
+ if probs[r][k] > max_pred_prob:
+ pred_id = k;
+ max_pred_prob = probs[r][k];
+ # argmax END
+ single_sentence_seg_token_ids.append(pred_id);
+ l = r + 1
+ preds.append(single_sentence_seg_token_ids)
+ return preds
+
+
+def load_punc_bert_model(model_dir):
+ model_path = os.path.join(model_dir, "model.pth.tar")
+ package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
+ package["args"].bert = None
+ package["args"].pretrained_bert = os.path.join(model_dir, "chinese-lert-base")
+ model = FireRedPuncBert.from_args(package["args"])
+ model.load_state_dict(package["model_state_dict"], strict=False)
+ return model
+
+
+class ModelIO:
+ def __init__(self, model_dir):
+ self.tokenizer = HfBertTokenizer(os.path.join(model_dir, "chinese-lert-base"))
+ self.in_dict = TokenDict(os.path.join(model_dir, "chinese-bert-wwm-ext_vocab.txt"), unk="[UNK]")
+ self.out_dict = TokenDict(os.path.join(model_dir, "out_dict"))
+ self.INPUT_IGNORE_ID = self.in_dict["[PAD]"]
+ self.DEFAULT_OUT = " "
+
+ def text2tensor(self, batch_text):
+ batch_txt_tokens = []
+ batch_input_seqs = []
+ for text in batch_text:
+ tokens, _ = self.tokenizer.tokenize(text, recover_unk=True)
+ input_seq = []
+ for token in tokens:
+ input_seq.append(self.in_dict.get(token, self.in_dict.unk))
+ batch_txt_tokens.append(tokens)
+ batch_input_seqs.append(input_seq)
+ padded_inputs, lengths = self.pad_list(batch_input_seqs, self.INPUT_IGNORE_ID)
+ return padded_inputs, lengths, batch_txt_tokens
+
+ def timestamp2tensor(self, batch_timestamp):
+ batch_txt_tokens = []
+ batch_input_seqs = []
+ batch_tokens_split_num = []
+ for timestamps in batch_timestamp:
+ txt_token = []
+ input_seq = []
+ tokens_split_num = []
+ for token, start, end in timestamps:
+ sub_tokens, _ = self.tokenizer.tokenize(token, recover_unk=True)
+ tokens_split_num.append(len(sub_tokens))
+ txt_token.extend(sub_tokens)
+ for sub_token in sub_tokens:
+ input_seq.append(self.in_dict.get(sub_token, self.in_dict.unk))
+ batch_txt_tokens.append(txt_token)
+ batch_input_seqs.append(input_seq)
+ batch_tokens_split_num.append(tokens_split_num)
+ padded_inputs, lengths = self.pad_list(batch_input_seqs, self.INPUT_IGNORE_ID)
+ return padded_inputs, lengths, batch_txt_tokens, batch_tokens_split_num
+
+ @classmethod
+ def pad_list(cls, input_seqs, pad_value):
+ lengths = [len(seq) for seq in input_seqs]
+ padded_inputs = torch.zeros(len(input_seqs), max(lengths)).fill_(pad_value).long()
+ for i, input_seq in enumerate(input_seqs):
+ end = lengths[i]
+ padded_inputs[i, :end] = torch.LongTensor(input_seq[:end])
+ lengths = torch.IntTensor(lengths)
+ return padded_inputs, lengths
+
+ def add_punc_to_txt(self, token_seqs, pred_seqs):
+ punc_txts = []
+ for token_seq, pred_seq in zip(token_seqs, pred_seqs):
+ assert len(token_seq) == len(pred_seq)
+ txt = ""
+ for i, token in enumerate(token_seq):
+ tag = self.out_dict[pred_seq[i]]
+
+ # tokenizer_type == "huggingface_bert":
+ if token.startswith("##"):
+ token = token.replace("##", "")
+ elif re.search("[a-zA-Z0-9#]+", token) and \
+ i > 0 and re.search("[a-zA-Z0-9#]+", token_seq[i-1]):
+ if self.out_dict[pred_seq[i-1]] == self.DEFAULT_OUT:
+ token = " " + token
+
+ if tag == self.DEFAULT_OUT:
+ txt += token
+ else:
+ txt += token + tag
+ txt = txt.replace(" ", " ")
+ punc_txts.append(txt)
+ return punc_txts
+
+ def add_punc_to_txt_with_timestamp(self, token_seqs, pred_seqs,
+ batch_timestamp, batch_tokens_split_num):
+ punc_txts = []
+ for token_seq, pred_seq, timestamps, tokens_split_num in \
+ zip(token_seqs, pred_seqs, batch_timestamp, batch_tokens_split_num):
+ assert len(token_seq) == len(pred_seq)
+ sentences = []
+ txt, start, end = "", -1, -1
+
+ i = 0
+ j = 0
+ last_token = ""
+ last_tag = ""
+ while i < len(token_seq) and j < len(tokens_split_num):
+ split_num = tokens_split_num[j]
+ timestamp = timestamps[j]
+ assert len(timestamp) == 3
+ if start == -1:
+ start = timestamp[1]
+ end = timestamp[2]
+
+ # Initialize the variables 'token' and 'tag' before each iteration to prevent contamination from the previous word's variables
+ token = ""
+ tag = self.DEFAULT_OUT
+
+ for k in range(split_num):
+ sub_token = token_seq[i]
+ tag = self.out_dict[pred_seq[i]]
+ sub_token = re.sub("^##", "", sub_token)
+ if k == 0:
+ token = sub_token
+ else: # k > 0
+ token += sub_token
+ i += 1
+
+ # If the word segmenter fails to produce any tokens (for example, the input is an empty string "")
+ # Forcefully assign the original string to the token to ensure that the assertion passes and the subsequent logic retains all necessary information
+ if split_num == 0:
+ token = timestamp[0]
+
+ assert token == timestamp[0], f"{token}/{timestamp}"
+ j += 1
+ # Add " " before English & Digit
+ if re.search("[a-zA-Z0-9#]+", token) and \
+ j > 0 and re.search("[a-zA-Z0-9#]+", last_token):
+ if last_tag == self.DEFAULT_OUT:
+ token = " " + token
+
+ if tag == self.DEFAULT_OUT:
+ txt += token
+ else:
+ txt += token + tag
+ # Get New sentence
+ txt = txt.replace(" ", " ")
+ assert start != -1
+ sentences.append((txt, start, end))
+ txt, start, end = "", -1, -1
+ last_token = token
+ last_tag = tag
+ if txt != "":
+ assert start != -1 and end != -1
+ sentences.append((txt, start, end))
+
+ punc_txts.append(sentences)
+ return punc_txts
+
+
+class RuleBaedTxtFix:
+ @classmethod
+ def fix(cls, txt_ori, capitalize_first=True):
+ txt = txt_ori.lower()
+ # English Punc
+ txt = re.sub(r"([a-z]),([a-z])", r"\1, \2", txt)
+ txt = re.sub(r"([a-z])。([a-z])", r"\1. \2", txt)
+ txt = re.sub(r"([a-z])?([a-z])", r"\1? \2", txt)
+ txt = re.sub(r"([a-z])!([a-z])", r"\1! \2", txt)
+ txt = re.sub(r"^([a-z]+),", r"\1,", txt)
+ txt = re.sub(r"^([a-z]+)。", r"\1.", txt)
+ txt = re.sub(r"^([a-z]+)?", r"\1?", txt)
+ txt = re.sub(r"^([a-z]+)!", r"\1!", txt)
+ txt = re.sub(r"( [a-zA-Z']+),$", r"\1,", txt)
+ txt = re.sub(r"( [a-zA-Z']+)。$", r"\1.", txt)
+ txt = re.sub(r"( [a-zA-Z']+)?$", r"\1?", txt)
+ txt = re.sub(r"( [a-zA-Z']+)!$", r"\1!", txt)
+ # I
+ txt = re.sub("^i ", "I ", txt)
+ txt = re.sub("^i'm ", "I'm ", txt)
+ txt = re.sub("^i'd ", "I'd ", txt)
+ txt = re.sub("^i've ", "I've ", txt)
+ txt = re.sub("^i'll ", "I'll ", txt)
+ txt = re.sub(" i ", " I ", txt)
+ txt = re.sub(" i'm ", " I'm ", txt)
+ txt = re.sub(" i'd ", " I'd ", txt)
+ txt = re.sub(" i've ", " I've ", txt)
+ txt = re.sub(" i'll ", " I'll ", txt)
+ # First English upper
+ if capitalize_first and len(txt) > 0 and re.match("[a-z]", txt[0]):
+ txt = txt[0].upper() + txt[1:]
+ txt = re.sub(r'([.!?。?!])\s+([a-z])', lambda m: f"{m.group(1)} {m.group(2).upper()}", txt)
+
+ return txt
+
+
+if __name__ == "__main__":
+ txts = [
+ "i'm ok. how are you? i'm fine.",
+ "Tim,"
+ ]
+ for txt in txts:
+ txt2 = RuleBaedTxtFix.fix(txt)
diff --git a/fireredasr2s/fireredvad/__init__.py b/fireredasr2s/fireredvad/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2486ce72708dbddea1145bb693d63553897cbd92
--- /dev/null
+++ b/fireredasr2s/fireredvad/__init__.py
@@ -0,0 +1,57 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import os
+import sys
+
+__version__ = "0.0.1"
+
+_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+try:
+ from fireredasr2s.fireredvad.aed import FireRedAed, FireRedAedConfig
+ from fireredasr2s.fireredvad.stream_vad import FireRedStreamVad, FireRedStreamVadConfig
+ from fireredasr2s.fireredvad.vad import FireRedVad, FireRedVadConfig
+except ImportError:
+ if _CURRENT_DIR not in sys.path:
+ sys.path.insert(0, _CURRENT_DIR)
+ from .aed import FireRedAed, FireRedAedConfig
+ from .stream_vad import FireRedStreamVad, FireRedStreamVadConfig
+ from .vad import FireRedVad, FireRedVadConfig
+
+
+def non_stream_vad(wav_path, model_dir="pretrained_models/FireRedVAD/VAD", **kwargs):
+ """Quick VAD inference"""
+ config = FireRedVadConfig(**kwargs)
+ vad = FireRedVad.from_pretrained(model_dir, config)
+ result, probs = vad.detect(wav_path)
+ return result
+
+
+def stream_vad_full(wav_path, model_dir="pretrained_models/FireRedVAD/Stream-VAD", **kwargs):
+ """Quick Stream VAD inference"""
+ config = FireRedStreamVadConfig(**kwargs)
+ svad = FireRedStreamVad.from_pretrained(model_dir, config)
+ frame_results, result = svad.detect_full(wav_path)
+ return frame_results, result
+
+
+def non_stream_aed(wav_path, model_dir="pretrained_models/FireRedVAD/AED", **kwargs):
+ """Quick AED inference"""
+ config = FireRedAedConfig(**kwargs)
+ aed = FireRedAed.from_pretrained(model_dir, config)
+ result, probs = aed.detect(wav_path)
+ return result
+
+
+__all__ = [
+ '__version__',
+ 'FireRedVad',
+ 'FireRedVadConfig',
+ 'FireRedAed',
+ 'FireRedAedConfig',
+ 'FireRedStreamVad',
+ 'FireRedStreamVadConfig',
+ 'non_stream_vad',
+ 'stream_vad_full',
+ 'non_stream_aed'
+]
diff --git a/fireredasr2s/fireredvad/aed.py b/fireredasr2s/fireredvad/aed.py
new file mode 100644
index 0000000000000000000000000000000000000000..391ca0a36c5462eb3e064b685ef34d73389ea20f
--- /dev/null
+++ b/fireredasr2s/fireredvad/aed.py
@@ -0,0 +1,109 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import logging
+import os
+from dataclasses import dataclass
+
+import torch
+
+from .core.audio_feat import AudioFeat
+from .core.detect_model import DetectModel
+from .core.vad_postprocessor import VadPostprocessor
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FireRedAedConfig:
+ use_gpu: bool = False
+ smooth_window_size: int = 5
+ speech_threshold: float = 0.4
+ singing_threshold: float = 0.5
+ music_threshold: float = 0.5
+ min_event_frame: int = 20
+ max_event_frame: int = 2000 # 20s
+ min_silence_frame: int = 20
+ merge_silence_frame: int = 0
+ extend_speech_frame: int = 0
+ chunk_max_frame: int = 30000 # 300s
+
+
+class FireRedAed:
+ IDX2EVENT = {0: "speech", 1: "singing", 2: "music"}
+
+ @classmethod
+ def from_pretrained(cls, model_dir, config=FireRedAedConfig()):
+ # Build Feat Extractor
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
+ audio_feat = AudioFeat(cmvn_path)
+
+ # Build Model
+ model = DetectModel.from_pretrained(model_dir)
+ if config.use_gpu:
+ model.cuda()
+ else:
+ model.cpu()
+
+ # Build Postprocessor
+ event2postprocessor = {}
+ for event in cls.IDX2EVENT.values():
+ threshold = getattr(config, f"{event}_threshold")
+ event2postprocessor[event] = VadPostprocessor(
+ config.smooth_window_size,
+ threshold,
+ config.min_event_frame,
+ config.max_event_frame,
+ config.min_silence_frame,
+ config.merge_silence_frame,
+ config.extend_speech_frame)
+ return cls(audio_feat, model, event2postprocessor, config)
+
+ def __init__(self, audio_feat, model, event2postprocessor, config):
+ self.audio_feat = audio_feat
+ self.model = model
+ self.event2postprocessor = event2postprocessor
+ self.config = config
+
+ def detect(self, audio):
+ # Extract feat
+ feat, dur = self.audio_feat.extract(audio)
+ if self.config.use_gpu:
+ feat = feat.cuda()
+
+ # Model inference
+ if feat.size(0) <= self.config.chunk_max_frame:
+ probs, _ = self.model.forward(feat.unsqueeze(0))
+ assert probs.size(-1) == len(self.IDX2EVENT)
+ probs = probs.cpu().squeeze(0) # (T,3)
+ else:
+ logger.warning(f"Too long input, split every {self.config.chunk_max_frame} frames")
+ chunk_probs = []
+ chunks = feat.split(self.config.chunk_max_frame, dim=0)
+ for chunk in chunks:
+ chunk_prob, _ = self.model.forward(chunk.unsqueeze(0))
+ assert chunk_prob.size(-1) == len(self.IDX2EVENT)
+ chunk_probs.append(chunk_prob.cpu())
+ probs = torch.cat(chunk_probs, dim=1)
+ probs = probs.squeeze(0) # (T,3)
+
+ # Prob Postprocess
+ event2starts_ends_s = {}
+ event2raw_ratio = {}
+ for idx, event in self.IDX2EVENT.items():
+ threshold = getattr(self.config, f"{event}_threshold")
+ postprocessor = self.event2postprocessor[event]
+ event_probs = probs[:, idx].tolist()
+ decision = postprocessor.process(event_probs)
+ starts_ends_s = postprocessor.decision_to_segment(decision, dur)
+ event2starts_ends_s[event] = starts_ends_s
+
+ raw_ratio = sum(int(p>= threshold) for p in event_probs) / len(event_probs) if len(event_probs) else 0
+ event2raw_ratio[event] = round(raw_ratio, 3)
+
+ # Format result
+ result = {"dur": round(dur, 3),
+ "event2timestamps": event2starts_ends_s,
+ "event2ratio": event2raw_ratio}
+ if isinstance(audio, str):
+ result["wav_path"] = audio
+ return result, probs
diff --git a/fireredasr2s/fireredvad/bin/__init__.py b/fireredasr2s/fireredvad/bin/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fireredasr2s/fireredvad/bin/aed.py b/fireredasr2s/fireredvad/bin/aed.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e9a2b7ede04a01ca7bf3f959b78591fb3200a2
--- /dev/null
+++ b/fireredasr2s/fireredvad/bin/aed.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import argparse
+import json
+import logging
+import time
+
+from fireredvad.aed import FireRedAedConfig, FireRedAed
+from fireredvad.utils.io import get_wav_info, write_event_textgrid, split_and_save_event_segment
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredvad.bin.aed")
+
+
+parser = argparse.ArgumentParser()
+# Input
+parser.add_argument("--wav_path", type=str)
+parser.add_argument("--wav_paths", type=str, nargs="*")
+parser.add_argument("--wav_scp", type=str)
+parser.add_argument("--wav_dir", type=str)
+# Output
+parser.add_argument("--output", type=str, default="aed_output")
+parser.add_argument("--write_textgrid", type=int, default=0)
+parser.add_argument("--save_segment_dir", type=str, default="")
+# AED Options
+parser.add_argument('--model_dir', type=str,
+ default="pretrained_models/FireRedVAD-AED-251104")
+parser.add_argument('--use_gpu', type=int, default=0)
+parser.add_argument("--smooth_window_size", type=int, default=5)
+parser.add_argument("--speech_threshold", type=float, default=0.4)
+parser.add_argument("--singing_threshold", type=float, default=0.5)
+parser.add_argument("--music_threshold", type=float, default=0.5)
+parser.add_argument("--min_event_frame", type=int, default=20)
+parser.add_argument("--max_event_frame", type=int, default=2000)
+parser.add_argument("--min_silence_frame", type=int, default=20)
+parser.add_argument("--merge_silence_frame", type=int, default=0)
+parser.add_argument("--extend_speech_frame", type=int, default=0)
+parser.add_argument("--chunk_max_frame", type=int, default=30000)
+
+
+def main(args):
+ logger.info("Start AED...\n")
+ wavs = get_wav_info(args)
+ fout = open(args.output, "w") if args.output else None
+
+ aed_config = FireRedAedConfig(
+ use_gpu = args.use_gpu,
+ smooth_window_size = args.smooth_window_size,
+ speech_threshold = args.speech_threshold,
+ singing_threshold = args.singing_threshold,
+ music_threshold = args.music_threshold,
+ min_event_frame = args.min_event_frame,
+ max_event_frame = args.max_event_frame,
+ min_silence_frame = args.min_silence_frame,
+ merge_silence_frame = args.merge_silence_frame,
+ extend_speech_frame = args.extend_speech_frame,
+ chunk_max_frame = args.chunk_max_frame)
+ logger.info(f"{aed_config}")
+ aed = FireRedAed.from_pretrained(args.model_dir, aed_config)
+
+ for i, (uttid, wav_path) in enumerate(wavs):
+ logger.info("")
+ logger.info(f">>> {i} Processing {wav_path}")
+ start_time = time.time()
+
+ result, probs = aed.detect(wav_path)
+
+ elapsed = time.time() - start_time
+ dur = result["dur"]
+ rtf = elapsed / dur if dur > 0 else 0
+ logger.info(f"Result: {result}")
+ logger.info(f"Dur={dur} elapsed(ms)={round(elapsed*1000, 2)} RTF={round(rtf, 5)}")
+
+ if fout:
+ fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
+ if args.write_textgrid:
+ write_event_textgrid(result["wav_path"], result["dur"], result["event2timestamps"])
+ if args.save_segment_dir:
+ split_and_save_event_segment(wav_path, result["event2timestamps"], args.save_segment_dir)
+ if fout: fout.close()
+
+ logger.info("All AED Done")
+
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ logger.info(f"{args}")
+ main(args)
diff --git a/fireredasr2s/fireredvad/bin/fireredvad_cli.py b/fireredasr2s/fireredvad/bin/fireredvad_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ce43db6cfb06f550cbca53a8965dce056a04cb3
--- /dev/null
+++ b/fireredasr2s/fireredvad/bin/fireredvad_cli.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import argparse
+import logging
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredvad")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="FireRedVAD: VAD & AED")
+ parser.add_argument("--task", type=str, required=True,
+ choices=["vad", "stream_vad", "aed"],
+ help="Task type: vad, stream_vad, or aed")
+ parser.add_argument("--wav_path", type=str, required=True)
+ parser.add_argument("--model_dir", type=str, default=None)
+ parser.add_argument("--use_gpu", type=int, default=0)
+
+ args, unknown = parser.parse_known_args()
+
+ if args.task == "vad":
+ from fireredvad import non_stream_vad
+ model_dir = args.model_dir or "pretrained_models/FireRedVAD/VAD"
+ result = non_stream_vad(args.wav_path, model_dir=model_dir, use_gpu=args.use_gpu)
+ elif args.task == "stream_vad":
+ from fireredvad import stream_vad_full
+ model_dir = args.model_dir or "pretrained_models/FireRedVAD/Stream-VAD"
+ result = stream_vad_full(args.wav_path, model_dir=model_dir, use_gpu=args.use_gpu)
+ elif args.task == "aed":
+ from fireredvad import non_stream_aed
+ model_dir = args.model_dir or "pretrained_models/FireRedVAD/AED"
+ result = non_stream_aed(args.wav_path, model_dir=model_dir, use_gpu=args.use_gpu)
+
+ logger.info(f"Result: {result}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fireredasr2s/fireredvad/bin/stream_vad.py b/fireredasr2s/fireredvad/bin/stream_vad.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6a8955237b8d2af56c2e3eba26c6accc2145cf4
--- /dev/null
+++ b/fireredasr2s/fireredvad/bin/stream_vad.py
@@ -0,0 +1,172 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import argparse
+import json
+import logging
+
+import soundfile as sf
+
+from fireredvad.core.constants import SAMPLE_RATE, FRAME_LENGTH_SAMPLE, FRAME_SHIFT_SAMPLE
+from fireredvad.stream_vad import FireRedStreamVadConfig, FireRedStreamVad
+from fireredvad.utils.io import get_wav_info, write_textgrid, split_and_save_segment, timeit
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredvad.bin.stream_vad")
+
+
+parser = argparse.ArgumentParser()
+# Input
+parser.add_argument("--wav_path", type=str)
+parser.add_argument("--wav_paths", type=str, nargs="*")
+parser.add_argument("--wav_scp", type=str)
+parser.add_argument("--wav_dir", type=str)
+# Output
+parser.add_argument("--output", type=str, default="vad_output")
+parser.add_argument("--write_textgrid", type=int, default=0)
+parser.add_argument("--save_segment_dir", type=str, default="")
+# VAD Options
+parser.add_argument('--model_dir', type=str,
+ default="pretrained_models/FireRedVAD-VAD-stream-251104")
+parser.add_argument('--stream_vad_mode', type=str, default="all",
+ choices=["framewise", "chunkwise", "full", "all"])
+parser.add_argument('--stream_chunk_frame', type=int, default=10)
+# Vad Config
+parser.add_argument('--use_gpu', type=int, default=0)
+parser.add_argument("--smooth_window_size", type=int, default=5)
+parser.add_argument("--speech_threshold", type=float, default=0.3)
+parser.add_argument("--pad_start_frame", type=int, default=5)
+parser.add_argument("--min_speech_frame", type=int, default=8)
+parser.add_argument("--max_speech_frame", type=int, default=2000)
+parser.add_argument("--min_silence_frame", type=int, default=20)
+parser.add_argument("--chunk_max_frame", type=int, default=30000)
+
+
+def main(args):
+ logger.info("Start Stream VAD...\n")
+ wavs = get_wav_info(args)
+ fout = open(args.output, "w") if args.output else None
+
+ vad_config = FireRedStreamVadConfig(
+ use_gpu = args.use_gpu,
+ smooth_window_size = args.smooth_window_size,
+ speech_threshold = args.speech_threshold,
+ pad_start_frame = args.pad_start_frame,
+ min_speech_frame = args.min_speech_frame,
+ max_speech_frame = args.max_speech_frame,
+ min_silence_frame = args.min_silence_frame,
+ chunk_max_frame = args.chunk_max_frame)
+ logger.info(f"{vad_config}")
+ stream_vad = FireRedStreamVad.from_pretrained(args.model_dir, vad_config)
+
+ for i, (uttid, wav_path) in enumerate(wavs):
+ logger.info("")
+ logger.info(f">>> {i} Processing {wav_path}")
+
+ if args.stream_vad_mode in ["all", "full"]:
+ results, timestamps, dur = vad_full(wav_path, stream_vad, args)
+
+ if args.stream_vad_mode in ["all", "chunkwise"]:
+ results, timestamps, dur = vad_chunkwise(wav_path, stream_vad, args)
+
+ if args.stream_vad_mode in ["all", "framewise"]:
+ results, timestamps, dur = vad_framewise(wav_path, stream_vad, args)
+
+ if fout:
+ d = {"uttid": uttid, "wav_path": wav_path, "dur": dur, "timestamps": timestamps}
+ fout.write(f"{json.dumps(d, ensure_ascii=False)}\n")
+ if args.write_textgrid:
+ write_textgrid(wav_path, dur, timestamps)
+ if args.save_segment_dir:
+ split_and_save_segment(wav_path, timestamps, args.save_segment_dir)
+ if fout: fout.close()
+
+ logger.info("All Stream VAD Done")
+
+
+@timeit
+def vad_framewise(wav_path, stream_vad, args):
+ logger.info("Stream VAD Mode: framewise")
+
+ wav_np, sr = sf.read(wav_path, dtype="int16")
+ assert sr == SAMPLE_RATE
+ n_frame = 0
+ frame_results = []
+ stream_vad.reset()
+ for j in range(0, len(wav_np) - FRAME_LENGTH_SAMPLE + 1, FRAME_SHIFT_SAMPLE):
+ audio_frame = wav_np[j:j+FRAME_LENGTH_SAMPLE]
+ result = stream_vad.detect_frame(audio_frame)
+ n_frame += 1
+ logger.debug(f"{n_frame:4d} {result}")
+ if result.is_speech_start:
+ logger.info(f"Speech start {result.speech_start_frame}")
+ elif result.is_speech_end:
+ logger.info(f"Speech end {result.speech_end_frame}")
+ frame_results.append(result)
+
+ logger.info(f"#frame={len(frame_results)}")
+ timestamps = stream_vad.results_to_timestamps(frame_results)
+ logger.info(f"timestamps(seconds): {timestamps}")
+ dur = len(wav_np) / sr
+ return frame_results, timestamps, dur
+
+
+@timeit
+def vad_chunkwise(wav_path, stream_vad, args):
+ logger.info(f"Stream VAD Mode: chunkwise {args.stream_chunk_frame}")
+ N = args.stream_chunk_frame
+ assert N > 0
+ chunk_length = FRAME_LENGTH_SAMPLE + (N-1)*FRAME_SHIFT_SAMPLE
+ chunk_shift = N * FRAME_SHIFT_SAMPLE
+
+ wav_np, sr = sf.read(wav_path, dtype="int16")
+ assert sr == SAMPLE_RATE
+ n_frame = 0
+ chunk_results = []
+ stream_vad.reset()
+ for j in range(0, len(wav_np), chunk_shift):
+ audio_chunk = wav_np[j:j+chunk_length]
+ results = stream_vad.detect_chunk(audio_chunk)
+ for result in results:
+ n_frame += 1
+ logger.debug(f"{n_frame:4d} {result}")
+ if result.is_speech_start:
+ logger.info(f"Speech start {result.speech_start_frame}")
+ elif result.is_speech_end:
+ logger.info(f"Speech end {result.speech_end_frame}")
+ chunk_results.append(result)
+
+ logger.info(f"#frame={len(chunk_results)}")
+ timestamps = stream_vad.results_to_timestamps(chunk_results)
+ logger.info(f"timestamps(seconds): {timestamps}")
+ dur = len(wav_np) / sr
+ return chunk_results, timestamps, dur
+
+
+@timeit
+def vad_full(wav_path, stream_vad, args):
+ logger.info("Stream VAD Mode: full")
+ frame_results, result = stream_vad.detect_full(wav_path)
+ logger.info(f"Result: {result}")
+ timestamps = result["timestamps"]
+ dur = result["dur"]
+
+ n_frame = 0
+ for frame_result in frame_results:
+ n_frame += 1
+ logger.debug(f"{n_frame:4d} {result}")
+ if frame_result.is_speech_start:
+ logger.info(f"Speech start {frame_result.speech_start_frame}")
+ elif frame_result.is_speech_end:
+ logger.info(f"Speech end {frame_result.speech_end_frame}")
+ logger.info(f"#frame={len(frame_results)}")
+
+ return frame_results, timestamps, dur
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ logger.info(f"{args}")
+ main(args)
diff --git a/fireredasr2s/fireredvad/bin/vad.py b/fireredasr2s/fireredvad/bin/vad.py
new file mode 100644
index 0000000000000000000000000000000000000000..905dc074874130a352c88d088f557eae241af740
--- /dev/null
+++ b/fireredasr2s/fireredvad/bin/vad.py
@@ -0,0 +1,88 @@
+#!/usr/bin/env python3
+
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import argparse
+import json
+import logging
+import time
+
+from fireredvad.vad import FireRedVadConfig, FireRedVad
+from fireredvad.utils.io import get_wav_info, write_textgrid, split_and_save_segment
+
+logging.basicConfig(level=logging.INFO,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+logger = logging.getLogger("fireredvad.bin.vad")
+
+
+parser = argparse.ArgumentParser()
+# Input
+parser.add_argument("--wav_path", type=str)
+parser.add_argument("--wav_paths", type=str, nargs="*")
+parser.add_argument("--wav_scp", type=str)
+parser.add_argument("--wav_dir", type=str)
+# Output
+parser.add_argument("--output", type=str, default="vad_output")
+parser.add_argument("--write_textgrid", type=int, default=0)
+parser.add_argument("--save_segment_dir", type=str, default="")
+# VAD Options
+parser.add_argument('--model_dir', type=str,
+ default="pretrained_models/FireRedVAD-VAD-preview")
+parser.add_argument('--use_gpu', type=int, default=0)
+parser.add_argument("--smooth_window_size", type=int, default=5)
+parser.add_argument("--speech_threshold", type=float, default=0.4)
+parser.add_argument("--min_speech_frame", type=int, default=20)
+parser.add_argument("--max_speech_frame", type=int, default=2000)
+parser.add_argument("--min_silence_frame", type=int, default=20)
+parser.add_argument("--merge_silence_frame", type=int, default=0)
+parser.add_argument("--extend_speech_frame", type=int, default=0)
+parser.add_argument("--chunk_max_frame", type=int, default=30000)
+
+
+def main(args):
+ logger.info("Start VAD...\n")
+ wavs = get_wav_info(args)
+ fout = open(args.output, "w") if args.output else None
+
+ vad_config = FireRedVadConfig(
+ use_gpu = args.use_gpu,
+ smooth_window_size = args.smooth_window_size,
+ speech_threshold = args.speech_threshold,
+ min_speech_frame = args.min_speech_frame,
+ max_speech_frame = args.max_speech_frame,
+ min_silence_frame = args.min_silence_frame,
+ merge_silence_frame = args.merge_silence_frame,
+ extend_speech_frame = args.extend_speech_frame,
+ chunk_max_frame = args.chunk_max_frame)
+ logger.info(f"{vad_config}")
+ vad = FireRedVad.from_pretrained(args.model_dir, vad_config)
+
+ for i, (uttid, wav_path) in enumerate(wavs):
+ logger.info("")
+ logger.info(f">>> {i} Processing {wav_path}")
+ start_time = time.time()
+
+ result, probs = vad.detect(wav_path)
+
+ elapsed = time.time() - start_time
+ dur = result["dur"]
+ rtf = elapsed / dur if dur > 0 else 0
+ logger.info(f"Result: {result}")
+ logger.info(f"Dur={dur} elapsed(ms)={round(elapsed*1000, 2)} RTF={round(rtf, 5)}")
+
+ if fout:
+ fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
+ if args.write_textgrid:
+ write_textgrid(result["wav_path"], result["dur"], result["timestamps"])
+ if args.save_segment_dir:
+ split_and_save_segment(wav_path, result["timestamps"], args.save_segment_dir)
+ if fout: fout.close()
+
+ logger.info("All VAD Done")
+
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ logger.info(f"{args}")
+ main(args)
diff --git a/fireredasr2s/fireredvad/core/__init__.py b/fireredasr2s/fireredvad/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fireredasr2s/fireredvad/core/audio_feat.py b/fireredasr2s/fireredvad/core/audio_feat.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a34f646c2ede81c945f09683e2769a5ab86f177
--- /dev/null
+++ b/fireredasr2s/fireredvad/core/audio_feat.py
@@ -0,0 +1,106 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import math
+import os
+
+import soundfile as sf
+import kaldiio
+import kaldi_native_fbank as knf
+import numpy as np
+import torch
+
+
+class AudioFeat:
+ def __init__(self, kaldi_cmvn_file):
+ self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
+ self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
+ frame_shift=10, dither=0)
+
+ def reset(self):
+ pass
+
+ def extract(self, audio):
+ if isinstance(audio, str):
+ wav_np, sample_rate = sf.read(audio, dtype="int16")
+ elif isinstance(audio, (list, tuple)):
+ wav_np, sample_rate = audio
+ else:
+ wav_np = audio
+ sample_rate = 16000
+ assert sample_rate == 16000
+
+ dur = wav_np.shape[0] / sample_rate
+ fbank = self.fbank((sample_rate, wav_np))
+ if self.cmvn is not None:
+ fbank = self.cmvn(fbank)
+ feat = torch.from_numpy(fbank).float()
+ return feat, dur
+
+
+
+class CMVN:
+ def __init__(self, kaldi_cmvn_file):
+ self.dim, self.means, self.inverse_std_variances = \
+ self.read_kaldi_cmvn(kaldi_cmvn_file)
+
+ def __call__(self, x, is_train=False):
+ assert x.shape[-1] == self.dim, "CMVN dim mismatch"
+ out = x - self.means
+ out = out * self.inverse_std_variances
+ return out
+
+ def read_kaldi_cmvn(self, kaldi_cmvn_file):
+ assert os.path.exists(kaldi_cmvn_file)
+ stats = kaldiio.load_mat(kaldi_cmvn_file)
+ assert stats.shape[0] == 2
+ dim = stats.shape[-1] - 1
+ count = stats[0, dim]
+ assert count >= 1
+ floor = 1e-20
+ means = []
+ inverse_std_variances = []
+ for d in range(dim):
+ mean = stats[0, d] / count
+ means.append(mean.item())
+ variance = (stats[1, d] / count) - mean * mean
+ if variance < floor:
+ variance = floor
+ istd = 1.0 / math.sqrt(variance)
+ inverse_std_variances.append(istd)
+ return dim, np.array(means), np.array(inverse_std_variances)
+
+
+
+class KaldifeatFbank:
+ def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
+ dither=0):
+ self.dither = dither
+ opts = knf.FbankOptions()
+ opts.frame_opts.samp_freq = 16000
+ opts.frame_opts.frame_length_ms = 25
+ opts.frame_opts.frame_shift_ms = 10
+ opts.frame_opts.dither = dither
+ opts.frame_opts.snip_edges = True
+ opts.mel_opts.num_bins = num_mel_bins
+ opts.mel_opts.debug_mel = False
+ self.opts = opts
+
+ def __call__(self, wav, is_train=False):
+ if isinstance(wav, str):
+ wav_np, sample_rate = sf.read(wav, dtype="int16")
+ elif isinstance(wav, (tuple, list)) and len(wav) == 2:
+ sample_rate, wav_np = wav
+ assert len(wav_np.shape) == 1
+
+ dither = self.dither if is_train else 0.0
+ self.opts.frame_opts.dither = dither
+ fbank = knf.OnlineFbank(self.opts)
+
+ fbank.accept_waveform(sample_rate, wav_np.tolist())
+ feat = []
+ for i in range(fbank.num_frames_ready):
+ feat.append(fbank.get_frame(i))
+ if len(feat) == 0:
+ return np.zeros((0, self.opts.mel_opts.num_bins))
+ feat = np.vstack(feat)
+ return feat
diff --git a/fireredasr2s/fireredvad/core/constants.py b/fireredasr2s/fireredvad/core/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..1151a4e00194547d11d594e156b5c504fd44ca61
--- /dev/null
+++ b/fireredasr2s/fireredvad/core/constants.py
@@ -0,0 +1,10 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+SAMPLE_RATE = 16000
+FRAME_LENGTH_MS = 25
+FRAME_SHIFT_MS = 10
+FRAME_LENGTH_S = 0.025
+FRAME_SHIFT_S = 0.010
+FRAME_LENGTH_SAMPLE = int(SAMPLE_RATE * FRAME_LENGTH_MS / 1000)
+FRAME_SHIFT_SAMPLE = int(SAMPLE_RATE * FRAME_SHIFT_MS / 1000)
+FRAME_PER_SECONDS = int(1000 / FRAME_SHIFT_MS)
diff --git a/fireredasr2s/fireredvad/core/detect_model.py b/fireredasr2s/fireredvad/core/detect_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2112fe9d3e5b43ab78d51ba7ae0d5b9a324a7e44
--- /dev/null
+++ b/fireredasr2s/fireredvad/core/detect_model.py
@@ -0,0 +1,246 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DetectModel(nn.Module):
+ @classmethod
+ def from_pretrained(cls, model_dir):
+ model_path = os.path.join(model_dir, "model.pth.tar")
+ package = torch.load(model_path,
+ map_location=lambda storage, loc: storage, weights_only=False)
+ model = cls(package["args"])
+ model.load_state_dict(package["model_state_dict"], strict=True)
+ model.eval()
+ return model
+
+ def __init__(self, args):
+ super().__init__()
+ self.dfsmn = DFSMN(args.idim, args.R, args.M, args.H, args.P,
+ args.N1, args.S1, args.N2, args.S2,
+ args.dropout)
+ self.out = torch.nn.Linear(args.H, args.odim)
+
+ @torch.no_grad()
+ def forward(self, feat, caches=None):
+ # type: (Tensor, Optional[List[Tensor]]) -> Tuple[Tensor, List[Tensor]]
+ x, new_caches = self.dfsmn(feat, caches=caches)
+ logits = self.out(x)
+ probs = torch.sigmoid(logits)
+ return probs, new_caches
+
+
+class DFSMN(nn.Module):
+ def __init__(self, D, R, M, H, P, N1, S1, N2=0, S2=0, dropout=0.1):
+ """
+ DFSMN config: Rx[H-P(N1,N2,S1,S2)]-MxH
+ Args:
+ D: input dimension
+ R: number of DFSMN blocks
+ M: number of DNN layers
+ H: hidden size
+ P: projection size
+ N1: lookback order
+ S1: lookback stride
+ N2: lookahead order
+ S2: lookahead stride
+ """
+ super().__init__()
+ # Components
+ # 1st FSMN block connecting input layer, without skip connection
+ self.fc1 = nn.Sequential(nn.Linear(D, H, bias=True),
+ nn.ReLU(),
+ nn.Dropout(dropout))
+ self.fc2 = nn.Sequential(nn.Linear(H, P, bias=True),
+ nn.ReLU(),
+ nn.Dropout(dropout))
+ self.fsmn1 = FSMN(P, N1, S1, N2, S2)
+ # N-1 DFSMN blocks
+ self.fsmns = nn.ModuleList([DFSMNBlock(H, P, N1, S1, N2, S2, dropout) for _ in range(R-1)])
+ # M DNN layers
+ dnn = [nn.Linear(P, H, bias=True), nn.ReLU(), nn.Dropout(dropout)]
+ for l in range(M - 1):
+ dnn += [nn.Linear(H, H, bias=True), nn.ReLU(), nn.Dropout(dropout)]
+ self.dnns = nn.Sequential(*dnn)
+
+ def forward(self, inputs, input_lengths=None, caches=None):
+ # type: (Tensor, Optional[Tensor], Optional[List[Tensor]]) -> Tuple[Tensor, List[Tensor]]
+ """
+ Args:
+ inputs: [N, T, D], padded, T is sequence length, D is input dim
+ mask: processing padding issue, masked position is 1
+ tensor.masked_fill(mask, value) will fill elements of tensor with value where mask is one.
+ Returns:
+ output: [N, T, P]
+ """
+ if input_lengths is None:
+ mask = None
+ else:
+ mask = get_mask_from_lengths(input_lengths)
+ # 1st FSMN
+ h = self.fc1(inputs)
+ p = self.fc2(h)
+ new_caches = []
+ if caches is None:
+ cache = None
+ else:
+ cache = caches[0]
+ memory, new_cache = self.fsmn1(p, mask=mask, cache=cache)
+ new_caches.append(new_cache)
+
+ # R-1 FSMN
+ i = 1
+ for fsmn in self.fsmns:
+ if caches is None:
+ cache = None
+ else:
+ cache = caches[i]
+ memory, new_cache = fsmn(memory, mask=mask, cache=cache)
+ new_caches.append(new_cache)
+ i += 1
+ # M DNN
+ output = self.dnns(memory)
+ return output, new_caches
+
+
+
+def get_mask_from_lengths(lengths):
+ """Mask position is set to 1 for Tensor.masked_fill(mask, value)
+ Args:
+ lengths: (N, )
+ Return:
+ mask: (N, T)
+ """
+ N = lengths.size(0)
+ T = torch.max(lengths).item()
+ mask = torch.zeros(N, T).to(lengths.device)
+ for i in range(N):
+ mask[i, lengths[i]:] = 1
+ return mask.to(torch.uint8)
+
+
+
+class DFSMNBlock(nn.Module):
+ def __init__(self, H, P, N1, S1, N2=0, S2=0, dropout=0.1):
+ """
+ DFSMNBlock = [input -> Affine+ReLU -> Affine -> vFSMN -> output]
+ | ^
+ |-------------------------------------------|
+ (skip connection)
+ Args:
+ H: hidden size
+ P: projection size
+ N1: lookback order
+ S1: lookback stride
+ N2: lookahead order
+ S2: lookahead stride
+ """
+ super().__init__()
+ # Hyper-parameter
+ self.H, self.P, self.N1, self.S1, self.N2, self.S2 = H, P, N1, S1, N2, S2
+ # Components
+ # step1. \hat{P}^{l-1} -> H^{l}, nonlinear affine transform
+ self.fc1 = nn.Sequential(nn.Linear(P, H, bias=True),
+ nn.ReLU(),
+ nn.Dropout(dropout))
+ # Step2. H^{l} -> P^{l}, linear affine transform
+ self.fc2 = nn.Linear(H, P, bias=False)
+ # Step3. P^{l}-> \hat{P}^{l}, fsmn layer
+ self.fsmn = FSMN(P, N1, S1, N2, S2)
+
+ def forward(self, inputs, mask=None, cache=None):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
+ """
+ Args:
+ inputs: [N, T, P], padded, T is sequence length, P is projection size
+ mask: processing padding issue, masked position is 1
+ tensor.masked_fill(mask, value) will fill elements of tensor with value where mask is one.
+ Returns:
+ output: [N, T, P]
+ """
+ residual = inputs
+ # step1. \hat{P}^{l-1} -> H^{l}, nonlinear affine transform
+ h = self.fc1(inputs)
+ # Step2. H^{l} -> P^{l}, linear affine transform
+ p = self.fc2(h)
+ # Step3. P^{l}-> \hat{P}^{l}, fsmn layer
+ memory, new_cache = self.fsmn(p, mask=mask, cache=cache)
+ # Step4. skip connection
+ output = memory + residual
+ return output, new_cache
+
+
+class FSMN(nn.Module):
+ def __init__(self, P, N1, S1, N2=0, S2=0):
+ """
+ Args:
+ P: projection size
+ N1: lookback order
+ S1: lookback stride
+ N2: lookahead order
+ S2: lookahead stride
+ """
+ super().__init__()
+ # Hyper-parameter
+ assert N1 >= 1
+ self.N1, self.S1, self.N2, self.S2 = N1, S1, N2, S2
+ # Components
+ # P^{l}-> \hat{P}^{l}
+ self.lookback_padding = (N1-1)*S1
+ self.lookback_filter = nn.Conv1d(in_channels=P, out_channels=P,
+ kernel_size=N1, stride=1,
+ padding=self.lookback_padding, dilation=S1,
+ groups=P, bias=False)
+ if self.N2 > 0:
+ self.lookahead_filter = nn.Conv1d(in_channels=P, out_channels=P,
+ kernel_size=N2, stride=1,
+ padding=(N2-1)*S2, dilation=S2,
+ groups=P, bias=False)
+ else:
+ self.lookahead_filter = nn.Identity()
+
+ def forward(self, inputs, mask=None, cache=None):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
+ """
+ Args:
+ inputs: [N, T, P], padded, T is sequence length, P is projection size
+ mask: processing padding issue, masked position is 1
+ tensor.masked_fill(mask, value) will fill elements of tensor with value where mask is one.
+ Returns:
+ memory: [N, T, P]
+ """
+ T = inputs.size(1)
+ if mask is not None:
+ mask = mask.unsqueeze(-1) # [N, T, 1]
+ inputs = inputs.masked_fill(mask, 0.0)
+
+ inputs = inputs.permute((0, 2, 1)).contiguous() # [N, T, P] -> [N, P, T]
+ residual = inputs
+
+ if cache is not None:
+ inputs = torch.cat((cache, inputs), dim=2) # (N, P, C+T)
+ new_cache = inputs[:, :, -self.lookback_padding:] # (N, P, Co)
+
+ # P^{l}-> \hat{P}^{l}, fsmn layer
+ lookback = self.lookback_filter(inputs)
+ if self.N1 > 1:
+ lookback = lookback[:, :, :-(self.N1-1)*self.S1]
+ if cache is not None:
+ start = cache.size(2)
+ lookback = lookback[:, :, start:]
+ memory = residual + lookback
+ else:
+ memory = residual + lookback
+
+ if self.N2 > 0 and T > 1:
+ lookahead = self.lookahead_filter(inputs)
+ memory += F.pad(lookahead[:, :, self.N2*self.S2:], (0, self.S2))
+ memory = memory.permute((0, 2, 1)).contiguous() # [N, P, T] -> [N, T, P]
+
+ if mask is not None:
+ memory = memory.masked_fill(mask, 0.0)
+ return memory, new_cache
diff --git a/fireredasr2s/fireredvad/core/stream_vad_postprocessor.py b/fireredasr2s/fireredvad/core/stream_vad_postprocessor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3feac6e66f55a9c1287a0921b18b3a912dbe274
--- /dev/null
+++ b/fireredasr2s/fireredvad/core/stream_vad_postprocessor.py
@@ -0,0 +1,163 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import enum
+from collections import deque
+from dataclasses import dataclass
+
+
+@dataclass
+class StreamVadFrameResult:
+ frame_idx: int # 1-based
+ is_speech: bool
+ raw_prob: float
+ smoothed_prob: float
+ is_speech_start: bool = False
+ is_speech_end: bool = False
+ speech_start_frame: int = -1 # 1-based
+ speech_end_frame: int = -1 # 1-based
+
+
+@enum.unique
+class VadState(enum.Enum):
+ SILENCE = 0
+ POSSIBLE_SPEECH = 1
+ SPEECH = 2
+ POSSIBLE_SILENCE = 3
+
+
+class StreamVadPostprocessor:
+ def __init__(self,
+ smooth_window_size,
+ speech_threshold,
+ pad_start_frame,
+ min_speech_frame,
+ max_speech_frame,
+ min_silence_frame):
+ self.smooth_window_size = max(1, smooth_window_size)
+ self.speech_threshold = speech_threshold
+ self.pad_start_frame = max(self.smooth_window_size, pad_start_frame)
+ self.min_speech_frame = min_speech_frame
+ self.max_speech_frame = max_speech_frame
+ self.min_silence_frame = min_silence_frame
+ self.reset()
+
+ def reset(self):
+ self.frame_cnt = 0
+ # smooth window
+ self.smooth_window = deque()
+ self.smooth_window_sum = 0.0
+ # state transition
+ self.state = VadState.SILENCE
+ self.speech_cnt = 0
+ self.silence_cnt = 0
+ self.hit_max_speech = False
+ self.last_speech_start_frame = -1
+ self.last_speech_end_frame = -1
+
+ def process_one_frame(self, raw_prob):
+ assert type(raw_prob) == float
+ assert 0.0 <= raw_prob and raw_prob <= 1.0
+ self.frame_cnt += 1
+
+ smoothed_prob = self.smooth_prob(raw_prob)
+
+ is_speech = self.apply_threshold(smoothed_prob)
+
+ result = StreamVadFrameResult(
+ frame_idx = self.frame_cnt,
+ is_speech=is_speech,
+ raw_prob=round(raw_prob, 3),
+ smoothed_prob=round(smoothed_prob, 3)
+ )
+
+ result = self.state_transition(is_speech, result)
+
+ return result
+
+ def smooth_prob(self, prob):
+ if self.smooth_window_size <= 1:
+ return prob
+ self.smooth_window.append(prob)
+ self.smooth_window_sum += prob
+ if len(self.smooth_window) > self.smooth_window_size:
+ left = self.smooth_window.popleft()
+ self.smooth_window_sum -= left
+ smoothed_prob = self.smooth_window_sum / len(self.smooth_window)
+ return smoothed_prob
+
+ def apply_threshold(self, prob):
+ return int(prob >= self.speech_threshold)
+
+ def state_transition(self, is_speech, result):
+ if self.hit_max_speech:
+ result.is_speech_start = True
+ result.speech_start_frame = self.frame_cnt
+ self.last_speech_start_frame = result.speech_start_frame
+ self.hit_max_speech = False
+
+ if self.state == VadState.SILENCE:
+ if is_speech:
+ self.state = VadState.POSSIBLE_SPEECH
+ self.speech_cnt += 1
+ else:
+ self.silence_cnt += 1
+ self.speech_cnt = 0
+
+ elif self.state == VadState.POSSIBLE_SPEECH:
+ if is_speech:
+ self.speech_cnt += 1
+ if self.speech_cnt >= self.min_speech_frame:
+ self.state = VadState.SPEECH
+ result.is_speech_start = True
+ result.speech_start_frame = max(1,
+ self.frame_cnt - self.speech_cnt + 1 - self.pad_start_frame,
+ self.last_speech_end_frame + 1)
+ self.last_speech_start_frame = result.speech_start_frame
+ self.silence_cnt = 0
+ else:
+ self.state = VadState.SILENCE
+ self.silence_cnt = 1
+ self.speech_cnt = 0
+
+ elif self.state == VadState.SPEECH:
+ self.speech_cnt += 1
+ if is_speech:
+ self.silence_cnt = 0
+ if self.speech_cnt >= self.max_speech_frame:
+ self.hit_max_speech = True
+ self.speech_cnt = 0
+ result.is_speech_end = True
+ result.speech_end_frame = self.frame_cnt
+ result.speech_start_frame = self.last_speech_start_frame
+ self.last_speech_start_frame = -1
+ self.last_speech_end_frame = result.speech_end_frame
+ else:
+ self.state = VadState.POSSIBLE_SILENCE
+ self.silence_cnt += 1
+
+ elif self.state == VadState.POSSIBLE_SILENCE:
+ self.speech_cnt += 1
+ if is_speech:
+ self.state = VadState.SPEECH
+ self.silence_cnt = 0
+ if self.speech_cnt >= self.max_speech_frame:
+ self.hit_max_speech = True
+ self.speech_cnt = 0
+ result.is_speech_end = True
+ result.speech_end_frame = self.frame_cnt
+ result.speech_start_frame = self.last_speech_start_frame
+ self.last_speech_start_frame = -1
+ self.last_speech_end_frame = result.speech_end_frame
+
+ else:
+ self.silence_cnt += 1
+ if self.silence_cnt >= self.min_silence_frame:
+ self.state = VadState.SILENCE
+ result.is_speech_end = True
+ result.speech_end_frame = self.frame_cnt
+ result.speech_start_frame = self.last_speech_start_frame
+ self.last_speech_end_frame = result.speech_end_frame
+ self.last_speech_start_frame = -1
+ self.speech_cnt = 0
+
+ return result
diff --git a/fireredasr2s/fireredvad/core/vad_postprocessor.py b/fireredasr2s/fireredvad/core/vad_postprocessor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e95cdf5c30a3ea476aafe70669fd2370c5455ad0
--- /dev/null
+++ b/fireredasr2s/fireredvad/core/vad_postprocessor.py
@@ -0,0 +1,247 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import enum
+import logging
+from collections import deque
+
+import numpy as np
+
+from .constants import FRAME_LENGTH_MS, FRAME_SHIFT_MS, FRAME_LENGTH_S, FRAME_SHIFT_S
+
+logger = logging.getLogger(__name__)
+
+
+@enum.unique
+class VadState(enum.Enum):
+ SILENCE = 0
+ POSSIBLE_SPEECH = 1
+ SPEECH = 2
+ POSSIBLE_SILENCE = 3
+
+
+class VadPostprocessor:
+ def __init__(self, smooth_window_size,
+ prob_threshold,
+ min_speech_frame,
+ max_speech_frame,
+ min_silence_frame,
+ merge_silence_frame,
+ extend_speech_frame):
+ self.smooth_window_size = max(1, smooth_window_size)
+ self.prob_threshold = prob_threshold
+ self.min_speech_frame = min_speech_frame
+ self.max_speech_frame = max_speech_frame
+ self.min_silence_frame = min_silence_frame
+ self.merge_silence_frame = merge_silence_frame
+ self.extend_speech_frame = extend_speech_frame
+
+ def process(self, raw_probs):
+ if not raw_probs:
+ return []
+
+ smoothed_probs = self._smooth_prob(raw_probs)
+
+ binary_preds = self._apply_threshold(smoothed_probs)
+
+ # decision: 0 means silence, 1 means speech
+ decisions = self._smooth_preds_with_state_machine(binary_preds)
+
+ fixed_decisions = self._fix_smooth_window_start(decisions)
+ smoothed_decisions = self._merge_short_silence_segments(fixed_decisions)
+ extend_decisions = self._extend_speech_segments(smoothed_decisions)
+ final_decisions = self._split_long_speech_segments(extend_decisions, raw_probs)
+ # don't call _merge_short_silence_segments after _split_long_speech_segments
+
+ return final_decisions
+
+ def decision_to_segment(self, decisions, wav_dur=None):
+ segments = []
+ speech_start = None
+ for t, decision in enumerate(decisions):
+ if decision == 1 and speech_start is None:
+ speech_start = t
+ elif decision == 0 and speech_start is not None:
+ if (t - speech_start) < self.min_speech_frame:
+ logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
+ segments.append((speech_start * FRAME_SHIFT_S,
+ t * FRAME_SHIFT_S))
+ speech_start = None
+ if speech_start is not None:
+ t = len(decisions) - 1
+ if (t - speech_start) < self.min_speech_frame:
+ logger.warning("Unexpected short speech segment, check vad_postprocessor.py")
+ end_time = len(decisions) * FRAME_SHIFT_S + FRAME_LENGTH_S
+ if wav_dur is not None:
+ end_time = min(end_time, wav_dur)
+ segments.append((speech_start * FRAME_SHIFT_S,
+ end_time))
+ segments = [(round(s, 3), round(e, 3)) for s, e in segments]
+ return segments
+
+ def _smooth_prob_simple(self, probs):
+ if self.smooth_window_size <= 1:
+ return probs
+ smoothed_probs = probs.copy()
+ window = deque()
+ window_sum = 0.0
+ for i, p in enumerate(probs):
+ window.append(p)
+ window_sum += p
+ if len(window) > self.smooth_window_size:
+ left = window.popleft()
+ window_sum -= left
+ window_avg = window_sum / len(window)
+ smoothed_probs[i] = window_avg
+ return smoothed_probs
+
+ def _smooth_prob(self, probs):
+ if self.smooth_window_size <= 1:
+ return np.asarray(probs)
+ probs_np = np.array(probs)
+ kernel = np.ones(self.smooth_window_size) / self.smooth_window_size
+ # mode='same' 保持长度,'valid' 会变短
+ smoothed = np.convolve(probs_np, kernel, mode='full')[:len(probs)]
+ # 处理边界:前几帧用累积平均
+ for i in range(min(self.smooth_window_size - 1, len(probs))):
+ smoothed[i] = np.mean(probs_np[:i+1])
+ return smoothed #.tolist()
+
+ def _apply_threshold_simple(self, probs):
+ return [int(p >= self.prob_threshold) for p in probs]
+
+ def _apply_threshold(self, probs):
+ probs_np = np.asarray(probs)
+ return (probs_np >= self.prob_threshold).astype(int).tolist()
+
+ def _smooth_preds_with_state_machine(self, binary_preds):
+ """
+ state transition is constrained by min_speech_frame & min_silence_frame
+ """
+ if self.min_speech_frame <= 0 and self.min_silence_frame <= 0:
+ return binary_preds
+ decisions = [0] * len(binary_preds)
+ state = VadState.SILENCE
+ speech_start = -1
+ silence_start = -1
+ for t, is_speech in enumerate(binary_preds):
+ # State transition
+ if state == VadState.SILENCE:
+ if is_speech:
+ state = VadState.POSSIBLE_SPEECH
+ speech_start = t
+
+ elif state == VadState.POSSIBLE_SPEECH:
+ if is_speech:
+ assert speech_start != -1
+ if t - speech_start >= self.min_speech_frame:
+ state = VadState.SPEECH
+ decisions[speech_start:t] = [1] * (t - speech_start)
+ else:
+ state = VadState.SILENCE
+ speech_start = -1
+
+ elif state == VadState.SPEECH:
+ if not is_speech:
+ state = VadState.POSSIBLE_SILENCE
+ silence_start = t
+
+ elif state == VadState.POSSIBLE_SILENCE:
+ if not is_speech:
+ assert silence_start != -1
+ if t - silence_start >= self.min_silence_frame:
+ state = VadState.SILENCE
+ speech_start = -1
+ else:
+ state = VadState.SPEECH
+ silence_start = -1
+
+ # current frame's decision
+ if state == VadState.SPEECH or state == VadState.POSSIBLE_SILENCE:
+ decision = 1
+ elif state == VadState.SILENCE or state == VadState.POSSIBLE_SPEECH:
+ decision = 0
+ else:
+ raise ValueError("Impossible VAD state")
+
+ decisions[t] = decision
+ return decisions
+
+ def _fix_smooth_window_start(self, decisions):
+ new_decisions = decisions.copy()
+ for t, decision in enumerate(decisions):
+ if t > 0 and decisions[t-1] == 0 and decision == 1:
+ start = max(0, t-self.smooth_window_size)
+ new_decisions[start:t] = [1] * (t - start)
+ return new_decisions
+
+ def _merge_short_silence_segments(self, decisions):
+ if self.merge_silence_frame <= 0:
+ return decisions
+ new_decisions = decisions.copy()
+ silence_start = None
+ for t, decision in enumerate(decisions):
+ if t > 0 and decisions[t-1] == 1 and decision == 0 and silence_start is None:
+ silence_start = t
+ elif t > 0 and decisions[t-1] == 0 and decision == 1 and silence_start is not None:
+ silence_frame = t - silence_start
+ if silence_frame < self.merge_silence_frame:
+ new_decisions[silence_start:t] = [1] * silence_frame
+ silence_start = None
+ return new_decisions
+
+ def _extend_speech_segments_simple(self, decisions):
+ """
+ extend N frames before & after speech segments
+ """
+ if self.extend_speech_frame <= 0:
+ return decisions
+ new_decisions = decisions.copy()
+ for t, decision in enumerate(decisions):
+ if decision == 1:
+ start = max(0, t - self.extend_speech_frame)
+ end = min(len(decisions), t + self.extend_speech_frame + 1)
+ new_decisions[start:end] = [1] * (end - start)
+ return new_decisions
+
+ def _extend_speech_segments(self, decisions):
+ """
+ extend N frames before & after speech segments
+ """
+ if self.extend_speech_frame <= 0:
+ return decisions
+ decisions_np = np.array(decisions)
+ kernel = np.ones(2 * self.extend_speech_frame + 1)
+ extended = np.convolve(decisions_np, kernel, mode='same')
+ return (extended > 0).astype(int).tolist()
+
+ def _split_long_speech_segments(self, decisions, probs):
+ new_decisions = decisions.copy()
+ segments = self.decision_to_segment(decisions)
+ for start_s, end_s in segments:
+ start_frame = int(start_s / FRAME_SHIFT_S)
+ end_frame = int(end_s / FRAME_SHIFT_S)
+ dur_frames = end_frame - start_frame
+ if dur_frames > self.max_speech_frame:
+ segment_probs = probs[start_frame:end_frame]
+ split_points = self._find_split_points(segment_probs)
+ for split_point in split_points:
+ split_frame = start_frame + split_point
+ new_decisions[split_frame] = 0
+ return new_decisions
+
+ def _find_split_points(self, probs):
+ split_points = []
+ length = len(probs)
+ start = 0
+ while start < length:
+ if (length - start) <= self.max_speech_frame:
+ break
+ window_start = int(start + self.max_speech_frame / 2)
+ window_end = int(start + self.max_speech_frame)
+ window_probs = probs[window_start:window_end]
+
+ min_index = window_start + np.argmin(window_probs)
+ split_points.append(min_index)
+
+ start = min_index + 1
+ return split_points
diff --git a/fireredasr2s/fireredvad/stream_vad.py b/fireredasr2s/fireredvad/stream_vad.py
new file mode 100644
index 0000000000000000000000000000000000000000..6338841901c0c000d25006fe6adecf386de3deea
--- /dev/null
+++ b/fireredasr2s/fireredvad/stream_vad.py
@@ -0,0 +1,185 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import logging
+import os
+from dataclasses import dataclass
+from typing import List, Tuple, Union
+
+import torch
+import numpy as np
+
+from .core.constants import FRAME_LENGTH_SAMPLE, FRAME_PER_SECONDS
+from .core.audio_feat import AudioFeat
+from .core.detect_model import DetectModel
+from .core.stream_vad_postprocessor import StreamVadPostprocessor, StreamVadFrameResult
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FireRedStreamVadConfig:
+ use_gpu: bool = False
+ smooth_window_size: int = 5
+ speech_threshold: float = 0.5
+ pad_start_frame : int = 5
+ min_speech_frame: int = 8
+ max_speech_frame: int = 2000 # 20s
+ min_silence_frame: int = 20
+ chunk_max_frame: int = 30000 # 300s
+ def __post_init__(self):
+ if self.speech_threshold < 0 or self.speech_threshold > 1:
+ raise ValueError("speech_threshold must be in [0, 1]")
+ if self.min_speech_frame <= 0:
+ raise ValueError("min_speech_frame must be positive")
+
+
+
+class FireRedStreamVad:
+ @classmethod
+ def from_pretrained(cls, model_dir, config=FireRedStreamVadConfig()):
+ # Feat
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
+ feat_extractor = AudioFeat(cmvn_path)
+
+ # Load & Build Model
+ vad_model = DetectModel.from_pretrained(model_dir)
+ if config.use_gpu:
+ vad_model.cuda()
+ else:
+ vad_model.cpu()
+
+ # Build Postprocessor
+ postprocessor = StreamVadPostprocessor(
+ config.smooth_window_size,
+ config.speech_threshold,
+ config.pad_start_frame,
+ config.min_speech_frame,
+ config.max_speech_frame,
+ config.min_silence_frame)
+ return cls(feat_extractor, vad_model, postprocessor, config)
+
+ def __init__(self, audio_feat, vad_model, postprocessor, config):
+ self.audio_feat = audio_feat
+ self.vad_model = vad_model
+ self.postprocessor = postprocessor
+ self.config = config
+ self.model_caches = None
+
+ def reset(self):
+ self.model_caches = None
+ self.audio_feat.reset()
+ self.postprocessor.reset()
+
+ def detect_frame(self, audio_frame: np.ndarray) -> StreamVadFrameResult:
+ if len(audio_frame) != FRAME_LENGTH_SAMPLE:
+ raise ValueError(f"Expected {FRAME_LENGTH_SAMPLE} samples, got {len(audio_frame)}")
+
+ feat, dur = self.audio_feat.extract(audio_frame)
+ if self.config.use_gpu:
+ feat = feat.cuda()
+
+ prob, self.model_caches = self.vad_model.forward(
+ feat.unsqueeze(0), caches=self.model_caches)
+ raw_prob = prob.cpu().squeeze().tolist()
+
+ frame_result = self.postprocessor.process_one_frame(raw_prob)
+ return frame_result
+
+ def detect_chunk(self, audio_chunk: np.ndarray) -> List[StreamVadFrameResult]:
+ feats, dur = self.audio_feat.extract(audio_chunk)
+ if self.config.use_gpu:
+ feats = feats.cuda()
+
+ probs, self.model_caches = self.vad_model.forward(
+ feats.unsqueeze(0), caches=self.model_caches)
+ raw_probs = probs.cpu().squeeze().tolist()
+ if isinstance(raw_probs, float):
+ raw_probs = [raw_probs]
+
+ chunk_results = []
+ for t, raw_prob in enumerate(raw_probs):
+ stream_vad_frame_result = self.postprocessor.process_one_frame(raw_prob)
+ chunk_results.append(stream_vad_frame_result)
+ return chunk_results
+
+ def detect_full(self, audio: Union[str, np.ndarray]) -> Tuple[List[StreamVadFrameResult], dict]:
+ self.reset()
+ feats, dur = self.audio_feat.extract(audio)
+ if self.config.use_gpu:
+ feats = feats.cuda()
+
+ if feats.size(0) <= self.config.chunk_max_frame:
+ probs, _ = self.vad_model.forward(feats.unsqueeze(0))
+ else:
+ logger.warning(f"Too long input, split every {self.config.chunk_max_frame} frames")
+ chunk_probs = []
+ chunks = feats.split(self.config.chunk_max_frame, dim=0)
+ for chunk in chunks:
+ chunk_prob, _ = self.vad_model.forward(chunk.unsqueeze(0))
+ chunk_probs.append(chunk_prob)
+ probs = torch.cat(chunk_probs, dim=1)
+ probs = probs.squeeze() # (T,)
+ raw_probs = probs.cpu().squeeze().tolist() # (T,)
+ if isinstance(raw_probs, float):
+ raw_probs = [raw_probs]
+
+ frame_results = []
+ for t, raw_prob in enumerate(raw_probs):
+ stream_vad_frame_result = self.postprocessor.process_one_frame(raw_prob)
+ frame_results.append(stream_vad_frame_result)
+ self.reset()
+
+ # Format result
+ timestamps = self.results_to_timestamps(frame_results)
+ result = {"dur": round(dur, 3),
+ "timestamps": timestamps}
+ if isinstance(audio, str):
+ result["wav_path"] = audio
+ return frame_results, result
+
+ def set_mode(self, mode: int = 0):
+ if mode == 0: # VERY PERMISSIVE
+ self.config.speech_threshold = 0.3
+ self.config.min_speech_frame = 8
+ self.config.min_silence_frame = 20
+ elif mode == 1: # PERMISSIVE
+ self.config.speech_threshold = 0.5
+ self.config.min_speech_frame = 10
+ self.config.min_silence_frame = 15
+ elif mode == 2: # AGGRESSIVE
+ self.config.speech_threshold = 0.7
+ self.config.min_speech_frame = 15
+ self.config.min_silence_frame = 10
+ elif mode == 3: # VERY_AGGRESSIVE
+ self.config.speech_threshold = 0.9
+ self.config.min_speech_frame = 20
+ self.config.min_silence_frame = 5
+ self.postprocessor.speech_threshold = self.config.speech_threshold
+ self.postprocessor.min_speech_frame = self.config.min_speech_frame
+ self.postprocessor.min_silence_frame = self.config.min_silence_frame
+
+ @classmethod
+ def results_to_timestamps(cls, results):
+ results = sorted(results, key=lambda r: r.frame_idx)
+ # Get frame index (0-based)
+ frame_timestamps = []
+ start, end = -1, -1
+ for r in results:
+ if r.is_speech_start:
+ if start != -1: logger.warning("start should be -1")
+ start = max(0, r.speech_start_frame - 1)
+ end = -1
+ elif r.is_speech_end:
+ assert end == -1
+ end = max(0, r.speech_end_frame - 1)
+ frame_timestamps.append((start, end))
+ start, end = -1, -1
+ if start != -1:
+ assert end == -1
+ end = results[-1].frame_idx - 1
+ frame_timestamps.append((start, end))
+ # Convert to seconds
+ timestamps = []
+ for s, e in frame_timestamps:
+ timestamps.append((s/FRAME_PER_SECONDS, e/FRAME_PER_SECONDS))
+ return timestamps
diff --git a/fireredasr2s/fireredvad/utils/__init__.py b/fireredasr2s/fireredvad/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fireredasr2s/fireredvad/utils/io.py b/fireredasr2s/fireredvad/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..22946c70e98fe800467d5705b97742b0f95a8e8e
--- /dev/null
+++ b/fireredasr2s/fireredvad/utils/io.py
@@ -0,0 +1,107 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import functools
+import glob
+import logging
+import os
+import time
+
+import soundfile as sf
+from textgrid import TextGrid, IntervalTier
+
+logger = logging.getLogger(__name__)
+
+
+def get_wav_info(args):
+ """
+ Returns:
+ wavs: list of (uttid, wav_path)
+ """
+ base = lambda p: os.path.basename(p).replace(".wav", "")
+ if args.wav_path:
+ wavs = [(base(args.wav_path), args.wav_path)]
+ elif args.wav_paths and len(args.wav_paths) >= 1:
+ wavs = [(base(p), p) for p in sorted(args.wav_paths)]
+ elif args.wav_scp:
+ with open(args.wav_scp) as fin:
+ wavs = [line.strip().split() for line in fin]
+ elif args.wav_dir:
+ wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
+ wavs = [(base(p), p) for p in sorted(wavs)]
+ else:
+ raise ValueError("Please provide valid wav info")
+ logger.info(f"#wavs={len(wavs)}")
+ return wavs
+
+
+def write_textgrid(wav_path, wav_dur, event):
+ textgrid_file = wav_path.replace(".wav", ".TextGrid")
+ logger.info(f"Write {textgrid_file}")
+ textgrid = TextGrid(maxTime=wav_dur)
+ tier = IntervalTier(name="voice", maxTime=wav_dur)
+ for start_s, end_s in event:
+ if start_s == end_s:
+ logger.warning(f"Write TG, skip start=end {start_s}")
+ continue
+ start_s = max(start_s, 0)
+ end_s = min(end_s, wav_dur)
+ tier.add(minTime=start_s, maxTime=end_s, mark="1")
+ textgrid.append(tier)
+ textgrid.write(textgrid_file)
+
+
+def write_event_textgrid(wav_path, wav_dur, event2starts_ends_s):
+ textgrid_file = wav_path.replace(".wav", ".TextGrid")
+ logger.info(f"Write {textgrid_file}")
+ textgrid = TextGrid(maxTime=wav_dur)
+ for event, starts_ends_s in event2starts_ends_s.items():
+ tier = IntervalTier(name=event, maxTime=wav_dur)
+ for start_s, end_s in starts_ends_s:
+ if start_s == end_s:
+ logger.warning(f"Write TG, skip start=end {start_s}")
+ continue
+ start_s = max(start_s, 0)
+ end_s = min(end_s, wav_dur)
+ tier.add(minTime=start_s, maxTime=end_s, mark="1")
+ textgrid.append(tier)
+ textgrid.write(textgrid_file)
+
+
+
+def split_and_save_segment(wav_path, timestamps, save_segment_dir):
+ logger.info("Split & save segment")
+ os.makedirs(save_segment_dir, exist_ok=True)
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
+ for j, (start_s, end_s) in enumerate(timestamps):
+ uttid = wav_path.split("/")[-1].replace(".wav", "")
+ seg_id = f"{uttid}_{j}_{int(start_s*1000)}_{int(end_s*1000)}"
+ seg_path = f"{save_segment_dir}/{seg_id}.wav"
+ start, end = int(start_s * sample_rate), int(end_s * sample_rate)
+ sf.write(seg_path, wav_np[start:end], samplerate=sample_rate)
+
+
+def split_and_save_event_segment(wav_path, event2timestamps, save_segment_dir):
+ logger.info("Split & save segment")
+ os.makedirs(save_segment_dir, exist_ok=True)
+ wav_np, sample_rate = sf.read(wav_path, dtype="int16")
+ for event, timestamps in event2timestamps.items():
+ for i, (start_s, end_s) in enumerate(timestamps):
+ uttid = wav_path.split("/")[-1].replace(".wav", "")
+ seg_id = f"{uttid}_{event}_{i}_{int(start_s*1000)}_{int(end_s*1000)}"
+ seg_path = f"{save_segment_dir}/{seg_id}.wav"
+ start, end = int(start_s * sample_rate), int(end_s * sample_rate)
+ sf.write(seg_path, wav_np[start:end], samplerate=sample_rate)
+
+
+def timeit(func):
+ # dur must be last return value of func
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ r = func(*args, **kwargs)
+ elapsed = time.time() - start
+ dur = r[-1]
+ rtf = elapsed / dur if dur else 0
+ logger.info(f"RTF={round(rtf, 5)}, elapsed={round(elapsed*1000, 2)}ms, dur={dur}s")
+ return r
+ return wrapper
diff --git a/fireredasr2s/fireredvad/vad.py b/fireredasr2s/fireredvad/vad.py
new file mode 100644
index 0000000000000000000000000000000000000000..78266c2ad1513b3328117a94af2c05d8f7438f08
--- /dev/null
+++ b/fireredasr2s/fireredvad/vad.py
@@ -0,0 +1,98 @@
+# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang)
+
+import logging
+import os
+from dataclasses import dataclass
+
+import torch
+
+from .core.audio_feat import AudioFeat
+from .core.detect_model import DetectModel
+from .core.vad_postprocessor import VadPostprocessor
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FireRedVadConfig:
+ use_gpu: bool = False
+ smooth_window_size: int = 5
+ speech_threshold: float = 0.4
+ min_speech_frame: int = 20
+ max_speech_frame: int = 2000 # 20s
+ min_silence_frame: int = 20
+ merge_silence_frame: int = 0
+ extend_speech_frame: int = 0
+ chunk_max_frame: int = 30000 # 300s
+ def __post_init__(self):
+ if self.speech_threshold < 0 or self.speech_threshold > 1:
+ raise ValueError("speech_threshold must be in [0, 1]")
+ if self.min_speech_frame <= 0:
+ raise ValueError("min_speech_frame must be positive")
+
+
+
+class FireRedVad:
+ @classmethod
+ def from_pretrained(cls, model_dir, config=FireRedVadConfig()):
+ # Build Feat Extractor
+ cmvn_path = os.path.join(model_dir, "cmvn.ark")
+ audio_feat = AudioFeat(cmvn_path)
+
+ # Build Model
+ vad_model = DetectModel.from_pretrained(model_dir)
+ if config.use_gpu:
+ vad_model.cuda()
+ else:
+ vad_model.cpu()
+
+ # Build Postprocessor
+ vad_postprocessor = VadPostprocessor(
+ config.smooth_window_size,
+ config.speech_threshold,
+ config.min_speech_frame,
+ config.max_speech_frame,
+ config.min_silence_frame,
+ config.merge_silence_frame,
+ config.extend_speech_frame)
+ return cls(audio_feat, vad_model, vad_postprocessor, config)
+
+ def __init__(self, audio_feat, vad_model, vad_postprocessor, config):
+ self.audio_feat = audio_feat
+ self.vad_model = vad_model
+ self.vad_postprocessor = vad_postprocessor
+ self.config = config
+
+ def detect(self, audio, do_postprocess=True):
+ # Extract feat
+ feats, dur = self.audio_feat.extract(audio)
+ if self.config.use_gpu:
+ feats = feats.cuda()
+
+ # Model inference
+ if feats.size(0) <= self.config.chunk_max_frame:
+ probs, _ = self.vad_model.forward(feats.unsqueeze(0))
+ probs = probs.cpu().squeeze() # (T,)
+ else:
+ logger.warning(f"Too long input, split every {self.config.chunk_max_frame} frames")
+ chunk_probs = []
+ chunks = feats.split(self.config.chunk_max_frame, dim=0)
+ for chunk in chunks:
+ chunk_prob, _ = self.vad_model.forward(chunk.unsqueeze(0))
+ chunk_probs.append(chunk_prob.cpu())
+ probs = torch.cat(chunk_probs, dim=1)
+ probs = probs.squeeze() # (T,)
+
+ if not do_postprocess:
+ return None, probs
+
+ # Prob Postprocess
+ decisions = self.vad_postprocessor.process(probs.tolist())
+ starts_ends_s = self.vad_postprocessor.decision_to_segment(decisions, dur)
+
+ # Format result
+ result = {"dur": round(dur, 3),
+ "timestamps": starts_ends_s}
+ if isinstance(audio, str):
+ result["wav_path"] = audio
+ return result, probs
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7bb31b39072132cf94248777d496ad46b1f78fce
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+--extra-index-url https://download.pytorch.org/whl/cu118
+torch==2.1.0+cu118
+torchaudio==2.1.0+cu118
+transformers==4.51.3
+numpy==1.26.1
+cn2an==0.5.23
+kaldiio==2.18.0
+kaldi_native_fbank==1.15
+sentencepiece==0.1.99
+soundfile==0.12.1
+textgrid