Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| sys.path.append("./fireredasr") | |
| from fireredasr.models.fireredasr import FireRedAsr | |
| asr_model_aed = None | |
| asr_model_llm = None | |
| def init_model(model_dir_aed, model_dir_llm): | |
| global asr_model_aed | |
| global asr_model_llm | |
| if asr_model_aed is None: | |
| asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed) | |
| if asr_model_llm is None: | |
| asr_model_llm = FireRedAsr.from_pretrained("llm", model_dir_llm) | |
| def asr_inference(audio_file): | |
| if not audio_file: | |
| return "Please upload a wav file" | |
| batch_uttid = ["demo"] | |
| batch_wav_path = [audio_file] | |
| results = asr_model_aed.transcribe( | |
| batch_uttid, | |
| batch_wav_path, | |
| { | |
| "use_gpu": True, | |
| "beam_size": 3, | |
| "nbest": 1, | |
| "decode_max_len": 0, | |
| "softmax_smoothing": 1.25, | |
| "aed_length_penalty": 0.6, | |
| "eos_penalty": 1.0, | |
| #"decode_min_len": args.decode_min_len, | |
| #"repetition_penalty": args.repetition_penalty, | |
| #"llm_length_penalty": args.llm_length_penalty, | |
| #"temperature": args.temperature | |
| } | |
| ) | |
| text_output = results[0]["text"] | |
| return text_output | |
| def asr_inference_llm(audio_file): | |
| if not audio_file: | |
| return "Please upload a wav file" | |
| batch_uttid = ["demo"] | |
| batch_wav_path = [audio_file] | |
| results = asr_model_llm.transcribe( | |
| batch_uttid, | |
| batch_wav_path, | |
| { | |
| "use_gpu": True, | |
| "beam_size": 3, | |
| "nbest": 1, | |
| "decode_max_len": 0, | |
| "decode_min_len": 0, | |
| "repetition_penalty": 3.0, | |
| "llm_length_penalty": 1.0, | |
| "temperature": 1.0 | |
| } | |
| ) | |
| text_output = results[0]["text"] | |
| return text_output | |
| with gr.Blocks(title="FireRedASR") as demo: | |
| gr.HTML( | |
| "<h1 style='text-align: center'>FireRedASR Demo</h1>" | |
| ) | |
| gr.Markdown("Upload an audio file (wav) to get speech-to-text results.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath") | |
| audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath") | |
| with gr.Column(): | |
| asr_button = gr.Button("Start Recognition (FireRedASR-AED-L)", variant="primary") | |
| text_output = gr.Textbox(label="Model Result (FireRedASR-AED-L)", interactive=False, lines=3, max_lines=12) | |
| asr_button_llm = gr.Button("Start Recognition (FireRedASR-LLM-L)", variant="primary") | |
| text_output_llm = gr.Textbox(label="Model Result (FireRedASR-LLM-L)", interactive=False, lines=3, max_lines=12) | |
| asr_button.click( | |
| fn=asr_inference, | |
| inputs=[audio_file], | |
| outputs=[text_output] | |
| ) | |
| asr_button_llm.click( | |
| fn=asr_inference_llm, | |
| inputs=[audio_file], | |
| outputs=[text_output_llm] | |
| ) | |
| if __name__ == "__main__": | |
| # Download model | |
| local_dir='pretrained_models/FireRedASR-AED-L' | |
| snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir) | |
| local_dir_llm='pretrained_models/FireRedASR-LLM-L' | |
| snapshot_download(repo_id='FireRedTeam/FireRedASR-LLM-L', local_dir=local_dir_llm) | |
| local_dir_qwen='pretrained_models/FireRedASR-LLM-L/Qwen2-7B-Instruct' | |
| snapshot_download(repo_id='Qwen/Qwen2-7B-Instruct', local_dir=local_dir_qwen) | |
| # Init model | |
| init_model(local_dir, local_dir_llm) | |
| # UI | |
| demo.queue() | |
| demo.launch() | |