Uploaded Utils, Pycache and Python Files
Browse files- __pycache__/schema_filter.cpython-38.pyc +0 -0
- eval_mode.py +182 -0
- schema_filter.py +339 -0
- training_mode.py +194 -0
- utils/__pycache__/classifier_model.cpython-38.pyc +0 -0
- utils/classifier_model.py +186 -0
__pycache__/schema_filter.cpython-38.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
eval_mode.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from schema_filter import filter_func, SchemaItemClassifierInference
|
| 2 |
+
|
| 3 |
+
# 在eval模式下,sql不用提供
|
| 4 |
+
data = {
|
| 5 |
+
"text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
|
| 6 |
+
"sql": "",
|
| 7 |
+
"schema": {
|
| 8 |
+
"schema_items": [
|
| 9 |
+
{
|
| 10 |
+
"table_name": "lists",
|
| 11 |
+
"table_comment": "",
|
| 12 |
+
"column_names": [
|
| 13 |
+
"user_id",
|
| 14 |
+
"list_id",
|
| 15 |
+
"list_title",
|
| 16 |
+
"list_movie_number",
|
| 17 |
+
"list_update_timestamp_utc",
|
| 18 |
+
"list_creation_timestamp_utc",
|
| 19 |
+
"list_followers",
|
| 20 |
+
"list_url",
|
| 21 |
+
"list_comments",
|
| 22 |
+
"list_description",
|
| 23 |
+
"list_cover_image_url",
|
| 24 |
+
"list_first_image_url",
|
| 25 |
+
"list_second_image_url",
|
| 26 |
+
"list_third_image_url"
|
| 27 |
+
],
|
| 28 |
+
"column_comments": [
|
| 29 |
+
"",
|
| 30 |
+
"",
|
| 31 |
+
"",
|
| 32 |
+
"",
|
| 33 |
+
"",
|
| 34 |
+
"",
|
| 35 |
+
"",
|
| 36 |
+
"",
|
| 37 |
+
"",
|
| 38 |
+
"",
|
| 39 |
+
"",
|
| 40 |
+
"",
|
| 41 |
+
"",
|
| 42 |
+
""
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"table_name": "movies",
|
| 47 |
+
"table_comment": "",
|
| 48 |
+
"column_names": [
|
| 49 |
+
"movie_id",
|
| 50 |
+
"movie_title",
|
| 51 |
+
"movie_release_year",
|
| 52 |
+
"movie_url",
|
| 53 |
+
"movie_title_language",
|
| 54 |
+
"movie_popularity",
|
| 55 |
+
"movie_image_url",
|
| 56 |
+
"director_id",
|
| 57 |
+
"director_name",
|
| 58 |
+
"director_url"
|
| 59 |
+
],
|
| 60 |
+
"column_comments": [
|
| 61 |
+
"",
|
| 62 |
+
"",
|
| 63 |
+
"",
|
| 64 |
+
"",
|
| 65 |
+
"",
|
| 66 |
+
"",
|
| 67 |
+
"",
|
| 68 |
+
"",
|
| 69 |
+
"",
|
| 70 |
+
""
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"table_name": "ratings_users",
|
| 75 |
+
"table_comment": "",
|
| 76 |
+
"column_names": [
|
| 77 |
+
"user_id",
|
| 78 |
+
"rating_date_utc",
|
| 79 |
+
"user_trialist",
|
| 80 |
+
"user_subscriber",
|
| 81 |
+
"user_avatar_image_url",
|
| 82 |
+
"user_cover_image_url",
|
| 83 |
+
"user_eligible_for_trial",
|
| 84 |
+
"user_has_payment_method"
|
| 85 |
+
],
|
| 86 |
+
"column_comments": [
|
| 87 |
+
"",
|
| 88 |
+
"",
|
| 89 |
+
"",
|
| 90 |
+
"",
|
| 91 |
+
"",
|
| 92 |
+
"",
|
| 93 |
+
"",
|
| 94 |
+
""
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"table_name": "lists_users",
|
| 99 |
+
"table_comment": "",
|
| 100 |
+
"column_names": [
|
| 101 |
+
"user_id",
|
| 102 |
+
"list_id",
|
| 103 |
+
"list_update_date_utc",
|
| 104 |
+
"list_creation_date_utc",
|
| 105 |
+
"user_trialist",
|
| 106 |
+
"user_subscriber",
|
| 107 |
+
"user_avatar_image_url",
|
| 108 |
+
"user_cover_image_url",
|
| 109 |
+
"user_eligible_for_trial",
|
| 110 |
+
"user_has_payment_method"
|
| 111 |
+
],
|
| 112 |
+
"column_comments": [
|
| 113 |
+
"",
|
| 114 |
+
"",
|
| 115 |
+
"",
|
| 116 |
+
"",
|
| 117 |
+
"",
|
| 118 |
+
"",
|
| 119 |
+
"",
|
| 120 |
+
"",
|
| 121 |
+
"",
|
| 122 |
+
""
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"table_name": "ratings",
|
| 127 |
+
"table_comment": "",
|
| 128 |
+
"column_names": [
|
| 129 |
+
"movie_id",
|
| 130 |
+
"rating_id",
|
| 131 |
+
"rating_url",
|
| 132 |
+
"rating_score",
|
| 133 |
+
"rating_timestamp_utc",
|
| 134 |
+
"critic",
|
| 135 |
+
"critic_likes",
|
| 136 |
+
"critic_comments",
|
| 137 |
+
"user_id",
|
| 138 |
+
"user_trialist",
|
| 139 |
+
"user_subscriber",
|
| 140 |
+
"user_eligible_for_trial",
|
| 141 |
+
"user_has_payment_method"
|
| 142 |
+
],
|
| 143 |
+
"column_comments": [
|
| 144 |
+
"",
|
| 145 |
+
"",
|
| 146 |
+
"",
|
| 147 |
+
"",
|
| 148 |
+
"",
|
| 149 |
+
"",
|
| 150 |
+
"",
|
| 151 |
+
"",
|
| 152 |
+
"",
|
| 153 |
+
"",
|
| 154 |
+
"",
|
| 155 |
+
"",
|
| 156 |
+
""
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
]
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
dataset = [data]
|
| 164 |
+
|
| 165 |
+
# 最多保留数据库中的7张表
|
| 166 |
+
num_top_k_tables = 7
|
| 167 |
+
# 对于每张保留的表,最多保留其中20个列,所以输入的prompt中最多有7*10=70个列
|
| 168 |
+
num_top_k_columns = 10
|
| 169 |
+
|
| 170 |
+
# 加载分类器模型
|
| 171 |
+
sic = SchemaItemClassifierInference("sic_merged")
|
| 172 |
+
|
| 173 |
+
# 对于测试数据,我们需要加载训练好的分类器,根据用户问题对表和列打分
|
| 174 |
+
dataset = filter_func(
|
| 175 |
+
dataset = dataset,
|
| 176 |
+
dataset_type = "eval",
|
| 177 |
+
sic = sic,
|
| 178 |
+
num_top_k_tables = num_top_k_tables,
|
| 179 |
+
num_top_k_columns = num_top_k_columns
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
print(dataset)
|
schema_filter.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from utils.classifier_model import SchemaItemClassifier
|
| 8 |
+
from transformers.trainer_utils import set_seed
|
| 9 |
+
|
| 10 |
+
def prepare_inputs_and_labels(sample, tokenizer):
|
| 11 |
+
table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
|
| 12 |
+
column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
|
| 13 |
+
column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]]
|
| 14 |
+
|
| 15 |
+
# `column_name_word_indices` and `table_name_word_indices` record the word indices of each column and table in `input_words`, whose element is an integer
|
| 16 |
+
column_name_word_indices, table_name_word_indices = [], []
|
| 17 |
+
|
| 18 |
+
input_words = [sample["text"]]
|
| 19 |
+
for table_id, table_name in enumerate(table_names):
|
| 20 |
+
input_words.append("|")
|
| 21 |
+
input_words.append(table_name)
|
| 22 |
+
table_name_word_indices.append(len(input_words) - 1)
|
| 23 |
+
input_words.append(":")
|
| 24 |
+
|
| 25 |
+
for column_name in column_names[table_id]:
|
| 26 |
+
input_words.append(column_name)
|
| 27 |
+
column_name_word_indices.append(len(input_words) - 1)
|
| 28 |
+
input_words.append(",")
|
| 29 |
+
|
| 30 |
+
# remove the last ","
|
| 31 |
+
input_words = input_words[:-1]
|
| 32 |
+
|
| 33 |
+
tokenized_inputs = tokenizer(
|
| 34 |
+
input_words,
|
| 35 |
+
return_tensors="pt",
|
| 36 |
+
is_split_into_words = True,
|
| 37 |
+
padding = "max_length",
|
| 38 |
+
max_length = 512,
|
| 39 |
+
truncation = True
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# after tokenizing, one table name or column name may be splitted into multiple tokens (i.e., sub-words)
|
| 43 |
+
# `column_name_token_indices` and `table_name_token_indices` records the token indices of each column and table in `input_ids`, whose element is a list of integer
|
| 44 |
+
column_name_token_indices, table_name_token_indices = [], []
|
| 45 |
+
word_indices = tokenized_inputs.word_ids(batch_index = 0)
|
| 46 |
+
|
| 47 |
+
# obtain token indices of each column in `input_ids`
|
| 48 |
+
for column_name_word_index in column_name_word_indices:
|
| 49 |
+
column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index])
|
| 50 |
+
|
| 51 |
+
# obtain token indices of each table in `input_ids`
|
| 52 |
+
for table_name_word_index in table_name_word_indices:
|
| 53 |
+
table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index])
|
| 54 |
+
|
| 55 |
+
encoder_input_ids = tokenized_inputs["input_ids"]
|
| 56 |
+
encoder_input_attention_mask = tokenized_inputs["attention_mask"]
|
| 57 |
+
|
| 58 |
+
# print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True)))
|
| 59 |
+
|
| 60 |
+
if torch.cuda.is_available():
|
| 61 |
+
encoder_input_ids = encoder_input_ids.cuda()
|
| 62 |
+
encoder_input_attention_mask = encoder_input_attention_mask.cuda()
|
| 63 |
+
|
| 64 |
+
return encoder_input_ids, encoder_input_attention_mask, \
|
| 65 |
+
column_name_token_indices, table_name_token_indices, column_num_in_each_table
|
| 66 |
+
|
| 67 |
+
def get_schema(tables_and_columns):
|
| 68 |
+
schema_items = []
|
| 69 |
+
table_names = list(dict.fromkeys([t for t, c in tables_and_columns]))
|
| 70 |
+
for table_name in table_names:
|
| 71 |
+
schema_items.append(
|
| 72 |
+
{
|
| 73 |
+
"table_name": table_name,
|
| 74 |
+
"column_names": [c for t, c in tables_and_columns if t == table_name]
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return {"schema_items": schema_items}
|
| 79 |
+
|
| 80 |
+
def get_sequence_length(text, tables_and_columns, tokenizer):
|
| 81 |
+
table_names = [t for t, c in tables_and_columns]
|
| 82 |
+
# duplicate `table_names` while preserving order
|
| 83 |
+
table_names = list(dict.fromkeys(table_names))
|
| 84 |
+
|
| 85 |
+
column_names = []
|
| 86 |
+
for table_name in table_names:
|
| 87 |
+
column_names.append([c for t, c in tables_and_columns if t == table_name])
|
| 88 |
+
|
| 89 |
+
input_words = [text]
|
| 90 |
+
for table_id, table_name in enumerate(table_names):
|
| 91 |
+
input_words.append("|")
|
| 92 |
+
input_words.append(table_name)
|
| 93 |
+
input_words.append(":")
|
| 94 |
+
for column_name in column_names[table_id]:
|
| 95 |
+
input_words.append(column_name)
|
| 96 |
+
input_words.append(",")
|
| 97 |
+
# remove the last ","
|
| 98 |
+
input_words = input_words[:-1]
|
| 99 |
+
|
| 100 |
+
tokenized_inputs = tokenizer(input_words, is_split_into_words = True)
|
| 101 |
+
|
| 102 |
+
return len(tokenized_inputs["input_ids"])
|
| 103 |
+
|
| 104 |
+
# handle extremely long schema sequences
|
| 105 |
+
def split_sample(sample, tokenizer):
|
| 106 |
+
text = sample["text"]
|
| 107 |
+
|
| 108 |
+
table_names = []
|
| 109 |
+
column_names = []
|
| 110 |
+
for table in sample["schema"]["schema_items"]:
|
| 111 |
+
table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
|
| 112 |
+
if table["table_comment"] != "" else table["table_name"])
|
| 113 |
+
column_names.append([column_name + " ( " + column_comment + " ) " \
|
| 114 |
+
if column_comment != "" else column_name \
|
| 115 |
+
for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
|
| 116 |
+
|
| 117 |
+
splitted_samples = []
|
| 118 |
+
recorded_tables_and_columns = []
|
| 119 |
+
|
| 120 |
+
for table_idx, table_name in enumerate(table_names):
|
| 121 |
+
for column_name in column_names[table_idx]:
|
| 122 |
+
if get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500:
|
| 123 |
+
recorded_tables_and_columns.append([table_name, column_name])
|
| 124 |
+
else:
|
| 125 |
+
splitted_samples.append(
|
| 126 |
+
{
|
| 127 |
+
"text": text,
|
| 128 |
+
"schema": get_schema(recorded_tables_and_columns)
|
| 129 |
+
}
|
| 130 |
+
)
|
| 131 |
+
recorded_tables_and_columns = [[table_name, column_name]]
|
| 132 |
+
|
| 133 |
+
splitted_samples.append(
|
| 134 |
+
{
|
| 135 |
+
"text": text,
|
| 136 |
+
"schema": get_schema(recorded_tables_and_columns)
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return splitted_samples
|
| 141 |
+
|
| 142 |
+
def merge_pred_results(sample, pred_results):
|
| 143 |
+
# table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
|
| 144 |
+
# column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
|
| 145 |
+
table_names = []
|
| 146 |
+
column_names = []
|
| 147 |
+
for table in sample["schema"]["schema_items"]:
|
| 148 |
+
table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
|
| 149 |
+
if table["table_comment"] != "" else table["table_name"])
|
| 150 |
+
column_names.append([column_name + " ( " + column_comment + " ) " \
|
| 151 |
+
if column_comment != "" else column_name \
|
| 152 |
+
for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
|
| 153 |
+
|
| 154 |
+
merged_results = []
|
| 155 |
+
for table_id, table_name in enumerate(table_names):
|
| 156 |
+
table_prob = 0
|
| 157 |
+
column_probs = []
|
| 158 |
+
for result_dict in pred_results:
|
| 159 |
+
if table_name in result_dict:
|
| 160 |
+
if table_prob < result_dict[table_name]["table_prob"]:
|
| 161 |
+
table_prob = result_dict[table_name]["table_prob"]
|
| 162 |
+
column_probs += result_dict[table_name]["column_probs"]
|
| 163 |
+
|
| 164 |
+
merged_results.append(
|
| 165 |
+
{
|
| 166 |
+
"table_name": table_name,
|
| 167 |
+
"table_prob": table_prob,
|
| 168 |
+
"column_names": column_names[table_id],
|
| 169 |
+
"column_probs": column_probs
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return merged_results
|
| 174 |
+
|
| 175 |
+
def filter_func(dataset, dataset_type, sic, num_top_k_tables = 5, num_top_k_columns = 5):
|
| 176 |
+
for data in tqdm(dataset, desc = "filtering schema items for the dataset"):
|
| 177 |
+
filtered_schema = dict()
|
| 178 |
+
filtered_schema["schema_items"] = []
|
| 179 |
+
|
| 180 |
+
table_names = [table["table_name"] for table in data["schema"]["schema_items"]]
|
| 181 |
+
table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]]
|
| 182 |
+
column_names = [table["column_names"] for table in data["schema"]["schema_items"]]
|
| 183 |
+
column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]]
|
| 184 |
+
|
| 185 |
+
if dataset_type == "eval":
|
| 186 |
+
# predict scores for each tables and columns
|
| 187 |
+
pred_results = sic.predict(data)
|
| 188 |
+
# remain top_k1 tables for each database and top_k2 columns for each remained table
|
| 189 |
+
table_probs = [pred_result["table_prob"] for pred_result in pred_results]
|
| 190 |
+
table_indices = np.argsort(-np.array(table_probs), kind="stable")[:num_top_k_tables].tolist()
|
| 191 |
+
elif dataset_type == "train":
|
| 192 |
+
table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 1]
|
| 193 |
+
if len(table_indices) < num_top_k_tables:
|
| 194 |
+
unused_table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 0]
|
| 195 |
+
table_indices += random.sample(unused_table_indices, min(len(unused_table_indices), num_top_k_tables - len(table_indices)))
|
| 196 |
+
random.shuffle(table_indices)
|
| 197 |
+
|
| 198 |
+
for table_idx in table_indices:
|
| 199 |
+
if dataset_type == "eval":
|
| 200 |
+
column_probs = pred_results[table_idx]["column_probs"]
|
| 201 |
+
column_indices = np.argsort(-np.array(column_probs), kind="stable")[:num_top_k_columns].tolist()
|
| 202 |
+
elif dataset_type == "train":
|
| 203 |
+
column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 1]
|
| 204 |
+
if len(column_indices) < num_top_k_columns:
|
| 205 |
+
unused_column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 0]
|
| 206 |
+
column_indices += random.sample(unused_column_indices, min(len(unused_column_indices), num_top_k_columns - len(column_indices)))
|
| 207 |
+
random.shuffle(column_indices)
|
| 208 |
+
|
| 209 |
+
filtered_schema["schema_items"].append(
|
| 210 |
+
{
|
| 211 |
+
"table_name": table_names[table_idx],
|
| 212 |
+
"table_comment": table_comments[table_idx],
|
| 213 |
+
"column_names": [column_names[table_idx][column_idx] for column_idx in column_indices],
|
| 214 |
+
"column_comments": [column_comments[table_idx][column_idx] for column_idx in column_indices]
|
| 215 |
+
}
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# replace the old schema with the filtered schema
|
| 219 |
+
data["schema"] = filtered_schema
|
| 220 |
+
|
| 221 |
+
if dataset_type == "train":
|
| 222 |
+
del data["table_labels"]
|
| 223 |
+
del data["column_labels"]
|
| 224 |
+
|
| 225 |
+
return dataset
|
| 226 |
+
|
| 227 |
+
def lista_contains_listb(lista, listb):
|
| 228 |
+
for b in listb:
|
| 229 |
+
if b not in lista:
|
| 230 |
+
return 0
|
| 231 |
+
|
| 232 |
+
return 1
|
| 233 |
+
|
| 234 |
+
class SchemaItemClassifierInference():
|
| 235 |
+
def __init__(self, model_save_path):
|
| 236 |
+
set_seed(42)
|
| 237 |
+
# load tokenizer
|
| 238 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
|
| 239 |
+
# initialize model
|
| 240 |
+
self.model = SchemaItemClassifier(model_save_path, "test")
|
| 241 |
+
# load fine-tuned params
|
| 242 |
+
self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False)
|
| 243 |
+
if torch.cuda.is_available():
|
| 244 |
+
self.model = self.model.cuda()
|
| 245 |
+
self.model.eval()
|
| 246 |
+
|
| 247 |
+
def predict_one(self, sample):
|
| 248 |
+
encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\
|
| 249 |
+
table_name_token_indices, column_num_in_each_table = prepare_inputs_and_labels(sample, self.tokenizer)
|
| 250 |
+
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
model_outputs = self.model(
|
| 253 |
+
encoder_input_ids,
|
| 254 |
+
encoder_input_attention_mask,
|
| 255 |
+
[column_name_token_indices],
|
| 256 |
+
[table_name_token_indices],
|
| 257 |
+
[column_num_in_each_table]
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
table_logits = model_outputs["batch_table_name_cls_logits"][0]
|
| 261 |
+
table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist()
|
| 262 |
+
|
| 263 |
+
column_logits = model_outputs["batch_column_info_cls_logits"][0]
|
| 264 |
+
column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist()
|
| 265 |
+
|
| 266 |
+
splitted_column_pred_probs = []
|
| 267 |
+
# split predicted column probs into each table
|
| 268 |
+
for table_id, column_num in enumerate(column_num_in_each_table):
|
| 269 |
+
splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num])
|
| 270 |
+
column_pred_probs = splitted_column_pred_probs
|
| 271 |
+
|
| 272 |
+
result_dict = dict()
|
| 273 |
+
for table_idx, table in enumerate(sample["schema"]["schema_items"]):
|
| 274 |
+
result_dict[table["table_name"]] = {
|
| 275 |
+
"table_name": table["table_name"],
|
| 276 |
+
"table_prob": table_pred_probs[table_idx],
|
| 277 |
+
"column_names": table["column_names"],
|
| 278 |
+
"column_probs": column_pred_probs[table_idx],
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
return result_dict
|
| 282 |
+
|
| 283 |
+
def predict(self, test_sample):
|
| 284 |
+
splitted_samples = split_sample(test_sample, self.tokenizer)
|
| 285 |
+
pred_results = []
|
| 286 |
+
for splitted_sample in splitted_samples:
|
| 287 |
+
pred_results.append(self.predict_one(splitted_sample))
|
| 288 |
+
|
| 289 |
+
return merge_pred_results(test_sample, pred_results)
|
| 290 |
+
|
| 291 |
+
def evaluate_coverage(self, dataset):
|
| 292 |
+
max_k = 100
|
| 293 |
+
total_num_for_table_coverage, total_num_for_column_coverage = 0, 0
|
| 294 |
+
table_coverage_results = [0]*max_k
|
| 295 |
+
column_coverage_results = [0]*max_k
|
| 296 |
+
|
| 297 |
+
for data in dataset:
|
| 298 |
+
indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1]
|
| 299 |
+
pred_results = sic.predict(data)
|
| 300 |
+
# print(pred_results)
|
| 301 |
+
table_probs = [res["table_prob"] for res in pred_results]
|
| 302 |
+
for k in range(max_k):
|
| 303 |
+
indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist()
|
| 304 |
+
if lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables):
|
| 305 |
+
table_coverage_results[k] += 1
|
| 306 |
+
total_num_for_table_coverage += 1
|
| 307 |
+
|
| 308 |
+
for table_idx in range(len(data["table_labels"])):
|
| 309 |
+
indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1]
|
| 310 |
+
if len(indices_of_used_columns) == 0:
|
| 311 |
+
continue
|
| 312 |
+
column_probs = pred_results[table_idx]["column_probs"]
|
| 313 |
+
for k in range(max_k):
|
| 314 |
+
indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist()
|
| 315 |
+
if lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns):
|
| 316 |
+
column_coverage_results[k] += 1
|
| 317 |
+
|
| 318 |
+
total_num_for_column_coverage += 1
|
| 319 |
+
|
| 320 |
+
indices_of_top_10_columns = np.argsort(-np.array(column_probs), kind="stable")[:10].tolist()
|
| 321 |
+
if lista_contains_listb(indices_of_top_10_columns, indices_of_used_columns) == 0:
|
| 322 |
+
print(pred_results[table_idx])
|
| 323 |
+
print(data["column_labels"][table_idx])
|
| 324 |
+
print(data["question"])
|
| 325 |
+
|
| 326 |
+
print(total_num_for_table_coverage)
|
| 327 |
+
print(table_coverage_results)
|
| 328 |
+
print(total_num_for_column_coverage)
|
| 329 |
+
print(column_coverage_results)
|
| 330 |
+
|
| 331 |
+
if __name__ == "__main__":
|
| 332 |
+
dataset_name = "bird_with_evidence"
|
| 333 |
+
# dataset_name = "bird"
|
| 334 |
+
# dataset_name = "spider"
|
| 335 |
+
sic = SchemaItemClassifierInference("sic_ckpts/sic_{}".format(dataset_name))
|
| 336 |
+
import json
|
| 337 |
+
dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name)))
|
| 338 |
+
|
| 339 |
+
sic.evaluate_coverage(dataset)
|
training_mode.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from schema_filter import filter_func
|
| 2 |
+
|
| 3 |
+
data = {
|
| 4 |
+
"text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
|
| 5 |
+
"sql": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1",
|
| 6 |
+
"schema": {
|
| 7 |
+
"schema_items": [
|
| 8 |
+
{
|
| 9 |
+
"table_name": "lists",
|
| 10 |
+
"table_comment": "",
|
| 11 |
+
"column_names": [
|
| 12 |
+
"user_id",
|
| 13 |
+
"list_id",
|
| 14 |
+
"list_title",
|
| 15 |
+
"list_movie_number",
|
| 16 |
+
"list_update_timestamp_utc",
|
| 17 |
+
"list_creation_timestamp_utc",
|
| 18 |
+
"list_followers",
|
| 19 |
+
"list_url",
|
| 20 |
+
"list_comments",
|
| 21 |
+
"list_description",
|
| 22 |
+
"list_cover_image_url",
|
| 23 |
+
"list_first_image_url",
|
| 24 |
+
"list_second_image_url",
|
| 25 |
+
"list_third_image_url"
|
| 26 |
+
],
|
| 27 |
+
"column_comments": [
|
| 28 |
+
"",
|
| 29 |
+
"",
|
| 30 |
+
"",
|
| 31 |
+
"",
|
| 32 |
+
"",
|
| 33 |
+
"",
|
| 34 |
+
"",
|
| 35 |
+
"",
|
| 36 |
+
"",
|
| 37 |
+
"",
|
| 38 |
+
"",
|
| 39 |
+
"",
|
| 40 |
+
"",
|
| 41 |
+
""
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"table_name": "movies",
|
| 46 |
+
"table_comment": "",
|
| 47 |
+
"column_names": [
|
| 48 |
+
"movie_id",
|
| 49 |
+
"movie_title",
|
| 50 |
+
"movie_release_year",
|
| 51 |
+
"movie_url",
|
| 52 |
+
"movie_title_language",
|
| 53 |
+
"movie_popularity",
|
| 54 |
+
"movie_image_url",
|
| 55 |
+
"director_id",
|
| 56 |
+
"director_name",
|
| 57 |
+
"director_url"
|
| 58 |
+
],
|
| 59 |
+
"column_comments": [
|
| 60 |
+
"",
|
| 61 |
+
"",
|
| 62 |
+
"",
|
| 63 |
+
"",
|
| 64 |
+
"",
|
| 65 |
+
"",
|
| 66 |
+
"",
|
| 67 |
+
"",
|
| 68 |
+
"",
|
| 69 |
+
""
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"table_name": "ratings_users",
|
| 74 |
+
"table_comment": "",
|
| 75 |
+
"column_names": [
|
| 76 |
+
"user_id",
|
| 77 |
+
"rating_date_utc",
|
| 78 |
+
"user_trialist",
|
| 79 |
+
"user_subscriber",
|
| 80 |
+
"user_avatar_image_url",
|
| 81 |
+
"user_cover_image_url",
|
| 82 |
+
"user_eligible_for_trial",
|
| 83 |
+
"user_has_payment_method"
|
| 84 |
+
],
|
| 85 |
+
"column_comments": [
|
| 86 |
+
"",
|
| 87 |
+
"",
|
| 88 |
+
"",
|
| 89 |
+
"",
|
| 90 |
+
"",
|
| 91 |
+
"",
|
| 92 |
+
"",
|
| 93 |
+
""
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"table_name": "lists_users",
|
| 98 |
+
"table_comment": "",
|
| 99 |
+
"column_names": [
|
| 100 |
+
"user_id",
|
| 101 |
+
"list_id",
|
| 102 |
+
"list_update_date_utc",
|
| 103 |
+
"list_creation_date_utc",
|
| 104 |
+
"user_trialist",
|
| 105 |
+
"user_subscriber",
|
| 106 |
+
"user_avatar_image_url",
|
| 107 |
+
"user_cover_image_url",
|
| 108 |
+
"user_eligible_for_trial",
|
| 109 |
+
"user_has_payment_method"
|
| 110 |
+
],
|
| 111 |
+
"column_comments": [
|
| 112 |
+
"",
|
| 113 |
+
"",
|
| 114 |
+
"",
|
| 115 |
+
"",
|
| 116 |
+
"",
|
| 117 |
+
"",
|
| 118 |
+
"",
|
| 119 |
+
"",
|
| 120 |
+
"",
|
| 121 |
+
""
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"table_name": "ratings",
|
| 126 |
+
"table_comment": "",
|
| 127 |
+
"column_names": [
|
| 128 |
+
"movie_id",
|
| 129 |
+
"rating_id",
|
| 130 |
+
"rating_url",
|
| 131 |
+
"rating_score",
|
| 132 |
+
"rating_timestamp_utc",
|
| 133 |
+
"critic",
|
| 134 |
+
"critic_likes",
|
| 135 |
+
"critic_comments",
|
| 136 |
+
"user_id",
|
| 137 |
+
"user_trialist",
|
| 138 |
+
"user_subscriber",
|
| 139 |
+
"user_eligible_for_trial",
|
| 140 |
+
"user_has_payment_method"
|
| 141 |
+
],
|
| 142 |
+
"column_comments": [
|
| 143 |
+
"",
|
| 144 |
+
"",
|
| 145 |
+
"",
|
| 146 |
+
"",
|
| 147 |
+
"",
|
| 148 |
+
"",
|
| 149 |
+
"",
|
| 150 |
+
"",
|
| 151 |
+
"",
|
| 152 |
+
"",
|
| 153 |
+
"",
|
| 154 |
+
"",
|
| 155 |
+
""
|
| 156 |
+
]
|
| 157 |
+
}
|
| 158 |
+
]
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def find_used_tables_and_columns(dataset):
|
| 163 |
+
for data in dataset:
|
| 164 |
+
sql = data["sql"].lower()
|
| 165 |
+
data["table_labels"] = []
|
| 166 |
+
data["column_labels"] = []
|
| 167 |
+
|
| 168 |
+
for table_info in data["schema"]["schema_items"]:
|
| 169 |
+
table_name = table_info["table_name"]
|
| 170 |
+
data["table_labels"].append(1 if table_name.lower() in sql else 0)
|
| 171 |
+
data["column_labels"].append([1 if column_name.lower() in sql else 0 \
|
| 172 |
+
for column_name in table_info["column_names"]])
|
| 173 |
+
return dataset
|
| 174 |
+
|
| 175 |
+
dataset = [data]
|
| 176 |
+
|
| 177 |
+
# 根据sql找到用到的表和列
|
| 178 |
+
dataset = find_used_tables_and_columns(dataset)
|
| 179 |
+
|
| 180 |
+
# 最多保留数据库中的6张表
|
| 181 |
+
num_top_k_tables = 6
|
| 182 |
+
# 对于每张保留的表,最多保留其中6个列,所以输入的prompt中最多有6*6=36个列
|
| 183 |
+
num_top_k_columns = 6
|
| 184 |
+
|
| 185 |
+
# 对于训练数据,我们可以根据sql来模拟filter的过程,这时,sic(schema item classifier)是None就行,不需要用到模型
|
| 186 |
+
dataset = filter_func(
|
| 187 |
+
dataset = dataset,
|
| 188 |
+
dataset_type = "train",
|
| 189 |
+
sic = None,
|
| 190 |
+
num_top_k_tables = num_top_k_tables,
|
| 191 |
+
num_top_k_columns = num_top_k_columns
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
print(dataset)
|
utils/__pycache__/classifier_model.cpython-38.pyc
ADDED
|
Binary file (4.01 kB). View file
|
|
|
utils/classifier_model.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from transformers import AutoConfig, XLMRobertaXLModel
|
| 5 |
+
|
| 6 |
+
class SchemaItemClassifier(nn.Module):
|
| 7 |
+
def __init__(self, model_name_or_path, mode):
|
| 8 |
+
super(SchemaItemClassifier, self).__init__()
|
| 9 |
+
if mode in ["eval", "test"]:
|
| 10 |
+
# load config
|
| 11 |
+
config = AutoConfig.from_pretrained(model_name_or_path)
|
| 12 |
+
# randomly initialize model's parameters according to the config
|
| 13 |
+
self.plm_encoder = XLMRobertaXLModel(config)
|
| 14 |
+
elif mode == "train":
|
| 15 |
+
self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path)
|
| 16 |
+
else:
|
| 17 |
+
raise ValueError()
|
| 18 |
+
|
| 19 |
+
self.plm_hidden_size = self.plm_encoder.config.hidden_size
|
| 20 |
+
|
| 21 |
+
# column cls head
|
| 22 |
+
self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
|
| 23 |
+
self.column_info_cls_head_linear2 = nn.Linear(256, 2)
|
| 24 |
+
|
| 25 |
+
# column bi-lstm layer
|
| 26 |
+
self.column_info_bilstm = nn.LSTM(
|
| 27 |
+
input_size = self.plm_hidden_size,
|
| 28 |
+
hidden_size = int(self.plm_hidden_size/2),
|
| 29 |
+
num_layers = 2,
|
| 30 |
+
dropout = 0,
|
| 31 |
+
bidirectional = True
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# linear layer after column bi-lstm layer
|
| 35 |
+
self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
|
| 36 |
+
|
| 37 |
+
# table cls head
|
| 38 |
+
self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
|
| 39 |
+
self.table_name_cls_head_linear2 = nn.Linear(256, 2)
|
| 40 |
+
|
| 41 |
+
# table bi-lstm pooling layer
|
| 42 |
+
self.table_name_bilstm = nn.LSTM(
|
| 43 |
+
input_size = self.plm_hidden_size,
|
| 44 |
+
hidden_size = int(self.plm_hidden_size/2),
|
| 45 |
+
num_layers = 2,
|
| 46 |
+
dropout = 0,
|
| 47 |
+
bidirectional = True
|
| 48 |
+
)
|
| 49 |
+
# linear layer after table bi-lstm layer
|
| 50 |
+
self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
|
| 51 |
+
|
| 52 |
+
# activation function
|
| 53 |
+
self.leakyrelu = nn.LeakyReLU()
|
| 54 |
+
self.tanh = nn.Tanh()
|
| 55 |
+
|
| 56 |
+
# table-column cross-attention layer
|
| 57 |
+
self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8)
|
| 58 |
+
|
| 59 |
+
# dropout function, p=0.2 means randomly set 20% neurons to 0
|
| 60 |
+
self.dropout = nn.Dropout(p = 0.2)
|
| 61 |
+
|
| 62 |
+
def table_column_cross_attention(
|
| 63 |
+
self,
|
| 64 |
+
table_name_embeddings_in_one_db,
|
| 65 |
+
column_info_embeddings_in_one_db,
|
| 66 |
+
column_number_in_each_table
|
| 67 |
+
):
|
| 68 |
+
table_num = table_name_embeddings_in_one_db.shape[0]
|
| 69 |
+
table_name_embedding_attn_list = []
|
| 70 |
+
for table_id in range(table_num):
|
| 71 |
+
table_name_embedding = table_name_embeddings_in_one_db[[table_id], :]
|
| 72 |
+
column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[
|
| 73 |
+
sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :]
|
| 74 |
+
|
| 75 |
+
table_name_embedding_attn, _ = self.table_column_cross_attention_layer(
|
| 76 |
+
table_name_embedding,
|
| 77 |
+
column_info_embeddings_in_one_table,
|
| 78 |
+
column_info_embeddings_in_one_table
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
table_name_embedding_attn_list.append(table_name_embedding_attn)
|
| 82 |
+
|
| 83 |
+
# residual connection
|
| 84 |
+
table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0)
|
| 85 |
+
# row-wise L2 norm
|
| 86 |
+
table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1)
|
| 87 |
+
|
| 88 |
+
return table_name_embeddings_in_one_db
|
| 89 |
+
|
| 90 |
+
def table_column_cls(
|
| 91 |
+
self,
|
| 92 |
+
encoder_input_ids,
|
| 93 |
+
encoder_input_attention_mask,
|
| 94 |
+
batch_aligned_column_info_ids,
|
| 95 |
+
batch_aligned_table_name_ids,
|
| 96 |
+
batch_column_number_in_each_table
|
| 97 |
+
):
|
| 98 |
+
batch_size = encoder_input_ids.shape[0]
|
| 99 |
+
|
| 100 |
+
encoder_output = self.plm_encoder(
|
| 101 |
+
input_ids = encoder_input_ids,
|
| 102 |
+
attention_mask = encoder_input_attention_mask,
|
| 103 |
+
return_dict = True
|
| 104 |
+
) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size)
|
| 105 |
+
|
| 106 |
+
batch_table_name_cls_logits, batch_column_info_cls_logits = [], []
|
| 107 |
+
|
| 108 |
+
# handle each data in current batch
|
| 109 |
+
for batch_id in range(batch_size):
|
| 110 |
+
column_number_in_each_table = batch_column_number_in_each_table[batch_id]
|
| 111 |
+
sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size)
|
| 112 |
+
|
| 113 |
+
# obtain table ids for each table
|
| 114 |
+
aligned_table_name_ids = batch_aligned_table_name_ids[batch_id]
|
| 115 |
+
# obtain column ids for each column
|
| 116 |
+
aligned_column_info_ids = batch_aligned_column_info_ids[batch_id]
|
| 117 |
+
|
| 118 |
+
table_name_embedding_list, column_info_embedding_list = [], []
|
| 119 |
+
|
| 120 |
+
# obtain table embedding via bi-lstm pooling + a non-linear layer
|
| 121 |
+
for table_name_ids in aligned_table_name_ids:
|
| 122 |
+
table_name_embeddings = sequence_embeddings[table_name_ids, :]
|
| 123 |
+
|
| 124 |
+
# BiLSTM pooling
|
| 125 |
+
output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings)
|
| 126 |
+
table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size)
|
| 127 |
+
table_name_embedding_list.append(table_name_embedding)
|
| 128 |
+
table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0)
|
| 129 |
+
# non-linear mlp layer
|
| 130 |
+
table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db))
|
| 131 |
+
|
| 132 |
+
# obtain column embedding via bi-lstm pooling + a non-linear layer
|
| 133 |
+
for column_info_ids in aligned_column_info_ids:
|
| 134 |
+
column_info_embeddings = sequence_embeddings[column_info_ids, :]
|
| 135 |
+
|
| 136 |
+
# BiLSTM pooling
|
| 137 |
+
output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings)
|
| 138 |
+
column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size)
|
| 139 |
+
column_info_embedding_list.append(column_info_embedding)
|
| 140 |
+
column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0)
|
| 141 |
+
# non-linear mlp layer
|
| 142 |
+
column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db))
|
| 143 |
+
|
| 144 |
+
# table-column (tc) cross-attention
|
| 145 |
+
table_name_embeddings_in_one_db = self.table_column_cross_attention(
|
| 146 |
+
table_name_embeddings_in_one_db,
|
| 147 |
+
column_info_embeddings_in_one_db,
|
| 148 |
+
column_number_in_each_table
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# calculate table 0-1 logits
|
| 152 |
+
table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db)
|
| 153 |
+
table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db))
|
| 154 |
+
table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db)
|
| 155 |
+
|
| 156 |
+
# calculate column 0-1 logits
|
| 157 |
+
column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db)
|
| 158 |
+
column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db))
|
| 159 |
+
column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db)
|
| 160 |
+
|
| 161 |
+
batch_table_name_cls_logits.append(table_name_cls_logits)
|
| 162 |
+
batch_column_info_cls_logits.append(column_info_cls_logits)
|
| 163 |
+
|
| 164 |
+
return batch_table_name_cls_logits, batch_column_info_cls_logits
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
encoder_input_ids,
|
| 169 |
+
encoder_attention_mask,
|
| 170 |
+
batch_aligned_column_info_ids,
|
| 171 |
+
batch_aligned_table_name_ids,
|
| 172 |
+
batch_column_number_in_each_table,
|
| 173 |
+
):
|
| 174 |
+
batch_table_name_cls_logits, batch_column_info_cls_logits \
|
| 175 |
+
= self.table_column_cls(
|
| 176 |
+
encoder_input_ids,
|
| 177 |
+
encoder_attention_mask,
|
| 178 |
+
batch_aligned_column_info_ids,
|
| 179 |
+
batch_aligned_table_name_ids,
|
| 180 |
+
batch_column_number_in_each_table
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"batch_table_name_cls_logits" : batch_table_name_cls_logits,
|
| 185 |
+
"batch_column_info_cls_logits": batch_column_info_cls_logits
|
| 186 |
+
}
|