yangrongzhao commited on
Commit ·
78224f0
1
Parent(s): 249d07c
Add provider, Update README, add RTF
Browse files- README.md +13 -15
- ax_common.py +10 -6
- run_ax.py +5 -0
README.md
CHANGED
|
@@ -1,14 +1,3 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
language:
|
| 4 |
-
- zh
|
| 5 |
-
pipeline_tag: automatic-speech-recognition
|
| 6 |
-
tags:
|
| 7 |
-
- wenet
|
| 8 |
-
- axera
|
| 9 |
-
- speech-recognition
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
# wenet.axera
|
| 13 |
|
| 14 |
WeNet on Axera.
|
|
@@ -140,23 +129,32 @@ python run_ort.py \
|
|
| 140 |
Offline CTC:
|
| 141 |
|
| 142 |
```bash
|
| 143 |
-
python3 run_ax.py -i demo.wav --mode ctc_prefix_beam_search
|
| 144 |
```
|
| 145 |
|
| 146 |
Online CTC:
|
| 147 |
|
| 148 |
```bash
|
| 149 |
-
python3 run_ax.py -i demo.wav --online --mode ctc_prefix_beam_search
|
| 150 |
```
|
| 151 |
|
| 152 |
Offline attention rescoring:
|
| 153 |
|
| 154 |
```bash
|
| 155 |
-
python3 run_ax.py -i demo.wav --mode attention_rescoring
|
| 156 |
```
|
| 157 |
|
| 158 |
Online attention rescoring:
|
| 159 |
|
| 160 |
```bash
|
| 161 |
-
python3 run_ax.py -i demo.wav --online --mode attention_rescoring
|
| 162 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# wenet.axera
|
| 2 |
|
| 3 |
WeNet on Axera.
|
|
|
|
| 129 |
Offline CTC:
|
| 130 |
|
| 131 |
```bash
|
| 132 |
+
python3 run_ax.py -i demo.wav --mode ctc_prefix_beam_search --provider AxEngineExecutionProvider
|
| 133 |
```
|
| 134 |
|
| 135 |
Online CTC:
|
| 136 |
|
| 137 |
```bash
|
| 138 |
+
python3 run_ax.py -i demo.wav --online --mode ctc_prefix_beam_search --provider AxEngineExecutionProvider
|
| 139 |
```
|
| 140 |
|
| 141 |
Offline attention rescoring:
|
| 142 |
|
| 143 |
```bash
|
| 144 |
+
python3 run_ax.py -i demo.wav --mode attention_rescoring --provider AxEngineExecutionProvider
|
| 145 |
```
|
| 146 |
|
| 147 |
Online attention rescoring:
|
| 148 |
|
| 149 |
```bash
|
| 150 |
+
python3 run_ax.py -i demo.wav --online --mode attention_rescoring --provider AxEngineExecutionProvider
|
| 151 |
```
|
| 152 |
+
|
| 153 |
+
RTF 测试结果,`demo.wav` 时长 4.204s,repeat 5,不含模型加载:
|
| 154 |
+
|
| 155 |
+
| 模式 | 平均耗时 | RTF |
|
| 156 |
+
| --- | ---: | ---: |
|
| 157 |
+
| offline CTC | 0.5202s | 0.1237 |
|
| 158 |
+
| online CTC | 0.5582s | 0.1328 |
|
| 159 |
+
| offline attention rescoring | 0.5266s | 0.1253 |
|
| 160 |
+
| online attention rescoring | 0.5626s | 0.1338 |
|
ax_common.py
CHANGED
|
@@ -468,10 +468,10 @@ def update_online_state(state, outputs):
|
|
| 468 |
|
| 469 |
class AxModel:
|
| 470 |
|
| 471 |
-
def __init__(self, path):
|
| 472 |
from axengine import InferenceSession
|
| 473 |
|
| 474 |
-
self.session = InferenceSession(path)
|
| 475 |
self.output_names = [item.name for item in self.session.get_outputs()]
|
| 476 |
|
| 477 |
def run(self, input_feed):
|
|
@@ -491,7 +491,8 @@ class WenetAXRunner:
|
|
| 491 |
decoder_len=32,
|
| 492 |
decoding_chunk_size=16,
|
| 493 |
num_decoding_left_chunks=5,
|
| 494 |
-
batch_size=1
|
|
|
|
| 495 |
self.config_path = config_path
|
| 496 |
self.vocab_path = vocab_path
|
| 497 |
self.encoder_offline_path = encoder_offline_path
|
|
@@ -502,6 +503,7 @@ class WenetAXRunner:
|
|
| 502 |
self.decoding_chunk_size = decoding_chunk_size
|
| 503 |
self.num_decoding_left_chunks = num_decoding_left_chunks
|
| 504 |
self.batch_size = batch_size
|
|
|
|
| 505 |
|
| 506 |
self.configs = load_config(config_path)
|
| 507 |
self.vocabulary, self.char_dict = load_vocab(vocab_path)
|
|
@@ -514,19 +516,21 @@ class WenetAXRunner:
|
|
| 514 |
@property
|
| 515 |
def offline_encoder(self):
|
| 516 |
if self._offline_encoder is None:
|
| 517 |
-
self._offline_encoder = AxModel(self.encoder_offline_path
|
|
|
|
| 518 |
return self._offline_encoder
|
| 519 |
|
| 520 |
@property
|
| 521 |
def online_encoder(self):
|
| 522 |
if self._online_encoder is None:
|
| 523 |
-
self._online_encoder = AxModel(self.encoder_online_path
|
|
|
|
| 524 |
return self._online_encoder
|
| 525 |
|
| 526 |
@property
|
| 527 |
def decoder(self):
|
| 528 |
if self._decoder is None:
|
| 529 |
-
self._decoder = AxModel(self.decoder_path)
|
| 530 |
return self._decoder
|
| 531 |
|
| 532 |
def compute_feats(self, audio_file):
|
|
|
|
| 468 |
|
| 469 |
class AxModel:
|
| 470 |
|
| 471 |
+
def __init__(self, path, provider="AxEngineExecutionProvider"):
|
| 472 |
from axengine import InferenceSession
|
| 473 |
|
| 474 |
+
self.session = InferenceSession(path, providers=[provider])
|
| 475 |
self.output_names = [item.name for item in self.session.get_outputs()]
|
| 476 |
|
| 477 |
def run(self, input_feed):
|
|
|
|
| 491 |
decoder_len=32,
|
| 492 |
decoding_chunk_size=16,
|
| 493 |
num_decoding_left_chunks=5,
|
| 494 |
+
batch_size=1,
|
| 495 |
+
provider="AxEngineExecutionProvider"):
|
| 496 |
self.config_path = config_path
|
| 497 |
self.vocab_path = vocab_path
|
| 498 |
self.encoder_offline_path = encoder_offline_path
|
|
|
|
| 503 |
self.decoding_chunk_size = decoding_chunk_size
|
| 504 |
self.num_decoding_left_chunks = num_decoding_left_chunks
|
| 505 |
self.batch_size = batch_size
|
| 506 |
+
self.provider = provider
|
| 507 |
|
| 508 |
self.configs = load_config(config_path)
|
| 509 |
self.vocabulary, self.char_dict = load_vocab(vocab_path)
|
|
|
|
| 516 |
@property
|
| 517 |
def offline_encoder(self):
|
| 518 |
if self._offline_encoder is None:
|
| 519 |
+
self._offline_encoder = AxModel(self.encoder_offline_path,
|
| 520 |
+
self.provider)
|
| 521 |
return self._offline_encoder
|
| 522 |
|
| 523 |
@property
|
| 524 |
def online_encoder(self):
|
| 525 |
if self._online_encoder is None:
|
| 526 |
+
self._online_encoder = AxModel(self.encoder_online_path,
|
| 527 |
+
self.provider)
|
| 528 |
return self._online_encoder
|
| 529 |
|
| 530 |
@property
|
| 531 |
def decoder(self):
|
| 532 |
if self._decoder is None:
|
| 533 |
+
self._decoder = AxModel(self.decoder_path, self.provider)
|
| 534 |
return self._decoder
|
| 535 |
|
| 536 |
def compute_feats(self, audio_file):
|
run_ax.py
CHANGED
|
@@ -29,6 +29,9 @@ def get_args():
|
|
| 29 |
parser.add_argument("--decoder_len", type=int, default=32)
|
| 30 |
parser.add_argument("--decoding_chunk_size", type=int, default=16)
|
| 31 |
parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
|
|
|
|
|
|
|
|
|
|
| 32 |
parser.add_argument("--mode",
|
| 33 |
choices=[
|
| 34 |
"ctc_greedy_search", "ctc_prefix_beam_search",
|
|
@@ -43,6 +46,7 @@ def main():
|
|
| 43 |
args = get_args()
|
| 44 |
print(f"online: {args.online}")
|
| 45 |
print(f"mode: {args.mode}")
|
|
|
|
| 46 |
|
| 47 |
runner = WenetAXRunner(
|
| 48 |
args.config,
|
|
@@ -54,6 +58,7 @@ def main():
|
|
| 54 |
decoder_len=args.decoder_len,
|
| 55 |
decoding_chunk_size=args.decoding_chunk_size,
|
| 56 |
num_decoding_left_chunks=args.num_decoding_left_chunks,
|
|
|
|
| 57 |
)
|
| 58 |
result = runner.transcribe(args.input,
|
| 59 |
online=args.online,
|
|
|
|
| 29 |
parser.add_argument("--decoder_len", type=int, default=32)
|
| 30 |
parser.add_argument("--decoding_chunk_size", type=int, default=16)
|
| 31 |
parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
|
| 32 |
+
parser.add_argument("--provider",
|
| 33 |
+
type=str,
|
| 34 |
+
default="AxEngineExecutionProvider")
|
| 35 |
parser.add_argument("--mode",
|
| 36 |
choices=[
|
| 37 |
"ctc_greedy_search", "ctc_prefix_beam_search",
|
|
|
|
| 46 |
args = get_args()
|
| 47 |
print(f"online: {args.online}")
|
| 48 |
print(f"mode: {args.mode}")
|
| 49 |
+
print(f"provider: {args.provider}")
|
| 50 |
|
| 51 |
runner = WenetAXRunner(
|
| 52 |
args.config,
|
|
|
|
| 58 |
decoder_len=args.decoder_len,
|
| 59 |
decoding_chunk_size=args.decoding_chunk_size,
|
| 60 |
num_decoding_left_chunks=args.num_decoding_left_chunks,
|
| 61 |
+
provider=args.provider,
|
| 62 |
)
|
| 63 |
result = runner.transcribe(args.input,
|
| 64 |
online=args.online,
|