yangrongzhao commited on
Commit
78224f0
·
1 Parent(s): 249d07c

Add provider, Update README, add RTF

Browse files
Files changed (3) hide show
  1. README.md +13 -15
  2. ax_common.py +10 -6
  3. run_ax.py +5 -0
README.md CHANGED
@@ -1,14 +1,3 @@
1
- ---
2
- license: mit
3
- language:
4
- - zh
5
- pipeline_tag: automatic-speech-recognition
6
- tags:
7
- - wenet
8
- - axera
9
- - speech-recognition
10
- ---
11
-
12
  # wenet.axera
13
 
14
  WeNet on Axera.
@@ -140,23 +129,32 @@ python run_ort.py \
140
  Offline CTC:
141
 
142
  ```bash
143
- python3 run_ax.py -i demo.wav --mode ctc_prefix_beam_search
144
  ```
145
 
146
  Online CTC:
147
 
148
  ```bash
149
- python3 run_ax.py -i demo.wav --online --mode ctc_prefix_beam_search
150
  ```
151
 
152
  Offline attention rescoring:
153
 
154
  ```bash
155
- python3 run_ax.py -i demo.wav --mode attention_rescoring
156
  ```
157
 
158
  Online attention rescoring:
159
 
160
  ```bash
161
- python3 run_ax.py -i demo.wav --online --mode attention_rescoring
162
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # wenet.axera
2
 
3
  WeNet on Axera.
 
129
  Offline CTC:
130
 
131
  ```bash
132
+ python3 run_ax.py -i demo.wav --mode ctc_prefix_beam_search --provider AxEngineExecutionProvider
133
  ```
134
 
135
  Online CTC:
136
 
137
  ```bash
138
+ python3 run_ax.py -i demo.wav --online --mode ctc_prefix_beam_search --provider AxEngineExecutionProvider
139
  ```
140
 
141
  Offline attention rescoring:
142
 
143
  ```bash
144
+ python3 run_ax.py -i demo.wav --mode attention_rescoring --provider AxEngineExecutionProvider
145
  ```
146
 
147
  Online attention rescoring:
148
 
149
  ```bash
150
+ python3 run_ax.py -i demo.wav --online --mode attention_rescoring --provider AxEngineExecutionProvider
151
  ```
152
+
153
+ RTF 测试结果,`demo.wav` 时长 4.204s,repeat 5,不含模型加载:
154
+
155
+ | 模式 | 平均耗时 | RTF |
156
+ | --- | ---: | ---: |
157
+ | offline CTC | 0.5202s | 0.1237 |
158
+ | online CTC | 0.5582s | 0.1328 |
159
+ | offline attention rescoring | 0.5266s | 0.1253 |
160
+ | online attention rescoring | 0.5626s | 0.1338 |
ax_common.py CHANGED
@@ -468,10 +468,10 @@ def update_online_state(state, outputs):
468
 
469
  class AxModel:
470
 
471
- def __init__(self, path):
472
  from axengine import InferenceSession
473
 
474
- self.session = InferenceSession(path)
475
  self.output_names = [item.name for item in self.session.get_outputs()]
476
 
477
  def run(self, input_feed):
@@ -491,7 +491,8 @@ class WenetAXRunner:
491
  decoder_len=32,
492
  decoding_chunk_size=16,
493
  num_decoding_left_chunks=5,
494
- batch_size=1):
 
495
  self.config_path = config_path
496
  self.vocab_path = vocab_path
497
  self.encoder_offline_path = encoder_offline_path
@@ -502,6 +503,7 @@ class WenetAXRunner:
502
  self.decoding_chunk_size = decoding_chunk_size
503
  self.num_decoding_left_chunks = num_decoding_left_chunks
504
  self.batch_size = batch_size
 
505
 
506
  self.configs = load_config(config_path)
507
  self.vocabulary, self.char_dict = load_vocab(vocab_path)
@@ -514,19 +516,21 @@ class WenetAXRunner:
514
  @property
515
  def offline_encoder(self):
516
  if self._offline_encoder is None:
517
- self._offline_encoder = AxModel(self.encoder_offline_path)
 
518
  return self._offline_encoder
519
 
520
  @property
521
  def online_encoder(self):
522
  if self._online_encoder is None:
523
- self._online_encoder = AxModel(self.encoder_online_path)
 
524
  return self._online_encoder
525
 
526
  @property
527
  def decoder(self):
528
  if self._decoder is None:
529
- self._decoder = AxModel(self.decoder_path)
530
  return self._decoder
531
 
532
  def compute_feats(self, audio_file):
 
468
 
469
  class AxModel:
470
 
471
+ def __init__(self, path, provider="AxEngineExecutionProvider"):
472
  from axengine import InferenceSession
473
 
474
+ self.session = InferenceSession(path, providers=[provider])
475
  self.output_names = [item.name for item in self.session.get_outputs()]
476
 
477
  def run(self, input_feed):
 
491
  decoder_len=32,
492
  decoding_chunk_size=16,
493
  num_decoding_left_chunks=5,
494
+ batch_size=1,
495
+ provider="AxEngineExecutionProvider"):
496
  self.config_path = config_path
497
  self.vocab_path = vocab_path
498
  self.encoder_offline_path = encoder_offline_path
 
503
  self.decoding_chunk_size = decoding_chunk_size
504
  self.num_decoding_left_chunks = num_decoding_left_chunks
505
  self.batch_size = batch_size
506
+ self.provider = provider
507
 
508
  self.configs = load_config(config_path)
509
  self.vocabulary, self.char_dict = load_vocab(vocab_path)
 
516
  @property
517
  def offline_encoder(self):
518
  if self._offline_encoder is None:
519
+ self._offline_encoder = AxModel(self.encoder_offline_path,
520
+ self.provider)
521
  return self._offline_encoder
522
 
523
  @property
524
  def online_encoder(self):
525
  if self._online_encoder is None:
526
+ self._online_encoder = AxModel(self.encoder_online_path,
527
+ self.provider)
528
  return self._online_encoder
529
 
530
  @property
531
  def decoder(self):
532
  if self._decoder is None:
533
+ self._decoder = AxModel(self.decoder_path, self.provider)
534
  return self._decoder
535
 
536
  def compute_feats(self, audio_file):
run_ax.py CHANGED
@@ -29,6 +29,9 @@ def get_args():
29
  parser.add_argument("--decoder_len", type=int, default=32)
30
  parser.add_argument("--decoding_chunk_size", type=int, default=16)
31
  parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
 
 
 
32
  parser.add_argument("--mode",
33
  choices=[
34
  "ctc_greedy_search", "ctc_prefix_beam_search",
@@ -43,6 +46,7 @@ def main():
43
  args = get_args()
44
  print(f"online: {args.online}")
45
  print(f"mode: {args.mode}")
 
46
 
47
  runner = WenetAXRunner(
48
  args.config,
@@ -54,6 +58,7 @@ def main():
54
  decoder_len=args.decoder_len,
55
  decoding_chunk_size=args.decoding_chunk_size,
56
  num_decoding_left_chunks=args.num_decoding_left_chunks,
 
57
  )
58
  result = runner.transcribe(args.input,
59
  online=args.online,
 
29
  parser.add_argument("--decoder_len", type=int, default=32)
30
  parser.add_argument("--decoding_chunk_size", type=int, default=16)
31
  parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
32
+ parser.add_argument("--provider",
33
+ type=str,
34
+ default="AxEngineExecutionProvider")
35
  parser.add_argument("--mode",
36
  choices=[
37
  "ctc_greedy_search", "ctc_prefix_beam_search",
 
46
  args = get_args()
47
  print(f"online: {args.online}")
48
  print(f"mode: {args.mode}")
49
+ print(f"provider: {args.provider}")
50
 
51
  runner = WenetAXRunner(
52
  args.config,
 
58
  decoder_len=args.decoder_len,
59
  decoding_chunk_size=args.decoding_chunk_size,
60
  num_decoding_left_chunks=args.num_decoding_left_chunks,
61
+ provider=args.provider,
62
  )
63
  result = runner.transcribe(args.input,
64
  online=args.online,