lxowalle commited on
Commit
6c5434c
·
1 Parent(s): 5a18609

* optmize models path

Browse files
README.md CHANGED
@@ -4,93 +4,6 @@ language:
4
  - en
5
  pipeline_tag: automatic-speech-recognition
6
  ---
7
- # sensevoice.axera
8
- FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
9
 
10
- ## TODO
11
-
12
- - [x] 支持AX630C
13
- - [ ] 支持C++
14
- - [x] 支持FastAPI
15
-
16
- ## 功能
17
- - 语音识别
18
- - 自动识别语言(支持中文、英文、粤语、日语、韩语)
19
- - 情感识别
20
- - 自动标点
21
- - 支持流式识别
22
-
23
- ## 支持平台
24
-
25
- - [x] AX650N
26
- - [x] AX630C
27
-
28
- ## 环境安装
29
- ```
30
- pip3 install -r requirements.txt
31
- ```
32
- 如果空间不足可以使用 --prefix 指定别的安装路径
33
-
34
-
35
- ## 使用
36
- ```
37
- # 首次运行会自动从huggingface上下载模型, 保存到models中
38
- python3 main.py -i 输入音频文件
39
- ```
40
- 运行参数说明:
41
- | 参数名称 | 说明 | 默认值 |
42
- | --- | --- | --- |
43
- | --input/-i | 输入音频文件 | |
44
- | --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
45
- | --streaming | 流式识别 | |
46
-
47
-
48
- ### 示例:
49
- example下有测试音频
50
-
51
- 如 粤语测试
52
- ```
53
- python3 main.py -i example/yue.mp3
54
- ```
55
- 输出
56
- ```
57
- RTF: 0.03026517820946964 Latency: 0.15689468383789062s Total length: 5.184s
58
- ['呢几个字。', '都表达唔到,我想讲嘅意。', '思。']
59
- ```
60
-
61
- 流式识别
62
-
63
- ```
64
- python3 main.py -i example/zh.mp3 --streaming
65
- ```
66
- 输出
67
- ```
68
- {'timestamps': [540], 'text': '开'}
69
- {'timestamps': [540, 780, 1080], 'text': '开放时'}
70
- {'timestamps': [540, 780, 1080, 1260, 1740], 'text': '开放时间早'}
71
- {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340], 'text': '开放时间早上9'}
72
- {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640], 'text': '开放时间早上9点'}
73
- {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060], 'text': '开放时间早上9点至'}
74
- {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020], 'text': '开放时间早上9点至下午'}
75
- {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020, 4440, 4620], 'text': '开放时间早上9点至下午五点'}
76
- RTF: 0.03678379235444246
77
-
78
- ```
79
-
80
- ## 准确率
81
-
82
- 使用WER(Word-Error-Rate)作为评价标准
83
-
84
- **WER = 0.0389**
85
-
86
- ### 复现测试结果
87
-
88
- ```
89
- ./download_datasets.sh
90
- python test_wer.py -d datasets -l zh
91
- ```
92
-
93
- ## 技术讨论
94
-
95
- - Github issues
96
- - QQ 群: 139953715
 
4
  - en
5
  pipeline_tag: automatic-speech-recognition
6
  ---
7
+ # SenseVoice
 
8
 
9
+ Refer here[](https://wiki.sipeed.com/maixpy/doc/en/mllm/asr_sensevoice.html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README_ZH.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: automatic-speech-recognition
6
+ ---
7
+ # SenseVoice
8
+
9
+ 使用文档参考[这里](https://wiki.sipeed.com/maixpy/doc/zh/mllm/asr_sensevoice.html)
config.json DELETED
File without changes
download_dataset.sh DELETED
@@ -1,2 +0,0 @@
1
- wget https://github.com/ml-inory/whisper.axera/releases/download/v1.0/datasets.zip
2
- unzip datasets.zip -d ./
 
 
 
download_utils.py DELETED
@@ -1,33 +0,0 @@
1
- import os
2
-
3
- # Speed up hf download using mirror url
4
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
5
- from huggingface_hub import snapshot_download
6
-
7
- current_file_path = os.path.dirname(__file__)
8
- REPO_ROOT = "AXERA-TECH"
9
- CACHE_PATH = os.path.join(current_file_path, "models")
10
-
11
-
12
- def download_model(model_name: str) -> str:
13
- """
14
- Download model from AXERA-TECH's huggingface space.
15
-
16
- model_name: str
17
- Available model names could be checked on https://huggingface.co/AXERA-TECH.
18
-
19
- Returns:
20
- str: Path to model_name
21
-
22
- """
23
- os.makedirs(CACHE_PATH, exist_ok=True)
24
-
25
- model_path = os.path.join(CACHE_PATH, model_name)
26
- if not os.path.exists(model_path):
27
- print(f"Downloading {model_name}...")
28
- snapshot_download(
29
- repo_id=f"{REPO_ROOT}/{model_name}",
30
- local_dir=os.path.join(CACHE_PATH, model_name),
31
- )
32
-
33
- return model_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example/en.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f10378336a4e584f3f63799e62f99d5add3c2a401b51d3abe7d3a3a82f255ada
3
- size 57441
 
 
 
 
example/ja.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:496dbc43b289e1d0d0cb916df9737450bca56acd8aaca046a7a2472363b1be53
3
- size 57837
 
 
 
 
example/ko.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8612f62db8319a6cb4ab4b1d2039bfc32f174f89611889ddafdeb5c0a6070b5f
3
- size 27909
 
 
 
 
example/yue.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5098eebc13530a66e4eac1f30d3246e65c9cfc4e096665f9d395aca8eff0d181
3
- size 31246
 
 
 
 
example/zh.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e64de19e4ff9a02e682955c9112f32d2317cfdbb5bc2f3504664044c993f195
3
- size 44973
 
 
 
 
gradio_demo.py DELETED
@@ -1,62 +0,0 @@
1
- import gradio as gr
2
- import os
3
- from SenseVoiceAx import SenseVoiceAx
4
- from print_utils import rich_transcription_postprocess
5
-
6
- max_len = 256
7
-
8
- model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
9
-
10
- assert os.path.exists(model_path), f"model {model_path} not exist"
11
-
12
- pipeline = SenseVoiceAx(
13
- model_path,
14
- max_len=max_len,
15
- beam_size=3,
16
- language="auto",
17
- hot_words=None,
18
- use_itn=True,
19
- streaming=False,
20
- )
21
-
22
-
23
- def speech_to_text(audio_path, lang):
24
- """
25
- audio_path: 音频文件路径
26
- lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko"
27
- """
28
- if not audio_path:
29
- return "无音频"
30
-
31
- pipeline.choose_language(language=lang)
32
- asr_res = pipeline.infer(audio_path, print_rtf=False)
33
-
34
- return asr_res
35
-
36
-
37
- def main():
38
- with gr.Blocks() as demo:
39
- with gr.Row():
40
- output_text = gr.Textbox(label="识别结果", lines=5)
41
-
42
- with gr.Row():
43
- audio_input = gr.Audio(
44
- sources=["upload"], type="filepath", label="录制或上传音频", format="mp3"
45
- )
46
- lang_dropdown = gr.Dropdown(
47
- choices=["auto", "zh", "en", "yue", "ja", "ko"],
48
- value="auto",
49
- label="选择音频语言",
50
- )
51
-
52
- audio_input.change(
53
- fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text
54
- )
55
-
56
- demo.launch(
57
- server_name="0.0.0.0",
58
- )
59
-
60
-
61
- if __name__ == "__main__":
62
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,79 +0,0 @@
1
- import os
2
- import argparse
3
- from SenseVoiceAx import SenseVoiceAx
4
- import librosa
5
- import numpy as np
6
- import time
7
-
8
-
9
- def get_args():
10
- parser = argparse.ArgumentParser()
11
- parser.add_argument(
12
- "--input", "-i", required=True, type=str, help="Input audio file"
13
- )
14
- parser.add_argument(
15
- "--language",
16
- "-l",
17
- required=False,
18
- type=str,
19
- default="auto",
20
- choices=["auto", "zh", "en", "yue", "ja", "ko"],
21
- )
22
- parser.add_argument("--streaming", action="store_true")
23
- return parser.parse_args()
24
-
25
-
26
- def main():
27
- args = get_args()
28
-
29
- input_audio = args.input
30
- language = args.language
31
- use_itn = True # 标点符号预测
32
- if not args.streaming:
33
- max_len = 256
34
- model_path = os.path.join("sensevoice_ax630c", "sensevoice.axmodel")
35
- else:
36
- max_len = 26
37
- model_path = os.path.join("sensevoice_ax630c", "streaming_sensevoice.axmodel")
38
-
39
- assert os.path.exists(model_path), f"model {model_path} not exist"
40
-
41
- print(f"input_audio: {input_audio}")
42
- print(f"language: {language}")
43
- print(f"use_itn: {use_itn}")
44
- print(f"model_path: {model_path}")
45
- print(f"streaming: {args.streaming}")
46
-
47
- pipeline = SenseVoiceAx(
48
- model_path,
49
- max_len=max_len,
50
- beam_size=3,
51
- language="auto",
52
- hot_words=None,
53
- use_itn=True,
54
- streaming=args.streaming,
55
- )
56
-
57
- if not args.streaming:
58
- asr_res = pipeline.infer(input_audio, print_rtf=True)
59
- print("ASR result: " + asr_res)
60
- else:
61
- samples, sr = librosa.load(input_audio, sr=16000)
62
- samples = (samples * 32768).tolist()
63
- duration = len(samples) / 16000
64
-
65
- start = time.time()
66
- step = int(0.1 * sr)
67
- for i in range(0, len(samples), step):
68
- is_last = i + step >= len(samples)
69
- for res in pipeline.stream_infer(samples[i : i + step], is_last):
70
- print(res)
71
-
72
- end = time.time()
73
- cost_time = end - start
74
-
75
- print(f"RTF: {cost_time / duration}")
76
-
77
-
78
- if __name__ == "__main__":
79
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
print_utils.py DELETED
@@ -1,131 +0,0 @@
1
- emo_dict = {
2
- "<|HAPPY|>": "😊",
3
- "<|SAD|>": "😔",
4
- "<|ANGRY|>": "😡",
5
- "<|NEUTRAL|>": "",
6
- "<|FEARFUL|>": "😰",
7
- "<|DISGUSTED|>": "🤢",
8
- "<|SURPRISED|>": "😮",
9
- }
10
-
11
- event_dict = {
12
- "<|BGM|>": "🎼",
13
- "<|Speech|>": "",
14
- "<|Applause|>": "👏",
15
- "<|Laughter|>": "😀",
16
- "<|Cry|>": "😭",
17
- "<|Sneeze|>": "🤧",
18
- "<|Breath|>": "",
19
- "<|Cough|>": "🤧",
20
- }
21
-
22
- lang_dict = {
23
- "<|zh|>": "<|lang|>",
24
- "<|en|>": "<|lang|>",
25
- "<|yue|>": "<|lang|>",
26
- "<|ja|>": "<|lang|>",
27
- "<|ko|>": "<|lang|>",
28
- "<|nospeech|>": "<|lang|>",
29
- }
30
-
31
- emoji_dict = {
32
- "<|nospeech|><|Event_UNK|>": "❓",
33
- "<|zh|>": "",
34
- "<|en|>": "",
35
- "<|yue|>": "",
36
- "<|ja|>": "",
37
- "<|ko|>": "",
38
- "<|nospeech|>": "",
39
- "<|HAPPY|>": "😊",
40
- "<|SAD|>": "😔",
41
- "<|ANGRY|>": "😡",
42
- "<|NEUTRAL|>": "",
43
- "<|BGM|>": "🎼",
44
- "<|Speech|>": "",
45
- "<|Applause|>": "👏",
46
- "<|Laughter|>": "😀",
47
- "<|FEARFUL|>": "😰",
48
- "<|DISGUSTED|>": "🤢",
49
- "<|SURPRISED|>": "😮",
50
- "<|Cry|>": "😭",
51
- "<|EMO_UNKNOWN|>": "",
52
- "<|Sneeze|>": "🤧",
53
- "<|Breath|>": "",
54
- "<|Cough|>": "😷",
55
- "<|Sing|>": "",
56
- "<|Speech_Noise|>": "",
57
- "<|withitn|>": "",
58
- "<|woitn|>": "",
59
- "<|GBG|>": "",
60
- "<|Event_UNK|>": "",
61
- }
62
-
63
- emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
64
- event_set = {
65
- "🎼",
66
- "👏",
67
- "😀",
68
- "😭",
69
- "🤧",
70
- "😷",
71
- }
72
-
73
-
74
- def format_str_v2(s):
75
- sptk_dict = {}
76
- for sptk in emoji_dict:
77
- sptk_dict[sptk] = s.count(sptk)
78
- s = s.replace(sptk, "")
79
- emo = "<|NEUTRAL|>"
80
- for e in emo_dict:
81
- if sptk_dict[e] > sptk_dict[emo]:
82
- emo = e
83
- for e in event_dict:
84
- if sptk_dict[e] > 0:
85
- s = event_dict[e] + s
86
- s = s + emo_dict[emo]
87
-
88
- for emoji in emo_set.union(event_set):
89
- s = s.replace(" " + emoji, emoji)
90
- s = s.replace(emoji + " ", emoji)
91
- return s.strip()
92
-
93
-
94
- def rich_transcription_postprocess(s):
95
- def get_emo(s):
96
- return s[-1] if s[-1] in emo_set else None
97
-
98
- def get_event(s):
99
- return s[0] if s[0] in event_set else None
100
-
101
- s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
102
- for lang in lang_dict:
103
- s = s.replace(lang, "<|lang|>")
104
- s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
105
- new_s = " " + s_list[0]
106
- cur_ent_event = get_event(new_s)
107
- for i in range(1, len(s_list)):
108
- if len(s_list[i]) == 0:
109
- continue
110
- if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
111
- s_list[i] = s_list[i][1:]
112
- # else:
113
- cur_ent_event = get_event(s_list[i])
114
- if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
115
- new_s = new_s[:-1]
116
- new_s += s_list[i].strip().lstrip()
117
- new_s = new_s.replace("The.", " ")
118
- return new_s.strip()
119
-
120
-
121
- def rich_print_asr_res(asr_res, will_print=True, remove_punc=False):
122
- res = "".join([rich_transcription_postprocess(i) for i in asr_res])
123
-
124
- if remove_punc:
125
- res = res.replace(",", "")
126
- res = res.replace("。", "")
127
-
128
- if will_print:
129
- print(res)
130
-
131
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,11 +0,0 @@
1
- huggingface_hub
2
- numpy<2
3
- kaldi-native-fbank
4
- librosa==0.9.1
5
- sentencepiece
6
- fastapi
7
- gradio
8
- emoji
9
- asr-decoder
10
- online-fbank
11
- torch
 
 
 
 
 
 
 
 
 
 
 
 
SenseVoiceAx.py → sensevoice-maixcam2/SenseVoiceAx.py RENAMED
@@ -281,18 +281,18 @@ class SenseVoiceAx:
281
 
282
  """
283
  model_path_root = os.path.dirname(model_path)
284
- emb_path = os.path.join(model_path_root, "../embeddings.npy")
285
- cmvn_file = os.path.join(model_path_root, "../am.mvn")
286
  bpe_model = os.path.join(
287
- model_path_root, "../chn_jpn_yue_eng_ko_spectok.bpe.model"
288
  )
289
  if streaming:
290
  self.position_encoding = np.load(
291
- os.path.join(model_path_root, "../pe_streaming.npy")
292
  )
293
  else:
294
  self.position_encoding = np.load(
295
- os.path.join(model_path_root, "../pe_nonstream.npy")
296
  )
297
 
298
  self.streaming = streaming
@@ -553,7 +553,7 @@ class SenseVoiceAx:
553
  self.cur_idx = -1
554
  self.decoder.reset()
555
  self.fbank = OnlineFbank(window_type="hamming")
556
- self.caches = np.zeros(self.caches_shape)
557
 
558
  def get_size(self):
559
  effective_size = self.cur_idx + 1 - self.padding
 
281
 
282
  """
283
  model_path_root = os.path.dirname(model_path)
284
+ emb_path = os.path.join(model_path_root, "embeddings.npy")
285
+ cmvn_file = os.path.join(model_path_root, "am.mvn")
286
  bpe_model = os.path.join(
287
+ model_path_root, "chn_jpn_yue_eng_ko_spectok.bpe.model"
288
  )
289
  if streaming:
290
  self.position_encoding = np.load(
291
+ os.path.join(model_path_root, "pe_streaming.npy")
292
  )
293
  else:
294
  self.position_encoding = np.load(
295
+ os.path.join(model_path_root, "pe_nonstream.npy")
296
  )
297
 
298
  self.streaming = streaming
 
553
  self.cur_idx = -1
554
  self.decoder.reset()
555
  self.fbank = OnlineFbank(window_type="hamming")
556
+ self.caches = np.zeros(self.caches_shape, dtype=np.float32)
557
 
558
  def get_size(self):
559
  effective_size = self.cur_idx + 1 - self.padding
am.mvn → sensevoice-maixcam2/am.mvn RENAMED
File without changes
chn_jpn_yue_eng_ko_spectok.bpe.model → sensevoice-maixcam2/chn_jpn_yue_eng_ko_spectok.bpe.model RENAMED
File without changes
client.py → sensevoice-maixcam2/client.py RENAMED
@@ -6,7 +6,7 @@ import requests, json, os
6
  import wave
7
  import numpy as np
8
  import threading
9
- from maix import app, time
10
 
11
  class Sensevoice:
12
  def __init__(self, model = "", url="http://0.0.0.0:12347", lauguage="auto", stream=False):
@@ -26,6 +26,17 @@ class Sensevoice:
26
  if not os.path.exists(model):
27
  raise ValueError(f'Model {self.model} is not existed!')
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  def _check_service(self):
30
  try:
31
  response = requests.get(self.url + '/status')
@@ -75,6 +86,15 @@ class Sensevoice:
75
  except:
76
  return "not loaded"
77
 
 
 
 
 
 
 
 
 
 
78
  def _start_model(self):
79
  try:
80
  data = {
@@ -83,6 +103,7 @@ class Sensevoice:
83
  "language": self.launguage,
84
  "stream": self.stream
85
  }
 
86
  response = requests.post(self.url + '/start_model', json=data)
87
  if response.status_code == 200:
88
  res = json.loads(response.text)
@@ -129,6 +150,9 @@ class Sensevoice:
129
 
130
  def is_ready(self, block=False):
131
  while not app.need_exit():
 
 
 
132
  if self._get_status() == "loaded":
133
  return True
134
  else:
@@ -136,10 +160,6 @@ class Sensevoice:
136
  time.sleep(1)
137
  else:
138
  return False
139
-
140
- if self.thread_is_exit:
141
- return True if self.thread_exit_code == 0 else False
142
-
143
  return False
144
 
145
  def stop(self):
@@ -225,17 +245,20 @@ class Sensevoice:
225
  }
226
 
227
  try:
 
228
  response = requests.post(self.url + '/asr', json=data)
229
  if response.status_code == 200:
230
  res = json.loads(response.text)
231
  text = res.get("text", "")
232
  if len(text) > 0:
233
- return text[0]
234
  else:
235
- return ""
236
  else:
237
  print(f"Requests failed: {response.status_code}")
238
- return ""
 
 
239
  except Exception as e:
240
  print("Requests failed:", e)
241
  return ""
@@ -258,13 +281,15 @@ class Sensevoice:
258
  "launguage": "auto",
259
  "step": 0.1,
260
  }
261
- print('start post')
262
  try:
263
  response = requests.post(self.url + '/asr_stream', json=data, stream=True)
264
  for line in response.iter_lines():
265
  if line:
266
  chunk = json.loads(line)
267
  yield chunk.get("text", "")
 
 
268
  except Exception as e:
269
  print("Requests failed:", e)
270
  return ""
 
6
  import wave
7
  import numpy as np
8
  import threading
9
+ from maix import app, time, nn
10
 
11
  class Sensevoice:
12
  def __init__(self, model = "", url="http://0.0.0.0:12347", lauguage="auto", stream=False):
 
26
  if not os.path.exists(model):
27
  raise ValueError(f'Model {self.model} is not existed!')
28
 
29
+ self.model_mud = nn.MUD(model)
30
+ self.model_configs = self.model_mud.items
31
+ server_path = self.model_configs.get('extra', {}).get('server', '/root/models/sensevoice_maixcam2/server.py')
32
+ model_dir = os.path.dirname(self.model)
33
+ service_env_path = self.model_configs.get('extra', {}).get('sensevoice_service_env')
34
+ self._create_service_environment_file(model_dir, os.path.join(model_dir, server_path), os.path.join(model_dir, service_env_path))
35
+
36
+ def _create_service_environment_file(self, working_dir: str, server_path: str, env_file_path: str = "/tmp/sensevoice.service.env"):
37
+ os.system(f"echo 'WORK_DIR={os.path.realpath(working_dir)}' > {os.path.realpath(env_file_path)}")
38
+ os.system(f"echo 'SERVER_FILE={os.path.realpath(server_path)}' >> {os.path.realpath(env_file_path)}")
39
+
40
  def _check_service(self):
41
  try:
42
  response = requests.get(self.url + '/status')
 
86
  except:
87
  return "not loaded"
88
 
89
+ def _reset_model(self):
90
+ try:
91
+ response = requests.post(self.url + '/reset_model')
92
+ if response.status_code == 200:
93
+ res = json.loads(response.text)
94
+ return res["status"]
95
+ except:
96
+ return "not loaded"
97
+
98
  def _start_model(self):
99
  try:
100
  data = {
 
103
  "language": self.launguage,
104
  "stream": self.stream
105
  }
106
+
107
  response = requests.post(self.url + '/start_model', json=data)
108
  if response.status_code == 200:
109
  res = json.loads(response.text)
 
150
 
151
  def is_ready(self, block=False):
152
  while not app.need_exit():
153
+ if self.thread_is_exit:
154
+ return True if self.thread_exit_code == 0 else False
155
+
156
  if self._get_status() == "loaded":
157
  return True
158
  else:
 
160
  time.sleep(1)
161
  else:
162
  return False
 
 
 
 
163
  return False
164
 
165
  def stop(self):
 
245
  }
246
 
247
  try:
248
+ result = ""
249
  response = requests.post(self.url + '/asr', json=data)
250
  if response.status_code == 200:
251
  res = json.loads(response.text)
252
  text = res.get("text", "")
253
  if len(text) > 0:
254
+ result = text[0]
255
  else:
256
+ result = ""
257
  else:
258
  print(f"Requests failed: {response.status_code}")
259
+ result = ""
260
+
261
+ return result
262
  except Exception as e:
263
  print("Requests failed:", e)
264
  return ""
 
281
  "launguage": "auto",
282
  "step": 0.1,
283
  }
284
+
285
  try:
286
  response = requests.post(self.url + '/asr_stream', json=data, stream=True)
287
  for line in response.iter_lines():
288
  if line:
289
  chunk = json.loads(line)
290
  yield chunk.get("text", "")
291
+
292
+ self._reset_model()
293
  except Exception as e:
294
  print("Requests failed:", e)
295
  return ""
embeddings.npy → sensevoice-maixcam2/embeddings.npy RENAMED
File without changes
frontend.py → sensevoice-maixcam2/frontend.py RENAMED
File without changes
model.mud → sensevoice-maixcam2/model.mud RENAMED
@@ -1,6 +1,6 @@
1
  [basic]
2
  type = axmodel
3
- model_npu = sensevoice_ax630c/sensevoice.axmodel
4
  model_vnpu =
5
 
6
  [extra]
@@ -11,5 +11,6 @@ beam_size = 3
11
  language = auto
12
  hot_words = None,
13
  use_itn = True
14
- stream_model = sensevoice_ax630c/streaming_sensevoice.axmodel
15
- server_url = http://127.0.0.1:12345
 
 
1
  [basic]
2
  type = axmodel
3
+ model_npu = sensevoice.axmodel
4
  model_vnpu =
5
 
6
  [extra]
 
11
  language = auto
12
  hot_words = None,
13
  use_itn = True
14
+ stream_model = streaming_sensevoice.axmodel
15
+ sensevoice_service_env = /tmp/sensevoice.service.env
16
+ server = server.py
pe_nonstream.npy → sensevoice-maixcam2/pe_nonstream.npy RENAMED
File without changes
pe_streaming.npy → sensevoice-maixcam2/pe_streaming.npy RENAMED
File without changes
{sensevoice_ax630c → sensevoice-maixcam2}/sensevoice.axmodel RENAMED
File without changes
server.py → sensevoice-maixcam2/server.py RENAMED
@@ -1,6 +1,7 @@
1
  import numpy as np
2
  from fastapi import FastAPI, HTTPException, Body
3
  from fastapi.responses import JSONResponse, StreamingResponse
 
4
  from typing import List, Optional
5
  import logging
6
  import json
@@ -15,13 +16,12 @@ import time
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
- app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
19
-
20
  # 全局变量存储模型
21
  asr_model = None
22
  asr_model_is_loaded = False
23
  mud_configs = None
24
 
 
25
  def parse_config_file_to_json(file_path):
26
  """从文件读取配置并解析为JSON"""
27
  if not os.path.exists(file_path):
@@ -96,9 +96,12 @@ async def handle_heartbeat():
96
  last_get_heartbeat_s = time.time() # 模型未加载时,重置获取心跳的时间
97
  await asyncio.sleep(10) # 每隔 10 秒检查一次
98
 
99
- @app.on_event("startup")
100
- async def load_model():
101
  asyncio.create_task(handle_heartbeat())
 
 
 
102
 
103
  def validate_audio_data(audio_data: List[float]) -> np.ndarray:
104
  """
@@ -147,6 +150,15 @@ async def get_status():
147
  global asr_model_is_loaded
148
  return JSONResponse(content={"status": "loaded" if asr_model_is_loaded else "not loaded"})
149
 
 
 
 
 
 
 
 
 
 
150
  @app.post("/start_model", summary="Load model")
151
  async def start_model(
152
  model_path: str = Body(
@@ -161,10 +173,14 @@ async def start_model(
161
  """
162
  global asr_model
163
  global asr_model_is_loaded
164
- logger.info("Loading ASR model...")
165
 
166
  if asr_model_is_loaded:
167
- return JSONResponse(content={"status": "loaded"})
 
 
 
 
 
168
 
169
  try:
170
  mud_configs = parse_config_file_to_json(model_path)
 
1
  import numpy as np
2
  from fastapi import FastAPI, HTTPException, Body
3
  from fastapi.responses import JSONResponse, StreamingResponse
4
+ from contextlib import asynccontextmanager
5
  from typing import List, Optional
6
  import logging
7
  import json
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
 
19
  # 全局变量存储模型
20
  asr_model = None
21
  asr_model_is_loaded = False
22
  mud_configs = None
23
 
24
+
25
  def parse_config_file_to_json(file_path):
26
  """从文件读取配置并解析为JSON"""
27
  if not os.path.exists(file_path):
 
96
  last_get_heartbeat_s = time.time() # 模型未加载时,重置获取心跳的时间
97
  await asyncio.sleep(10) # 每隔 10 秒检查一次
98
 
99
+ @asynccontextmanager
100
+ async def lifespan(app: FastAPI):
101
  asyncio.create_task(handle_heartbeat())
102
+ yield
103
+
104
+ app = FastAPI(title="ASR Server", lifespan = lifespan, description="Automatic Speech Recognition API")
105
 
106
  def validate_audio_data(audio_data: List[float]) -> np.ndarray:
107
  """
 
150
  global asr_model_is_loaded
151
  return JSONResponse(content={"status": "loaded" if asr_model_is_loaded else "not loaded"})
152
 
153
+ @app.post("/reset_model", summary="Reset model")
154
+ async def reset_model():
155
+ global asr_model
156
+ global asr_model_is_loaded
157
+ if not asr_model_is_loaded:
158
+ return JSONResponse(content={"status": "not loaded"})
159
+ asr_model.reset()
160
+ return JSONResponse(content={"status": "ok"})
161
+
162
  @app.post("/start_model", summary="Load model")
163
  async def start_model(
164
  model_path: str = Body(
 
173
  """
174
  global asr_model
175
  global asr_model_is_loaded
 
176
 
177
  if asr_model_is_loaded:
178
+ if asr_model.streaming == stream:
179
+ return JSONResponse(content={"status": "loaded"})
180
+ else:
181
+ await try_stop_model()
182
+
183
+ logger.info("Loading ASR model...")
184
 
185
  try:
186
  mud_configs = parse_config_file_to_json(model_path)
{sensevoice_ax630c → sensevoice-maixcam2}/streaming_sensevoice.axmodel RENAMED
File without changes
tokenizer.py → sensevoice-maixcam2/tokenizer.py RENAMED
File without changes
test_wer.py DELETED
@@ -1,296 +0,0 @@
1
- import os, sys
2
- import argparse
3
- from SenseVoiceAx import SenseVoiceAx
4
- from tokenizer import SentencepiecesTokenizer
5
- from print_utils import rich_transcription_postprocess, rich_print_asr_res
6
- from download_utils import download_model
7
- import logging
8
- import re
9
- import emoji
10
-
11
-
12
- def setup_logging():
13
- """配置日志系统,同时输出到控制台和文件"""
14
- # 获取脚本所在目录
15
- script_dir = os.path.dirname(os.path.abspath(__file__))
16
- log_file = os.path.join(script_dir, "test_wer.log")
17
-
18
- # 配置日志格式
19
- log_format = "%(asctime)s - %(levelname)s - %(message)s"
20
- date_format = "%Y-%m-%d %H:%M:%S"
21
-
22
- # 创建logger
23
- logger = logging.getLogger()
24
- logger.setLevel(logging.INFO)
25
-
26
- # 清除现有的handler
27
- for handler in logger.handlers[:]:
28
- logger.removeHandler(handler)
29
-
30
- # 创建文件handler
31
- file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
32
- file_handler.setLevel(logging.INFO)
33
- file_formatter = logging.Formatter(log_format, date_format)
34
- file_handler.setFormatter(file_formatter)
35
-
36
- # 创建控制台handler
37
- console_handler = logging.StreamHandler()
38
- console_handler.setLevel(logging.INFO)
39
- console_formatter = logging.Formatter(log_format, date_format)
40
- console_handler.setFormatter(console_formatter)
41
-
42
- # 添加handler到logger
43
- logger.addHandler(file_handler)
44
- logger.addHandler(console_handler)
45
-
46
- return logger
47
-
48
-
49
- class AIShellDataset:
50
- def __init__(self, gt_path: str):
51
- """
52
- 初始化数据集
53
-
54
- Args:
55
- json_path: voice.json文件的路径
56
- """
57
- self.gt_path = gt_path
58
- self.dataset_dir = os.path.dirname(gt_path)
59
- self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
60
-
61
- # 检查必要文件和文件夹是否存在
62
- assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
63
- assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
64
-
65
- # 加载数据
66
- self.data = []
67
- with open(gt_path, "r", encoding="utf-8") as f:
68
- for line in f:
69
- line = line.strip()
70
- audio_path, gt = line.split(" ")
71
- audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
72
- self.data.append({"audio_path": audio_path, "gt": gt})
73
-
74
- # 使用logging而不是print
75
- logger = logging.getLogger()
76
- logger.info(f"加载了 {len(self.data)} 条数据")
77
-
78
- def __iter__(self):
79
- """返回迭代器"""
80
- self.index = 0
81
- return self
82
-
83
- def __next__(self):
84
- """返回下一个数据项"""
85
- if self.index >= len(self.data):
86
- raise StopIteration
87
-
88
- item = self.data[self.index]
89
- audio_path = item["audio_path"]
90
- ground_truth = item["gt"]
91
-
92
- self.index += 1
93
- return audio_path, ground_truth
94
-
95
- def __len__(self):
96
- """返回数据集大小"""
97
- return len(self.data)
98
-
99
-
100
- class CommonVoiceDataset:
101
- """Common Voice数据集解析器"""
102
-
103
- def __init__(self, tsv_path: str):
104
- """
105
- 初始化数据集
106
-
107
- Args:
108
- json_path: voice.json文件的路径
109
- """
110
- self.tsv_path = tsv_path
111
- self.dataset_dir = os.path.dirname(tsv_path)
112
- self.voice_dir = os.path.join(self.dataset_dir, "clips")
113
-
114
- # 检查必要文件和文件夹是否存在
115
- assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
116
- assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
117
-
118
- # 加载JSON数据
119
- self.data = []
120
- with open(tsv_path, "r", encoding="utf-8") as f:
121
- f.readline()
122
- for line in f:
123
- line = line.strip()
124
- splits = line.split("\t")
125
- audio_path = splits[1]
126
- gt = splits[3]
127
- audio_path = os.path.join(self.voice_dir, audio_path)
128
- self.data.append({"audio_path": audio_path, "gt": gt})
129
-
130
- # 使用logging而不是print
131
- logger = logging.getLogger()
132
- logger.info(f"加载了 {len(self.data)} 条数据")
133
-
134
- def __iter__(self):
135
- """返回迭代器"""
136
- self.index = 0
137
- return self
138
-
139
- def __next__(self):
140
- """返回下一个数据项"""
141
- if self.index >= len(self.data):
142
- raise StopIteration
143
-
144
- item = self.data[self.index]
145
- audio_path = item["audio_path"]
146
- ground_truth = item["gt"]
147
-
148
- self.index += 1
149
- return audio_path, ground_truth
150
-
151
- def __len__(self):
152
- """返回数据集大小"""
153
- return len(self.data)
154
-
155
-
156
- def get_args():
157
- parser = argparse.ArgumentParser()
158
- parser.add_argument(
159
- "--dataset",
160
- "-d",
161
- type=str,
162
- required=True,
163
- choices=["aishell", "common_voice"],
164
- help="Test dataset",
165
- )
166
- parser.add_argument(
167
- "--gt_path",
168
- "-g",
169
- type=str,
170
- required=True,
171
- help="Test dataset ground truth file",
172
- )
173
- parser.add_argument(
174
- "--language",
175
- "-l",
176
- required=False,
177
- type=str,
178
- default="auto",
179
- choices=["auto", "zh", "en", "yue", "ja", "ko"],
180
- )
181
- parser.add_argument(
182
- "--max_num", type=int, default=-1, required=False, help="Maximum test data num"
183
- )
184
- return parser.parse_args()
185
-
186
-
187
- def min_distance(word1: str, word2: str) -> int:
188
-
189
- row = len(word1) + 1
190
- column = len(word2) + 1
191
-
192
- cache = [[0] * column for i in range(row)]
193
-
194
- for i in range(row):
195
- for j in range(column):
196
-
197
- if i == 0 and j == 0:
198
- cache[i][j] = 0
199
- elif i == 0 and j != 0:
200
- cache[i][j] = j
201
- elif j == 0 and i != 0:
202
- cache[i][j] = i
203
- else:
204
- if word1[i - 1] == word2[j - 1]:
205
- cache[i][j] = cache[i - 1][j - 1]
206
- else:
207
- replace = cache[i - 1][j - 1] + 1
208
- insert = cache[i][j - 1] + 1
209
- remove = cache[i - 1][j] + 1
210
-
211
- cache[i][j] = min(replace, insert, remove)
212
-
213
- return cache[row - 1][column - 1]
214
-
215
-
216
- def remove_punctuation(text):
217
- # 定义正则表达式模式,匹配所有标点符号
218
- # 这个模式包括常见的标点符号和中文标点
219
- pattern = r"[^\w\s]|_"
220
-
221
- # 使用sub方法将所有匹配的标点符号替换为空字符串
222
- cleaned_text = re.sub(pattern, "", text)
223
-
224
- return cleaned_text
225
-
226
-
227
- def main():
228
- logger = setup_logging()
229
- args = get_args()
230
-
231
- language = args.language
232
- use_itn = False # 标点符号预测
233
- max_num = args.max_num
234
-
235
- dataset_type = args.dataset.lower()
236
- if dataset_type == "aishell":
237
- dataset = AIShellDataset(args.gt_path)
238
- elif dataset_type == "common_voice":
239
- dataset = CommonVoiceDataset(args.gt_path)
240
- else:
241
- raise ValueError(f"Unknown dataset type {dataset_type}")
242
-
243
- # model_path_root = download_model("SenseVoice")
244
- model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
245
- bpemodel = "chn_jpn_yue_eng_ko_spectok.bpe.model"
246
-
247
- assert os.path.exists(model_path), f"model {model_path} not exist"
248
-
249
- logger.info(f"dataset: {args.dataset}")
250
- logger.info(f"language: {language}")
251
- logger.info(f"use_itn: {use_itn}")
252
- logger.info(f"model_path: {model_path}")
253
-
254
- tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
255
- pipeline = SenseVoiceAx(
256
- model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256
257
- )
258
-
259
- # Iterate over dataset
260
- hyp = []
261
- references = []
262
- all_character_error_num = 0
263
- all_character_num = 0
264
- max_data_num = max_num if max_num > 0 else len(dataset)
265
- for n, (audio_path, reference) in enumerate(dataset):
266
- reference = remove_punctuation(reference).lower()
267
-
268
- asr_res = pipeline.infer(audio_path, print_rtf=False)
269
- hypothesis = rich_print_asr_res(
270
- asr_res, will_print=False, remove_punc=True
271
- ).lower()
272
- hypothesis = emoji.replace_emoji(hypothesis, replace="")
273
-
274
- character_error_num = min_distance(reference, hypothesis)
275
- character_num = len(reference)
276
- character_error_rate = character_error_num / character_num * 100
277
-
278
- all_character_error_num += character_error_num
279
- all_character_num += character_num
280
-
281
- hyp.append(hypothesis)
282
- references.append(reference)
283
-
284
- line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
285
- logger.info(line_content)
286
-
287
- if n + 1 >= max_data_num:
288
- break
289
-
290
- total_character_error_rate = all_character_error_num / all_character_num * 100
291
-
292
- logger.info(f"Total WER: {total_character_error_rate}%")
293
-
294
-
295
- if __name__ == "__main__":
296
- main()