FireRedASR2S / fireredasr2s /fireredasr2s-cli
FireRed Team
Init
0ddb4a4 verified
#!/usr/bin/env python3
# Copyright 2026 Xiaohongshu. (Author: Kaituo Xu, Kai Huang, Yan Jia, Junjie Chen, Wenpeng Li)
import argparse
import glob
import json
import logging
import os
import soundfile as sf
from textgrid import IntervalTier, TextGrid
from fireredasr2s.fireredasr2 import FireRedAsr2Config
from fireredasr2s.fireredasr2system import (FireRedAsr2System,
FireRedAsr2SystemConfig)
from fireredasr2s.fireredlid import FireRedLidConfig
from fireredasr2s.fireredpunc import FireRedPuncConfig
from fireredasr2s.fireredvad import FireRedVadConfig
logging.basicConfig(level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
logger = logging.getLogger("fireredasr2s.asr_system")
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
input_g = parser.add_argument_group("Input Options")
input_g.add_argument("--wav_path", type=str)
input_g.add_argument("--wav_paths", type=str, nargs="*")
input_g.add_argument("--wav_dir", type=str)
input_g.add_argument("--wav_scp", type=str)
input_g.add_argument("--sort_wav_by_dur", type=int, default=0)
output_g = parser.add_argument_group("Output Options")
output_g.add_argument("--outdir", type=str, default="output")
output_g.add_argument("--write_textgrid", type=int, default=1)
output_g.add_argument("--write_srt", type=int, default=1)
output_g.add_argument("--save_segment", type=int, default=0)
module_g = parser.add_argument_group("Module Switches")
module_g.add_argument('--enable_vad', type=int, default=1, choices=[0, 1])
module_g.add_argument('--enable_lid', type=int, default=1, choices=[0, 1])
module_g.add_argument('--enable_punc', type=int, default=1, choices=[0, 1])
asr_g = parser.add_argument_group("ASR Options")
asr_g.add_argument('--asr_type', type=str, default="aed", choices=["aed", "llm"])
asr_g.add_argument('--asr_model_dir', type=str, default="pretrained_models/FireRedASR2-AED")
asr_g.add_argument('--asr_use_gpu', type=int, default=1)
asr_g.add_argument('--asr_use_half', type=int, default=0)
asr_g.add_argument("--asr_batch_size", type=int, default=1)
# FireRedASR-AED
asr_g.add_argument("--beam_size", type=int, default=3)
asr_g.add_argument("--decode_max_len", type=int, default=0)
asr_g.add_argument("--nbest", type=int, default=1)
asr_g.add_argument("--softmax_smoothing", type=float, default=1.25)
asr_g.add_argument("--aed_length_penalty", type=float, default=0.6)
asr_g.add_argument("--eos_penalty", type=float, default=1.0)
asr_g.add_argument("--return_timestamp", type=int, default=1)
# FireRedASR-AED External LM
asr_g.add_argument("--elm_dir", type=str, default="")
asr_g.add_argument("--elm_weight", type=float, default=0.0)
vad_g = parser.add_argument_group("VAD Options")
vad_g.add_argument('--vad_model_dir', type=str, default="pretrained_models/FireRedVAD/VAD")
vad_g.add_argument('--vad_use_gpu', type=int, default=1)
# Non-streaming VAD
vad_g.add_argument("--vad_chunk_max_frame", type=int, default=30000)
vad_g.add_argument("--smooth_window_size", type=int, default=5)
vad_g.add_argument("--speech_threshold", type=float, default=0.2)
vad_g.add_argument("--min_speech_frame", type=int, default=20)
vad_g.add_argument("--max_speech_frame", type=int, default=1000)
vad_g.add_argument("--min_silence_frame", type=int, default=10)
vad_g.add_argument("--merge_silence_frame", type=int, default=50)
vad_g.add_argument("--extend_speech_frame", type=int, default=10)
lid_g = parser.add_argument_group("LID Options")
lid_g.add_argument('--lid_model_dir', type=str, default="pretrained_models/FireRedLID")
lid_g.add_argument('--lid_use_gpu', type=int, default=1)
punc_g = parser.add_argument_group("Punc Options")
punc_g.add_argument('--punc_model_dir', type=str, default="pretrained_models/FireRedPunc")
punc_g.add_argument('--punc_use_gpu', type=int, default=1)
punc_g.add_argument("--punc_batch_size", type=int, default=1)
punc_g.add_argument('--punc_with_timestamp', type=int, default=1)
punc_g.add_argument('--punc_sentence_max_length', type=int, default=-1)
def main(args):
wavs = get_wav_info(args)
if args.outdir:
os.makedirs(args.outdir, exist_ok=True)
fout = open(args.outdir + "/result.jsonl", "w") if args.outdir else None
# Build Models
# VAD
vad_config = FireRedVadConfig(
args.vad_use_gpu,
args.smooth_window_size,
args.speech_threshold,
args.min_speech_frame,
args.max_speech_frame,
args.min_silence_frame,
args.merge_silence_frame,
args.extend_speech_frame,
args.vad_chunk_max_frame
)
# LID
lid_config = FireRedLidConfig(args.lid_use_gpu)
# ASR
asr_config = FireRedAsr2Config(
args.asr_use_gpu,
args.asr_use_half,
args.beam_size,
args.nbest,
args.decode_max_len,
args.softmax_smoothing,
args.aed_length_penalty,
args.eos_penalty,
args.return_timestamp,
0, 1.0, 0.0, 1.0,
args.elm_dir,
args.elm_weight
)
# Punc
punc_config = FireRedPuncConfig(
args.punc_use_gpu,
args.punc_sentence_max_length
)
asr_system_config = FireRedAsr2SystemConfig(
args.vad_model_dir, args.lid_model_dir,
args.asr_type, args.asr_model_dir, args.punc_model_dir,
vad_config, lid_config, asr_config, punc_config,
args.asr_batch_size, args.punc_batch_size,
args.enable_vad, args.enable_lid, args.enable_punc
)
asr_system = FireRedAsr2System(asr_system_config)
for i, (uttid, wav_path) in enumerate(wavs):
logger.info("")
result = asr_system.process(wav_path, uttid)
logger.info(f"FINAL: {result}")
if fout:
fout.write(f"{json.dumps(result, ensure_ascii=False)}\n")
fout.flush()
name = os.path.basename(wav_path).replace(".wav", "")
if args.write_textgrid:
tg_dir = os.path.join(args.outdir, "asr_tg")
write_textgrid(tg_dir, name, result["dur_s"], result["sentences"], result["words"])
if args.write_srt:
srt_dir = os.path.join(args.outdir, "asr_srt")
write_srt(srt_dir, name, result["sentences"])
if args.save_segment:
save_segment_dir = os.path.join(args.outdir, "vad_segment")
split_and_save_segment(wav_path, result["vad_segments_ms"], save_segment_dir)
if fout:
fout.close()
logger.info("All Done")
def get_wav_info(args):
"""
Returns:
wavs: list of (uttid, wav_path)
"""
def base(p): return os.path.basename(p).replace(".wav", "")
if args.wav_path:
wavs = [(base(args.wav_path), args.wav_path)]
elif args.wav_paths and len(args.wav_paths) >= 1:
wavs = [(base(p), p) for p in sorted(args.wav_paths)]
elif args.wav_scp:
wavs = [line.strip().split() for line in open(args.wav_scp)]
elif args.wav_dir:
wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
wavs = [(base(p), p) for p in sorted(wavs)]
else:
raise ValueError("Please provide valid wav info")
logger.info(f"#wavs={len(wavs)}")
return wavs
def write_textgrid(tg_dir, name, wav_dur, sentences, words=None):
os.makedirs(tg_dir, exist_ok=True)
textgrid_file = os.path.join(tg_dir, name + ".TextGrid")
logger.info(f"Write {textgrid_file}")
textgrid = TextGrid(maxTime=wav_dur)
tier = IntervalTier(name="sentence", maxTime=wav_dur)
for sentence in sentences:
start_s = sentence["start_ms"] / 1000.0
end_s = sentence["end_ms"] / 1000.0
text = sentence["text"]
confi = sentence["asr_confidence"]
if start_s == end_s:
logger.info(f"(sent) Write TG, skip start=end {start_s} {text}")
continue
start_s = max(start_s, 0)
end_s = min(end_s, wav_dur)
tier.add(minTime=start_s, maxTime=end_s, mark=f"{text}\n{confi}")
textgrid.append(tier)
if words:
tier = IntervalTier(name="token", maxTime=wav_dur)
for word in words:
start_s = word["start_ms"] / 1000.0
end_s = word["end_ms"] / 1000.0
text = word["text"]
if start_s == end_s:
logger.info(f"(word) Write TG, skip start=end {start_s} {text}")
continue
start_s = max(start_s, 0)
end_s = min(end_s, wav_dur)
tier.add(minTime=start_s, maxTime=end_s, mark=text)
textgrid.append(tier)
textgrid.write(textgrid_file)
def write_srt(srt_dir, name, sentences):
def _ms2srt_time(ms):
h = ms // 1000 // 3600
m = (ms // 1000 % 3600) // 60
s = (ms // 1000 % 3600) % 60
ms = (ms % 1000)
r = f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
return r
os.makedirs(srt_dir, exist_ok=True)
srt_file = os.path.join(srt_dir, name + ".srt")
logger.info(f"Write {srt_file}")
i = 0
with open(srt_file, "w") as fout:
for sentence in sentences:
start_ms = sentence["start_ms"]
end_ms = sentence["end_ms"]
text = sentence["text"]
if text.strip() == "":
continue
i += 1
fout.write(f"{i}\n")
s = _ms2srt_time(start_ms)
e = _ms2srt_time(end_ms)
fout.write(f"{s} --> {e}\n")
fout.write(f"{text}\n")
if i != len(sentences):
fout.write("\n")
def split_and_save_segment(wav_path, timestamps_ms, save_segment_dir):
logger.info("Split & save segment")
os.makedirs(save_segment_dir, exist_ok=True)
wav_np, sample_rate = sf.read(wav_path, dtype="int16")
for i, (start_ms, end_ms) in enumerate(timestamps_ms):
uttid = wav_path.split("/")[-1].replace(".wav", "")
seg_id = f"{uttid}_{i}_{start_ms}_{end_ms}"
seg_path = f"{save_segment_dir}/{seg_id}.wav"
start = int(start_ms / 1000 * sample_rate)
end = int(end_ms / 1000 * sample_rate)
sf.write(seg_path, wav_np[start:end], samplerate=sample_rate)
def cli_main():
args = parser.parse_args()
logger.info(args)
main(args)
if __name__ == "__main__":
cli_main()