| import collections |
| import itertools |
| import json |
| import os |
|
|
| import attr |
| import nltk.corpus |
| import torch |
| import torchtext |
| import numpy as np |
|
|
| from seq2struct.models import abstract_preproc, transformer |
| from seq2struct.models.spider import spider_enc_modules |
| from seq2struct.resources import pretrained_embeddings |
| from seq2struct.utils import registry |
| from seq2struct.utils import vocab |
| from seq2struct.utils import serialization |
| from seq2struct import resources |
| from seq2struct.resources import corenlp |
| from transformers import BertModel, BertTokenizer, BartModel, BartTokenizer |
| from seq2struct.models.spider.spider_match_utils import ( |
| compute_schema_linking, |
| compute_cell_value_linking |
| ) |
|
|
|
|
| @attr.s |
| class SpiderEncoderState: |
| state = attr.ib() |
| memory = attr.ib() |
| question_memory = attr.ib() |
| schema_memory = attr.ib() |
| words = attr.ib() |
|
|
| pointer_memories = attr.ib() |
| pointer_maps = attr.ib() |
|
|
| m2c_align_mat = attr.ib() |
| m2t_align_mat = attr.ib() |
|
|
| def find_word_occurrences(self, word): |
| return [i for i, w in enumerate(self.words) if w == word] |
|
|
|
|
| @attr.s |
| class PreprocessedSchema: |
| column_names = attr.ib(factory=list) |
| table_names = attr.ib(factory=list) |
| table_bounds = attr.ib(factory=list) |
| column_to_table = attr.ib(factory=dict) |
| table_to_columns = attr.ib(factory=dict) |
| foreign_keys = attr.ib(factory=dict) |
| foreign_keys_tables = attr.ib(factory=lambda: collections.defaultdict(set)) |
| primary_keys = attr.ib(factory=list) |
|
|
| |
| normalized_column_names = attr.ib(factory=list) |
| normalized_table_names = attr.ib(factory=list) |
|
|
| def preprocess_schema_uncached(schema, |
| tokenize_func, |
| include_table_name_in_column, |
| fix_issue_16_primary_keys, |
| bert=False): |
| """If it's bert, we also cache the normalized version of |
| question/column/table for schema linking""" |
| r = PreprocessedSchema() |
|
|
| if bert: assert not include_table_name_in_column |
|
|
| last_table_id = None |
| for i, column in enumerate(schema.columns): |
| col_toks = tokenize_func( |
| column.name, column.unsplit_name) |
|
|
| |
| type_tok = '<type: {}>'.format(column.type) |
| if bert: |
| |
| column_name = col_toks + [type_tok] |
| r.normalized_column_names.append(Bertokens(col_toks)) |
| else: |
| column_name = [type_tok] + col_toks |
|
|
| if include_table_name_in_column: |
| if column.table is None: |
| table_name = ['<any-table>'] |
| else: |
| table_name = tokenize_func( |
| column.table.name, column.table.unsplit_name) |
| column_name += ['<table-sep>'] + table_name |
| r.column_names.append(column_name) |
|
|
| table_id = None if column.table is None else column.table.id |
| r.column_to_table[str(i)] = table_id |
| if table_id is not None: |
| columns = r.table_to_columns.setdefault(str(table_id), []) |
| columns.append(i) |
| if last_table_id != table_id: |
| r.table_bounds.append(i) |
| last_table_id = table_id |
|
|
| if column.foreign_key_for is not None: |
| r.foreign_keys[str(column.id)] = column.foreign_key_for.id |
| r.foreign_keys_tables[str(column.table.id)].add(column.foreign_key_for.table.id) |
|
|
| r.table_bounds.append(len(schema.columns)) |
| assert len(r.table_bounds) == len(schema.tables) + 1 |
|
|
| for i, table in enumerate(schema.tables): |
| table_toks = tokenize_func( |
| table.name, table.unsplit_name) |
| r.table_names.append(table_toks) |
| if bert: |
| r.normalized_table_names.append(Bertokens(table_toks)) |
| last_table = schema.tables[-1] |
|
|
| r.foreign_keys_tables = serialization.to_dict_with_sorted_values(r.foreign_keys_tables) |
| r.primary_keys = [ |
| column.id |
| for table in schema.tables |
| for column in table.primary_keys |
| ] if fix_issue_16_primary_keys else [ |
| column.id |
| for column in last_table.primary_keys |
| for table in schema.tables |
| ] |
|
|
| return r |
|
|
| class SpiderEncoderV2Preproc(abstract_preproc.AbstractPreproc): |
|
|
| def __init__( |
| self, |
| save_path, |
| min_freq=3, |
| max_count=5000, |
| include_table_name_in_column=True, |
| word_emb=None, |
| count_tokens_in_word_emb_for_vocab=False, |
| |
| fix_issue_16_primary_keys=False, |
| compute_sc_link=False, |
| compute_cv_link=False, |
| db_path=None): |
| if word_emb is None: |
| self.word_emb = None |
| else: |
| self.word_emb = registry.construct('word_emb', word_emb) |
|
|
| self.data_dir = os.path.join(save_path, 'enc') |
| self.include_table_name_in_column = include_table_name_in_column |
| self.count_tokens_in_word_emb_for_vocab = count_tokens_in_word_emb_for_vocab |
| self.fix_issue_16_primary_keys = fix_issue_16_primary_keys |
| self.compute_sc_link = compute_sc_link |
| self.compute_cv_link = compute_cv_link |
| self.texts = collections.defaultdict(list) |
|
|
| self.db_path = db_path |
| if self.compute_cv_link: assert self.db_path is not None |
|
|
| self.vocab_builder = vocab.VocabBuilder(min_freq, max_count) |
| self.vocab_path = os.path.join(save_path, 'enc_vocab.json') |
| self.vocab_word_freq_path = os.path.join(save_path, 'enc_word_freq.json') |
| self.vocab = None |
| self.counted_db_ids = set() |
| self.preprocessed_schemas = {} |
|
|
|
|
| def validate_item(self, item, section): |
| return True, None |
| |
| def add_item(self, item, section, validation_info): |
| preprocessed = self.preprocess_item(item, validation_info) |
| self.texts[section].append(preprocessed) |
|
|
| if section == 'train': |
| if item.schema.db_id in self.counted_db_ids: |
| to_count = preprocessed['question'] |
| else: |
| self.counted_db_ids.add(item.schema.db_id) |
| to_count = itertools.chain( |
| preprocessed['question'], |
| *preprocessed['columns'], |
| *preprocessed['tables']) |
|
|
| for token in to_count: |
| count_token = ( |
| self.word_emb is None or |
| self.count_tokens_in_word_emb_for_vocab or |
| self.word_emb.lookup(token) is None) |
| if count_token: |
| self.vocab_builder.add_word(token) |
|
|
| def clear_items(self): |
| self.texts = collections.defaultdict(list) |
|
|
| def preprocess_item(self, item, validation_info): |
| question, question_for_copying = self._tokenize_for_copying(item.text, item.orig['question']) |
| preproc_schema = self._preprocess_schema(item.schema) |
| if self.compute_sc_link: |
| assert preproc_schema.column_names[0][0].startswith("<type:") |
| column_names_without_types = [col[1:] for col in preproc_schema.column_names] |
| sc_link = compute_schema_linking(question, \ |
| column_names_without_types, preproc_schema.table_names) |
| else: |
| sc_link = {"q_col_match": {}, "q_tab_match": {}} |
|
|
| if self.compute_cv_link: |
| cv_link = compute_cell_value_linking(question, item.schema, self.db_path) |
| else: |
| cv_link = {"num_date_match": {}, "cell_match": {}} |
|
|
| return { |
| 'raw_question': item.orig['question'], |
| 'question': question, |
| 'question_for_copying': question_for_copying, |
| 'db_id': item.schema.db_id, |
| 'sc_link': sc_link, |
| 'cv_link': cv_link, |
| 'columns': preproc_schema.column_names, |
| 'tables': preproc_schema.table_names, |
| 'table_bounds': preproc_schema.table_bounds, |
| 'column_to_table': preproc_schema.column_to_table, |
| 'table_to_columns': preproc_schema.table_to_columns, |
| 'foreign_keys': preproc_schema.foreign_keys, |
| 'foreign_keys_tables': preproc_schema.foreign_keys_tables, |
| 'primary_keys': preproc_schema.primary_keys, |
| } |
|
|
| def _preprocess_schema(self, schema): |
| if schema.db_id in self.preprocessed_schemas: |
| return self.preprocessed_schemas[schema.db_id] |
| result = preprocess_schema_uncached(schema, self._tokenize, |
| self.include_table_name_in_column, self.fix_issue_16_primary_keys) |
| self.preprocessed_schemas[schema.db_id] = result |
| return result |
|
|
| def _tokenize(self, presplit, unsplit): |
| if self.word_emb: |
| return self.word_emb.tokenize(unsplit) |
| return presplit |
|
|
| def _tokenize_for_copying(self, presplit, unsplit): |
| if self.word_emb: |
| return self.word_emb.tokenize_for_copying(unsplit) |
| return presplit, presplit |
|
|
| def save(self): |
| os.makedirs(self.data_dir, exist_ok=True) |
| self.vocab = self.vocab_builder.finish() |
| print(f"{len(self.vocab)} words in vocab") |
| self.vocab.save(self.vocab_path) |
| self.vocab_builder.save(self.vocab_word_freq_path) |
|
|
| for section, texts in self.texts.items(): |
| with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: |
| for text in texts: |
| f.write(json.dumps(text) + '\n') |
|
|
| def load(self): |
| self.vocab = vocab.Vocab.load(self.vocab_path) |
| self.vocab_builder.load(self.vocab_word_freq_path) |
|
|
| def dataset(self, section): |
| return [ |
| json.loads(line) |
| for line in open(os.path.join(self.data_dir, section + '.jsonl'))] |
|
|
|
|
| @registry.register('encoder', 'spiderv2') |
| class SpiderEncoderV2(torch.nn.Module): |
|
|
| batched = True |
| Preproc = SpiderEncoderV2Preproc |
|
|
| def __init__( |
| self, |
| device, |
| preproc, |
| word_emb_size=128, |
| recurrent_size=256, |
| dropout=0., |
| question_encoder=('emb', 'bilstm'), |
| column_encoder=('emb', 'bilstm'), |
| table_encoder=('emb', 'bilstm'), |
| update_config={}, |
| include_in_memory=('question', 'column', 'table'), |
| batch_encs_update=True, |
| top_k_learnable = 0): |
| super().__init__() |
| self._device = device |
| self.preproc = preproc |
|
|
| self.vocab = preproc.vocab |
| self.word_emb_size = word_emb_size |
| self.recurrent_size = recurrent_size |
| assert self.recurrent_size % 2 == 0 |
| word_freq = self.preproc.vocab_builder.word_freq |
| top_k_words = set([_a[0] for _a in word_freq.most_common(top_k_learnable)]) |
| self.learnable_words = top_k_words |
| |
| self.include_in_memory = set(include_in_memory) |
| self.dropout = dropout |
|
|
| self.question_encoder = self._build_modules(question_encoder) |
| self.column_encoder = self._build_modules(column_encoder) |
| self.table_encoder = self._build_modules(table_encoder) |
|
|
|
|
| update_modules = { |
| 'relational_transformer': |
| spider_enc_modules.RelationalTransformerUpdate, |
| 'none': |
| spider_enc_modules.NoOpUpdate, |
| } |
|
|
| self.encs_update = registry.instantiate( |
| update_modules[update_config['name']], |
| update_config, |
| unused_keys={"name"}, |
| device=self._device, |
| hidden_size=recurrent_size, |
| ) |
| self.batch_encs_update = batch_encs_update |
|
|
| def _build_modules(self, module_types): |
| module_builder = { |
| 'emb': lambda: spider_enc_modules.LookupEmbeddings( |
| self._device, |
| self.vocab, |
| self.preproc.word_emb, |
| self.word_emb_size, |
| self.learnable_words), |
| 'linear': lambda: spider_enc_modules.EmbLinear( |
| input_size=self.word_emb_size, |
| output_size=self.word_emb_size), |
| 'bilstm': lambda: spider_enc_modules.BiLSTM( |
| input_size=self.word_emb_size, |
| output_size=self.recurrent_size, |
| dropout=self.dropout, |
| summarize=False), |
| 'bilstm-native': lambda: spider_enc_modules.BiLSTM( |
| input_size=self.word_emb_size, |
| output_size=self.recurrent_size, |
| dropout=self.dropout, |
| summarize=False, |
| use_native=True), |
| 'bilstm-summarize': lambda: spider_enc_modules.BiLSTM( |
| input_size=self.word_emb_size, |
| output_size=self.recurrent_size, |
| dropout=self.dropout, |
| summarize=True), |
| 'bilstm-native-summarize': lambda: spider_enc_modules.BiLSTM( |
| input_size=self.word_emb_size, |
| output_size=self.recurrent_size, |
| dropout=self.dropout, |
| summarize=True, |
| use_native=True), |
| } |
|
|
| modules = [] |
| for module_type in module_types: |
| modules.append(module_builder[module_type]()) |
| return torch.nn.Sequential(*modules) |
|
|
|
|
| def forward_unbatched(self, desc): |
| |
| |
| |
|
|
| |
| q_enc, (_, _) = self.question_encoder([desc['question']]) |
|
|
| |
| |
| |
| |
| |
| c_enc, c_boundaries = self.column_encoder(desc['columns']) |
| column_pointer_maps = { |
| i: list(range(left, right)) |
| for i, (left, right) in enumerate(zip(c_boundaries, c_boundaries[1:])) |
| } |
|
|
| |
| |
| |
| |
| |
| t_enc, t_boundaries = self.table_encoder(desc['tables']) |
| c_enc_length = c_enc.shape[0] |
| table_pointer_maps = { |
| i: [ |
| idx |
| for col in desc['table_to_columns'][str(i)] |
| for idx in column_pointer_maps[col] |
| ] + list(range(left + c_enc_length, right + c_enc_length)) |
| for i, (left, right) in enumerate(zip(t_boundaries, t_boundaries[1:])) |
| } |
|
|
| |
| |
| |
| q_enc_new, c_enc_new, t_enc_new = self.encs_update( |
| desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries) |
| |
| memory = [] |
| words_for_copying = [] |
| if 'question' in self.include_in_memory: |
| memory.append(q_enc_new) |
| if 'question_for_copying' in desc: |
| assert q_enc_new.shape[1] == desc['question_for_copying'] |
| words_for_copying += desc['question_for_copying'] |
| else: |
| words_for_copying += [''] * q_enc_new.shape[1] |
| if 'column' in self.include_in_memory: |
| memory.append(c_enc_new) |
| words_for_copying += [''] * c_enc_new.shape[1] |
| if 'table' in self.include_in_memory: |
| memory.append(t_enc_new) |
| words_for_copying += [''] * t_enc_new.shape[1] |
| memory = torch.cat(memory, dim=1) |
|
|
| return SpiderEncoderState( |
| state=None, |
| memory=memory, |
| words=words_for_copying, |
| pointer_memories={ |
| 'column': c_enc_new, |
| 'table': torch.cat((c_enc_new, t_enc_new), dim=1), |
| }, |
| pointer_maps={ |
| 'column': column_pointer_maps, |
| 'table': table_pointer_maps, |
| } |
| ) |
|
|
| def forward(self, descs): |
| |
| |
| |
|
|
| |
| qs = [[desc['question']] for desc in descs] |
| q_enc, _ = self.question_encoder(qs) |
|
|
| |
| |
| |
| |
| |
| c_enc, c_boundaries = self.column_encoder([desc['columns'] for desc in descs]) |
|
|
| column_pointer_maps = [ |
| { |
| i: list(range(left, right)) |
| for i, (left, right) in enumerate(zip(c_boundaries_for_item, c_boundaries_for_item[1:])) |
| } |
| for batch_idx, c_boundaries_for_item in enumerate(c_boundaries) |
| ] |
|
|
| |
| |
| |
| |
| |
| t_enc, t_boundaries = self.table_encoder([desc['tables'] for desc in descs]) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| table_pointer_maps = [ |
| { |
| i: list(range(left, right)) |
| for i, (left, right) in enumerate(zip(t_boundaries_for_item, t_boundaries_for_item[1:])) |
| } |
| for batch_idx, (desc, t_boundaries_for_item) in enumerate(zip(descs, t_boundaries)) |
| ] |
|
|
| |
| |
| |
| if self.batch_encs_update: |
| q_enc_new, c_enc_new, t_enc_new = self.encs_update( |
| descs, q_enc, c_enc, c_boundaries, t_enc, t_boundaries) |
| |
| result = [] |
| for batch_idx, desc in enumerate(descs): |
| if self.batch_encs_update: |
| q_enc_new_item = q_enc_new.select(batch_idx).unsqueeze(0) |
| c_enc_new_item = c_enc_new.select(batch_idx).unsqueeze(0) |
| t_enc_new_item = t_enc_new.select(batch_idx).unsqueeze(0) |
| else: |
| q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ |
| self.encs_update.forward_unbatched( |
| desc, |
| q_enc.select(batch_idx).unsqueeze(1), |
| c_enc.select(batch_idx).unsqueeze(1), |
| c_boundaries[batch_idx], |
| t_enc.select(batch_idx).unsqueeze(1), |
| t_boundaries[batch_idx]) |
|
|
| memory = [] |
| words_for_copying = [] |
| if 'question' in self.include_in_memory: |
| memory.append(q_enc_new_item) |
| if 'question_for_copying' in desc: |
| assert q_enc_new_item.shape[1] == len(desc['question_for_copying']) |
| words_for_copying += desc['question_for_copying'] |
| else: |
| words_for_copying += [''] * q_enc_new_item.shape[1] |
| if 'column' in self.include_in_memory: |
| memory.append(c_enc_new_item) |
| words_for_copying += [''] * c_enc_new_item.shape[1] |
| if 'table' in self.include_in_memory: |
| memory.append(t_enc_new_item) |
| words_for_copying += [''] * t_enc_new_item.shape[1] |
| memory = torch.cat(memory, dim=1) |
|
|
| result.append(SpiderEncoderState( |
| state=None, |
| memory=memory, |
| question_memory=q_enc_new_item, |
| schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
| |
| words=words_for_copying, |
| pointer_memories={ |
| 'column': c_enc_new_item, |
| 'table': torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
| }, |
| pointer_maps={ |
| 'column': column_pointer_maps[batch_idx], |
| 'table': table_pointer_maps[batch_idx], |
| }, |
| m2c_align_mat=align_mat_item[0], |
| m2t_align_mat=align_mat_item[1], |
| )) |
| return result |
|
|
|
|
| class Bertokens: |
| def __init__(self, pieces): |
| self.pieces = pieces |
|
|
| self.normalized_pieces = None |
| self.idx_map = None |
|
|
| self.normalize_toks() |
|
|
| def normalize_toks(self): |
| """ |
| If the token is not a word piece, then find its lemma |
| If it is, combine pieces into a word, and then find its lemma |
| E.g., a ##b ##c will be normalized as "abc", "", "" |
| NOTE: this is only used for schema linking |
| """ |
| self.startidx2pieces = dict() |
| self.pieces2startidx = dict() |
| cache_start = None |
| for i, piece in enumerate(self.pieces + [""]): |
| if piece.startswith("##"): |
| if cache_start is None: |
| cache_start = i - 1 |
|
|
| self.pieces2startidx[i] = cache_start |
| self.pieces2startidx[i-1] = cache_start |
| else: |
| if cache_start is not None: |
| self.startidx2pieces[cache_start] = i |
| cache_start = None |
| assert cache_start is None |
|
|
| |
| combined_word = {} |
| for start, end in self.startidx2pieces.items(): |
| assert end - start + 1 < 10 |
| pieces = [self.pieces[start]] + [self.pieces[_id].strip("##") for _id in range(start+1, end)] |
| word = "".join(pieces) |
| combined_word[start] = word |
| |
| |
| idx_map = {} |
| new_toks = [] |
| for i, piece in enumerate(self.pieces): |
| if i in combined_word: |
| idx_map[len(new_toks)] = i |
| new_toks.append(combined_word[i]) |
| elif i in self.pieces2startidx: |
| |
| pass |
| else: |
| idx_map[len(new_toks)] = i |
| new_toks.append(piece) |
| self.idx_map = idx_map |
| |
| |
| normalized_toks = [] |
| for i, tok in enumerate(new_toks): |
| ann = corenlp.annotate(tok, annotators = ['tokenize', 'ssplit', 'lemma']) |
| lemmas = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] |
| lemma_word = " ".join(lemmas) |
| normalized_toks.append(lemma_word) |
|
|
| self.normalized_pieces = normalized_toks |
| |
| def bert_schema_linking(self, columns, tables): |
| question_tokens =self.normalized_pieces |
| column_tokens = [c.normalized_pieces for c in columns] |
| table_tokens = [t.normalized_pieces for t in tables] |
| sc_link = compute_schema_linking(question_tokens, column_tokens, table_tokens) |
|
|
| new_sc_link = {} |
| for m_type in sc_link: |
| _match = {} |
| for ij_str in sc_link[m_type]: |
| q_id_str, col_tab_id_str = ij_str.split(",") |
| q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) |
| real_q_id = self.idx_map[q_id] |
| _match[f"{real_q_id},{col_tab_id}"] = sc_link[m_type][ij_str] |
|
|
| new_sc_link[m_type] = _match |
| return new_sc_link |
|
|
|
|
| class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc): |
|
|
| def __init__( |
| self, |
| save_path, |
| db_path, |
| fix_issue_16_primary_keys=False, |
| include_table_name_in_column = False, |
| bert_version = "bert-base-uncased", |
| compute_sc_link=True, |
| compute_cv_link=False): |
| |
| self.data_dir = os.path.join(save_path, 'enc') |
| self.db_path = db_path |
| self.texts = collections.defaultdict(list) |
| self.fix_issue_16_primary_keys = fix_issue_16_primary_keys |
| self.include_table_name_in_column = include_table_name_in_column |
| self.compute_sc_link = compute_sc_link |
| self.compute_cv_link = compute_cv_link |
|
|
| self.counted_db_ids = set() |
| self.preprocessed_schemas = {} |
|
|
| self.tokenizer = BertTokenizer.from_pretrained(bert_version) |
| self.tokenizer.add_special_tokens({"additional_special_tokens": ["<col>"]}) |
| |
| column_types = ["text", "number", "time", "boolean", "others"] |
| self.tokenizer.add_tokens([f"<type: {t}>" for t in column_types]) |
|
|
| def _tokenize(self, presplit, unsplit): |
| if self.tokenizer: |
| toks = self.tokenizer.tokenize(unsplit) |
| return toks |
| return presplit |
|
|
|
|
| def add_item(self, item, section, validation_info): |
| preprocessed = self.preprocess_item(item, validation_info) |
| self.texts[section].append(preprocessed) |
|
|
| def preprocess_item(self, item, validation_info): |
| question = self._tokenize(item.text, item.orig['question']) |
| preproc_schema = self._preprocess_schema(item.schema) |
| if self.compute_sc_link: |
| question_bert_tokens = Bertokens(item.text) |
| sc_link = question_bert_tokens.bert_schema_linking( |
| preproc_schema.normalized_column_names, |
| preproc_schema.normalized_table_names |
| ) |
| else: |
| sc_link = {"q_col_match": {}, "q_tab_match": {}} |
| |
| if self.compute_cv_link: |
| question_bert_tokens = Bertokens(question) |
| cv_link = compute_cell_value_linking( |
| question_bert_tokens.normalized_pieces, item.schema, self.db_path) |
| else: |
| cv_link = {"num_date_match": {}, "cell_match": {}} |
| |
| return { |
| 'raw_question': item.orig['question'], |
| 'question': question, |
| 'db_id': item.schema.db_id, |
| 'sc_link': sc_link, |
| 'cv_link': cv_link, |
| 'columns': preproc_schema.column_names, |
| 'tables': preproc_schema.table_names, |
| 'table_bounds': preproc_schema.table_bounds, |
| 'column_to_table': preproc_schema.column_to_table, |
| 'table_to_columns': preproc_schema.table_to_columns, |
| 'foreign_keys': preproc_schema.foreign_keys, |
| 'foreign_keys_tables': preproc_schema.foreign_keys_tables, |
| 'primary_keys': preproc_schema.primary_keys, |
| } |
| |
| def validate_item(self, item, section): |
| question = self._tokenize(item.text, item.orig['question']) |
| preproc_schema = self._preprocess_schema(item.schema) |
|
|
| num_words = len(question) + 2 + \ |
| sum(len(c) + 1 for c in preproc_schema.column_names) + \ |
| sum(len(t) + 1 for t in preproc_schema.table_names) |
| if num_words > 512: |
| return False, None |
| else: |
| return True, None |
|
|
| def _preprocess_schema(self, schema): |
| if schema.db_id in self.preprocessed_schemas: |
| return self.preprocessed_schemas[schema.db_id] |
| result = preprocess_schema_uncached(schema, self._tokenize, |
| self.include_table_name_in_column, |
| self.fix_issue_16_primary_keys, bert=True) |
| self.preprocessed_schemas[schema.db_id] = result |
| return result |
|
|
|
|
| def save(self): |
| os.makedirs(self.data_dir, exist_ok=True) |
| self.tokenizer.save_pretrained(self.data_dir) |
|
|
| for section, texts in self.texts.items(): |
| with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: |
| for text in texts: |
| f.write(json.dumps(text) + '\n') |
|
|
| def load(self): |
| self.tokenizer = BertTokenizer.from_pretrained(self.data_dir) |
|
|
|
|
|
|
| @registry.register('encoder', 'spider-bert') |
| class SpiderEncoderBert(torch.nn.Module): |
|
|
| Preproc = SpiderEncoderBertPreproc |
| batched = True |
|
|
| def __init__( |
| self, |
| device, |
| preproc, |
| update_config={}, |
| bert_token_type=False, |
| bert_version="bert-base-uncased", |
| summarize_header="first", |
| use_column_type=True, |
| include_in_memory=('question', 'column', 'table')): |
| super().__init__() |
| self._device = device |
| self.preproc = preproc |
| self.bert_token_type = bert_token_type |
| self.base_enc_hidden_size = 1024 if bert_version == "bert-large-uncased-whole-word-masking" else 768 |
|
|
| assert summarize_header in ["first", "avg"] |
| self.summarize_header = summarize_header |
| self.enc_hidden_size = self.base_enc_hidden_size |
| self.use_column_type = use_column_type |
|
|
| self.include_in_memory = set(include_in_memory) |
| update_modules = { |
| 'relational_transformer': |
| spider_enc_modules.RelationalTransformerUpdate, |
| 'none': |
| spider_enc_modules.NoOpUpdate, |
| } |
|
|
| self.encs_update = registry.instantiate( |
| update_modules[update_config['name']], |
| update_config, |
| unused_keys={"name"}, |
| device=self._device, |
| hidden_size=self.enc_hidden_size, |
| sc_link = True, |
| ) |
| |
| self.bert_model = BertModel.from_pretrained(bert_version) |
| self.tokenizer = self.preproc.tokenizer |
| self.bert_model.resize_token_embeddings(len(self.tokenizer)) |
| |
|
|
| def forward(self, descs): |
| batch_token_lists = [] |
| batch_id_to_retrieve_question = [] |
| batch_id_to_retrieve_column = [] |
| batch_id_to_retrieve_table = [] |
| if self.summarize_header == "avg": |
| batch_id_to_retrieve_column_2 = [] |
| batch_id_to_retrieve_table_2 = [] |
| long_seq_set = set() |
| batch_id_map = {} |
| for batch_idx, desc in enumerate(descs): |
| qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
| if self.use_column_type: |
| cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']] |
| else: |
| cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']] |
| tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']] |
|
|
| token_list = qs + [c for col in cols for c in col] + \ |
| [t for tab in tabs for t in tab] |
| assert self.check_bert_seq(token_list) |
| if len(token_list) > 512: |
| long_seq_set.add(batch_idx) |
| continue |
|
|
| q_b = len(qs) |
| col_b = q_b + sum(len(c) for c in cols) |
| |
| question_indexes = list(range(q_b))[1:-1] |
| |
| column_indexes = \ |
| np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1] ]).tolist() |
| table_indexes = \ |
| np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist() |
| if self.summarize_header == "avg": |
| column_indexes_2 = \ |
| np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:] |
| table_indexes_2 = \ |
| np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:] |
| |
| indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list) |
| batch_token_lists.append(indexed_token_list) |
|
|
| question_rep_ids = torch.LongTensor(question_indexes).to(self._device) |
| batch_id_to_retrieve_question.append(question_rep_ids) |
| column_rep_ids = torch.LongTensor(column_indexes).to(self._device) |
| batch_id_to_retrieve_column.append(column_rep_ids) |
| table_rep_ids = torch.LongTensor(table_indexes).to(self._device) |
| batch_id_to_retrieve_table.append(table_rep_ids) |
| if self.summarize_header == "avg": |
| assert(all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2))) |
| column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device) |
| batch_id_to_retrieve_column_2.append(column_rep_ids_2) |
| assert(all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2))) |
| table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device) |
| batch_id_to_retrieve_table_2.append(table_rep_ids_2) |
|
|
| batch_id_map[batch_idx] = len(batch_id_map) |
| |
| padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists) |
| tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device) |
| att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device) |
|
|
| if self.bert_token_type: |
| tok_type_tensor = torch.LongTensor(tok_type_lists).to(self._device) |
| bert_output = self.bert_model(tokens_tensor, |
| attention_mask=att_masks_tensor, token_type_ids=tok_type_tensor)[0] |
| else: |
| bert_output = self.bert_model(tokens_tensor, |
| attention_mask=att_masks_tensor)[0] |
|
|
| enc_output = bert_output |
|
|
| column_pointer_maps = [ |
| { |
| i: [i] |
| for i in range(len(desc['columns'])) |
| } |
| for desc in descs |
| ] |
| table_pointer_maps = [ |
| { |
| i: [i] |
| for i in range(len(desc['tables'])) |
| } |
| for desc in descs |
| ] |
| |
| assert len(long_seq_set) == 0 |
|
|
| result = [] |
| for batch_idx, desc in enumerate(descs): |
| c_boundary = list(range(len(desc["columns"]) + 1)) |
| t_boundary = list(range(len(desc["tables"]) + 1)) |
|
|
| if batch_idx in long_seq_set: |
| q_enc, col_enc, tab_enc = self.encoder_long_seq(desc) |
| else: |
| bert_batch_idx = batch_id_map[batch_idx] |
| q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]] |
| col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]] |
| tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]] |
|
|
| if self.summarize_header == "avg": |
| col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]] |
| tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]] |
|
|
| col_enc = (col_enc + col_enc_2) / 2.0 |
| tab_enc = (tab_enc + tab_enc_2) / 2.0 |
| |
| assert q_enc.size()[0] == len(desc["question"]) |
| assert col_enc.size()[0] == c_boundary[-1] |
| assert tab_enc.size()[0] == t_boundary[-1] |
| |
| q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ |
| self.encs_update.forward_unbatched( |
| desc, |
| q_enc.unsqueeze(1), |
| col_enc.unsqueeze(1), |
| c_boundary, |
| tab_enc.unsqueeze(1), |
| t_boundary) |
| import pickle |
| pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc, |
| "t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb")) |
|
|
|
|
| memory = [] |
| if 'question' in self.include_in_memory: |
| memory.append(q_enc_new_item) |
| if 'column' in self.include_in_memory: |
| memory.append(c_enc_new_item) |
| if 'table' in self.include_in_memory: |
| memory.append(t_enc_new_item) |
| memory = torch.cat(memory, dim=1) |
|
|
| result.append(SpiderEncoderState( |
| state=None, |
| memory=memory, |
| question_memory=q_enc_new_item, |
| schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
| |
| words=desc['question'], |
| pointer_memories={ |
| 'column': c_enc_new_item, |
| 'table': t_enc_new_item, |
| }, |
| pointer_maps={ |
| 'column': column_pointer_maps[batch_idx], |
| 'table': table_pointer_maps[batch_idx], |
| }, |
| m2c_align_mat=align_mat_item[0], |
| m2t_align_mat=align_mat_item[1], |
| )) |
| return result |
| |
| @DeprecationWarning |
| def encoder_long_seq(self, desc): |
| """ |
| Since bert cannot handle sequence longer than 512, each column/table is encoded individually |
| The representation of a column/table is the vector of the first token [CLS] |
| """ |
| qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
| cols = [self.pad_single_sentence_for_bert(c, cls=True) for c in desc['columns']] |
| tabs = [self.pad_single_sentence_for_bert(t, cls=True) for t in desc['tables']] |
|
|
| enc_q = self._bert_encode(qs) |
| enc_col = self._bert_encode(cols) |
| enc_tab = self._bert_encode(tabs) |
| return enc_q, enc_col, enc_tab |
| |
| @DeprecationWarning |
| def _bert_encode(self, toks): |
| if not isinstance(toks[0], list): |
| indexed_tokens = self.tokenizer.convert_tokens_to_ids(toks) |
| tokens_tensor = torch.tensor([indexed_tokens]).to(self._device) |
| outputs = self.bert_model(tokens_tensor) |
| return outputs[0][0, 1:-1] |
| else: |
| max_len = max([len(it) for it in toks]) |
| tok_ids = [] |
| for item_toks in toks: |
| item_toks = item_toks + [self.tokenizer.pad_token] * (max_len - len(item_toks)) |
| indexed_tokens = self.tokenizer.convert_tokens_to_ids(item_toks) |
| tok_ids.append(indexed_tokens) |
|
|
| tokens_tensor = torch.tensor(tok_ids).to(self._device) |
| outputs = self.bert_model(tokens_tensor) |
| return outputs[0][:,0,:] |
|
|
| def check_bert_seq(self, toks): |
| if toks[0] == self.tokenizer.cls_token and toks[-1] == self.tokenizer.sep_token: |
| return True |
| else: |
| return False |
|
|
| def pad_single_sentence_for_bert(self, toks, cls=True): |
| if cls: |
| return [self.tokenizer.cls_token] + toks + [self.tokenizer.sep_token] |
| else: |
| return toks + [self.tokenizer.sep_token] |
|
|
| def pad_sequence_for_bert_batch(self, tokens_lists): |
| pad_id = self.tokenizer.pad_token_id |
| max_len = max([len(it) for it in tokens_lists]) |
| assert max_len <= 512 |
| toks_ids = [] |
| att_masks = [] |
| tok_type_lists = [] |
| for item_toks in tokens_lists: |
| padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks)) |
| toks_ids.append(padded_item_toks) |
|
|
| _att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks)) |
| att_masks.append(_att_mask) |
|
|
| first_sep_id = padded_item_toks.index(self.tokenizer.sep_token_id) |
| assert first_sep_id > 0 |
| _tok_type_list = [0] * (first_sep_id + 1) + [1] * (max_len - first_sep_id - 1) |
| tok_type_lists.append(_tok_type_list) |
| return toks_ids, att_masks, tok_type_lists |
|
|
|
|
| """ |
| ############################### |
| BART models |
| ############################### |
| """ |
|
|
| class BartTokens: |
| def __init__(self, text, tokenizer): |
| self.text = text |
| |
| self.tokenizer = tokenizer |
| self.normalized_pieces = None |
| self.idx_map = None |
|
|
| self.normalize_toks() |
|
|
| def normalize_toks(self): |
| tokens = nltk.word_tokenize(self.text.replace("'", " ' ").replace('"', ' " ')) |
| self.idx_map = {} |
| |
| toks = [] |
| for i, tok in enumerate(tokens): |
| self.idx_map[i] = len(toks) |
| toks.extend(self.tokenizer.tokenize(tok, add_prefix_space=True)) |
|
|
| normalized_toks = [] |
| for i, tok in enumerate(tokens): |
| ann = corenlp.annotate(tok, annotators=["tokenize", "ssplit", "lemma"]) |
| lemmas = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] |
| lemma_word = " ".join(lemmas) |
| normalized_toks.append(lemma_word) |
| self.normalized_pieces = normalized_toks |
|
|
| def bart_schema_linking(self, columns, tables): |
| question_tokens = self.normalized_pieces |
| column_tokens = [c.normalized_pieces for c in columns] |
| table_tokens = [t.normalized_pieces for t in tables] |
| sc_link = compute_schema_linking(question_tokens, column_tokens, table_tokens) |
|
|
| new_sc_link = {} |
| for m_type in sc_link: |
| _match = {} |
| for ij_str in sc_link[m_type]: |
| q_id_str, col_tab_id_str = ij_str.split(",") |
| q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) |
| real_q_id = self.idx_map[q_id] |
| _match[f"{real_q_id},{col_tab_id}"] = sc_link[m_type][ij_str] |
| new_sc_link[m_type] = _match |
| return new_sc_link |
|
|
| def bart_cv_linking(self, schema, db_path): |
| question_tokens = self.normalized_pieces |
| cv_link = compute_cell_value_linking(question_tokens, schema, db_path) |
|
|
| new_cv_link = {} |
| for m_type in cv_link: |
| if m_type != "normalized_token": |
| _match = {} |
| for ij_str in cv_link[m_type]: |
| q_id_str, col_tab_id_str = ij_str.split(",") |
| q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) |
| real_q_id = self.idx_map[q_id] |
| _match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str] |
|
|
| new_cv_link[m_type] = _match |
| else: |
| new_cv_link[m_type] = cv_link[m_type] |
| return new_cv_link |
|
|
|
|
|
|
|
|
| def preprocess_schema_uncached_bart(schema, |
| tokenizer, |
| tokenize_func, |
| include_table_name_in_column, |
| fix_issue_16_primary_keys, |
| bart=False): |
| """If it's bert, we also cache the normalized version of |
| question/column/table for schema linking""" |
| r = PreprocessedSchema() |
|
|
| if bart: assert not include_table_name_in_column |
|
|
| last_table_id = None |
| for i, column in enumerate(schema.columns): |
| col_toks = tokenize_func( |
| column.name, column.unsplit_name) |
|
|
| |
| type_tok = '<type: {}>'.format(column.type) |
| if bart: |
| |
| column_name = col_toks + [type_tok] |
| r.normalized_column_names.append(BartTokens(column.unsplit_name, tokenizer)) |
| else: |
| column_name = [type_tok] + col_toks |
|
|
| if include_table_name_in_column: |
| if column.table is None: |
| table_name = ['<any-table>'] |
| else: |
| table_name = tokenize_func( |
| column.table.name, column.table.unsplit_name) |
| column_name += ['<table-sep>'] + table_name |
| r.column_names.append(column_name) |
|
|
| table_id = None if column.table is None else column.table.id |
| r.column_to_table[str(i)] = table_id |
| if table_id is not None: |
| columns = r.table_to_columns.setdefault(str(table_id), []) |
| columns.append(i) |
| if last_table_id != table_id: |
| r.table_bounds.append(i) |
| last_table_id = table_id |
|
|
| if column.foreign_key_for is not None: |
| r.foreign_keys[str(column.id)] = column.foreign_key_for.id |
| r.foreign_keys_tables[str(column.table.id)].add(column.foreign_key_for.table.id) |
|
|
| r.table_bounds.append(len(schema.columns)) |
| assert len(r.table_bounds) == len(schema.tables) + 1 |
|
|
| for i, table in enumerate(schema.tables): |
| table_toks = tokenize_func( |
| table.name, table.unsplit_name) |
| r.table_names.append(table_toks) |
| if bart: |
| r.normalized_table_names.append(BartTokens(table.unsplit_name, tokenizer)) |
| last_table = schema.tables[-1] |
|
|
| r.foreign_keys_tables = serialization.to_dict_with_sorted_values(r.foreign_keys_tables) |
| r.primary_keys = [ |
| column.id |
| for table in schema.tables |
| for column in table.primary_keys |
| ] if fix_issue_16_primary_keys else [ |
| column.id |
| for column in last_table.primary_keys |
| for table in schema.tables |
| ] |
|
|
| return r |
|
|
| import nltk |
| class SpiderEncoderBartPreproc(SpiderEncoderV2Preproc): |
| |
| def __init__( |
| self, |
| save_path, |
| db_path, |
| fix_issue_16_primary_keys=False, |
| include_table_name_in_column=False, |
| bart_version = "bart-large", |
| compute_sc_link=True, |
| compute_cv_link=False): |
| self.data_dir = os.path.join(save_path, 'enc') |
| self.db_path = db_path |
| self.texts = collections.defaultdict(list) |
| self.fix_issue_16_primary_keys = fix_issue_16_primary_keys |
| self.include_table_name_in_column = include_table_name_in_column |
| self.compute_sc_link = compute_sc_link |
| self.compute_cv_link = compute_cv_link |
|
|
| self.counted_db_ids = set() |
| self.preprocessed_schemas = {} |
|
|
| self.tokenizer = BartTokenizer.from_pretrained(bart_version) |
|
|
| column_types = ["text", "number", "time", "boolean", "others"] |
| self.tokenizer.add_tokens([f"<type: {t}>" for t in column_types]) |
|
|
| def _tokenize(self, presplit, unsplit): |
| |
| |
| tokens = nltk.word_tokenize(unsplit.replace("'", " ' ").replace('"', ' " ')) |
| toks = [] |
| for token in tokens: |
| toks.extend(self.tokenizer.tokenize(token, add_prefix_space=True)) |
| return toks |
|
|
| def add_item(self, item, section, validation_info): |
| preprocessed = self.preprocess_item(item, validation_info) |
| self.texts[section].append(preprocessed) |
|
|
| def preprocess_item(self, item, validation_info): |
| |
| |
| question = self._tokenize(item.text, item.orig['question']) |
| preproc_schema = self._preprocess_schema(item.schema) |
| question_bart_tokens = BartTokens(item.orig['question'], self.tokenizer) |
| if self.compute_sc_link: |
| |
| sc_link = question_bart_tokens.bart_schema_linking( |
| preproc_schema.normalized_column_names, |
| preproc_schema.normalized_table_names |
| ) |
| else: |
| sc_link = {"q_col_match": {}, "q_tab_match": {}} |
|
|
| if self.compute_cv_link: |
| cv_link = question_bart_tokens.bart_cv_linking( |
| item.schema, self.db_path) |
| else: |
| cv_link = {"num_date_match": {}, "cell_match": {}} |
|
|
| return { |
| 'raw_question': item.orig['question'], |
| 'question': question, |
| 'db_id': item.schema.db_id, |
| 'sc_link': sc_link, |
| 'cv_link': cv_link, |
| 'columns': preproc_schema.column_names, |
| 'tables': preproc_schema.table_names, |
| 'table_bounds': preproc_schema.table_bounds, |
| 'column_to_table': preproc_schema.column_to_table, |
| 'table_to_columns': preproc_schema.table_to_columns, |
| 'foreign_keys': preproc_schema.foreign_keys, |
| 'foreign_keys_tables': preproc_schema.foreign_keys_tables, |
| 'primary_keys': preproc_schema.primary_keys, |
| } |
|
|
| def validate_item(self, item, section): |
| question = self._tokenize(item.text, item.orig['question']) |
| preproc_schema = self._preprocess_schema(item.schema) |
| |
| num_words = len(question) + 2 + \ |
| sum(len(c) + 1 for c in preproc_schema.column_names) + \ |
| sum(len(t) + 1 for t in preproc_schema.table_names) |
| if num_words > 512: |
| return False, None |
| else: |
| return True, None |
|
|
| def _preprocess_schema(self, schema): |
| if schema.db_id in self.preprocessed_schemas: |
| return self.preprocessed_schemas[schema.db_id] |
| result = preprocess_schema_uncached_bart(schema, self.tokenizer, self._tokenize, |
| self.include_table_name_in_column, |
| self.fix_issue_16_primary_keys, bart=True) |
| self.preprocessed_schemas[schema.db_id] = result |
| return result |
|
|
| def save(self): |
| os.makedirs(self.data_dir, exist_ok=True) |
| self.tokenizer.save_pretrained(self.data_dir) |
|
|
| for section, texts in self.texts.items(): |
| with open(os.path.join(self.data_dir, section + '.jsonl'), 'w') as f: |
| for text in texts: |
| f.write(json.dumps(text) + '\n') |
|
|
| def load(self): |
| self.tokenizer = BartTokenizer.from_pretrained(self.data_dir) |
|
|
|
|
| @registry.register('encoder', 'spider-bart') |
| class SpiderEncoderBart(torch.nn.Module): |
| Preproc = SpiderEncoderBartPreproc |
| batched = True |
|
|
| def __init__( |
| self, |
| device, |
| preproc, |
| update_config={}, |
| bart_version="facebook/bart-large", |
| summarize_header="first", |
| use_column_type=True, |
| include_in_memory=('question', 'column', 'table')): |
| super().__init__() |
| self._device = device |
| self.preproc = preproc |
| self.base_enc_hidden_size = 1024 |
|
|
| assert summarize_header in ["first", "avg"] |
| self.summarize_header = summarize_header |
| self.enc_hidden_size = self.base_enc_hidden_size |
| self.use_column_type = use_column_type |
|
|
| self.include_in_memory = set(include_in_memory) |
| update_modules = { |
| 'relational_transformer': |
| spider_enc_modules.RelationalTransformerUpdate, |
| 'none': |
| spider_enc_modules.NoOpUpdate, |
| } |
|
|
| self.encs_update = registry.instantiate( |
| update_modules[update_config['name']], |
| update_config, |
| unused_keys={"name"}, |
| device=self._device, |
| hidden_size=self.enc_hidden_size, |
| sc_link=True, |
| ) |
|
|
| self.bert_model = BartModel.from_pretrained(bart_version) |
| print(next(self.bert_model.encoder.parameters())) |
|
|
| def replace_model_with_pretrained(model, path, prefix): |
| restore_state_dict = torch.load( |
| path, map_location=lambda storage, location: storage) |
| keep_keys = [] |
| for key in restore_state_dict.keys(): |
| if key.startswith(prefix): |
| keep_keys.append(key) |
| loaded_dict = {k.replace(prefix, ""): restore_state_dict[k] for k in keep_keys} |
| model.load_state_dict(loaded_dict) |
| print("Updated the model with {}".format(path)) |
|
|
|
|
| self.tokenizer = self.preproc.tokenizer |
| self.bert_model.resize_token_embeddings(50266) |
|
|
| replace_model_with_pretrained(self.bert_model.encoder, os.path.join( |
| "./pretrained_checkpoint", |
| "pytorch_model.bin"), "bert.model.encoder.") |
| self.bert_model.resize_token_embeddings(len(self.tokenizer)) |
| self.bert_model = self.bert_model.encoder |
| self.bert_model.decoder = None |
|
|
| print(next(self.bert_model.parameters())) |
|
|
| def forward(self, descs): |
| batch_token_lists = [] |
| batch_id_to_retrieve_question = [] |
| batch_id_to_retrieve_column = [] |
| batch_id_to_retrieve_table = [] |
| if self.summarize_header == "avg": |
| batch_id_to_retrieve_column_2 = [] |
| batch_id_to_retrieve_table_2 = [] |
| long_seq_set = set() |
| batch_id_map = {} |
| for batch_idx, desc in enumerate(descs): |
| qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
| if self.use_column_type: |
| cols = [self.pad_single_sentence_for_bert(c, cls=False) for c in desc['columns']] |
| else: |
| cols = [self.pad_single_sentence_for_bert(c[:-1], cls=False) for c in desc['columns']] |
| tabs = [self.pad_single_sentence_for_bert(t, cls=False) for t in desc['tables']] |
|
|
| token_list = qs + [c for col in cols for c in col] + \ |
| [t for tab in tabs for t in tab] |
| assert self.check_bert_seq(token_list) |
| if len(token_list) > 512: |
| long_seq_set.add(batch_idx) |
| continue |
|
|
| q_b = len(qs) |
| col_b = q_b + sum(len(c) for c in cols) |
| |
| question_indexes = list(range(q_b))[1:-1] |
| |
| column_indexes = \ |
| np.cumsum([q_b] + [len(token_list) for token_list in cols[:-1]]).tolist() |
| table_indexes = \ |
| np.cumsum([col_b] + [len(token_list) for token_list in tabs[:-1]]).tolist() |
| if self.summarize_header == "avg": |
| column_indexes_2 = \ |
| np.cumsum([q_b - 2] + [len(token_list) for token_list in cols]).tolist()[1:] |
| table_indexes_2 = \ |
| np.cumsum([col_b - 2] + [len(token_list) for token_list in tabs]).tolist()[1:] |
|
|
| indexed_token_list = self.tokenizer.convert_tokens_to_ids(token_list) |
| batch_token_lists.append(indexed_token_list) |
|
|
| question_rep_ids = torch.LongTensor(question_indexes).to(self._device) |
| batch_id_to_retrieve_question.append(question_rep_ids) |
| column_rep_ids = torch.LongTensor(column_indexes).to(self._device) |
| batch_id_to_retrieve_column.append(column_rep_ids) |
| table_rep_ids = torch.LongTensor(table_indexes).to(self._device) |
| batch_id_to_retrieve_table.append(table_rep_ids) |
| if self.summarize_header == "avg": |
| assert (all(i2 >= i1 for i1, i2 in zip(column_indexes, column_indexes_2))) |
| column_rep_ids_2 = torch.LongTensor(column_indexes_2).to(self._device) |
| batch_id_to_retrieve_column_2.append(column_rep_ids_2) |
| assert (all(i2 >= i1 for i1, i2 in zip(table_indexes, table_indexes_2))) |
| table_rep_ids_2 = torch.LongTensor(table_indexes_2).to(self._device) |
| batch_id_to_retrieve_table_2.append(table_rep_ids_2) |
|
|
| batch_id_map[batch_idx] = len(batch_id_map) |
|
|
| padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists) |
| tokens_tensor = torch.LongTensor(padded_token_lists).to(self._device) |
| att_masks_tensor = torch.LongTensor(att_mask_lists).to(self._device) |
|
|
|
|
| bert_output = self.bert_model(tokens_tensor, |
| attention_mask=att_masks_tensor)[0] |
|
|
| enc_output = bert_output |
|
|
| column_pointer_maps = [ |
| { |
| i: [i] |
| for i in range(len(desc['columns'])) |
| } |
| for desc in descs |
| ] |
| table_pointer_maps = [ |
| { |
| i: [i] |
| for i in range(len(desc['tables'])) |
| } |
| for desc in descs |
| ] |
|
|
| assert len(long_seq_set) == 0 |
|
|
| result = [] |
| for batch_idx, desc in enumerate(descs): |
| c_boundary = list(range(len(desc["columns"]) + 1)) |
| t_boundary = list(range(len(desc["tables"]) + 1)) |
|
|
| if batch_idx in long_seq_set: |
| q_enc, col_enc, tab_enc = self.encoder_long_seq(desc) |
| else: |
| bert_batch_idx = batch_id_map[batch_idx] |
| q_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_question[bert_batch_idx]] |
| col_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_column[bert_batch_idx]] |
| tab_enc = enc_output[bert_batch_idx][batch_id_to_retrieve_table[bert_batch_idx]] |
|
|
| if self.summarize_header == "avg": |
| col_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_column_2[bert_batch_idx]] |
| tab_enc_2 = enc_output[bert_batch_idx][batch_id_to_retrieve_table_2[bert_batch_idx]] |
|
|
| col_enc = (col_enc + col_enc_2) / 2.0 |
| tab_enc = (tab_enc + tab_enc_2) / 2.0 |
|
|
| assert q_enc.size()[0] == len(desc["question"]) |
| assert col_enc.size()[0] == c_boundary[-1] |
| assert tab_enc.size()[0] == t_boundary[-1] |
|
|
| q_enc_new_item, c_enc_new_item, t_enc_new_item, align_mat_item = \ |
| self.encs_update.forward_unbatched( |
| desc, |
| q_enc.unsqueeze(1), |
| col_enc.unsqueeze(1), |
| c_boundary, |
| tab_enc.unsqueeze(1), |
| t_boundary) |
| import pickle |
| pickle.dump({"desc": desc, "q_enc": q_enc, "col_enc": col_enc, "c_boundary": c_boundary, "tab_enc": tab_enc, |
| "t_boundary": t_boundary}, open("descs_{}.pkl".format(batch_idx), "wb")) |
|
|
| memory = [] |
| if 'question' in self.include_in_memory: |
| memory.append(q_enc_new_item) |
| if 'column' in self.include_in_memory: |
| memory.append(c_enc_new_item) |
| if 'table' in self.include_in_memory: |
| memory.append(t_enc_new_item) |
| memory = torch.cat(memory, dim=1) |
|
|
| result.append(SpiderEncoderState( |
| state=None, |
| memory=memory, |
| question_memory=q_enc_new_item, |
| schema_memory=torch.cat((c_enc_new_item, t_enc_new_item), dim=1), |
| |
| words=desc['question'], |
| pointer_memories={ |
| 'column': c_enc_new_item, |
| 'table': t_enc_new_item, |
| }, |
| pointer_maps={ |
| 'column': column_pointer_maps[batch_idx], |
| 'table': table_pointer_maps[batch_idx], |
| }, |
| m2c_align_mat=align_mat_item[0], |
| m2t_align_mat=align_mat_item[1], |
| )) |
| return result |
|
|
| @DeprecationWarning |
| def encoder_long_seq(self, desc): |
| """ |
| Since bert cannot handle sequence longer than 512, each column/table is encoded individually |
| The representation of a column/table is the vector of the first token [CLS] |
| """ |
| qs = self.pad_single_sentence_for_bert(desc['question'], cls=True) |
| cols = [self.pad_single_sentence_for_bert(c, cls=True) for c in desc['columns']] |
| tabs = [self.pad_single_sentence_for_bert(t, cls=True) for t in desc['tables']] |
|
|
| enc_q = self._bert_encode(qs) |
| enc_col = self._bert_encode(cols) |
| enc_tab = self._bert_encode(tabs) |
| return enc_q, enc_col, enc_tab |
|
|
| @DeprecationWarning |
| def _bert_encode(self, toks): |
| if not isinstance(toks[0], list): |
| indexed_tokens = self.tokenizer.convert_tokens_to_ids(toks) |
| tokens_tensor = torch.tensor([indexed_tokens]).to(self._device) |
| outputs = self.bert_model(tokens_tensor) |
| return outputs[0][0, 1:-1] |
| else: |
| max_len = max([len(it) for it in toks]) |
| tok_ids = [] |
| for item_toks in toks: |
| item_toks = item_toks + [self.tokenizer.pad_token] * (max_len - len(item_toks)) |
| indexed_tokens = self.tokenizer.convert_tokens_to_ids(item_toks) |
| tok_ids.append(indexed_tokens) |
|
|
| tokens_tensor = torch.tensor(tok_ids).to(self._device) |
| outputs = self.bert_model(tokens_tensor) |
| return outputs[0][:, 0, :] |
|
|
| def check_bert_seq(self, toks): |
| if toks[0] == self.tokenizer.cls_token and toks[-1] == self.tokenizer.sep_token: |
| return True |
| else: |
| return False |
|
|
| def pad_single_sentence_for_bert(self, toks, cls=True): |
| if cls: |
| return [self.tokenizer.cls_token] + toks + [self.tokenizer.sep_token] |
| else: |
| return toks + [self.tokenizer.sep_token] |
|
|
| def pad_sequence_for_bert_batch(self, tokens_lists): |
| pad_id = self.tokenizer.pad_token_id |
| max_len = max([len(it) for it in tokens_lists]) |
| assert max_len <= 512 |
| toks_ids = [] |
| att_masks = [] |
| tok_type_lists = [] |
| for item_toks in tokens_lists: |
| padded_item_toks = item_toks + [pad_id] * (max_len - len(item_toks)) |
| toks_ids.append(padded_item_toks) |
|
|
| _att_mask = [1] * len(item_toks) + [0] * (max_len - len(item_toks)) |
| att_masks.append(_att_mask) |
|
|
| first_sep_id = padded_item_toks.index(self.tokenizer.sep_token_id) |
| assert first_sep_id > 0 |
| _tok_type_list = [0] * (first_sep_id + 1) + [1] * (max_len - first_sep_id - 1) |
| tok_type_lists.append(_tok_type_list) |
| return toks_ids, att_masks, tok_type_lists |