FireRedTeam commited on
Commit
3bc7439
·
1 Parent(s): c762067

add llm asr

Browse files
Files changed (1) hide show
  1. app.py +43 -4
app.py CHANGED
@@ -9,12 +9,17 @@ from fireredasr.models.fireredasr import FireRedAsr
9
 
10
 
11
  asr_model_aed = None
 
12
 
13
 
14
- def init_model(model_dir_aed):
15
  global asr_model_aed
 
16
  if asr_model_aed is None:
17
  asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed)
 
 
 
18
 
19
  @spaces.GPU(duration=20)
20
  def asr_inference(audio_file):
@@ -43,6 +48,30 @@ def asr_inference(audio_file):
43
  return text_output
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  with gr.Blocks(title="FireRedASR") as demo:
47
  gr.HTML(
48
  "<h1 style='text-align: center'>FireRedASR Demo</h1>"
@@ -53,10 +82,12 @@ with gr.Blocks(title="FireRedASR") as demo:
53
  with gr.Column():
54
  #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
55
  audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
56
- asr_button = gr.Button("Start Recognition", variant="primary")
57
 
58
  with gr.Column():
59
- text_output = gr.Textbox(label="Model Result", interactive=False, lines=6, max_lines=12)
 
 
 
60
 
61
  asr_button.click(
62
  fn=asr_inference,
@@ -64,13 +95,21 @@ with gr.Blocks(title="FireRedASR") as demo:
64
  outputs=[text_output]
65
  )
66
 
 
 
 
 
 
 
67
 
68
  if __name__ == "__main__":
69
  # Download model
70
  local_dir='pretrained_models/FireRedASR-AED-L'
71
  snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
 
 
72
  # Init model
73
- init_model(local_dir)
74
  # UI
75
  demo.queue()
76
  demo.launch()
 
9
 
10
 
11
  asr_model_aed = None
12
+ asr_model_llm = None
13
 
14
 
15
+ def init_model(model_dir_aed, model_dir_llm):
16
  global asr_model_aed
17
+ global asr_model_llm
18
  if asr_model_aed is None:
19
  asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir_aed)
20
+ if asr_model_llm is None:
21
+ asr_model_llm = FireRedAsr.from_pretrained("llm", model_dir_llm)
22
+
23
 
24
  @spaces.GPU(duration=20)
25
  def asr_inference(audio_file):
 
48
  return text_output
49
 
50
 
51
+ @spaces.GPU(duration=30)
52
+ def asr_inference_llm(audio_file):
53
+ if not audio_file:
54
+ return "Please upload a wav file"
55
+ batch_uttid = ["demo"]
56
+ batch_wav_path = [audio_file]
57
+ results = asr_model_llm.transcribe(
58
+ batch_uttid,
59
+ batch_wav_path,
60
+ {
61
+ "use_gpu": True,
62
+ "beam_size": 3,
63
+ "nbest": 1,
64
+ "decode_max_len": 0,
65
+ "decode_min_len": 0,
66
+ "repetition_penalty": 3.0,
67
+ "llm_length_penalty": 1.0,
68
+ "temperature": 1.0
69
+ }
70
+ )
71
+ text_output = results[0]["text"]
72
+ return text_output
73
+
74
+
75
  with gr.Blocks(title="FireRedASR") as demo:
76
  gr.HTML(
77
  "<h1 style='text-align: center'>FireRedASR Demo</h1>"
 
82
  with gr.Column():
83
  #audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
84
  audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
 
85
 
86
  with gr.Column():
87
+ asr_button = gr.Button("Start Recognition (FireRedASR-AED-L)", variant="primary")
88
+ text_output = gr.Textbox(label="Model Result (FireRedASR-AED-L)", interactive=False, lines=3, max_lines=12)
89
+ asr_button_llm = gr.Button("Start Recognition (FireRedASR-LLM-L)", variant="primary")
90
+ text_output_llm = gr.Textbox(label="Model Result (FireRedASR-LLM-L)", interactive=False, lines=3, max_lines=12)
91
 
92
  asr_button.click(
93
  fn=asr_inference,
 
95
  outputs=[text_output]
96
  )
97
 
98
+ asr_button_llm.click(
99
+ fn=asr_inference_llm,
100
+ inputs=[audio_file],
101
+ outputs=[text_output_llm]
102
+ )
103
+
104
 
105
  if __name__ == "__main__":
106
  # Download model
107
  local_dir='pretrained_models/FireRedASR-AED-L'
108
  snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
109
+ local_dir_llm='pretrained_models/FireRedASR-LLM-L'
110
+ snapshot_download(repo_id='FireRedTeam/FireRedASR-LLM-L', local_dir=local_dir_llm)
111
  # Init model
112
+ init_model(local_dir, local_dir_llm)
113
  # UI
114
  demo.queue()
115
  demo.launch()