Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,677 Bytes
784e76a 64bc319 34d8fa9 162974e 784e76a 162974e 3bc7439 162974e 3bc7439 162974e 3bc7439 162974e f526d77 3bc7439 64bc319 9e835c6 66d962d 162974e ff49c86 162974e 9e835c6 162974e e8ace30 66d962d 3bc7439 66d962d 162974e 66d962d c762067 66d962d 3bc7439 66d962d 3bc7439 66d962d 162974e 3bc7439 2fa2902 162974e 3bc7439 162974e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import sys
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
sys.path.append("./fireredasr")
from fireredasr.models.fireredasr import FireRedAsr
asr_model_aed = None
asr_model_llm = None
def init_model(model_dir_aed, model_dir_llm):
global asr_model_aed
global asr_model_llm
if asr_model_aed is None:
asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed)
if asr_model_llm is None:
asr_model_llm = FireRedAsr.from_pretrained("llm", model_dir_llm)
@spaces.GPU(duration=20)
def asr_inference(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_aed.transcribe(
batch_uttid,
batch_wav_path,
{
"use_gpu": True,
"beam_size": 3,
"nbest": 1,
"decode_max_len": 0,
"softmax_smoothing": 1.25,
"aed_length_penalty": 0.6,
"eos_penalty": 1.0,
#"decode_min_len": args.decode_min_len,
#"repetition_penalty": args.repetition_penalty,
#"llm_length_penalty": args.llm_length_penalty,
#"temperature": args.temperature
}
)
text_output = results[0]["text"]
return text_output
@spaces.GPU(duration=30)
def asr_inference_llm(audio_file):
if not audio_file:
return "Please upload a wav file"
batch_uttid = ["demo"]
batch_wav_path = [audio_file]
results = asr_model_llm.transcribe(
batch_uttid,
batch_wav_path,
{
"use_gpu": True,
"beam_size": 3,
"nbest": 1,
"decode_max_len": 0,
"decode_min_len": 0,
"repetition_penalty": 3.0,
"llm_length_penalty": 1.0,
"temperature": 1.0
}
)
text_output = results[0]["text"]
return text_output
with gr.Blocks(title="FireRedASR") as demo:
gr.HTML(
"<h1 style='text-align: center'>FireRedASR Demo</h1>"
)
gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
with gr.Row():
with gr.Column():
#audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
with gr.Column():
asr_button = gr.Button("Start Recognition (FireRedASR-AED-L)", variant="primary")
text_output = gr.Textbox(label="Model Result (FireRedASR-AED-L)", interactive=False, lines=3, max_lines=12)
asr_button_llm = gr.Button("Start Recognition (FireRedASR-LLM-L)", variant="primary")
text_output_llm = gr.Textbox(label="Model Result (FireRedASR-LLM-L)", interactive=False, lines=3, max_lines=12)
asr_button.click(
fn=asr_inference,
inputs=[audio_file],
outputs=[text_output]
)
asr_button_llm.click(
fn=asr_inference_llm,
inputs=[audio_file],
outputs=[text_output_llm]
)
if __name__ == "__main__":
# Download model
local_dir='pretrained_models/FireRedASR-AED-L'
snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
local_dir_llm='pretrained_models/FireRedASR-LLM-L'
snapshot_download(repo_id='FireRedTeam/FireRedASR-LLM-L', local_dir=local_dir_llm)
local_dir_qwen='pretrained_models/FireRedASR-LLM-L/Qwen2-7B-Instruct'
snapshot_download(repo_id='Qwen/Qwen2-7B-Instruct', local_dir=local_dir_qwen)
# Init model
init_model(local_dir, local_dir_llm)
# UI
demo.queue()
demo.launch()
|