| import operator |
| import networkx as nx |
| import attr |
| import torch |
|
|
| from seq2struct.beam_search import Hypothesis |
| from seq2struct.models.nl2code.decoder import TreeState, get_field_presence_info |
| from seq2struct.models.nl2code.tree_traversal import TreeTraversal |
|
|
| @attr.s |
| class Hypothesis4Filtering(Hypothesis): |
| column_history = attr.ib(factory=list) |
| table_history = attr.ib(factory=list) |
| key_column_history = attr.ib(factory=list) |
|
|
|
|
| def beam_search_with_heuristics(model, orig_item, preproc_item, beam_size, max_steps, from_cond=True): |
| """ |
| Find the valid FROM clasue with beam search |
| """ |
| inference_state, next_choices = model.begin_inference(orig_item, preproc_item) |
| beam = [Hypothesis4Filtering(inference_state, next_choices)] |
|
|
| cached_finished_seqs = [] |
| beam_prefix = beam |
| while True: |
| |
| prefixes2fill_from = [] |
| for step in range(max_steps): |
| if len(prefixes2fill_from) >= beam_size: |
| break |
|
|
| candidates = [] |
| for hyp in beam_prefix: |
| |
| if hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \ |
| and hyp.inference_state.cur_item.node_type == "from": |
| prefixes2fill_from.append(hyp) |
| else: |
| candidates += [(hyp, choice, choice_score.item(), |
| hyp.score + choice_score.item()) |
| for choice, choice_score in hyp.next_choices] |
| candidates.sort(key=operator.itemgetter(3), reverse=True) |
| candidates = candidates[:beam_size-len(prefixes2fill_from)] |
|
|
| |
| beam_prefix = [] |
| for hyp, choice, choice_score, cum_score in candidates: |
| inference_state = hyp.inference_state.clone() |
|
|
| |
| column_history = hyp.column_history[:] |
| if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY and \ |
| hyp.inference_state.cur_item.node_type == "column": |
| column_history = column_history + [choice] |
|
|
| next_choices = inference_state.step(choice) |
| assert next_choices is not None |
| beam_prefix.append( |
| Hypothesis4Filtering(inference_state, next_choices, cum_score, |
| hyp.choice_history + [choice], |
| hyp.score_history + [choice_score], |
| column_history)) |
|
|
| prefixes2fill_from.sort(key=operator.attrgetter('score'), reverse=True) |
| |
|
|
| |
| beam_from = prefixes2fill_from |
| max_size = 6 |
| unfiltered_finished = [] |
| prefixes_unfinished = [] |
| for step in range(max_steps): |
| if len(unfiltered_finished) + len(prefixes_unfinished) > max_size: |
| break |
|
|
| candidates = [] |
| for hyp in beam_from: |
| if step > 0 and hyp.inference_state.cur_item.state == TreeTraversal.State.CHILDREN_APPLY \ |
| and hyp.inference_state.cur_item.node_type == "from": |
| prefixes_unfinished.append(hyp) |
| else: |
| candidates += [(hyp, choice, choice_score.item(), |
| hyp.score + choice_score.item()) |
| for choice, choice_score in hyp.next_choices] |
| candidates.sort(key=operator.itemgetter(3), reverse=True) |
| candidates = candidates[:max_size - len(prefixes_unfinished)] |
|
|
| beam_from = [] |
| for hyp, choice, choice_score, cum_score in candidates: |
| inference_state = hyp.inference_state.clone() |
|
|
| |
| table_history = hyp.table_history[:] |
| key_column_history = hyp.key_column_history[:] |
| if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: |
| if hyp.inference_state.cur_item.node_type == "table": |
| table_history = table_history + [choice] |
| elif hyp.inference_state.cur_item.node_type == "column": |
| key_column_history = key_column_history + [choice] |
|
|
| next_choices = inference_state.step(choice) |
| if next_choices is None: |
| unfiltered_finished.append(Hypothesis4Filtering( |
| inference_state, |
| None, |
| cum_score, |
| hyp.choice_history + [choice], |
| hyp.score_history + [choice_score], |
| hyp.column_history, table_history, |
| key_column_history)) |
| else: |
| beam_from.append( |
| Hypothesis4Filtering(inference_state, next_choices, cum_score, |
| hyp.choice_history + [choice], |
| hyp.score_history + [choice_score], |
| hyp.column_history, table_history, |
| key_column_history)) |
|
|
| unfiltered_finished.sort(key=operator.attrgetter('score'), reverse=True) |
|
|
| |
| filtered_finished = [] |
| for hyp in unfiltered_finished: |
| mentioned_column_ids = set(hyp.column_history) |
| mentioned_key_column_ids = set(hyp.key_column_history) |
| mentioned_table_ids = set(hyp.table_history) |
|
|
| |
| if len(mentioned_table_ids) != len(hyp.table_history): |
| continue |
|
|
| |
| |
| if from_cond: |
| covered_tables = set() |
| must_include_key_columns = set() |
| candidate_table_ids = sorted(mentioned_table_ids) |
| start_table_id = candidate_table_ids[0] |
| for table_id in candidate_table_ids[1:]: |
| if table_id in covered_tables: |
| continue |
| try: |
| path = nx.shortest_path( |
| orig_item.schema.foreign_key_graph, source=start_table_id, target=table_id) |
| except (nx.NetworkXNoPath, nx.NodeNotFound): |
| covered_tables.add(table_id) |
| continue |
| |
| for source_table_id, target_table_id in zip(path, path[1:]): |
| if target_table_id in covered_tables: |
| continue |
| if target_table_id not in mentioned_table_ids: |
| continue |
| col1, col2 = orig_item.schema.foreign_key_graph[source_table_id][target_table_id]['columns'] |
| must_include_key_columns.add(col1) |
| must_include_key_columns.add(col2) |
| if not must_include_key_columns == mentioned_key_column_ids: |
| continue |
|
|
| |
| must_table_ids = set() |
| for col in mentioned_column_ids: |
| tab_ = orig_item.schema.columns[col].table |
| if tab_ is not None: |
| must_table_ids.add(tab_.id) |
| if not must_table_ids.issubset(mentioned_table_ids): |
| continue |
| |
| filtered_finished.append(hyp) |
| |
| filtered_finished.sort(key=operator.attrgetter('score'), reverse=True) |
| |
| prefixes_unfinished.sort(key=operator.attrgetter('score'), reverse=True) |
| |
|
|
| prefixes_, filtered_ = merge_beams(prefixes_unfinished, filtered_finished, beam_size) |
|
|
| if filtered_: |
| cached_finished_seqs = cached_finished_seqs + filtered_ |
| cached_finished_seqs.sort(key=operator.attrgetter('score'), reverse=True) |
|
|
| if prefixes_ and len(prefixes_[0].choice_history) < 200: |
| beam_prefix = prefixes_ |
| for hyp in beam_prefix: |
| hyp.table_history = [] |
| hyp.column_history = [] |
| hyp.key_column_history = [] |
| elif cached_finished_seqs: |
| return cached_finished_seqs[:beam_size] |
| else: |
| return unfiltered_finished[:beam_size] |
|
|
|
|
| |
| def merge_beams(beam_1, beam_2, beam_size): |
| if len(beam_1) == 0 or len(beam_2) == 0: |
| return beam_1, beam_2 |
| |
| annoated_beam_1 = [("beam_1", b) for b in beam_1] |
| annoated_beam_2 = [("beam_2", b) for b in beam_2] |
| merged_beams = annoated_beam_1 + annoated_beam_2 |
| merged_beams.sort(key=lambda x: x[1].score, reverse=True) |
|
|
| ret_beam_1 = [] |
| ret_beam_2 = [] |
| for label, beam in merged_beams[:beam_size]: |
| if label == "beam_1": |
| ret_beam_1.append(beam) |
| else: |
| assert label == "beam_2" |
| ret_beam_2.append(beam) |
| return ret_beam_1, ret_beam_2 |
|
|
|
|
| def beam_search_with_oracle_column(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False): |
| inference_state, next_choices = model.begin_inference(orig_item, preproc_item) |
| beam = [Hypothesis(inference_state, next_choices)] |
| finished = [] |
| assert beam_size == 1 |
|
|
| |
| root_node = preproc_item[1].tree |
|
|
| col_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "column")])) |
| tab_queue = list(reversed([val for val in model.decoder.ast_wrapper.find_all_descendants_of_type(root_node, "table")])) |
| col_queue_copy = col_queue[:] |
| tab_queue_copy = tab_queue[:] |
|
|
| predict_counter = 0 |
|
|
| for step in range(max_steps): |
| if visualize_flag: |
| print('step:') |
| print(step) |
| |
| if len(finished) == beam_size: |
| break |
| |
| |
| assert len(beam) == 1 |
| hyp = beam[0] |
| if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: |
| if hyp.inference_state.cur_item.node_type == "column" \ |
| and len(col_queue) > 0: |
| gold_col = col_queue[0] |
|
|
| flag = False |
| for _choice in hyp.next_choices: |
| if _choice[0] == gold_col: |
| flag = True |
| hyp.next_choices = [_choice] |
| col_queue = col_queue[1:] |
| break |
| assert flag |
| elif hyp.inference_state.cur_item.node_type == "table" \ |
| and len(tab_queue) > 0: |
| gold_tab = tab_queue[0] |
|
|
| flag = False |
| for _choice in hyp.next_choices: |
| if _choice[0] == gold_tab: |
| flag = True |
| hyp.next_choices = [_choice] |
| tab_queue = tab_queue[1:] |
| break |
| assert flag |
|
|
| |
| if hyp.inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY: |
| predict_counter += 1 |
| |
| |
| |
| candidates = [] |
| for hyp in beam: |
| candidates += [(hyp, choice, choice_score.item(), |
| hyp.score + choice_score.item()) |
| for choice, choice_score in hyp.next_choices] |
|
|
| |
| candidates.sort(key=operator.itemgetter(3), reverse=True) |
| candidates = candidates[:beam_size - len(finished)] |
|
|
|
|
| |
| beam = [] |
| for hyp, choice, choice_score, cum_score in candidates: |
| inference_state = hyp.inference_state.clone() |
| next_choices = inference_state.step(choice) |
| if next_choices is None: |
| finished.append(Hypothesis( |
| inference_state, |
| None, |
| cum_score, |
| hyp.choice_history + [choice], |
| hyp.score_history + [choice_score])) |
| else: |
| beam.append( |
| Hypothesis(inference_state, next_choices, cum_score, |
| hyp.choice_history + [choice], |
| hyp.score_history + [choice_score])) |
| if (len(col_queue_copy) + len(tab_queue_copy)) != predict_counter: |
| |
| pass |
| finished.sort(key=operator.attrgetter('score'), reverse=True) |
| return finished |
|
|
|
|
| def beam_search_with_oracle_sketch(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False): |
| inference_state, next_choices = model.begin_inference(orig_item, preproc_item) |
| hyp = Hypothesis(inference_state, next_choices) |
|
|
| parsed = model.decoder.preproc.grammar.parse(orig_item.code, "val") |
| if not parsed: |
| return [] |
|
|
| queue = [ |
| TreeState( |
| node = preproc_item[1].tree, |
| parent_field_type=model.decoder.preproc.grammar.root_type, |
| ) |
| ] |
|
|
| while queue: |
| item = queue.pop() |
| node = item.node |
| parent_field_type = item.parent_field_type |
|
|
| if isinstance(node, (list, tuple)): |
| node_type = parent_field_type + '*' |
| rule = (node_type, len(node)) |
| if rule not in model.decoder.rules_index: |
| return [] |
| rule_idx = model.decoder.rules_index[rule] |
| assert inference_state.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY |
| next_choices = inference_state.step(rule_idx) |
|
|
| if model.decoder.preproc.use_seq_elem_rules and \ |
| parent_field_type in model.decoder.ast_wrapper.sum_types: |
| parent_field_type += '_seq_elem' |
|
|
| for i, elem in reversed(list(enumerate(node))): |
| queue.append( |
| TreeState( |
| node=elem, |
| parent_field_type=parent_field_type, |
| )) |
|
|
| hyp = Hypothesis( |
| inference_state, |
| None, |
| 0, |
| hyp.choice_history + [rule_idx], |
| hyp.score_history + [0]) |
| continue |
|
|
| if parent_field_type in model.decoder.preproc.grammar.pointers: |
| assert inference_state.cur_item.state == TreeTraversal.State.POINTER_APPLY |
| |
| |
|
|
| assert isinstance(node, int) |
| next_choices = inference_state.step(node) |
| hyp = Hypothesis( |
| inference_state, |
| None, |
| 0, |
| hyp.choice_history + [node], |
| hyp.score_history + [0]) |
| continue |
|
|
| if parent_field_type in model.decoder.ast_wrapper.primitive_types: |
| field_value_split = model.decoder.preproc.grammar.tokenize_field_value(node) + [ |
| '<EOS>'] |
|
|
| for token in field_value_split: |
| next_choices = inference_state.step(token) |
| hyp = Hypothesis( |
| inference_state, |
| None, |
| 0, |
| hyp.choice_history + field_value_split, |
| hyp.score_history + [0]) |
| continue |
| |
| type_info = model.decoder.ast_wrapper.singular_types[node['_type']] |
|
|
| if parent_field_type in model.decoder.preproc.sum_type_constructors: |
| |
| rule = (parent_field_type, type_info.name) |
| rule_idx = model.decoder.rules_index[rule] |
| inference_state.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY |
| extra_rules = [ |
| model.decoder.rules_index[parent_field_type, extra_type] |
| for extra_type in node.get('_extra_types', [])] |
| next_choices = inference_state.step(rule_idx, extra_rules) |
|
|
| hyp = Hypothesis( |
| inference_state, |
| None, |
| 0, |
| hyp.choice_history + [rule_idx], |
| hyp.score_history + [0]) |
|
|
| if type_info.fields: |
| |
| |
| present = get_field_presence_info(model.decoder.ast_wrapper, node, type_info.fields) |
| rule = (node['_type'], tuple(present)) |
| rule_idx = model.decoder.rules_index[rule] |
| next_choices = inference_state.step(rule_idx) |
|
|
| hyp = Hypothesis( |
| inference_state, |
| None, |
| 0, |
| hyp.choice_history + [rule_idx], |
| hyp.score_history + [0]) |
|
|
| |
| for field_info in reversed(type_info.fields): |
| if field_info.name not in node: |
| continue |
|
|
| queue.append( |
| TreeState( |
| node=node[field_info.name], |
| parent_field_type=field_info.type, |
| )) |
|
|
| return [hyp] |