File size: 3,666 Bytes
3c50954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from typing import List

import jieba
import torch
from wenet.cli.hub import Hub
from wenet.paraformer.search import _isAllAlpha
from wenet.text.char_tokenizer import CharTokenizer


class PuncModel:

    def __init__(self, model_dir: str) -> None:
        self.model_dir = model_dir
        model_path = os.path.join(model_dir, 'final.zip')
        units_path = os.path.join(model_dir, 'units.txt')

        self.model = torch.jit.load(model_path)
        self.tokenizer = CharTokenizer(units_path)
        self.device = torch.device("cpu")
        self.use_jieba = False

        self.punc_table = ['<unk>', '', ',', '。', '?', '、']

    def split_words(self, text: str):
        if not self.use_jieba:
            self.use_jieba = True
            import logging

            # Disable jieba's logger
            logging.getLogger('jieba').disabled = True
            jieba.load_userdict(os.path.join(self.model_dir, 'jieba_usr_dict'))

        result_list = []
        tokens = text.split()
        current_language = None
        buffer = []

        for token in tokens:
            is_english = token.isascii()
            if is_english:
                language = "English"
            else:
                language = "Chinese"

            if current_language and language != current_language:
                if current_language == "Chinese":
                    result_list.extend(jieba.cut(''.join(buffer), HMM=False))
                else:
                    result_list.extend(buffer)
                buffer = []

            buffer.append(token)
            current_language = language

        if buffer:
            if current_language == "Chinese":
                result_list.extend(jieba.cut(''.join(buffer), HMM=False))
            else:
                result_list.extend(buffer)

        return result_list

    def add_punc_batch(self, texts: List[str]):
        batch_text_words = []
        batch_text_ids = []
        batch_text_lens = []

        for text in texts:
            words = self.split_words(text)
            ids = self.tokenizer.tokens2ids(words)
            batch_text_words.append(words)
            batch_text_ids.append(ids)
            batch_text_lens.append(len(ids))

        texts_tensor = torch.tensor(batch_text_ids,
                                    device=self.device,
                                    dtype=torch.int64)
        texts_lens_tensor = torch.tensor(batch_text_lens,
                                         device=self.device,
                                         dtype=torch.int64)

        log_probs, _ = self.model(texts_tensor, texts_lens_tensor)
        result = []
        outs = log_probs.argmax(-1).cpu().numpy()
        for i, out in enumerate(outs):
            punc_id = out[:batch_text_lens[i]]
            sentence = ''
            for j, word in enumerate(batch_text_words[i]):
                if _isAllAlpha(word):
                    word = '▁' + word
                word += self.punc_table[punc_id[j]]
                sentence += word
            result.append(sentence.replace('▁', ' '))
        return result

    def __call__(self, text: str):
        if text != '':
            r = self.add_punc_batch([text])[0]
            return r
        return ''


def load_model(model_dir: str = None,
               gpu: int = -1,
               device: str = "cpu") -> PuncModel:
    if model_dir is None:
        model_dir = Hub.get_model_by_lang('punc')
    if gpu != -1:
        # remain the original usage of gpu
        device = "cuda"
    punc = PuncModel(model_dir)
    punc.device = torch.device(device)
    punc.model.to(device)
    return punc