MPTR / prompt_dataset.py
mcps5601
Add application files
e77bcc6
import torch
import pandas as pd
def get_prompt_length(tokenizer, prompt):
return len(tokenizer.encode(prompt))
def tokenize_multipart_input(
tokenizer,
input_text_list: list,
max_seq_len: int,
template=None,
prompt=None,
):
"""This function is an adaptation of the `tokenize_multipart_input` found in princeton-nlp's repository
at https://github.com/princeton-nlp/LM-BFF/blob/main/src/dataset.py.
Modifications include:
- Extension of automatic prompt generation for multi-label classification.
- Removal of parameters like `first_sent_limit`, `other_sent_limit`, `gpt3`, `truncate_head`, and `support_labels`.
- Optimization of the code flow.
Args:
tokenizer: a pre-trained tokenizer from Hugging Face Transformers
input_text_list (list): documents ready for tokenization.
max_seq_len (int): max sequence length after adding the prompt along with special tokens from BERT.
template (str, optional): placeholder for the prompt.
prompt (str, optional): the prompt we use for input text.
"""
def enc(text):
return tokenizer.encode(text, add_special_tokens=False)
input_ids = []
attention_mask = []
token_type_ids = [] # Only for BERT
mask_pos = None # Position of the mask token
if prompt:
special_token_mapping = {
"cls": tokenizer.cls_token_id,
"mask": tokenizer.mask_token_id,
"sep": tokenizer.sep_token_id,
"sep+": tokenizer.sep_token_id,
}
# Get variable list in the template
if prompt != "auto":
template = template.replace("[PROMPT]", prompt)
template_list = template.split("*")
if prompt == "auto":
# find cls place
cls_pos = template_list.index("cls")
if template_list[cls_pos + 1] == "":
# For these kinds of cases: *cls**sent_0*_Liver*mask*.*sep+*
# Prompt is next to sent_0.
prompt = template_list[cls_pos + 3]
elif template_list[cls_pos + 1] != "" and (
template_list[cls_pos + 1].startswith("_")
):
# For these kinds of cases: *cls*_Liver*mask*.*+sent_0**sep+*
# Prompt is next to cls.
prompt = template_list[cls_pos + 1]
if prompt.startswith("_"):
prompt = prompt[1:]
segment_id = 0
for part in template_list:
new_tokens = []
segment_plus_1_flag = False
if part in special_token_mapping:
new_tokens.append(special_token_mapping[part])
if part == "sep+":
segment_plus_1_flag = True
elif part[:5] == "sent_" or part[:6] == "+sent_":
sent_id = int(part.split("_")[1])
max_len = max_seq_len - 3 - get_prompt_length(tokenizer, prompt)
# Tokenize and truncate to max_seq_len
tokens = enc(input_text_list[sent_id])[-max_len:]
new_tokens += tokens
else:
# Just natural language prompt
part = part.replace("_", " ")
# handle special case when T5 tokenizer might add an extra space
if len(part) == 1:
new_tokens.append(tokenizer.convert_tokens_to_ids(part))
else:
new_tokens += enc(part)
input_ids += new_tokens
attention_mask += [1 for i in range(len(new_tokens))]
token_type_ids += [segment_id for i in range(len(new_tokens))]
if segment_plus_1_flag:
segment_id += 1
mask_pos = [input_ids.index(tokenizer.mask_token_id)]
# Make sure that the masked position is inside the max_length
assert mask_pos[0] < max_seq_len
else:
input_ids = [tokenizer.cls_token_id]
attention_mask = [1]
token_type_ids = [0]
max_len = max_seq_len - 2
for sent_id, input_text in enumerate(input_text_list):
if input_text is None:
# Do not have text_b
continue
if pd.isna(input_text) or input_text is None:
# Empty input
input_text = ""
input_tokens = enc(input_text)[:max_len] + [tokenizer.sep_token_id]
input_ids += input_tokens
attention_mask += [1 for i in range(len(input_tokens))]
token_type_ids += [sent_id for i in range(len(input_tokens))]
return input_ids, attention_mask, token_type_ids, mask_pos
class InferenceDataset(torch.utils.data.Dataset):
"""
A class for creating the CGMH dataset in PyTorch.
Currently, this class supports:
(1) Few-shot data (e.g., train_size=16)
(2) Small-size data (e.g., train_size>100)
---
Attributes
data (pd.DataFrame): the CGMH dataset
tokenizer: a pre-trained HuggingFace tokenizer
max_seq_len (int): maximum length for a sequence
template (_type_, optional): template for the model. Defaults to None.
prompt (_type_, optional): prompt for the model. Defaults to None.
"""
def __init__(
self,
input_text: str,
tokenizer,
max_seq_len: int,
template=None,
prompt=None,
):
self.doc = input_text
self.template = template
self.prompt = prompt
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __getitem__(self, idx):
input_ids, attn_mask, segs, mask_pos = tokenize_multipart_input(
tokenizer=self.tokenizer,
input_text_list=[self.doc],
template=self.template,
prompt=self.prompt,
max_seq_len=self.max_seq_len,
)
item = {
"input_ids": input_ids,
"token_type_ids": segs,
"attention_mask": attn_mask,
}
if self.prompt:
item["mask_pos"] = mask_pos
return item
def __len__(self):
return 1