## @package beam_search # Module caffe2.python.models.seq2seq.beam_search from collections import namedtuple from caffe2.python import core import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper class BeamSearchForwardOnly(object): """ Class generalizing forward beam search for seq2seq models. Also provides types to specify the recurrent structure of decoding: StateConfig: initial_value: blob providing value of state at first step_model state_prev_link: LinkConfig describing how recurrent step receives input from global state blob in each step state_link: LinkConfig describing how step writes (produces new state) to global state blob in each step LinkConfig: blob: blob connecting global state blob to step application offset: offset from beginning of global blob for link in time dimension window: width of global blob to read/write in time dimension """ LinkConfig = namedtuple('LinkConfig', ['blob', 'offset', 'window']) StateConfig = namedtuple( 'StateConfig', ['initial_value', 'state_prev_link', 'state_link'], ) def __init__( self, beam_size, model, eos_token_id, go_token_id=seq2seq_util.GO_ID, post_eos_penalty=None, ): self.beam_size = beam_size self.model = model self.step_model = Seq2SeqModelHelper( name='step_model', param_model=self.model, ) self.go_token_id = go_token_id self.eos_token_id = eos_token_id self.post_eos_penalty = post_eos_penalty ( self.timestep, self.scores_t_prev, self.tokens_t_prev, self.hypo_t_prev, self.attention_t_prev, ) = self.step_model.net.AddExternalInputs( 'timestep', 'scores_t_prev', 'tokens_t_prev', 'hypo_t_prev', 'attention_t_prev', ) tokens_t_prev_int32 = self.step_model.net.Cast( self.tokens_t_prev, 'tokens_t_prev_int32', to=core.DataType.INT32, ) self.tokens_t_prev_int32_flattened, _ = self.step_model.net.Reshape( [tokens_t_prev_int32], [tokens_t_prev_int32, 'input_t_int32_old_shape'], shape=[1, -1], ) def get_step_model(self): return self.step_model def get_previous_tokens(self): return self.tokens_t_prev_int32_flattened def get_timestep(self): return self.timestep # TODO: make attentions a generic state # data_dependencies is a list of blobs that the operator should wait for # before beginning execution. This ensures that ops are run in the correct # order when the RecurrentNetwork op is embedded in a DAGNet, for ex. def apply( self, inputs, length, log_probs, attentions, state_configs, data_dependencies, word_rewards=None, possible_translation_tokens=None, go_token_id=None, ): ZERO = self.model.param_init_net.ConstantFill( [], 'ZERO', shape=[1], value=0, dtype=core.DataType.INT32, ) on_initial_step = self.step_model.net.EQ( [ZERO, self.timestep], 'on_initial_step', ) if self.post_eos_penalty is not None: eos_token = self.model.param_init_net.ConstantFill( [], 'eos_token', shape=[self.beam_size], value=self.eos_token_id, dtype=core.DataType.INT32, ) finished_penalty = self.model.param_init_net.ConstantFill( [], 'finished_penalty', shape=[1], value=float(self.post_eos_penalty), dtype=core.DataType.FLOAT, ) ZERO_FLOAT = self.model.param_init_net.ConstantFill( [], 'ZERO_FLOAT', shape=[1], value=0.0, dtype=core.DataType.FLOAT, ) finished_penalty = self.step_model.net.Conditional( [on_initial_step, ZERO_FLOAT, finished_penalty], 'possible_finished_penalty', ) tokens_t_flat = self.step_model.net.FlattenToVec( self.tokens_t_prev, 'tokens_t_flat', ) tokens_t_flat_int = self.step_model.net.Cast( tokens_t_flat, 'tokens_t_flat_int', to=core.DataType.INT32, ) predecessor_is_eos = self.step_model.net.EQ( [tokens_t_flat_int, eos_token], 'predecessor_is_eos', ) predecessor_is_eos_float = self.step_model.net.Cast( predecessor_is_eos, 'predecessor_is_eos_float', to=core.DataType.FLOAT, ) predecessor_is_eos_penalty = self.step_model.net.Mul( [predecessor_is_eos_float, finished_penalty], 'predecessor_is_eos_penalty', broadcast=1, ) log_probs = self.step_model.net.Add( [log_probs, predecessor_is_eos_penalty], 'log_probs_penalized', broadcast=1, axis=0, ) # [beam_size, beam_size] best_scores_per_hypo, best_tokens_per_hypo = self.step_model.net.TopK( log_probs, ['best_scores_per_hypo', 'best_tokens_per_hypo_indices'], k=self.beam_size, ) if possible_translation_tokens: # [beam_size, beam_size] best_tokens_per_hypo = self.step_model.net.Gather( [possible_translation_tokens, best_tokens_per_hypo], ['best_tokens_per_hypo'] ) # [beam_size] scores_t_prev_squeezed, _ = self.step_model.net.Reshape( self.scores_t_prev, ['scores_t_prev_squeezed', 'scores_t_prev_old_shape'], shape=[self.beam_size], ) # [beam_size, beam_size] output_scores = self.step_model.net.Add( [best_scores_per_hypo, scores_t_prev_squeezed], 'output_scores', broadcast=1, axis=0, ) if word_rewards is not None: # [beam_size, beam_size] word_rewards_for_best_tokens_per_hypo = self.step_model.net.Gather( [word_rewards, best_tokens_per_hypo], 'word_rewards_for_best_tokens_per_hypo', ) # [beam_size, beam_size] output_scores = self.step_model.net.Add( [output_scores, word_rewards_for_best_tokens_per_hypo], 'output_scores', ) # [beam_size * beam_size] output_scores_flattened, _ = self.step_model.net.Reshape( [output_scores], [output_scores, 'output_scores_old_shape'], shape=[-1], ) MINUS_ONE_INT32 = self.model.param_init_net.ConstantFill( [], 'MINUS_ONE_INT32', value=-1, shape=[1], dtype=core.DataType.INT32, ) BEAM_SIZE = self.model.param_init_net.ConstantFill( [], 'beam_size', shape=[1], value=self.beam_size, dtype=core.DataType.INT32, ) # current_beam_size (predecessor states from previous step) # is 1 on first step (so we just need beam_size scores), # and beam_size subsequently (so we need all beam_size * beam_size # scores) slice_end = self.step_model.net.Conditional( [on_initial_step, BEAM_SIZE, MINUS_ONE_INT32], ['slice_end'], ) # [current_beam_size * beam_size] output_scores_flattened_slice = self.step_model.net.Slice( [output_scores_flattened, ZERO, slice_end], 'output_scores_flattened_slice', ) # [1, current_beam_size * beam_size] output_scores_flattened_slice, _ = self.step_model.net.Reshape( output_scores_flattened_slice, [ output_scores_flattened_slice, 'output_scores_flattened_slice_old_shape', ], shape=[1, -1], ) # [1, beam_size] scores_t, best_indices = self.step_model.net.TopK( output_scores_flattened_slice, ['scores_t', 'best_indices'], k=self.beam_size, ) BEAM_SIZE_64 = self.model.param_init_net.Cast( BEAM_SIZE, 'BEAM_SIZE_64', to=core.DataType.INT64, ) # [1, beam_size] hypo_t_int32 = self.step_model.net.Div( [best_indices, BEAM_SIZE_64], 'hypo_t_int32', broadcast=1, ) hypo_t = self.step_model.net.Cast( hypo_t_int32, 'hypo_t', to=core.DataType.FLOAT, ) # [beam_size, encoder_length, 1] attention_t = self.step_model.net.Gather( [attentions, hypo_t_int32], 'attention_t', ) # [1, beam_size, encoder_length] attention_t, _ = self.step_model.net.Reshape( attention_t, [attention_t, 'attention_t_old_shape'], shape=[1, self.beam_size, -1], ) # [beam_size * beam_size] best_tokens_per_hypo_flatten, _ = self.step_model.net.Reshape( best_tokens_per_hypo, [ 'best_tokens_per_hypo_flatten', 'best_tokens_per_hypo_old_shape', ], shape=[-1], ) tokens_t_int32 = self.step_model.net.Gather( [best_tokens_per_hypo_flatten, best_indices], 'tokens_t_int32', ) tokens_t = self.step_model.net.Cast( tokens_t_int32, 'tokens_t', to=core.DataType.FLOAT, ) def choose_state_per_hypo(state_config): state_flattened, _ = self.step_model.net.Reshape( state_config.state_link.blob, [ state_config.state_link.blob, state_config.state_link.blob + '_old_shape', ], shape=[self.beam_size, -1], ) state_chosen_per_hypo = self.step_model.net.Gather( [state_flattened, hypo_t_int32], str(state_config.state_link.blob) + '_chosen_per_hypo', ) return self.StateConfig( initial_value=state_config.initial_value, state_prev_link=state_config.state_prev_link, state_link=self.LinkConfig( blob=state_chosen_per_hypo, offset=state_config.state_link.offset, window=state_config.state_link.window, ) ) state_configs = [choose_state_per_hypo(c) for c in state_configs] initial_scores = self.model.param_init_net.ConstantFill( [], 'initial_scores', shape=[1], value=0.0, dtype=core.DataType.FLOAT, ) if go_token_id: initial_tokens = self.model.net.Copy( [go_token_id], 'initial_tokens', ) else: initial_tokens = self.model.param_init_net.ConstantFill( [], 'initial_tokens', shape=[1], value=float(self.go_token_id), dtype=core.DataType.FLOAT, ) initial_hypo = self.model.param_init_net.ConstantFill( [], 'initial_hypo', shape=[1], value=0.0, dtype=core.DataType.FLOAT, ) encoder_inputs_flattened, _ = self.model.net.Reshape( inputs, ['encoder_inputs_flattened', 'encoder_inputs_old_shape'], shape=[-1], ) init_attention = self.model.net.ConstantFill( encoder_inputs_flattened, 'init_attention', value=0.0, dtype=core.DataType.FLOAT, ) state_configs = state_configs + [ self.StateConfig( initial_value=initial_scores, state_prev_link=self.LinkConfig(self.scores_t_prev, 0, 1), state_link=self.LinkConfig(scores_t, 1, 1), ), self.StateConfig( initial_value=initial_tokens, state_prev_link=self.LinkConfig(self.tokens_t_prev, 0, 1), state_link=self.LinkConfig(tokens_t, 1, 1), ), self.StateConfig( initial_value=initial_hypo, state_prev_link=self.LinkConfig(self.hypo_t_prev, 0, 1), state_link=self.LinkConfig(hypo_t, 1, 1), ), self.StateConfig( initial_value=init_attention, state_prev_link=self.LinkConfig(self.attention_t_prev, 0, 1), state_link=self.LinkConfig(attention_t, 1, 1), ), ] fake_input = self.model.net.ConstantFill( length, 'beam_search_fake_input', input_as_shape=True, extra_shape=[self.beam_size, 1], value=0.0, dtype=core.DataType.FLOAT, ) all_inputs = ( [fake_input] + self.step_model.params + [state_config.initial_value for state_config in state_configs] + data_dependencies ) forward_links = [] recurrent_states = [] for state_config in state_configs: state_name = str(state_config.state_prev_link.blob) + '_states' recurrent_states.append(state_name) forward_links.append(( state_config.state_prev_link.blob, state_name, state_config.state_prev_link.offset, state_config.state_prev_link.window, )) forward_links.append(( state_config.state_link.blob, state_name, state_config.state_link.offset, state_config.state_link.window, )) link_internal, link_external, link_offset, link_window = ( zip(*forward_links) ) all_outputs = [ str(s) + '_all' for s in [scores_t, tokens_t, hypo_t, attention_t] ] results = self.model.net.RecurrentNetwork( all_inputs, all_outputs + ['step_workspaces'], param=[all_inputs.index(p) for p in self.step_model.params], alias_src=[ str(s) + '_states' for s in [ self.scores_t_prev, self.tokens_t_prev, self.hypo_t_prev, self.attention_t_prev, ] ], alias_dst=all_outputs, alias_offset=[0] * 4, recurrent_states=recurrent_states, initial_recurrent_state_ids=[ all_inputs.index(state_config.initial_value) for state_config in state_configs ], link_internal=[str(l) for l in link_internal], link_external=[str(l) for l in link_external], link_offset=link_offset, link_window=link_window, backward_link_internal=[], backward_link_external=[], backward_link_offset=[], step_net=self.step_model.net.Proto(), timestep=str(self.timestep), outputs_with_grads=[], enable_rnn_executor=1, rnn_executor_debug=0 ) score_t_all, tokens_t_all, hypo_t_all, attention_t_all = results[:4] output_token_beam_list = self.model.net.Cast( tokens_t_all, 'output_token_beam_list', to=core.DataType.INT32, ) output_prev_index_beam_list = self.model.net.Cast( hypo_t_all, 'output_prev_index_beam_list', to=core.DataType.INT32, ) output_score_beam_list = self.model.net.Alias( score_t_all, 'output_score_beam_list', ) output_attention_weights_beam_list = self.model.net.Alias( attention_t_all, 'output_attention_weights_beam_list', ) return ( output_token_beam_list, output_prev_index_beam_list, output_score_beam_list, output_attention_weights_beam_list, )