pgps-demo / model /decoder /tree_decoder.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import torch
import torch.nn as nn
from utils import *
from model.module import *
from torch.nn import functional as F
import copy
class TreeNode: # the class save the tree node
def __init__(self, embedding, left_flag=False):
self.embedding = embedding
self.left_flag = left_flag
class TreeEmbedding: # the class save the tree
def __init__(self, embedding, terminal=False):
self.embedding = embedding
self.terminal = terminal
class TreeBeam: # the class save the beam node
def __init__(self, score, node_stacks, embeddings_stacks, left_child_trees, out):
self.score = score
self.embeddings_stacks = embeddings_stacks
self.node_stacks = node_stacks
self.left_child_trees = left_child_trees
self.out = out
class Prediction(nn.Module):
# a seq2tree decoder with Problem aware dynamic encoding
def __init__(self, cfg, op_const_size):
super(Prediction, self).__init__()
# Define layers
self.em_dropout = nn.Dropout(cfg.dropout_rate)
# for Computational symbols and Generated numbers
self.concat_l = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size)
self.concat_r = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size)
self.concat_lg = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size)
self.concat_rg = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size)
# attention module
self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
# predefined constant
self.op_const_id = torch.arange(op_const_size).unsqueeze(0).cuda()
self.padding_hidden = torch.zeros(1, cfg.decoder_hidden_size).cuda()
def forward(self, node_stacks, left_child_trees, encoder_outputs, var_pades, source_mask, candi_mask, embedding_op_const):
'''
Augments:
node_stacks: [[TreeNode(_)]]*B, store the variable h
left_child_trees: [t]*B, store the representation of left tree
encoder_outputs: [B, S1, H]
var_pades: [B, S2, H], all_vars_encoder_outputs
padding_hidden: [1, H]
source_mask: [B, S1], mask for source seq
candi_mask: [B, op_size+const_size+var_size], mask for target seq
Returns:
num_score: [B x (op_size+const_size+var_size)]
current_embeddings: q [B x 1 x H], the target vector of the current node
current_context: c [B x 1 x H], the context vector of the current node, is calculated using the target vector and encoder_outputs
current_all_embeddings: [B x (op_size+const_size+var_size) x H] e (M_op, M_con, h_loc^p)
'''
current_embeddings = []
for node_list in node_stacks:
if len(node_list) == 0:
current_embeddings.append(self.padding_hidden)
else:
current_node = node_list[-1]
current_embeddings.append(current_node.embedding)
current_node_temp = [] # B x (1 x H)
for l, c in zip(left_child_trees, current_embeddings):
if l is None:
cd = self.em_dropout(c)
g = torch.tanh(self.concat_l(cd))
t = torch.sigmoid(self.concat_lg(cd))
current_node_temp.append(g*t)
else:
ld = self.em_dropout(l)
cd = self.em_dropout(c)
g = torch.tanh(self.concat_r(torch.cat((ld, cd), 1)))
t = torch.sigmoid(self.concat_rg(torch.cat((ld, cd), 1)))
current_node_temp.append(g*t)
current_node = torch.stack(current_node_temp, dim=0) # B x 1 x H (q)
current_embeddings = self.em_dropout(current_node)
current_attn = self.attn(current_embeddings, encoder_outputs, source_mask) # B x S
current_context = current_attn.unsqueeze(1).bmm(encoder_outputs) # B x 1 x H (c)
leaf_input = torch.cat((current_node, current_context), 2) # B x 1 x 2H
embedding_weight_op_const = embedding_op_const(self.op_const_id.repeat(var_pades.size(0), 1)) # B x var_size x H
embedding_weight_all = torch.cat((embedding_weight_op_const, var_pades), dim=1) # B x (op_size+const_size+var_size) x H
leaf_input = self.em_dropout(leaf_input)
embedding_weight_all_ = self.em_dropout(embedding_weight_all)
num_score = self.score(leaf_input, embedding_weight_all_, candi_mask) # B x (op_size+const_size+var_size)
return num_score, current_node, current_context, embedding_weight_all
class GenerateNode(nn.Module):
def __init__(self, cfg, op_size):
super(GenerateNode, self).__init__()
self.embedding_size = cfg.decoder_embedding_size
self.hidden_size = cfg.decoder_hidden_size
self.op_size = op_size
self.em_dropout = nn.Dropout(cfg.dropout_rate)
self.generate_l = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
self.generate_r = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
self.generate_lg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
self.generate_rg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
def forward(self, current_embedding, node_label, current_context, embedding_op_const):
"""
Generate the hidden node hl and hr of tree, according to the front part of eq(10)(11)
Arguments:
current_embedding: [B x 1 x H (q)], the target vector of the current node
node_label: [B (id)]
current_context: [B x 1 x H (c)], context vector of current node
embedding_op_const: Embedding of op_const
Returns:
left_child: [B x H (h)]
right_child: [B x H (h)]
token_embedding: [B x H (e(y|P) of op)]
"""
node_label_op = torch.clamp(node_label, max=self.op_size-1)
current_embedding_ = self.em_dropout(current_embedding.squeeze(1))
current_context_ = self.em_dropout(current_context.squeeze(1))
token_embedding = embedding_op_const(node_label_op)
token_embedding_ = self.em_dropout(token_embedding)
l_child = torch.tanh(self.generate_l(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
l_child_g = torch.sigmoid(self.generate_lg(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
r_child = torch.tanh(self.generate_r(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
r_child_g = torch.sigmoid(self.generate_rg(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
l_child = l_child * l_child_g
r_child = r_child * r_child_g
return l_child, r_child, token_embedding
class Merge(nn.Module):
"""
Get subtree embedding via Recursive Neural Network
"""
def __init__(self, cfg):
super(Merge, self).__init__()
self.embedding_size = cfg.decoder_embedding_size
self.hidden_size = cfg.decoder_hidden_size
self.em_dropout = nn.Dropout(cfg.dropout_rate)
self.merge = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
self.merge_g = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
def forward(self, node_embedding, sub_tree_1, sub_tree_2):
'''
Arguments:
node_embedding: 1 x H
sub_tree_1: 1 x H
sub_tree_2: 1 x H
Return:
sub_tree: 1 x H
'''
sub_tree_1 = self.em_dropout(sub_tree_1)
sub_tree_2 = self.em_dropout(sub_tree_2)
node_embedding = self.em_dropout(node_embedding)
sub_tree = torch.tanh(self.merge(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
sub_tree_g = torch.sigmoid(self.merge_g(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
sub_tree = sub_tree * sub_tree_g
return sub_tree
class TreeDecoder(nn.Module):
def __init__(self, cfg, tgt_lang):
super(TreeDecoder, self).__init__()
# embedding for op, const, num
self.var_start = tgt_lang.var_start
self.op_num = tgt_lang.op_num
self.const_num = tgt_lang.const_num
self.embedding_op_const = nn.Embedding(self.op_num+self.const_num, cfg.decoder_embedding_size)
self.embedding_var = None # obtain from encoder
self.cfg = cfg
# modules of TreeDecoder
self.predict = Prediction(cfg, self.op_num+self.const_num)
self.generate = GenerateNode(cfg, self.op_num)
self.merge = Merge(cfg)
def get_var_encoder_outputs(self, encoder_outputs, var_positions):
"""
Arguments:
encoder_outputs: B x S1 x H
var_positions: B x S2
Returns:
var_embeddings: B x S2 x H
"""
hidden_size = encoder_outputs.size(-1)
expand_var_positions = var_positions.unsqueeze(-1).repeat(1, 1, hidden_size)
var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_positions)
return var_embeddings
def forward(self, encoder_outputs, problem_output, len_source, var_positions, len_var, \
is_train=False, text_target=None, len_target=None):
"""
Arguments:
encoder_outputs: B x S1 x H
problem_output: B x H
len_source: B
text_target: B x S2
len_target: B
var_positions: B x S3
len_var: B
Return:
training: output B x S x (op_size+const_size+var_size), logits of one batch
testing: [expr] x B
"""
self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_positions) # B x S2 x H
self.source_mask = sequence_mask(len_source)
self.candi_mask = sequence_mask(len_var+self.var_start)
if is_train:
return self._forward_train(encoder_outputs, problem_output, text_target)
else:
return self._forward_test(encoder_outputs, problem_output)
def _forward_train(self, encoder_outputs, problem_output, text_target):
"""
Arguments:
embeddings_stacks: [[TreeEmbedding(t, terminal)]]*B, a stack of subtrees t in the first order traversal
left_child_trees: [t]*B, the representation of left tree of current node
node_stacks: [[TreeNode(h, left_flag)]]*B, a stack of hidden state h in the first order traversal
Returns:
all_node_outputs: B x S x (op_size+const_size+var_size), logits of one batch
"""
node_stacks = [[TreeNode(init_hidden)] for init_hidden in problem_output.split(1, dim=0)]
embeddings_stacks = [[] for _ in range(encoder_outputs.size(0))]
left_child_trees = [None]*encoder_outputs.size(0)
all_node_outputs = []
for t in range(text_target.size(1)):
num_score, current_embeddings, current_context, current_all_embeddings = self.predict(
node_stacks,
left_child_trees,
encoder_outputs,
self.embedding_var,
self.source_mask,
self.candi_mask,
self.embedding_op_const)
all_node_outputs.append(num_score) # [B x (op_size+const_size+var_size)] * S
left_child, right_child, token_embedding = self.generate(
current_embeddings,
text_target[:,t],
current_context,
self.embedding_op_const)
left_child_trees = []
for idx, (l, r, node_stack, target_id, embeddings_stack) in enumerate(zip(left_child.split(1), right_child.split(1),
node_stacks, text_target[:,t].tolist(), embeddings_stacks)):
# Determines whether the tree traversal is complete
if len(node_stack) != 0:
node_stack.pop()
else:
left_child_trees.append(None)
continue
if target_id < self.op_num:
node_stack.append(TreeNode(r))
node_stack.append(TreeNode(l, left_flag=True))
# embeddings_stack, put e(y|P) of op in temporarily
embeddings_stack.append(TreeEmbedding(token_embedding[idx].unsqueeze(0), False))
else:
current_num = current_all_embeddings[idx, target_id].unsqueeze(0) # 1 x H
# Reach the right leaf node and merge the tree representation from bottom up
while len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
sub_stree = embeddings_stack.pop()
op = embeddings_stack.pop()
# embedding vector of two sub-targets is merged as the subtree embedding of nodes, corresponding to eq(12)
# with e(y|P), sub_tree_1 and sub_tree_2
current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
embeddings_stack.append(TreeEmbedding(current_num, True))
# Reach the left leaf node and save the representation of the left subtree for generation of q
if len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
left_child_trees.append(embeddings_stack[-1].embedding)
else:
left_child_trees.append(None)
all_node_outputs = torch.stack(all_node_outputs, dim=1)
return all_node_outputs
def _forward_test(self, encoder_outputs, problem_output):
exp_outputs = []
for sample_id in range(encoder_outputs.size(0)):
# set batch size as 1
node_stacks = [[TreeNode(problem_output[sample_id:sample_id+1])]]
embeddings_stacks = [[]]
left_child_trees = [None]
beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_child_trees, [])]
for _ in range(self.cfg.max_output_len):
# re-maintain of one beams
current_beams = []
while len(beams) > 0:
beam_item = beams.pop()
# The candidates are stored in beams in all process
if len(beam_item.node_stacks[0]) == 0:
current_beams.append(beam_item)
continue
num_score, current_embeddings, current_context, current_all_embeddings = self.predict(
beam_item.node_stacks,
beam_item.left_child_trees,
encoder_outputs[sample_id:sample_id+1],
self.embedding_var[sample_id:sample_id+1],
self.source_mask[sample_id:sample_id+1],
self.candi_mask[sample_id:sample_id+1],
self.embedding_op_const)
out_score = F.log_softmax(num_score, dim=1)
topv, topi = out_score.topk(self.cfg.beam_size)
for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
current_node_stack = copy_list(beam_item.node_stacks)
current_left_child_trees = []
current_embeddings_stacks = copy_list(beam_item.embeddings_stacks)
current_out = copy.deepcopy(beam_item.out)
out_token = int(ti)
current_out.append(out_token)
current_node_stack[0].pop()
if out_token < self.op_num:
generate_input = torch.LongTensor([out_token]).cuda()
left_child, right_child, token_embedding = self.generate(
current_embeddings,
generate_input,
current_context,
self.embedding_op_const)
current_node_stack[0].append(TreeNode(right_child))
current_node_stack[0].append(TreeNode(left_child, left_flag=True))
current_embeddings_stacks[0].append(TreeEmbedding(token_embedding, False))
else:
current_num = current_all_embeddings[:, out_token]
while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
sub_stree = current_embeddings_stacks[0].pop()
op = current_embeddings_stacks[0].pop()
current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
current_embeddings_stacks[0].append(TreeEmbedding(current_num, True))
if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
current_left_child_trees.append(current_embeddings_stacks[0][-1].embedding)
else:
current_left_child_trees.append(None)
current_beams.append(TreeBeam(beam_item.score+float(tv), current_node_stack, current_embeddings_stacks,
current_left_child_trees, current_out))
beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
beams = beams[:self.cfg.beam_size]
# early termination
flag = True
for beam_item in beams:
if len(beam_item.node_stacks[0]) != 0:
flag = False
break
if flag: break
exp_outputs.append(beams[0].out)
return exp_outputs