|
|
import argparse |
|
|
import time |
|
|
import wave |
|
|
from pathlib import Path |
|
|
from typing import Tuple |
|
|
|
|
|
import numpy as np |
|
|
import sherpa_onnx |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--lang", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Language code (e.g., 'en', 'fr', 'de')", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--hf-token", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Hugging Face access token for private model repository", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--num-threads", |
|
|
type=int, |
|
|
default=1, |
|
|
help="Number of threads for neural network computation", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--decoding-method", |
|
|
type=str, |
|
|
default="greedy_search", |
|
|
help="Valid values: greedy_search and modified_beam_search", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--max-active-paths", |
|
|
type=int, |
|
|
default=4, |
|
|
help="Used only when --decoding-method is modified_beam_search.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--lm", |
|
|
type=str, |
|
|
default="", |
|
|
help="Used only when --decoding-method is modified_beam_search. Path of language model.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--lm-scale", |
|
|
type=float, |
|
|
default=0.1, |
|
|
help="Used only when --decoding-method is modified_beam_search. Scale of language model.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--provider", |
|
|
type=str, |
|
|
default="cpu", |
|
|
help="Valid values: cpu, cuda, coreml", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--hotwords-file", |
|
|
type=str, |
|
|
default="", |
|
|
help="The file containing hotwords, one word/phrase per line.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--hotwords-score", |
|
|
type=float, |
|
|
default=1.5, |
|
|
help="Hotword score for biasing word/phrase. Used only if --hotwords-file is given.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"sound_files", |
|
|
type=str, |
|
|
nargs="+", |
|
|
help="The input sound file(s) to decode. Must be WAVE format, single channel, 16-bit.", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def assert_file_exists(filename: str): |
|
|
assert Path(filename).is_file(), f"{filename} does not exist!" |
|
|
|
|
|
|
|
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: |
|
|
with wave.open(wave_filename) as f: |
|
|
assert f.getnchannels() == 1, f.getnchannels() |
|
|
assert f.getsampwidth() == 2, f.getsampwidth() |
|
|
num_samples = f.getnframes() |
|
|
samples = f.readframes(num_samples) |
|
|
samples_int16 = np.frombuffer(samples, dtype=np.int16) |
|
|
samples_float32 = samples_int16.astype(np.float32) / 32768 |
|
|
return samples_float32, f.getframerate() |
|
|
|
|
|
|
|
|
def download_models(language_code, hf_token): |
|
|
"""Downloads encoder, decoder, joiner, and tokens.txt from Hugging Face.""" |
|
|
repo_id = "Banafo/test-onnx" |
|
|
|
|
|
model_filenames = { |
|
|
"encoder": f"{language_code}_encoder.onnx", |
|
|
"decoder": f"{language_code}_decoder.onnx", |
|
|
"joiner": f"{language_code}_joiner.onnx", |
|
|
"tokens": f"{language_code}_tokens.txt", |
|
|
} |
|
|
|
|
|
model_paths = {} |
|
|
for model_name, filename in model_filenames.items(): |
|
|
print(f"Downloading {filename}...") |
|
|
model_paths[model_name] = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token) |
|
|
print(f"Loaded {filename}") |
|
|
|
|
|
return model_paths |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = get_args() |
|
|
|
|
|
|
|
|
model_paths = download_models(args.lang, args.hf_token) |
|
|
|
|
|
|
|
|
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( |
|
|
tokens=model_paths["tokens"], |
|
|
encoder=model_paths["encoder"], |
|
|
decoder=model_paths["decoder"], |
|
|
joiner=model_paths["joiner"], |
|
|
num_threads=args.num_threads, |
|
|
provider=args.provider, |
|
|
sample_rate=16000, |
|
|
feature_dim=80, |
|
|
decoding_method=args.decoding_method, |
|
|
max_active_paths=args.max_active_paths, |
|
|
lm=args.lm, |
|
|
lm_scale=args.lm_scale, |
|
|
hotwords_file=args.hotwords_file, |
|
|
hotwords_score=args.hotwords_score, |
|
|
) |
|
|
|
|
|
print("Started!") |
|
|
start_time = time.time() |
|
|
|
|
|
streams = [] |
|
|
total_duration = 0 |
|
|
for wave_filename in args.sound_files: |
|
|
assert_file_exists(wave_filename) |
|
|
samples, sample_rate = read_wave(wave_filename) |
|
|
duration = len(samples) / sample_rate |
|
|
total_duration += duration |
|
|
|
|
|
s = recognizer.create_stream() |
|
|
s.accept_waveform(sample_rate, samples) |
|
|
|
|
|
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) |
|
|
s.accept_waveform(sample_rate, tail_paddings) |
|
|
s.input_finished() |
|
|
|
|
|
streams.append(s) |
|
|
|
|
|
while True: |
|
|
ready_list = [s for s in streams if recognizer.is_ready(s)] |
|
|
if not ready_list: |
|
|
break |
|
|
recognizer.decode_streams(ready_list) |
|
|
|
|
|
results = [recognizer.get_result(s) for s in streams] |
|
|
end_time = time.time() |
|
|
print("Done!") |
|
|
|
|
|
for wave_filename, result in zip(args.sound_files, results): |
|
|
print(f"{wave_filename}\n{result}") |
|
|
print("-" * 10) |
|
|
|
|
|
elapsed_seconds = end_time - start_time |
|
|
rtf = elapsed_seconds / total_duration |
|
|
print(f"num_threads: {args.num_threads}") |
|
|
print(f"decoding_method: {args.decoding_method}") |
|
|
print(f"Wave duration: {total_duration:.3f} s") |
|
|
print(f"Elapsed time: {elapsed_seconds:.3f} s") |
|
|
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|