qqc1989 commited on
Commit
566fca0
·
verified ·
1 Parent(s): fedb44c

Upload 30 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.axmodel filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ cpp/whisper filter=lfs diff=lfs merge=lfs -text
cpp/TSCharacters.ocd2 ADDED
Binary file (46.1 kB). View file
 
cpp/TSPhrases.ocd2 ADDED
Binary file (9.78 kB). View file
 
cpp/t2s.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Traditional Chinese to Simplified Chinese",
3
+ "segmentation": {
4
+ "type": "mmseg",
5
+ "dict": {
6
+ "type": "ocd2",
7
+ "file": "TSPhrases.ocd2"
8
+ }
9
+ },
10
+ "conversion_chain": [{
11
+ "dict": {
12
+ "type": "group",
13
+ "dicts": [{
14
+ "type": "ocd2",
15
+ "file": "TSPhrases.ocd2"
16
+ }, {
17
+ "type": "ocd2",
18
+ "file": "TSCharacters.ocd2"
19
+ }]
20
+ }
21
+ }]
22
+ }
cpp/whisper ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2cbd2f10e309e8bdc3c63989b0637311fe3eb2a39d03c76c17bc66ac86405bc
3
+ size 489848
models-ax630c/base-decoder-loop.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b12160aaa1ca31248a32ce05713fd72e273b16444389853c1f52990cf5130eb
3
+ size 130364397
models-ax630c/base-decoder-main.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:940f273d111e3aee53cdb692a384a29556981aa146afbb2f558f6aac262c0621
3
+ size 135675471
models-ax630c/base-encoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9f89ed5bbe31bcf98aa0e479ced1699b39816db2d3e2e2ff84c6e887af2b79b
3
+ size 56024079
models-ax630c/base-positional_embedding.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88fa1cdbf2b06f86b0ecb7be0fccfc39e906502986572b8cf5319c250e857169
3
+ size 917504
models-ax630c/base-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
models-ax650/small-decoder-loop.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b472a0f3539d17fece09e92bf6cd69ebf391928a6050896bbf86b558a25def22
3
+ size 269002567
models-ax650/small-decoder-main.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3bfc577f60c35192d8ce8cc24f9ca4aa84af72756ba11af9d178d337cb7eb1c
3
+ size 285531695
models-ax650/small-encoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b3bc8db9762f9b2dfe78bffbc8070fb877b2572c5288253573e49a8c7b37948
3
+ size 139705612
models-ax650/small-positional_embedding.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c10bc44f2bd94bdf1b7aa03581309fa536132b3fe79bfe22c9a6934a42cd8b58
3
+ size 1376256
models-ax650/small-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
models-onnx/base-decoder-loop.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1616a829b7d3d643616633551204b8d0f008fb7a7dc38919eda2e8c6c6ed9714
3
+ size 194571088
models-onnx/base-decoder-main.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1096b83590016bdbe74c66c7ccad1c0120abd6d37214560b1dfe4cd886a0e683
3
+ size 205485892
models-onnx/base-encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd4b51bd569e9b2b2d83a8ed56f3618811f0c593aa95c010069df675027b5f2b
3
+ size 95026988
models-onnx/base-positional_embedding.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88fa1cdbf2b06f86b0ecb7be0fccfc39e906502986572b8cf5319c250e857169
3
+ size 917504
models-onnx/base-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
models-onnx/small-positional_embedding.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c10bc44f2bd94bdf1b7aa03581309fa536132b3fe79bfe22c9a6934a42cd8b58
3
+ size 1376256
models-onnx/small-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
models-onnx/tiny-decoder-loop.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cbb3533939e2dfdf567b27762b12cf0956b7d7982bfb915228d24789f483058
3
+ size 112843354
models-onnx/tiny-decoder-main.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59ced1cf4e9a6f2aef0a2457f64f846e5682033abb4b894ba7680a60c792ad73
3
+ size 118301861
models-onnx/tiny-encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8030a6d1f3615b8a5e000995fee88357768c7dbaad05a79f853a4040c97087b
3
+ size 37606186
models-onnx/tiny-positional_embedding.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c13450ae630323a0bdd39b1226f92a7ac251131a909c7efdb7d2f5516736eb83
3
+ size 688128
models-onnx/tiny-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
python/languages.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WHISPER_LANGUAGES = {
2
+ "en": "english",
3
+ "zh": "chinese",
4
+ "de": "german",
5
+ "es": "spanish",
6
+ "ru": "russian",
7
+ "ko": "korean",
8
+ "fr": "french",
9
+ "ja": "japanese",
10
+ "pt": "portuguese",
11
+ "tr": "turkish",
12
+ "pl": "polish",
13
+ "ca": "catalan",
14
+ "nl": "dutch",
15
+ "ar": "arabic",
16
+ "sv": "swedish",
17
+ "it": "italian",
18
+ "id": "indonesian",
19
+ "hi": "hindi",
20
+ "fi": "finnish",
21
+ "vi": "vietnamese",
22
+ "he": "hebrew",
23
+ "uk": "ukrainian",
24
+ "el": "greek",
25
+ "ms": "malay",
26
+ "cs": "czech",
27
+ "ro": "romanian",
28
+ "da": "danish",
29
+ "hu": "hungarian",
30
+ "ta": "tamil",
31
+ "no": "norwegian",
32
+ "th": "thai",
33
+ "ur": "urdu",
34
+ "hr": "croatian",
35
+ "bg": "bulgarian",
36
+ "lt": "lithuanian",
37
+ "la": "latin",
38
+ "mi": "maori",
39
+ "ml": "malayalam",
40
+ "cy": "welsh",
41
+ "sk": "slovak",
42
+ "te": "telugu",
43
+ "fa": "persian",
44
+ "lv": "latvian",
45
+ "bn": "bengali",
46
+ "sr": "serbian",
47
+ "az": "azerbaijani",
48
+ "sl": "slovenian",
49
+ "kn": "kannada",
50
+ "et": "estonian",
51
+ "mk": "macedonian",
52
+ "br": "breton",
53
+ "eu": "basque",
54
+ "is": "icelandic",
55
+ "hy": "armenian",
56
+ "ne": "nepali",
57
+ "mn": "mongolian",
58
+ "bs": "bosnian",
59
+ "kk": "kazakh",
60
+ "sq": "albanian",
61
+ "sw": "swahili",
62
+ "gl": "galician",
63
+ "mr": "marathi",
64
+ "pa": "punjabi",
65
+ "si": "sinhala",
66
+ "km": "khmer",
67
+ "sn": "shona",
68
+ "yo": "yoruba",
69
+ "so": "somali",
70
+ "af": "afrikaans",
71
+ "oc": "occitan",
72
+ "ka": "georgian",
73
+ "be": "belarusian",
74
+ "tg": "tajik",
75
+ "sd": "sindhi",
76
+ "gu": "gujarati",
77
+ "am": "amharic",
78
+ "yi": "yiddish",
79
+ "lo": "lao",
80
+ "uz": "uzbek",
81
+ "fo": "faroese",
82
+ "ht": "haitian creole",
83
+ "ps": "pashto",
84
+ "tk": "turkmen",
85
+ "nn": "nynorsk",
86
+ "mt": "maltese",
87
+ "sa": "sanskrit",
88
+ "lb": "luxembourgish",
89
+ "my": "myanmar",
90
+ "bo": "tibetan",
91
+ "tl": "tagalog",
92
+ "mg": "malagasy",
93
+ "as": "assamese",
94
+ "tt": "tatar",
95
+ "haw": "hawaiian",
96
+ "ln": "lingala",
97
+ "ha": "hausa",
98
+ "ba": "bashkir",
99
+ "jw": "javanese",
100
+ "su": "sundanese",
101
+ "yue": "cantonese",
102
+ }
python/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.26.4
2
+ soundfile
3
+ librosa
4
+ zhconv
python/whisper.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import axengine as axe
3
+ import numpy as np
4
+ import librosa
5
+ import os
6
+ from typing import Tuple
7
+ import soundfile as sf
8
+ import base64
9
+ import zhconv
10
+ import time
11
+ from languages import WHISPER_LANGUAGES
12
+
13
+
14
+ WHISPER_N_MELS = 80
15
+ WHISPER_SAMPLE_RATE = 16000
16
+ WHISPER_N_FFT = 480
17
+ WHISPER_HOP_LENGTH = 160
18
+
19
+ WHISPER_SOT = 50258
20
+ WHISPER_EOT = 50257
21
+ WHISPER_BLANK = 220
22
+ WHISPER_NO_TIMESTAMPS = 50363
23
+ WHISPER_NO_SPEECH = 50362
24
+ WHISPER_TRANSLATE = 50358
25
+ WHISPER_TRANSCRIBE = 50359
26
+ WHISPER_VOCAB_SIZE = 51865
27
+ WHISPER_N_TEXT_CTX = 448
28
+
29
+ NEG_INF = float("-inf")
30
+ SOT_SEQUENCE = np.array([WHISPER_SOT,WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES).index("zh"),WHISPER_TRANSCRIBE,WHISPER_NO_TIMESTAMPS], dtype=np.int32)
31
+ WHISPER_N_TEXT_STATE_MAP = {
32
+ "tiny": 384,
33
+ "base": 512,
34
+ "small": 768
35
+ }
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(
40
+ prog="whisper",
41
+ description="Run Whisper on input audio file"
42
+ )
43
+ parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
44
+ parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small"], required=True, help="model type, only support tiny, base and small currently")
45
+ parser.add_argument("--model_path", "-p", type=str, required=False, default="../models", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
46
+ parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
47
+ return parser.parse_args()
48
+
49
+
50
+ def print_args(args):
51
+ print(f"wav: {args.wav}")
52
+ print(f"model_type: {args.model_type}")
53
+ print(f"model_path: {args.model_path}")
54
+ print(f"language: {args.language}")
55
+
56
+
57
+ def load_audio(filename: str) -> Tuple[np.ndarray, int]:
58
+ data, sample_rate = sf.read(
59
+ filename,
60
+ always_2d=True,
61
+ dtype="float32",
62
+ )
63
+ data = data[:, 0] # use only the first channel
64
+ data = librosa.resample(data, orig_sr=sample_rate, target_sr=WHISPER_SAMPLE_RATE)
65
+ samples = np.ascontiguousarray(data)
66
+ return samples, sample_rate
67
+
68
+
69
+ def load_models(model_path, model_type):
70
+ encoder_path = f"{model_type}-encoder.axmodel"
71
+ decoder_main_path = f"{model_type}-decoder-main.axmodel"
72
+ decoder_loop_path = f"{model_type}-decoder-loop.axmodel"
73
+ pe_path = f"{model_type}-positional_embedding.bin"
74
+ token_path = f"{model_type}-tokens.txt"
75
+
76
+ required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, token_path)]
77
+ # Check file existence
78
+ for i, file_path in enumerate(required_files):
79
+ assert os.path.exists(file_path), f"{file_path} NOT exist"
80
+
81
+ # Load encoder
82
+ encoder = axe.InferenceSession(required_files[0])
83
+ # Load decoder main
84
+ decoder_main = axe.InferenceSession(required_files[1])
85
+ # Load decoder loop
86
+ decoder_loop = axe.InferenceSession(required_files[2])
87
+ # Load position embedding
88
+ pe = np.fromfile(required_files[3], dtype=np.float32)
89
+ # Load tokens
90
+ tokens = []
91
+ with open(required_files[4], "r") as f:
92
+ for line in f:
93
+ line = line.strip()
94
+ tokens.append(line.split(" ")[0])
95
+
96
+ return encoder, decoder_main, decoder_loop, pe, tokens
97
+
98
+
99
+ def compute_feature(wav_path, n_mels = WHISPER_N_MELS, padding = 480000):
100
+ audio, sr = load_audio(wav_path)
101
+
102
+ audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
103
+
104
+ mel = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=WHISPER_N_FFT, hop_length=WHISPER_HOP_LENGTH, window="hann", center=True, pad_mode="reflect", power=2.0, n_mels=n_mels)
105
+ log_spec = np.log10(np.maximum(mel, 1e-10))
106
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
107
+ mel = (log_spec + 4.0) / 4.0
108
+
109
+ # We pad 1500 frames at the end so that it is able to detect eot
110
+ # You can use another value instead of 1500.
111
+ # mel = np.concatenate((mel, np.zeros((n_mels, 1500), dtype=np.float32)), axis=-1)
112
+
113
+ target = 3000
114
+ if mel.shape[1] > target:
115
+ # -50 so that there are some zero tail paddings.
116
+ mel = mel[:, : target]
117
+ mel[:, -50:] = 0
118
+
119
+ # We don't need to pad it to 30 seconds now!
120
+ if mel.shape[1] < target:
121
+ mel = np.concatenate((mel, np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
122
+
123
+ return mel
124
+
125
+
126
+ def supress_tokens(logits, is_initial):
127
+ if is_initial:
128
+ logits[WHISPER_EOT] = NEG_INF
129
+ logits[WHISPER_BLANK] = NEG_INF
130
+
131
+ logits[WHISPER_NO_TIMESTAMPS] = NEG_INF
132
+ logits[WHISPER_SOT] = NEG_INF
133
+ logits[WHISPER_NO_SPEECH] = NEG_INF
134
+ logits[WHISPER_TRANSLATE] = NEG_INF
135
+ return logits
136
+
137
+
138
+ def choose_language(lang):
139
+ if lang not in WHISPER_LANGUAGES.keys():
140
+ raise Exception(f"Unknown language: {lang}. Check languages.py for correct options.")
141
+ SOT_SEQUENCE[1] = WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES.keys()).index(lang)
142
+
143
+
144
+ def main():
145
+ args = get_args()
146
+ print_args(args)
147
+
148
+ # Check wav existence
149
+ wav_path = args.wav
150
+ assert os.path.exists(wav_path), f"{wav_path} NOT exist"
151
+
152
+ # Choose language
153
+ choose_language(args.language)
154
+
155
+ # Load models and other stuff
156
+ start = time.time()
157
+ encoder, decoder_main, decoder_loop, pe, token_table = load_models(args.model_path, args.model_type)
158
+ print(f"Load models take {(time.time() - start) * 1000}ms")
159
+ WHISPER_N_TEXT_STATE = WHISPER_N_TEXT_STATE_MAP[args.model_type]
160
+
161
+ # Preprocess
162
+ start = time.time()
163
+ mel = compute_feature(wav_path, n_mels=WHISPER_N_MELS)
164
+ print(f"Preprocess wav take {(time.time() - start) * 1000}ms")
165
+ # mel.tofile("mel.bin")
166
+
167
+ # Run encoder
168
+ start = time.time()
169
+ x = encoder.run(None, input_feed={"mel": mel[None, ...]})
170
+ n_layer_cross_k, n_layer_cross_v = x
171
+ print(f"Run encoder take {(time.time() - start) * 1000}ms")
172
+
173
+ # n_layer_cross_k.tofile("n_layer_cross_k.bin")
174
+ # n_layer_cross_v.tofile("n_layer_cross_v.bin")
175
+
176
+ # Run decoder_main
177
+ start = time.time()
178
+ x = decoder_main.run(None, input_feed={
179
+ "tokens": SOT_SEQUENCE[None, ...],
180
+ "n_layer_cross_k": n_layer_cross_k,
181
+ "n_layer_cross_v": n_layer_cross_v
182
+ })
183
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = x
184
+ print(f"Run decoder_main take {(time.time() - start) * 1000}ms")
185
+
186
+ # Decode token
187
+ logits = logits[0, -1, :]
188
+ logits = supress_tokens(logits, is_initial=True)
189
+ # logits.tofile("logits.bin")
190
+ max_token_id = np.argmax(logits)
191
+ output_tokens = []
192
+ print(f"First token: {max_token_id}")
193
+
194
+ # Position embedding offset
195
+ offset = SOT_SEQUENCE.shape[0]
196
+
197
+ # Autoregressively run decoder until token meets EOT
198
+ for i in range(WHISPER_N_TEXT_CTX - SOT_SEQUENCE.shape[0]):
199
+ if max_token_id == WHISPER_EOT:
200
+ break
201
+
202
+ output_tokens.append(max_token_id)
203
+
204
+ mask = np.zeros((WHISPER_N_TEXT_CTX,), dtype=np.float32)
205
+ mask[: WHISPER_N_TEXT_CTX - offset - 1] = NEG_INF
206
+
207
+ # Run decoder_loop
208
+ start = time.time()
209
+ x = decoder_loop.run(None, input_feed={
210
+ "tokens": np.array([[output_tokens[-1]]], dtype=np.int32),
211
+ "in_n_layer_self_k_cache": n_layer_self_k_cache,
212
+ "in_n_layer_self_v_cache": n_layer_self_v_cache,
213
+ "n_layer_cross_k": n_layer_cross_k,
214
+ "n_layer_cross_v": n_layer_cross_v,
215
+ "positional_embedding": pe[offset * WHISPER_N_TEXT_STATE : (offset + 1) * WHISPER_N_TEXT_STATE][None, ...],
216
+ "mask": mask
217
+ })
218
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = x
219
+ print(f"Run decoder_loop take {(time.time() - start) * 1000}ms")
220
+
221
+ # Decode token
222
+ offset += 1
223
+ logits = supress_tokens(logits.flatten(), is_initial=False)
224
+ max_token_id = np.argmax(logits)
225
+
226
+ print(f"Iter {i} \t Token: {max_token_id}")
227
+
228
+ s = b""
229
+ for i in output_tokens:
230
+ s += base64.b64decode(token_table[i])
231
+ # print(s.decode().strip())
232
+ pd = s.decode().strip()
233
+ if args.language == "zh":
234
+ pd = zhconv.convert(pd, 'zh-hans')
235
+
236
+ print(f"Result: {pd}")
237
+
238
+
239
+ if __name__ == "__main__":
240
+ main()
python/whisper_onnx.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ import librosa
5
+ import os
6
+ from typing import Tuple
7
+ import soundfile as sf
8
+ import base64
9
+ import zhconv
10
+ import time
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from languages import WHISPER_LANGUAGES
14
+
15
+
16
+ WHISPER_N_MELS = 80
17
+ WHISPER_SAMPLE_RATE = 16000
18
+ WHISPER_N_FFT = 480
19
+ WHISPER_HOP_LENGTH = 160
20
+
21
+ WHISPER_SOT = 50258
22
+ WHISPER_EOT = 50257
23
+ WHISPER_BLANK = 220
24
+ WHISPER_NO_TIMESTAMPS = 50363
25
+ WHISPER_NO_SPEECH = 50362
26
+ WHISPER_TRANSLATE = 50358
27
+ WHISPER_TRANSCRIBE = 50359
28
+ WHISPER_VOCAB_SIZE = 51865
29
+ WHISPER_N_TEXT_CTX = 448
30
+
31
+ NEG_INF = float("-inf")
32
+ SOT_SEQUENCE = np.array([WHISPER_SOT,WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES).index("zh"),WHISPER_TRANSCRIBE,WHISPER_NO_TIMESTAMPS], dtype=np.int64)
33
+ WHISPER_N_TEXT_STATE_MAP = {
34
+ "tiny": 384,
35
+ "base": 512,
36
+ "small": 768
37
+ }
38
+
39
+
40
+ def get_args():
41
+ parser = argparse.ArgumentParser(
42
+ prog="whisper",
43
+ description="Run Whisper on input audio file"
44
+ )
45
+ parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
46
+ parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small"], required=True, help="model type, only support tiny/base/small currently")
47
+ parser.add_argument("--model_path", "-p", type=str, required=False, default="../models", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
48
+ parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
49
+ return parser.parse_args()
50
+
51
+
52
+ def print_args(args):
53
+ print(f"wav: {args.wav}")
54
+ print(f"model_type: {args.model_type}")
55
+ print(f"model_path: {args.model_path}")
56
+ print(f"language: {args.language}")
57
+
58
+
59
+ def load_audio(filename: str) -> Tuple[np.ndarray, int]:
60
+ data, sample_rate = sf.read(
61
+ filename,
62
+ always_2d=True,
63
+ dtype="float32",
64
+ )
65
+ data = data[:, 0] # use only the first channel
66
+ data = librosa.resample(data, orig_sr=sample_rate, target_sr=WHISPER_SAMPLE_RATE)
67
+ samples = np.ascontiguousarray(data)
68
+ return samples, sample_rate
69
+
70
+
71
+ def load_models(model_path, model_type):
72
+ encoder_path = f"{model_type}-encoder.onnx"
73
+ decoder_main_path = f"{model_type}-decoder-main.onnx"
74
+ decoder_loop_path = f"{model_type}-decoder-loop.onnx"
75
+ pe_path = f"{model_type}-positional_embedding.bin"
76
+ token_path = f"{model_type}-tokens.txt"
77
+
78
+ required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, token_path)]
79
+ # Check file existence
80
+ for i, file_path in enumerate(required_files):
81
+ assert os.path.exists(file_path), f"{file_path} NOT exist"
82
+
83
+ # Load encoder
84
+ encoder = ort.InferenceSession(required_files[0], providers=['CPUExecutionProvider'])
85
+ # Load decoder main
86
+ decoder_main = ort.InferenceSession(required_files[1], providers=['CPUExecutionProvider'])
87
+ # Load decoder loop
88
+ decoder_loop = ort.InferenceSession(required_files[2], providers=['CPUExecutionProvider'])
89
+ # Load position embedding
90
+ pe = np.fromfile(required_files[3], dtype=np.float32)
91
+ # Load tokens
92
+ tokens = []
93
+ with open(required_files[4], "r") as f:
94
+ for line in f:
95
+ line = line.strip()
96
+ tokens.append(line.split(" ")[0])
97
+
98
+ return encoder, decoder_main, decoder_loop, pe, tokens
99
+
100
+
101
+ def compute_feature(wav_path, n_mels = WHISPER_N_MELS, padding = 480000):
102
+ audio, sr = load_audio(wav_path)
103
+
104
+ audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
105
+
106
+ mel = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=WHISPER_N_FFT, hop_length=WHISPER_HOP_LENGTH, window="hann", center=True, pad_mode="reflect", power=2.0, n_mels=n_mels)
107
+ log_spec = np.log10(np.maximum(mel, 1e-10))
108
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
109
+ mel = (log_spec + 4.0) / 4.0
110
+
111
+ # We pad 1500 frames at the end so that it is able to detect eot
112
+ # You can use another value instead of 1500.
113
+ # mel = np.concatenate((mel, np.zeros((n_mels, 1500), dtype=np.float32)), axis=-1)
114
+
115
+ target = 3000
116
+ if mel.shape[1] > target:
117
+ # -50 so that there are some zero tail paddings.
118
+ mel = mel[:, : target]
119
+ mel[:, -50:] = 0
120
+
121
+ # We don't need to pad it to 30 seconds now!
122
+ if mel.shape[1] < target:
123
+ mel = np.concatenate((mel, np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
124
+
125
+ return mel
126
+
127
+
128
+ def supress_tokens(logits, is_initial):
129
+ if is_initial:
130
+ logits[WHISPER_EOT] = NEG_INF
131
+ logits[WHISPER_BLANK] = NEG_INF
132
+
133
+ logits[WHISPER_NO_TIMESTAMPS] = NEG_INF
134
+ logits[WHISPER_SOT] = NEG_INF
135
+ logits[WHISPER_NO_SPEECH] = NEG_INF
136
+ logits[WHISPER_TRANSLATE] = NEG_INF
137
+ return logits
138
+
139
+
140
+ def choose_language(lang):
141
+ if lang not in WHISPER_LANGUAGES.keys():
142
+ raise Exception(f"Unknown language: {lang}. Check languages.py for correct options.")
143
+ SOT_SEQUENCE[1] = WHISPER_SOT + 1 + tuple(WHISPER_LANGUAGES.keys()).index(lang)
144
+
145
+
146
+ def main():
147
+ args = get_args()
148
+ print_args(args)
149
+
150
+ # Check wav existence
151
+ wav_path = args.wav
152
+ assert os.path.exists(wav_path), f"{wav_path} NOT exist"
153
+
154
+ # Choose language
155
+ choose_language(args.language)
156
+
157
+ # Load models and other stuff
158
+ encoder, decoder_main, decoder_loop, pe, token_table = load_models(args.model_path, args.model_type)
159
+ WHISPER_N_TEXT_STATE = WHISPER_N_TEXT_STATE_MAP[args.model_type]
160
+
161
+ # Preprocess
162
+ mel = compute_feature(wav_path, n_mels=WHISPER_N_MELS)
163
+ # mel.tofile("mel.bin")
164
+ # mel = np.load("../mel.npy")[..., :3000]
165
+
166
+ # Run encoder
167
+ start = time.time()
168
+ x = encoder.run(None, input_feed={"mel": mel[None, ...]})
169
+ n_layer_cross_k, n_layer_cross_v = x
170
+ print(f"Run encoder take {(time.time() - start) * 1000}ms")
171
+
172
+ # n_layer_cross_k.tofile("n_layer_cross_k.bin")
173
+ # n_layer_cross_v.tofile("n_layer_cross_v.bin")
174
+
175
+ # Run decoder_main
176
+ start = time.time()
177
+ x = decoder_main.run(None, input_feed={
178
+ "tokens": SOT_SEQUENCE[None, ...],
179
+ "n_layer_cross_k": n_layer_cross_k,
180
+ "n_layer_cross_v": n_layer_cross_v
181
+ })
182
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = x
183
+ print(f"Run decoder_main take {(time.time() - start) * 1000}ms")
184
+
185
+ # Decode token
186
+ logits = logits[0, -1, :]
187
+ logits = supress_tokens(logits, is_initial=True)
188
+ # logits.tofile("logits.bin")
189
+ max_token_id = np.argmax(logits)
190
+ output_tokens = []
191
+ print(f"First token: {max_token_id}")
192
+
193
+ # Position embedding offset
194
+ offset = SOT_SEQUENCE.shape[0]
195
+
196
+ # Autoregressively run decoder until token meets EOT
197
+ for i in range(WHISPER_N_TEXT_CTX - SOT_SEQUENCE.shape[0]):
198
+ if max_token_id == WHISPER_EOT:
199
+ break
200
+
201
+ output_tokens.append(max_token_id)
202
+
203
+ mask = np.zeros((WHISPER_N_TEXT_CTX,), dtype=np.float32)
204
+ mask[: WHISPER_N_TEXT_CTX - offset - 1] = NEG_INF
205
+
206
+ # Run decoder_loop
207
+ start = time.time()
208
+ x = decoder_loop.run(None, input_feed={
209
+ "tokens": np.array([[output_tokens[-1]]], dtype=np.int64),
210
+ "in_n_layer_self_k_cache": n_layer_self_k_cache,
211
+ "in_n_layer_self_v_cache": n_layer_self_v_cache,
212
+ "n_layer_cross_k": n_layer_cross_k,
213
+ "n_layer_cross_v": n_layer_cross_v,
214
+ "positional_embedding": pe[offset * WHISPER_N_TEXT_STATE : (offset + 1) * WHISPER_N_TEXT_STATE][None, ...],
215
+ "mask": mask
216
+ })
217
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = x
218
+ print(f"Run decoder_loop take {(time.time() - start) * 1000}ms")
219
+
220
+ # Decode token
221
+ offset += 1
222
+ logits = supress_tokens(logits.flatten(), is_initial=False)
223
+ max_token_id = np.argmax(logits)
224
+
225
+ print(f"Iter {i} \t Token: {max_token_id}")
226
+
227
+ s = b""
228
+ for i in output_tokens:
229
+ s += base64.b64decode(token_table[i])
230
+ # print(s.decode().strip())
231
+ pd = s.decode().strip()
232
+ if args.language == "zh":
233
+ pd = zhconv.convert(pd, 'zh-hans')
234
+
235
+ print(f"Result: {pd}")
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main()