File size: 3,666 Bytes
3c50954 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | import os
from typing import List
import jieba
import torch
from wenet.cli.hub import Hub
from wenet.paraformer.search import _isAllAlpha
from wenet.text.char_tokenizer import CharTokenizer
class PuncModel:
def __init__(self, model_dir: str) -> None:
self.model_dir = model_dir
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
self.tokenizer = CharTokenizer(units_path)
self.device = torch.device("cpu")
self.use_jieba = False
self.punc_table = ['<unk>', '', ',', '。', '?', '、']
def split_words(self, text: str):
if not self.use_jieba:
self.use_jieba = True
import logging
# Disable jieba's logger
logging.getLogger('jieba').disabled = True
jieba.load_userdict(os.path.join(self.model_dir, 'jieba_usr_dict'))
result_list = []
tokens = text.split()
current_language = None
buffer = []
for token in tokens:
is_english = token.isascii()
if is_english:
language = "English"
else:
language = "Chinese"
if current_language and language != current_language:
if current_language == "Chinese":
result_list.extend(jieba.cut(''.join(buffer), HMM=False))
else:
result_list.extend(buffer)
buffer = []
buffer.append(token)
current_language = language
if buffer:
if current_language == "Chinese":
result_list.extend(jieba.cut(''.join(buffer), HMM=False))
else:
result_list.extend(buffer)
return result_list
def add_punc_batch(self, texts: List[str]):
batch_text_words = []
batch_text_ids = []
batch_text_lens = []
for text in texts:
words = self.split_words(text)
ids = self.tokenizer.tokens2ids(words)
batch_text_words.append(words)
batch_text_ids.append(ids)
batch_text_lens.append(len(ids))
texts_tensor = torch.tensor(batch_text_ids,
device=self.device,
dtype=torch.int64)
texts_lens_tensor = torch.tensor(batch_text_lens,
device=self.device,
dtype=torch.int64)
log_probs, _ = self.model(texts_tensor, texts_lens_tensor)
result = []
outs = log_probs.argmax(-1).cpu().numpy()
for i, out in enumerate(outs):
punc_id = out[:batch_text_lens[i]]
sentence = ''
for j, word in enumerate(batch_text_words[i]):
if _isAllAlpha(word):
word = '▁' + word
word += self.punc_table[punc_id[j]]
sentence += word
result.append(sentence.replace('▁', ' '))
return result
def __call__(self, text: str):
if text != '':
r = self.add_punc_batch([text])[0]
return r
return ''
def load_model(model_dir: str = None,
gpu: int = -1,
device: str = "cpu") -> PuncModel:
if model_dir is None:
model_dir = Hub.get_model_by_lang('punc')
if gpu != -1:
# remain the original usage of gpu
device = "cuda"
punc = PuncModel(model_dir)
punc.device = torch.device(device)
punc.model.to(device)
return punc
|