| import logging |
| import os |
| import pickle |
| import time |
| import json |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, NewType, Tuple |
| from tqdm import tqdm |
|
|
|
|
| import torch |
| from torch.utils.data.dataset import Dataset |
| from transformers.tokenization_utils import PreTrainedTokenizer |
| from transformers.data.data_collator import DataCollator |
| from transformers.tokenization_bart import BartTokenizer |
| from transformers.tokenization_roberta import RobertaTokenizer |
| from relogic.pretrainkit.datasets.utils import pad_and_tensorize_sequence |
| logger = logging.getLogger(__name__) |
|
|
| label_mapping = json.load(open("data/preprocessed_data/bart_parser_label_mapping_2.json")) |
|
|
| class QuerySchemaRelation2SQLDataset(Dataset): |
| """ |
| Dataset for relation-aware text-to-SQL: query + schema + relation -> SQL |
| """ |
| def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, local_rank=-1): |
| self.examples = [] |
| self.keywords = label_mapping["keyword"] |
| self.label_eos_id = self.keywords.index(label_mapping["label_eos_token"]) |
| self.label_bos_id = self.keywords.index(label_mapping["label_bos_token"]) |
| add_prefix_space = isinstance(tokenizer, BartTokenizer) or isinstance(tokenizer, RobertaTokenizer) |
| total, valid = 0, 0 |
| with open(file_path, encoding="utf-8") as f: |
| for line in tqdm(f): |
| total += 1 |
| example = json.loads(line) |
| text = example["normalized_question"] |
| columns = example["columns"] |
| tables = example["tables"] |
| columns_text = example["column_text"] |
| tables_text = example["table_text"] |
| sql = example["sql"] |
| |
| token_idx_to_sub_token_start_idx = {} |
| text_tokens = [tokenizer.cls_token] |
| start_idx = 0 |
| for idx, token in enumerate(text.split()): |
| sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space) |
| token_idx_to_sub_token_start_idx[idx] = start_idx |
| text_tokens.extend(sub_tokens) |
| start_idx += len(sub_tokens) |
| text_tokens.append(tokenizer.sep_token) |
| question_start, question_end = 1, len(text_tokens) - 1 |
|
|
| column_spans = [] |
| start_idx = len(text_tokens) |
| for column_tokens in columns_text: |
| column_str = " ".join(column_tokens) |
| column_tokens = tokenizer.tokenize(column_str, add_prefix_space=add_prefix_space) |
| text_tokens.extend(column_tokens) |
| text_tokens.append(tokenizer.sep_token) |
| end_idx = start_idx + len(column_tokens) |
| column_spans.append((start_idx, end_idx)) |
| start_idx = end_idx + 1 |
|
|
| column_start = [column_span[0] for column_span in column_spans] |
| column_end = [column_span[1] for column_span in column_spans] |
|
|
| table_spans = [] |
| start_idx = len(text_tokens) |
| for table_tokens in tables_text: |
| table_str = " ".join(table_tokens) |
| table_tokens = tokenizer.tokenize(table_str, add_prefix_space=add_prefix_space) |
| text_tokens.extend(table_tokens) |
| text_tokens.append(tokenizer.sep_token) |
| end_idx = start_idx + len(table_tokens) |
| table_spans.append((start_idx, end_idx)) |
| start_idx = end_idx + 1 |
|
|
| table_start = [table_span[0] for table_span in table_spans] |
| table_end = [table_span[1] for table_span in table_spans] |
|
|
| input_ids = tokenizer.convert_tokens_to_ids(text_tokens) |
|
|
| if len(input_ids) > block_size: |
| continue |
|
|
| label_ids = [] |
| try: |
| for token in sql.split(): |
| if token in columns: |
| label_ids.append(columns.index(token) + len(self.keywords)) |
| else: |
| label_ids.append(self.keywords.index(token)) |
| except: |
| continue |
|
|
| label_ids = [self.label_bos_id] + label_ids + [self.label_eos_id] |
|
|
| primary_key = [int(x) for x in example["sc_struct"]["primary_key"]] |
| foreign_key = {x.split(",")[0]: int(x.split(",")[1]) for x in example["sc_struct"]["foreign_key"]} |
| column_to_table = {"0": None} |
|
|
| sc_link = {"q_col_match": {}, "q_tab_match": {}} |
| for k, v in example["sc_link"]["q_col_match"].items(): |
| new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| sc_link["q_col_match"][new_k] = v |
|
|
| for k, v in example["sc_link"]["q_tab_match"].items(): |
| new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| sc_link["q_tab_match"][new_k] = v |
|
|
| cv_link = {"num_date_match": {}, "cell_match": {}} |
| for k, v in example["cv_link"]["num_date_match"].items(): |
| new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| cv_link["num_date_match"][new_k] = v |
| for k, v in example["cv_link"]["cell_match"].items(): |
| new_k = str(token_idx_to_sub_token_start_idx[int(k.split(",")[0])]) + "," + k.split(",")[1] |
| cv_link["cell_match"][new_k] = v |
|
|
|
|
| for idx, column in enumerate(columns): |
| if column == "*": |
| continue |
| t = column.split(".")[0] |
| column_to_table[str(idx)] = tables.index(t) |
|
|
| foreign_keys_tables = {} |
| for k, v in foreign_key.items(): |
| t_k = str(column_to_table[str(k)]) |
| t_v = str(column_to_table[str(v)]) |
| if t_k not in foreign_keys_tables: |
| foreign_keys_tables[t_k] = [] |
| if int(t_v) not in foreign_keys_tables[t_k]: |
| foreign_keys_tables[t_k].append(int(t_v)) |
|
|
| self.examples.append({ |
| "input_ids": input_ids, |
| "example_info": { |
| "normalized_question": text, |
| "columns": columns, |
| "tables": tables, |
| "tokens": text_tokens, |
| "question_start": question_start, |
| "question_end": question_end, |
| "column_start": torch.LongTensor(column_start), |
| "column_end": torch.LongTensor(column_end), |
| "table_start": torch.LongTensor(table_start), |
| "table_end": torch.LongTensor(table_end), |
| "sc_link": sc_link, |
| "cv_link": cv_link, |
| "primary_keys": primary_key, |
| "foreign_keys": foreign_key, |
| "column_to_table": column_to_table, |
| "foreign_keys_tables": foreign_keys_tables |
| }, |
| "column_spans": column_spans, |
| "label_ids": label_ids}) |
| valid += 1 |
| print("Valid Example {}; Invalid Example {}".format(valid, total - valid)) |
|
|
| def __len__(self): |
| return len(self.examples) |
|
|
| def __getitem__(self, i): |
| return self.examples[i] |
|
|
|
|
| @dataclass |
| class DataCollatorForQuerySchemaRelation2SQL: |
| """ |
| Data collator used for query + schema -> sql modeling. |
| """ |
| tokenizer: PreTrainedTokenizer |
| label_padding_id = label_mapping["keyword"].index(label_mapping["label_padding_token"]) |
| label_eos_id = label_mapping["keyword"].index(label_mapping["label_eos_token"]) |
| label_bos_id = label_mapping["keyword"].index(label_mapping["label_bos_token"]) |
| def collate_batch(self, examples) -> Dict[str, torch.Tensor]: |
|
|
| input_ids_sequences = [example["input_ids"] for example in examples] |
| column_spans_sequences = [example["column_spans"] for example in examples] |
| label_ids_sequences = [example["label_ids"] for example in examples] |
| padded_input_ids_tensor = pad_and_tensorize_sequence( |
| input_ids_sequences, padding_value=self.tokenizer.pad_token_id) |
| padded_column_spans_tensor = pad_and_tensorize_sequence( |
| column_spans_sequences, padding_value=(0, 1)) |
|
|
| example_info_list = [] |
| for example in examples: |
| example_info_list.append(example["example_info"]) |
| label_ids_tensor = pad_and_tensorize_sequence( |
| label_ids_sequences, padding_value=self.label_padding_id) |
| return { |
| "input_ids": padded_input_ids_tensor, |
| "column_spans": padded_column_spans_tensor, |
| "labels": label_ids_tensor, |
| "example_info_list": example_info_list, |
| "input_padding_id": self.tokenizer.pad_token_id, |
| "label_padding_id": self.label_padding_id, |
| "label_eos_id": self.label_eos_id, |
| "label_bos_id": self.label_bos_id |
| } |
|
|
|
|