Spaces:
Runtime error
Runtime error
| import torch | |
| import pandas as pd | |
| def get_prompt_length(tokenizer, prompt): | |
| return len(tokenizer.encode(prompt)) | |
| def tokenize_multipart_input( | |
| tokenizer, | |
| input_text_list: list, | |
| max_seq_len: int, | |
| template=None, | |
| prompt=None, | |
| ): | |
| """This function is an adaptation of the `tokenize_multipart_input` found in princeton-nlp's repository | |
| at https://github.com/princeton-nlp/LM-BFF/blob/main/src/dataset.py. | |
| Modifications include: | |
| - Extension of automatic prompt generation for multi-label classification. | |
| - Removal of parameters like `first_sent_limit`, `other_sent_limit`, `gpt3`, `truncate_head`, and `support_labels`. | |
| - Optimization of the code flow. | |
| Args: | |
| tokenizer: a pre-trained tokenizer from Hugging Face Transformers | |
| input_text_list (list): documents ready for tokenization. | |
| max_seq_len (int): max sequence length after adding the prompt along with special tokens from BERT. | |
| template (str, optional): placeholder for the prompt. | |
| prompt (str, optional): the prompt we use for input text. | |
| """ | |
| def enc(text): | |
| return tokenizer.encode(text, add_special_tokens=False) | |
| input_ids = [] | |
| attention_mask = [] | |
| token_type_ids = [] # Only for BERT | |
| mask_pos = None # Position of the mask token | |
| if prompt: | |
| special_token_mapping = { | |
| "cls": tokenizer.cls_token_id, | |
| "mask": tokenizer.mask_token_id, | |
| "sep": tokenizer.sep_token_id, | |
| "sep+": tokenizer.sep_token_id, | |
| } | |
| # Get variable list in the template | |
| if prompt != "auto": | |
| template = template.replace("[PROMPT]", prompt) | |
| template_list = template.split("*") | |
| if prompt == "auto": | |
| # find cls place | |
| cls_pos = template_list.index("cls") | |
| if template_list[cls_pos + 1] == "": | |
| # For these kinds of cases: *cls**sent_0*_Liver*mask*.*sep+* | |
| # Prompt is next to sent_0. | |
| prompt = template_list[cls_pos + 3] | |
| elif template_list[cls_pos + 1] != "" and ( | |
| template_list[cls_pos + 1].startswith("_") | |
| ): | |
| # For these kinds of cases: *cls*_Liver*mask*.*+sent_0**sep+* | |
| # Prompt is next to cls. | |
| prompt = template_list[cls_pos + 1] | |
| if prompt.startswith("_"): | |
| prompt = prompt[1:] | |
| segment_id = 0 | |
| for part in template_list: | |
| new_tokens = [] | |
| segment_plus_1_flag = False | |
| if part in special_token_mapping: | |
| new_tokens.append(special_token_mapping[part]) | |
| if part == "sep+": | |
| segment_plus_1_flag = True | |
| elif part[:5] == "sent_" or part[:6] == "+sent_": | |
| sent_id = int(part.split("_")[1]) | |
| max_len = max_seq_len - 3 - get_prompt_length(tokenizer, prompt) | |
| # Tokenize and truncate to max_seq_len | |
| tokens = enc(input_text_list[sent_id])[-max_len:] | |
| new_tokens += tokens | |
| else: | |
| # Just natural language prompt | |
| part = part.replace("_", " ") | |
| # handle special case when T5 tokenizer might add an extra space | |
| if len(part) == 1: | |
| new_tokens.append(tokenizer.convert_tokens_to_ids(part)) | |
| else: | |
| new_tokens += enc(part) | |
| input_ids += new_tokens | |
| attention_mask += [1 for i in range(len(new_tokens))] | |
| token_type_ids += [segment_id for i in range(len(new_tokens))] | |
| if segment_plus_1_flag: | |
| segment_id += 1 | |
| mask_pos = [input_ids.index(tokenizer.mask_token_id)] | |
| # Make sure that the masked position is inside the max_length | |
| assert mask_pos[0] < max_seq_len | |
| else: | |
| input_ids = [tokenizer.cls_token_id] | |
| attention_mask = [1] | |
| token_type_ids = [0] | |
| max_len = max_seq_len - 2 | |
| for sent_id, input_text in enumerate(input_text_list): | |
| if input_text is None: | |
| # Do not have text_b | |
| continue | |
| if pd.isna(input_text) or input_text is None: | |
| # Empty input | |
| input_text = "" | |
| input_tokens = enc(input_text)[:max_len] + [tokenizer.sep_token_id] | |
| input_ids += input_tokens | |
| attention_mask += [1 for i in range(len(input_tokens))] | |
| token_type_ids += [sent_id for i in range(len(input_tokens))] | |
| return input_ids, attention_mask, token_type_ids, mask_pos | |
| class InferenceDataset(torch.utils.data.Dataset): | |
| """ | |
| A class for creating the CGMH dataset in PyTorch. | |
| Currently, this class supports: | |
| (1) Few-shot data (e.g., train_size=16) | |
| (2) Small-size data (e.g., train_size>100) | |
| --- | |
| Attributes | |
| data (pd.DataFrame): the CGMH dataset | |
| tokenizer: a pre-trained HuggingFace tokenizer | |
| max_seq_len (int): maximum length for a sequence | |
| template (_type_, optional): template for the model. Defaults to None. | |
| prompt (_type_, optional): prompt for the model. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| input_text: str, | |
| tokenizer, | |
| max_seq_len: int, | |
| template=None, | |
| prompt=None, | |
| ): | |
| self.doc = input_text | |
| self.template = template | |
| self.prompt = prompt | |
| self.tokenizer = tokenizer | |
| self.max_seq_len = max_seq_len | |
| def __getitem__(self, idx): | |
| input_ids, attn_mask, segs, mask_pos = tokenize_multipart_input( | |
| tokenizer=self.tokenizer, | |
| input_text_list=[self.doc], | |
| template=self.template, | |
| prompt=self.prompt, | |
| max_seq_len=self.max_seq_len, | |
| ) | |
| item = { | |
| "input_ids": input_ids, | |
| "token_type_ids": segs, | |
| "attention_mask": attn_mask, | |
| } | |
| if self.prompt: | |
| item["mask_pos"] = mask_pos | |
| return item | |
| def __len__(self): | |
| return 1 | |