Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,158 Bytes
784e76a 64bc319 162974e 784e76a 162974e f526d77 64bc319 66d962d 162974e ff49c86 162974e 66d962d 162974e 66d962d 162974e 66d962d 162974e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import sys
import gradio as gr
from huggingface_hub import snapshot_download
sys.path.append("./fireredasr")
from fireredasr.models.fireredasr import FireRedAsr
asr_model_aed = None
def init_model(model_dir_aed):
global asr_model_aed
if asr_model_aed is None:
asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed)
def asr_inference(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_aed.transcribe(
batch_uttid,
batch_wav_path,
{
"use_gpu": False,
"beam_size": 3,
"nbest": 1,
"decode_max_len": 0,
"softmax_smoothing": 1.25,
"aed_length_penalty": 0.6,
"eos_penalty": 1.0,
#"decode_min_len": args.decode_min_len,
#"repetition_penalty": args.repetition_penalty,
#"llm_length_penalty": args.llm_length_penalty,
#"temperature": args.temperature
}
)
text_output = results["text"]
return text_output
with gr.Blocks(title="FireRedASR") as demo:
gr.HTML(
"<h1 style='text-align: center'>FireRedASR Demo</h1>"
)
gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
with gr.Row():
with gr.Column():
audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
#audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
asr_button = gr.Button("Start Recognition", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Model Result", interactive=False, lines=6, max_lines=12)
asr_button.click(
fn=asr_inference,
inputs=[audio_file],
outputs=[text_output]
)
if __name__ == "__main__":
# Download model
local_dir='pretrained_models/FireRedASR-AED-L'
snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
# Init model
init_model(local_dir)
# UI
demo.queue()
demo.launch()
|