FireRedASR2S / app.py
FireRedTeam's picture
Upload app.py
100f398 verified
import sys
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
sys.path.append("./fireredasr2s")
from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
from fireredasr2s.fireredasr2.asr import FireRedAsr2, FireRedAsr2Config
from fireredasr2s.fireredvad.vad import FireRedVad, FireRedVadConfig
from fireredasr2s.fireredvad.aed import FireRedAed, FireRedAedConfig
from fireredasr2s.fireredvad.stream_vad import FireRedStreamVad, FireRedStreamVadConfig
asr_system = None
asr_model_aed = None
asr_model_llm = None
vad_model = None
aed_model = None
stream_vad_model = None
def init_model(model_dir_aed, model_dir_llm):
global asr_system
global asr_model_aed
global asr_model_llm
global vad_model
global aed_model
global stream_vad_model
if asr_system is None:
asr_system_config = FireRedAsr2SystemConfig() # Use default config
asr_system = FireRedAsr2System(asr_system_config)
if asr_model_aed is None:
asr_config_aed = FireRedAsr2Config(
use_gpu=True,
use_half=False,
beam_size=3,
nbest=1,
decode_max_len=0,
softmax_smoothing=1.25,
aed_length_penalty=0.6,
eos_penalty=1.0,
return_timestamp=True
)
asr_model_aed = FireRedAsr2.from_pretrained("aed", model_dir_aed, asr_config_aed)
if asr_model_llm is None:
asr_config_llm = FireRedAsr2Config(
use_gpu=True,
decode_min_len=0,
repetition_penalty=3.0,
llm_length_penalty=1.0,
temperature=1.0
)
asr_model_llm = FireRedAsr2.from_pretrained("llm", model_dir_llm, asr_config_llm)
if vad_model is None:
vad_config = FireRedVadConfig(
use_gpu=False,
smooth_window_size=5,
speech_threshold=0.4,
min_speech_frame=20,
max_speech_frame=2000,
min_silence_frame=20,
merge_silence_frame=0,
extend_speech_frame=0,
chunk_max_frame=30000)
vad_model = FireRedVad.from_pretrained("pretrained_models/FireRedVAD/VAD", vad_config)
if aed_model is None:
aed_config = FireRedAedConfig(
use_gpu=False,
smooth_window_size=5,
speech_threshold=0.4,
singing_threshold=0.5,
music_threshold=0.5,
min_event_frame=20,
max_event_frame=2000,
min_silence_frame=20,
merge_silence_frame=0,
extend_speech_frame=0,
chunk_max_frame=30000)
aed_model = FireRedAed.from_pretrained("pretrained_models/FireRedVAD/AED", aed_config)
if stream_vad_model is None:
vad_config = FireRedStreamVadConfig(
use_gpu=False,
smooth_window_size=5,
speech_threshold=0.4,
pad_start_frame=5,
min_speech_frame=8,
max_speech_frame=2000,
min_silence_frame=20,
chunk_max_frame=30000)
stream_vad_model = FireRedStreamVad.from_pretrained("pretrained_models/FireRedVAD/Stream-VAD", vad_config)
@spaces.GPU(duration=20)
def asr_sys_inference(audio_file):
if not audio_file:
return "Please upload a wav file"
results = asr_system.process(audio_file)
s = f'ASR: {results["text"]}\nSentences: {results["sentences"]}\nVAD(ms): {results["vad_segments_ms"]}\nDuration: {results["dur_s"]}s'
return s
@spaces.GPU(duration=20)
def asr_inference(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_aed.transcribe(
batch_uttid,
batch_wav_path
)
text_output = results[0]["text"]
return text_output
@spaces.GPU(duration=30)
def asr_inference_llm(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_llm.transcribe(
batch_uttid,
batch_wav_path,
)
text_output = results[0]["text"]
return text_output
@spaces.GPU(duration=20)
def vad_inference(audio_file):
if not audio_file:
return "Please upload a wav file"
result, probs = vad_model.detect(audio_file)
s = f'Duration: {result["dur"]}s'
s += f'\nVAD: {result["timestamps"]}'
frame_results, result = stream_vad_model.detect_full(audio_file)
s += f'\nStream VAD: {result["timestamps"]}'
result, probs = aed_model.detect(audio_file)
s += f'\nAudio Event: {result["event2ratio"]}\n {result["event2timestamps"]}'
return s
with gr.Blocks(title="FireRedASR2S") as demo:
gr.HTML(
"<h1 style='text-align: center'>FireRedASR2S Demo</h1>"
)
gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
with gr.Row():
with gr.Column():
#audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
vad_button = gr.Button("Start Recognition (FireRedVAD)", variant="primary")
vad_output = gr.Textbox(label="Model Result (FireRedVAD)", interactive=False, lines=3, max_lines=12)
with gr.Column():
asr_sys_button = gr.Button("Start Recognition (FireRedASR2S)", variant="primary")
text_sys_output = gr.Textbox(label="Model Result (FireRedASR2S)", interactive=False, lines=3, max_lines=12)
asr_button = gr.Button("Start Recognition (FireRedASR2-AED-L)", variant="primary")
text_output = gr.Textbox(label="Model Result (FireRedASR2-AED-L)", interactive=False, lines=3, max_lines=12)
asr_button_llm = gr.Button("Start Recognition (FireRedASR2-LLM-L)", variant="primary")
text_output_llm = gr.Textbox(label="Model Result (FireRedASR2-LLM-L)", interactive=False, lines=3, max_lines=12)
vad_button.click(
fn=vad_inference,
inputs=[audio_file],
outputs=[vad_output]
)
asr_sys_button.click(
fn=asr_sys_inference,
inputs=[audio_file],
outputs=[text_sys_output]
)
asr_button.click(
fn=asr_inference,
inputs=[audio_file],
outputs=[text_output]
)
asr_button_llm.click(
fn=asr_inference_llm,
inputs=[audio_file],
outputs=[text_output_llm]
)
if __name__ == "__main__":
# Download model
local_dir='pretrained_models/FireRedASR2-AED'
snapshot_download(repo_id='FireRedTeam/FireRedASR2-AED', local_dir=local_dir)
local_dir_llm='pretrained_models/FireRedASR2-LLM'
snapshot_download(repo_id='FireRedTeam/FireRedASR2-LLM', local_dir=local_dir_llm)
for name in ['FireRedVAD', 'FireRedLID', 'FireRedPunc']:
snapshot_download(repo_id=f'FireRedTeam/{name}', local_dir=f'pretrained_models/{name}')
# Init model
init_model(local_dir, local_dir_llm)
# UI
demo.queue()
demo.launch()