|
|
import torch |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
def get_prompt_length(tokenizer, prompt): |
|
|
return len(tokenizer.encode(prompt)) |
|
|
|
|
|
|
|
|
def tokenize_multipart_input( |
|
|
tokenizer, |
|
|
input_text_list: list, |
|
|
max_seq_len: int, |
|
|
template=None, |
|
|
prompt=None, |
|
|
): |
|
|
"""This function is an adaptation of the `tokenize_multipart_input` found in princeton-nlp's repository |
|
|
at https://github.com/princeton-nlp/LM-BFF/blob/main/src/dataset.py. |
|
|
|
|
|
Modifications include: |
|
|
- Extension of automatic prompt generation for multi-label classification. |
|
|
- Removal of parameters like `first_sent_limit`, `other_sent_limit`, `gpt3`, `truncate_head`, and `support_labels`. |
|
|
- Optimization of the code flow. |
|
|
|
|
|
Args: |
|
|
tokenizer: a pre-trained tokenizer from Hugging Face Transformers |
|
|
input_text_list (list): documents ready for tokenization. |
|
|
max_seq_len (int): max sequence length after adding the prompt along with special tokens from BERT. |
|
|
template (str, optional): placeholder for the prompt. |
|
|
prompt (str, optional): the prompt we use for input text. |
|
|
""" |
|
|
|
|
|
def enc(text): |
|
|
return tokenizer.encode(text, add_special_tokens=False) |
|
|
|
|
|
input_ids = [] |
|
|
attention_mask = [] |
|
|
token_type_ids = [] |
|
|
mask_pos = None |
|
|
|
|
|
if prompt: |
|
|
special_token_mapping = { |
|
|
"cls": tokenizer.cls_token_id, |
|
|
"mask": tokenizer.mask_token_id, |
|
|
"sep": tokenizer.sep_token_id, |
|
|
"sep+": tokenizer.sep_token_id, |
|
|
} |
|
|
|
|
|
if prompt != "auto": |
|
|
template = template.replace("[PROMPT]", prompt) |
|
|
template_list = template.split("*") |
|
|
if prompt == "auto": |
|
|
|
|
|
cls_pos = template_list.index("cls") |
|
|
if template_list[cls_pos + 1] == "": |
|
|
|
|
|
|
|
|
prompt = template_list[cls_pos + 3] |
|
|
elif template_list[cls_pos + 1] != "" and ( |
|
|
template_list[cls_pos + 1].startswith("_") |
|
|
): |
|
|
|
|
|
|
|
|
prompt = template_list[cls_pos + 1] |
|
|
if prompt.startswith("_"): |
|
|
prompt = prompt[1:] |
|
|
segment_id = 0 |
|
|
|
|
|
for part in template_list: |
|
|
new_tokens = [] |
|
|
segment_plus_1_flag = False |
|
|
if part in special_token_mapping: |
|
|
new_tokens.append(special_token_mapping[part]) |
|
|
if part == "sep+": |
|
|
segment_plus_1_flag = True |
|
|
elif part[:5] == "sent_" or part[:6] == "+sent_": |
|
|
sent_id = int(part.split("_")[1]) |
|
|
max_len = max_seq_len - 3 - get_prompt_length(tokenizer, prompt) |
|
|
|
|
|
tokens = enc(input_text_list[sent_id])[-max_len:] |
|
|
new_tokens += tokens |
|
|
else: |
|
|
|
|
|
part = part.replace("_", " ") |
|
|
|
|
|
if len(part) == 1: |
|
|
new_tokens.append(tokenizer.convert_tokens_to_ids(part)) |
|
|
else: |
|
|
new_tokens += enc(part) |
|
|
|
|
|
input_ids += new_tokens |
|
|
attention_mask += [1 for i in range(len(new_tokens))] |
|
|
token_type_ids += [segment_id for i in range(len(new_tokens))] |
|
|
|
|
|
if segment_plus_1_flag: |
|
|
segment_id += 1 |
|
|
|
|
|
mask_pos = [input_ids.index(tokenizer.mask_token_id)] |
|
|
|
|
|
assert mask_pos[0] < max_seq_len |
|
|
|
|
|
else: |
|
|
input_ids = [tokenizer.cls_token_id] |
|
|
attention_mask = [1] |
|
|
token_type_ids = [0] |
|
|
max_len = max_seq_len - 2 |
|
|
|
|
|
for sent_id, input_text in enumerate(input_text_list): |
|
|
if input_text is None: |
|
|
|
|
|
continue |
|
|
if pd.isna(input_text) or input_text is None: |
|
|
|
|
|
input_text = "" |
|
|
input_tokens = enc(input_text)[:max_len] + [tokenizer.sep_token_id] |
|
|
input_ids += input_tokens |
|
|
attention_mask += [1 for i in range(len(input_tokens))] |
|
|
token_type_ids += [sent_id for i in range(len(input_tokens))] |
|
|
|
|
|
return input_ids, attention_mask, token_type_ids, mask_pos |
|
|
|
|
|
|
|
|
class InferenceDataset(torch.utils.data.Dataset): |
|
|
""" |
|
|
A class for creating the CGMH dataset in PyTorch. |
|
|
Currently, this class supports: |
|
|
(1) Few-shot data (e.g., train_size=16) |
|
|
(2) Small-size data (e.g., train_size>100) |
|
|
--- |
|
|
Attributes |
|
|
data (pd.DataFrame): the CGMH dataset |
|
|
tokenizer: a pre-trained HuggingFace tokenizer |
|
|
max_seq_len (int): maximum length for a sequence |
|
|
template (_type_, optional): template for the model. Defaults to None. |
|
|
prompt (_type_, optional): prompt for the model. Defaults to None. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_text: str, |
|
|
tokenizer, |
|
|
max_seq_len: int, |
|
|
template=None, |
|
|
prompt=None, |
|
|
): |
|
|
self.doc = input_text |
|
|
self.template = template |
|
|
self.prompt = prompt |
|
|
self.tokenizer = tokenizer |
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
input_ids, attn_mask, segs, mask_pos = tokenize_multipart_input( |
|
|
tokenizer=self.tokenizer, |
|
|
input_text_list=[self.doc], |
|
|
template=self.template, |
|
|
prompt=self.prompt, |
|
|
max_seq_len=self.max_seq_len, |
|
|
) |
|
|
item = { |
|
|
"input_ids": input_ids, |
|
|
"token_type_ids": segs, |
|
|
"attention_mask": attn_mask, |
|
|
} |
|
|
if self.prompt: |
|
|
item["mask_pos"] = mask_pos |
|
|
return item |
|
|
|
|
|
def __len__(self): |
|
|
return 1 |
|
|
|