File size: 3,677 Bytes
784e76a
 
64bc319
34d8fa9
162974e
 
784e76a
 
162974e
 
 
3bc7439
162974e
 
3bc7439
162974e
3bc7439
162974e
f526d77
3bc7439
 
 
64bc319
9e835c6
66d962d
 
 
162974e
 
ff49c86
162974e
 
 
9e835c6
162974e
 
 
 
 
 
 
 
 
 
 
 
e8ace30
66d962d
 
 
3bc7439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66d962d
 
162974e
66d962d
 
 
 
 
c762067
 
66d962d
 
3bc7439
 
 
 
66d962d
 
 
 
 
 
 
3bc7439
 
 
 
 
 
66d962d
162974e
 
 
 
3bc7439
 
2fa2902
 
162974e
3bc7439
162974e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import sys

import gradio as gr
import spaces
from huggingface_hub import snapshot_download

sys.path.append("./fireredasr")
from fireredasr.models.fireredasr import FireRedAsr


asr_model_aed = None
asr_model_llm = None


def init_model(model_dir_aed, model_dir_llm):
    global asr_model_aed
    global asr_model_llm
    if asr_model_aed is None:
        asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed)
    if asr_model_llm is None:
        asr_model_llm = FireRedAsr.from_pretrained("llm", model_dir_llm)


@spaces.GPU(duration=20)
def asr_inference(audio_file):
    if not audio_file:
        return "Please upload a wav file"
    batch_uttid = ["demo"]
    batch_wav_path = [audio_file]
    results = asr_model_aed.transcribe(
        batch_uttid,
        batch_wav_path,
        {
            "use_gpu": True,
            "beam_size": 3,
            "nbest": 1,
            "decode_max_len": 0,
            "softmax_smoothing": 1.25,
            "aed_length_penalty": 0.6,
            "eos_penalty": 1.0,
            #"decode_min_len": args.decode_min_len,
            #"repetition_penalty": args.repetition_penalty,
            #"llm_length_penalty": args.llm_length_penalty,
            #"temperature": args.temperature
        }
    )
    text_output = results[0]["text"]
    return text_output


@spaces.GPU(duration=30)
def asr_inference_llm(audio_file):
    if not audio_file:
        return "Please upload a wav file"
    batch_uttid = ["demo"]
    batch_wav_path = [audio_file]
    results = asr_model_llm.transcribe(
        batch_uttid,
        batch_wav_path,
        {
            "use_gpu": True,
            "beam_size": 3,
            "nbest": 1,
            "decode_max_len": 0,
            "decode_min_len": 0,
            "repetition_penalty": 3.0,
            "llm_length_penalty": 1.0,
            "temperature": 1.0
        }
    )
    text_output = results[0]["text"]
    return text_output


with gr.Blocks(title="FireRedASR") as demo:
    gr.HTML(
        "<h1 style='text-align: center'>FireRedASR Demo</h1>"
    )
    gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")

    with gr.Row():
        with gr.Column():
            #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
            audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")

        with gr.Column():
            asr_button = gr.Button("Start Recognition (FireRedASR-AED-L)", variant="primary")
            text_output = gr.Textbox(label="Model Result (FireRedASR-AED-L)", interactive=False, lines=3, max_lines=12)
            asr_button_llm = gr.Button("Start Recognition (FireRedASR-LLM-L)", variant="primary")
            text_output_llm = gr.Textbox(label="Model Result (FireRedASR-LLM-L)", interactive=False, lines=3, max_lines=12)

    asr_button.click(
        fn=asr_inference,
        inputs=[audio_file],
        outputs=[text_output]
    )

    asr_button_llm.click(
        fn=asr_inference_llm,
        inputs=[audio_file],
        outputs=[text_output_llm]
    )


if __name__ == "__main__":
    # Download model
    local_dir='pretrained_models/FireRedASR-AED-L'
    snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
    local_dir_llm='pretrained_models/FireRedASR-LLM-L'
    snapshot_download(repo_id='FireRedTeam/FireRedASR-LLM-L', local_dir=local_dir_llm)
    local_dir_qwen='pretrained_models/FireRedASR-LLM-L/Qwen2-7B-Instruct'
    snapshot_download(repo_id='Qwen/Qwen2-7B-Instruct', local_dir=local_dir_qwen)
    # Init model
    init_model(local_dir, local_dir_llm)
    # UI
    demo.queue()
    demo.launch()