FireRedTeam commited on
Commit
162974e
·
1 Parent(s): 66d962d

add fireredasr inference code

Browse files
Files changed (3) hide show
  1. app.py +43 -5
  2. fireredasr +1 -0
  3. pretrained_models/README.md +1 -0
app.py CHANGED
@@ -1,23 +1,53 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  def asr_inference(audio_file):
5
  if not audio_file:
6
  return "Please upload a wav file"
7
- model = None
8
- text_output = "Developing " + audio_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  return text_output
10
 
11
 
12
  with gr.Blocks(title="FireRedASR") as demo:
13
  gr.HTML(
14
- "<h1 style='text-align: center'>FireRedASR</h1>"
15
  )
16
  gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
17
 
18
  with gr.Row():
19
  with gr.Column():
20
- audio_file = gr.Audio(label="Upload Audio", sources=["microphone", "upload"], type="filepath")
21
  #audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
22
  asr_button = gr.Button("Start Recognition", variant="primary")
23
 
@@ -31,4 +61,12 @@ with gr.Blocks(title="FireRedASR") as demo:
31
  )
32
 
33
 
34
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import snapshot_download
3
+
4
+ from fireredasr.fireredasr.models.fireredasr import FireRedAsr
5
+
6
+
7
+ asr_model_aed = None
8
+
9
+
10
+ def init_model(model_dir_aed):
11
+ global asr_model_aed
12
+ if asr_model_aed is None:
13
+ asr_model_aed = FireRedAsr.from_pretrained("aed", model_dir)
14
 
15
 
16
  def asr_inference(audio_file):
17
  if not audio_file:
18
  return "Please upload a wav file"
19
+ batch_uttid = ["demo"]
20
+ batch_wav_path = [audio_file]
21
+ results = model.transcribe(
22
+ batch_uttid,
23
+ batch_wav_path,
24
+ {
25
+ "use_gpu": False,
26
+ "beam_size": 3,
27
+ "nbest": 1,
28
+ "decode_max_len": 0,
29
+ "softmax_smoothing": 1.25,
30
+ "aed_length_penalty": 0.6,
31
+ "eos_penalty": 1.0,
32
+ #"decode_min_len": args.decode_min_len,
33
+ #"repetition_penalty": args.repetition_penalty,
34
+ #"llm_length_penalty": args.llm_length_penalty,
35
+ #"temperature": args.temperature
36
+ }
37
+ )
38
+ text_output = results["text"]
39
  return text_output
40
 
41
 
42
  with gr.Blocks(title="FireRedASR") as demo:
43
  gr.HTML(
44
+ "<h1 style='text-align: center'>FireRedASR Demo</h1>"
45
  )
46
  gr.Markdown("Upload an audio file (wav) to get speech-to-text results.")
47
 
48
  with gr.Row():
49
  with gr.Column():
50
+ audio_file = gr.Audio(label="Upload Audio", sources=["upload", "microphone"], type="filepath")
51
  #audio_file = gr.Audio(label="Upload wav file", sources=["upload"], type="filepath")
52
  asr_button = gr.Button("Start Recognition", variant="primary")
53
 
 
61
  )
62
 
63
 
64
+ if __name__ == "__main__":
65
+ # Download model
66
+ local_dir='pretrained_models/FireRedASR-AED-L'
67
+ snapshot_download(repo_id='FireRedTeam/FireRedASR-AED-L', local_dir=local_dir)
68
+ # Init model
69
+ init_model(local_dir)
70
+ # UI
71
+ demo.queue()
72
+ demo.launch()
fireredasr ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 1eadb81b66eca948cd492bc0aeedd786333c049d
pretrained_models/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Put pretrained models here.