| import itertools |
| import operator |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| import torchtext |
|
|
| from seq2struct.models import variational_lstm |
| from seq2struct.models import transformer |
| from seq2struct.utils import batched_sequence |
|
|
|
|
| def clamp(value, abs_max): |
| value = max(-abs_max, value) |
| value = min(abs_max, value) |
| return value |
|
|
|
|
| def get_attn_mask(seq_lengths): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| max_length, batch_size = int(max(seq_lengths)), len(seq_lengths) |
| attn_mask = torch.LongTensor(batch_size, max_length, max_length).fill_(0) |
| for batch_idx, seq_length in enumerate(seq_lengths): |
| attn_mask[batch_idx, :seq_length, :seq_length] = 1 |
| return attn_mask |
|
|
|
|
| class LookupEmbeddings(torch.nn.Module): |
| def __init__(self, device, vocab, embedder, emb_size, learnable_words=[]): |
| super().__init__() |
| self._device = device |
| self.vocab = vocab |
| self.embedder = embedder |
| self.emb_size = emb_size |
|
|
| self.embedding = torch.nn.Embedding( |
| num_embeddings=len(self.vocab), |
| embedding_dim=emb_size) |
| if self.embedder: |
| assert emb_size == self.embedder.dim |
|
|
| |
| self.learnable_words = learnable_words |
| init_embed_list = [] |
| for i, word in enumerate(self.vocab): |
| if self.embedder.contains(word): |
| init_embed_list.append( \ |
| self.embedder.lookup(word)) |
| else: |
| init_embed_list.append(self.embedding.weight[i]) |
| init_embed_weight = torch.stack(init_embed_list, 0) |
| self.embedding.weight = nn.Parameter(init_embed_weight) |
|
|
| def forward_unbatched(self, token_lists): |
| |
| |
| |
| |
|
|
| embs = [] |
| for tokens in token_lists: |
| |
| token_indices = torch.tensor( |
| self.vocab.indices(tokens), device=self._device).unsqueeze(0) |
|
|
| |
| emb = self.embedding(token_indices) |
|
|
| |
| emb = emb.transpose(0, 1) |
| embs.append(emb) |
|
|
| |
| all_embs = torch.cat(embs, dim=0) |
|
|
| |
| |
| |
| boundaries = np.cumsum([0] + [emb.shape[0] for emb in embs]) |
|
|
| return all_embs, boundaries |
| |
| def _compute_boundaries(self, token_lists): |
| |
| |
| |
| |
| boundaries = [ |
| np.cumsum([0] + [len(token_list) for token_list in token_lists_for_item]) |
| for token_lists_for_item in token_lists] |
|
|
| return boundaries |
| |
| def _embed_token(self, token, batch_idx): |
| if token in self.learnable_words or not self.embedder.contains(token): |
| return self.embedding.weight[self.vocab.index(token)] |
| else: |
| emb = self.embedder.lookup(token) |
| return emb.to(self._device) |
|
|
| def forward(self, token_lists): |
| |
| |
| |
| |
| |
| all_embs = batched_sequence.PackedSequencePlus.from_lists( |
| lists=[ |
| [ |
| token |
| for token_list in token_lists_for_item |
| for token in token_list |
| ] |
| for token_lists_for_item in token_lists |
| ], |
| item_shape=(self.emb_size,), |
| device=self._device, |
| item_to_tensor=self._embed_token) |
| all_embs = all_embs.apply(lambda d: d.to(self._device)) |
| |
| return all_embs, self._compute_boundaries(token_lists) |
|
|
| def _embed_words_learned(self, token_lists): |
| |
| |
| |
| |
|
|
| |
| indices = batched_sequence.PackedSequencePlus.from_lists( |
| lists=[ |
| [ |
| token |
| for token_list in token_lists_for_item |
| for token in token_list |
| ] |
| for token_lists_for_item in token_lists |
| ], |
| item_shape=(1,), |
| tensor_type=torch.LongTensor, |
| item_to_tensor=lambda token, batch_idx, out: out.fill_(self.vocab.index(token)) |
| ) |
| indices = indices.apply(lambda d: d.to(self._device)) |
| |
| all_embs = indices.apply(lambda x: self.embedding(x.squeeze(-1))) |
|
|
| return all_embs, self._compute_boundaries(token_lists) |
|
|
|
|
| class EmbLinear(torch.nn.Module): |
| def __init__(self, input_size, output_size): |
| super().__init__() |
| self.linear = torch.nn.Linear(input_size, output_size) |
| |
| def forward(self, input_): |
| all_embs, boundaries = input_ |
| all_embs = all_embs.apply(lambda d: self.linear(d)) |
| return all_embs, boundaries |
|
|
|
|
| class BiLSTM(torch.nn.Module): |
| def __init__(self, input_size, output_size, dropout, summarize, use_native=False): |
| |
| |
| |
| |
| |
| |
| super().__init__() |
|
|
| if use_native: |
| self.lstm = torch.nn.LSTM( |
| input_size=input_size, |
| hidden_size=output_size // 2, |
| bidirectional=True, |
| dropout=dropout) |
| self.dropout = torch.nn.Dropout(dropout) |
| else: |
| self.lstm = variational_lstm.LSTM( |
| input_size=input_size, |
| hidden_size=int(output_size // 2), |
| bidirectional=True, |
| dropout=dropout) |
| self.summarize = summarize |
| self.use_native = use_native |
|
|
| def forward_unbatched(self, input_): |
| |
| all_embs, boundaries = input_ |
|
|
| new_boundaries = [0] |
| outputs = [] |
| for left, right in zip(boundaries, boundaries[1:]): |
| |
| |
| |
| |
| if self.use_native: |
| inp = self.dropout(all_embs[left:right]) |
| output, (h, c) = self.lstm(inp) |
| else: |
| output, (h, c) = self.lstm(all_embs[left:right]) |
| if self.summarize: |
| seq_emb = torch.cat((h[0], h[1]), dim=-1).unsqueeze(0) |
| new_boundaries.append(new_boundaries[-1] + 1) |
| else: |
| seq_emb = output |
| new_boundaries.append(new_boundaries[-1] + output.shape[0]) |
| outputs.append(seq_emb) |
|
|
| return torch.cat(outputs, dim=0), new_boundaries |
|
|
| def forward(self, input_): |
| |
| |
| all_embs, boundaries = input_ |
|
|
| |
| |
| desc_lengths = [] |
| batch_desc_to_flat_map = {} |
| for batch_idx, boundaries_for_item in enumerate(boundaries): |
| for desc_idx, (left, right) in enumerate(zip(boundaries_for_item, boundaries_for_item[1:])): |
| desc_lengths.append((batch_idx, desc_idx, right - left)) |
| batch_desc_to_flat_map[batch_idx, desc_idx] = len(batch_desc_to_flat_map) |
|
|
| |
| |
| |
| remapped_ps_indices = [] |
| def rearranged_all_embs_map_index(desc_lengths_idx, seq_idx): |
| batch_idx, desc_idx, _ = desc_lengths[desc_lengths_idx] |
| return batch_idx, boundaries[batch_idx][desc_idx] + seq_idx |
| def rearranged_all_embs_gather_from_indices(indices): |
| batch_indices, seq_indices = zip(*indices) |
| remapped_ps_indices[:] = all_embs.raw_index(batch_indices, seq_indices) |
| return all_embs.ps.data[torch.LongTensor(remapped_ps_indices)] |
| rearranged_all_embs = batched_sequence.PackedSequencePlus.from_gather( |
| lengths=[length for _, _, length in desc_lengths], |
| map_index=rearranged_all_embs_map_index, |
| gather_from_indices=rearranged_all_embs_gather_from_indices) |
| rev_remapped_ps_indices = tuple( |
| x[0] for x in sorted( |
| enumerate(remapped_ps_indices), key=operator.itemgetter(1))) |
|
|
| |
| |
| |
| |
| if self.use_native: |
| rearranged_all_embs = rearranged_all_embs.apply(self.dropout) |
| output, (h, c) = self.lstm(rearranged_all_embs.ps) |
| if self.summarize: |
| |
| h = torch.cat((h[0], h[1]), dim=-1) |
|
|
| |
| new_all_embs = batched_sequence.PackedSequencePlus.from_gather( |
| lengths=[len(boundaries_for_item) - 1 for boundaries_for_item in boundaries], |
| map_index=lambda batch_idx, desc_idx: rearranged_all_embs.sort_to_orig[batch_desc_to_flat_map[batch_idx, desc_idx]], |
| gather_from_indices=lambda indices: h[torch.LongTensor(indices)]) |
|
|
| new_boundaries = [ |
| list(range(len(boundaries_for_item))) |
| for boundaries_for_item in boundaries |
| ] |
| else: |
| new_all_embs = all_embs.apply( |
| lambda _: output.data[torch.LongTensor(rev_remapped_ps_indices)]) |
| new_boundaries = boundaries |
|
|
| return new_all_embs, new_boundaries |
|
|
|
|
| class RelationalTransformerUpdate(torch.nn.Module): |
|
|
| def __init__(self, device, num_layers, num_heads, hidden_size, |
| ff_size=None, |
| dropout=0.1, |
| merge_types=False, |
| tie_layers=False, |
| qq_max_dist=2, |
| |
| |
| |
| cc_foreign_key=True, |
| cc_table_match=True, |
| cc_max_dist=2, |
| ct_foreign_key=True, |
| ct_table_match=True, |
| |
| tc_table_match=True, |
| tc_foreign_key=True, |
| tt_max_dist=2, |
| tt_foreign_key=True, |
| sc_link=False, |
| cv_link=False, |
| ): |
| super().__init__() |
| self._device = device |
| self.num_heads = num_heads |
|
|
| self.qq_max_dist = qq_max_dist |
| |
| |
| |
| self.cc_foreign_key = cc_foreign_key |
| self.cc_table_match = cc_table_match |
| self.cc_max_dist = cc_max_dist |
| self.ct_foreign_key = ct_foreign_key |
| self.ct_table_match = ct_table_match |
| |
| self.tc_table_match = tc_table_match |
| self.tc_foreign_key = tc_foreign_key |
| self.tt_max_dist = tt_max_dist |
| self.tt_foreign_key = tt_foreign_key |
|
|
| self.relation_ids = {} |
| def add_relation(name): |
| self.relation_ids[name] = len(self.relation_ids) |
| def add_rel_dist(name, max_dist): |
| for i in range(-max_dist, max_dist + 1): |
| add_relation((name, i)) |
|
|
| add_rel_dist('qq_dist', qq_max_dist) |
|
|
| add_relation('qc_default') |
| |
| |
|
|
| add_relation('qt_default') |
| |
| |
|
|
| add_relation('cq_default') |
| |
| |
|
|
| add_relation('cc_default') |
| if cc_foreign_key: |
| add_relation('cc_foreign_key_forward') |
| add_relation('cc_foreign_key_backward') |
| if cc_table_match: |
| add_relation('cc_table_match') |
| add_rel_dist('cc_dist', cc_max_dist) |
|
|
| add_relation('ct_default') |
| if ct_foreign_key: |
| add_relation('ct_foreign_key') |
| if ct_table_match: |
| add_relation('ct_primary_key') |
| add_relation('ct_table_match') |
| add_relation('ct_any_table') |
|
|
| add_relation('tq_default') |
| |
| |
|
|
| add_relation('tc_default') |
| if tc_table_match: |
| add_relation('tc_primary_key') |
| add_relation('tc_table_match') |
| add_relation('tc_any_table') |
| if tc_foreign_key: |
| add_relation('tc_foreign_key') |
|
|
| add_relation('tt_default') |
| if tt_foreign_key: |
| add_relation('tt_foreign_key_forward') |
| add_relation('tt_foreign_key_backward') |
| add_relation('tt_foreign_key_both') |
| add_rel_dist('tt_dist', tt_max_dist) |
|
|
| |
| |
| if sc_link: |
| add_relation('qcCEM') |
| add_relation('cqCEM') |
| add_relation('qtTEM') |
| add_relation('tqTEM') |
| add_relation('qcCPM') |
| add_relation('cqCPM') |
| add_relation('qtTPM') |
| add_relation('tqTPM') |
| |
| if cv_link: |
| add_relation("qcNUMBER") |
| add_relation("cqNUMBER") |
| add_relation("qcTIME") |
| add_relation("cqTIME") |
| add_relation("qcCELLMATCH") |
| add_relation("cqCELLMATCH") |
|
|
| if merge_types: |
| assert not cc_foreign_key |
| assert not cc_table_match |
| assert not ct_foreign_key |
| assert not ct_table_match |
| assert not tc_foreign_key |
| assert not tc_table_match |
| assert not tt_foreign_key |
|
|
| assert cc_max_dist == qq_max_dist |
| assert tt_max_dist == qq_max_dist |
|
|
| add_relation('xx_default') |
| self.relation_ids['qc_default'] = self.relation_ids['xx_default'] |
| self.relation_ids['qt_default'] = self.relation_ids['xx_default'] |
| self.relation_ids['cq_default'] = self.relation_ids['xx_default'] |
| self.relation_ids['cc_default'] = self.relation_ids['xx_default'] |
| self.relation_ids['ct_default'] = self.relation_ids['xx_default'] |
| self.relation_ids['tq_default'] = self.relation_ids['xx_default'] |
| self.relation_ids['tc_default'] = self.relation_ids['xx_default'] |
| self.relation_ids['tt_default'] = self.relation_ids['xx_default'] |
|
|
| if sc_link: |
| self.relation_ids['qcCEM'] = self.relation_ids['xx_default'] |
| self.relation_ids['qcCPM'] = self.relation_ids['xx_default'] |
| self.relation_ids['qtTEM'] = self.relation_ids['xx_default'] |
| self.relation_ids['qtTPM'] = self.relation_ids['xx_default'] |
| self.relation_ids['cqCEM'] = self.relation_ids['xx_default'] |
| self.relation_ids['cqCPM'] = self.relation_ids['xx_default'] |
| self.relation_ids['tqTEM'] = self.relation_ids['xx_default'] |
| self.relation_ids['tqTPM'] = self.relation_ids['xx_default'] |
| if cv_link: |
| self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default'] |
| self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default'] |
| self.relation_ids["qcTIME"] = self.relation_ids['xx_default'] |
| self.relation_ids["cqTIME"] = self.relation_ids['xx_default'] |
| self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default'] |
| self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default'] |
|
|
| for i in range(-qq_max_dist, qq_max_dist + 1): |
| self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i] |
| self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i] |
|
|
| if ff_size is None: |
| ff_size = hidden_size * 4 |
| self.encoder = transformer.Encoder( |
| lambda: transformer.EncoderLayer( |
| hidden_size, |
| transformer.MultiHeadedAttentionWithRelations( |
| num_heads, |
| hidden_size, |
| dropout), |
| transformer.PositionwiseFeedForward( |
| hidden_size, |
| ff_size, |
| dropout), |
| len(self.relation_ids), |
| dropout), |
| hidden_size, |
| num_layers, |
| tie_layers) |
| |
| self.align_attn = transformer.PointerWithRelations(hidden_size, |
| len(self.relation_ids), dropout) |
| |
| def create_align_mask(self, num_head, q_length, c_length, t_length): |
| |
| all_length = q_length + c_length + t_length |
| mask_1 = torch.ones(num_head - 1, all_length, all_length, device=self._device) |
| mask_2 = torch.zeros(1, all_length, all_length, device=self._device) |
| for i in range(q_length): |
| for j in range(q_length, q_length + c_length): |
| mask_2[0, i, j] = 1 |
| mask_2[0, j, i] = 1 |
| mask = torch.cat([mask_1, mask_2], 0) |
| return mask |
|
|
| |
| def forward_unbatched(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries): |
| |
| enc = torch.cat((q_enc, c_enc, t_enc), dim=0) |
|
|
| |
| enc = enc.transpose(0, 1) |
|
|
| |
| relations = self.compute_relations( |
| desc, |
| enc_length=enc.shape[1], |
| q_enc_length=q_enc.shape[0], |
| c_enc_length=c_enc.shape[0], |
| c_boundaries=c_boundaries, |
| t_boundaries=t_boundaries) |
|
|
| relations_t = torch.LongTensor(relations).to(self._device) |
| enc_new = self.encoder(enc, relations_t, mask=None) |
|
|
| |
| c_base = q_enc.shape[0] |
| t_base = q_enc.shape[0] + c_enc.shape[0] |
| q_enc_new = enc_new[:, :c_base] |
| c_enc_new = enc_new[:, c_base:t_base] |
| t_enc_new = enc_new[:, t_base:] |
|
|
| m2c_align_mat = self.align_attn(enc_new, enc_new[:, c_base:t_base], \ |
| enc_new[:, c_base:t_base], relations_t[:, c_base:t_base]) |
| m2t_align_mat = self.align_attn(enc_new, enc_new[:, t_base:], \ |
| enc_new[:, t_base:], relations_t[:, t_base:]) |
| return q_enc_new, c_enc_new, t_enc_new, (m2c_align_mat, m2t_align_mat) |
|
|
| def forward(self, descs, q_enc, c_enc, c_boundaries, t_enc, t_boundaries): |
| |
| |
| enc = batched_sequence.PackedSequencePlus.cat_seqs((q_enc, c_enc, t_enc)) |
|
|
| q_enc_lengths = list(q_enc.orig_lengths()) |
| c_enc_lengths = list(c_enc.orig_lengths()) |
| t_enc_lengths = list(t_enc.orig_lengths()) |
| enc_lengths = list(enc.orig_lengths()) |
| max_enc_length = max(enc_lengths) |
|
|
| all_relations = [] |
| for batch_idx, desc in enumerate(descs): |
| enc_length = enc_lengths[batch_idx] |
| relations_for_item = self.compute_relations( |
| desc, |
| enc_length, |
| q_enc_lengths[batch_idx], |
| c_enc_lengths[batch_idx], |
| c_boundaries[batch_idx], |
| t_boundaries[batch_idx]) |
| all_relations.append(np.pad(relations_for_item, ((0, max_enc_length - enc_length),), 'constant')) |
| relations_t = torch.from_numpy(np.stack(all_relations)).to(self._device) |
|
|
| |
| mask = get_attn_mask(enc_lengths).to(self._device) |
| |
| enc_padded, _ = enc.pad(batch_first=True) |
| enc_new = self.encoder(enc_padded, relations_t, mask=mask) |
|
|
| |
| def gather_from_enc_new(indices): |
| batch_indices, seq_indices = zip(*indices) |
| return enc_new[torch.LongTensor(batch_indices), torch.LongTensor(seq_indices)] |
|
|
| q_enc_new = batched_sequence.PackedSequencePlus.from_gather( |
| lengths=q_enc_lengths, |
| map_index=lambda batch_idx, seq_idx: (batch_idx, seq_idx), |
| gather_from_indices=gather_from_enc_new) |
| c_enc_new = batched_sequence.PackedSequencePlus.from_gather( |
| lengths=c_enc_lengths, |
| map_index=lambda batch_idx, seq_idx: (batch_idx, q_enc_lengths[batch_idx] + seq_idx), |
| gather_from_indices=gather_from_enc_new) |
| t_enc_new = batched_sequence.PackedSequencePlus.from_gather( |
| lengths=t_enc_lengths, |
| map_index=lambda batch_idx, seq_idx: (batch_idx, q_enc_lengths[batch_idx] + c_enc_lengths[batch_idx] + seq_idx), |
| gather_from_indices=gather_from_enc_new) |
| return q_enc_new, c_enc_new, t_enc_new |
|
|
| def compute_relations(self, desc, enc_length, q_enc_length, c_enc_length, c_boundaries, t_boundaries): |
| sc_link = desc.get('sc_link', {'q_col_match': {}, 'q_tab_match': {}}) |
| cv_link = desc.get('cv_link', {'num_date_match': {}, 'cell_match': {}}) |
|
|
| |
| loc_types = {} |
| for i in range(q_enc_length): |
| loc_types[i] = ('question',) |
|
|
| c_base = q_enc_length |
| for c_id, (c_start, c_end) in enumerate(zip(c_boundaries, c_boundaries[1:])): |
| for i in range(c_start + c_base, c_end + c_base): |
| loc_types[i] = ('column', c_id) |
| t_base = q_enc_length + c_enc_length |
| for t_id, (t_start, t_end) in enumerate(zip(t_boundaries, t_boundaries[1:])): |
| for i in range(t_start + t_base, t_end + t_base): |
| loc_types[i] = ('table', t_id) |
| |
| relations = np.empty((enc_length, enc_length), dtype=np.int64) |
|
|
| for i, j in itertools.product(range(enc_length),repeat=2): |
| def set_relation(name): |
| relations[i, j] = self.relation_ids[name] |
|
|
| i_type, j_type = loc_types[i], loc_types[j] |
| if i_type[0] == 'question': |
| if j_type[0] == 'question': |
| set_relation(('qq_dist', clamp(j - i, self.qq_max_dist))) |
| elif j_type[0] == 'column': |
| |
| j_real = j - c_base |
| if f"{i},{j_real}" in sc_link["q_col_match"]: |
| set_relation("qc" + sc_link["q_col_match"][f"{i},{j_real}"]) |
| elif f"{i},{j_real}" in cv_link["cell_match"]: |
| set_relation("qc" + cv_link["cell_match"][f"{i},{j_real}"]) |
| elif f"{i},{j_real}" in cv_link["num_date_match"]: |
| set_relation("qc" + cv_link["num_date_match"][f"{i},{j_real}"]) |
| else: |
| set_relation('qc_default') |
| elif j_type[0] == 'table': |
| |
| j_real = j - t_base |
| if f"{i},{j_real}" in sc_link["q_tab_match"]: |
| set_relation("qt" + sc_link["q_tab_match"][f"{i},{j_real}"]) |
| else: |
| set_relation('qt_default') |
|
|
| elif i_type[0] == 'column': |
| if j_type[0] == 'question': |
| |
| i_real = i - c_base |
| if f"{j},{i_real}" in sc_link["q_col_match"]: |
| set_relation("cq" + sc_link["q_col_match"][f"{j},{i_real}"]) |
| elif f"{j},{i_real}" in cv_link["cell_match"]: |
| set_relation("cq" + cv_link["cell_match"][f"{j},{i_real}"]) |
| elif f"{j},{i_real}" in cv_link["num_date_match"]: |
| set_relation("cq" + cv_link["num_date_match"][f"{j},{i_real}"]) |
| else: |
| set_relation('cq_default') |
| elif j_type[0] == 'column': |
| col1, col2 = i_type[1], j_type[1] |
| if col1 == col2: |
| set_relation(('cc_dist', clamp(j - i, self.cc_max_dist))) |
| else: |
| set_relation('cc_default') |
| if self.cc_foreign_key: |
| if desc['foreign_keys'].get(str(col1)) == col2: |
| set_relation('cc_foreign_key_forward') |
| if desc['foreign_keys'].get(str(col2)) == col1: |
| set_relation('cc_foreign_key_backward') |
| if (self.cc_table_match and |
| desc['column_to_table'][str(col1)] == desc['column_to_table'][str(col2)]): |
| set_relation('cc_table_match') |
|
|
| elif j_type[0] == 'table': |
| col, table = i_type[1], j_type[1] |
| set_relation('ct_default') |
| if self.ct_foreign_key and self.match_foreign_key(desc, col, table): |
| set_relation('ct_foreign_key') |
| if self.ct_table_match: |
| col_table = desc['column_to_table'][str(col)] |
| if col_table == table: |
| if col in desc['primary_keys']: |
| set_relation('ct_primary_key') |
| else: |
| set_relation('ct_table_match') |
| elif col_table is None: |
| set_relation('ct_any_table') |
|
|
| elif i_type[0] == 'table': |
| if j_type[0] == 'question': |
| |
| i_real = i - t_base |
| if f"{j},{i_real}" in sc_link["q_tab_match"]: |
| set_relation("tq" + sc_link["q_tab_match"][f"{j},{i_real}"]) |
| else: |
| set_relation('tq_default') |
| elif j_type[0] == 'column': |
| table, col = i_type[1], j_type[1] |
| set_relation('tc_default') |
|
|
| if self.tc_foreign_key and self.match_foreign_key(desc, col, table): |
| set_relation('tc_foreign_key') |
| if self.tc_table_match: |
| col_table = desc['column_to_table'][str(col)] |
| if col_table == table: |
| if col in desc['primary_keys']: |
| set_relation('tc_primary_key') |
| else: |
| set_relation('tc_table_match') |
| elif col_table is None: |
| set_relation('tc_any_table') |
| elif j_type[0] == 'table': |
| table1, table2 = i_type[1], j_type[1] |
| if table1 == table2: |
| set_relation(('tt_dist', clamp(j - i, self.tt_max_dist))) |
| else: |
| set_relation('tt_default') |
| if self.tt_foreign_key: |
| forward = table2 in desc['foreign_keys_tables'].get(str(table1), ()) |
| backward = table1 in desc['foreign_keys_tables'].get(str(table2), ()) |
| if forward and backward: |
| set_relation('tt_foreign_key_both') |
| elif forward: |
| set_relation('tt_foreign_key_forward') |
| elif backward: |
| set_relation('tt_foreign_key_backward') |
| return relations |
|
|
| @classmethod |
| def match_foreign_key(cls, desc, col, table): |
| foreign_key_for = desc['foreign_keys'].get(str(col)) |
| if foreign_key_for is None: |
| return False |
|
|
| foreign_table = desc['column_to_table'][str(foreign_key_for)] |
| return desc['column_to_table'][str(col)] == foreign_table |
|
|
|
|
| class NoOpUpdate: |
| def __init__(self, device, hidden_size): |
| pass |
|
|
| def __call__(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries): |
| |
| return q_enc, c_enc, t_enc |
| |
| def forward_unbatched(self, desc, q_enc, c_enc, c_boundaries, t_enc, t_boundaries): |
| """ |
| The same interface with RAT |
| return: encodings with size: length * embed_size, alignment matrix |
| """ |
| return q_enc.transpose(0,1), c_enc.transpose(0,1), t_enc.transpose(0,1), (None, None) |
|
|
|
|