Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from utils import * | |
| from model.module import * | |
| from torch.nn import functional as F | |
| import copy | |
| class TreeNode: # the class save the tree node | |
| def __init__(self, embedding, left_flag=False): | |
| self.embedding = embedding | |
| self.left_flag = left_flag | |
| class TreeEmbedding: # the class save the tree | |
| def __init__(self, embedding, terminal=False): | |
| self.embedding = embedding | |
| self.terminal = terminal | |
| class TreeBeam: # the class save the beam node | |
| def __init__(self, score, node_stacks, embeddings_stacks, left_child_trees, out): | |
| self.score = score | |
| self.embeddings_stacks = embeddings_stacks | |
| self.node_stacks = node_stacks | |
| self.left_child_trees = left_child_trees | |
| self.out = out | |
| class Prediction(nn.Module): | |
| # a seq2tree decoder with Problem aware dynamic encoding | |
| def __init__(self, cfg, op_const_size): | |
| super(Prediction, self).__init__() | |
| # Define layers | |
| self.em_dropout = nn.Dropout(cfg.dropout_rate) | |
| # for Computational symbols and Generated numbers | |
| self.concat_l = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size) | |
| self.concat_r = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size) | |
| self.concat_lg = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size) | |
| self.concat_rg = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size) | |
| # attention module | |
| self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size) | |
| self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size) | |
| # predefined constant | |
| self.op_const_id = torch.arange(op_const_size).unsqueeze(0).cuda() | |
| self.padding_hidden = torch.zeros(1, cfg.decoder_hidden_size).cuda() | |
| def forward(self, node_stacks, left_child_trees, encoder_outputs, var_pades, source_mask, candi_mask, embedding_op_const): | |
| ''' | |
| Augments: | |
| node_stacks: [[TreeNode(_)]]*B, store the variable h | |
| left_child_trees: [t]*B, store the representation of left tree | |
| encoder_outputs: [B, S1, H] | |
| var_pades: [B, S2, H], all_vars_encoder_outputs | |
| padding_hidden: [1, H] | |
| source_mask: [B, S1], mask for source seq | |
| candi_mask: [B, op_size+const_size+var_size], mask for target seq | |
| Returns: | |
| num_score: [B x (op_size+const_size+var_size)] | |
| current_embeddings: q [B x 1 x H], the target vector of the current node | |
| current_context: c [B x 1 x H], the context vector of the current node, is calculated using the target vector and encoder_outputs | |
| current_all_embeddings: [B x (op_size+const_size+var_size) x H] e (M_op, M_con, h_loc^p) | |
| ''' | |
| current_embeddings = [] | |
| for node_list in node_stacks: | |
| if len(node_list) == 0: | |
| current_embeddings.append(self.padding_hidden) | |
| else: | |
| current_node = node_list[-1] | |
| current_embeddings.append(current_node.embedding) | |
| current_node_temp = [] # B x (1 x H) | |
| for l, c in zip(left_child_trees, current_embeddings): | |
| if l is None: | |
| cd = self.em_dropout(c) | |
| g = torch.tanh(self.concat_l(cd)) | |
| t = torch.sigmoid(self.concat_lg(cd)) | |
| current_node_temp.append(g*t) | |
| else: | |
| ld = self.em_dropout(l) | |
| cd = self.em_dropout(c) | |
| g = torch.tanh(self.concat_r(torch.cat((ld, cd), 1))) | |
| t = torch.sigmoid(self.concat_rg(torch.cat((ld, cd), 1))) | |
| current_node_temp.append(g*t) | |
| current_node = torch.stack(current_node_temp, dim=0) # B x 1 x H (q) | |
| current_embeddings = self.em_dropout(current_node) | |
| current_attn = self.attn(current_embeddings, encoder_outputs, source_mask) # B x S | |
| current_context = current_attn.unsqueeze(1).bmm(encoder_outputs) # B x 1 x H (c) | |
| leaf_input = torch.cat((current_node, current_context), 2) # B x 1 x 2H | |
| embedding_weight_op_const = embedding_op_const(self.op_const_id.repeat(var_pades.size(0), 1)) # B x var_size x H | |
| embedding_weight_all = torch.cat((embedding_weight_op_const, var_pades), dim=1) # B x (op_size+const_size+var_size) x H | |
| leaf_input = self.em_dropout(leaf_input) | |
| embedding_weight_all_ = self.em_dropout(embedding_weight_all) | |
| num_score = self.score(leaf_input, embedding_weight_all_, candi_mask) # B x (op_size+const_size+var_size) | |
| return num_score, current_node, current_context, embedding_weight_all | |
| class GenerateNode(nn.Module): | |
| def __init__(self, cfg, op_size): | |
| super(GenerateNode, self).__init__() | |
| self.embedding_size = cfg.decoder_embedding_size | |
| self.hidden_size = cfg.decoder_hidden_size | |
| self.op_size = op_size | |
| self.em_dropout = nn.Dropout(cfg.dropout_rate) | |
| self.generate_l = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size) | |
| self.generate_r = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size) | |
| self.generate_lg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size) | |
| self.generate_rg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size) | |
| def forward(self, current_embedding, node_label, current_context, embedding_op_const): | |
| """ | |
| Generate the hidden node hl and hr of tree, according to the front part of eq(10)(11) | |
| Arguments: | |
| current_embedding: [B x 1 x H (q)], the target vector of the current node | |
| node_label: [B (id)] | |
| current_context: [B x 1 x H (c)], context vector of current node | |
| embedding_op_const: Embedding of op_const | |
| Returns: | |
| left_child: [B x H (h)] | |
| right_child: [B x H (h)] | |
| token_embedding: [B x H (e(y|P) of op)] | |
| """ | |
| node_label_op = torch.clamp(node_label, max=self.op_size-1) | |
| current_embedding_ = self.em_dropout(current_embedding.squeeze(1)) | |
| current_context_ = self.em_dropout(current_context.squeeze(1)) | |
| token_embedding = embedding_op_const(node_label_op) | |
| token_embedding_ = self.em_dropout(token_embedding) | |
| l_child = torch.tanh(self.generate_l(torch.cat((current_embedding_, current_context_, token_embedding_), 1))) | |
| l_child_g = torch.sigmoid(self.generate_lg(torch.cat((current_embedding_, current_context_, token_embedding_), 1))) | |
| r_child = torch.tanh(self.generate_r(torch.cat((current_embedding_, current_context_, token_embedding_), 1))) | |
| r_child_g = torch.sigmoid(self.generate_rg(torch.cat((current_embedding_, current_context_, token_embedding_), 1))) | |
| l_child = l_child * l_child_g | |
| r_child = r_child * r_child_g | |
| return l_child, r_child, token_embedding | |
| class Merge(nn.Module): | |
| """ | |
| Get subtree embedding via Recursive Neural Network | |
| """ | |
| def __init__(self, cfg): | |
| super(Merge, self).__init__() | |
| self.embedding_size = cfg.decoder_embedding_size | |
| self.hidden_size = cfg.decoder_hidden_size | |
| self.em_dropout = nn.Dropout(cfg.dropout_rate) | |
| self.merge = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size) | |
| self.merge_g = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size) | |
| def forward(self, node_embedding, sub_tree_1, sub_tree_2): | |
| ''' | |
| Arguments: | |
| node_embedding: 1 x H | |
| sub_tree_1: 1 x H | |
| sub_tree_2: 1 x H | |
| Return: | |
| sub_tree: 1 x H | |
| ''' | |
| sub_tree_1 = self.em_dropout(sub_tree_1) | |
| sub_tree_2 = self.em_dropout(sub_tree_2) | |
| node_embedding = self.em_dropout(node_embedding) | |
| sub_tree = torch.tanh(self.merge(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1))) | |
| sub_tree_g = torch.sigmoid(self.merge_g(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1))) | |
| sub_tree = sub_tree * sub_tree_g | |
| return sub_tree | |
| class TreeDecoder(nn.Module): | |
| def __init__(self, cfg, tgt_lang): | |
| super(TreeDecoder, self).__init__() | |
| # embedding for op, const, num | |
| self.var_start = tgt_lang.var_start | |
| self.op_num = tgt_lang.op_num | |
| self.const_num = tgt_lang.const_num | |
| self.embedding_op_const = nn.Embedding(self.op_num+self.const_num, cfg.decoder_embedding_size) | |
| self.embedding_var = None # obtain from encoder | |
| self.cfg = cfg | |
| # modules of TreeDecoder | |
| self.predict = Prediction(cfg, self.op_num+self.const_num) | |
| self.generate = GenerateNode(cfg, self.op_num) | |
| self.merge = Merge(cfg) | |
| def get_var_encoder_outputs(self, encoder_outputs, var_positions): | |
| """ | |
| Arguments: | |
| encoder_outputs: B x S1 x H | |
| var_positions: B x S2 | |
| Returns: | |
| var_embeddings: B x S2 x H | |
| """ | |
| hidden_size = encoder_outputs.size(-1) | |
| expand_var_positions = var_positions.unsqueeze(-1).repeat(1, 1, hidden_size) | |
| var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_positions) | |
| return var_embeddings | |
| def forward(self, encoder_outputs, problem_output, len_source, var_positions, len_var, \ | |
| is_train=False, text_target=None, len_target=None): | |
| """ | |
| Arguments: | |
| encoder_outputs: B x S1 x H | |
| problem_output: B x H | |
| len_source: B | |
| text_target: B x S2 | |
| len_target: B | |
| var_positions: B x S3 | |
| len_var: B | |
| Return: | |
| training: output B x S x (op_size+const_size+var_size), logits of one batch | |
| testing: [expr] x B | |
| """ | |
| self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_positions) # B x S2 x H | |
| self.source_mask = sequence_mask(len_source) | |
| self.candi_mask = sequence_mask(len_var+self.var_start) | |
| if is_train: | |
| return self._forward_train(encoder_outputs, problem_output, text_target) | |
| else: | |
| return self._forward_test(encoder_outputs, problem_output) | |
| def _forward_train(self, encoder_outputs, problem_output, text_target): | |
| """ | |
| Arguments: | |
| embeddings_stacks: [[TreeEmbedding(t, terminal)]]*B, a stack of subtrees t in the first order traversal | |
| left_child_trees: [t]*B, the representation of left tree of current node | |
| node_stacks: [[TreeNode(h, left_flag)]]*B, a stack of hidden state h in the first order traversal | |
| Returns: | |
| all_node_outputs: B x S x (op_size+const_size+var_size), logits of one batch | |
| """ | |
| node_stacks = [[TreeNode(init_hidden)] for init_hidden in problem_output.split(1, dim=0)] | |
| embeddings_stacks = [[] for _ in range(encoder_outputs.size(0))] | |
| left_child_trees = [None]*encoder_outputs.size(0) | |
| all_node_outputs = [] | |
| for t in range(text_target.size(1)): | |
| num_score, current_embeddings, current_context, current_all_embeddings = self.predict( | |
| node_stacks, | |
| left_child_trees, | |
| encoder_outputs, | |
| self.embedding_var, | |
| self.source_mask, | |
| self.candi_mask, | |
| self.embedding_op_const) | |
| all_node_outputs.append(num_score) # [B x (op_size+const_size+var_size)] * S | |
| left_child, right_child, token_embedding = self.generate( | |
| current_embeddings, | |
| text_target[:,t], | |
| current_context, | |
| self.embedding_op_const) | |
| left_child_trees = [] | |
| for idx, (l, r, node_stack, target_id, embeddings_stack) in enumerate(zip(left_child.split(1), right_child.split(1), | |
| node_stacks, text_target[:,t].tolist(), embeddings_stacks)): | |
| # Determines whether the tree traversal is complete | |
| if len(node_stack) != 0: | |
| node_stack.pop() | |
| else: | |
| left_child_trees.append(None) | |
| continue | |
| if target_id < self.op_num: | |
| node_stack.append(TreeNode(r)) | |
| node_stack.append(TreeNode(l, left_flag=True)) | |
| # embeddings_stack, put e(y|P) of op in temporarily | |
| embeddings_stack.append(TreeEmbedding(token_embedding[idx].unsqueeze(0), False)) | |
| else: | |
| current_num = current_all_embeddings[idx, target_id].unsqueeze(0) # 1 x H | |
| # Reach the right leaf node and merge the tree representation from bottom up | |
| while len(embeddings_stack) > 0 and embeddings_stack[-1].terminal: | |
| sub_stree = embeddings_stack.pop() | |
| op = embeddings_stack.pop() | |
| # embedding vector of two sub-targets is merged as the subtree embedding of nodes, corresponding to eq(12) | |
| # with e(y|P), sub_tree_1 and sub_tree_2 | |
| current_num = self.merge(op.embedding, sub_stree.embedding, current_num) | |
| embeddings_stack.append(TreeEmbedding(current_num, True)) | |
| # Reach the left leaf node and save the representation of the left subtree for generation of q | |
| if len(embeddings_stack) > 0 and embeddings_stack[-1].terminal: | |
| left_child_trees.append(embeddings_stack[-1].embedding) | |
| else: | |
| left_child_trees.append(None) | |
| all_node_outputs = torch.stack(all_node_outputs, dim=1) | |
| return all_node_outputs | |
| def _forward_test(self, encoder_outputs, problem_output): | |
| exp_outputs = [] | |
| for sample_id in range(encoder_outputs.size(0)): | |
| # set batch size as 1 | |
| node_stacks = [[TreeNode(problem_output[sample_id:sample_id+1])]] | |
| embeddings_stacks = [[]] | |
| left_child_trees = [None] | |
| beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_child_trees, [])] | |
| for _ in range(self.cfg.max_output_len): | |
| # re-maintain of one beams | |
| current_beams = [] | |
| while len(beams) > 0: | |
| beam_item = beams.pop() | |
| # The candidates are stored in beams in all process | |
| if len(beam_item.node_stacks[0]) == 0: | |
| current_beams.append(beam_item) | |
| continue | |
| num_score, current_embeddings, current_context, current_all_embeddings = self.predict( | |
| beam_item.node_stacks, | |
| beam_item.left_child_trees, | |
| encoder_outputs[sample_id:sample_id+1], | |
| self.embedding_var[sample_id:sample_id+1], | |
| self.source_mask[sample_id:sample_id+1], | |
| self.candi_mask[sample_id:sample_id+1], | |
| self.embedding_op_const) | |
| out_score = F.log_softmax(num_score, dim=1) | |
| topv, topi = out_score.topk(self.cfg.beam_size) | |
| for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)): | |
| current_node_stack = copy_list(beam_item.node_stacks) | |
| current_left_child_trees = [] | |
| current_embeddings_stacks = copy_list(beam_item.embeddings_stacks) | |
| current_out = copy.deepcopy(beam_item.out) | |
| out_token = int(ti) | |
| current_out.append(out_token) | |
| current_node_stack[0].pop() | |
| if out_token < self.op_num: | |
| generate_input = torch.LongTensor([out_token]).cuda() | |
| left_child, right_child, token_embedding = self.generate( | |
| current_embeddings, | |
| generate_input, | |
| current_context, | |
| self.embedding_op_const) | |
| current_node_stack[0].append(TreeNode(right_child)) | |
| current_node_stack[0].append(TreeNode(left_child, left_flag=True)) | |
| current_embeddings_stacks[0].append(TreeEmbedding(token_embedding, False)) | |
| else: | |
| current_num = current_all_embeddings[:, out_token] | |
| while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal: | |
| sub_stree = current_embeddings_stacks[0].pop() | |
| op = current_embeddings_stacks[0].pop() | |
| current_num = self.merge(op.embedding, sub_stree.embedding, current_num) | |
| current_embeddings_stacks[0].append(TreeEmbedding(current_num, True)) | |
| if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal: | |
| current_left_child_trees.append(current_embeddings_stacks[0][-1].embedding) | |
| else: | |
| current_left_child_trees.append(None) | |
| current_beams.append(TreeBeam(beam_item.score+float(tv), current_node_stack, current_embeddings_stacks, | |
| current_left_child_trees, current_out)) | |
| beams = sorted(current_beams, key=lambda x: x.score, reverse=True) | |
| beams = beams[:self.cfg.beam_size] | |
| # early termination | |
| flag = True | |
| for beam_item in beams: | |
| if len(beam_item.node_stacks[0]) != 0: | |
| flag = False | |
| break | |
| if flag: break | |
| exp_outputs.append(beams[0].out) | |
| return exp_outputs | |