| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| | from abc import ABCMeta, abstractmethod |
| | import argparse |
| | from future.utils import viewitems |
| | import logging |
| | import numpy as np |
| | import sys |
| |
|
| | from caffe2.python import core, rnn_cell, workspace |
| | from caffe2.python.models.seq2seq.beam_search import BeamSearchForwardOnly |
| | from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper |
| | import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| | logger.setLevel(logging.INFO) |
| | logger.addHandler(logging.StreamHandler(sys.stderr)) |
| |
|
| |
|
| | def _weighted_sum(model, values, weight, output_name): |
| | values_weights = zip(values, [weight] * len(values)) |
| | values_weights_flattened = [x for v_w in values_weights for x in v_w] |
| | return model.net.WeightedSum( |
| | values_weights_flattened, |
| | output_name, |
| | ) |
| |
|
| |
|
| | class Seq2SeqModelCaffe2EnsembleDecoderBase(metaclass=ABCMeta): |
| |
|
| | @abstractmethod |
| | def get_model_file(self, model): |
| | pass |
| |
|
| | @abstractmethod |
| | def get_db_type(self): |
| | pass |
| |
|
| | def build_word_rewards(self, vocab_size, word_reward, unk_reward): |
| | word_rewards = np.full([vocab_size], word_reward, dtype=np.float32) |
| | word_rewards[seq2seq_util.PAD_ID] = 0 |
| | word_rewards[seq2seq_util.GO_ID] = 0 |
| | word_rewards[seq2seq_util.EOS_ID] = 0 |
| | word_rewards[seq2seq_util.UNK_ID] = word_reward + unk_reward |
| | return word_rewards |
| |
|
| | def load_models(self): |
| | db_reader = 'reader' |
| | for model, scope_name in zip( |
| | self.models, |
| | self.decoder_scope_names, |
| | ): |
| | params_for_current_model = [ |
| | param |
| | for param in self.model.GetAllParams() |
| | if str(param).startswith(scope_name) |
| | ] |
| | assert workspace.RunOperatorOnce(core.CreateOperator( |
| | 'CreateDB', |
| | [], [db_reader], |
| | db=self.get_model_file(model), |
| | db_type=self.get_db_type()) |
| | ), 'Failed to create db {}'.format(self.get_model_file(model)) |
| | assert workspace.RunOperatorOnce(core.CreateOperator( |
| | 'Load', |
| | [db_reader], |
| | params_for_current_model, |
| | load_all=1, |
| | add_prefix=scope_name + '/', |
| | strip_prefix='gpu_0/', |
| | )) |
| | logger.info('Model {} is loaded from a checkpoint {}'.format( |
| | scope_name, self.get_model_file(model))) |
| |
|
| |
|
| | class Seq2SeqModelCaffe2EnsembleDecoder(Seq2SeqModelCaffe2EnsembleDecoderBase): |
| |
|
| | def get_model_file(self, model): |
| | return model['model_file'] |
| |
|
| | def get_db_type(self): |
| | return 'minidb' |
| |
|
| | def scope(self, scope_name, blob_name): |
| | return ( |
| | scope_name + '/' + blob_name |
| | if scope_name is not None |
| | else blob_name |
| | ) |
| |
|
| | def _build_decoder( |
| | self, |
| | model, |
| | step_model, |
| | model_params, |
| | scope, |
| | previous_tokens, |
| | timestep, |
| | fake_seq_lengths, |
| | ): |
| | attention_type = model_params['attention'] |
| | assert attention_type in ['none', 'regular'] |
| | use_attention = (attention_type != 'none') |
| |
|
| | with core.NameScope(scope): |
| | encoder_embeddings = seq2seq_util.build_embeddings( |
| | model=model, |
| | vocab_size=self.source_vocab_size, |
| | embedding_size=model_params['encoder_embedding_size'], |
| | name='encoder_embeddings', |
| | freeze_embeddings=False, |
| | ) |
| |
|
| | ( |
| | encoder_outputs, |
| | weighted_encoder_outputs, |
| | final_encoder_hidden_states, |
| | final_encoder_cell_states, |
| | encoder_units_per_layer, |
| | ) = seq2seq_util.build_embedding_encoder( |
| | model=model, |
| | encoder_params=model_params['encoder_type'], |
| | num_decoder_layers=len(model_params['decoder_layer_configs']), |
| | inputs=self.encoder_inputs, |
| | input_lengths=self.encoder_lengths, |
| | vocab_size=self.source_vocab_size, |
| | embeddings=encoder_embeddings, |
| | embedding_size=model_params['encoder_embedding_size'], |
| | use_attention=use_attention, |
| | num_gpus=0, |
| | forward_only=True, |
| | scope=scope, |
| | ) |
| | with core.NameScope(scope): |
| | if use_attention: |
| | |
| | encoder_outputs = model.net.Tile( |
| | encoder_outputs, |
| | 'encoder_outputs_tiled', |
| | tiles=self.beam_size, |
| | axis=1, |
| | ) |
| |
|
| | if weighted_encoder_outputs is not None: |
| | weighted_encoder_outputs = model.net.Tile( |
| | weighted_encoder_outputs, |
| | 'weighted_encoder_outputs_tiled', |
| | tiles=self.beam_size, |
| | axis=1, |
| | ) |
| |
|
| | decoder_embeddings = seq2seq_util.build_embeddings( |
| | model=model, |
| | vocab_size=self.target_vocab_size, |
| | embedding_size=model_params['decoder_embedding_size'], |
| | name='decoder_embeddings', |
| | freeze_embeddings=False, |
| | ) |
| | embedded_tokens_t_prev = step_model.net.Gather( |
| | [decoder_embeddings, previous_tokens], |
| | 'embedded_tokens_t_prev', |
| | ) |
| |
|
| | decoder_cells = [] |
| | decoder_units_per_layer = [] |
| | for i, layer_config in enumerate(model_params['decoder_layer_configs']): |
| | num_units = layer_config['num_units'] |
| | decoder_units_per_layer.append(num_units) |
| | if i == 0: |
| | input_size = model_params['decoder_embedding_size'] |
| | else: |
| | input_size = ( |
| | model_params['decoder_layer_configs'][i - 1]['num_units'] |
| | ) |
| |
|
| | cell = rnn_cell.LSTMCell( |
| | forward_only=True, |
| | input_size=input_size, |
| | hidden_size=num_units, |
| | forget_bias=0.0, |
| | memory_optimization=False, |
| | ) |
| | decoder_cells.append(cell) |
| |
|
| | with core.NameScope(scope): |
| | if final_encoder_hidden_states is not None: |
| | for i in range(len(final_encoder_hidden_states)): |
| | if final_encoder_hidden_states[i] is not None: |
| | final_encoder_hidden_states[i] = model.net.Tile( |
| | final_encoder_hidden_states[i], |
| | 'final_encoder_hidden_tiled_{}'.format(i), |
| | tiles=self.beam_size, |
| | axis=1, |
| | ) |
| | if final_encoder_cell_states is not None: |
| | for i in range(len(final_encoder_cell_states)): |
| | if final_encoder_cell_states[i] is not None: |
| | final_encoder_cell_states[i] = model.net.Tile( |
| | final_encoder_cell_states[i], |
| | 'final_encoder_cell_tiled_{}'.format(i), |
| | tiles=self.beam_size, |
| | axis=1, |
| | ) |
| | initial_states = \ |
| | seq2seq_util.build_initial_rnn_decoder_states( |
| | model=model, |
| | encoder_units_per_layer=encoder_units_per_layer, |
| | decoder_units_per_layer=decoder_units_per_layer, |
| | final_encoder_hidden_states=final_encoder_hidden_states, |
| | final_encoder_cell_states=final_encoder_cell_states, |
| | use_attention=use_attention, |
| | ) |
| |
|
| | attention_decoder = seq2seq_util.LSTMWithAttentionDecoder( |
| | encoder_outputs=encoder_outputs, |
| | encoder_output_dim=encoder_units_per_layer[-1], |
| | encoder_lengths=None, |
| | vocab_size=self.target_vocab_size, |
| | attention_type=attention_type, |
| | embedding_size=model_params['decoder_embedding_size'], |
| | decoder_num_units=decoder_units_per_layer[-1], |
| | decoder_cells=decoder_cells, |
| | weighted_encoder_outputs=weighted_encoder_outputs, |
| | name=scope, |
| | ) |
| | states_prev = step_model.net.AddExternalInputs(*[ |
| | '{}/{}_prev'.format(scope, s) |
| | for s in attention_decoder.get_state_names() |
| | ]) |
| | decoder_outputs, states = attention_decoder.apply( |
| | model=step_model, |
| | input_t=embedded_tokens_t_prev, |
| | seq_lengths=fake_seq_lengths, |
| | states=states_prev, |
| | timestep=timestep, |
| | ) |
| |
|
| | state_configs = [ |
| | BeamSearchForwardOnly.StateConfig( |
| | initial_value=initial_state, |
| | state_prev_link=BeamSearchForwardOnly.LinkConfig( |
| | blob=state_prev, |
| | offset=0, |
| | window=1, |
| | ), |
| | state_link=BeamSearchForwardOnly.LinkConfig( |
| | blob=state, |
| | offset=1, |
| | window=1, |
| | ), |
| | ) |
| | for initial_state, state_prev, state in zip( |
| | initial_states, |
| | states_prev, |
| | states, |
| | ) |
| | ] |
| |
|
| | with core.NameScope(scope): |
| | decoder_outputs_flattened, _ = step_model.net.Reshape( |
| | [decoder_outputs], |
| | [ |
| | 'decoder_outputs_flattened', |
| | 'decoder_outputs_and_contexts_combination_old_shape', |
| | ], |
| | shape=[-1, attention_decoder.get_output_dim()], |
| | ) |
| | output_logits = seq2seq_util.output_projection( |
| | model=step_model, |
| | decoder_outputs=decoder_outputs_flattened, |
| | decoder_output_size=attention_decoder.get_output_dim(), |
| | target_vocab_size=self.target_vocab_size, |
| | decoder_softmax_size=model_params['decoder_softmax_size'], |
| | ) |
| | |
| | output_probs = step_model.net.Softmax( |
| | output_logits, |
| | 'output_probs', |
| | ) |
| | output_log_probs = step_model.net.Log( |
| | output_probs, |
| | 'output_log_probs', |
| | ) |
| | if use_attention: |
| | attention_weights = attention_decoder.get_attention_weights() |
| | else: |
| | attention_weights = step_model.net.ConstantFill( |
| | [self.encoder_inputs], |
| | 'zero_attention_weights_tmp_1', |
| | value=0.0, |
| | ) |
| | attention_weights = step_model.net.Transpose( |
| | attention_weights, |
| | 'zero_attention_weights_tmp_2', |
| | ) |
| | attention_weights = step_model.net.Tile( |
| | attention_weights, |
| | 'zero_attention_weights_tmp', |
| | tiles=self.beam_size, |
| | axis=0, |
| | ) |
| |
|
| | return ( |
| | state_configs, |
| | output_log_probs, |
| | attention_weights, |
| | ) |
| |
|
| | def __init__( |
| | self, |
| | translate_params, |
| | ): |
| | self.models = translate_params['ensemble_models'] |
| | decoding_params = translate_params['decoding_params'] |
| | self.beam_size = decoding_params['beam_size'] |
| |
|
| | assert len(self.models) > 0 |
| | source_vocab = self.models[0]['source_vocab'] |
| | target_vocab = self.models[0]['target_vocab'] |
| | for model in self.models: |
| | assert model['source_vocab'] == source_vocab |
| | assert model['target_vocab'] == target_vocab |
| |
|
| | self.source_vocab_size = len(source_vocab) |
| | self.target_vocab_size = len(target_vocab) |
| |
|
| | self.decoder_scope_names = [ |
| | 'model{}'.format(i) for i in range(len(self.models)) |
| | ] |
| |
|
| | self.model = Seq2SeqModelHelper(init_params=True) |
| |
|
| | self.encoder_inputs = self.model.net.AddExternalInput('encoder_inputs') |
| | self.encoder_lengths = self.model.net.AddExternalInput( |
| | 'encoder_lengths' |
| | ) |
| | self.max_output_seq_len = self.model.net.AddExternalInput( |
| | 'max_output_seq_len' |
| | ) |
| |
|
| | fake_seq_lengths = self.model.param_init_net.ConstantFill( |
| | [], |
| | 'fake_seq_lengths', |
| | shape=[self.beam_size], |
| | value=100000, |
| | dtype=core.DataType.INT32, |
| | ) |
| |
|
| | beam_decoder = BeamSearchForwardOnly( |
| | beam_size=self.beam_size, |
| | model=self.model, |
| | go_token_id=seq2seq_util.GO_ID, |
| | eos_token_id=seq2seq_util.EOS_ID, |
| | ) |
| | step_model = beam_decoder.get_step_model() |
| |
|
| | state_configs = [] |
| | output_log_probs = [] |
| | attention_weights = [] |
| | for model, scope_name in zip( |
| | self.models, |
| | self.decoder_scope_names, |
| | ): |
| | ( |
| | state_configs_per_decoder, |
| | output_log_probs_per_decoder, |
| | attention_weights_per_decoder, |
| | ) = self._build_decoder( |
| | model=self.model, |
| | step_model=step_model, |
| | model_params=model['model_params'], |
| | scope=scope_name, |
| | previous_tokens=beam_decoder.get_previous_tokens(), |
| | timestep=beam_decoder.get_timestep(), |
| | fake_seq_lengths=fake_seq_lengths, |
| | ) |
| | state_configs.extend(state_configs_per_decoder) |
| | output_log_probs.append(output_log_probs_per_decoder) |
| | if attention_weights_per_decoder is not None: |
| | attention_weights.append(attention_weights_per_decoder) |
| |
|
| | assert len(attention_weights) > 0 |
| | num_decoders_with_attention_blob = ( |
| | self.model.param_init_net.ConstantFill( |
| | [], |
| | 'num_decoders_with_attention_blob', |
| | value=1 / float(len(attention_weights)), |
| | shape=[1], |
| | ) |
| | ) |
| | |
| | attention_weights_average = _weighted_sum( |
| | model=step_model, |
| | values=attention_weights, |
| | weight=num_decoders_with_attention_blob, |
| | output_name='attention_weights_average', |
| | ) |
| |
|
| | num_decoders_blob = self.model.param_init_net.ConstantFill( |
| | [], |
| | 'num_decoders_blob', |
| | value=1 / float(len(output_log_probs)), |
| | shape=[1], |
| | ) |
| | |
| | output_log_probs_average = _weighted_sum( |
| | model=step_model, |
| | values=output_log_probs, |
| | weight=num_decoders_blob, |
| | output_name='output_log_probs_average', |
| | ) |
| | word_rewards = self.model.param_init_net.ConstantFill( |
| | [], |
| | 'word_rewards', |
| | shape=[self.target_vocab_size], |
| | value=0.0, |
| | dtype=core.DataType.FLOAT, |
| | ) |
| | ( |
| | self.output_token_beam_list, |
| | self.output_prev_index_beam_list, |
| | self.output_score_beam_list, |
| | self.output_attention_weights_beam_list, |
| | ) = beam_decoder.apply( |
| | inputs=self.encoder_inputs, |
| | length=self.max_output_seq_len, |
| | log_probs=output_log_probs_average, |
| | attentions=attention_weights_average, |
| | state_configs=state_configs, |
| | data_dependencies=[], |
| | word_rewards=word_rewards, |
| | ) |
| |
|
| | workspace.RunNetOnce(self.model.param_init_net) |
| | workspace.FeedBlob( |
| | 'word_rewards', |
| | self.build_word_rewards( |
| | vocab_size=self.target_vocab_size, |
| | word_reward=translate_params['decoding_params']['word_reward'], |
| | unk_reward=translate_params['decoding_params']['unk_reward'], |
| | ) |
| | ) |
| |
|
| | workspace.CreateNet( |
| | self.model.net, |
| | input_blobs=[ |
| | str(self.encoder_inputs), |
| | str(self.encoder_lengths), |
| | str(self.max_output_seq_len), |
| | ], |
| | ) |
| |
|
| | logger.info('Params created: ') |
| | for param in self.model.params: |
| | logger.info(param) |
| |
|
| | def decode(self, numberized_input, max_output_seq_len): |
| | workspace.FeedBlob( |
| | self.encoder_inputs, |
| | np.array([ |
| | [token_id] for token_id in reversed(numberized_input) |
| | ]).astype(dtype=np.int32), |
| | ) |
| | workspace.FeedBlob( |
| | self.encoder_lengths, |
| | np.array([len(numberized_input)]).astype(dtype=np.int32), |
| | ) |
| | workspace.FeedBlob( |
| | self.max_output_seq_len, |
| | np.array([max_output_seq_len]).astype(dtype=np.int64), |
| | ) |
| |
|
| | workspace.RunNet(self.model.net) |
| |
|
| | num_steps = max_output_seq_len |
| | score_beam_list = workspace.FetchBlob(self.output_score_beam_list) |
| | token_beam_list = ( |
| | workspace.FetchBlob(self.output_token_beam_list) |
| | ) |
| | prev_index_beam_list = ( |
| | workspace.FetchBlob(self.output_prev_index_beam_list) |
| | ) |
| |
|
| | attention_weights_beam_list = ( |
| | workspace.FetchBlob(self.output_attention_weights_beam_list) |
| | ) |
| | best_indices = (num_steps, 0) |
| | for i in range(num_steps + 1): |
| | for hyp_index in range(self.beam_size): |
| | if ( |
| | ( |
| | token_beam_list[i][hyp_index][0] == |
| | seq2seq_util.EOS_ID or |
| | i == num_steps |
| | ) and |
| | ( |
| | score_beam_list[i][hyp_index][0] > |
| | score_beam_list[best_indices[0]][best_indices[1]][0] |
| | ) |
| | ): |
| | best_indices = (i, hyp_index) |
| |
|
| | i, hyp_index = best_indices |
| | output = [] |
| | attention_weights_per_token = [] |
| | best_score = -score_beam_list[i][hyp_index][0] |
| | while i > 0: |
| | output.append(token_beam_list[i][hyp_index][0]) |
| | attention_weights_per_token.append( |
| | attention_weights_beam_list[i][hyp_index] |
| | ) |
| | hyp_index = prev_index_beam_list[i][hyp_index][0] |
| | i -= 1 |
| |
|
| | attention_weights_per_token = reversed(attention_weights_per_token) |
| | |
| | attention_weights_per_token = [ |
| | list(reversed(attention_weights))[:len(numberized_input)] |
| | for attention_weights in attention_weights_per_token |
| | ] |
| | output = list(reversed(output)) |
| | return output, attention_weights_per_token, best_score |
| |
|
| |
|
| | def run_seq2seq_beam_decoder(args, model_params, decoding_params): |
| | source_vocab = seq2seq_util.gen_vocab( |
| | args.source_corpus, |
| | args.unk_threshold, |
| | ) |
| | logger.info('Source vocab size {}'.format(len(source_vocab))) |
| | target_vocab = seq2seq_util.gen_vocab( |
| | args.target_corpus, |
| | args.unk_threshold, |
| | ) |
| | inversed_target_vocab = {v: k for (k, v) in viewitems(target_vocab)} |
| | logger.info('Target vocab size {}'.format(len(target_vocab))) |
| |
|
| | decoder = Seq2SeqModelCaffe2EnsembleDecoder( |
| | translate_params=dict( |
| | ensemble_models=[dict( |
| | source_vocab=source_vocab, |
| | target_vocab=target_vocab, |
| | model_params=model_params, |
| | model_file=args.checkpoint, |
| | )], |
| | decoding_params=decoding_params, |
| | ), |
| | ) |
| | decoder.load_models() |
| |
|
| | for line in sys.stdin: |
| | numerized_source_sentence = seq2seq_util.get_numberized_sentence( |
| | line, |
| | source_vocab, |
| | ) |
| | translation, alignment, _ = decoder.decode( |
| | numerized_source_sentence, |
| | 2 * len(numerized_source_sentence) + 5, |
| | ) |
| | print(' '.join([inversed_target_vocab[tid] for tid in translation])) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description='Caffe2: Seq2Seq Translation', |
| | ) |
| | parser.add_argument('--source-corpus', type=str, default=None, |
| | help='Path to source corpus in a text file format. Each ' |
| | 'line in the file should contain a single sentence', |
| | required=True) |
| | parser.add_argument('--target-corpus', type=str, default=None, |
| | help='Path to target corpus in a text file format', |
| | required=True) |
| | parser.add_argument('--unk-threshold', type=int, default=50, |
| | help='Threshold frequency under which token becomes ' |
| | 'labeled unknown token') |
| |
|
| | parser.add_argument('--use-bidirectional-encoder', action='store_true', |
| | help='Set flag to use bidirectional recurrent network ' |
| | 'in encoder') |
| | parser.add_argument('--use-attention', action='store_true', |
| | help='Set flag to use seq2seq with attention model') |
| | parser.add_argument('--encoder-cell-num-units', type=int, default=512, |
| | help='Number of cell units per encoder layer') |
| | parser.add_argument('--encoder-num-layers', type=int, default=2, |
| | help='Number encoder layers') |
| | parser.add_argument('--decoder-cell-num-units', type=int, default=512, |
| | help='Number of cell units in the decoder layer') |
| | parser.add_argument('--decoder-num-layers', type=int, default=2, |
| | help='Number decoder layers') |
| | parser.add_argument('--encoder-embedding-size', type=int, default=256, |
| | help='Size of embedding in the encoder layer') |
| | parser.add_argument('--decoder-embedding-size', type=int, default=512, |
| | help='Size of embedding in the decoder layer') |
| | parser.add_argument('--decoder-softmax-size', type=int, default=None, |
| | help='Size of softmax layer in the decoder') |
| |
|
| | parser.add_argument('--beam-size', type=int, default=6, |
| | help='Size of beam for the decoder') |
| | parser.add_argument('--word-reward', type=float, default=0.0, |
| | help='Reward per each word generated.') |
| | parser.add_argument('--unk-reward', type=float, default=0.0, |
| | help='Reward per each UNK token generated. ' |
| | 'Typically should be negative.') |
| |
|
| | parser.add_argument('--checkpoint', type=str, default=None, |
| | help='Path to checkpoint', required=True) |
| |
|
| | args = parser.parse_args() |
| |
|
| | encoder_layer_configs = [ |
| | dict( |
| | num_units=args.encoder_cell_num_units, |
| | ), |
| | ] * args.encoder_num_layers |
| |
|
| | if args.use_bidirectional_encoder: |
| | assert args.encoder_cell_num_units % 2 == 0 |
| | encoder_layer_configs[0]['num_units'] /= 2 |
| |
|
| | decoder_layer_configs = [ |
| | dict( |
| | num_units=args.decoder_cell_num_units, |
| | ), |
| | ] * args.decoder_num_layers |
| |
|
| | run_seq2seq_beam_decoder( |
| | args, |
| | model_params=dict( |
| | attention=('regular' if args.use_attention else 'none'), |
| | decoder_layer_configs=decoder_layer_configs, |
| | encoder_type=dict( |
| | encoder_layer_configs=encoder_layer_configs, |
| | use_bidirectional_encoder=args.use_bidirectional_encoder, |
| | ), |
| | encoder_embedding_size=args.encoder_embedding_size, |
| | decoder_embedding_size=args.decoder_embedding_size, |
| | decoder_softmax_size=args.decoder_softmax_size, |
| | ), |
| | decoding_params=dict( |
| | beam_size=args.beam_size, |
| | word_reward=args.word_reward, |
| | unk_reward=args.unk_reward, |
| | ), |
| | ) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|