|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Input utils for virtual adversarial text classification."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import os
|
|
|
|
|
|
|
| import tensorflow as tf
|
|
|
| from data import data_utils
|
|
|
|
|
| class VatxtInput(object):
|
| """Wrapper around NextQueuedSequenceBatch."""
|
|
|
| def __init__(self,
|
| batch,
|
| state_name=None,
|
| tokens=None,
|
| num_states=0,
|
| eos_id=None):
|
| """Construct VatxtInput.
|
|
|
| Args:
|
| batch: NextQueuedSequenceBatch.
|
| state_name: str, name of state to fetch and save.
|
| tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence.
|
| num_states: int The number of states to store.
|
| eos_id: int Id of end of Sequence.
|
| """
|
| self._batch = batch
|
| self._state_name = state_name
|
| self._tokens = (tokens if tokens is not None else
|
| batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID])
|
| self._num_states = num_states
|
|
|
| w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT]
|
| self._weights = w
|
|
|
| l = batch.sequences[data_utils.SequenceWrapper.F_LABEL]
|
| self._labels = l
|
|
|
|
|
| self._eos_weights = None
|
| if eos_id:
|
| ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32)
|
| self._eos_weights = ew
|
|
|
| @property
|
| def tokens(self):
|
| return self._tokens
|
|
|
| @property
|
| def weights(self):
|
| return self._weights
|
|
|
| @property
|
| def eos_weights(self):
|
| return self._eos_weights
|
|
|
| @property
|
| def labels(self):
|
| return self._labels
|
|
|
| @property
|
| def length(self):
|
| return self._batch.length
|
|
|
| @property
|
| def state_name(self):
|
| return self._state_name
|
|
|
| @property
|
| def state(self):
|
|
|
| state_names = _get_tuple_state_names(self._num_states, self._state_name)
|
| return tuple([
|
| tf.contrib.rnn.LSTMStateTuple(
|
| self._batch.state(c_name), self._batch.state(h_name))
|
| for c_name, h_name in state_names
|
| ])
|
|
|
| def save_state(self, value):
|
|
|
| state_names = _get_tuple_state_names(self._num_states, self._state_name)
|
| save_ops = []
|
| for (c_state, h_state), (c_name, h_name) in zip(value, state_names):
|
| save_ops.append(self._batch.save_state(c_name, c_state))
|
| save_ops.append(self._batch.save_state(h_name, h_state))
|
| return tf.group(*save_ops)
|
|
|
|
|
| def _get_tuple_state_names(num_states, base_name):
|
| """Returns state names for use with LSTM tuple state."""
|
| state_names = [('{}_{}_c'.format(i, base_name), '{}_{}_h'.format(
|
| i, base_name)) for i in range(num_states)]
|
| return state_names
|
|
|
|
|
| def _split_bidir_tokens(batch):
|
| tokens = batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]
|
|
|
|
|
| forward, reverse = [
|
| tf.squeeze(t, axis=[2]) for t in tf.split(tokens, 2, axis=2)
|
| ]
|
| return forward, reverse
|
|
|
|
|
| def _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq):
|
| """Returns input filenames for configuration.
|
|
|
| Args:
|
| phase: str, 'train', 'test', or 'valid'.
|
| bidir: bool, bidirectional model.
|
| pretrain: bool, pretraining or classification.
|
| use_seq2seq: bool, seq2seq data, only valid if pretrain=True.
|
|
|
| Returns:
|
| Tuple of filenames.
|
|
|
| Raises:
|
| ValueError: if an invalid combination of arguments is provided that does not
|
| map to any data files (e.g. pretrain=False, use_seq2seq=True).
|
| """
|
| data_spec = (phase, bidir, pretrain, use_seq2seq)
|
| data_specs = {
|
| ('train', True, True, False): (data_utils.TRAIN_LM,
|
| data_utils.TRAIN_REV_LM),
|
| ('train', True, False, False): (data_utils.TRAIN_BD_CLASS,),
|
| ('train', False, True, False): (data_utils.TRAIN_LM,),
|
| ('train', False, True, True): (data_utils.TRAIN_SA,),
|
| ('train', False, False, False): (data_utils.TRAIN_CLASS,),
|
| ('test', True, True, False): (data_utils.TEST_LM,
|
| data_utils.TRAIN_REV_LM),
|
| ('test', True, False, False): (data_utils.TEST_BD_CLASS,),
|
| ('test', False, True, False): (data_utils.TEST_LM,),
|
| ('test', False, True, True): (data_utils.TEST_SA,),
|
| ('test', False, False, False): (data_utils.TEST_CLASS,),
|
| ('valid', True, False, False): (data_utils.VALID_BD_CLASS,),
|
| ('valid', False, False, False): (data_utils.VALID_CLASS,),
|
| }
|
| if data_spec not in data_specs:
|
| raise ValueError(
|
| 'Data specification (phase, bidir, pretrain, use_seq2seq) %s not '
|
| 'supported' % str(data_spec))
|
|
|
| return data_specs[data_spec]
|
|
|
|
|
| def _read_single_sequence_example(file_list, tokens_shape=None):
|
| """Reads and parses SequenceExamples from TFRecord-encoded file_list."""
|
| tf.logging.info('Constructing TFRecordReader from files: %s', file_list)
|
| file_queue = tf.train.string_input_producer(file_list)
|
| reader = tf.TFRecordReader()
|
| seq_key, serialized_record = reader.read(file_queue)
|
| ctx, sequence = tf.parse_single_sequence_example(
|
| serialized_record,
|
| sequence_features={
|
| data_utils.SequenceWrapper.F_TOKEN_ID:
|
| tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64),
|
| data_utils.SequenceWrapper.F_LABEL:
|
| tf.FixedLenSequenceFeature([], dtype=tf.int64),
|
| data_utils.SequenceWrapper.F_WEIGHT:
|
| tf.FixedLenSequenceFeature([], dtype=tf.float32),
|
| })
|
| return seq_key, ctx, sequence
|
|
|
|
|
| def _read_and_batch(data_dir,
|
| fname,
|
| state_name,
|
| state_size,
|
| num_layers,
|
| unroll_steps,
|
| batch_size,
|
| bidir_input=False):
|
| """Inputs for text model.
|
|
|
| Args:
|
| data_dir: str, directory containing TFRecord files of SequenceExample.
|
| fname: str, input file name.
|
| state_name: string, key for saved state of LSTM.
|
| state_size: int, size of LSTM state.
|
| num_layers: int, the number of layers in the LSTM.
|
| unroll_steps: int, number of timesteps to unroll for TBTT.
|
| batch_size: int, batch size.
|
| bidir_input: bool, whether the input is bidirectional. If True, creates 2
|
| states, state_name and state_name + '_reverse'.
|
|
|
| Returns:
|
| Instance of NextQueuedSequenceBatch
|
|
|
| Raises:
|
| ValueError: if file for input specification is not found.
|
| """
|
| data_path = os.path.join(data_dir, fname)
|
| if not tf.gfile.Exists(data_path):
|
| raise ValueError('Failed to find file: %s' % data_path)
|
|
|
| tokens_shape = [2] if bidir_input else []
|
| seq_key, ctx, sequence = _read_single_sequence_example(
|
| [data_path], tokens_shape=tokens_shape)
|
|
|
| state_names = _get_tuple_state_names(num_layers, state_name)
|
| initial_states = {}
|
| for c_state, h_state in state_names:
|
| initial_states[c_state] = tf.zeros(state_size)
|
| initial_states[h_state] = tf.zeros(state_size)
|
| if bidir_input:
|
| rev_state_names = _get_tuple_state_names(num_layers,
|
| '{}_reverse'.format(state_name))
|
| for rev_c_state, rev_h_state in rev_state_names:
|
| initial_states[rev_c_state] = tf.zeros(state_size)
|
| initial_states[rev_h_state] = tf.zeros(state_size)
|
| batch = tf.contrib.training.batch_sequences_with_states(
|
| input_key=seq_key,
|
| input_sequences=sequence,
|
| input_context=ctx,
|
| input_length=tf.shape(sequence['token_id'])[0],
|
| initial_states=initial_states,
|
| num_unroll=unroll_steps,
|
| batch_size=batch_size,
|
| allow_small_batch=False,
|
| num_threads=4,
|
| capacity=batch_size * 10,
|
| make_keys_unique=True,
|
| make_keys_unique_seed=29392)
|
| return batch
|
|
|
|
|
| def inputs(data_dir=None,
|
| phase='train',
|
| bidir=False,
|
| pretrain=False,
|
| use_seq2seq=False,
|
| state_name='lstm',
|
| state_size=None,
|
| num_layers=0,
|
| batch_size=32,
|
| unroll_steps=100,
|
| eos_id=None):
|
| """Inputs for text model.
|
|
|
| Args:
|
| data_dir: str, directory containing TFRecord files of SequenceExample.
|
| phase: str, dataset for evaluation {'train', 'valid', 'test'}.
|
| bidir: bool, bidirectional LSTM.
|
| pretrain: bool, whether to read pretraining data or classification data.
|
| use_seq2seq: bool, whether to read seq2seq data or the language model data.
|
| state_name: string, key for saved state of LSTM.
|
| state_size: int, size of LSTM state.
|
| num_layers: int, the number of LSTM layers.
|
| batch_size: int, batch size.
|
| unroll_steps: int, number of timesteps to unroll for TBTT.
|
| eos_id: int, id of end of sequence. used for the kl weights on vat
|
| Returns:
|
| Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and
|
| reverse).
|
| """
|
| with tf.name_scope('inputs'):
|
| filenames = _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq)
|
|
|
| if bidir and pretrain:
|
|
|
|
|
| forward_fname, reverse_fname = filenames
|
| forward_batch = _read_and_batch(data_dir, forward_fname, state_name,
|
| state_size, num_layers, unroll_steps,
|
| batch_size)
|
| state_name_rev = state_name + '_reverse'
|
| reverse_batch = _read_and_batch(data_dir, reverse_fname, state_name_rev,
|
| state_size, num_layers, unroll_steps,
|
| batch_size)
|
| forward_input = VatxtInput(
|
| forward_batch,
|
| state_name=state_name,
|
| num_states=num_layers,
|
| eos_id=eos_id)
|
| reverse_input = VatxtInput(
|
| reverse_batch,
|
| state_name=state_name_rev,
|
| num_states=num_layers,
|
| eos_id=eos_id)
|
| return forward_input, reverse_input
|
|
|
| elif bidir:
|
|
|
|
|
| fname, = filenames
|
| batch = _read_and_batch(
|
| data_dir,
|
| fname,
|
| state_name,
|
| state_size,
|
| num_layers,
|
| unroll_steps,
|
| batch_size,
|
| bidir_input=True)
|
| forward_tokens, reverse_tokens = _split_bidir_tokens(batch)
|
| forward_input = VatxtInput(
|
| batch,
|
| state_name=state_name,
|
| tokens=forward_tokens,
|
| num_states=num_layers)
|
| reverse_input = VatxtInput(
|
| batch,
|
| state_name=state_name + '_reverse',
|
| tokens=reverse_tokens,
|
| num_states=num_layers)
|
| return forward_input, reverse_input
|
| else:
|
|
|
| fname, = filenames
|
| batch = _read_and_batch(
|
| data_dir,
|
| fname,
|
| state_name,
|
| state_size,
|
| num_layers,
|
| unroll_steps,
|
| batch_size,
|
| bidir_input=False)
|
| return VatxtInput(
|
| batch, state_name=state_name, num_states=num_layers, eos_id=eos_id)
|
|
|