| import logging |
| import os |
| import pickle |
| import time |
| import json |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, NewType, Tuple |
| from tqdm import tqdm |
|
|
|
|
| import torch |
| from torch.utils.data.dataset import Dataset |
| from transformers.tokenization_utils import PreTrainedTokenizer |
| from transformers.data.data_collator import DataCollator |
|
|
| logger = logging.getLogger(__name__) |
|
|
| label_mapping = json.load(open("data/preprocessed_data/bart_parser_pretrain_label_mapping.json")) |
|
|
| def pad_and_tensorize_sequence(sequences, padding_value): |
| max_size = max([len(sequence) for sequence in sequences]) |
| padded_sequences = [] |
| for sequence in sequences: |
| padded_sequence = sequence + [padding_value] * (max_size - len(sequence)) |
| padded_sequences.append(padded_sequence) |
| return torch.tensor(padded_sequences, dtype=torch.long) |
|
|
| class QuerySchema2SQLDataset(Dataset): |
| """ |
| Dataset for pretraining task: query + schema -> SQL |
| There is not masking for query and schema. |
| """ |
| def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1): |
| assert os.path.isfile(file_path) |
| logger.info("Creating features from dataset file at %s", file_path) |
|
|
| self.examples = [] |
| self.keywords = label_mapping["keyword"] |
| self.label_eos_id = self.keywords.index(label_mapping["label_eos_token"]) |
| self.label_bos_id = self.keywords.index(label_mapping["label_bos_token"]) |
| total, valid = 0, 0 |
| with open(file_path, encoding="utf-8") as f: |
| for line in tqdm(f): |
| total += 1 |
| example = json.loads(line) |
| text = example["question"] |
| columns = example["columns"] + example["tables"] + example["extra"] + example["negative"][:15] |
| columns = [column.lower() for column in columns] |
|
|
| |
| column_to_text = {} |
| for column in columns: |
| column_text = column.replace(".", " ").replace("_", " ") |
| column_to_text[column] = column_text.lower() |
| sql = example["processed_sql"] |
| text_tokens = [tokenizer.cls_token] + tokenizer.tokenize(text) + [tokenizer.sep_token] |
| column_spans = [] |
| start_idx = len(text_tokens) |
| for column in columns: |
| column_tokens = tokenizer.tokenize(column_to_text[column]) |
| text_tokens.extend(column_tokens) |
| text_tokens.append(tokenizer.sep_token) |
| end_idx = start_idx + len(column_tokens) |
| column_spans.append((start_idx, end_idx)) |
| start_idx = end_idx + 1 |
| input_ids = tokenizer.convert_tokens_to_ids(text_tokens) |
|
|
| if len(input_ids) > 600: |
| continue |
|
|
| label_ids = [] |
| try: |
| for token in sql.split(): |
| token = token.lower() |
| if token in columns: |
| label_ids.append(columns.index(token) + len(self.keywords)) |
| else: |
| label_ids.append(self.keywords.index(token)) |
| except: |
| continue |
| if len(label_ids) > 300: |
| continue |
| label_ids = [self.label_bos_id] + label_ids + [self.label_eos_id] |
|
|
| self.examples.append({ |
| "idx": example["sql_id"], |
| "input_ids": input_ids, |
| "column_spans": column_spans, |
| "label_ids": label_ids}) |
| valid += 1 |
| print("Valid Example {}; Invalid Example {}".format(valid, total-valid)) |
|
|
| def __len__(self): |
| return len(self.examples) |
|
|
| def __getitem__(self, i): |
| return self.examples[i] |
|
|
| @dataclass |
| class DataCollatorForQuerySchema2SQL: |
| """ |
| Data collator used for query + schema -> sql modeling. |
| """ |
| tokenizer: PreTrainedTokenizer |
| label_padding_id = label_mapping["keyword"].index(label_mapping["label_padding_token"]) |
| label_eos_id = label_mapping["keyword"].index(label_mapping["label_eos_token"]) |
| label_bos_id = label_mapping["keyword"].index(label_mapping["label_bos_token"]) |
| logging_file = open("index_logging.txt", "w") |
| def collate_batch(self, examples) -> Dict[str, torch.Tensor]: |
| for example in examples: |
| self.logging_file.write(str(example["idx"]) + "\n") |
| input_ids_sequences = [example["input_ids"] for example in examples] |
| column_spans_sequences = [example["column_spans"] for example in examples] |
| label_ids_sequences = [example["label_ids"] for example in examples] |
| padded_input_ids_tensor = pad_and_tensorize_sequence( |
| input_ids_sequences, padding_value=self.tokenizer.pad_token_id) |
| padded_column_spans_tensor = pad_and_tensorize_sequence( |
| column_spans_sequences, padding_value=(0, 1)) |
|
|
|
|
| label_ids_tensor = pad_and_tensorize_sequence( |
| label_ids_sequences, padding_value=self.label_padding_id) |
| return { |
| "input_ids": padded_input_ids_tensor, |
| "column_spans": padded_column_spans_tensor, |
| "labels": label_ids_tensor, |
| "input_padding_id": self.tokenizer.pad_token_id, |
| "label_padding_id": self.label_padding_id, |
| "label_eos_id": self.label_eos_id, |
| "label_bos_id": self.label_bos_id |
| } |
|
|
|
|