fix providers, add test_wer.py
Browse files- SenseVoiceAx.py +2 -1
- test_wer.py +1 -2
SenseVoiceAx.py
CHANGED
|
@@ -80,6 +80,7 @@ class SenseVoiceAx:
|
|
| 80 |
hot_words: Optional[List[str]] = None,
|
| 81 |
use_itn: bool = True,
|
| 82 |
streaming: bool = False,
|
|
|
|
| 83 |
):
|
| 84 |
"""
|
| 85 |
Initialize SenseVoiceAx
|
|
@@ -126,7 +127,7 @@ class SenseVoiceAx:
|
|
| 126 |
lfr_m=7,
|
| 127 |
lfr_n=6,
|
| 128 |
)
|
| 129 |
-
self.model = axe.InferenceSession(model_path)
|
| 130 |
self.sample_rate = 16000
|
| 131 |
self.blank_id = 0
|
| 132 |
self.max_len = max_len
|
|
|
|
| 80 |
hot_words: Optional[List[str]] = None,
|
| 81 |
use_itn: bool = True,
|
| 82 |
streaming: bool = False,
|
| 83 |
+
providers=['AxEngineExecutionProvider']
|
| 84 |
):
|
| 85 |
"""
|
| 86 |
Initialize SenseVoiceAx
|
|
|
|
| 127 |
lfr_m=7,
|
| 128 |
lfr_n=6,
|
| 129 |
)
|
| 130 |
+
self.model = axe.InferenceSession(model_path, providers=providers)
|
| 131 |
self.sample_rate = 16000
|
| 132 |
self.blank_id = 0
|
| 133 |
self.max_len = max_len
|
test_wer.py
CHANGED
|
@@ -251,9 +251,8 @@ def main():
|
|
| 251 |
logger.info(f"use_itn: {use_itn}")
|
| 252 |
logger.info(f"model_path: {model_path}")
|
| 253 |
|
| 254 |
-
tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
|
| 255 |
pipeline = SenseVoiceAx(
|
| 256 |
-
model_path, language=language
|
| 257 |
)
|
| 258 |
|
| 259 |
# Iterate over dataset
|
|
|
|
| 251 |
logger.info(f"use_itn: {use_itn}")
|
| 252 |
logger.info(f"model_path: {model_path}")
|
| 253 |
|
|
|
|
| 254 |
pipeline = SenseVoiceAx(
|
| 255 |
+
model_path, language=language
|
| 256 |
)
|
| 257 |
|
| 258 |
# Iterate over dataset
|