inoryQwQ commited on
Commit
ff70a9c
·
1 Parent(s): e138696

1. move python scripts to new folder python/

Browse files

2. add cpp prebuilt binary and library
3. Update README for cpp demo

.gitattributes CHANGED
@@ -36,4 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  sensevoice.axmodel filter=lfs diff=lfs merge=lfs -text
37
  *.axmodel filter=lfs diff=lfs merge=lfs -text
38
  *.wav filter=lfs diff=lfs merge=lfs -text
39
- *.mp3 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
36
  sensevoice.axmodel filter=lfs diff=lfs merge=lfs -text
37
  *.axmodel filter=lfs diff=lfs merge=lfs -text
38
  *.wav filter=lfs diff=lfs merge=lfs -text
39
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
40
+ cpp/ax650/lib/libax_asr_api.so filter=lfs diff=lfs merge=lfs -text
41
+ cpp/ax630c/lib/libax_asr_api.so filter=lfs diff=lfs merge=lfs -text
42
+ cpp/ax650/test_sensevoice filter=lfs diff=lfs merge=lfs -text
43
+ cpp/ax630c/test_sensevoice filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -4,13 +4,13 @@ language:
4
  - en
5
  pipeline_tag: automatic-speech-recognition
6
  ---
7
- # sensevoice.axera
8
  FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
9
 
10
  ## TODO
11
 
12
  - [x] 支持AX630C
13
- - [ ] 支持C++
14
  - [x] 支持FastAPI
15
 
16
  ## 功能
@@ -25,6 +25,13 @@ FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseV
25
  - [x] AX650N
26
  - [x] AX630C
27
 
 
 
 
 
 
 
 
28
 
29
  ## 环境安装
30
 
@@ -60,6 +67,9 @@ pip install https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3.rc2
60
  安装,或把版本号更改为你想使用的版本
61
 
62
  ## 使用
 
 
 
63
  ```
64
  # 首次运行会自动从huggingface上下载模型, 保存到models中
65
  python3 main.py -i 输入音频文件
@@ -71,6 +81,20 @@ python3 main.py -i 输入音频文件
71
  | --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
72
  | --streaming | 流式识别 | |
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  ### 示例:
76
  example下有测试音频
 
4
  - en
5
  pipeline_tag: automatic-speech-recognition
6
  ---
7
+ # SenseVoice
8
  FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
9
 
10
  ## TODO
11
 
12
  - [x] 支持AX630C
13
+ - [x] 支持C++
14
  - [x] 支持FastAPI
15
 
16
  ## 功能
 
25
  - [x] AX650N
26
  - [x] AX630C
27
 
28
+ ## Table of contents
29
+
30
+ - [环境安装](#环境安装)
31
+ - [使用](#使用)
32
+ - [准确率](#准确率)
33
+ - [技术讨论](#技术讨论)
34
+
35
 
36
  ## 环境安装
37
 
 
67
  安装,或把版本号更改为你想使用的版本
68
 
69
  ## 使用
70
+
71
+ ### Python
72
+
73
  ```
74
  # 首次运行会自动从huggingface上下载模型, 保存到models中
75
  python3 main.py -i 输入音频文件
 
81
  | --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
82
  | --streaming | 流式识别 | |
83
 
84
+ ### CPP
85
+
86
+ - AX650
87
+ ```
88
+ ./cpp/ax650/test_sensevoice -a example/zh.mp3 -p sensevoice_ax650/
89
+ ```
90
+
91
+ - AX630C
92
+ ```
93
+ ./cpp/ax630c/test_sensevoice -a example/zh.mp3 -p sensevoice_ax630c/
94
+ ```
95
+
96
+ 对应的源码在[Github](https://github.com/AXERA-TECH/ax_asr_api)上
97
+
98
 
99
  ### 示例:
100
  example下有测试音频
SenseVoiceAx.py DELETED
@@ -1,335 +0,0 @@
1
- import axengine as axe
2
- import numpy as np
3
- import librosa
4
- from frontend import WavFrontend
5
- import time
6
- from typing import List, Union, Optional
7
- from asr_decoder import CTCDecoder
8
- from online_fbank import OnlineFbank
9
- import torch
10
-
11
-
12
- def unique_consecutive(arr):
13
- """
14
- 找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1)
15
-
16
- 参数:
17
- arr: 一维numpy数组
18
-
19
- 返回:
20
- unique_values: 去除连续重复值后的数组
21
- """
22
- if len(arr) == 0:
23
- return np.array([])
24
-
25
- if len(arr) == 1:
26
- return arr.copy()
27
-
28
- # 找出变化的位置
29
- diff = np.diff(arr)
30
- change_positions = np.where(diff != 0)[0] + 1
31
-
32
- # 添加起始位置
33
- start_positions = np.concatenate(([0], change_positions))
34
-
35
- # 获取唯一值(每个连续段的第一个值)
36
- unique_values = arr[start_positions]
37
-
38
- return unique_values
39
-
40
-
41
- class SenseVoiceAx:
42
- """SenseVoice axmodel runner"""
43
-
44
- def __init__(
45
- self,
46
- model_path: str,
47
- cmvn_file: str,
48
- token_file: str,
49
- bpe_model: str = None,
50
- max_seq_len: int = 256,
51
- beam_size: int = 3,
52
- hot_words: Optional[List[str]] = None,
53
- streaming: bool = False,
54
- providers=["AxEngineExecutionProvider"],
55
- ):
56
- """
57
- Initialize SenseVoiceAx
58
-
59
- Args:
60
- model_path: Path of axmodel
61
- max_len: Fixed shape of input of axmodel
62
- beam_size: Max number of hypos to hold after each decode step
63
- language: Support auto, zh(Chinese), en(English), yue(Cantonese), ja(Japanese), ko(Korean)
64
- hot_words: Words that may fail to recognize,
65
- special words/phrases (aka hotwords) like rare words, personalized information etc.
66
- use_itn: Allow Invert Text Normalization if True,
67
- ITN converts ASR model output into its written form to improve text readability,
68
- For example, the ITN module replaces “one hundred and twenty-three dollars” transcribed by an ASR model with “$123.”
69
- streaming: Processes audio in small segments or "chunks" sequentially and outputs text on the fly.
70
- Use stream_infer method if streaming is true otherwise infer.
71
-
72
- """
73
-
74
- self.streaming = streaming
75
-
76
- self.frontend = WavFrontend(
77
- cmvn_file=cmvn_file,
78
- fs=16000,
79
- window="hamming",
80
- n_mels=80,
81
- frame_length=25,
82
- frame_shift=10,
83
- lfr_m=7,
84
- lfr_n=6,
85
- )
86
-
87
- self.model = axe.InferenceSession(model_path, providers=providers)
88
- self.sample_rate = 16000
89
- self.blank_id = 0
90
- self.max_seq_len = max_seq_len
91
- self.padding = 16
92
- self.input_size = 560
93
- self.query_num = 4
94
- self.tokens = self.load_tokens(token_file)
95
-
96
- self.lid_dict = {
97
- "auto": 0,
98
- "zh": 3,
99
- "en": 4,
100
- "yue": 7,
101
- "ja": 11,
102
- "ko": 12,
103
- "nospeech": 13,
104
- }
105
-
106
- # decoder
107
- if beam_size > 1 and hot_words is not None:
108
- self.beam_size = beam_size
109
- symbol_table = {}
110
- for i in range(len(self.tokens)):
111
- symbol_table[self.tokens[i]] = i
112
- self.decoder = CTCDecoder(hot_words, symbol_table, bpe_model)
113
- else:
114
- self.beam_size = 1
115
- self.decoder = CTCDecoder()
116
-
117
- if streaming:
118
- self.cur_idx = -1
119
- self.chunk_size = max_seq_len - self.padding
120
- self.caches_shape = (max_seq_len, self.input_size)
121
- self.caches = np.zeros(self.caches_shape, dtype=np.float32)
122
- self.zeros = np.zeros((1, self.input_size), dtype=np.float32)
123
- self.neg_mean, self.inv_stddev = (
124
- self.frontend.cmvn[0, :],
125
- self.frontend.cmvn[1, :],
126
- )
127
-
128
- self.fbank = OnlineFbank(window_type="hamming")
129
- self.stream_mask = self.sequence_mask(
130
- max_seq_len + self.query_num, max_seq_len + self.query_num
131
- )
132
-
133
- def load_tokens(self, token_file):
134
- tokens = []
135
- with open(token_file, "r") as f:
136
- for line in f:
137
- tokens.append(line[:-1])
138
- return tokens
139
-
140
- @property
141
- def language_options(self):
142
- return list(self.lid_dict.keys())
143
-
144
- def sequence_mask(self, max_seq_len, actual_seq_len):
145
- mask = np.zeros((1, 1, max_seq_len), dtype=np.int32)
146
- mask[:, :, :actual_seq_len] = 1
147
- return mask
148
-
149
- def load_data(self, filepath: str) -> np.ndarray:
150
- waveform, _ = librosa.load(filepath, sr=self.sample_rate)
151
- return waveform.flatten()
152
-
153
- @staticmethod
154
- def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
155
- def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
156
- pad_width = ((0, max_feat_len - cur_len), (0, 0))
157
- return np.pad(feat, pad_width, "constant", constant_values=0)
158
-
159
- feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
160
- feats = np.array(feat_res).astype(np.float32)
161
- return feats
162
-
163
- def preprocess(self, waveform):
164
- feats, feats_len = [], []
165
- for wf in [waveform]:
166
- speech, _ = self.frontend.fbank(wf)
167
- feat, feat_len = self.frontend.lfr_cmvn(speech)
168
- feats.append(feat)
169
- feats_len.append(feat_len)
170
-
171
- feats = self.pad_feats(feats, np.max(feats_len))
172
- feats_len = np.array(feats_len).astype(np.int32)
173
- return feats, feats_len
174
-
175
- def postprocess(self, ctc_logits, encoder_out_lens):
176
- # 提取数据
177
- x = ctc_logits[0, 4 : encoder_out_lens[0], :]
178
-
179
- # 获取最大值索引
180
- yseq = np.argmax(x, axis=-1)
181
-
182
- # 去除连续重复元素
183
- yseq = unique_consecutive(yseq)
184
-
185
- # 创建掩码并过滤 blank_id
186
- mask = yseq != self.blank_id
187
- token_int = yseq[mask].tolist()
188
-
189
- return token_int
190
-
191
- def infer_waveform(self, waveform: np.ndarray, language="auto"):
192
- # start = time.time()
193
- feat, feat_len = self.preprocess(waveform)
194
- # print(f"Preprocess take {time.time() - start}s")
195
-
196
- slice_len = self.max_seq_len - self.query_num
197
- slice_num = int(np.ceil(feat.shape[1] / slice_len))
198
-
199
- language_token = self.lid_dict[language]
200
- language_token = np.array([language_token], dtype=np.int32)
201
-
202
- asr_res = []
203
- for i in range(slice_num):
204
- if i == 0:
205
- sub_feat = feat[:, i * slice_len : (i + 1) * slice_len, :]
206
- else:
207
- sub_feat = feat[
208
- :,
209
- i * slice_len - self.padding : (i + 1) * slice_len - self.padding,
210
- :,
211
- ]
212
-
213
- real_len = sub_feat.shape[1]
214
- if real_len < self.max_seq_len:
215
- sub_feat = np.concatenate(
216
- [
217
- sub_feat,
218
- np.zeros(
219
- (1, self.max_seq_len - real_len, sub_feat.shape[-1]),
220
- dtype=np.float32,
221
- ),
222
- ],
223
- axis=1,
224
- )
225
-
226
- mask = self.sequence_mask(self.max_seq_len + self.query_num, real_len)
227
-
228
- # start = time.time()
229
- outputs = self.model.run(
230
- None,
231
- {
232
- "speech": sub_feat,
233
- "mask": mask,
234
- "language": language_token,
235
- },
236
- )
237
- ctc_logits, encoder_out_lens = outputs
238
-
239
- token_int = self.postprocess(ctc_logits, encoder_out_lens)
240
-
241
- asr_res.extend(token_int)
242
-
243
- text = "".join([self.tokens[i] for i in asr_res])
244
- return text
245
-
246
- def infer(
247
- self, filepath_or_data: Union[np.ndarray, str], language="auto", print_rtf=False
248
- ):
249
- assert not self.streaming, "This method is for non-streaming model"
250
-
251
- if isinstance(filepath_or_data, str):
252
- waveform = self.load_data(filepath_or_data)
253
- else:
254
- waveform = filepath_or_data
255
-
256
- total_time = waveform.shape[-1] / self.sample_rate
257
-
258
- start = time.time()
259
- asr_res = self.infer_waveform(waveform, language)
260
- latency = time.time() - start
261
-
262
- if print_rtf:
263
- rtf = latency / total_time
264
- print(f"RTF: {rtf} Latency: {latency}s Total length: {total_time}s")
265
- return asr_res
266
-
267
- def decode(self, times, tokens):
268
- times_ms = []
269
- for step, token in zip(times, tokens):
270
- if len(self.tokens[token].strip()) == 0:
271
- continue
272
- times_ms.append(step * 60)
273
- return times_ms, "".join([self.tokens[i] for i in tokens])
274
-
275
- def reset(self):
276
- self.cur_idx = -1
277
- self.decoder.reset()
278
- self.fbank = OnlineFbank(window_type="hamming")
279
- self.caches = np.zeros(self.caches_shape)
280
-
281
- def get_size(self):
282
- effective_size = self.cur_idx + 1 - self.padding
283
- if effective_size <= 0:
284
- return 0
285
- return effective_size % self.chunk_size or self.chunk_size
286
-
287
- def stream_infer(self, audio, is_last, language="auto"):
288
- assert self.streaming, "This method is for streaming model"
289
-
290
- language_token = self.lid_dict[language]
291
- language_token = np.array([language_token], dtype=np.int32)
292
-
293
- self.fbank.accept_waveform(audio, is_last)
294
- features = self.fbank.get_lfr_frames(
295
- neg_mean=self.neg_mean, inv_stddev=self.inv_stddev
296
- )
297
-
298
- if is_last and len(features) == 0:
299
- features = self.zeros
300
-
301
- for idx, feature in enumerate(features):
302
- is_last = is_last and idx == features.shape[0] - 1
303
- self.caches = np.roll(self.caches, -1, axis=0)
304
- self.caches[-1, :] = feature
305
- self.cur_idx += 1
306
- cur_size = self.get_size()
307
- if cur_size != self.chunk_size and not is_last:
308
- continue
309
-
310
- speech = self.caches[None, ...]
311
- outputs = self.model.run(
312
- None,
313
- {
314
- "speech": speech,
315
- "mask": self.stream_mask,
316
- "language": language_token,
317
- },
318
- )
319
- ctc_logits, encoder_out_lens = outputs
320
- probs = ctc_logits[0, 4 : encoder_out_lens[0]]
321
- probs = torch.from_numpy(probs)
322
-
323
- if cur_size != self.chunk_size:
324
- probs = probs[self.chunk_size - cur_size :]
325
- if not is_last:
326
- probs = probs[: self.chunk_size]
327
- if self.beam_size > 1:
328
- res = self.decoder.ctc_prefix_beam_search(
329
- probs, beam_size=self.beam_size, is_last=is_last
330
- )
331
- times_ms, text = self.decode(res["times"][0], res["tokens"][0])
332
- else:
333
- res = self.decoder.ctc_greedy_search(probs, is_last=is_last)
334
- times_ms, text = self.decode(res["times"], res["tokens"])
335
- yield {"timestamps": times_ms, "text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cpp/ax630c/include/ax_asr_api.h ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**************************************************************************************************
2
+ *
3
+ * Copyright (c) 2019-2026 Axera Semiconductor (Ningbo) Co., Ltd. All Rights Reserved.
4
+ *
5
+ * This source file is the property of Axera Semiconductor (Ningbo) Co., Ltd. and
6
+ * may not be copied or distributed in any isomorphic form without the prior
7
+ * written consent of Axera Semiconductor (Ningbo) Co., Ltd.
8
+ *
9
+ **************************************************************************************************/
10
+ #ifndef _AX_ASR_API_H_
11
+ #define _AX_ASR_API_H_
12
+
13
+ #ifdef __cplusplus
14
+ extern "C" {
15
+ #endif
16
+
17
+ #define AX_ASR_API __attribute__((visibility("default")))
18
+
19
+
20
+ // Supported asr
21
+ enum AX_ASR_TYPE_E {
22
+ AX_WHISPER_TINY = 0,
23
+ AX_WHISPER_BASE,
24
+ AX_WHISPER_SMALL,
25
+ AX_WHISPER_TURBO,
26
+ AX_SENSEVOICE
27
+ };
28
+
29
+ /**
30
+ * @brief Opaque handle type for asr ASR context
31
+ *
32
+ * This handle encapsulates all internal state of the asr ASR system.
33
+ * The actual implementation is hidden from C callers to maintain ABI stability.
34
+ */
35
+ typedef void* AX_ASR_HANDLE;
36
+
37
+ /**
38
+ * @brief Initialize the asr ASR system with specific configuration
39
+ *
40
+ * Creates and initializes a new asr ASR context with the specified
41
+ * model type, model path, and language. This function loads the appropriate
42
+ * models, configures the recognizer, and prepares it for speech recognition.
43
+ *
44
+ * @param model_type Type of asr model to use
45
+ * @param model_path Directory path where model files are stored
46
+ * Model files are expected to be in the format: *.axmodel
47
+ *
48
+ * @return AX_ASR_HANDLE Opaque handle to the initialized asr context,
49
+ * or NULL if initialization fails
50
+ *
51
+ * @note The caller is responsible for calling AX_ASR_Uninit() to free
52
+ * resources when the handle is no longer needed.
53
+ * @example
54
+ * // Initialize recognition with whisper tiny model
55
+ * AX_ASR_HANDLE handle = AX_ASR_Init(WHISPER_TINY, "./models-ax650/");
56
+ *
57
+ */
58
+ AX_ASR_API AX_ASR_HANDLE AX_ASR_Init(AX_ASR_TYPE_E asr_type, const char* model_path);
59
+
60
+ /**
61
+ * @brief Deinitialize and release asr ASR resources
62
+ *
63
+ * Cleans up all resources associated with the asr context, including
64
+ * unloading models, freeing memory, and releasing hardware resources.
65
+ *
66
+ * @param handle asr context handle obtained from AX_ASR_Init()
67
+ *
68
+ * @warning After calling this function, the handle becomes invalid and
69
+ * should not be used in any subsequent API calls.
70
+ */
71
+ AX_ASR_API void AX_ASR_Uninit(AX_ASR_HANDLE handle);
72
+
73
+ /**
74
+ * @brief Perform speech recognition and return dynamically allocated string
75
+ *
76
+ * @param handle asr context handle
77
+ * @param wav_file Path to the input 16k pcmf32 WAV audio file
78
+ * @param language Preferred language,
79
+ * For whisper, check https://whisper-api.com/docs/languages/
80
+ * For sensevoice, support auto, zh, en, yue, ja, ko
81
+ * @param result Pointer to receive the allocated result string
82
+ *
83
+ * @return int Status code (0 = success, <0 = error)
84
+ *
85
+ * @note The returned string is allocated with malloc() and must be freed
86
+ * by the caller using free() when no longer needed.
87
+ */
88
+ AX_ASR_API int AX_ASR_RunFile(AX_ASR_HANDLE handle,
89
+ const char* wav_file,
90
+ const char* language,
91
+ char** result);
92
+
93
+ /**
94
+ * @brief Perform speech recognition and return dynamically allocated string
95
+ *
96
+ * @param handle asr context handle
97
+ * @param pcm_data 16k Mono PCM f32 data, range from -1.0 to 1.0,
98
+ * will be resampled if not 16k
99
+ * @param num_samples Sample num of PCM data
100
+ * @param sample_rate Sample rate of input audio
101
+ * @param language Preferred language,
102
+ * For whisper, check https://whisper-api.com/docs/languages/
103
+ * For sensevoice, support auto, zh, en, yue, ja, ko
104
+ * @param result Pointer to receive the allocated result string
105
+ *
106
+ * @return int Status code (0 = success, <0 = error)
107
+ *
108
+ * @note The returned string is allocated with malloc() and must be freed
109
+ * by the caller using free() when no longer needed.
110
+ */
111
+ AX_ASR_API int AX_ASR_RunPCM(AX_ASR_HANDLE handle,
112
+ float* pcm_data,
113
+ int num_samples,
114
+ int sample_rate,
115
+ const char* language,
116
+ char** result);
117
+
118
+ #ifdef __cplusplus
119
+ }
120
+ #endif
121
+
122
+ #endif // _AX_ASR_API_H_
cpp/ax630c/lib/cmake/ax_asr_api/ax_asr_api-config-release.cmake ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Release".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "ax::ax_asr_api" for configuration "Release"
9
+ set_property(TARGET ax::ax_asr_api APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
10
+ set_target_properties(ax::ax_asr_api PROPERTIES
11
+ IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libax_asr_api.so"
12
+ IMPORTED_SONAME_RELEASE "libax_asr_api.so"
13
+ )
14
+
15
+ list(APPEND _cmake_import_check_targets ax::ax_asr_api )
16
+ list(APPEND _cmake_import_check_files_for_ax::ax_asr_api "${_IMPORT_PREFIX}/lib/libax_asr_api.so" )
17
+
18
+ # Commands beyond this point should not need to know the version.
19
+ set(CMAKE_IMPORT_FILE_VERSION)
cpp/ax630c/lib/cmake/ax_asr_api/ax_asr_api-config.cmake ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by CMake
2
+
3
+ if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
4
+ message(FATAL_ERROR "CMake >= 2.8.0 required")
5
+ endif()
6
+ if(CMAKE_VERSION VERSION_LESS "2.8.3")
7
+ message(FATAL_ERROR "CMake >= 2.8.3 required")
8
+ endif()
9
+ cmake_policy(PUSH)
10
+ cmake_policy(VERSION 2.8.3...3.28)
11
+ #----------------------------------------------------------------
12
+ # Generated CMake target import file.
13
+ #----------------------------------------------------------------
14
+
15
+ # Commands may need to know the format version.
16
+ set(CMAKE_IMPORT_FILE_VERSION 1)
17
+
18
+ # Protect against multiple inclusion, which would fail when already imported targets are added once more.
19
+ set(_cmake_targets_defined "")
20
+ set(_cmake_targets_not_defined "")
21
+ set(_cmake_expected_targets "")
22
+ foreach(_cmake_expected_target IN ITEMS ax::ax_asr_api)
23
+ list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
24
+ if(TARGET "${_cmake_expected_target}")
25
+ list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
26
+ else()
27
+ list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
28
+ endif()
29
+ endforeach()
30
+ unset(_cmake_expected_target)
31
+ if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
32
+ unset(_cmake_targets_defined)
33
+ unset(_cmake_targets_not_defined)
34
+ unset(_cmake_expected_targets)
35
+ unset(CMAKE_IMPORT_FILE_VERSION)
36
+ cmake_policy(POP)
37
+ return()
38
+ endif()
39
+ if(NOT _cmake_targets_defined STREQUAL "")
40
+ string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
41
+ string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
42
+ message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
43
+ endif()
44
+ unset(_cmake_targets_defined)
45
+ unset(_cmake_targets_not_defined)
46
+ unset(_cmake_expected_targets)
47
+
48
+
49
+ # Compute the installation prefix relative to this file.
50
+ get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
51
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
52
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
53
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
54
+ if(_IMPORT_PREFIX STREQUAL "/")
55
+ set(_IMPORT_PREFIX "")
56
+ endif()
57
+
58
+ # Create imported target ax::ax_asr_api
59
+ add_library(ax::ax_asr_api SHARED IMPORTED)
60
+
61
+ set_target_properties(ax::ax_asr_api PROPERTIES
62
+ INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
63
+ )
64
+
65
+ # Load information for each installed configuration.
66
+ file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/ax_asr_api-config-*.cmake")
67
+ foreach(_cmake_config_file IN LISTS _cmake_config_files)
68
+ include("${_cmake_config_file}")
69
+ endforeach()
70
+ unset(_cmake_config_file)
71
+ unset(_cmake_config_files)
72
+
73
+ # Cleanup temporary variables.
74
+ set(_IMPORT_PREFIX)
75
+
76
+ # Loop over all imported files and verify that they actually exist
77
+ foreach(_cmake_target IN LISTS _cmake_import_check_targets)
78
+ if(CMAKE_VERSION VERSION_LESS "3.28"
79
+ OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
80
+ OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
81
+ foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
82
+ if(NOT EXISTS "${_cmake_file}")
83
+ message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
84
+ \"${_cmake_file}\"
85
+ but this file does not exist. Possible reasons include:
86
+ * The file was deleted, renamed, or moved to another location.
87
+ * An install or uninstall procedure did not complete successfully.
88
+ * The installation package was faulty and contained
89
+ \"${CMAKE_CURRENT_LIST_FILE}\"
90
+ but not all the files it references.
91
+ ")
92
+ endif()
93
+ endforeach()
94
+ endif()
95
+ unset(_cmake_file)
96
+ unset("_cmake_import_check_files_for_${_cmake_target}")
97
+ endforeach()
98
+ unset(_cmake_target)
99
+ unset(_cmake_import_check_targets)
100
+
101
+ # This file does not depend on other imported targets which have
102
+ # been exported from the same project but in a separate export set.
103
+
104
+ # Commands beyond this point should not need to know the version.
105
+ set(CMAKE_IMPORT_FILE_VERSION)
106
+ cmake_policy(POP)
cpp/ax630c/lib/libax_asr_api.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b0ac391d15517d2bda5f589faa6b4bc0f3af6782cce9d1384c4f8a2f471c7fc
3
+ size 421408
cpp/ax630c/test_sensevoice ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad05307ac08e8b83589839788412a1e73bd48b8f8a8abfa079ab2bda9b547610
3
+ size 161088
cpp/ax650/include/ax_asr_api.h ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**************************************************************************************************
2
+ *
3
+ * Copyright (c) 2019-2026 Axera Semiconductor (Ningbo) Co., Ltd. All Rights Reserved.
4
+ *
5
+ * This source file is the property of Axera Semiconductor (Ningbo) Co., Ltd. and
6
+ * may not be copied or distributed in any isomorphic form without the prior
7
+ * written consent of Axera Semiconductor (Ningbo) Co., Ltd.
8
+ *
9
+ **************************************************************************************************/
10
+ #ifndef _AX_ASR_API_H_
11
+ #define _AX_ASR_API_H_
12
+
13
+ #ifdef __cplusplus
14
+ extern "C" {
15
+ #endif
16
+
17
+ #define AX_ASR_API __attribute__((visibility("default")))
18
+
19
+
20
+ // Supported asr
21
+ enum AX_ASR_TYPE_E {
22
+ AX_WHISPER_TINY = 0,
23
+ AX_WHISPER_BASE,
24
+ AX_WHISPER_SMALL,
25
+ AX_WHISPER_TURBO,
26
+ AX_SENSEVOICE
27
+ };
28
+
29
+ /**
30
+ * @brief Opaque handle type for asr ASR context
31
+ *
32
+ * This handle encapsulates all internal state of the asr ASR system.
33
+ * The actual implementation is hidden from C callers to maintain ABI stability.
34
+ */
35
+ typedef void* AX_ASR_HANDLE;
36
+
37
+ /**
38
+ * @brief Initialize the asr ASR system with specific configuration
39
+ *
40
+ * Creates and initializes a new asr ASR context with the specified
41
+ * model type, model path, and language. This function loads the appropriate
42
+ * models, configures the recognizer, and prepares it for speech recognition.
43
+ *
44
+ * @param model_type Type of asr model to use
45
+ * @param model_path Directory path where model files are stored
46
+ * Model files are expected to be in the format: *.axmodel
47
+ *
48
+ * @return AX_ASR_HANDLE Opaque handle to the initialized asr context,
49
+ * or NULL if initialization fails
50
+ *
51
+ * @note The caller is responsible for calling AX_ASR_Uninit() to free
52
+ * resources when the handle is no longer needed.
53
+ * @example
54
+ * // Initialize recognition with whisper tiny model
55
+ * AX_ASR_HANDLE handle = AX_ASR_Init(WHISPER_TINY, "./models-ax650/");
56
+ *
57
+ */
58
+ AX_ASR_API AX_ASR_HANDLE AX_ASR_Init(AX_ASR_TYPE_E asr_type, const char* model_path);
59
+
60
+ /**
61
+ * @brief Deinitialize and release asr ASR resources
62
+ *
63
+ * Cleans up all resources associated with the asr context, including
64
+ * unloading models, freeing memory, and releasing hardware resources.
65
+ *
66
+ * @param handle asr context handle obtained from AX_ASR_Init()
67
+ *
68
+ * @warning After calling this function, the handle becomes invalid and
69
+ * should not be used in any subsequent API calls.
70
+ */
71
+ AX_ASR_API void AX_ASR_Uninit(AX_ASR_HANDLE handle);
72
+
73
+ /**
74
+ * @brief Perform speech recognition and return dynamically allocated string
75
+ *
76
+ * @param handle asr context handle
77
+ * @param wav_file Path to the input 16k pcmf32 WAV audio file
78
+ * @param language Preferred language,
79
+ * For whisper, check https://whisper-api.com/docs/languages/
80
+ * For sensevoice, support auto, zh, en, yue, ja, ko
81
+ * @param result Pointer to receive the allocated result string
82
+ *
83
+ * @return int Status code (0 = success, <0 = error)
84
+ *
85
+ * @note The returned string is allocated with malloc() and must be freed
86
+ * by the caller using free() when no longer needed.
87
+ */
88
+ AX_ASR_API int AX_ASR_RunFile(AX_ASR_HANDLE handle,
89
+ const char* wav_file,
90
+ const char* language,
91
+ char** result);
92
+
93
+ /**
94
+ * @brief Perform speech recognition and return dynamically allocated string
95
+ *
96
+ * @param handle asr context handle
97
+ * @param pcm_data 16k Mono PCM f32 data, range from -1.0 to 1.0,
98
+ * will be resampled if not 16k
99
+ * @param num_samples Sample num of PCM data
100
+ * @param sample_rate Sample rate of input audio
101
+ * @param language Preferred language,
102
+ * For whisper, check https://whisper-api.com/docs/languages/
103
+ * For sensevoice, support auto, zh, en, yue, ja, ko
104
+ * @param result Pointer to receive the allocated result string
105
+ *
106
+ * @return int Status code (0 = success, <0 = error)
107
+ *
108
+ * @note The returned string is allocated with malloc() and must be freed
109
+ * by the caller using free() when no longer needed.
110
+ */
111
+ AX_ASR_API int AX_ASR_RunPCM(AX_ASR_HANDLE handle,
112
+ float* pcm_data,
113
+ int num_samples,
114
+ int sample_rate,
115
+ const char* language,
116
+ char** result);
117
+
118
+ #ifdef __cplusplus
119
+ }
120
+ #endif
121
+
122
+ #endif // _AX_ASR_API_H_
cpp/ax650/lib/cmake/ax_asr_api/ax_asr_api-config-debug.cmake ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Debug".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "ax::ax_asr_api" for configuration "Debug"
9
+ set_property(TARGET ax::ax_asr_api APPEND PROPERTY IMPORTED_CONFIGURATIONS DEBUG)
10
+ set_target_properties(ax::ax_asr_api PROPERTIES
11
+ IMPORTED_LOCATION_DEBUG "${_IMPORT_PREFIX}/lib/libax_asr_api.so"
12
+ IMPORTED_SONAME_DEBUG "libax_asr_api.so"
13
+ )
14
+
15
+ list(APPEND _cmake_import_check_targets ax::ax_asr_api )
16
+ list(APPEND _cmake_import_check_files_for_ax::ax_asr_api "${_IMPORT_PREFIX}/lib/libax_asr_api.so" )
17
+
18
+ # Commands beyond this point should not need to know the version.
19
+ set(CMAKE_IMPORT_FILE_VERSION)
cpp/ax650/lib/cmake/ax_asr_api/ax_asr_api-config-release.cmake ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Release".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "ax::ax_asr_api" for configuration "Release"
9
+ set_property(TARGET ax::ax_asr_api APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
10
+ set_target_properties(ax::ax_asr_api PROPERTIES
11
+ IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libax_asr_api.so"
12
+ IMPORTED_SONAME_RELEASE "libax_asr_api.so"
13
+ )
14
+
15
+ list(APPEND _cmake_import_check_targets ax::ax_asr_api )
16
+ list(APPEND _cmake_import_check_files_for_ax::ax_asr_api "${_IMPORT_PREFIX}/lib/libax_asr_api.so" )
17
+
18
+ # Commands beyond this point should not need to know the version.
19
+ set(CMAKE_IMPORT_FILE_VERSION)
cpp/ax650/lib/cmake/ax_asr_api/ax_asr_api-config.cmake ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by CMake
2
+
3
+ if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
4
+ message(FATAL_ERROR "CMake >= 2.8.0 required")
5
+ endif()
6
+ if(CMAKE_VERSION VERSION_LESS "2.8.3")
7
+ message(FATAL_ERROR "CMake >= 2.8.3 required")
8
+ endif()
9
+ cmake_policy(PUSH)
10
+ cmake_policy(VERSION 2.8.3...3.28)
11
+ #----------------------------------------------------------------
12
+ # Generated CMake target import file.
13
+ #----------------------------------------------------------------
14
+
15
+ # Commands may need to know the format version.
16
+ set(CMAKE_IMPORT_FILE_VERSION 1)
17
+
18
+ # Protect against multiple inclusion, which would fail when already imported targets are added once more.
19
+ set(_cmake_targets_defined "")
20
+ set(_cmake_targets_not_defined "")
21
+ set(_cmake_expected_targets "")
22
+ foreach(_cmake_expected_target IN ITEMS ax::ax_asr_api)
23
+ list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
24
+ if(TARGET "${_cmake_expected_target}")
25
+ list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
26
+ else()
27
+ list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
28
+ endif()
29
+ endforeach()
30
+ unset(_cmake_expected_target)
31
+ if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
32
+ unset(_cmake_targets_defined)
33
+ unset(_cmake_targets_not_defined)
34
+ unset(_cmake_expected_targets)
35
+ unset(CMAKE_IMPORT_FILE_VERSION)
36
+ cmake_policy(POP)
37
+ return()
38
+ endif()
39
+ if(NOT _cmake_targets_defined STREQUAL "")
40
+ string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
41
+ string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
42
+ message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
43
+ endif()
44
+ unset(_cmake_targets_defined)
45
+ unset(_cmake_targets_not_defined)
46
+ unset(_cmake_expected_targets)
47
+
48
+
49
+ # Compute the installation prefix relative to this file.
50
+ get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
51
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
52
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
53
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
54
+ if(_IMPORT_PREFIX STREQUAL "/")
55
+ set(_IMPORT_PREFIX "")
56
+ endif()
57
+
58
+ # Create imported target ax::ax_asr_api
59
+ add_library(ax::ax_asr_api SHARED IMPORTED)
60
+
61
+ set_target_properties(ax::ax_asr_api PROPERTIES
62
+ INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
63
+ )
64
+
65
+ # Load information for each installed configuration.
66
+ file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/ax_asr_api-config-*.cmake")
67
+ foreach(_cmake_config_file IN LISTS _cmake_config_files)
68
+ include("${_cmake_config_file}")
69
+ endforeach()
70
+ unset(_cmake_config_file)
71
+ unset(_cmake_config_files)
72
+
73
+ # Cleanup temporary variables.
74
+ set(_IMPORT_PREFIX)
75
+
76
+ # Loop over all imported files and verify that they actually exist
77
+ foreach(_cmake_target IN LISTS _cmake_import_check_targets)
78
+ if(CMAKE_VERSION VERSION_LESS "3.28"
79
+ OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
80
+ OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
81
+ foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
82
+ if(NOT EXISTS "${_cmake_file}")
83
+ message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
84
+ \"${_cmake_file}\"
85
+ but this file does not exist. Possible reasons include:
86
+ * The file was deleted, renamed, or moved to another location.
87
+ * An install or uninstall procedure did not complete successfully.
88
+ * The installation package was faulty and contained
89
+ \"${CMAKE_CURRENT_LIST_FILE}\"
90
+ but not all the files it references.
91
+ ")
92
+ endif()
93
+ endforeach()
94
+ endif()
95
+ unset(_cmake_file)
96
+ unset("_cmake_import_check_files_for_${_cmake_target}")
97
+ endforeach()
98
+ unset(_cmake_target)
99
+ unset(_cmake_import_check_targets)
100
+
101
+ # This file does not depend on other imported targets which have
102
+ # been exported from the same project but in a separate export set.
103
+
104
+ # Commands beyond this point should not need to know the version.
105
+ set(CMAKE_IMPORT_FILE_VERSION)
106
+ cmake_policy(POP)
cpp/ax650/lib/libax_asr_api.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4796b07d503ef78826a37b1f458941ca460ee73e0e87c90d5e9dfb999335b9ec
3
+ size 421624
cpp/ax650/test_sensevoice ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21c9de78331edac15eede7b14b0d1eb5743ce3e74d95e70c2c89b06417707baf
3
+ size 161088
download_utils.py DELETED
@@ -1,33 +0,0 @@
1
- import os
2
-
3
- # Speed up hf download using mirror url
4
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
5
- from huggingface_hub import snapshot_download
6
-
7
- current_file_path = os.path.dirname(__file__)
8
- REPO_ROOT = "AXERA-TECH"
9
- CACHE_PATH = os.path.join(current_file_path, "models")
10
-
11
-
12
- def download_model(model_name: str) -> str:
13
- """
14
- Download model from AXERA-TECH's huggingface space.
15
-
16
- model_name: str
17
- Available model names could be checked on https://huggingface.co/AXERA-TECH.
18
-
19
- Returns:
20
- str: Path to model_name
21
-
22
- """
23
- os.makedirs(CACHE_PATH, exist_ok=True)
24
-
25
- model_path = os.path.join(CACHE_PATH, model_name)
26
- if not os.path.exists(model_path):
27
- print(f"Downloading {model_name}...")
28
- snapshot_download(
29
- repo_id=f"{REPO_ROOT}/{model_name}",
30
- local_dir=os.path.join(CACHE_PATH, model_name),
31
- )
32
-
33
- return model_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend.py DELETED
@@ -1,460 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- from pathlib import Path
3
- from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
4
- import copy
5
-
6
- import numpy as np
7
- import kaldi_native_fbank as knf
8
-
9
-
10
- class WavFrontend:
11
- """Conventional frontend structure for ASR."""
12
-
13
- def __init__(
14
- self,
15
- cmvn_file: str = None,
16
- fs: int = 16000,
17
- window: str = "hamming",
18
- n_mels: int = 80,
19
- frame_length: int = 25,
20
- frame_shift: int = 10,
21
- lfr_m: int = 1,
22
- lfr_n: int = 1,
23
- dither: float = 1.0,
24
- **kwargs,
25
- ) -> None:
26
-
27
- opts = knf.FbankOptions()
28
- opts.frame_opts.samp_freq = fs
29
- opts.frame_opts.dither = dither
30
- opts.frame_opts.window_type = window
31
- opts.frame_opts.frame_shift_ms = float(frame_shift)
32
- opts.frame_opts.frame_length_ms = float(frame_length)
33
- opts.mel_opts.num_bins = n_mels
34
- opts.energy_floor = 0
35
- opts.frame_opts.snip_edges = True
36
- opts.mel_opts.debug_mel = False
37
- self.opts = opts
38
-
39
- self.lfr_m = lfr_m
40
- self.lfr_n = lfr_n
41
- self.cmvn_file = cmvn_file
42
-
43
- if self.cmvn_file:
44
- self.cmvn = self.load_cmvn()
45
- self.fbank_fn = None
46
- self.fbank_beg_idx = 0
47
- self.reset_status()
48
-
49
- def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
50
- waveform = waveform * (1 << 15)
51
- self.fbank_fn = knf.OnlineFbank(self.opts)
52
- self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
53
- frames = self.fbank_fn.num_frames_ready
54
- mat = np.empty([frames, self.opts.mel_opts.num_bins])
55
- for i in range(frames):
56
- mat[i, :] = self.fbank_fn.get_frame(i)
57
- feat = mat.astype(np.float32)
58
- feat_len = np.array(mat.shape[0]).astype(np.int32)
59
- return feat, feat_len
60
-
61
- def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
62
- waveform = waveform * (1 << 15)
63
- # self.fbank_fn = knf.OnlineFbank(self.opts)
64
- self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
65
- frames = self.fbank_fn.num_frames_ready
66
- mat = np.empty([frames, self.opts.mel_opts.num_bins])
67
- for i in range(self.fbank_beg_idx, frames):
68
- mat[i, :] = self.fbank_fn.get_frame(i)
69
- # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
70
- feat = mat.astype(np.float32)
71
- feat_len = np.array(mat.shape[0]).astype(np.int32)
72
- return feat, feat_len
73
-
74
- def reset_status(self):
75
- self.fbank_fn = knf.OnlineFbank(self.opts)
76
- self.fbank_beg_idx = 0
77
-
78
- def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
79
- if self.lfr_m != 1 or self.lfr_n != 1:
80
- feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
81
-
82
- if self.cmvn_file:
83
- feat = self.apply_cmvn(feat)
84
-
85
- feat_len = np.array(feat.shape[0]).astype(np.int32)
86
- return feat, feat_len
87
-
88
- @staticmethod
89
- def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
90
- LFR_inputs = []
91
-
92
- T = inputs.shape[0]
93
- T_lfr = int(np.ceil(T / lfr_n))
94
- left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
95
- inputs = np.vstack((left_padding, inputs))
96
- T = T + (lfr_m - 1) // 2
97
- for i in range(T_lfr):
98
- if lfr_m <= T - i * lfr_n:
99
- LFR_inputs.append(
100
- (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
101
- )
102
- else:
103
- # process last LFR frame
104
- num_padding = lfr_m - (T - i * lfr_n)
105
- frame = inputs[i * lfr_n :].reshape(-1)
106
- for _ in range(num_padding):
107
- frame = np.hstack((frame, inputs[-1]))
108
-
109
- LFR_inputs.append(frame)
110
- LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
111
- return LFR_outputs
112
-
113
- def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
114
- """
115
- Apply CMVN with mvn data
116
- """
117
- frame, dim = inputs.shape
118
- means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
119
- vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
120
- inputs = (inputs + means) * vars
121
- return inputs
122
-
123
- def load_cmvn(
124
- self,
125
- ) -> np.ndarray:
126
- with open(self.cmvn_file, "r", encoding="utf-8") as f:
127
- lines = f.readlines()
128
-
129
- means_list = []
130
- vars_list = []
131
- for i in range(len(lines)):
132
- line_item = lines[i].split()
133
- if line_item[0] == "<AddShift>":
134
- line_item = lines[i + 1].split()
135
- if line_item[0] == "<LearnRateCoef>":
136
- add_shift_line = line_item[3 : (len(line_item) - 1)]
137
- means_list = list(add_shift_line)
138
- continue
139
- elif line_item[0] == "<Rescale>":
140
- line_item = lines[i + 1].split()
141
- if line_item[0] == "<LearnRateCoef>":
142
- rescale_line = line_item[3 : (len(line_item) - 1)]
143
- vars_list = list(rescale_line)
144
- continue
145
-
146
- means = np.array(means_list).astype(np.float64)
147
- vars = np.array(vars_list).astype(np.float64)
148
- cmvn = np.array([means, vars])
149
- return cmvn
150
-
151
-
152
- class WavFrontendOnline(WavFrontend):
153
- def __init__(self, **kwargs):
154
- super().__init__(**kwargs)
155
- # self.fbank_fn = knf.OnlineFbank(self.opts)
156
- # add variables
157
- self.frame_sample_length = int(
158
- self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
159
- )
160
- self.frame_shift_sample_length = int(
161
- self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
162
- )
163
- self.waveform = None
164
- self.reserve_waveforms = None
165
- self.input_cache = None
166
- self.lfr_splice_cache = []
167
-
168
- @staticmethod
169
- # inputs has catted the cache
170
- def apply_lfr(
171
- inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
172
- ) -> Tuple[np.ndarray, np.ndarray, int]:
173
- """
174
- Apply lfr with data
175
- """
176
-
177
- LFR_inputs = []
178
- T = inputs.shape[0] # include the right context
179
- T_lfr = int(
180
- np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
181
- ) # minus the right context: (lfr_m - 1) // 2
182
- splice_idx = T_lfr
183
- for i in range(T_lfr):
184
- if lfr_m <= T - i * lfr_n:
185
- LFR_inputs.append(
186
- (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
187
- )
188
- else: # process last LFR frame
189
- if is_final:
190
- num_padding = lfr_m - (T - i * lfr_n)
191
- frame = (inputs[i * lfr_n :]).reshape(-1)
192
- for _ in range(num_padding):
193
- frame = np.hstack((frame, inputs[-1]))
194
- LFR_inputs.append(frame)
195
- else:
196
- # update splice_idx and break the circle
197
- splice_idx = i
198
- break
199
- splice_idx = min(T - 1, splice_idx * lfr_n)
200
- lfr_splice_cache = inputs[splice_idx:, :]
201
- LFR_outputs = np.vstack(LFR_inputs)
202
- return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
203
-
204
- @staticmethod
205
- def compute_frame_num(
206
- sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
207
- ) -> int:
208
- frame_num = int(
209
- (sample_length - frame_sample_length) / frame_shift_sample_length + 1
210
- )
211
- return (
212
- frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
213
- )
214
-
215
- def fbank(
216
- self, input: np.ndarray, input_lengths: np.ndarray
217
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
218
- self.fbank_fn = knf.OnlineFbank(self.opts)
219
- batch_size = input.shape[0]
220
- if self.input_cache is None:
221
- self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
222
- input = np.concatenate((self.input_cache, input), axis=1)
223
- frame_num = self.compute_frame_num(
224
- input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
225
- )
226
- # update self.in_cache
227
- self.input_cache = input[
228
- :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
229
- ]
230
- waveforms = np.empty(0, dtype=np.float32)
231
- feats_pad = np.empty(0, dtype=np.float32)
232
- feats_lens = np.empty(0, dtype=np.int32)
233
- if frame_num:
234
- waveforms = []
235
- feats = []
236
- feats_lens = []
237
- for i in range(batch_size):
238
- waveform = input[i]
239
- waveforms.append(
240
- waveform[
241
- : (
242
- (frame_num - 1) * self.frame_shift_sample_length
243
- + self.frame_sample_length
244
- )
245
- ]
246
- )
247
- waveform = waveform * (1 << 15)
248
-
249
- self.fbank_fn.accept_waveform(
250
- self.opts.frame_opts.samp_freq, waveform.tolist()
251
- )
252
- frames = self.fbank_fn.num_frames_ready
253
- mat = np.empty([frames, self.opts.mel_opts.num_bins])
254
- for i in range(frames):
255
- mat[i, :] = self.fbank_fn.get_frame(i)
256
- feat = mat.astype(np.float32)
257
- feat_len = np.array(mat.shape[0]).astype(np.int32)
258
- feats.append(feat)
259
- feats_lens.append(feat_len)
260
-
261
- waveforms = np.stack(waveforms)
262
- feats_lens = np.array(feats_lens)
263
- feats_pad = np.array(feats)
264
- self.fbanks = feats_pad
265
- self.fbanks_lens = copy.deepcopy(feats_lens)
266
- return waveforms, feats_pad, feats_lens
267
-
268
- def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
269
- return self.fbanks, self.fbanks_lens
270
-
271
- def lfr_cmvn(
272
- self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
273
- ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
274
- batch_size = input.shape[0]
275
- feats = []
276
- feats_lens = []
277
- lfr_splice_frame_idxs = []
278
- for i in range(batch_size):
279
- mat = input[i, : input_lengths[i], :]
280
- lfr_splice_frame_idx = -1
281
- if self.lfr_m != 1 or self.lfr_n != 1:
282
- # update self.lfr_splice_cache in self.apply_lfr
283
- mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
284
- mat, self.lfr_m, self.lfr_n, is_final
285
- )
286
- if self.cmvn_file is not None:
287
- mat = self.apply_cmvn(mat)
288
- feat_length = mat.shape[0]
289
- feats.append(mat)
290
- feats_lens.append(feat_length)
291
- lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
292
-
293
- feats_lens = np.array(feats_lens)
294
- feats_pad = np.array(feats)
295
- return feats_pad, feats_lens, lfr_splice_frame_idxs
296
-
297
- def extract_fbank(
298
- self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
299
- ) -> Tuple[np.ndarray, np.ndarray]:
300
- batch_size = input.shape[0]
301
- assert (
302
- batch_size == 1
303
- ), "we support to extract feature online only when the batch size is equal to 1 now"
304
- waveforms, feats, feats_lengths = self.fbank(
305
- input, input_lengths
306
- ) # input shape: B T D
307
- if feats.shape[0]:
308
- self.waveforms = (
309
- waveforms
310
- if self.reserve_waveforms is None
311
- else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
312
- )
313
- if not self.lfr_splice_cache:
314
- for i in range(batch_size):
315
- self.lfr_splice_cache.append(
316
- np.expand_dims(feats[i][0, :], axis=0).repeat(
317
- (self.lfr_m - 1) // 2, axis=0
318
- )
319
- )
320
-
321
- if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
322
- lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
323
- feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
324
- feats_lengths += lfr_splice_cache_np[0].shape[0]
325
- frame_from_waveforms = int(
326
- (self.waveforms.shape[1] - self.frame_sample_length)
327
- / self.frame_shift_sample_length
328
- + 1
329
- )
330
- minus_frame = (
331
- (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
332
- )
333
- feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
334
- feats, feats_lengths, is_final
335
- )
336
- if self.lfr_m == 1:
337
- self.reserve_waveforms = None
338
- else:
339
- reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
340
- # print('reserve_frame_idx: ' + str(reserve_frame_idx))
341
- # print('frame_frame: ' + str(frame_from_waveforms))
342
- self.reserve_waveforms = self.waveforms[
343
- :,
344
- reserve_frame_idx
345
- * self.frame_shift_sample_length : frame_from_waveforms
346
- * self.frame_shift_sample_length,
347
- ]
348
- sample_length = (
349
- frame_from_waveforms - 1
350
- ) * self.frame_shift_sample_length + self.frame_sample_length
351
- self.waveforms = self.waveforms[:, :sample_length]
352
- else:
353
- # update self.reserve_waveforms and self.lfr_splice_cache
354
- self.reserve_waveforms = self.waveforms[
355
- :, : -(self.frame_sample_length - self.frame_shift_sample_length)
356
- ]
357
- for i in range(batch_size):
358
- self.lfr_splice_cache[i] = np.concatenate(
359
- (self.lfr_splice_cache[i], feats[i]), axis=0
360
- )
361
- return np.empty(0, dtype=np.float32), feats_lengths
362
- else:
363
- if is_final:
364
- self.waveforms = (
365
- waveforms
366
- if self.reserve_waveforms is None
367
- else self.reserve_waveforms
368
- )
369
- feats = np.stack(self.lfr_splice_cache)
370
- feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
371
- feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
372
- if is_final:
373
- self.cache_reset()
374
- return feats, feats_lengths
375
-
376
- def get_waveforms(self):
377
- return self.waveforms
378
-
379
- def cache_reset(self):
380
- self.fbank_fn = knf.OnlineFbank(self.opts)
381
- self.reserve_waveforms = None
382
- self.input_cache = None
383
- self.lfr_splice_cache = []
384
-
385
-
386
- def load_bytes(input):
387
- middle_data = np.frombuffer(input, dtype=np.int16)
388
- middle_data = np.asarray(middle_data)
389
- if middle_data.dtype.kind not in "iu":
390
- raise TypeError("'middle_data' must be an array of integers")
391
- dtype = np.dtype("float32")
392
- if dtype.kind != "f":
393
- raise TypeError("'dtype' must be a floating point type")
394
-
395
- i = np.iinfo(middle_data.dtype)
396
- abs_max = 2 ** (i.bits - 1)
397
- offset = i.min + abs_max
398
- array = np.frombuffer(
399
- (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32
400
- )
401
- return array
402
-
403
-
404
- class SinusoidalPositionEncoderOnline:
405
- """Streaming Positional encoding."""
406
-
407
- def encode(
408
- self,
409
- positions: np.ndarray = None,
410
- depth: int = None,
411
- dtype: np.dtype = np.float32,
412
- ):
413
- batch_size = positions.shape[0]
414
- positions = positions.astype(dtype)
415
- log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (
416
- depth / 2 - 1
417
- )
418
- inv_timescales = np.exp(
419
- np.arange(depth / 2).astype(dtype) * (-log_timescale_increment)
420
- )
421
- inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
422
- scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(
423
- inv_timescales, [1, 1, -1]
424
- )
425
- encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
426
- return encoding.astype(dtype)
427
-
428
- def forward(self, x, start_idx=0):
429
- batch_size, timesteps, input_dim = x.shape
430
- positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
431
- position_encoding = self.encode(positions, input_dim, x.dtype)
432
-
433
- return x + position_encoding[:, start_idx : start_idx + timesteps]
434
-
435
-
436
- def test():
437
- path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
438
- import librosa
439
-
440
- cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
441
- config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
442
- from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
443
-
444
- config = read_yaml(config_file)
445
- waveform, _ = librosa.load(path, sr=None)
446
- frontend = WavFrontend(
447
- cmvn_file=cmvn_file,
448
- **config["frontend_conf"],
449
- )
450
- speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
451
- feat, feat_len = frontend.lfr_cmvn(
452
- speech
453
- ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
454
-
455
- frontend.reset_status() # clear cache
456
- return feat, feat_len
457
-
458
-
459
- if __name__ == "__main__":
460
- test()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo.py DELETED
@@ -1,70 +0,0 @@
1
- import gradio as gr
2
- import os
3
- from SenseVoiceAx import SenseVoiceAx
4
- from download_utils import download_model
5
-
6
- model_root = download_model("SenseVoice")
7
- model_root = os.path.join(model_root, "sensevoice_ax650")
8
- max_seq_len = 256
9
- model_path = os.path.join(model_root, "sensevoice.axmodel")
10
-
11
- assert os.path.exists(model_path), f"model {model_path} not exist"
12
-
13
- cmvn_file = os.path.join(model_root, "am.mvn")
14
- bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
15
- token_file = os.path.join(model_root, "tokens.txt")
16
-
17
- model = SenseVoiceAx(
18
- model_path,
19
- cmvn_file,
20
- token_file,
21
- bpe_model,
22
- max_seq_len=max_seq_len,
23
- beam_size=3,
24
- hot_words=None,
25
- streaming=False,
26
- )
27
-
28
- # 你实现的语言转文本函数
29
- def speech_to_text(audio_path, lang):
30
- """
31
- audio_path: 音频文件路径
32
- lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko"
33
- """
34
- if not audio_path:
35
- return "无音频"
36
-
37
- asr_res = model.infer(audio_path, lang, print_rtf=False)
38
- return asr_res
39
-
40
-
41
- def main():
42
- with gr.Blocks() as demo:
43
- with gr.Row():
44
- output_text = gr.Textbox(label="识别结果", lines=5)
45
-
46
- with gr.Row():
47
- audio_input = gr.Audio(
48
- sources=["microphone"], type="filepath", label="录制或上传音频", format="mp3"
49
- )
50
- lang_dropdown = gr.Dropdown(
51
- choices=["auto", "zh", "en", "yue", "ja", "ko"],
52
- value="auto",
53
- label="选择音频语言",
54
- )
55
-
56
- audio_input.change(
57
- fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text
58
- )
59
-
60
- demo.launch(
61
- server_name="0.0.0.0",
62
- server_port=7860,
63
- ssl_certfile="./cert.pem",
64
- ssl_keyfile="./key.pem",
65
- ssl_verify=False,
66
- )
67
-
68
-
69
- if __name__ == "__main__":
70
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,80 +0,0 @@
1
- import os
2
- import argparse
3
- from SenseVoiceAx import SenseVoiceAx
4
- import librosa
5
- from download_utils import download_model
6
- import time
7
-
8
-
9
- def get_args():
10
- parser = argparse.ArgumentParser()
11
- parser.add_argument(
12
- "--input", "-i", required=True, type=str, help="Input audio file"
13
- )
14
- parser.add_argument(
15
- "--language",
16
- "-l",
17
- required=False,
18
- type=str,
19
- default="auto",
20
- choices=["auto", "zh", "en", "yue", "ja", "ko"],
21
- )
22
- parser.add_argument("--streaming", action="store_true")
23
- return parser.parse_args()
24
-
25
-
26
- def main():
27
- args = get_args()
28
- print(vars(args))
29
-
30
- input_audio = args.input
31
- language = args.language
32
- model_root = download_model("SenseVoice")
33
- model_root = os.path.join(model_root, "sensevoice_ax650")
34
- if not args.streaming:
35
- max_seq_len = 256
36
- model_path = os.path.join(model_root, "sensevoice.axmodel")
37
- else:
38
- max_seq_len = 26
39
- model_path = os.path.join(model_root, "streaming_sensevoice.axmodel")
40
-
41
- assert os.path.exists(model_path), f"model {model_path} not exist"
42
-
43
- cmvn_file = os.path.join(model_root, "am.mvn")
44
- bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
45
- token_file = os.path.join(model_root, "tokens.txt")
46
-
47
- model = SenseVoiceAx(
48
- model_path,
49
- cmvn_file,
50
- token_file,
51
- bpe_model,
52
- max_seq_len=max_seq_len,
53
- beam_size=3,
54
- hot_words=None,
55
- streaming=args.streaming,
56
- )
57
-
58
- if not args.streaming:
59
- asr_res = model.infer(input_audio, language, print_rtf=True)
60
- print("ASR result: " + asr_res)
61
- else:
62
- samples, sr = librosa.load(input_audio, sr=16000)
63
- samples = (samples * 32768).tolist()
64
- duration = len(samples) / 16000
65
-
66
- start = time.time()
67
- step = int(0.1 * sr)
68
- for i in range(0, len(samples), step):
69
- is_last = i + step >= len(samples)
70
- for res in model.stream_infer(samples[i : i + step], is_last, language):
71
- print(res)
72
-
73
- end = time.time()
74
- cost_time = end - start
75
-
76
- print(f"RTF: {cost_time / duration}")
77
-
78
-
79
- if __name__ == "__main__":
80
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- huggingface_hub
2
- numpy<2
3
- kaldi-native-fbank
4
- librosa==0.9.1
5
- fastapi
6
- gradio==5.47.1
7
- online-fbank
8
- asr_decoder
 
 
 
 
 
 
 
 
 
server.py DELETED
@@ -1,153 +0,0 @@
1
- import numpy as np
2
- from fastapi import FastAPI, HTTPException, Body
3
- from fastapi.responses import JSONResponse
4
- from typing import List, Optional
5
- import logging
6
- import json
7
- from SenseVoiceAx import SenseVoiceAx
8
- from download_utils import download_model
9
- import os
10
- import librosa
11
-
12
- # 初始化日志
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
17
-
18
- # 全局变量存储模型
19
- asr_model = None
20
-
21
-
22
- @app.on_event("startup")
23
- async def load_model():
24
- """
25
- 服务启动时加载ASR模型
26
- """
27
- global asr_model
28
- logger.info("Loading ASR model...")
29
-
30
- try:
31
- # 模型加载
32
- language = "auto"
33
- use_itn = True # 标点符号预测
34
- max_len = 68
35
-
36
- model_root = download_model("SenseVoice")
37
- model_root = os.path.join(model_root, "sensevoice_ax650")
38
- max_seq_len = 256
39
- model_path = os.path.join(model_root, "sensevoice.axmodel")
40
-
41
- assert os.path.exists(model_path), f"model {model_path} not exist"
42
-
43
- cmvn_file = os.path.join(model_root, "am.mvn")
44
- bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
45
- token_file = os.path.join(model_root, "tokens.txt")
46
-
47
- asr_model = SenseVoiceAx(
48
- model_path,
49
- cmvn_file,
50
- token_file,
51
- bpe_model,
52
- max_seq_len=max_seq_len,
53
- beam_size=3,
54
- hot_words=None,
55
- streaming=False,
56
- )
57
-
58
- print(f"language: {language}")
59
- print(f"use_itn: {use_itn}")
60
- print(f"model_path: {model_path}")
61
-
62
- logger.info("ASR model loaded successfully")
63
- except Exception as e:
64
- logger.error(f"Failed to load ASR model: {str(e)}")
65
- raise
66
-
67
-
68
- def validate_audio_data(audio_data: List[float]) -> np.ndarray:
69
- """
70
- 验证并转换音频数据为numpy数组
71
-
72
- 参数:
73
- - audio_data: 浮点数列表表示的音频数据
74
-
75
- 返回:
76
- - 验证后的numpy数组
77
- """
78
- try:
79
- # 转换为numpy数组
80
- np_array = np.array(audio_data, dtype=np.float32)
81
-
82
- # 验证数据有效性
83
- if np_array.ndim != 1:
84
- raise ValueError("Audio data must be 1-dimensional")
85
-
86
- if len(np_array) == 0:
87
- raise ValueError("Audio data cannot be empty")
88
-
89
- return np_array
90
- except Exception as e:
91
- raise ValueError(f"Invalid audio data: {str(e)}")
92
-
93
-
94
- @app.get("/get_language", summary="Get current language")
95
- async def get_language():
96
- return JSONResponse(content={"language": asr_model.language})
97
-
98
-
99
- @app.get(
100
- "/get_language_options",
101
- summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]",
102
- )
103
- async def get_language_options():
104
- return JSONResponse(content={"language_options": asr_model.language_options})
105
-
106
-
107
- @app.post("/asr", summary="Recognize speech from numpy audio data")
108
- async def recognize_speech(
109
- audio_data: List[float] = Body(
110
- ..., embed=True, description="Audio data as list of floats"
111
- ),
112
- sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
113
- language: Optional[str] = Body("auto", description="Language"),
114
- ):
115
- """
116
- 接收numpy数组格式的音频数据并返回识别结果
117
-
118
- 参数:
119
- - audio_data: 浮点数列表表示的音频数据
120
- - sample_rate: 音频采样率(默认16000Hz)
121
-
122
- 返回:
123
- - JSON包含识别文本
124
- """
125
- try:
126
- # 检查模型是否已加载
127
- if asr_model is None:
128
- raise HTTPException(status_code=503, detail="ASR model not loaded")
129
-
130
- logger.info(f"Received audio data with length: {len(audio_data)}")
131
-
132
- # 验证并转换数据
133
- np_audio = validate_audio_data(audio_data)
134
- if sample_rate != asr_model.sample_rate:
135
- np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
136
-
137
- # 调用模型进行识别
138
- result = asr_model.infer_waveform(np_audio, language)
139
-
140
- return JSONResponse(content={"text": result})
141
-
142
- except ValueError as e:
143
- logger.error(f"Validation error: {str(e)}")
144
- raise HTTPException(status_code=400, detail=str(e))
145
- except Exception as e:
146
- logger.error(f"Recognition error: {str(e)}")
147
- raise HTTPException(status_code=500, detail=str(e))
148
-
149
-
150
- if __name__ == "__main__":
151
- import uvicorn
152
-
153
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_wer.py DELETED
@@ -1,299 +0,0 @@
1
- import os
2
- import argparse
3
- from SenseVoiceAx import SenseVoiceAx
4
- from download_utils import download_model
5
- import logging
6
- import re
7
-
8
-
9
- def setup_logging():
10
- """配置日志系统,同时输出到控制台和文件"""
11
- # 获取脚本所在目录
12
- script_dir = os.path.dirname(os.path.abspath(__file__))
13
- log_file = os.path.join(script_dir, "test_wer.log")
14
-
15
- # 配置日志格式
16
- log_format = "%(asctime)s - %(levelname)s - %(message)s"
17
- date_format = "%Y-%m-%d %H:%M:%S"
18
-
19
- # 创建logger
20
- logger = logging.getLogger()
21
- logger.setLevel(logging.INFO)
22
-
23
- # 清除现有的handler
24
- for handler in logger.handlers[:]:
25
- logger.removeHandler(handler)
26
-
27
- # 创建文件handler
28
- file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
29
- file_handler.setLevel(logging.INFO)
30
- file_formatter = logging.Formatter(log_format, date_format)
31
- file_handler.setFormatter(file_formatter)
32
-
33
- # 创建控制台handler
34
- console_handler = logging.StreamHandler()
35
- console_handler.setLevel(logging.INFO)
36
- console_formatter = logging.Formatter(log_format, date_format)
37
- console_handler.setFormatter(console_formatter)
38
-
39
- # 添加handler到logger
40
- logger.addHandler(file_handler)
41
- logger.addHandler(console_handler)
42
-
43
- return logger
44
-
45
-
46
- class AIShellDataset:
47
- def __init__(self, gt_path: str):
48
- """
49
- 初始化数据集
50
-
51
- Args:
52
- json_path: voice.json文件的路径
53
- """
54
- self.gt_path = gt_path
55
- self.dataset_dir = os.path.dirname(gt_path)
56
- self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
57
-
58
- # 检查必要文件和文件夹是否存在
59
- assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
60
- assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
61
-
62
- # 加载数据
63
- self.data = []
64
- with open(gt_path, "r", encoding="utf-8") as f:
65
- for line in f:
66
- line = line.strip()
67
- audio_path, gt = line.split(" ")
68
- audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
69
- self.data.append({"audio_path": audio_path, "gt": gt})
70
-
71
- # 使用logging而不是print
72
- logger = logging.getLogger()
73
- logger.info(f"加载了 {len(self.data)} 条数据")
74
-
75
- def __iter__(self):
76
- """返回迭代器"""
77
- self.index = 0
78
- return self
79
-
80
- def __next__(self):
81
- """返回下一个数据项"""
82
- if self.index >= len(self.data):
83
- raise StopIteration
84
-
85
- item = self.data[self.index]
86
- audio_path = item["audio_path"]
87
- ground_truth = item["gt"]
88
-
89
- self.index += 1
90
- return audio_path, ground_truth
91
-
92
- def __len__(self):
93
- """返回数据集大小"""
94
- return len(self.data)
95
-
96
-
97
- class CommonVoiceDataset:
98
- """Common Voice数据集解析器"""
99
-
100
- def __init__(self, tsv_path: str):
101
- """
102
- 初始化数据集
103
-
104
- Args:
105
- json_path: voice.json文件的路径
106
- """
107
- self.tsv_path = tsv_path
108
- self.dataset_dir = os.path.dirname(tsv_path)
109
- self.voice_dir = os.path.join(self.dataset_dir, "clips")
110
-
111
- # 检查必要文件和文件夹是否存在
112
- assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
113
- assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
114
-
115
- # 加载JSON数据
116
- self.data = []
117
- with open(tsv_path, "r", encoding="utf-8") as f:
118
- f.readline()
119
- for line in f:
120
- line = line.strip()
121
- splits = line.split("\t")
122
- audio_path = splits[1]
123
- gt = splits[3]
124
- audio_path = os.path.join(self.voice_dir, audio_path)
125
- self.data.append({"audio_path": audio_path, "gt": gt})
126
-
127
- # 使用logging而不是print
128
- logger = logging.getLogger()
129
- logger.info(f"加载了 {len(self.data)} 条数据")
130
-
131
- def __iter__(self):
132
- """返回迭代器"""
133
- self.index = 0
134
- return self
135
-
136
- def __next__(self):
137
- """返回下一个数据项"""
138
- if self.index >= len(self.data):
139
- raise StopIteration
140
-
141
- item = self.data[self.index]
142
- audio_path = item["audio_path"]
143
- ground_truth = item["gt"]
144
-
145
- self.index += 1
146
- return audio_path, ground_truth
147
-
148
- def __len__(self):
149
- """返回数据集大小"""
150
- return len(self.data)
151
-
152
-
153
- def get_args():
154
- parser = argparse.ArgumentParser()
155
- parser.add_argument(
156
- "--dataset",
157
- "-d",
158
- type=str,
159
- required=True,
160
- choices=["aishell", "common_voice"],
161
- help="Test dataset",
162
- )
163
- parser.add_argument(
164
- "--gt_path",
165
- "-g",
166
- type=str,
167
- required=True,
168
- help="Test dataset ground truth file",
169
- )
170
- parser.add_argument(
171
- "--language",
172
- "-l",
173
- required=False,
174
- type=str,
175
- default="auto",
176
- choices=["auto", "zh", "en", "yue", "ja", "ko"],
177
- )
178
- parser.add_argument(
179
- "--max_num", type=int, default=-1, required=False, help="Maximum test data num"
180
- )
181
- return parser.parse_args()
182
-
183
-
184
- def min_distance(word1: str, word2: str) -> int:
185
-
186
- row = len(word1) + 1
187
- column = len(word2) + 1
188
-
189
- cache = [[0] * column for i in range(row)]
190
-
191
- for i in range(row):
192
- for j in range(column):
193
-
194
- if i == 0 and j == 0:
195
- cache[i][j] = 0
196
- elif i == 0 and j != 0:
197
- cache[i][j] = j
198
- elif j == 0 and i != 0:
199
- cache[i][j] = i
200
- else:
201
- if word1[i - 1] == word2[j - 1]:
202
- cache[i][j] = cache[i - 1][j - 1]
203
- else:
204
- replace = cache[i - 1][j - 1] + 1
205
- insert = cache[i][j - 1] + 1
206
- remove = cache[i - 1][j] + 1
207
-
208
- cache[i][j] = min(replace, insert, remove)
209
-
210
- return cache[row - 1][column - 1]
211
-
212
-
213
- def remove_punctuation(text):
214
- # 定义正则表达式模式,匹配所有标点符号
215
- # 这个模式包括常见的标点符号和中文标点
216
- pattern = r"[^\w\s]|_"
217
-
218
- # 使用sub方法将所有匹配的标点符号替换为空字符串
219
- cleaned_text = re.sub(pattern, "", text)
220
-
221
- return cleaned_text
222
-
223
-
224
- def main():
225
- logger = setup_logging()
226
- args = get_args()
227
-
228
- language = args.language
229
- max_num = args.max_num
230
-
231
- dataset_type = args.dataset.lower()
232
- if dataset_type == "aishell":
233
- dataset = AIShellDataset(args.gt_path)
234
- elif dataset_type == "common_voice":
235
- dataset = CommonVoiceDataset(args.gt_path)
236
- else:
237
- raise ValueError(f"Unknown dataset type {dataset_type}")
238
-
239
- model_root = download_model("SenseVoice")
240
- model_root = os.path.join(model_root, "sensevoice_ax650")
241
- max_seq_len = 256
242
- model_path = os.path.join(model_root, "sensevoice.axmodel")
243
-
244
- assert os.path.exists(model_path), f"model {model_path} not exist"
245
-
246
- cmvn_file = os.path.join(model_root, "am.mvn")
247
- bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
248
- token_file = os.path.join(model_root, "tokens.txt")
249
-
250
- model = SenseVoiceAx(
251
- model_path,
252
- cmvn_file,
253
- token_file,
254
- bpe_model,
255
- max_seq_len=max_seq_len,
256
- beam_size=3,
257
- hot_words=None,
258
- streaming=False,
259
- )
260
-
261
- logger.info(f"dataset: {args.dataset}")
262
- logger.info(f"language: {language}")
263
- logger.info(f"model_path: {model_path}")
264
-
265
- # Iterate over dataset
266
- hyp = []
267
- references = []
268
- all_character_error_num = 0
269
- all_character_num = 0
270
- max_data_num = max_num if max_num > 0 else len(dataset)
271
- for n, (audio_path, reference) in enumerate(dataset):
272
- reference = remove_punctuation(reference).lower()
273
-
274
- asr_res = model.infer(audio_path, language, print_rtf=False)
275
- hypothesis = remove_punctuation(asr_res).lower()
276
-
277
- character_error_num = min_distance(reference, hypothesis)
278
- character_num = len(reference)
279
- character_error_rate = character_error_num / character_num * 100
280
-
281
- all_character_error_num += character_error_num
282
- all_character_num += character_num
283
-
284
- hyp.append(hypothesis)
285
- references.append(reference)
286
-
287
- line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
288
- logger.info(line_content)
289
-
290
- if n + 1 >= max_data_num:
291
- break
292
-
293
- total_character_error_rate = all_character_error_num / all_character_num * 100
294
-
295
- logger.info(f"Total WER: {total_character_error_rate}%")
296
-
297
-
298
- if __name__ == "__main__":
299
- main()