FireRedASR / app.py
FireRedTeam's picture
fix qwen
2fa2902
import sys
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
sys.path.append("./fireredasr")
from fireredasr.models.fireredasr import FireRedAsr
asr_model_aed = None
asr_model_llm = None
def init_model(model_dir_aed, model_dir_llm):
global asr_model_aed
global asr_model_llm
if asr_model_aed is None:
asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed)
if asr_model_llm is None:
asr_model_llm = FireRedAsr.from_pretrained("llm", model_dir_llm)
@spaces.GPU(duration=20)
def asr_inference(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_aed.transcribe(
batch_uttid,
batch_wav_path,
{
"use_gpu": True,
"beam_size": 3,
"nbest": 1,
"decode_max_len": 0,
"softmax_smoothing": 1.25,
"aed_length_penalty": 0.6,
"eos_penalty": 1.0,
#"decode_min_len": args.decode_min_len,
#"repetition_penalty": args.repetition_penalty,
#"llm_length_penalty": args.llm_length_penalty,
#"temperature": args.temperature
}
)
text_output = results[0]["text"]
return text_output
@spaces.GPU(duration=30)
def asr_inference_llm(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_llm.transcribe(
batch_uttid,
batch_wav_path,
{
"use_gpu": True,
"beam_size": 3,
"nbest": 1,
"decode_max_len": 0,
"decode_min_len": 0,
"repetition_penalty": 3.0,
"llm_length_penalty": 1.0,
"temperature": 1.0
}
)
text_output = results[0]["text"]
return text_output
with gr.Blocks(title="FireRedASR") as demo:
gr.HTML(
"<h1 style='text-align: center'>FireRedASR Demo</h1>"
)
gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
with gr.Row():
with gr.Column():
#audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
with gr.Column():
asr_button = gr.Button("Start Recognition (FireRedASR-AED-L)", variant="primary")
text_output = gr.Textbox(label="Model Result (FireRedASR-AED-L)", interactive=False, lines=3, max_lines=12)
asr_button_llm = gr.Button("Start Recognition (FireRedASR-LLM-L)", variant="primary")
text_output_llm = gr.Textbox(label="Model Result (FireRedASR-LLM-L)", interactive=False, lines=3, max_lines=12)
asr_button.click(
fn=asr_inference,
inputs=[audio_file],
outputs=[text_output]
)
asr_button_llm.click(
fn=asr_inference_llm,
inputs=[audio_file],
outputs=[text_output_llm]
)
if __name__ == "__main__":
# Download model
local_dir='pretrained_models/FireRedASR-AED-L'
snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
local_dir_llm='pretrained_models/FireRedASR-LLM-L'
snapshot_download(repo_id='FireRedTeam/FireRedASR-LLM-L', local_dir=local_dir_llm)
local_dir_qwen='pretrained_models/FireRedASR-LLM-L/Qwen2-7B-Instruct'
snapshot_download(repo_id='Qwen/Qwen2-7B-Instruct', local_dir=local_dir_qwen)
# Init model
init_model(local_dir, local_dir_llm)
# UI
demo.queue()
demo.launch()