File size: 2,113 Bytes
64bc319
162974e
 
 
 
 
 
 
 
 
 
 
 
64bc319
 
66d962d
 
 
162974e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66d962d
 
 
 
 
162974e
66d962d
 
 
 
 
162974e
66d962d
 
 
 
 
 
 
 
 
 
 
 
 
162974e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from huggingface_hub import snapshot_download

from fireredasr.fireredasr.models.fireredasr import FireRedAsr


asr_model_aed = None


def init_model(model_dir_aed):
    global asr_model_aed
    if asr_model_aed is None:
        asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir)


def asr_inference(audio_file):
    if not audio_file:
        return "Please upload a wav file"
    batch_uttid = ["demo"]
    batch_wav_path = [audio_file]
    results = model.transcribe(
        batch_uttid,
        batch_wav_path,
        {
            "use_gpu": False,
            "beam_size": 3,
            "nbest": 1,
            "decode_max_len": 0,
            "softmax_smoothing": 1.25,
            "aed_length_penalty": 0.6,
            "eos_penalty": 1.0,
            #"decode_min_len": args.decode_min_len,
            #"repetition_penalty": args.repetition_penalty,
            #"llm_length_penalty": args.llm_length_penalty,
            #"temperature": args.temperature
        }
    )
    text_output = results["text"]
    return text_output


with gr.Blocks(title="FireRedASR") as demo:
    gr.HTML(
        "<h1 style='text-align: center'>FireRedASR Demo</h1>"
    )
    gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")

    with gr.Row():
        with gr.Column():
            audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
            #audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
            asr_button = gr.Button("Start Recognition", variant="primary")

        with gr.Column():
            text_output = gr.Textbox(label="Model Result", interactive=False, lines=6, max_lines=12)

    asr_button.click(
        fn=asr_inference,
        inputs=[audio_file],
        outputs=[text_output]
    )


if __name__ == "__main__":
    # Download model
    local_dir='pretrained_models/FireRedASR-AED-L'
    snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
    # Init model
    init_model(local_dir)
    # UI
    demo.queue()
    demo.launch()