File size: 6,115 Bytes
e77bcc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import pandas as pd


def get_prompt_length(tokenizer, prompt):
    return len(tokenizer.encode(prompt))


def tokenize_multipart_input(
    tokenizer,
    input_text_list: list,
    max_seq_len: int,
    template=None,
    prompt=None,
):
    """This function is an adaptation of the `tokenize_multipart_input` found in princeton-nlp's repository
    at https://github.com/princeton-nlp/LM-BFF/blob/main/src/dataset.py.

    Modifications include:
    - Extension of automatic prompt generation for multi-label classification.
    - Removal of parameters like `first_sent_limit`, `other_sent_limit`, `gpt3`, `truncate_head`, and `support_labels`.
    - Optimization of the code flow.

    Args:
        tokenizer: a pre-trained tokenizer from Hugging Face Transformers
        input_text_list (list): documents ready for tokenization.
        max_seq_len (int): max sequence length after adding the prompt along with special tokens from BERT.
        template (str, optional): placeholder for the prompt.
        prompt (str, optional): the prompt we use for input text.
    """

    def enc(text):
        return tokenizer.encode(text, add_special_tokens=False)

    input_ids = []
    attention_mask = []
    token_type_ids = []  # Only for BERT
    mask_pos = None  # Position of the mask token

    if prompt:
        special_token_mapping = {
            "cls": tokenizer.cls_token_id,
            "mask": tokenizer.mask_token_id,
            "sep": tokenizer.sep_token_id,
            "sep+": tokenizer.sep_token_id,
        }
        # Get variable list in the template
        if prompt != "auto":
            template = template.replace("[PROMPT]", prompt)
        template_list = template.split("*")
        if prompt == "auto":
            # find cls place
            cls_pos = template_list.index("cls")
            if template_list[cls_pos + 1] == "":
                # For these kinds of cases: *cls**sent_0*_Liver*mask*.*sep+*
                # Prompt is next to sent_0.
                prompt = template_list[cls_pos + 3]
            elif template_list[cls_pos + 1] != "" and (
                template_list[cls_pos + 1].startswith("_")
            ):
                # For these kinds of cases: *cls*_Liver*mask*.*+sent_0**sep+*
                # Prompt is next to cls.
                prompt = template_list[cls_pos + 1]
            if prompt.startswith("_"):
                prompt = prompt[1:]
        segment_id = 0

        for part in template_list:
            new_tokens = []
            segment_plus_1_flag = False
            if part in special_token_mapping:
                new_tokens.append(special_token_mapping[part])
                if part == "sep+":
                    segment_plus_1_flag = True
            elif part[:5] == "sent_" or part[:6] == "+sent_":
                sent_id = int(part.split("_")[1])
                max_len = max_seq_len - 3 - get_prompt_length(tokenizer, prompt)
                # Tokenize and truncate to max_seq_len
                tokens = enc(input_text_list[sent_id])[-max_len:]
                new_tokens += tokens
            else:
                # Just natural language prompt
                part = part.replace("_", " ")
                # handle special case when T5 tokenizer might add an extra space
                if len(part) == 1:
                    new_tokens.append(tokenizer.convert_tokens_to_ids(part))
                else:
                    new_tokens += enc(part)

            input_ids += new_tokens
            attention_mask += [1 for i in range(len(new_tokens))]
            token_type_ids += [segment_id for i in range(len(new_tokens))]

            if segment_plus_1_flag:
                segment_id += 1

        mask_pos = [input_ids.index(tokenizer.mask_token_id)]
        # Make sure that the masked position is inside the max_length
        assert mask_pos[0] < max_seq_len

    else:
        input_ids = [tokenizer.cls_token_id]
        attention_mask = [1]
        token_type_ids = [0]
        max_len = max_seq_len - 2

        for sent_id, input_text in enumerate(input_text_list):
            if input_text is None:
                # Do not have text_b
                continue
            if pd.isna(input_text) or input_text is None:
                # Empty input
                input_text = ""
            input_tokens = enc(input_text)[:max_len] + [tokenizer.sep_token_id]
            input_ids += input_tokens
            attention_mask += [1 for i in range(len(input_tokens))]
            token_type_ids += [sent_id for i in range(len(input_tokens))]

    return input_ids, attention_mask, token_type_ids, mask_pos


class InferenceDataset(torch.utils.data.Dataset):
    """
    A class for creating the CGMH dataset in PyTorch.
    Currently, this class supports:
    (1) Few-shot data (e.g., train_size=16)
    (2) Small-size data (e.g., train_size>100)
    ---
    Attributes
        data (pd.DataFrame): the CGMH dataset
        tokenizer: a pre-trained HuggingFace tokenizer
        max_seq_len (int): maximum length for a sequence
        template (_type_, optional): template for the model. Defaults to None.
        prompt (_type_, optional): prompt for the model. Defaults to None.
    """

    def __init__(
        self,
        input_text: str,
        tokenizer,
        max_seq_len: int,
        template=None,
        prompt=None,
    ):
        self.doc = input_text
        self.template = template
        self.prompt = prompt
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __getitem__(self, idx):
        input_ids, attn_mask, segs, mask_pos = tokenize_multipart_input(
            tokenizer=self.tokenizer,
            input_text_list=[self.doc],
            template=self.template,
            prompt=self.prompt,
            max_seq_len=self.max_seq_len,
        )
        item = {
            "input_ids": input_ids,
            "token_type_ids": segs,
            "attention_mask": attn_mask,
        }
        if self.prompt:
            item["mask_pos"] = mask_pos
        return item

    def __len__(self):
        return 1