change structure
Browse files- LICENSE +21 -0
- README.md +50 -3
- SenseVoiceAx.py +192 -0
- download_utils.py +29 -0
- auto.npy → embeddings/auto.npy +0 -0
- en.npy → embeddings/en.npy +0 -0
- event_emo.npy → embeddings/event_emo.npy +0 -0
- ja.npy → embeddings/ja.npy +0 -0
- ko.npy → embeddings/ko.npy +0 -0
- nospeech.npy → embeddings/nospeech.npy +0 -0
- position_encoding.npy → embeddings/position_encoding.npy +0 -0
- withitn.npy → embeddings/withitn.npy +0 -0
- woitn.npy → embeddings/woitn.npy +0 -0
- yue.npy → embeddings/yue.npy +0 -0
- zh.npy → embeddings/zh.npy +0 -0
- frontend.py +429 -0
- main.py +40 -0
- print_utils.py +121 -0
- requirements.txt +5 -0
- sensevoice.axmodel → sensevoice_ax650/sensevoice.axmodel +0 -0
- tokenizer.py +133 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 祈Inory
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,50 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sensevoice.axera
|
| 2 |
+
FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
|
| 3 |
+
|
| 4 |
+
## 功能
|
| 5 |
+
- 语音识别
|
| 6 |
+
- 自动识别语言
|
| 7 |
+
- 情感识别
|
| 8 |
+
- 自动标点
|
| 9 |
+
|
| 10 |
+
## 支持平台
|
| 11 |
+
|
| 12 |
+
- [x] AX650N
|
| 13 |
+
- [ ] AX630C
|
| 14 |
+
|
| 15 |
+
## 环境安装
|
| 16 |
+
```
|
| 17 |
+
pip3 install -r requirements.txt
|
| 18 |
+
```
|
| 19 |
+
如果空间不足可以使用 --prefix 指定别的安装路径
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## 使用
|
| 23 |
+
```
|
| 24 |
+
# 首次运行会自动从huggingface上下载模型, 保存到models中
|
| 25 |
+
python3 main.py -i 输入音频文件
|
| 26 |
+
```
|
| 27 |
+
运行参数说明:
|
| 28 |
+
| 参数名称 | 说明 | 默认值 |
|
| 29 |
+
| --- | --- | --- |
|
| 30 |
+
| --input/-i | 输入音频文件 | |
|
| 31 |
+
| --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
### 示例:
|
| 35 |
+
example下有测试音频
|
| 36 |
+
|
| 37 |
+
如 粤语测试
|
| 38 |
+
```
|
| 39 |
+
python3 main.py -i example/yue.mp3
|
| 40 |
+
```
|
| 41 |
+
输出
|
| 42 |
+
```
|
| 43 |
+
RTF: 0.03026517820946964 Latency: 0.15689468383789062s Total length: 5.184s
|
| 44 |
+
['呢几个字。', '都表达唔到,我想讲嘅意。', '思。']
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## 技术讨论
|
| 48 |
+
|
| 49 |
+
- Github issues
|
| 50 |
+
- QQ 群: 139953715
|
SenseVoiceAx.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import axengine as axe
|
| 2 |
+
import numpy as np
|
| 3 |
+
import librosa
|
| 4 |
+
from frontend import WavFrontend
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 8 |
+
|
| 9 |
+
def sequence_mask(lengths, maxlen=None, dtype=np.float32):
|
| 10 |
+
# 如果 maxlen 未指定,则取 lengths 中的最大值
|
| 11 |
+
if maxlen is None:
|
| 12 |
+
maxlen = np.max(lengths)
|
| 13 |
+
|
| 14 |
+
# 创建一个从 0 到 maxlen-1 的行向量
|
| 15 |
+
row_vector = np.arange(0, maxlen, 1)
|
| 16 |
+
|
| 17 |
+
# 将 lengths 转换为列向量
|
| 18 |
+
matrix = np.expand_dims(lengths, axis=-1)
|
| 19 |
+
|
| 20 |
+
# 比较生成掩码
|
| 21 |
+
mask = row_vector < matrix
|
| 22 |
+
|
| 23 |
+
# 返回指定数据类型的掩码
|
| 24 |
+
return mask.astype(dtype)[None, ...]
|
| 25 |
+
|
| 26 |
+
def unique_consecutive_np(x, dim=None, return_inverse=False, return_counts=False):
|
| 27 |
+
if dim is None:
|
| 28 |
+
# 默认情况,展平后去重
|
| 29 |
+
x_flat = x.ravel()
|
| 30 |
+
mask = np.concatenate(([True], x_flat[1:] != x_flat[:-1]))
|
| 31 |
+
unique_data = x_flat[mask]
|
| 32 |
+
else:
|
| 33 |
+
# 沿着指定维度去重
|
| 34 |
+
axis = dim if dim >= 0 else x.ndim + dim
|
| 35 |
+
if axis >= x.ndim:
|
| 36 |
+
raise ValueError(f"dim {dim} is out of range for array of dimension {x.ndim}")
|
| 37 |
+
|
| 38 |
+
# 使用 np.diff 检查相邻元素是否相同
|
| 39 |
+
mask = np.ones(x.shape[axis], dtype=bool)
|
| 40 |
+
if x.shape[axis] > 1:
|
| 41 |
+
# 比较当前元素和前一个元素是否不同
|
| 42 |
+
diff = np.diff(x, axis=axis)
|
| 43 |
+
mask[1:] = np.any(diff != 0, axis=tuple(range(diff.ndim))[axis:])
|
| 44 |
+
|
| 45 |
+
# 使用 mask 索引提取唯一元素
|
| 46 |
+
unique_data = np.take(x, np.where(mask)[0], axis=axis)
|
| 47 |
+
|
| 48 |
+
# 处理 return_inverse 和 return_counts
|
| 49 |
+
results = (unique_data,)
|
| 50 |
+
|
| 51 |
+
if return_inverse:
|
| 52 |
+
if dim is None:
|
| 53 |
+
inv_idx = np.cumsum(mask) - 1
|
| 54 |
+
else:
|
| 55 |
+
inv_idx = np.cumsum(mask) - 1
|
| 56 |
+
# 需要调整形状以匹配输入
|
| 57 |
+
inv_idx = np.expand_dims(inv_idx, axis=axis)
|
| 58 |
+
inv_idx = np.broadcast_to(inv_idx, x.shape)
|
| 59 |
+
results += (inv_idx,)
|
| 60 |
+
|
| 61 |
+
if return_counts:
|
| 62 |
+
if dim is None:
|
| 63 |
+
counts = np.diff(np.where(np.concatenate((mask, [True])))[0])
|
| 64 |
+
else:
|
| 65 |
+
counts = np.diff(np.where(np.concatenate((mask, [True])))[0])
|
| 66 |
+
results += (counts,)
|
| 67 |
+
|
| 68 |
+
return results[0] if len(results) == 1 else results
|
| 69 |
+
|
| 70 |
+
class SenseVoiceAx:
|
| 71 |
+
def __init__(self, model_path, language="auto", use_itn=True, tokenizer=None):
|
| 72 |
+
model_path_root = os.path.join(os.path.dirname(model_path), "../embeddings")
|
| 73 |
+
self.frontend = WavFrontend(cmvn_file="am.mvn",
|
| 74 |
+
fs=16000,
|
| 75 |
+
window="hamming",
|
| 76 |
+
n_mels=80,
|
| 77 |
+
frame_length=25,
|
| 78 |
+
frame_shift=10,
|
| 79 |
+
lfr_m=7,
|
| 80 |
+
lfr_n=6,)
|
| 81 |
+
self.model = axe.InferenceSession(model_path)
|
| 82 |
+
self.sample_rate = 16000
|
| 83 |
+
self.tokenizer = tokenizer
|
| 84 |
+
self.blank_id = 0
|
| 85 |
+
self.max_len = 34
|
| 86 |
+
|
| 87 |
+
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
|
| 88 |
+
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
|
| 89 |
+
self.textnorm_dict = {"withitn": 14, "woitn": 15}
|
| 90 |
+
self.textnorm_int_dict = {25016: 14, 25017: 15}
|
| 91 |
+
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
|
| 92 |
+
|
| 93 |
+
self.position_encoding = np.load(f"{model_path_root}/position_encoding.npy")
|
| 94 |
+
language_query = np.load(f"{model_path_root}/{language}.npy")
|
| 95 |
+
textnorm_query = np.load(f"{model_path_root}/withitn.npy") if use_itn else np.load(f"{model_path_root}/woitn.npy")
|
| 96 |
+
event_emo_query = np.load(f"{model_path_root}/event_emo.npy")
|
| 97 |
+
self.input_query = np.concatenate((textnorm_query, language_query, event_emo_query), axis=1)
|
| 98 |
+
self.query_num = self.input_query.shape[1]
|
| 99 |
+
self.masks = sequence_mask(np.array([self.max_len], dtype=np.int32), dtype=np.float32)
|
| 100 |
+
|
| 101 |
+
def load_data(self, filepath: str) -> np.ndarray:
|
| 102 |
+
waveform, _ = librosa.load(filepath, sr=self.sample_rate)
|
| 103 |
+
return waveform.flatten()
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
|
| 107 |
+
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
|
| 108 |
+
pad_width = ((0, max_feat_len - cur_len), (0, 0))
|
| 109 |
+
return np.pad(feat, pad_width, "constant", constant_values=0)
|
| 110 |
+
|
| 111 |
+
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
|
| 112 |
+
feats = np.array(feat_res).astype(np.float32)
|
| 113 |
+
return feats
|
| 114 |
+
|
| 115 |
+
def preprocess(self, waveform):
|
| 116 |
+
feats, feats_len = [], []
|
| 117 |
+
for wf in [waveform]:
|
| 118 |
+
speech, _ = self.frontend.fbank(wf)
|
| 119 |
+
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
| 120 |
+
feats.append(feat)
|
| 121 |
+
feats_len.append(feat_len)
|
| 122 |
+
|
| 123 |
+
feats = self.pad_feats(feats, np.max(feats_len))
|
| 124 |
+
feats_len = np.array(feats_len).astype(np.int32)
|
| 125 |
+
return feats, feats_len
|
| 126 |
+
|
| 127 |
+
def postprocess(self, ctc_logits, encoder_out_lens):
|
| 128 |
+
# 提取数据
|
| 129 |
+
x = ctc_logits[0, :encoder_out_lens[0], :]
|
| 130 |
+
|
| 131 |
+
# 获取最大值索引
|
| 132 |
+
yseq = np.argmax(x, axis=-1)
|
| 133 |
+
|
| 134 |
+
# 去除连续重复元素
|
| 135 |
+
yseq = unique_consecutive_np(yseq, dim=-1)
|
| 136 |
+
|
| 137 |
+
# 创建掩码并过滤 blank_id
|
| 138 |
+
mask = yseq != self.blank_id
|
| 139 |
+
token_int = yseq[mask].tolist()
|
| 140 |
+
|
| 141 |
+
return token_int
|
| 142 |
+
|
| 143 |
+
def infer_waveform(self, waveform: np.ndarray):
|
| 144 |
+
feat, feat_len = self.preprocess(waveform)
|
| 145 |
+
|
| 146 |
+
slice_len = self.max_len - self.query_num
|
| 147 |
+
slice_num = int(np.ceil(feat.shape[1] / slice_len))
|
| 148 |
+
|
| 149 |
+
asr_res = []
|
| 150 |
+
for i in range(slice_num):
|
| 151 |
+
sub_feat = feat[:, i*slice_len:(i+1)*slice_len, :]
|
| 152 |
+
# concat query
|
| 153 |
+
sub_feat = np.concatenate([self.input_query, sub_feat], axis=1)
|
| 154 |
+
|
| 155 |
+
if sub_feat.shape[1] < self.max_len:
|
| 156 |
+
sub_feat = np.concatenate([
|
| 157 |
+
sub_feat,
|
| 158 |
+
np.zeros((1, self.max_len - sub_feat.shape[1], sub_feat.shape[-1]), dtype=np.float32)
|
| 159 |
+
],
|
| 160 |
+
axis=1)
|
| 161 |
+
|
| 162 |
+
outputs = self.model.run(None, {"speech": sub_feat,
|
| 163 |
+
"masks": self.masks,
|
| 164 |
+
"position_encoding": self.position_encoding})
|
| 165 |
+
ctc_logits, encoder_out_lens = outputs
|
| 166 |
+
|
| 167 |
+
token_int = self.postprocess(ctc_logits, encoder_out_lens)
|
| 168 |
+
if self.tokenizer is not None:
|
| 169 |
+
asr_res.append(self.tokenizer.tokens2text(token_int))
|
| 170 |
+
else:
|
| 171 |
+
asr_res.append(token_int)
|
| 172 |
+
|
| 173 |
+
return asr_res
|
| 174 |
+
|
| 175 |
+
def infer(self, filepath_or_data: Union[np.ndarray, str], print_rtf=True):
|
| 176 |
+
if isinstance(filepath_or_data, str):
|
| 177 |
+
waveform = self.load_data(filepath_or_data)
|
| 178 |
+
else:
|
| 179 |
+
waveform = filepath_or_data
|
| 180 |
+
|
| 181 |
+
total_time = waveform.shape[-1] / self.sample_rate
|
| 182 |
+
|
| 183 |
+
start = time.time()
|
| 184 |
+
asr_res = self.infer_waveform(waveform)
|
| 185 |
+
latency = time.time() - start
|
| 186 |
+
|
| 187 |
+
if print_rtf:
|
| 188 |
+
rtf = latency / total_time
|
| 189 |
+
print(f"RTF: {rtf} Latency: {latency}s Total length: {total_time}s")
|
| 190 |
+
return asr_res
|
| 191 |
+
|
| 192 |
+
|
download_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Speed up hf download using mirror url
|
| 3 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
|
| 6 |
+
current_file_path = os.path.dirname(__file__)
|
| 7 |
+
REPO_ROOT = "AXERA-TECH"
|
| 8 |
+
CACHE_PATH = os.path.join(current_file_path, "models")
|
| 9 |
+
|
| 10 |
+
def download_model(model_name: str) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Download model from AXERA-TECH's huggingface space.
|
| 13 |
+
|
| 14 |
+
model_name: str
|
| 15 |
+
Available model names could be checked on https://huggingface.co/AXERA-TECH.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
str: Path to model_name
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
os.makedirs(CACHE_PATH, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
model_path = os.path.join(CACHE_PATH, model_name)
|
| 24 |
+
if not os.path.exists(model_path):
|
| 25 |
+
print(f"Downloading {model_name}...")
|
| 26 |
+
snapshot_download(repo_id=f"{REPO_ROOT}/{model_name}",
|
| 27 |
+
local_dir=os.path.join(CACHE_PATH, model_name))
|
| 28 |
+
|
| 29 |
+
return model_path
|
auto.npy → embeddings/auto.npy
RENAMED
|
File without changes
|
en.npy → embeddings/en.npy
RENAMED
|
File without changes
|
event_emo.npy → embeddings/event_emo.npy
RENAMED
|
File without changes
|
ja.npy → embeddings/ja.npy
RENAMED
|
File without changes
|
ko.npy → embeddings/ko.npy
RENAMED
|
File without changes
|
nospeech.npy → embeddings/nospeech.npy
RENAMED
|
File without changes
|
position_encoding.npy → embeddings/position_encoding.npy
RENAMED
|
File without changes
|
withitn.npy → embeddings/withitn.npy
RENAMED
|
File without changes
|
woitn.npy → embeddings/woitn.npy
RENAMED
|
File without changes
|
yue.npy → embeddings/yue.npy
RENAMED
|
File without changes
|
zh.npy → embeddings/zh.npy
RENAMED
|
File without changes
|
frontend.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import kaldi_native_fbank as knf
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class WavFrontend:
|
| 11 |
+
"""Conventional frontend structure for ASR."""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
cmvn_file: str = None,
|
| 16 |
+
fs: int = 16000,
|
| 17 |
+
window: str = "hamming",
|
| 18 |
+
n_mels: int = 80,
|
| 19 |
+
frame_length: int = 25,
|
| 20 |
+
frame_shift: int = 10,
|
| 21 |
+
lfr_m: int = 1,
|
| 22 |
+
lfr_n: int = 1,
|
| 23 |
+
dither: float = 1.0,
|
| 24 |
+
**kwargs,
|
| 25 |
+
) -> None:
|
| 26 |
+
|
| 27 |
+
opts = knf.FbankOptions()
|
| 28 |
+
opts.frame_opts.samp_freq = fs
|
| 29 |
+
opts.frame_opts.dither = dither
|
| 30 |
+
opts.frame_opts.window_type = window
|
| 31 |
+
opts.frame_opts.frame_shift_ms = float(frame_shift)
|
| 32 |
+
opts.frame_opts.frame_length_ms = float(frame_length)
|
| 33 |
+
opts.mel_opts.num_bins = n_mels
|
| 34 |
+
opts.energy_floor = 0
|
| 35 |
+
opts.frame_opts.snip_edges = True
|
| 36 |
+
opts.mel_opts.debug_mel = False
|
| 37 |
+
self.opts = opts
|
| 38 |
+
|
| 39 |
+
self.lfr_m = lfr_m
|
| 40 |
+
self.lfr_n = lfr_n
|
| 41 |
+
self.cmvn_file = cmvn_file
|
| 42 |
+
|
| 43 |
+
if self.cmvn_file:
|
| 44 |
+
self.cmvn = self.load_cmvn()
|
| 45 |
+
self.fbank_fn = None
|
| 46 |
+
self.fbank_beg_idx = 0
|
| 47 |
+
self.reset_status()
|
| 48 |
+
|
| 49 |
+
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 50 |
+
waveform = waveform * (1 << 15)
|
| 51 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 52 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 53 |
+
frames = self.fbank_fn.num_frames_ready
|
| 54 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 55 |
+
for i in range(frames):
|
| 56 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 57 |
+
feat = mat.astype(np.float32)
|
| 58 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 59 |
+
return feat, feat_len
|
| 60 |
+
|
| 61 |
+
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 62 |
+
waveform = waveform * (1 << 15)
|
| 63 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 64 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 65 |
+
frames = self.fbank_fn.num_frames_ready
|
| 66 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 67 |
+
for i in range(self.fbank_beg_idx, frames):
|
| 68 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 69 |
+
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
|
| 70 |
+
feat = mat.astype(np.float32)
|
| 71 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 72 |
+
return feat, feat_len
|
| 73 |
+
|
| 74 |
+
def reset_status(self):
|
| 75 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 76 |
+
self.fbank_beg_idx = 0
|
| 77 |
+
|
| 78 |
+
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 79 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 80 |
+
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
|
| 81 |
+
|
| 82 |
+
if self.cmvn_file:
|
| 83 |
+
feat = self.apply_cmvn(feat)
|
| 84 |
+
|
| 85 |
+
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
| 86 |
+
return feat, feat_len
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
|
| 90 |
+
LFR_inputs = []
|
| 91 |
+
|
| 92 |
+
T = inputs.shape[0]
|
| 93 |
+
T_lfr = int(np.ceil(T / lfr_n))
|
| 94 |
+
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
|
| 95 |
+
inputs = np.vstack((left_padding, inputs))
|
| 96 |
+
T = T + (lfr_m - 1) // 2
|
| 97 |
+
for i in range(T_lfr):
|
| 98 |
+
if lfr_m <= T - i * lfr_n:
|
| 99 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
| 100 |
+
else:
|
| 101 |
+
# process last LFR frame
|
| 102 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 103 |
+
frame = inputs[i * lfr_n :].reshape(-1)
|
| 104 |
+
for _ in range(num_padding):
|
| 105 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 106 |
+
|
| 107 |
+
LFR_inputs.append(frame)
|
| 108 |
+
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
|
| 109 |
+
return LFR_outputs
|
| 110 |
+
|
| 111 |
+
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
|
| 112 |
+
"""
|
| 113 |
+
Apply CMVN with mvn data
|
| 114 |
+
"""
|
| 115 |
+
frame, dim = inputs.shape
|
| 116 |
+
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
|
| 117 |
+
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
|
| 118 |
+
inputs = (inputs + means) * vars
|
| 119 |
+
return inputs
|
| 120 |
+
|
| 121 |
+
def load_cmvn(
|
| 122 |
+
self,
|
| 123 |
+
) -> np.ndarray:
|
| 124 |
+
with open(self.cmvn_file, "r", encoding="utf-8") as f:
|
| 125 |
+
lines = f.readlines()
|
| 126 |
+
|
| 127 |
+
means_list = []
|
| 128 |
+
vars_list = []
|
| 129 |
+
for i in range(len(lines)):
|
| 130 |
+
line_item = lines[i].split()
|
| 131 |
+
if line_item[0] == "<AddShift>":
|
| 132 |
+
line_item = lines[i + 1].split()
|
| 133 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 134 |
+
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
| 135 |
+
means_list = list(add_shift_line)
|
| 136 |
+
continue
|
| 137 |
+
elif line_item[0] == "<Rescale>":
|
| 138 |
+
line_item = lines[i + 1].split()
|
| 139 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 140 |
+
rescale_line = line_item[3 : (len(line_item) - 1)]
|
| 141 |
+
vars_list = list(rescale_line)
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
means = np.array(means_list).astype(np.float64)
|
| 145 |
+
vars = np.array(vars_list).astype(np.float64)
|
| 146 |
+
cmvn = np.array([means, vars])
|
| 147 |
+
return cmvn
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class WavFrontendOnline(WavFrontend):
|
| 151 |
+
def __init__(self, **kwargs):
|
| 152 |
+
super().__init__(**kwargs)
|
| 153 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 154 |
+
# add variables
|
| 155 |
+
self.frame_sample_length = int(
|
| 156 |
+
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
|
| 157 |
+
)
|
| 158 |
+
self.frame_shift_sample_length = int(
|
| 159 |
+
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
|
| 160 |
+
)
|
| 161 |
+
self.waveform = None
|
| 162 |
+
self.reserve_waveforms = None
|
| 163 |
+
self.input_cache = None
|
| 164 |
+
self.lfr_splice_cache = []
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
# inputs has catted the cache
|
| 168 |
+
def apply_lfr(
|
| 169 |
+
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
|
| 170 |
+
) -> Tuple[np.ndarray, np.ndarray, int]:
|
| 171 |
+
"""
|
| 172 |
+
Apply lfr with data
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
LFR_inputs = []
|
| 176 |
+
T = inputs.shape[0] # include the right context
|
| 177 |
+
T_lfr = int(
|
| 178 |
+
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
| 179 |
+
) # minus the right context: (lfr_m - 1) // 2
|
| 180 |
+
splice_idx = T_lfr
|
| 181 |
+
for i in range(T_lfr):
|
| 182 |
+
if lfr_m <= T - i * lfr_n:
|
| 183 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
| 184 |
+
else: # process last LFR frame
|
| 185 |
+
if is_final:
|
| 186 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 187 |
+
frame = (inputs[i * lfr_n :]).reshape(-1)
|
| 188 |
+
for _ in range(num_padding):
|
| 189 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 190 |
+
LFR_inputs.append(frame)
|
| 191 |
+
else:
|
| 192 |
+
# update splice_idx and break the circle
|
| 193 |
+
splice_idx = i
|
| 194 |
+
break
|
| 195 |
+
splice_idx = min(T - 1, splice_idx * lfr_n)
|
| 196 |
+
lfr_splice_cache = inputs[splice_idx:, :]
|
| 197 |
+
LFR_outputs = np.vstack(LFR_inputs)
|
| 198 |
+
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
def compute_frame_num(
|
| 202 |
+
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
|
| 203 |
+
) -> int:
|
| 204 |
+
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
| 205 |
+
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
| 206 |
+
|
| 207 |
+
def fbank(
|
| 208 |
+
self, input: np.ndarray, input_lengths: np.ndarray
|
| 209 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 210 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 211 |
+
batch_size = input.shape[0]
|
| 212 |
+
if self.input_cache is None:
|
| 213 |
+
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
|
| 214 |
+
input = np.concatenate((self.input_cache, input), axis=1)
|
| 215 |
+
frame_num = self.compute_frame_num(
|
| 216 |
+
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
|
| 217 |
+
)
|
| 218 |
+
# update self.in_cache
|
| 219 |
+
self.input_cache = input[
|
| 220 |
+
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
|
| 221 |
+
]
|
| 222 |
+
waveforms = np.empty(0, dtype=np.float32)
|
| 223 |
+
feats_pad = np.empty(0, dtype=np.float32)
|
| 224 |
+
feats_lens = np.empty(0, dtype=np.int32)
|
| 225 |
+
if frame_num:
|
| 226 |
+
waveforms = []
|
| 227 |
+
feats = []
|
| 228 |
+
feats_lens = []
|
| 229 |
+
for i in range(batch_size):
|
| 230 |
+
waveform = input[i]
|
| 231 |
+
waveforms.append(
|
| 232 |
+
waveform[
|
| 233 |
+
: (
|
| 234 |
+
(frame_num - 1) * self.frame_shift_sample_length
|
| 235 |
+
+ self.frame_sample_length
|
| 236 |
+
)
|
| 237 |
+
]
|
| 238 |
+
)
|
| 239 |
+
waveform = waveform * (1 << 15)
|
| 240 |
+
|
| 241 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 242 |
+
frames = self.fbank_fn.num_frames_ready
|
| 243 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 244 |
+
for i in range(frames):
|
| 245 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 246 |
+
feat = mat.astype(np.float32)
|
| 247 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 248 |
+
feats.append(feat)
|
| 249 |
+
feats_lens.append(feat_len)
|
| 250 |
+
|
| 251 |
+
waveforms = np.stack(waveforms)
|
| 252 |
+
feats_lens = np.array(feats_lens)
|
| 253 |
+
feats_pad = np.array(feats)
|
| 254 |
+
self.fbanks = feats_pad
|
| 255 |
+
self.fbanks_lens = copy.deepcopy(feats_lens)
|
| 256 |
+
return waveforms, feats_pad, feats_lens
|
| 257 |
+
|
| 258 |
+
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 259 |
+
return self.fbanks, self.fbanks_lens
|
| 260 |
+
|
| 261 |
+
def lfr_cmvn(
|
| 262 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 263 |
+
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
|
| 264 |
+
batch_size = input.shape[0]
|
| 265 |
+
feats = []
|
| 266 |
+
feats_lens = []
|
| 267 |
+
lfr_splice_frame_idxs = []
|
| 268 |
+
for i in range(batch_size):
|
| 269 |
+
mat = input[i, : input_lengths[i], :]
|
| 270 |
+
lfr_splice_frame_idx = -1
|
| 271 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 272 |
+
# update self.lfr_splice_cache in self.apply_lfr
|
| 273 |
+
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
|
| 274 |
+
mat, self.lfr_m, self.lfr_n, is_final
|
| 275 |
+
)
|
| 276 |
+
if self.cmvn_file is not None:
|
| 277 |
+
mat = self.apply_cmvn(mat)
|
| 278 |
+
feat_length = mat.shape[0]
|
| 279 |
+
feats.append(mat)
|
| 280 |
+
feats_lens.append(feat_length)
|
| 281 |
+
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
| 282 |
+
|
| 283 |
+
feats_lens = np.array(feats_lens)
|
| 284 |
+
feats_pad = np.array(feats)
|
| 285 |
+
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
| 286 |
+
|
| 287 |
+
def extract_fbank(
|
| 288 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 289 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 290 |
+
batch_size = input.shape[0]
|
| 291 |
+
assert (
|
| 292 |
+
batch_size == 1
|
| 293 |
+
), "we support to extract feature online only when the batch size is equal to 1 now"
|
| 294 |
+
waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
|
| 295 |
+
if feats.shape[0]:
|
| 296 |
+
self.waveforms = (
|
| 297 |
+
waveforms
|
| 298 |
+
if self.reserve_waveforms is None
|
| 299 |
+
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
|
| 300 |
+
)
|
| 301 |
+
if not self.lfr_splice_cache:
|
| 302 |
+
for i in range(batch_size):
|
| 303 |
+
self.lfr_splice_cache.append(
|
| 304 |
+
np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
| 308 |
+
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
|
| 309 |
+
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
|
| 310 |
+
feats_lengths += lfr_splice_cache_np[0].shape[0]
|
| 311 |
+
frame_from_waveforms = int(
|
| 312 |
+
(self.waveforms.shape[1] - self.frame_sample_length)
|
| 313 |
+
/ self.frame_shift_sample_length
|
| 314 |
+
+ 1
|
| 315 |
+
)
|
| 316 |
+
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
| 317 |
+
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
|
| 318 |
+
feats, feats_lengths, is_final
|
| 319 |
+
)
|
| 320 |
+
if self.lfr_m == 1:
|
| 321 |
+
self.reserve_waveforms = None
|
| 322 |
+
else:
|
| 323 |
+
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
| 324 |
+
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
| 325 |
+
# print('frame_frame: ' + str(frame_from_waveforms))
|
| 326 |
+
self.reserve_waveforms = self.waveforms[
|
| 327 |
+
:,
|
| 328 |
+
reserve_frame_idx
|
| 329 |
+
* self.frame_shift_sample_length : frame_from_waveforms
|
| 330 |
+
* self.frame_shift_sample_length,
|
| 331 |
+
]
|
| 332 |
+
sample_length = (
|
| 333 |
+
frame_from_waveforms - 1
|
| 334 |
+
) * self.frame_shift_sample_length + self.frame_sample_length
|
| 335 |
+
self.waveforms = self.waveforms[:, :sample_length]
|
| 336 |
+
else:
|
| 337 |
+
# update self.reserve_waveforms and self.lfr_splice_cache
|
| 338 |
+
self.reserve_waveforms = self.waveforms[
|
| 339 |
+
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
|
| 340 |
+
]
|
| 341 |
+
for i in range(batch_size):
|
| 342 |
+
self.lfr_splice_cache[i] = np.concatenate(
|
| 343 |
+
(self.lfr_splice_cache[i], feats[i]), axis=0
|
| 344 |
+
)
|
| 345 |
+
return np.empty(0, dtype=np.float32), feats_lengths
|
| 346 |
+
else:
|
| 347 |
+
if is_final:
|
| 348 |
+
self.waveforms = (
|
| 349 |
+
waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
| 350 |
+
)
|
| 351 |
+
feats = np.stack(self.lfr_splice_cache)
|
| 352 |
+
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
|
| 353 |
+
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
|
| 354 |
+
if is_final:
|
| 355 |
+
self.cache_reset()
|
| 356 |
+
return feats, feats_lengths
|
| 357 |
+
|
| 358 |
+
def get_waveforms(self):
|
| 359 |
+
return self.waveforms
|
| 360 |
+
|
| 361 |
+
def cache_reset(self):
|
| 362 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 363 |
+
self.reserve_waveforms = None
|
| 364 |
+
self.input_cache = None
|
| 365 |
+
self.lfr_splice_cache = []
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def load_bytes(input):
|
| 369 |
+
middle_data = np.frombuffer(input, dtype=np.int16)
|
| 370 |
+
middle_data = np.asarray(middle_data)
|
| 371 |
+
if middle_data.dtype.kind not in "iu":
|
| 372 |
+
raise TypeError("'middle_data' must be an array of integers")
|
| 373 |
+
dtype = np.dtype("float32")
|
| 374 |
+
if dtype.kind != "f":
|
| 375 |
+
raise TypeError("'dtype' must be a floating point type")
|
| 376 |
+
|
| 377 |
+
i = np.iinfo(middle_data.dtype)
|
| 378 |
+
abs_max = 2 ** (i.bits - 1)
|
| 379 |
+
offset = i.min + abs_max
|
| 380 |
+
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
| 381 |
+
return array
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class SinusoidalPositionEncoderOnline:
|
| 385 |
+
"""Streaming Positional encoding."""
|
| 386 |
+
|
| 387 |
+
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
|
| 388 |
+
batch_size = positions.shape[0]
|
| 389 |
+
positions = positions.astype(dtype)
|
| 390 |
+
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
|
| 391 |
+
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
|
| 392 |
+
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
|
| 393 |
+
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
|
| 394 |
+
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
|
| 395 |
+
return encoding.astype(dtype)
|
| 396 |
+
|
| 397 |
+
def forward(self, x, start_idx=0):
|
| 398 |
+
batch_size, timesteps, input_dim = x.shape
|
| 399 |
+
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
|
| 400 |
+
position_encoding = self.encode(positions, input_dim, x.dtype)
|
| 401 |
+
|
| 402 |
+
return x + position_encoding[:, start_idx : start_idx + timesteps]
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def test():
|
| 406 |
+
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
| 407 |
+
import librosa
|
| 408 |
+
|
| 409 |
+
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
|
| 410 |
+
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
|
| 411 |
+
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
|
| 412 |
+
|
| 413 |
+
config = read_yaml(config_file)
|
| 414 |
+
waveform, _ = librosa.load(path, sr=None)
|
| 415 |
+
frontend = WavFrontend(
|
| 416 |
+
cmvn_file=cmvn_file,
|
| 417 |
+
**config["frontend_conf"],
|
| 418 |
+
)
|
| 419 |
+
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
|
| 420 |
+
feat, feat_len = frontend.lfr_cmvn(
|
| 421 |
+
speech
|
| 422 |
+
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
|
| 423 |
+
|
| 424 |
+
frontend.reset_status() # clear cache
|
| 425 |
+
return feat, feat_len
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if __name__ == "__main__":
|
| 429 |
+
test()
|
main.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import argparse
|
| 3 |
+
from SenseVoiceAx import SenseVoiceAx
|
| 4 |
+
from tokenizer import SentencepiecesTokenizer
|
| 5 |
+
from print_utils import rich_transcription_postprocess, rich_print_asr_res
|
| 6 |
+
from download_utils import download_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_args():
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument("--input", "-i", required=True, type=str, help="Input audio file")
|
| 12 |
+
parser.add_argument("--language", "-l", required=False, type=str, default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"])
|
| 13 |
+
return parser.parse_args()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
args = get_args()
|
| 18 |
+
|
| 19 |
+
input_audio = args.input
|
| 20 |
+
language = args.language
|
| 21 |
+
use_itn = True # 标点符号预测
|
| 22 |
+
|
| 23 |
+
model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
|
| 24 |
+
bpemodel = "chn_jpn_yue_eng_ko_spectok.bpe.model"
|
| 25 |
+
|
| 26 |
+
assert os.path.exists(model_path), f"model {model_path} not exist"
|
| 27 |
+
|
| 28 |
+
print(f"input_audio: {input_audio}")
|
| 29 |
+
print(f"language: {language}")
|
| 30 |
+
print(f"use_itn: {use_itn}")
|
| 31 |
+
print(f"model_path: {model_path}")
|
| 32 |
+
|
| 33 |
+
tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
|
| 34 |
+
pipeline = SenseVoiceAx(model_path, language, use_itn, tokenizer=tokenizer)
|
| 35 |
+
asr_res = pipeline.infer(input_audio, print_rtf=True)
|
| 36 |
+
print([rich_transcription_postprocess(i) for i in asr_res])
|
| 37 |
+
# rich_print_asr_res(asr_res)
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
print_utils.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
emo_dict = {
|
| 2 |
+
"<|HAPPY|>": "😊",
|
| 3 |
+
"<|SAD|>": "😔",
|
| 4 |
+
"<|ANGRY|>": "😡",
|
| 5 |
+
"<|NEUTRAL|>": "",
|
| 6 |
+
"<|FEARFUL|>": "😰",
|
| 7 |
+
"<|DISGUSTED|>": "🤢",
|
| 8 |
+
"<|SURPRISED|>": "😮",
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
event_dict = {
|
| 12 |
+
"<|BGM|>": "🎼",
|
| 13 |
+
"<|Speech|>": "",
|
| 14 |
+
"<|Applause|>": "👏",
|
| 15 |
+
"<|Laughter|>": "😀",
|
| 16 |
+
"<|Cry|>": "😭",
|
| 17 |
+
"<|Sneeze|>": "🤧",
|
| 18 |
+
"<|Breath|>": "",
|
| 19 |
+
"<|Cough|>": "🤧",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
lang_dict = {
|
| 23 |
+
"<|zh|>": "<|lang|>",
|
| 24 |
+
"<|en|>": "<|lang|>",
|
| 25 |
+
"<|yue|>": "<|lang|>",
|
| 26 |
+
"<|ja|>": "<|lang|>",
|
| 27 |
+
"<|ko|>": "<|lang|>",
|
| 28 |
+
"<|nospeech|>": "<|lang|>",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
emoji_dict = {
|
| 32 |
+
"<|nospeech|><|Event_UNK|>": "❓",
|
| 33 |
+
"<|zh|>": "",
|
| 34 |
+
"<|en|>": "",
|
| 35 |
+
"<|yue|>": "",
|
| 36 |
+
"<|ja|>": "",
|
| 37 |
+
"<|ko|>": "",
|
| 38 |
+
"<|nospeech|>": "",
|
| 39 |
+
"<|HAPPY|>": "😊",
|
| 40 |
+
"<|SAD|>": "😔",
|
| 41 |
+
"<|ANGRY|>": "😡",
|
| 42 |
+
"<|NEUTRAL|>": "",
|
| 43 |
+
"<|BGM|>": "🎼",
|
| 44 |
+
"<|Speech|>": "",
|
| 45 |
+
"<|Applause|>": "👏",
|
| 46 |
+
"<|Laughter|>": "😀",
|
| 47 |
+
"<|FEARFUL|>": "😰",
|
| 48 |
+
"<|DISGUSTED|>": "🤢",
|
| 49 |
+
"<|SURPRISED|>": "😮",
|
| 50 |
+
"<|Cry|>": "😭",
|
| 51 |
+
"<|EMO_UNKNOWN|>": "",
|
| 52 |
+
"<|Sneeze|>": "🤧",
|
| 53 |
+
"<|Breath|>": "",
|
| 54 |
+
"<|Cough|>": "😷",
|
| 55 |
+
"<|Sing|>": "",
|
| 56 |
+
"<|Speech_Noise|>": "",
|
| 57 |
+
"<|withitn|>": "",
|
| 58 |
+
"<|woitn|>": "",
|
| 59 |
+
"<|GBG|>": "",
|
| 60 |
+
"<|Event_UNK|>": "",
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
|
| 64 |
+
event_set = {
|
| 65 |
+
"🎼",
|
| 66 |
+
"👏",
|
| 67 |
+
"😀",
|
| 68 |
+
"😭",
|
| 69 |
+
"🤧",
|
| 70 |
+
"😷",
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def format_str_v2(s):
|
| 75 |
+
sptk_dict = {}
|
| 76 |
+
for sptk in emoji_dict:
|
| 77 |
+
sptk_dict[sptk] = s.count(sptk)
|
| 78 |
+
s = s.replace(sptk, "")
|
| 79 |
+
emo = "<|NEUTRAL|>"
|
| 80 |
+
for e in emo_dict:
|
| 81 |
+
if sptk_dict[e] > sptk_dict[emo]:
|
| 82 |
+
emo = e
|
| 83 |
+
for e in event_dict:
|
| 84 |
+
if sptk_dict[e] > 0:
|
| 85 |
+
s = event_dict[e] + s
|
| 86 |
+
s = s + emo_dict[emo]
|
| 87 |
+
|
| 88 |
+
for emoji in emo_set.union(event_set):
|
| 89 |
+
s = s.replace(" " + emoji, emoji)
|
| 90 |
+
s = s.replace(emoji + " ", emoji)
|
| 91 |
+
return s.strip()
|
| 92 |
+
|
| 93 |
+
def rich_transcription_postprocess(s):
|
| 94 |
+
def get_emo(s):
|
| 95 |
+
return s[-1] if s[-1] in emo_set else None
|
| 96 |
+
|
| 97 |
+
def get_event(s):
|
| 98 |
+
return s[0] if s[0] in event_set else None
|
| 99 |
+
|
| 100 |
+
s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
|
| 101 |
+
for lang in lang_dict:
|
| 102 |
+
s = s.replace(lang, "<|lang|>")
|
| 103 |
+
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
|
| 104 |
+
new_s = " " + s_list[0]
|
| 105 |
+
cur_ent_event = get_event(new_s)
|
| 106 |
+
for i in range(1, len(s_list)):
|
| 107 |
+
if len(s_list[i]) == 0:
|
| 108 |
+
continue
|
| 109 |
+
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
|
| 110 |
+
s_list[i] = s_list[i][1:]
|
| 111 |
+
# else:
|
| 112 |
+
cur_ent_event = get_event(s_list[i])
|
| 113 |
+
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
|
| 114 |
+
new_s = new_s[:-1]
|
| 115 |
+
new_s += s_list[i].strip().lstrip()
|
| 116 |
+
new_s = new_s.replace("The.", " ")
|
| 117 |
+
return new_s.strip()
|
| 118 |
+
|
| 119 |
+
def rich_print_asr_res(asr_res):
|
| 120 |
+
res = "".join([rich_transcription_postprocess(i) for i in asr_res])
|
| 121 |
+
print(res)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
huggingface_hub
|
| 2 |
+
numpy<2
|
| 3 |
+
kaldi-native-fbank
|
| 4 |
+
librosa==0.9.1
|
| 5 |
+
sentencepiece
|
sensevoice.axmodel → sensevoice_ax650/sensevoice.axmodel
RENAMED
|
File without changes
|
tokenizer.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sentencepiece as spm
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from abc import abstractmethod
|
| 8 |
+
from abc import ABC
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseTokenizer(ABC):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
token_list: Union[Path, str, Iterable[str]] = None,
|
| 16 |
+
unk_symbol: str = "<unk>",
|
| 17 |
+
**kwargs,
|
| 18 |
+
):
|
| 19 |
+
|
| 20 |
+
if token_list is not None:
|
| 21 |
+
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
|
| 22 |
+
token_list = Path(token_list)
|
| 23 |
+
self.token_list_repr = str(token_list)
|
| 24 |
+
self.token_list: List[str] = []
|
| 25 |
+
|
| 26 |
+
with token_list.open("r", encoding="utf-8") as f:
|
| 27 |
+
for idx, line in enumerate(f):
|
| 28 |
+
line = line.rstrip()
|
| 29 |
+
self.token_list.append(line)
|
| 30 |
+
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
|
| 31 |
+
token_list = Path(token_list)
|
| 32 |
+
self.token_list_repr = str(token_list)
|
| 33 |
+
self.token_list: List[str] = []
|
| 34 |
+
|
| 35 |
+
with open(token_list, "r", encoding="utf-8") as f:
|
| 36 |
+
self.token_list = json.load(f)
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
self.token_list: List[str] = list(token_list)
|
| 40 |
+
self.token_list_repr = ""
|
| 41 |
+
for i, t in enumerate(self.token_list):
|
| 42 |
+
if i == 3:
|
| 43 |
+
break
|
| 44 |
+
self.token_list_repr += f"{t}, "
|
| 45 |
+
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
| 46 |
+
|
| 47 |
+
self.token2id: Dict[str, int] = {}
|
| 48 |
+
for i, t in enumerate(self.token_list):
|
| 49 |
+
if t in self.token2id:
|
| 50 |
+
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
| 51 |
+
self.token2id[t] = i
|
| 52 |
+
|
| 53 |
+
self.unk_symbol = unk_symbol
|
| 54 |
+
if self.unk_symbol not in self.token2id:
|
| 55 |
+
raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
|
| 56 |
+
self.unk_id = self.token2id[self.unk_symbol]
|
| 57 |
+
|
| 58 |
+
def encode(self, text, **kwargs):
|
| 59 |
+
tokens = self.text2tokens(text)
|
| 60 |
+
text_ints = self.tokens2ids(tokens)
|
| 61 |
+
|
| 62 |
+
return text_ints
|
| 63 |
+
|
| 64 |
+
def decode(self, text_ints):
|
| 65 |
+
token = self.ids2tokens(text_ints)
|
| 66 |
+
text = self.tokens2text(token)
|
| 67 |
+
return text
|
| 68 |
+
|
| 69 |
+
def get_num_vocabulary_size(self) -> int:
|
| 70 |
+
return len(self.token_list)
|
| 71 |
+
|
| 72 |
+
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
| 73 |
+
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
| 74 |
+
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
| 75 |
+
return [self.token_list[i] for i in integers]
|
| 76 |
+
|
| 77 |
+
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
| 78 |
+
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
| 79 |
+
|
| 80 |
+
@abstractmethod
|
| 81 |
+
def text2tokens(self, line: str) -> List[str]:
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
|
| 84 |
+
@abstractmethod
|
| 85 |
+
def tokens2text(self, tokens: Iterable[str]) -> str:
|
| 86 |
+
raise NotImplementedError
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class SentencepiecesTokenizer(BaseTokenizer):
|
| 90 |
+
def __init__(self, bpemodel: Union[Path, str], **kwargs):
|
| 91 |
+
super().__init__(**kwargs)
|
| 92 |
+
self.bpemodel = str(bpemodel)
|
| 93 |
+
# NOTE(kamo):
|
| 94 |
+
# Don't build SentencePieceProcessor in __init__()
|
| 95 |
+
# because it's not picklable and it may cause following error,
|
| 96 |
+
# "TypeError: can't pickle SwigPyObject objects",
|
| 97 |
+
# when giving it as argument of "multiprocessing.Process()".
|
| 98 |
+
self.sp = None
|
| 99 |
+
self._build_sentence_piece_processor()
|
| 100 |
+
|
| 101 |
+
def __repr__(self):
|
| 102 |
+
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
|
| 103 |
+
|
| 104 |
+
def _build_sentence_piece_processor(self):
|
| 105 |
+
# Build SentencePieceProcessor lazily.
|
| 106 |
+
if self.sp is None:
|
| 107 |
+
self.sp = spm.SentencePieceProcessor()
|
| 108 |
+
self.sp.load(self.bpemodel)
|
| 109 |
+
|
| 110 |
+
def text2tokens(self, line: str) -> List[str]:
|
| 111 |
+
self._build_sentence_piece_processor()
|
| 112 |
+
return self.sp.EncodeAsPieces(line)
|
| 113 |
+
|
| 114 |
+
def tokens2text(self, tokens: Iterable[str]) -> str:
|
| 115 |
+
self._build_sentence_piece_processor()
|
| 116 |
+
return self.sp.DecodePieces(list(tokens))
|
| 117 |
+
|
| 118 |
+
def encode(self, line: str, **kwargs) -> List[int]:
|
| 119 |
+
self._build_sentence_piece_processor()
|
| 120 |
+
return self.sp.EncodeAsIds(line)
|
| 121 |
+
|
| 122 |
+
def decode(self, line: List[int], **kwargs):
|
| 123 |
+
self._build_sentence_piece_processor()
|
| 124 |
+
return self.sp.DecodeIds(line)
|
| 125 |
+
|
| 126 |
+
def get_vocab_size(self):
|
| 127 |
+
return self.sp.GetPieceSize()
|
| 128 |
+
|
| 129 |
+
def ids2tokens(self, *args, **kwargs):
|
| 130 |
+
return self.decode(*args, **kwargs)
|
| 131 |
+
|
| 132 |
+
def tokens2ids(self, *args, **kwargs):
|
| 133 |
+
return self.encode(*args, **kwargs)
|