inoryQwQ commited on
Commit
f3ecff1
·
1 Parent(s): 5a86cb7

change structure

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 祈Inory
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,50 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sensevoice.axera
2
+ FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
3
+
4
+ ## 功能
5
+ - 语音识别
6
+ - 自动识别语言
7
+ - 情感识别
8
+ - 自动标点
9
+
10
+ ## 支持平台
11
+
12
+ - [x] AX650N
13
+ - [ ] AX630C
14
+
15
+ ## 环境安装
16
+ ```
17
+ pip3 install -r requirements.txt
18
+ ```
19
+ 如果空间不足可以使用 --prefix 指定别的安装路径
20
+
21
+
22
+ ## 使用
23
+ ```
24
+ # 首次运行会自动从huggingface上下载模型, 保存到models中
25
+ python3 main.py -i 输入音频文件
26
+ ```
27
+ 运行参数说明:
28
+ | 参数名称 | 说明 | 默认值 |
29
+ | --- | --- | --- |
30
+ | --input/-i | 输入音频文件 | |
31
+ | --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
32
+
33
+
34
+ ### 示例:
35
+ example下有测试音频
36
+
37
+ 如 粤语测试
38
+ ```
39
+ python3 main.py -i example/yue.mp3
40
+ ```
41
+ 输出
42
+ ```
43
+ RTF: 0.03026517820946964 Latency: 0.15689468383789062s Total length: 5.184s
44
+ ['呢几个字。', '都表达唔到,我想讲嘅意。', '思。']
45
+ ```
46
+
47
+ ## 技术讨论
48
+
49
+ - Github issues
50
+ - QQ 群: 139953715
SenseVoiceAx.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import axengine as axe
2
+ import numpy as np
3
+ import librosa
4
+ from frontend import WavFrontend
5
+ import os
6
+ import time
7
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
8
+
9
+ def sequence_mask(lengths, maxlen=None, dtype=np.float32):
10
+ # 如果 maxlen 未指定,则取 lengths 中的最大值
11
+ if maxlen is None:
12
+ maxlen = np.max(lengths)
13
+
14
+ # 创建一个从 0 到 maxlen-1 的行向量
15
+ row_vector = np.arange(0, maxlen, 1)
16
+
17
+ # 将 lengths 转换为列向量
18
+ matrix = np.expand_dims(lengths, axis=-1)
19
+
20
+ # 比较生成掩码
21
+ mask = row_vector < matrix
22
+
23
+ # 返回指定数据类型的掩码
24
+ return mask.astype(dtype)[None, ...]
25
+
26
+ def unique_consecutive_np(x, dim=None, return_inverse=False, return_counts=False):
27
+ if dim is None:
28
+ # 默认情况,展平后去重
29
+ x_flat = x.ravel()
30
+ mask = np.concatenate(([True], x_flat[1:] != x_flat[:-1]))
31
+ unique_data = x_flat[mask]
32
+ else:
33
+ # 沿着指定维度去重
34
+ axis = dim if dim >= 0 else x.ndim + dim
35
+ if axis >= x.ndim:
36
+ raise ValueError(f"dim {dim} is out of range for array of dimension {x.ndim}")
37
+
38
+ # 使用 np.diff 检查相邻元素是否相同
39
+ mask = np.ones(x.shape[axis], dtype=bool)
40
+ if x.shape[axis] > 1:
41
+ # 比较当前元素和前一个元素是否不同
42
+ diff = np.diff(x, axis=axis)
43
+ mask[1:] = np.any(diff != 0, axis=tuple(range(diff.ndim))[axis:])
44
+
45
+ # 使用 mask 索引提取唯一元素
46
+ unique_data = np.take(x, np.where(mask)[0], axis=axis)
47
+
48
+ # 处理 return_inverse 和 return_counts
49
+ results = (unique_data,)
50
+
51
+ if return_inverse:
52
+ if dim is None:
53
+ inv_idx = np.cumsum(mask) - 1
54
+ else:
55
+ inv_idx = np.cumsum(mask) - 1
56
+ # 需要调整形状以匹配输入
57
+ inv_idx = np.expand_dims(inv_idx, axis=axis)
58
+ inv_idx = np.broadcast_to(inv_idx, x.shape)
59
+ results += (inv_idx,)
60
+
61
+ if return_counts:
62
+ if dim is None:
63
+ counts = np.diff(np.where(np.concatenate((mask, [True])))[0])
64
+ else:
65
+ counts = np.diff(np.where(np.concatenate((mask, [True])))[0])
66
+ results += (counts,)
67
+
68
+ return results[0] if len(results) == 1 else results
69
+
70
+ class SenseVoiceAx:
71
+ def __init__(self, model_path, language="auto", use_itn=True, tokenizer=None):
72
+ model_path_root = os.path.join(os.path.dirname(model_path), "../embeddings")
73
+ self.frontend = WavFrontend(cmvn_file="am.mvn",
74
+ fs=16000,
75
+ window="hamming",
76
+ n_mels=80,
77
+ frame_length=25,
78
+ frame_shift=10,
79
+ lfr_m=7,
80
+ lfr_n=6,)
81
+ self.model = axe.InferenceSession(model_path)
82
+ self.sample_rate = 16000
83
+ self.tokenizer = tokenizer
84
+ self.blank_id = 0
85
+ self.max_len = 34
86
+
87
+ self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
88
+ self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
89
+ self.textnorm_dict = {"withitn": 14, "woitn": 15}
90
+ self.textnorm_int_dict = {25016: 14, 25017: 15}
91
+ self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
92
+
93
+ self.position_encoding = np.load(f"{model_path_root}/position_encoding.npy")
94
+ language_query = np.load(f"{model_path_root}/{language}.npy")
95
+ textnorm_query = np.load(f"{model_path_root}/withitn.npy") if use_itn else np.load(f"{model_path_root}/woitn.npy")
96
+ event_emo_query = np.load(f"{model_path_root}/event_emo.npy")
97
+ self.input_query = np.concatenate((textnorm_query, language_query, event_emo_query), axis=1)
98
+ self.query_num = self.input_query.shape[1]
99
+ self.masks = sequence_mask(np.array([self.max_len], dtype=np.int32), dtype=np.float32)
100
+
101
+ def load_data(self, filepath: str) -> np.ndarray:
102
+ waveform, _ = librosa.load(filepath, sr=self.sample_rate)
103
+ return waveform.flatten()
104
+
105
+ @staticmethod
106
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
107
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
108
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
109
+ return np.pad(feat, pad_width, "constant", constant_values=0)
110
+
111
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
112
+ feats = np.array(feat_res).astype(np.float32)
113
+ return feats
114
+
115
+ def preprocess(self, waveform):
116
+ feats, feats_len = [], []
117
+ for wf in [waveform]:
118
+ speech, _ = self.frontend.fbank(wf)
119
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
120
+ feats.append(feat)
121
+ feats_len.append(feat_len)
122
+
123
+ feats = self.pad_feats(feats, np.max(feats_len))
124
+ feats_len = np.array(feats_len).astype(np.int32)
125
+ return feats, feats_len
126
+
127
+ def postprocess(self, ctc_logits, encoder_out_lens):
128
+ # 提取数据
129
+ x = ctc_logits[0, :encoder_out_lens[0], :]
130
+
131
+ # 获取最大值索引
132
+ yseq = np.argmax(x, axis=-1)
133
+
134
+ # 去除连续重复元素
135
+ yseq = unique_consecutive_np(yseq, dim=-1)
136
+
137
+ # 创建掩码并过滤 blank_id
138
+ mask = yseq != self.blank_id
139
+ token_int = yseq[mask].tolist()
140
+
141
+ return token_int
142
+
143
+ def infer_waveform(self, waveform: np.ndarray):
144
+ feat, feat_len = self.preprocess(waveform)
145
+
146
+ slice_len = self.max_len - self.query_num
147
+ slice_num = int(np.ceil(feat.shape[1] / slice_len))
148
+
149
+ asr_res = []
150
+ for i in range(slice_num):
151
+ sub_feat = feat[:, i*slice_len:(i+1)*slice_len, :]
152
+ # concat query
153
+ sub_feat = np.concatenate([self.input_query, sub_feat], axis=1)
154
+
155
+ if sub_feat.shape[1] < self.max_len:
156
+ sub_feat = np.concatenate([
157
+ sub_feat,
158
+ np.zeros((1, self.max_len - sub_feat.shape[1], sub_feat.shape[-1]), dtype=np.float32)
159
+ ],
160
+ axis=1)
161
+
162
+ outputs = self.model.run(None, {"speech": sub_feat,
163
+ "masks": self.masks,
164
+ "position_encoding": self.position_encoding})
165
+ ctc_logits, encoder_out_lens = outputs
166
+
167
+ token_int = self.postprocess(ctc_logits, encoder_out_lens)
168
+ if self.tokenizer is not None:
169
+ asr_res.append(self.tokenizer.tokens2text(token_int))
170
+ else:
171
+ asr_res.append(token_int)
172
+
173
+ return asr_res
174
+
175
+ def infer(self, filepath_or_data: Union[np.ndarray, str], print_rtf=True):
176
+ if isinstance(filepath_or_data, str):
177
+ waveform = self.load_data(filepath_or_data)
178
+ else:
179
+ waveform = filepath_or_data
180
+
181
+ total_time = waveform.shape[-1] / self.sample_rate
182
+
183
+ start = time.time()
184
+ asr_res = self.infer_waveform(waveform)
185
+ latency = time.time() - start
186
+
187
+ if print_rtf:
188
+ rtf = latency / total_time
189
+ print(f"RTF: {rtf} Latency: {latency}s Total length: {total_time}s")
190
+ return asr_res
191
+
192
+
download_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Speed up hf download using mirror url
3
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4
+ from huggingface_hub import snapshot_download
5
+
6
+ current_file_path = os.path.dirname(__file__)
7
+ REPO_ROOT = "AXERA-TECH"
8
+ CACHE_PATH = os.path.join(current_file_path, "models")
9
+
10
+ def download_model(model_name: str) -> str:
11
+ """
12
+ Download model from AXERA-TECH's huggingface space.
13
+
14
+ model_name: str
15
+ Available model names could be checked on https://huggingface.co/AXERA-TECH.
16
+
17
+ Returns:
18
+ str: Path to model_name
19
+
20
+ """
21
+ os.makedirs(CACHE_PATH, exist_ok=True)
22
+
23
+ model_path = os.path.join(CACHE_PATH, model_name)
24
+ if not os.path.exists(model_path):
25
+ print(f"Downloading {model_name}...")
26
+ snapshot_download(repo_id=f"{REPO_ROOT}/{model_name}",
27
+ local_dir=os.path.join(CACHE_PATH, model_name))
28
+
29
+ return model_path
auto.npy → embeddings/auto.npy RENAMED
File without changes
en.npy → embeddings/en.npy RENAMED
File without changes
event_emo.npy → embeddings/event_emo.npy RENAMED
File without changes
ja.npy → embeddings/ja.npy RENAMED
File without changes
ko.npy → embeddings/ko.npy RENAMED
File without changes
nospeech.npy → embeddings/nospeech.npy RENAMED
File without changes
position_encoding.npy → embeddings/position_encoding.npy RENAMED
File without changes
withitn.npy → embeddings/withitn.npy RENAMED
File without changes
woitn.npy → embeddings/woitn.npy RENAMED
File without changes
yue.npy → embeddings/yue.npy RENAMED
File without changes
zh.npy → embeddings/zh.npy RENAMED
File without changes
frontend.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
4
+ import copy
5
+
6
+ import numpy as np
7
+ import kaldi_native_fbank as knf
8
+
9
+
10
+ class WavFrontend:
11
+ """Conventional frontend structure for ASR."""
12
+
13
+ def __init__(
14
+ self,
15
+ cmvn_file: str = None,
16
+ fs: int = 16000,
17
+ window: str = "hamming",
18
+ n_mels: int = 80,
19
+ frame_length: int = 25,
20
+ frame_shift: int = 10,
21
+ lfr_m: int = 1,
22
+ lfr_n: int = 1,
23
+ dither: float = 1.0,
24
+ **kwargs,
25
+ ) -> None:
26
+
27
+ opts = knf.FbankOptions()
28
+ opts.frame_opts.samp_freq = fs
29
+ opts.frame_opts.dither = dither
30
+ opts.frame_opts.window_type = window
31
+ opts.frame_opts.frame_shift_ms = float(frame_shift)
32
+ opts.frame_opts.frame_length_ms = float(frame_length)
33
+ opts.mel_opts.num_bins = n_mels
34
+ opts.energy_floor = 0
35
+ opts.frame_opts.snip_edges = True
36
+ opts.mel_opts.debug_mel = False
37
+ self.opts = opts
38
+
39
+ self.lfr_m = lfr_m
40
+ self.lfr_n = lfr_n
41
+ self.cmvn_file = cmvn_file
42
+
43
+ if self.cmvn_file:
44
+ self.cmvn = self.load_cmvn()
45
+ self.fbank_fn = None
46
+ self.fbank_beg_idx = 0
47
+ self.reset_status()
48
+
49
+ def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
50
+ waveform = waveform * (1 << 15)
51
+ self.fbank_fn = knf.OnlineFbank(self.opts)
52
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
53
+ frames = self.fbank_fn.num_frames_ready
54
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
55
+ for i in range(frames):
56
+ mat[i, :] = self.fbank_fn.get_frame(i)
57
+ feat = mat.astype(np.float32)
58
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
59
+ return feat, feat_len
60
+
61
+ def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
62
+ waveform = waveform * (1 << 15)
63
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
64
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
65
+ frames = self.fbank_fn.num_frames_ready
66
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
67
+ for i in range(self.fbank_beg_idx, frames):
68
+ mat[i, :] = self.fbank_fn.get_frame(i)
69
+ # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
70
+ feat = mat.astype(np.float32)
71
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
72
+ return feat, feat_len
73
+
74
+ def reset_status(self):
75
+ self.fbank_fn = knf.OnlineFbank(self.opts)
76
+ self.fbank_beg_idx = 0
77
+
78
+ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
79
+ if self.lfr_m != 1 or self.lfr_n != 1:
80
+ feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
81
+
82
+ if self.cmvn_file:
83
+ feat = self.apply_cmvn(feat)
84
+
85
+ feat_len = np.array(feat.shape[0]).astype(np.int32)
86
+ return feat, feat_len
87
+
88
+ @staticmethod
89
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
90
+ LFR_inputs = []
91
+
92
+ T = inputs.shape[0]
93
+ T_lfr = int(np.ceil(T / lfr_n))
94
+ left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
95
+ inputs = np.vstack((left_padding, inputs))
96
+ T = T + (lfr_m - 1) // 2
97
+ for i in range(T_lfr):
98
+ if lfr_m <= T - i * lfr_n:
99
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
100
+ else:
101
+ # process last LFR frame
102
+ num_padding = lfr_m - (T - i * lfr_n)
103
+ frame = inputs[i * lfr_n :].reshape(-1)
104
+ for _ in range(num_padding):
105
+ frame = np.hstack((frame, inputs[-1]))
106
+
107
+ LFR_inputs.append(frame)
108
+ LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
109
+ return LFR_outputs
110
+
111
+ def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
112
+ """
113
+ Apply CMVN with mvn data
114
+ """
115
+ frame, dim = inputs.shape
116
+ means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
117
+ vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
118
+ inputs = (inputs + means) * vars
119
+ return inputs
120
+
121
+ def load_cmvn(
122
+ self,
123
+ ) -> np.ndarray:
124
+ with open(self.cmvn_file, "r", encoding="utf-8") as f:
125
+ lines = f.readlines()
126
+
127
+ means_list = []
128
+ vars_list = []
129
+ for i in range(len(lines)):
130
+ line_item = lines[i].split()
131
+ if line_item[0] == "<AddShift>":
132
+ line_item = lines[i + 1].split()
133
+ if line_item[0] == "<LearnRateCoef>":
134
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
135
+ means_list = list(add_shift_line)
136
+ continue
137
+ elif line_item[0] == "<Rescale>":
138
+ line_item = lines[i + 1].split()
139
+ if line_item[0] == "<LearnRateCoef>":
140
+ rescale_line = line_item[3 : (len(line_item) - 1)]
141
+ vars_list = list(rescale_line)
142
+ continue
143
+
144
+ means = np.array(means_list).astype(np.float64)
145
+ vars = np.array(vars_list).astype(np.float64)
146
+ cmvn = np.array([means, vars])
147
+ return cmvn
148
+
149
+
150
+ class WavFrontendOnline(WavFrontend):
151
+ def __init__(self, **kwargs):
152
+ super().__init__(**kwargs)
153
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
154
+ # add variables
155
+ self.frame_sample_length = int(
156
+ self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
157
+ )
158
+ self.frame_shift_sample_length = int(
159
+ self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
160
+ )
161
+ self.waveform = None
162
+ self.reserve_waveforms = None
163
+ self.input_cache = None
164
+ self.lfr_splice_cache = []
165
+
166
+ @staticmethod
167
+ # inputs has catted the cache
168
+ def apply_lfr(
169
+ inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
170
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
171
+ """
172
+ Apply lfr with data
173
+ """
174
+
175
+ LFR_inputs = []
176
+ T = inputs.shape[0] # include the right context
177
+ T_lfr = int(
178
+ np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
179
+ ) # minus the right context: (lfr_m - 1) // 2
180
+ splice_idx = T_lfr
181
+ for i in range(T_lfr):
182
+ if lfr_m <= T - i * lfr_n:
183
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
184
+ else: # process last LFR frame
185
+ if is_final:
186
+ num_padding = lfr_m - (T - i * lfr_n)
187
+ frame = (inputs[i * lfr_n :]).reshape(-1)
188
+ for _ in range(num_padding):
189
+ frame = np.hstack((frame, inputs[-1]))
190
+ LFR_inputs.append(frame)
191
+ else:
192
+ # update splice_idx and break the circle
193
+ splice_idx = i
194
+ break
195
+ splice_idx = min(T - 1, splice_idx * lfr_n)
196
+ lfr_splice_cache = inputs[splice_idx:, :]
197
+ LFR_outputs = np.vstack(LFR_inputs)
198
+ return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
199
+
200
+ @staticmethod
201
+ def compute_frame_num(
202
+ sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
203
+ ) -> int:
204
+ frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
205
+ return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
206
+
207
+ def fbank(
208
+ self, input: np.ndarray, input_lengths: np.ndarray
209
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
210
+ self.fbank_fn = knf.OnlineFbank(self.opts)
211
+ batch_size = input.shape[0]
212
+ if self.input_cache is None:
213
+ self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
214
+ input = np.concatenate((self.input_cache, input), axis=1)
215
+ frame_num = self.compute_frame_num(
216
+ input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
217
+ )
218
+ # update self.in_cache
219
+ self.input_cache = input[
220
+ :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
221
+ ]
222
+ waveforms = np.empty(0, dtype=np.float32)
223
+ feats_pad = np.empty(0, dtype=np.float32)
224
+ feats_lens = np.empty(0, dtype=np.int32)
225
+ if frame_num:
226
+ waveforms = []
227
+ feats = []
228
+ feats_lens = []
229
+ for i in range(batch_size):
230
+ waveform = input[i]
231
+ waveforms.append(
232
+ waveform[
233
+ : (
234
+ (frame_num - 1) * self.frame_shift_sample_length
235
+ + self.frame_sample_length
236
+ )
237
+ ]
238
+ )
239
+ waveform = waveform * (1 << 15)
240
+
241
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
242
+ frames = self.fbank_fn.num_frames_ready
243
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
244
+ for i in range(frames):
245
+ mat[i, :] = self.fbank_fn.get_frame(i)
246
+ feat = mat.astype(np.float32)
247
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
248
+ feats.append(feat)
249
+ feats_lens.append(feat_len)
250
+
251
+ waveforms = np.stack(waveforms)
252
+ feats_lens = np.array(feats_lens)
253
+ feats_pad = np.array(feats)
254
+ self.fbanks = feats_pad
255
+ self.fbanks_lens = copy.deepcopy(feats_lens)
256
+ return waveforms, feats_pad, feats_lens
257
+
258
+ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
259
+ return self.fbanks, self.fbanks_lens
260
+
261
+ def lfr_cmvn(
262
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
263
+ ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
264
+ batch_size = input.shape[0]
265
+ feats = []
266
+ feats_lens = []
267
+ lfr_splice_frame_idxs = []
268
+ for i in range(batch_size):
269
+ mat = input[i, : input_lengths[i], :]
270
+ lfr_splice_frame_idx = -1
271
+ if self.lfr_m != 1 or self.lfr_n != 1:
272
+ # update self.lfr_splice_cache in self.apply_lfr
273
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
274
+ mat, self.lfr_m, self.lfr_n, is_final
275
+ )
276
+ if self.cmvn_file is not None:
277
+ mat = self.apply_cmvn(mat)
278
+ feat_length = mat.shape[0]
279
+ feats.append(mat)
280
+ feats_lens.append(feat_length)
281
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
282
+
283
+ feats_lens = np.array(feats_lens)
284
+ feats_pad = np.array(feats)
285
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
286
+
287
+ def extract_fbank(
288
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
289
+ ) -> Tuple[np.ndarray, np.ndarray]:
290
+ batch_size = input.shape[0]
291
+ assert (
292
+ batch_size == 1
293
+ ), "we support to extract feature online only when the batch size is equal to 1 now"
294
+ waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
295
+ if feats.shape[0]:
296
+ self.waveforms = (
297
+ waveforms
298
+ if self.reserve_waveforms is None
299
+ else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
300
+ )
301
+ if not self.lfr_splice_cache:
302
+ for i in range(batch_size):
303
+ self.lfr_splice_cache.append(
304
+ np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
305
+ )
306
+
307
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
308
+ lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
309
+ feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
310
+ feats_lengths += lfr_splice_cache_np[0].shape[0]
311
+ frame_from_waveforms = int(
312
+ (self.waveforms.shape[1] - self.frame_sample_length)
313
+ / self.frame_shift_sample_length
314
+ + 1
315
+ )
316
+ minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
317
+ feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
318
+ feats, feats_lengths, is_final
319
+ )
320
+ if self.lfr_m == 1:
321
+ self.reserve_waveforms = None
322
+ else:
323
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
324
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
325
+ # print('frame_frame: ' + str(frame_from_waveforms))
326
+ self.reserve_waveforms = self.waveforms[
327
+ :,
328
+ reserve_frame_idx
329
+ * self.frame_shift_sample_length : frame_from_waveforms
330
+ * self.frame_shift_sample_length,
331
+ ]
332
+ sample_length = (
333
+ frame_from_waveforms - 1
334
+ ) * self.frame_shift_sample_length + self.frame_sample_length
335
+ self.waveforms = self.waveforms[:, :sample_length]
336
+ else:
337
+ # update self.reserve_waveforms and self.lfr_splice_cache
338
+ self.reserve_waveforms = self.waveforms[
339
+ :, : -(self.frame_sample_length - self.frame_shift_sample_length)
340
+ ]
341
+ for i in range(batch_size):
342
+ self.lfr_splice_cache[i] = np.concatenate(
343
+ (self.lfr_splice_cache[i], feats[i]), axis=0
344
+ )
345
+ return np.empty(0, dtype=np.float32), feats_lengths
346
+ else:
347
+ if is_final:
348
+ self.waveforms = (
349
+ waveforms if self.reserve_waveforms is None else self.reserve_waveforms
350
+ )
351
+ feats = np.stack(self.lfr_splice_cache)
352
+ feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
353
+ feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
354
+ if is_final:
355
+ self.cache_reset()
356
+ return feats, feats_lengths
357
+
358
+ def get_waveforms(self):
359
+ return self.waveforms
360
+
361
+ def cache_reset(self):
362
+ self.fbank_fn = knf.OnlineFbank(self.opts)
363
+ self.reserve_waveforms = None
364
+ self.input_cache = None
365
+ self.lfr_splice_cache = []
366
+
367
+
368
+ def load_bytes(input):
369
+ middle_data = np.frombuffer(input, dtype=np.int16)
370
+ middle_data = np.asarray(middle_data)
371
+ if middle_data.dtype.kind not in "iu":
372
+ raise TypeError("'middle_data' must be an array of integers")
373
+ dtype = np.dtype("float32")
374
+ if dtype.kind != "f":
375
+ raise TypeError("'dtype' must be a floating point type")
376
+
377
+ i = np.iinfo(middle_data.dtype)
378
+ abs_max = 2 ** (i.bits - 1)
379
+ offset = i.min + abs_max
380
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
381
+ return array
382
+
383
+
384
+ class SinusoidalPositionEncoderOnline:
385
+ """Streaming Positional encoding."""
386
+
387
+ def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
388
+ batch_size = positions.shape[0]
389
+ positions = positions.astype(dtype)
390
+ log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
391
+ inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
392
+ inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
393
+ scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
394
+ encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
395
+ return encoding.astype(dtype)
396
+
397
+ def forward(self, x, start_idx=0):
398
+ batch_size, timesteps, input_dim = x.shape
399
+ positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
400
+ position_encoding = self.encode(positions, input_dim, x.dtype)
401
+
402
+ return x + position_encoding[:, start_idx : start_idx + timesteps]
403
+
404
+
405
+ def test():
406
+ path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
407
+ import librosa
408
+
409
+ cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
410
+ config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
411
+ from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
412
+
413
+ config = read_yaml(config_file)
414
+ waveform, _ = librosa.load(path, sr=None)
415
+ frontend = WavFrontend(
416
+ cmvn_file=cmvn_file,
417
+ **config["frontend_conf"],
418
+ )
419
+ speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
420
+ feat, feat_len = frontend.lfr_cmvn(
421
+ speech
422
+ ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
423
+
424
+ frontend.reset_status() # clear cache
425
+ return feat, feat_len
426
+
427
+
428
+ if __name__ == "__main__":
429
+ test()
main.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import argparse
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ from tokenizer import SentencepiecesTokenizer
5
+ from print_utils import rich_transcription_postprocess, rich_print_asr_res
6
+ from download_utils import download_model
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--input", "-i", required=True, type=str, help="Input audio file")
12
+ parser.add_argument("--language", "-l", required=False, type=str, default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"])
13
+ return parser.parse_args()
14
+
15
+
16
+ def main():
17
+ args = get_args()
18
+
19
+ input_audio = args.input
20
+ language = args.language
21
+ use_itn = True # 标点符号预测
22
+
23
+ model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
24
+ bpemodel = "chn_jpn_yue_eng_ko_spectok.bpe.model"
25
+
26
+ assert os.path.exists(model_path), f"model {model_path} not exist"
27
+
28
+ print(f"input_audio: {input_audio}")
29
+ print(f"language: {language}")
30
+ print(f"use_itn: {use_itn}")
31
+ print(f"model_path: {model_path}")
32
+
33
+ tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
34
+ pipeline = SenseVoiceAx(model_path, language, use_itn, tokenizer=tokenizer)
35
+ asr_res = pipeline.infer(input_audio, print_rtf=True)
36
+ print([rich_transcription_postprocess(i) for i in asr_res])
37
+ # rich_print_asr_res(asr_res)
38
+
39
+ if __name__ == "__main__":
40
+ main()
print_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ emo_dict = {
2
+ "<|HAPPY|>": "😊",
3
+ "<|SAD|>": "😔",
4
+ "<|ANGRY|>": "😡",
5
+ "<|NEUTRAL|>": "",
6
+ "<|FEARFUL|>": "😰",
7
+ "<|DISGUSTED|>": "🤢",
8
+ "<|SURPRISED|>": "😮",
9
+ }
10
+
11
+ event_dict = {
12
+ "<|BGM|>": "🎼",
13
+ "<|Speech|>": "",
14
+ "<|Applause|>": "👏",
15
+ "<|Laughter|>": "😀",
16
+ "<|Cry|>": "😭",
17
+ "<|Sneeze|>": "🤧",
18
+ "<|Breath|>": "",
19
+ "<|Cough|>": "🤧",
20
+ }
21
+
22
+ lang_dict = {
23
+ "<|zh|>": "<|lang|>",
24
+ "<|en|>": "<|lang|>",
25
+ "<|yue|>": "<|lang|>",
26
+ "<|ja|>": "<|lang|>",
27
+ "<|ko|>": "<|lang|>",
28
+ "<|nospeech|>": "<|lang|>",
29
+ }
30
+
31
+ emoji_dict = {
32
+ "<|nospeech|><|Event_UNK|>": "❓",
33
+ "<|zh|>": "",
34
+ "<|en|>": "",
35
+ "<|yue|>": "",
36
+ "<|ja|>": "",
37
+ "<|ko|>": "",
38
+ "<|nospeech|>": "",
39
+ "<|HAPPY|>": "😊",
40
+ "<|SAD|>": "😔",
41
+ "<|ANGRY|>": "😡",
42
+ "<|NEUTRAL|>": "",
43
+ "<|BGM|>": "🎼",
44
+ "<|Speech|>": "",
45
+ "<|Applause|>": "👏",
46
+ "<|Laughter|>": "😀",
47
+ "<|FEARFUL|>": "😰",
48
+ "<|DISGUSTED|>": "🤢",
49
+ "<|SURPRISED|>": "😮",
50
+ "<|Cry|>": "😭",
51
+ "<|EMO_UNKNOWN|>": "",
52
+ "<|Sneeze|>": "🤧",
53
+ "<|Breath|>": "",
54
+ "<|Cough|>": "😷",
55
+ "<|Sing|>": "",
56
+ "<|Speech_Noise|>": "",
57
+ "<|withitn|>": "",
58
+ "<|woitn|>": "",
59
+ "<|GBG|>": "",
60
+ "<|Event_UNK|>": "",
61
+ }
62
+
63
+ emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
64
+ event_set = {
65
+ "🎼",
66
+ "👏",
67
+ "😀",
68
+ "😭",
69
+ "🤧",
70
+ "😷",
71
+ }
72
+
73
+
74
+ def format_str_v2(s):
75
+ sptk_dict = {}
76
+ for sptk in emoji_dict:
77
+ sptk_dict[sptk] = s.count(sptk)
78
+ s = s.replace(sptk, "")
79
+ emo = "<|NEUTRAL|>"
80
+ for e in emo_dict:
81
+ if sptk_dict[e] > sptk_dict[emo]:
82
+ emo = e
83
+ for e in event_dict:
84
+ if sptk_dict[e] > 0:
85
+ s = event_dict[e] + s
86
+ s = s + emo_dict[emo]
87
+
88
+ for emoji in emo_set.union(event_set):
89
+ s = s.replace(" " + emoji, emoji)
90
+ s = s.replace(emoji + " ", emoji)
91
+ return s.strip()
92
+
93
+ def rich_transcription_postprocess(s):
94
+ def get_emo(s):
95
+ return s[-1] if s[-1] in emo_set else None
96
+
97
+ def get_event(s):
98
+ return s[0] if s[0] in event_set else None
99
+
100
+ s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
101
+ for lang in lang_dict:
102
+ s = s.replace(lang, "<|lang|>")
103
+ s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
104
+ new_s = " " + s_list[0]
105
+ cur_ent_event = get_event(new_s)
106
+ for i in range(1, len(s_list)):
107
+ if len(s_list[i]) == 0:
108
+ continue
109
+ if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
110
+ s_list[i] = s_list[i][1:]
111
+ # else:
112
+ cur_ent_event = get_event(s_list[i])
113
+ if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
114
+ new_s = new_s[:-1]
115
+ new_s += s_list[i].strip().lstrip()
116
+ new_s = new_s.replace("The.", " ")
117
+ return new_s.strip()
118
+
119
+ def rich_print_asr_res(asr_res):
120
+ res = "".join([rich_transcription_postprocess(i) for i in asr_res])
121
+ print(res)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface_hub
2
+ numpy<2
3
+ kaldi-native-fbank
4
+ librosa==0.9.1
5
+ sentencepiece
sensevoice.axmodel → sensevoice_ax650/sensevoice.axmodel RENAMED
File without changes
tokenizer.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
5
+
6
+ import json
7
+ from abc import abstractmethod
8
+ from abc import ABC
9
+ import numpy as np
10
+
11
+
12
+ class BaseTokenizer(ABC):
13
+ def __init__(
14
+ self,
15
+ token_list: Union[Path, str, Iterable[str]] = None,
16
+ unk_symbol: str = "<unk>",
17
+ **kwargs,
18
+ ):
19
+
20
+ if token_list is not None:
21
+ if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
22
+ token_list = Path(token_list)
23
+ self.token_list_repr = str(token_list)
24
+ self.token_list: List[str] = []
25
+
26
+ with token_list.open("r", encoding="utf-8") as f:
27
+ for idx, line in enumerate(f):
28
+ line = line.rstrip()
29
+ self.token_list.append(line)
30
+ elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
31
+ token_list = Path(token_list)
32
+ self.token_list_repr = str(token_list)
33
+ self.token_list: List[str] = []
34
+
35
+ with open(token_list, "r", encoding="utf-8") as f:
36
+ self.token_list = json.load(f)
37
+
38
+ else:
39
+ self.token_list: List[str] = list(token_list)
40
+ self.token_list_repr = ""
41
+ for i, t in enumerate(self.token_list):
42
+ if i == 3:
43
+ break
44
+ self.token_list_repr += f"{t}, "
45
+ self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
46
+
47
+ self.token2id: Dict[str, int] = {}
48
+ for i, t in enumerate(self.token_list):
49
+ if t in self.token2id:
50
+ raise RuntimeError(f'Symbol "{t}" is duplicated')
51
+ self.token2id[t] = i
52
+
53
+ self.unk_symbol = unk_symbol
54
+ if self.unk_symbol not in self.token2id:
55
+ raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
56
+ self.unk_id = self.token2id[self.unk_symbol]
57
+
58
+ def encode(self, text, **kwargs):
59
+ tokens = self.text2tokens(text)
60
+ text_ints = self.tokens2ids(tokens)
61
+
62
+ return text_ints
63
+
64
+ def decode(self, text_ints):
65
+ token = self.ids2tokens(text_ints)
66
+ text = self.tokens2text(token)
67
+ return text
68
+
69
+ def get_num_vocabulary_size(self) -> int:
70
+ return len(self.token_list)
71
+
72
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
73
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
74
+ raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
75
+ return [self.token_list[i] for i in integers]
76
+
77
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
78
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
79
+
80
+ @abstractmethod
81
+ def text2tokens(self, line: str) -> List[str]:
82
+ raise NotImplementedError
83
+
84
+ @abstractmethod
85
+ def tokens2text(self, tokens: Iterable[str]) -> str:
86
+ raise NotImplementedError
87
+
88
+
89
+ class SentencepiecesTokenizer(BaseTokenizer):
90
+ def __init__(self, bpemodel: Union[Path, str], **kwargs):
91
+ super().__init__(**kwargs)
92
+ self.bpemodel = str(bpemodel)
93
+ # NOTE(kamo):
94
+ # Don't build SentencePieceProcessor in __init__()
95
+ # because it's not picklable and it may cause following error,
96
+ # "TypeError: can't pickle SwigPyObject objects",
97
+ # when giving it as argument of "multiprocessing.Process()".
98
+ self.sp = None
99
+ self._build_sentence_piece_processor()
100
+
101
+ def __repr__(self):
102
+ return f'{self.__class__.__name__}(model="{self.bpemodel}")'
103
+
104
+ def _build_sentence_piece_processor(self):
105
+ # Build SentencePieceProcessor lazily.
106
+ if self.sp is None:
107
+ self.sp = spm.SentencePieceProcessor()
108
+ self.sp.load(self.bpemodel)
109
+
110
+ def text2tokens(self, line: str) -> List[str]:
111
+ self._build_sentence_piece_processor()
112
+ return self.sp.EncodeAsPieces(line)
113
+
114
+ def tokens2text(self, tokens: Iterable[str]) -> str:
115
+ self._build_sentence_piece_processor()
116
+ return self.sp.DecodePieces(list(tokens))
117
+
118
+ def encode(self, line: str, **kwargs) -> List[int]:
119
+ self._build_sentence_piece_processor()
120
+ return self.sp.EncodeAsIds(line)
121
+
122
+ def decode(self, line: List[int], **kwargs):
123
+ self._build_sentence_piece_processor()
124
+ return self.sp.DecodeIds(line)
125
+
126
+ def get_vocab_size(self):
127
+ return self.sp.GetPieceSize()
128
+
129
+ def ids2tokens(self, *args, **kwargs):
130
+ return self.decode(*args, **kwargs)
131
+
132
+ def tokens2ids(self, *args, **kwargs):
133
+ return self.encode(*args, **kwargs)