yangrongzhao commited on
Commit
07f9af1
·
1 Parent(s): 5f47283

enlarge seq_len to 256

Browse files
SenseVoiceAx.py CHANGED
@@ -26,49 +26,33 @@ def sequence_mask(lengths, maxlen=None, dtype=np.float32):
26
  # 返回指定数据类型的掩码
27
  return mask.astype(dtype)[None, ...]
28
 
29
- def unique_consecutive_np(x, dim=None, return_inverse=False, return_counts=False):
30
- if dim is None:
31
- # 默认情况,展平后去重
32
- x_flat = x.ravel()
33
- mask = np.concatenate(([True], x_flat[1:] != x_flat[:-1]))
34
- unique_data = x_flat[mask]
35
- else:
36
- # 沿着指定维度去重
37
- axis = dim if dim >= 0 else x.ndim + dim
38
- if axis >= x.ndim:
39
- raise ValueError(f"dim {dim} is out of range for array of dimension {x.ndim}")
40
-
41
- # 使用 np.diff 检查相邻元素是否相同
42
- mask = np.ones(x.shape[axis], dtype=bool)
43
- if x.shape[axis] > 1:
44
- # 比较当前元素和前一个元素是否不同
45
- diff = np.diff(x, axis=axis)
46
- mask[1:] = np.any(diff != 0, axis=tuple(range(diff.ndim))[axis:])
47
-
48
- # 使用 mask 索引提取唯一元素
49
- unique_data = np.take(x, np.where(mask)[0], axis=axis)
50
-
51
- # 处理 return_inverse 和 return_counts
52
- results = (unique_data,)
53
-
54
- if return_inverse:
55
- if dim is None:
56
- inv_idx = np.cumsum(mask) - 1
57
- else:
58
- inv_idx = np.cumsum(mask) - 1
59
- # 需要调整形状以匹配输入
60
- inv_idx = np.expand_dims(inv_idx, axis=axis)
61
- inv_idx = np.broadcast_to(inv_idx, x.shape)
62
- results += (inv_idx,)
63
-
64
- if return_counts:
65
- if dim is None:
66
- counts = np.diff(np.where(np.concatenate((mask, [True])))[0])
67
- else:
68
- counts = np.diff(np.where(np.concatenate((mask, [True])))[0])
69
- results += (counts,)
70
 
71
- return results[0] if len(results) == 1 else results
 
 
 
 
 
 
72
 
73
 
74
  def longest_common_suffix_prefix_with_tolerance(
@@ -100,7 +84,7 @@ def longest_common_suffix_prefix_with_tolerance(
100
  return 0
101
 
102
  class SenseVoiceAx:
103
- def __init__(self, model_path, max_len=68, language="auto", use_itn=True, tokenizer=None):
104
  model_path_root = os.path.join(os.path.dirname(model_path), "..")
105
  embedding_root = os.path.join(model_path_root, "embeddings")
106
  self.frontend = WavFrontend(cmvn_file=f"{model_path_root}/am.mvn",
@@ -125,12 +109,29 @@ class SenseVoiceAx:
125
  self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
126
 
127
  self.position_encoding = np.load(f"{embedding_root}/position_encoding.npy")
128
- language_query = np.load(f"{embedding_root}/{language}.npy")
129
- textnorm_query = np.load(f"{embedding_root}/withitn.npy") if use_itn else np.load(f"{embedding_root}/woitn.npy")
130
- event_emo_query = np.load(f"{embedding_root}/event_emo.npy")
131
- self.input_query = np.concatenate((textnorm_query, language_query, event_emo_query), axis=1)
132
  self.query_num = self.input_query.shape[1]
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def load_data(self, filepath: str) -> np.ndarray:
135
  waveform, _ = librosa.load(filepath, sr=self.sample_rate)
136
  return waveform.flatten()
@@ -165,7 +166,7 @@ class SenseVoiceAx:
165
  yseq = np.argmax(x, axis=-1)
166
 
167
  # 去除连续重复元素
168
- yseq = unique_consecutive_np(yseq, dim=-1)
169
 
170
  # 创建掩码并过滤 blank_id
171
  mask = yseq != self.blank_id
@@ -173,14 +174,16 @@ class SenseVoiceAx:
173
 
174
  return token_int
175
 
176
- def infer_waveform(self, waveform: np.ndarray):
 
 
 
177
  feat, feat_len = self.preprocess(waveform)
178
 
179
  slice_len = self.max_len - self.query_num
180
  slice_num = int(np.ceil(feat.shape[1] / slice_len))
181
 
182
  asr_res = []
183
- prev_token_int = None
184
  for i in range(slice_num):
185
  if i == 0:
186
  sub_feat = feat[:, i*slice_len:(i+1)*slice_len, :]
@@ -205,20 +208,14 @@ class SenseVoiceAx:
205
 
206
  token_int = self.postprocess(ctc_logits, encoder_out_lens)
207
 
208
- # common prefix
209
- if self.padding > 0 and prev_token_int is not None:
210
- # prefix_len = common_prefix_len(prev_token_int, token_int)
211
- prefix_len = longest_common_suffix_prefix_with_tolerance(prev_token_int, token_int, 6)
212
- common_prefix = rich_transcription_postprocess(self.tokenizer.tokens2text(token_int[:prefix_len]))
213
-
214
- asr_res[-1] = asr_res[-1][:-len(common_prefix)]
215
- prev_token_int = np.copy(token_int)
216
-
217
- asr_res.append(self.tokenizer.tokens2text(token_int))
218
 
219
  return asr_res
220
 
221
- def infer(self, filepath_or_data: Union[np.ndarray, str], print_rtf=True):
222
  if isinstance(filepath_or_data, str):
223
  waveform = self.load_data(filepath_or_data)
224
  else:
@@ -227,7 +224,7 @@ class SenseVoiceAx:
227
  total_time = waveform.shape[-1] / self.sample_rate
228
 
229
  start = time.time()
230
- asr_res = self.infer_waveform(waveform)
231
  latency = time.time() - start
232
 
233
  if print_rtf:
 
26
  # 返回指定数据类型的掩码
27
  return mask.astype(dtype)[None, ...]
28
 
29
+ def unique_consecutive_np(arr):
30
+ """
31
+ 找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1)
32
+
33
+ 参数:
34
+ arr: 一维numpy数组
35
+
36
+ 返回:
37
+ unique_values: 去除连续重复值后的数组
38
+ """
39
+ if len(arr) == 0:
40
+ return np.array([])
41
+
42
+ if len(arr) == 1:
43
+ return arr.copy()
44
+
45
+ # 找出变化的位置
46
+ diff = np.diff(arr)
47
+ change_positions = np.where(diff != 0)[0] + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # 添加起始位置
50
+ start_positions = np.concatenate(([0], change_positions))
51
+
52
+ # 获取唯一值(每个连续段的第一个值)
53
+ unique_values = arr[start_positions]
54
+
55
+ return unique_values
56
 
57
 
58
  def longest_common_suffix_prefix_with_tolerance(
 
84
  return 0
85
 
86
  class SenseVoiceAx:
87
+ def __init__(self, model_path, max_len=256, language="auto", use_itn=True, tokenizer=None):
88
  model_path_root = os.path.join(os.path.dirname(model_path), "..")
89
  embedding_root = os.path.join(model_path_root, "embeddings")
90
  self.frontend = WavFrontend(cmvn_file=f"{model_path_root}/am.mvn",
 
109
  self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
110
 
111
  self.position_encoding = np.load(f"{embedding_root}/position_encoding.npy")
112
+ self.language_query = np.load(f"{embedding_root}/{language}.npy")
113
+ self.textnorm_query = np.load(f"{embedding_root}/withitn.npy") if use_itn else np.load(f"{embedding_root}/woitn.npy")
114
+ self.event_emo_query = np.load(f"{embedding_root}/event_emo.npy")
115
+ self.input_query = np.concatenate((self.textnorm_query, self.language_query, self.event_emo_query), axis=1)
116
  self.query_num = self.input_query.shape[1]
117
 
118
+ self.model_path_root = model_path_root
119
+ self.embedding_root = embedding_root
120
+ self.language = language
121
+
122
+ @property
123
+ def language_options(self):
124
+ return list(self.lid_dict.keys())
125
+
126
+ @property
127
+ def textnorm_options(self):
128
+ return list(self.textnorm_dict.keys())
129
+
130
+ def choose_language(self, language):
131
+ self.language_query = np.load(f"{self.embedding_root}/{language}.npy")
132
+ self.input_query = np.concatenate((self.textnorm_query, self.language_query, self.event_emo_query), axis=1)
133
+ self.language = language
134
+
135
  def load_data(self, filepath: str) -> np.ndarray:
136
  waveform, _ = librosa.load(filepath, sr=self.sample_rate)
137
  return waveform.flatten()
 
166
  yseq = np.argmax(x, axis=-1)
167
 
168
  # 去除连续重复元素
169
+ yseq = unique_consecutive_np(yseq)
170
 
171
  # 创建掩码并过滤 blank_id
172
  mask = yseq != self.blank_id
 
174
 
175
  return token_int
176
 
177
+ def infer_waveform(self, waveform: np.ndarray, language="auto"):
178
+ if language != self.language:
179
+ self.choose_language(language)
180
+
181
  feat, feat_len = self.preprocess(waveform)
182
 
183
  slice_len = self.max_len - self.query_num
184
  slice_num = int(np.ceil(feat.shape[1] / slice_len))
185
 
186
  asr_res = []
 
187
  for i in range(slice_num):
188
  if i == 0:
189
  sub_feat = feat[:, i*slice_len:(i+1)*slice_len, :]
 
208
 
209
  token_int = self.postprocess(ctc_logits, encoder_out_lens)
210
 
211
+ if self.tokenizer is not None:
212
+ asr_res.append(self.tokenizer.tokens2text(token_int))
213
+ else:
214
+ asr_res.append(token_int)
 
 
 
 
 
 
215
 
216
  return asr_res
217
 
218
+ def infer(self, filepath_or_data: Union[np.ndarray, str], language="auto", print_rtf=True):
219
  if isinstance(filepath_or_data, str):
220
  waveform = self.load_data(filepath_or_data)
221
  else:
 
224
  total_time = waveform.shape[-1] / self.sample_rate
225
 
226
  start = time.time()
227
+ asr_res = self.infer_waveform(waveform, language)
228
  latency = time.time() - start
229
 
230
  if print_rtf:
embeddings/position_encoding.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:406a92d1305e9ddd5e7538e0a5849ca3128a1922970acdf75ee9d953e6983850
3
- size 152448
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f1c9c550bd62fa164a959517f52d46a28591812fafdf002df0df2bd998f44b5
3
+ size 573568
gradio_demo.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ from tokenizer import SentencepiecesTokenizer
5
+ from print_utils import rich_transcription_postprocess
6
+ from download_utils import download_model
7
+
8
+ use_itn = True # 标点符号预测
9
+ max_len = 68
10
+
11
+ model_path_root = download_model("SenseVoice")
12
+ model_path = os.path.join(model_path_root, "sensevoice_ax650", "sensevoice.axmodel")
13
+ bpemodel = os.path.join(model_path_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
14
+
15
+ assert os.path.exists(model_path), f"model {model_path} not exist"
16
+
17
+ tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
18
+ pipeline = SenseVoiceAx(model_path,
19
+ max_len=max_len,
20
+ language="auto",
21
+ use_itn=use_itn,
22
+ tokenizer=tokenizer)
23
+ # 你实现的语言转文本函数
24
+ def speech_to_text(audio_path, lang):
25
+ """
26
+ audio_path: 音频文件路径
27
+ lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko"
28
+ """
29
+ if not audio_path:
30
+ return "无音频"
31
+
32
+ pipeline.choose_language(language=lang)
33
+ asr_res = pipeline.infer(audio_path, print_rtf=True)
34
+ res = " ".join([rich_transcription_postprocess(i) for i in asr_res])
35
+ # TODO: 这里写你的语音识别逻辑
36
+ # 返回一个示例文本
37
+ return res
38
+
39
+
40
+ def main():
41
+ with gr.Blocks() as demo:
42
+ with gr.Row():
43
+ output_text = gr.Textbox(
44
+ label="识别结果",
45
+ lines=5
46
+ )
47
+
48
+
49
+ with gr.Row():
50
+ audio_input = gr.Audio(
51
+ sources=["microphone"],
52
+ type="filepath",
53
+ label="录制或上传音频",
54
+ format="mp3"
55
+ )
56
+ lang_dropdown = gr.Dropdown(
57
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
58
+ value="auto",
59
+ label="选择音频语言"
60
+ )
61
+
62
+
63
+
64
+
65
+ audio_input.change(
66
+ fn=speech_to_text,
67
+ inputs=[audio_input, lang_dropdown],
68
+ outputs=output_text
69
+ )
70
+
71
+ demo.launch(
72
+ server_name="0.0.0.0",
73
+ server_port=7860,
74
+ ssl_certfile="./cert.pem", ssl_keyfile="./key.pem", ssl_verify=False
75
+ )
76
+
77
+ if __name__ == "__main__":
78
+ main()
requirements.txt CHANGED
@@ -2,4 +2,7 @@ huggingface_hub
2
  numpy<2
3
  kaldi-native-fbank
4
  librosa==0.9.1
5
- sentencepiece
 
 
 
 
2
  numpy<2
3
  kaldi-native-fbank
4
  librosa==0.9.1
5
+ sentencepiece
6
+ fastapi
7
+ gradio
8
+ emoji
sensevoice_ax650/sensevoice.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bd4f1df559d3788c2873eccad31a4e58260a1342a0cdacdad959b324fb155974
3
- size 261965288
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fad2f710930c23c91ea62d6951c0c6161194e3cf356fc31611798419c6638dd9
3
+ size 262381979
server.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fastapi import FastAPI, HTTPException, Body
3
+ from fastapi.responses import JSONResponse
4
+ from typing import List, Optional
5
+ import logging
6
+ import json
7
+ from SenseVoiceAx import SenseVoiceAx
8
+ from tokenizer import SentencepiecesTokenizer
9
+ from print_utils import rich_transcription_postprocess, rich_print_asr_res
10
+ from download_utils import download_model
11
+ import os
12
+ import librosa
13
+
14
+ # 初始化日志
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
19
+
20
+ # 全局变量存储模型
21
+ asr_model = None
22
+
23
+ @app.on_event("startup")
24
+ async def load_model():
25
+ """
26
+ 服务启动时加载ASR模型
27
+ """
28
+ global asr_model
29
+ logger.info("Loading ASR model...")
30
+
31
+ try:
32
+ # 模型加载
33
+ language = "auto"
34
+ use_itn = True # 标点符号预测
35
+ max_len = 68
36
+
37
+ model_path_root = download_model("SenseVoice")
38
+ model_path = os.path.join(model_path_root, "sensevoice_ax650", "sensevoice.axmodel")
39
+ bpemodel = os.path.join(model_path_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
40
+
41
+ assert os.path.exists(model_path), f"model {model_path} not exist"
42
+
43
+ print(f"language: {language}")
44
+ print(f"use_itn: {use_itn}")
45
+ print(f"model_path: {model_path}")
46
+
47
+ tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
48
+ asr_model = SenseVoiceAx(model_path,
49
+ max_len=max_len,
50
+ language=language,
51
+ use_itn=use_itn,
52
+ tokenizer=tokenizer)
53
+
54
+ logger.info("ASR model loaded successfully")
55
+ except Exception as e:
56
+ logger.error(f"Failed to load ASR model: {str(e)}")
57
+ raise
58
+
59
+ def validate_audio_data(audio_data: List[float]) -> np.ndarray:
60
+ """
61
+ 验证并转换音频数据为numpy数组
62
+
63
+ 参数:
64
+ - audio_data: 浮点数列表表示的音频数据
65
+
66
+ 返回:
67
+ - 验证后的numpy数组
68
+ """
69
+ try:
70
+ # 转换为numpy数组
71
+ np_array = np.array(audio_data, dtype=np.float32)
72
+
73
+ # 验证数据有效性
74
+ if np_array.ndim != 1:
75
+ raise ValueError("Audio data must be 1-dimensional")
76
+
77
+ if len(np_array) == 0:
78
+ raise ValueError("Audio data cannot be empty")
79
+
80
+ return np_array
81
+ except Exception as e:
82
+ raise ValueError(f"Invalid audio data: {str(e)}")
83
+
84
+ @app.get("/get_language", summary="Get current language")
85
+ async def get_language():
86
+ return JSONResponse(content={"language": asr_model.language})
87
+
88
+ @app.get("/get_language_options", summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]")
89
+ async def get_language_options():
90
+ return JSONResponse(content={"language_options": asr_model.language_options})
91
+
92
+ @app.post("/asr", summary="Recognize speech from numpy audio data")
93
+ async def recognize_speech(
94
+ audio_data: List[float] = Body(..., embed=True, description="Audio data as list of floats"),
95
+ sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
96
+ language: Optional[str] = Body("auto", description="Language")
97
+ ):
98
+ """
99
+ 接收numpy数组格式的音频数据并返回识别结果
100
+
101
+ 参数:
102
+ - audio_data: 浮点数列表表示的音频数据
103
+ - sample_rate: 音频采样率(默认16000Hz)
104
+
105
+ 返回:
106
+ - JSON包含识别文本
107
+ """
108
+ try:
109
+ # 检查模型是否已加载
110
+ if asr_model is None:
111
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
112
+
113
+ logger.info(f"Received audio data with length: {len(audio_data)}")
114
+
115
+ # 验证并转换数据
116
+ np_audio = validate_audio_data(audio_data)
117
+ if sample_rate != asr_model.sample_rate:
118
+ np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
119
+
120
+ # 调用模型进行识别
121
+ result = asr_model.infer_waveform(np_audio, language)
122
+
123
+ return JSONResponse(content={"text": result})
124
+
125
+ except ValueError as e:
126
+ logger.error(f"Validation error: {str(e)}")
127
+ raise HTTPException(status_code=400, detail=str(e))
128
+ except Exception as e:
129
+ logger.error(f"Recognition error: {str(e)}")
130
+ raise HTTPException(status_code=500, detail=str(e))
131
+
132
+ if __name__ == "__main__":
133
+ import uvicorn
134
+ uvicorn.run(app, host="0.0.0.0", port=8000)
test_wer.py CHANGED
@@ -4,73 +4,267 @@ from SenseVoiceAx import SenseVoiceAx
4
  from tokenizer import SentencepiecesTokenizer
5
  from print_utils import rich_transcription_postprocess, rich_print_asr_res
6
  from download_utils import download_model
7
- import jiwer
 
 
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def get_args():
11
  parser = argparse.ArgumentParser()
12
- parser.add_argument("--dataset", "-d", required=True, type=str, help="Input dataset")
 
13
  parser.add_argument("--language", "-l", required=False, type=str, default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"])
 
14
  return parser.parse_args()
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def main():
 
18
  args = get_args()
19
 
20
- dataset = args.dataset
21
  language = args.language
22
  use_itn = False # 标点符号预测
 
 
 
 
 
 
 
 
 
23
 
24
  model_path_root = download_model("SenseVoice")
25
- model_path = os.path.join(model_path_root, "sensevoice_ax650", "sensevoice.axmodel")
 
26
  bpemodel = os.path.join(model_path_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
27
 
28
  assert os.path.exists(model_path), f"model {model_path} not exist"
29
 
30
- print(f"dataset: {dataset}")
31
- print(f"language: {language}")
32
- print(f"use_itn: {use_itn}")
33
- print(f"model_path: {model_path}")
34
 
35
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
36
- pipeline = SenseVoiceAx(model_path, language, use_itn, tokenizer=tokenizer)
37
-
38
- # Load dataset
39
- wav_names = []
40
- references = []
41
- with open(os.path.join(dataset, "ground_truth.txt"), "r") as f:
42
- for line in f:
43
- line = line.strip()
44
- w, r = line.split(" ")
45
- wav_names.append(w)
46
- references.append(r)
47
 
48
  # Iterate over dataset
49
  hyp = []
50
- wer_file = open("wer.txt", "w")
51
- for wav_name, reference in zip(wav_names, references):
52
- wav_path = os.path.join(dataset, "aishell_S0764", wav_name + ".wav")
 
 
 
53
 
54
- asr_res = pipeline.infer(wav_path, print_rtf=False)
55
- hypothesis = rich_print_asr_res(asr_res, will_print=False, remove_punc=True)
56
- hyp.append(hypothesis)
 
 
 
 
57
 
58
- wer = jiwer.cer(
59
- reference,
60
- hypothesis
61
- )
 
62
 
63
- line_content = f"{wav_name} reference: {reference} hypothesis: {hypothesis} WER: {wer}"
64
- wer_file.write(line_content + "\n")
65
- print(line_content)
66
-
67
- total_wer = jiwer.cer(
68
- references,
69
- hyp
70
- )
71
- print(f"Total WER: {total_wer}")
72
- wer_file.write(f"Total WER: {total_wer}")
73
- wer_file.close()
74
 
75
  if __name__ == "__main__":
76
  main()
 
4
  from tokenizer import SentencepiecesTokenizer
5
  from print_utils import rich_transcription_postprocess, rich_print_asr_res
6
  from download_utils import download_model
7
+ import logging
8
+ import re
9
+ import emoji
10
 
11
 
12
+ def setup_logging():
13
+ """配置日志系统,同时输出到控制台和文件"""
14
+ # 获取脚本所在目录
15
+ script_dir = os.path.dirname(os.path.abspath(__file__))
16
+ log_file = os.path.join(script_dir, "test_wer.log")
17
+
18
+ # 配置日志格式
19
+ log_format = '%(asctime)s - %(levelname)s - %(message)s'
20
+ date_format = '%Y-%m-%d %H:%M:%S'
21
+
22
+ # 创建logger
23
+ logger = logging.getLogger()
24
+ logger.setLevel(logging.INFO)
25
+
26
+ # 清除现有的handler
27
+ for handler in logger.handlers[:]:
28
+ logger.removeHandler(handler)
29
+
30
+ # 创建文件handler
31
+ file_handler = logging.FileHandler(log_file, mode='w', encoding='utf-8')
32
+ file_handler.setLevel(logging.INFO)
33
+ file_formatter = logging.Formatter(log_format, date_format)
34
+ file_handler.setFormatter(file_formatter)
35
+
36
+ # 创建控制台handler
37
+ console_handler = logging.StreamHandler()
38
+ console_handler.setLevel(logging.INFO)
39
+ console_formatter = logging.Formatter(log_format, date_format)
40
+ console_handler.setFormatter(console_formatter)
41
+
42
+ # 添加handler到logger
43
+ logger.addHandler(file_handler)
44
+ logger.addHandler(console_handler)
45
+
46
+ return logger
47
+
48
+
49
+ class AIShellDataset:
50
+ def __init__(self, gt_path: str):
51
+ """
52
+ 初始化数据集
53
+
54
+ Args:
55
+ json_path: voice.json文件的路径
56
+ """
57
+ self.gt_path = gt_path
58
+ self.dataset_dir = os.path.dirname(gt_path)
59
+ self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
60
+
61
+ # 检查必要文件和文件夹是否存在
62
+ assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
63
+ assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
64
+
65
+ # 加载数据
66
+ self.data = []
67
+ with open(gt_path, 'r', encoding='utf-8') as f:
68
+ for line in f:
69
+ line = line.strip()
70
+ audio_path, gt = line.split(" ")
71
+ audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
72
+ self.data.append({"audio_path": audio_path, "gt": gt})
73
+
74
+ # 使用logging而不是print
75
+ logger = logging.getLogger()
76
+ logger.info(f"加载了 {len(self.data)} 条数据")
77
+
78
+ def __iter__(self):
79
+ """返回迭代器"""
80
+ self.index = 0
81
+ return self
82
+
83
+ def __next__(self):
84
+ """返回下一个数据项"""
85
+ if self.index >= len(self.data):
86
+ raise StopIteration
87
+
88
+ item = self.data[self.index]
89
+ audio_path = item["audio_path"]
90
+ ground_truth = item["gt"]
91
+
92
+ self.index += 1
93
+ return audio_path, ground_truth
94
+
95
+ def __len__(self):
96
+ """返回数据集大小"""
97
+ return len(self.data)
98
+
99
+
100
+ class CommonVoiceDataset:
101
+ """Common Voice数据集解析器"""
102
+
103
+ def __init__(self, tsv_path: str):
104
+ """
105
+ 初始化数据集
106
+
107
+ Args:
108
+ json_path: voice.json文件的路径
109
+ """
110
+ self.tsv_path = tsv_path
111
+ self.dataset_dir = os.path.dirname(tsv_path)
112
+ self.voice_dir = os.path.join(self.dataset_dir, "clips")
113
+
114
+ # 检查必要文件和文件夹是否存在
115
+ assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
116
+ assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
117
+
118
+ # 加载JSON数据
119
+ self.data = []
120
+ with open(tsv_path, 'r', encoding='utf-8') as f:
121
+ f.readline()
122
+ for line in f:
123
+ line = line.strip()
124
+ splits = line.split("\t")
125
+ audio_path = splits[1]
126
+ gt = splits[3]
127
+ audio_path = os.path.join(self.voice_dir, audio_path)
128
+ self.data.append({"audio_path": audio_path, "gt": gt})
129
+
130
+ # 使用logging而不是print
131
+ logger = logging.getLogger()
132
+ logger.info(f"加载了 {len(self.data)} 条数据")
133
+
134
+ def __iter__(self):
135
+ """返回迭代器"""
136
+ self.index = 0
137
+ return self
138
+
139
+ def __next__(self):
140
+ """返回下一个数据项"""
141
+ if self.index >= len(self.data):
142
+ raise StopIteration
143
+
144
+ item = self.data[self.index]
145
+ audio_path = item["audio_path"]
146
+ ground_truth = item["gt"]
147
+
148
+ self.index += 1
149
+ return audio_path, ground_truth
150
+
151
+ def __len__(self):
152
+ """返回数据集大小"""
153
+ return len(self.data)
154
+
155
+
156
  def get_args():
157
  parser = argparse.ArgumentParser()
158
+ parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
159
+ parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
160
  parser.add_argument("--language", "-l", required=False, type=str, default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"])
161
+ parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
162
  return parser.parse_args()
163
 
164
 
165
+ def min_distance(word1: str, word2: str) -> int:
166
+
167
+ row = len(word1) + 1
168
+ column = len(word2) + 1
169
+
170
+ cache = [ [0]*column for i in range(row) ]
171
+
172
+ for i in range(row):
173
+ for j in range(column):
174
+
175
+ if i ==0 and j ==0:
176
+ cache[i][j] = 0
177
+ elif i == 0 and j!=0:
178
+ cache[i][j] = j
179
+ elif j == 0 and i!=0:
180
+ cache[i][j] = i
181
+ else:
182
+ if word1[i-1] == word2[j-1]:
183
+ cache[i][j] = cache[i-1][j-1]
184
+ else:
185
+ replace = cache[i-1][j-1] + 1
186
+ insert = cache[i][j-1] + 1
187
+ remove = cache[i-1][j] + 1
188
+
189
+ cache[i][j] = min(replace, insert, remove)
190
+
191
+ return cache[row-1][column-1]
192
+
193
+
194
+ def remove_punctuation(text):
195
+ # 定义正则表达式模式,匹配所有标点符号
196
+ # 这个模式包括常见的标点符号和中文标点
197
+ pattern = r'[^\w\s]|_'
198
+
199
+ # 使用sub方法将所有匹配的标点符号替换为空字符串
200
+ cleaned_text = re.sub(pattern, '', text)
201
+
202
+ return cleaned_text
203
+
204
+
205
  def main():
206
+ logger = setup_logging()
207
  args = get_args()
208
 
 
209
  language = args.language
210
  use_itn = False # 标点符号预测
211
+ max_num = args.max_num
212
+
213
+ dataset_type = args.dataset.lower()
214
+ if dataset_type == "aishell":
215
+ dataset = AIShellDataset(args.gt_path)
216
+ elif dataset_type == "common_voice":
217
+ dataset = CommonVoiceDataset(args.gt_path)
218
+ else:
219
+ raise ValueError(f"Unknown dataset type {dataset_type}")
220
 
221
  model_path_root = download_model("SenseVoice")
222
+ # model_path = os.path.join(model_path_root, "sensevoice_ax650", "sensevoice.axmodel")
223
+ model_path = "./model_convert/output_dir/model.onnx"
224
  bpemodel = os.path.join(model_path_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
225
 
226
  assert os.path.exists(model_path), f"model {model_path} not exist"
227
 
228
+ logger.info(f"dataset: {args.dataset}")
229
+ logger.info(f"language: {language}")
230
+ logger.info(f"use_itn: {use_itn}")
231
+ logger.info(f"model_path: {model_path}")
232
 
233
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
234
+ pipeline = SenseVoiceAx(model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256)
 
 
 
 
 
 
 
 
 
 
235
 
236
  # Iterate over dataset
237
  hyp = []
238
+ references = []
239
+ all_character_error_num = 0
240
+ all_character_num = 0
241
+ max_data_num = max_num if max_num > 0 else len(dataset)
242
+ for n, (audio_path, reference) in enumerate(dataset):
243
+ reference = remove_punctuation(reference).lower()
244
 
245
+ asr_res = pipeline.infer(audio_path, print_rtf=False)
246
+ hypothesis = rich_print_asr_res(asr_res, will_print=False, remove_punc=True).lower()
247
+ hypothesis = emoji.replace_emoji(hypothesis, replace='')
248
+
249
+ character_error_num = min_distance(reference, hypothesis)
250
+ character_num = len(reference)
251
+ character_error_rate = character_error_num / character_num * 100
252
 
253
+ all_character_error_num += character_error_num
254
+ all_character_num += character_num
255
+
256
+ hyp.append(hypothesis)
257
+ references.append(reference)
258
 
259
+ line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
260
+ logger.info(line_content)
261
+
262
+ if n + 1 >= max_data_num:
263
+ break
264
+
265
+ total_character_error_rate = all_character_error_num / all_character_num * 100
266
+
267
+ logger.info(f"Total WER: {total_character_error_rate}%")
 
 
268
 
269
  if __name__ == "__main__":
270
  main()