| from torch.utils.data.dataset import Dataset |
| from transformers.tokenization_utils import PreTrainedTokenizer |
| from tqdm import tqdm |
| import json |
| from dataclasses import dataclass |
| import torch |
| from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence |
| import random |
|
|
| class TaBARTDataset(Dataset): |
| """ |
| This dataset is used for pretraining task on generation-based or retrieval-based |
| text-schema pair examples. |
| The fields that will be used is `question`, `table_info.header`, `entities`. |
| We already make sure that every entity in `entities` will be in `table_info.header`. |
| """ |
| def __init__(self, |
| tokenizer: PreTrainedTokenizer, |
| file_path: str, |
| col_token: str): |
| self.examples = [] |
| total = 0 |
| valid = 0 |
| with open(file_path, encoding="utf-8") as f: |
| for line in tqdm(f): |
| total += 1 |
| example = json.loads(line) |
| text = example["question"] |
| schema = example["table_info"]["header"] |
| tokens = [tokenizer.cls_token] + tokenizer.tokenize(text, add_prefix_space=True) + [col_token] |
| column_spans = [] |
| start_idx = len(tokens) |
| for column in schema: |
| column_tokens = tokenizer.tokenize(column.lower(), add_prefix_space=True) |
| tokens.extend(column_tokens) |
| column_spans.append((start_idx, start_idx + len(column_tokens))) |
| tokens.append(col_token) |
| start_idx += len(column_tokens) + 1 |
| |
| tokens[-1] = tokenizer.sep_token |
| input_ids = tokenizer.convert_tokens_to_ids(tokens) |
| entities = example["entities"] |
| column_labels = [0] * len(schema) |
| for entity in entities: |
| if entity != "limit" and entity != "*": |
| column_labels[schema.index(entity)] = 1 |
| if len(input_ids) > 600: |
| continue |
| self.examples.append({ |
| "input_ids": input_ids, |
| "column_spans": column_spans, |
| "column_labels": column_labels |
| }) |
| valid += 1 |
| |
| print("Total {} and Valid {}".format(total, valid)) |
| def __len__(self): |
| return len(self.examples) |
|
|
| def __getitem__(self, i): |
| return self.examples[i] |
|
|
|
|
| @dataclass |
| class DataCollatorForTaBART: |
| tokenizer: PreTrainedTokenizer |
| task: str |
| mlm_probability: float = 0.35 |
|
|
|
|
|
|
| def __post_init__(self): |
| self.label_bos_id = self.tokenizer.cls_token_id |
| self.label_eos_id = self.tokenizer.sep_token_id |
|
|
| def collate_batch(self, examples): |
| input_ids_sequences = [example["input_ids"] for example in examples] |
| padded_input_ids_tensor = pad_and_tensorize_sequence(input_ids_sequences, |
| padding_value=self.tokenizer.pad_token_id) |
| if self.task == "mlm": |
| inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone()) |
| return { |
| "task": "mlm", |
| "input_ids": inputs, |
| "labels": padded_input_ids_tensor, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| "label_bos_id": self.tokenizer.bos_token_id, |
| "label_eos_id": self.tokenizer.eos_token_id, |
| "label_padding_id": self.tokenizer.pad_token_id} |
| elif self.task == "col_pred": |
| column_labels_sequences = [example["column_labels"] for example in examples] |
| padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences, |
| padding_value=-100) |
| column_spans_sequences = [example["column_spans"] for example in examples] |
| padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences, |
| padding_value=(0, 1)) |
| return { |
| "task": "col_pred", |
| "input_ids": padded_input_ids_tensor, |
| "column_spans": padded_column_spans_tensor, |
| "labels": padded_label_ids_tensor, |
| "pad_token_id": self.tokenizer.pad_token_id} |
| elif self.task == "mlm+col_pred": |
| if random.random() < 0.6: |
| inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone()) |
| return { |
| "task": "mlm", |
| "input_ids": inputs, |
| "labels": padded_input_ids_tensor, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| "label_bos_id": self.tokenizer.bos_token_id, |
| "label_eos_id": self.tokenizer.eos_token_id, |
| "label_padding_id": self.tokenizer.pad_token_id} |
| else: |
| column_labels_sequences = [example["column_labels"] for example in examples] |
| padded_label_ids_tensor = pad_and_tensorize_sequence(column_labels_sequences, |
| padding_value=-100) |
| column_spans_sequences = [example["column_spans"] for example in examples] |
| padded_column_spans_tensor = pad_and_tensorize_sequence(column_spans_sequences, |
| padding_value=(0, 1)) |
| return { |
| "task": "col_pred", |
| "input_ids": padded_input_ids_tensor, |
| "column_spans": padded_column_spans_tensor, |
| "labels": padded_label_ids_tensor, |
| "pad_token_id": self.tokenizer.pad_token_id} |
|
|
| def mask_tokens(self, inputs): |
| """ |
| Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
| """ |
|
|
| if self.tokenizer.mask_token is None: |
| raise ValueError( |
| "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." |
| ) |
|
|
| labels = inputs.clone() |
| |
| probability_matrix = torch.full(labels.shape, self.mlm_probability) |
| special_tokens_mask = [ |
| self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
| ] |
| probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) |
| if self.tokenizer._pad_token is not None: |
| padding_mask = labels.eq(self.tokenizer.pad_token_id) |
| probability_matrix.masked_fill_(padding_mask, value=0.0) |
| masked_indices = torch.bernoulli(probability_matrix).bool() |
| labels[~masked_indices] = -100 |
|
|
| |
| indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
| inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
|
|
| |
| indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
| random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
| inputs[indices_random] = random_words[indices_random] |
|
|
| |
| return inputs, labels |
|
|
|
|