Spaces:
Running on Zero
Running on Zero
| import sys | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| sys.path.append("./fireredasr2s") | |
| from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig | |
| from fireredasr2s.fireredasr2.asr import FireRedAsr2, FireRedAsr2Config | |
| from fireredasr2s.fireredvad.vad import FireRedVad, FireRedVadConfig | |
| from fireredasr2s.fireredvad.aed import FireRedAed, FireRedAedConfig | |
| from fireredasr2s.fireredvad.stream_vad import FireRedStreamVad, FireRedStreamVadConfig | |
| asr_system = None | |
| asr_model_aed = None | |
| asr_model_llm = None | |
| vad_model = None | |
| aed_model = None | |
| stream_vad_model = None | |
| def init_model(model_dir_aed, model_dir_llm): | |
| global asr_system | |
| global asr_model_aed | |
| global asr_model_llm | |
| global vad_model | |
| global aed_model | |
| global stream_vad_model | |
| if asr_system is None: | |
| asr_system_config = FireRedAsr2SystemConfig() # Use default config | |
| asr_system = FireRedAsr2System(asr_system_config) | |
| if asr_model_aed is None: | |
| asr_config_aed = FireRedAsr2Config( | |
| use_gpu=True, | |
| use_half=False, | |
| beam_size=3, | |
| nbest=1, | |
| decode_max_len=0, | |
| softmax_smoothing=1.25, | |
| aed_length_penalty=0.6, | |
| eos_penalty=1.0, | |
| return_timestamp=True | |
| ) | |
| asr_model_aed = FireRedAsr2.from_pretrained("aed", model_dir_aed, asr_config_aed) | |
| if asr_model_llm is None: | |
| asr_config_llm = FireRedAsr2Config( | |
| use_gpu=True, | |
| decode_min_len=0, | |
| repetition_penalty=3.0, | |
| llm_length_penalty=1.0, | |
| temperature=1.0 | |
| ) | |
| asr_model_llm = FireRedAsr2.from_pretrained("llm", model_dir_llm, asr_config_llm) | |
| if vad_model is None: | |
| vad_config = FireRedVadConfig( | |
| use_gpu=False, | |
| smooth_window_size=5, | |
| speech_threshold=0.4, | |
| min_speech_frame=20, | |
| max_speech_frame=2000, | |
| min_silence_frame=20, | |
| merge_silence_frame=0, | |
| extend_speech_frame=0, | |
| chunk_max_frame=30000) | |
| vad_model = FireRedVad.from_pretrained("pretrained_models/FireRedVAD/VAD", vad_config) | |
| if aed_model is None: | |
| aed_config = FireRedAedConfig( | |
| use_gpu=False, | |
| smooth_window_size=5, | |
| speech_threshold=0.4, | |
| singing_threshold=0.5, | |
| music_threshold=0.5, | |
| min_event_frame=20, | |
| max_event_frame=2000, | |
| min_silence_frame=20, | |
| merge_silence_frame=0, | |
| extend_speech_frame=0, | |
| chunk_max_frame=30000) | |
| aed_model = FireRedAed.from_pretrained("pretrained_models/FireRedVAD/AED", aed_config) | |
| if stream_vad_model is None: | |
| vad_config = FireRedStreamVadConfig( | |
| use_gpu=False, | |
| smooth_window_size=5, | |
| speech_threshold=0.4, | |
| pad_start_frame=5, | |
| min_speech_frame=8, | |
| max_speech_frame=2000, | |
| min_silence_frame=20, | |
| chunk_max_frame=30000) | |
| stream_vad_model = FireRedStreamVad.from_pretrained("pretrained_models/FireRedVAD/Stream-VAD", vad_config) | |
| def asr_sys_inference(audio_file): | |
| if not audio_file: | |
| return "Please upload a wav file" | |
| results = asr_system.process(audio_file) | |
| s = f'ASR: {results["text"]}\nSentences: {results["sentences"]}\nVAD(ms): {results["vad_segments_ms"]}\nDuration: {results["dur_s"]}s' | |
| return s | |
| def asr_inference(audio_file): | |
| if not audio_file: | |
| return "Please upload a wav file" | |
| batch_uttid = ["demo"] | |
| batch_wav_path = [audio_file] | |
| results = asr_model_aed.transcribe( | |
| batch_uttid, | |
| batch_wav_path | |
| ) | |
| text_output = results[0]["text"] | |
| return text_output | |
| def asr_inference_llm(audio_file): | |
| if not audio_file: | |
| return "Please upload a wav file" | |
| batch_uttid = ["demo"] | |
| batch_wav_path = [audio_file] | |
| results = asr_model_llm.transcribe( | |
| batch_uttid, | |
| batch_wav_path, | |
| ) | |
| text_output = results[0]["text"] | |
| return text_output | |
| def vad_inference(audio_file): | |
| if not audio_file: | |
| return "Please upload a wav file" | |
| result, probs = vad_model.detect(audio_file) | |
| s = f'Duration: {result["dur"]}s' | |
| s += f'\nVAD: {result["timestamps"]}' | |
| frame_results, result = stream_vad_model.detect_full(audio_file) | |
| s += f'\nStream VAD: {result["timestamps"]}' | |
| result, probs = aed_model.detect(audio_file) | |
| s += f'\nAudio Event: {result["event2ratio"]}\n {result["event2timestamps"]}' | |
| return s | |
| with gr.Blocks(title="FireRedASR2S") as demo: | |
| gr.HTML( | |
| "<h1 style='text-align: center'>FireRedASR2S Demo</h1>" | |
| ) | |
| gr.Markdown("Upload an audio file (wav) to get speech-to-text results.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath") | |
| audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath") | |
| vad_button = gr.Button("Start Recognition (FireRedVAD)", variant="primary") | |
| vad_output = gr.Textbox(label="Model Result (FireRedVAD)", interactive=False, lines=3, max_lines=12) | |
| with gr.Column(): | |
| asr_sys_button = gr.Button("Start Recognition (FireRedASR2S)", variant="primary") | |
| text_sys_output = gr.Textbox(label="Model Result (FireRedASR2S)", interactive=False, lines=3, max_lines=12) | |
| asr_button = gr.Button("Start Recognition (FireRedASR2-AED-L)", variant="primary") | |
| text_output = gr.Textbox(label="Model Result (FireRedASR2-AED-L)", interactive=False, lines=3, max_lines=12) | |
| asr_button_llm = gr.Button("Start Recognition (FireRedASR2-LLM-L)", variant="primary") | |
| text_output_llm = gr.Textbox(label="Model Result (FireRedASR2-LLM-L)", interactive=False, lines=3, max_lines=12) | |
| vad_button.click( | |
| fn=vad_inference, | |
| inputs=[audio_file], | |
| outputs=[vad_output] | |
| ) | |
| asr_sys_button.click( | |
| fn=asr_sys_inference, | |
| inputs=[audio_file], | |
| outputs=[text_sys_output] | |
| ) | |
| asr_button.click( | |
| fn=asr_inference, | |
| inputs=[audio_file], | |
| outputs=[text_output] | |
| ) | |
| asr_button_llm.click( | |
| fn=asr_inference_llm, | |
| inputs=[audio_file], | |
| outputs=[text_output_llm] | |
| ) | |
| if __name__ == "__main__": | |
| # Download model | |
| local_dir='pretrained_models/FireRedASR2-AED' | |
| snapshot_download(repo_id='FireRedTeam/FireRedASR2-AED', local_dir=local_dir) | |
| local_dir_llm='pretrained_models/FireRedASR2-LLM' | |
| snapshot_download(repo_id='FireRedTeam/FireRedASR2-LLM', local_dir=local_dir_llm) | |
| for name in ['FireRedVAD', 'FireRedLID', 'FireRedPunc']: | |
| snapshot_download(repo_id=f'FireRedTeam/{name}', local_dir=f'pretrained_models/{name}') | |
| # Init model | |
| init_model(local_dir, local_dir_llm) | |
| # UI | |
| demo.queue() | |
| demo.launch() | |