import sys import gradio as gr import spaces from huggingface_hub import snapshot_download sys.path.append("./fireredasr") from fireredasr.models.fireredasr import FireRedAsr asr_model_aed = None asr_model_llm = None def init_model(model_dir_aed, model_dir_llm): global asr_model_aed global asr_model_llm if asr_model_aed is None: asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed) if asr_model_llm is None: asr_model_llm = FireRedAsr.from_pretrained("llm", model_dir_llm) @spaces.GPU(duration=20) def asr_inference(audio_file): if not audio_file: return "Please upload a wav file" batch_uttid = ["demo"] batch_wav_path = [audio_file] results = asr_model_aed.transcribe( batch_uttid, batch_wav_path, { "use_gpu": True, "beam_size": 3, "nbest": 1, "decode_max_len": 0, "softmax_smoothing": 1.25, "aed_length_penalty": 0.6, "eos_penalty": 1.0, #"decode_min_len": args.decode_min_len, #"repetition_penalty": args.repetition_penalty, #"llm_length_penalty": args.llm_length_penalty, #"temperature": args.temperature } ) text_output = results[0]["text"] return text_output @spaces.GPU(duration=30) def asr_inference_llm(audio_file): if not audio_file: return "Please upload a wav file" batch_uttid = ["demo"] batch_wav_path = [audio_file] results = asr_model_llm.transcribe( batch_uttid, batch_wav_path, { "use_gpu": True, "beam_size": 3, "nbest": 1, "decode_max_len": 0, "decode_min_len": 0, "repetition_penalty": 3.0, "llm_length_penalty": 1.0, "temperature": 1.0 } ) text_output = results[0]["text"] return text_output with gr.Blocks(title="FireRedASR") as demo: gr.HTML( "

FireRedASR Demo

" ) gr.Markdown("Upload an audio file (wav) to get speech-to-text results.") with gr.Row(): with gr.Column(): #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath") audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath") with gr.Column(): asr_button = gr.Button("Start Recognition (FireRedASR-AED-L)", variant="primary") text_output = gr.Textbox(label="Model Result (FireRedASR-AED-L)", interactive=False, lines=3, max_lines=12) asr_button_llm = gr.Button("Start Recognition (FireRedASR-LLM-L)", variant="primary") text_output_llm = gr.Textbox(label="Model Result (FireRedASR-LLM-L)", interactive=False, lines=3, max_lines=12) asr_button.click( fn=asr_inference, inputs=[audio_file], outputs=[text_output] ) asr_button_llm.click( fn=asr_inference_llm, inputs=[audio_file], outputs=[text_output_llm] ) if __name__ == "__main__": # Download model local_dir='pretrained_models/FireRedASR-AED-L' snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir) local_dir_llm='pretrained_models/FireRedASR-LLM-L' snapshot_download(repo_id='FireRedTeam/FireRedASR-LLM-L', local_dir=local_dir_llm) local_dir_qwen='pretrained_models/FireRedASR-LLM-L/Qwen2-7B-Instruct' snapshot_download(repo_id='Qwen/Qwen2-7B-Instruct', local_dir=local_dir_qwen) # Init model init_model(local_dir, local_dir_llm) # UI demo.queue() demo.launch()