FireRedASR / app.py
FireRedTeam's picture
add gpu
9e835c6
raw
history blame
2.18 kB
import sys
import gradio as gr
from huggingface_hub import snapshot_download
sys.path.append("./fireredasr")
from fireredasr.models.fireredasr import FireRedAsr
asr_model_aed = None
def init_model(model_dir_aed):
global asr_model_aed
if asr_model_aed is None:
asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed)
@spaces.GPU(duration=20)
def asr_inference(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_aed.transcribe(
batch_uttid,
batch_wav_path,
{
"use_gpu": True,
"beam_size": 3,
"nbest": 1,
"decode_max_len": 0,
"softmax_smoothing": 1.25,
"aed_length_penalty": 0.6,
"eos_penalty": 1.0,
#"decode_min_len": args.decode_min_len,
#"repetition_penalty": args.repetition_penalty,
#"llm_length_penalty": args.llm_length_penalty,
#"temperature": args.temperature
}
)
text_output = results["text"]
return text_output
with gr.Blocks(title="FireRedASR") as demo:
gr.HTML(
"<h1 style='text-align: center'>FireRedASR Demo</h1>"
)
gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
with gr.Row():
with gr.Column():
audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
#audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
asr_button = gr.Button("Start Recognition", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Model Result", interactive=False, lines=6, max_lines=12)
asr_button.click(
fn=asr_inference,
inputs=[audio_file],
outputs=[text_output]
)
if __name__ == "__main__":
# Download model
local_dir='pretrained_models/FireRedASR-AED-L'
snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
# Init model
init_model(local_dir)
# UI
demo.queue()
demo.launch()