| import torch |
|
|
| def compute_align_loss(model, desc_enc, example): |
| '''model: a nl2code decoder''' |
| |
| root_node = example.tree |
| rel_cols = list(reversed([val for val in model.ast_wrapper.find_all_descendants_of_type(root_node, "column")])) |
| rel_tabs = list(reversed([val for val in model.ast_wrapper.find_all_descendants_of_type(root_node, "table")])) |
|
|
| rel_cols_t = torch.LongTensor(sorted(list(set(rel_cols)))).to(model._device) |
| rel_tabs_t = torch.LongTensor(sorted(list(set(rel_tabs)))).to(model._device) |
|
|
| mc_att_on_rel_col = desc_enc.m2c_align_mat.index_select(1, rel_cols_t) |
| mc_max_rel_att, _ = mc_att_on_rel_col.max(dim=0) |
| mc_max_rel_att.clamp_(min=1e-9) |
|
|
| mt_att_on_rel_tab = desc_enc.m2t_align_mat.index_select(1, rel_tabs_t) |
| mt_max_rel_att, _ = mt_att_on_rel_tab.max(dim=0) |
| mt_max_rel_att.clamp_(min=1e-9) |
|
|
| c_num = desc_enc.m2c_align_mat.size()[1] |
| un_rel_cols_t = torch.LongTensor(sorted(list(set(range(c_num)) - set(rel_cols)))).to(model._device) |
| mc_att_on_unrel_col = desc_enc.m2c_align_mat.index_select(1, un_rel_cols_t) |
| mc_max_unrel_att, _ = mc_att_on_unrel_col.max(dim=0) |
| mc_max_unrel_att.clamp_(min=1e-9) |
| mc_margin = torch.log(mc_max_unrel_att).mean() - torch.log(mc_max_rel_att).mean() |
|
|
| t_num = desc_enc.m2t_align_mat.size()[1] |
| if t_num > len(set(rel_tabs)): |
| un_rel_tabs_t = torch.LongTensor(sorted(list(set(range(t_num)) - set(rel_tabs)))).to(model._device) |
| mt_att_on_unrel_tab = desc_enc.m2t_align_mat.index_select(1, un_rel_tabs_t) |
| mt_max_unrel_att, _ = mt_att_on_unrel_tab.max(dim=0) |
| mt_max_unrel_att.clamp_(min=1e-9) |
| mt_margin = torch.log(mt_max_unrel_att).mean() - torch.log(mt_max_rel_att).mean() |
| else: |
| mt_margin = torch.tensor(0.0).to(model._device) |
|
|
| gamma = 1 |
| |
| align_loss = - torch.log(mc_max_rel_att).mean() - torch.log(mt_max_rel_att).mean() |
| |
| |
| return align_loss |
|
|
|
|
| def compute_pointer_with_align( |
| model, |
| node_type, |
| prev_state, |
| prev_action_emb, |
| parent_h, |
| parent_action_emb, |
| desc_enc): |
| new_state, attention_weights = model._update_state( |
| node_type, prev_state, prev_action_emb, parent_h, |
| parent_action_emb, desc_enc) |
| |
| output = new_state[0] |
| memory_pointer_logits = model.pointers[node_type]( |
| output, desc_enc.memory) |
| memory_pointer_probs = torch.nn.functional.softmax(\ |
| memory_pointer_logits, dim=1) |
| |
| if node_type == "column": |
| pointer_probs = torch.mm(memory_pointer_probs, desc_enc.m2c_align_mat) |
| else: |
| assert node_type == "table" |
| pointer_probs = torch.mm(memory_pointer_probs, desc_enc.m2t_align_mat) |
| pointer_probs = pointer_probs.clamp(min=1e-9) |
| pointer_logits = torch.log(pointer_probs) |
| return output, new_state, pointer_logits, attention_weights |
|
|