|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """CNN-BiLSTM sentence encoder."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import tensorflow as tf
|
| from base import embeddings
|
| from model import model_helpers
|
|
|
|
|
| class Encoder(object):
|
| def __init__(self, config, inputs, pretrained_embeddings):
|
| self._config = config
|
| self._inputs = inputs
|
|
|
| self.word_reprs = self._get_word_reprs(pretrained_embeddings)
|
| self.uni_fw, self.uni_bw = self._get_unidirectional_reprs(self.word_reprs)
|
| self.uni_reprs = tf.concat([self.uni_fw, self.uni_bw], axis=-1)
|
| self.bi_fw, self.bi_bw, self.bi_reprs = self._get_bidirectional_reprs(
|
| self.uni_reprs)
|
|
|
| def _get_word_reprs(self, pretrained_embeddings):
|
| with tf.variable_scope('word_embeddings'):
|
| word_embedding_matrix = tf.get_variable(
|
| 'word_embedding_matrix', initializer=pretrained_embeddings)
|
| word_embeddings = tf.nn.embedding_lookup(
|
| word_embedding_matrix, self._inputs.words)
|
| word_embeddings = tf.nn.dropout(word_embeddings, self._inputs.keep_prob)
|
| word_embeddings *= tf.get_variable('emb_scale', initializer=1.0)
|
|
|
| if not self._config.use_chars:
|
| return word_embeddings
|
|
|
| with tf.variable_scope('char_embeddings'):
|
| char_embedding_matrix = tf.get_variable(
|
| 'char_embeddings',
|
| shape=[embeddings.NUM_CHARS, self._config.char_embedding_size])
|
| char_embeddings = tf.nn.embedding_lookup(char_embedding_matrix,
|
| self._inputs.chars)
|
| shape = tf.shape(char_embeddings)
|
| char_embeddings = tf.reshape(
|
| char_embeddings,
|
| shape=[-1, shape[-2], self._config.char_embedding_size])
|
| char_reprs = []
|
| for filter_width in self._config.char_cnn_filter_widths:
|
| conv = tf.layers.conv1d(
|
| char_embeddings, self._config.char_cnn_n_filters, filter_width)
|
| conv = tf.nn.relu(conv)
|
| conv = tf.nn.dropout(tf.reduce_max(conv, axis=1),
|
| self._inputs.keep_prob)
|
| conv = tf.reshape(conv, shape=[-1, shape[1],
|
| self._config.char_cnn_n_filters])
|
| char_reprs.append(conv)
|
| return tf.concat([word_embeddings] + char_reprs, axis=-1)
|
|
|
| def _get_unidirectional_reprs(self, word_reprs):
|
| with tf.variable_scope('unidirectional_reprs'):
|
| word_lstm_input_size = (
|
| self._config.word_embedding_size if not self._config.use_chars else
|
| (self._config.word_embedding_size +
|
| len(self._config.char_cnn_filter_widths)
|
| * self._config.char_cnn_n_filters))
|
| word_reprs.set_shape([None, None, word_lstm_input_size])
|
| (outputs_fw, outputs_bw), _ = tf.nn.bidirectional_dynamic_rnn(
|
| model_helpers.multi_lstm_cell(self._config.unidirectional_sizes,
|
| self._inputs.keep_prob,
|
| self._config.projection_size),
|
| model_helpers.multi_lstm_cell(self._config.unidirectional_sizes,
|
| self._inputs.keep_prob,
|
| self._config.projection_size),
|
| word_reprs,
|
| dtype=tf.float32,
|
| sequence_length=self._inputs.lengths,
|
| scope='unilstm'
|
| )
|
| return outputs_fw, outputs_bw
|
|
|
| def _get_bidirectional_reprs(self, uni_reprs):
|
| with tf.variable_scope('bidirectional_reprs'):
|
| current_outputs = uni_reprs
|
| outputs_fw, outputs_bw = None, None
|
| for size in self._config.bidirectional_sizes:
|
| (outputs_fw, outputs_bw), _ = tf.nn.bidirectional_dynamic_rnn(
|
| model_helpers.lstm_cell(size, self._inputs.keep_prob,
|
| self._config.projection_size),
|
| model_helpers.lstm_cell(size, self._inputs.keep_prob,
|
| self._config.projection_size),
|
| current_outputs,
|
| dtype=tf.float32,
|
| sequence_length=self._inputs.lengths,
|
| scope='bilstm'
|
| )
|
| current_outputs = tf.concat([outputs_fw, outputs_bw], axis=-1)
|
| return outputs_fw, outputs_bw, current_outputs
|
|
|