File size: 4,047 Bytes
90f0b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import re

import torch
from transformers import AutoTokenizer
from transformers.trainer_pt_utils import LabelSmoother

DEFAULT_SPEECH_TOKEN = "<speech>"
IGNORE_TOKEN_ID = LabelSmoother.ignore_index


class LlmTokenizerWrapper:
    @classmethod
    def build_llm_tokenizer(cls, llm_path, use_flash_attn=False):
        tokenizer = AutoTokenizer.from_pretrained(llm_path)
        if use_flash_attn:
            tokenizer.padding_side = "left"
        else:
            tokenizer.padding_side = "right"
        special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
        tokenizer.add_special_tokens(special_tokens_dict)
        return tokenizer

    @classmethod
    def clean_text(cls, origin_text):
        """remove punc, remove space between Chinese and keep space between English"""
        # remove punc
        text = re.sub("[,。?!,\.!?《》()\·“”、\\/]", "", origin_text)
        # merge space
        text = re.sub("\s+", " ", text)

        # remove space between Chinese and keep space between English
        pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])')  # Chinese
        parts = pattern.split(text.strip())
        parts = [p for p in parts if len(p.strip()) > 0]
        text = "".join(parts)
        text = text.strip()

        text = text.lower()
        return text

    @classmethod
    def preprocess_texts(cls, origin_texts, tokenizer, max_len, decode=False):
        messages = []
        clean_texts = []
        for i, origin_text in enumerate(origin_texts):
            text = cls.clean_text(origin_text)
            clean_texts.append(text)
            text = text if not decode else ""
            message = [
                {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
                {"role": "assistant", "content": text},
            ]
            messages.append(message)

        texts = []
        if not decode:
            TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
        else:
            TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
        for i, msg in enumerate(messages):
            texts.append(
                tokenizer.apply_chat_template(
                    msg,
                    tokenize=True,
                    chat_template=TEMPLATE,
                    add_generation_prompt=False,
                    padding="longest",
                    max_length=max_len,
                    truncation=True,
                )
            )

        # Padding texts
        max_len_texts = max([len(text) for text in texts])
        if tokenizer.padding_side == "right":
            texts = [
                text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
                for text in texts
            ]
        else:
            texts = [
                [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
                for text in texts
            ]
        input_ids = torch.tensor(texts, dtype=torch.int)

        target_ids = input_ids.clone()
        target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID

        # first get the indices of the tokens
        mask_prompt = True
        if mask_prompt:
            mask_indices = torch.where(
                input_ids == tokenizer.convert_tokens_to_ids("assistant")
                )
            for i in range(mask_indices[0].size(0)):
                row = mask_indices[0][i]
                col = mask_indices[1][i]
                target_ids[row, : col + 2] = IGNORE_TOKEN_ID

        attention_mask = input_ids.ne(tokenizer.pad_token_id)

        target_ids = target_ids.type(torch.LongTensor)
        input_ids = input_ids.type(torch.LongTensor)
        return input_ids, attention_mask, target_ids, clean_texts