Banafo commited on
Commit
b8a6bf1
·
verified ·
1 Parent(s): abf4082

Upload decode_file.py

Browse files
Files changed (1) hide show
  1. decode_file.py +200 -0
decode_file.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import wave
4
+ from pathlib import Path
5
+ from typing import Tuple
6
+
7
+ import numpy as np
8
+ import sherpa_onnx
9
+ from huggingface_hub import hf_hub_download
10
+
11
+
12
+ def get_args():
13
+ parser = argparse.ArgumentParser(
14
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
15
+ )
16
+
17
+ parser.add_argument(
18
+ "--lang",
19
+ type=str,
20
+ required=True,
21
+ help="Language code (e.g., 'en', 'fr', 'de')",
22
+ )
23
+
24
+ parser.add_argument(
25
+ "--hf-token",
26
+ type=str,
27
+ required=True,
28
+ help="Hugging Face access token for private model repository",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--num-threads",
33
+ type=int,
34
+ default=1,
35
+ help="Number of threads for neural network computation",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--decoding-method",
40
+ type=str,
41
+ default="greedy_search",
42
+ help="Valid values: greedy_search and modified_beam_search",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--max-active-paths",
47
+ type=int,
48
+ default=4,
49
+ help="Used only when --decoding-method is modified_beam_search.",
50
+ )
51
+
52
+ parser.add_argument(
53
+ "--lm",
54
+ type=str,
55
+ default="",
56
+ help="Used only when --decoding-method is modified_beam_search. Path of language model.",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--lm-scale",
61
+ type=float,
62
+ default=0.1,
63
+ help="Used only when --decoding-method is modified_beam_search. Scale of language model.",
64
+ )
65
+
66
+ parser.add_argument(
67
+ "--provider",
68
+ type=str,
69
+ default="cpu",
70
+ help="Valid values: cpu, cuda, coreml",
71
+ )
72
+
73
+ parser.add_argument(
74
+ "--hotwords-file",
75
+ type=str,
76
+ default="",
77
+ help="The file containing hotwords, one word/phrase per line.",
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--hotwords-score",
82
+ type=float,
83
+ default=1.5,
84
+ help="Hotword score for biasing word/phrase. Used only if --hotwords-file is given.",
85
+ )
86
+
87
+ parser.add_argument(
88
+ "sound_files",
89
+ type=str,
90
+ nargs="+",
91
+ help="The input sound file(s) to decode. Must be WAVE format, single channel, 16-bit.",
92
+ )
93
+
94
+ return parser.parse_args()
95
+
96
+
97
+ def assert_file_exists(filename: str):
98
+ assert Path(filename).is_file(), f"{filename} does not exist!"
99
+
100
+
101
+ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
102
+ with wave.open(wave_filename) as f:
103
+ assert f.getnchannels() == 1, f.getnchannels()
104
+ assert f.getsampwidth() == 2, f.getsampwidth()
105
+ num_samples = f.getnframes()
106
+ samples = f.readframes(num_samples)
107
+ samples_int16 = np.frombuffer(samples, dtype=np.int16)
108
+ samples_float32 = samples_int16.astype(np.float32) / 32768
109
+ return samples_float32, f.getframerate()
110
+
111
+
112
+ def download_models(language_code, hf_token):
113
+ """Downloads encoder, decoder, joiner, and tokens.txt from Hugging Face."""
114
+ repo_id = "Banafo/test-onnx"
115
+
116
+ model_filenames = {
117
+ "encoder": f"{language_code}_encoder.onnx",
118
+ "decoder": f"{language_code}_decoder.onnx",
119
+ "joiner": f"{language_code}_joiner.onnx",
120
+ "tokens": f"{language_code}_tokens.txt",
121
+ }
122
+
123
+ model_paths = {}
124
+ for model_name, filename in model_filenames.items():
125
+ print(f"Downloading {filename}...")
126
+ model_paths[model_name] = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token)
127
+ print(f"Loaded {filename}")
128
+
129
+ return model_paths
130
+
131
+
132
+ def main():
133
+ args = get_args()
134
+
135
+ # Download models and tokens file
136
+ model_paths = download_models(args.lang, args.hf_token)
137
+
138
+ # Initialize the transducer-based recognizer
139
+ recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
140
+ tokens=model_paths["tokens"],
141
+ encoder=model_paths["encoder"],
142
+ decoder=model_paths["decoder"],
143
+ joiner=model_paths["joiner"],
144
+ num_threads=args.num_threads,
145
+ provider=args.provider,
146
+ sample_rate=16000,
147
+ feature_dim=80,
148
+ decoding_method=args.decoding_method,
149
+ max_active_paths=args.max_active_paths,
150
+ lm=args.lm,
151
+ lm_scale=args.lm_scale,
152
+ hotwords_file=args.hotwords_file,
153
+ hotwords_score=args.hotwords_score,
154
+ )
155
+
156
+ print("Started!")
157
+ start_time = time.time()
158
+
159
+ streams = []
160
+ total_duration = 0
161
+ for wave_filename in args.sound_files:
162
+ assert_file_exists(wave_filename)
163
+ samples, sample_rate = read_wave(wave_filename)
164
+ duration = len(samples) / sample_rate
165
+ total_duration += duration
166
+
167
+ s = recognizer.create_stream()
168
+ s.accept_waveform(sample_rate, samples)
169
+
170
+ tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
171
+ s.accept_waveform(sample_rate, tail_paddings)
172
+ s.input_finished()
173
+
174
+ streams.append(s)
175
+
176
+ while True:
177
+ ready_list = [s for s in streams if recognizer.is_ready(s)]
178
+ if not ready_list:
179
+ break
180
+ recognizer.decode_streams(ready_list)
181
+
182
+ results = [recognizer.get_result(s) for s in streams]
183
+ end_time = time.time()
184
+ print("Done!")
185
+
186
+ for wave_filename, result in zip(args.sound_files, results):
187
+ print(f"{wave_filename}\n{result}")
188
+ print("-" * 10)
189
+
190
+ elapsed_seconds = end_time - start_time
191
+ rtf = elapsed_seconds / total_duration
192
+ print(f"num_threads: {args.num_threads}")
193
+ print(f"decoding_method: {args.decoding_method}")
194
+ print(f"Wave duration: {total_duration:.3f} s")
195
+ print(f"Elapsed time: {elapsed_seconds:.3f} s")
196
+ print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}")
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()