File size: 7,090 Bytes
0ddb4a4
 
 
 
 
 
 
 
1d62baa
b71b08b
92dd882
f3bc9f0
0ddb4a4
 
1d62baa
0ddb4a4
 
b71b08b
 
f3bc9f0
0ddb4a4
 
 
1d62baa
0ddb4a4
 
b71b08b
 
f3bc9f0
c0614ee
 
 
0ddb4a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71b08b
 
 
 
 
 
 
 
 
 
 
 
f3bc9f0
 
b71b08b
 
 
 
 
 
 
 
 
 
 
 
f3bc9f0
 
 
 
 
 
 
 
 
 
 
0ddb4a4
 
c0614ee
 
 
 
 
b71b08b
7a6c59c
c0614ee
 
0ddb4a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71b08b
 
 
 
f3bc9f0
b71b08b
f3bc9f0
100f398
f3bc9f0
 
100f398
b71b08b
 
 
c0614ee
0ddb4a4
c0614ee
0ddb4a4
 
 
 
 
 
 
b71b08b
 
0ddb4a4
 
c0614ee
 
0ddb4a4
 
 
 
 
b71b08b
 
 
 
 
 
c0614ee
 
 
 
 
 
0ddb4a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22caea
 
 
 
c0614ee
 
0ddb4a4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import sys

import gradio as gr
import spaces
from huggingface_hub import snapshot_download

sys.path.append("./fireredasr2s")
from fireredasr2s import FireRedAsr2System, FireRedAsr2SystemConfig
from fireredasr2s.fireredasr2.asr import FireRedAsr2, FireRedAsr2Config
from fireredasr2s.fireredvad.vad import FireRedVad, FireRedVadConfig
from fireredasr2s.fireredvad.aed import FireRedAed, FireRedAedConfig
from fireredasr2s.fireredvad.stream_vad import FireRedStreamVad, FireRedStreamVadConfig


asr_system = None
asr_model_aed = None
asr_model_llm = None
vad_model = None
aed_model = None
stream_vad_model = None


def init_model(model_dir_aed, model_dir_llm):
    global asr_system
    global asr_model_aed
    global asr_model_llm
    global vad_model
    global aed_model
    global stream_vad_model
    if asr_system is None:
        asr_system_config = FireRedAsr2SystemConfig()  # Use default config
        asr_system = FireRedAsr2System(asr_system_config)
    if asr_model_aed is None:
        asr_config_aed = FireRedAsr2Config(
            use_gpu=True,
            use_half=False,
            beam_size=3,
            nbest=1,
            decode_max_len=0,
            softmax_smoothing=1.25,
            aed_length_penalty=0.6,
            eos_penalty=1.0,
            return_timestamp=True
        )
        asr_model_aed = FireRedAsr2.from_pretrained("aed", model_dir_aed, asr_config_aed)
    if asr_model_llm is None:
        asr_config_llm = FireRedAsr2Config(
            use_gpu=True,
            decode_min_len=0,
            repetition_penalty=3.0,
            llm_length_penalty=1.0,
            temperature=1.0
        )
        asr_model_llm = FireRedAsr2.from_pretrained("llm", model_dir_llm, asr_config_llm)
    if vad_model is None:
        vad_config = FireRedVadConfig(
            use_gpu=False,
            smooth_window_size=5,
            speech_threshold=0.4,
            min_speech_frame=20,
            max_speech_frame=2000,
            min_silence_frame=20,
            merge_silence_frame=0,
            extend_speech_frame=0,
            chunk_max_frame=30000)
        vad_model = FireRedVad.from_pretrained("pretrained_models/FireRedVAD/VAD", vad_config)
    if aed_model is None:
        aed_config = FireRedAedConfig(
            use_gpu=False,
            smooth_window_size=5,
            speech_threshold=0.4,
            singing_threshold=0.5,
            music_threshold=0.5,
            min_event_frame=20,
            max_event_frame=2000,
            min_silence_frame=20,
            merge_silence_frame=0,
            extend_speech_frame=0,
            chunk_max_frame=30000)
        aed_model = FireRedAed.from_pretrained("pretrained_models/FireRedVAD/AED", aed_config)
    if stream_vad_model is None:
        vad_config = FireRedStreamVadConfig(
            use_gpu=False,
            smooth_window_size=5,
            speech_threshold=0.4,
            pad_start_frame=5,
            min_speech_frame=8,
            max_speech_frame=2000,
            min_silence_frame=20,
            chunk_max_frame=30000)
        stream_vad_model = FireRedStreamVad.from_pretrained("pretrained_models/FireRedVAD/Stream-VAD", vad_config)


@spaces.GPU(duration=20)
def asr_sys_inference(audio_file):
    if not audio_file:
        return "Please upload a wav file"
    results = asr_system.process(audio_file)
    s = f'ASR: {results["text"]}\nSentences: {results["sentences"]}\nVAD(ms): {results["vad_segments_ms"]}\nDuration: {results["dur_s"]}s'
    return s


@spaces.GPU(duration=20)
def asr_inference(audio_file):
    if not audio_file:
        return "Please upload a wav file"
    batch_uttid = ["demo"]
    batch_wav_path = [audio_file]
    results = asr_model_aed.transcribe(
        batch_uttid,
        batch_wav_path
    )
    text_output = results[0]["text"]
    return text_output


@spaces.GPU(duration=30)
def asr_inference_llm(audio_file):
    if not audio_file:
        return "Please upload a wav file"
    batch_uttid = ["demo"]
    batch_wav_path = [audio_file]
    results = asr_model_llm.transcribe(
        batch_uttid,
        batch_wav_path,
    )
    text_output = results[0]["text"]
    return text_output


@spaces.GPU(duration=20)
def vad_inference(audio_file):
    if not audio_file:
        return "Please upload a wav file"
    result, probs = vad_model.detect(audio_file)
    s = f'Duration: {result["dur"]}s'
    s += f'\nVAD: {result["timestamps"]}'
    frame_results, result = stream_vad_model.detect_full(audio_file)
    s += f'\nStream VAD: {result["timestamps"]}'
    result, probs = aed_model.detect(audio_file)
    s += f'\nAudio Event: {result["event2ratio"]}\n  {result["event2timestamps"]}'
    return s


with gr.Blocks(title="FireRedASR2S") as demo:
    gr.HTML(
        "<h1 style='text-align: center'>FireRedASR2S Demo</h1>"
    )
    gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")

    with gr.Row():
        with gr.Column():
            #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
            audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
            vad_button = gr.Button("Start Recognition (FireRedVAD)", variant="primary")
            vad_output = gr.Textbox(label="Model Result (FireRedVAD)", interactive=False, lines=3, max_lines=12)

        with gr.Column():
            asr_sys_button = gr.Button("Start Recognition (FireRedASR2S)", variant="primary")
            text_sys_output = gr.Textbox(label="Model Result (FireRedASR2S)", interactive=False, lines=3, max_lines=12)
            asr_button = gr.Button("Start Recognition (FireRedASR2-AED-L)", variant="primary")
            text_output = gr.Textbox(label="Model Result (FireRedASR2-AED-L)", interactive=False, lines=3, max_lines=12)
            asr_button_llm = gr.Button("Start Recognition (FireRedASR2-LLM-L)", variant="primary")
            text_output_llm = gr.Textbox(label="Model Result (FireRedASR2-LLM-L)", interactive=False, lines=3, max_lines=12)

    vad_button.click(
        fn=vad_inference,
        inputs=[audio_file],
        outputs=[vad_output]
    )

    asr_sys_button.click(
        fn=asr_sys_inference,
        inputs=[audio_file],
        outputs=[text_sys_output]
    )

    asr_button.click(
        fn=asr_inference,
        inputs=[audio_file],
        outputs=[text_output]
    )

    asr_button_llm.click(
        fn=asr_inference_llm,
        inputs=[audio_file],
        outputs=[text_output_llm]
    )


if __name__ == "__main__":
    # Download model
    local_dir='pretrained_models/FireRedASR2-AED'
    snapshot_download(repo_id='FireRedTeam/FireRedASR2-AED', local_dir=local_dir)
    local_dir_llm='pretrained_models/FireRedASR2-LLM'
    snapshot_download(repo_id='FireRedTeam/FireRedASR2-LLM', local_dir=local_dir_llm)
    for name in ['FireRedVAD', 'FireRedLID', 'FireRedPunc']:
        snapshot_download(repo_id=f'FireRedTeam/{name}', local_dir=f'pretrained_models/{name}')
    # Init model
    init_model(local_dir, local_dir_llm)
    # UI
    demo.queue()
    demo.launch()