Yuekai Zhang commited on
Commit
c397ab6
·
1 Parent(s): b0df9b2

add gpu support

Browse files
Dockerfile CHANGED
@@ -1,4 +1,9 @@
1
  FROM nvcr.io/nvidia/pytorch:22.12-py3
 
 
2
  COPY ./ /workspace/
3
  WORKDIR /workspace/
4
- RUN pip3 install -r requirements.txt
 
 
 
 
1
  FROM nvcr.io/nvidia/pytorch:22.12-py3
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && apt-get install -y ffmpeg
4
  COPY ./ /workspace/
5
  WORKDIR /workspace/
6
+ RUN pip3 install --no-cache-dir --upgrade -r requirements-gradio.txt
7
+ RUN chmod -R 777 /workspace/*
8
+
9
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from funasr_onnx import Fsmn_vad, Paraformer, CT_Transformer
2
+ from transcribe import get_models, transcribe
3
+ import soundfile
4
+ import gradio as gr
5
+ import pytube as pt
6
+ import datetime
7
+ import os
8
+
9
+ asr_model, vad_model, punc_model = get_models("./models")
10
+
11
+ def convert_to_wav(in_filename: str) -> str:
12
+ """Convert the input audio file to a wave file"""
13
+ out_filename = in_filename + ".wav"
14
+ if '.mp3' in in_filename:
15
+ _ = os.system(f"ffmpeg -y -i '{in_filename}' -acodec pcm_s16le -ac 1 -ar 16000 '{out_filename}'")
16
+ else:
17
+ _ = os.system(f"ffmpeg -hide_banner -y -i '{in_filename}' -ar 16000 '{out_filename}'")
18
+ speech, _ = soundfile.read(out_filename)
19
+ print(f"load speech shape {speech.shape}")
20
+ return speech
21
+
22
+ def file_transcribe(microphone, file_upload):
23
+ warn_output = ""
24
+ if (microphone is not None) and (file_upload is not None):
25
+ warn_output = (
26
+ "WARNING: You've uploaded an audio file and used the microphone. "
27
+ "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
28
+ )
29
+
30
+ elif (microphone is None) and (file_upload is None):
31
+ return "ERROR: You have to either use the microphone or upload an audio file"
32
+
33
+ file = microphone if microphone is not None else file_upload
34
+
35
+ speech = convert_to_wav(file)
36
+
37
+ items = []
38
+ vad_model.vad_scorer.AllResetDetection()
39
+ for item in transcribe(speech, asr_model, vad_model, punc_model):
40
+ items.append(item)
41
+ print(item)
42
+
43
+ text = "\n".join(items)
44
+
45
+ return warn_output + text
46
+
47
+
48
+ def _return_yt_html_embed(yt_url):
49
+ video_id = yt_url.split("?v=")[-1]
50
+ HTML_str = (
51
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
52
+ " </center>"
53
+ )
54
+ return HTML_str
55
+
56
+
57
+ def youtube_transcribe(yt_url):
58
+ yt = pt.YouTube(yt_url)
59
+ html_embed_str = _return_yt_html_embed(yt_url)
60
+ stream = yt.streams.filter(only_audio=True)[0]
61
+ filename = f"audio.mp3"
62
+ stream.download(filename=filename)
63
+
64
+ speech=convert_to_wav(filename)
65
+ items = []
66
+ vad_model.vad_scorer.AllResetDetection()
67
+ for item in transcribe(speech, asr_model, vad_model, punc_model):
68
+ items.append(item)
69
+ print(item)
70
+
71
+ text = "\n".join(items)
72
+ os.system(f"rm -rf audio.mp3 audio.mp3.wav")
73
+ return html_embed_str, text
74
+
75
+
76
+ def run():
77
+ gr.close_all()
78
+ demo = gr.Blocks()
79
+
80
+ mf_transcribe = gr.Interface(
81
+ fn=file_transcribe,
82
+ inputs=[
83
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True),
84
+ gr.inputs.Audio(source="upload", type="filepath", optional=True),
85
+ ],
86
+ outputs="text",
87
+ layout="horizontal",
88
+ theme="huggingface",
89
+ title="ParaformerX: Copilot for Audio",
90
+ description=(
91
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the the pretrained paraformer model to transcribe audio files of arbitrary length."
92
+ ),
93
+ allow_flagging="never",
94
+ )
95
+
96
+ yt_transcribe = gr.Interface(
97
+ fn=youtube_transcribe,
98
+ inputs=[gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
99
+ outputs=["html", "text"],
100
+ layout="horizontal",
101
+ theme="huggingface",
102
+ title="Demo: Transcribe YouTube",
103
+ description=(
104
+ "Transcribe long-form YouTube videos with the click of a button! Demo uses the the pretrained paraformer model to transcribe audio files of arbitrary length."
105
+ ),
106
+ allow_flagging="never",
107
+ )
108
+
109
+ with demo:
110
+ gr.TabbedInterface([mf_transcribe, yt_transcribe], ["Transcribe Audio", "Transcribe YouTube"])
111
+
112
+ demo.launch(server_name="0.0.0.0", server_port=7860, enable_queue=True)
113
+
114
+ if __name__ == "__main__":
115
+ run()
funasr_onnx/utils/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (164 Bytes)
 
funasr_onnx/utils/__pycache__/e2e_vad.cpython-38.pyc DELETED
Binary file (16.4 kB)
 
funasr_onnx/utils/__pycache__/frontend.cpython-38.pyc DELETED
Binary file (6.1 kB)
 
funasr_onnx/utils/__pycache__/postprocess_utils.cpython-38.pyc DELETED
Binary file (3.84 kB)
 
funasr_onnx/utils/__pycache__/timestamp_utils.cpython-38.pyc DELETED
Binary file (1.52 kB)
 
funasr_onnx/utils/__pycache__/utils.cpython-38.pyc DELETED
Binary file (10.8 kB)
 
requirements-gradio.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WeTextProcessing
2
+ onnxruntime-gpu
3
+ onnxruntime
4
+ soundfile
5
+ librosa
6
+ scipy
7
+ numpy
8
+ typeguard==2.13.3
9
+ kaldi-native-fbank
10
+ PyYAML>=5.1.2
11
+ gradio
12
+ pytube
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  WeTextProcessing
 
2
  onnxruntime
3
  soundfile
4
  librosa
 
1
  WeTextProcessing
2
+ onnxruntime-gpu
3
  onnxruntime
4
  soundfile
5
  librosa
transcribe.py CHANGED
@@ -3,6 +3,7 @@ from funasr_onnx import Fsmn_vad, Paraformer, CT_Transformer
3
  import datetime
4
  from itn.chinese.inverse_normalizer import InverseNormalizer
5
  import argparse
 
6
 
7
  def get_args():
8
  parser = argparse.ArgumentParser(
@@ -29,14 +30,17 @@ def process_time(milliseconds):
29
  delta = datetime.timedelta(milliseconds=milliseconds)
30
  time_str = str(delta)
31
  time_parts = time_str.split(".")[0].split(":")
32
- time_hms = "{:02d}:{:02d}:{:02d}".format(int(time_parts[0]), int(time_parts[1]), int(time_parts[2]))
33
  return time_hms
34
 
35
- def get_models(model_dir):
36
  vad_model_dir = model_dir + "/speech_fsmn_vad_zh-cn-16k-common-pytorch"
37
  asr_model_dir = model_dir + "/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
38
  punc_model_dir = model_dir + "/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
39
- asr_model = Paraformer(asr_model_dir, batch_size=1, plot_timestamp_to="./", pred_bias=0) # cpu
 
 
 
40
  punc_model= CT_Transformer(punc_model_dir)
41
  vad_model = Fsmn_vad(vad_model_dir)
42
  return asr_model, vad_model, punc_model
@@ -47,22 +51,61 @@ def load_audio(wav_path):
47
 
48
  def transcribe(speech, asr_model, vad_model=None, punc_model=None, invnormalizer=None):
49
  if vad_model:
 
50
  segments_info = vad_model(audio_in=speech)
51
- assert len(segments_info) == 1, "only support batch_size 1"
52
- for seg in segments_info[0]:
53
- if seg[1] == -1: # end of speech
54
- seg[1] = len(speech) // 16
55
- seg_speech = speech[seg[0]*16:seg[1]*16]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- result = asr_model(seg_speech)
58
- result = result[0]['preds'][0]
59
- if invnormalizer:
60
- result = invnormalizer.normalize(result)
61
- if punc_model:
62
- result = punc_model(result)
63
- result = result[0]
64
- item = f"{process_time(seg[0])}-->{process_time(seg[1])} {result}"
65
- yield item
 
66
 
67
  if __name__ == "__main__":
68
  args = get_args()
 
3
  import datetime
4
  from itn.chinese.inverse_normalizer import InverseNormalizer
5
  import argparse
6
+ import torch
7
 
8
  def get_args():
9
  parser = argparse.ArgumentParser(
 
30
  delta = datetime.timedelta(milliseconds=milliseconds)
31
  time_str = str(delta)
32
  time_parts = time_str.split(".")[0].split(":")
33
+ time_hms = "{:02d}:{:02d}:{:02d}:{:03d}".format(int(time_parts[0]), int(time_parts[1]), int(time_parts[2]), int(str(milliseconds)[-3:]))
34
  return time_hms
35
 
36
+ def get_models(model_dir, batch_size=16, enable_gpu=False):
37
  vad_model_dir = model_dir + "/speech_fsmn_vad_zh-cn-16k-common-pytorch"
38
  asr_model_dir = model_dir + "/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
39
  punc_model_dir = model_dir + "/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
40
+ if torch.cuda.is_available() and enable_gpu:
41
+ asr_model = Paraformer(asr_model_dir, batch_size=batch_size, device_id=0, plot_timestamp_to="./", pred_bias=0) # gpu
42
+ else:
43
+ asr_model = Paraformer(asr_model_dir, batch_size=1, plot_timestamp_to="./", pred_bias=0) # cpu
44
  punc_model= CT_Transformer(punc_model_dir)
45
  vad_model = Fsmn_vad(vad_model_dir)
46
  return asr_model, vad_model, punc_model
 
51
 
52
  def transcribe(speech, asr_model, vad_model=None, punc_model=None, invnormalizer=None):
53
  if vad_model:
54
+ vad_model.vad_scorer.AllResetDetection()
55
  segments_info = vad_model(audio_in=speech)
56
+ assert len(segments_info) == 1, "only support batch_size 1"
57
+ if asr_model.batch_size > 1:
58
+ all_results = []
59
+ assert torch.cuda.is_available(), "only support batch_size > 1 on gpu"
60
+ i, end, step = 0, len(segments_info[0]), asr_model.batch_size
61
+ while i < end:
62
+ sub_segments_info = segments_info[0][i:i+step]
63
+ seg_speech_list, duration = [], 0
64
+ for seg in sub_segments_info:
65
+ if seg[1] == -1: # end of speech
66
+ seg[1] = len(speech) // 16
67
+ seg_speech = speech[seg[0]*16:seg[1]*16]
68
+ duration += (seg[1] - seg[0]) /1000
69
+ if duration < 8 * asr_model.batch_size: # max audio length should never exceed 8s * batch_size
70
+ seg_speech_list.append(seg_speech)
71
+ i += 1
72
+ else:
73
+ break
74
+ assert seg_speech_list
75
+ result = asr_model(seg_speech_list)
76
+ all_results.extend(result)
77
+ assert len(all_results) == len(segments_info[0])
78
+ for i, seg in enumerate(segments_info[0]):
79
+ if seg[1] == -1: # end of speech
80
+ seg[1] = len(speech) // 16
81
+ result = all_results[i]['preds'][0]
82
+ if invnormalizer:
83
+ try:
84
+ result = invnormalizer.normalize(result)
85
+ except:
86
+ print("error in normalization")
87
+ if punc_model:
88
+ if result:
89
+ result = punc_model(result)
90
+ result = result[0]
91
+ item = f"{process_time(seg[0])}-->{process_time(seg[1])} {result}"
92
+ yield item
93
+ else:
94
+ for seg in segments_info[0]:
95
+ if seg[1] == -1: # end of speech
96
+ seg[1] = len(speech) // 16
97
+ seg_speech = speech[seg[0]*16:seg[1]*16]
98
 
99
+ result = asr_model(seg_speech)
100
+ result = result[0]['preds'][0]
101
+ if invnormalizer:
102
+ result = invnormalizer.normalize(result)
103
+ if punc_model:
104
+ if result:
105
+ result = punc_model(result)
106
+ result = result[0]
107
+ item = f"{process_time(seg[0])}-->{process_time(seg[1])} {result}"
108
+ yield item
109
 
110
  if __name__ == "__main__":
111
  args = get_args()