inoryQwQ commited on
Commit
22b36ed
·
1 Parent(s): b3b007c

fix providers, add test_wer.py

Browse files
Files changed (2) hide show
  1. SenseVoiceAx.py +2 -1
  2. test_wer.py +1 -2
SenseVoiceAx.py CHANGED
@@ -80,6 +80,7 @@ class SenseVoiceAx:
80
  hot_words: Optional[List[str]] = None,
81
  use_itn: bool = True,
82
  streaming: bool = False,
 
83
  ):
84
  """
85
  Initialize SenseVoiceAx
@@ -126,7 +127,7 @@ class SenseVoiceAx:
126
  lfr_m=7,
127
  lfr_n=6,
128
  )
129
- self.model = axe.InferenceSession(model_path)
130
  self.sample_rate = 16000
131
  self.blank_id = 0
132
  self.max_len = max_len
 
80
  hot_words: Optional[List[str]] = None,
81
  use_itn: bool = True,
82
  streaming: bool = False,
83
+ providers=['AxEngineExecutionProvider']
84
  ):
85
  """
86
  Initialize SenseVoiceAx
 
127
  lfr_m=7,
128
  lfr_n=6,
129
  )
130
+ self.model = axe.InferenceSession(model_path, providers=providers)
131
  self.sample_rate = 16000
132
  self.blank_id = 0
133
  self.max_len = max_len
test_wer.py CHANGED
@@ -251,9 +251,8 @@ def main():
251
  logger.info(f"use_itn: {use_itn}")
252
  logger.info(f"model_path: {model_path}")
253
 
254
- tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
255
  pipeline = SenseVoiceAx(
256
- model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256
257
  )
258
 
259
  # Iterate over dataset
 
251
  logger.info(f"use_itn: {use_itn}")
252
  logger.info(f"model_path: {model_path}")
253
 
 
254
  pipeline = SenseVoiceAx(
255
+ model_path, language=language
256
  )
257
 
258
  # Iterate over dataset