|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Create TFRecord files of SequenceExample protos from dataset.
|
|
|
| Constructs 3 datasets:
|
| 1. Labeled data for the LSTM classification model, optionally with label gain.
|
| "*_classification.tfrecords" (for both unidirectional and bidirectional
|
| models).
|
| 2. Data for the unsupervised LM-LSTM model that predicts the next token.
|
| "*_lm.tfrecords" (generates forward and reverse data).
|
| 3. Data for the unsupervised SA-LSTM model that uses Seq2Seq.
|
| "*_sa.tfrecords".
|
| """
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import os
|
| import string
|
|
|
|
|
|
|
| import tensorflow as tf
|
|
|
| from data import data_utils
|
| from data import document_generators
|
|
|
| data = data_utils
|
| flags = tf.app.flags
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults '
|
| 'to FLAGS.output_dir/vocab.txt.')
|
| flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.')
|
|
|
|
|
| flags.DEFINE_boolean('label_gain', False,
|
| 'Enable linear label gain. If True, sentiment label will '
|
| 'be included at each timestep with linear weight '
|
| 'increase.')
|
|
|
|
|
| def build_shuffling_tf_record_writer(fname):
|
| return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname))
|
|
|
|
|
| def build_tf_record_writer(fname):
|
| return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname))
|
|
|
|
|
| def build_input_sequence(doc, vocab_ids):
|
| """Builds input sequence from file.
|
|
|
| Splits lines on whitespace. Treats punctuation as whitespace. For word-level
|
| sequences, only keeps terms that are in the vocab.
|
|
|
| Terms are added as token in the SequenceExample. The EOS_TOKEN is also
|
| appended. Label and weight features are set to 0.
|
|
|
| Args:
|
| doc: Document (defined in `document_generators`) from which to build the
|
| sequence.
|
| vocab_ids: dict<term, id>.
|
|
|
| Returns:
|
| SequenceExampleWrapper.
|
| """
|
| seq = data.SequenceWrapper()
|
| for token in document_generators.tokens(doc):
|
| if token in vocab_ids:
|
| seq.add_timestep().set_token(vocab_ids[token])
|
|
|
|
|
| seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN])
|
|
|
| return seq
|
|
|
|
|
| def make_vocab_ids(vocab_filename):
|
| if FLAGS.output_char:
|
| ret = dict([(char, i) for i, char in enumerate(string.printable)])
|
| ret[data.EOS_TOKEN] = len(string.printable)
|
| return ret
|
| else:
|
| with open(vocab_filename, encoding='utf-8') as vocab_f:
|
| return dict([(line.strip(), i) for i, line in enumerate(vocab_f)])
|
|
|
|
|
| def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
|
| """Generates training data."""
|
|
|
|
|
| writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM)
|
| writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA)
|
| writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS)
|
| writer_valid_class = build_tf_record_writer(data.VALID_CLASS)
|
| writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM)
|
| writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS)
|
| writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS)
|
|
|
| for doc in document_generators.documents(
|
| dataset='train', include_unlabeled=True, include_validation=True):
|
| input_seq = build_input_sequence(doc, vocab_ids)
|
| if len(input_seq) < 2:
|
| continue
|
| rev_seq = data.build_reverse_sequence(input_seq)
|
| lm_seq = data.build_lm_sequence(input_seq)
|
| rev_lm_seq = data.build_lm_sequence(rev_seq)
|
| seq_ae_seq = data.build_seq_ae_sequence(input_seq)
|
| if doc.label is not None:
|
|
|
| label_seq = data.build_labeled_sequence(
|
| input_seq,
|
| doc.label,
|
| label_gain=(FLAGS.label_gain and not doc.is_validation))
|
| bd_label_seq = data.build_labeled_sequence(
|
| data.build_bidirectional_seq(input_seq, rev_seq),
|
| doc.label,
|
| label_gain=(FLAGS.label_gain and not doc.is_validation))
|
| class_writer = writer_valid_class if doc.is_validation else writer_class
|
| bd_class_writer = (writer_bd_valid_class
|
| if doc.is_validation else writer_bd_class)
|
| class_writer.write(label_seq.seq.SerializeToString())
|
| bd_class_writer.write(bd_label_seq.seq.SerializeToString())
|
|
|
|
|
| lm_seq_ser = lm_seq.seq.SerializeToString()
|
| seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
|
| writer_lm_all.write(lm_seq_ser)
|
| writer_seq_ae_all.write(seq_ae_seq_ser)
|
| if not doc.is_validation:
|
| writer_lm.write(lm_seq_ser)
|
| writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
|
| writer_seq_ae.write(seq_ae_seq_ser)
|
|
|
|
|
| writer_lm.close()
|
| writer_seq_ae.close()
|
| writer_class.close()
|
| writer_valid_class.close()
|
| writer_rev_lm.close()
|
| writer_bd_class.close()
|
| writer_bd_valid_class.close()
|
|
|
|
|
| def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
|
| """Generates test data."""
|
|
|
| writer_lm = build_shuffling_tf_record_writer(data.TEST_LM)
|
| writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM)
|
| writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA)
|
| writer_class = build_tf_record_writer(data.TEST_CLASS)
|
| writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS)
|
|
|
| for doc in document_generators.documents(
|
| dataset='test', include_unlabeled=False, include_validation=True):
|
| input_seq = build_input_sequence(doc, vocab_ids)
|
| if len(input_seq) < 2:
|
| continue
|
| rev_seq = data.build_reverse_sequence(input_seq)
|
| lm_seq = data.build_lm_sequence(input_seq)
|
| rev_lm_seq = data.build_lm_sequence(rev_seq)
|
| seq_ae_seq = data.build_seq_ae_sequence(input_seq)
|
| label_seq = data.build_labeled_sequence(input_seq, doc.label)
|
| bd_label_seq = data.build_labeled_sequence(
|
| data.build_bidirectional_seq(input_seq, rev_seq), doc.label)
|
|
|
|
|
| writer_class.write(label_seq.seq.SerializeToString())
|
| writer_bd_class.write(bd_label_seq.seq.SerializeToString())
|
| lm_seq_ser = lm_seq.seq.SerializeToString()
|
| seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
|
| writer_lm.write(lm_seq_ser)
|
| writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
|
| writer_seq_ae.write(seq_ae_seq_ser)
|
| writer_lm_all.write(lm_seq_ser)
|
| writer_seq_ae_all.write(seq_ae_seq_ser)
|
|
|
|
|
| writer_lm.close()
|
| writer_rev_lm.close()
|
| writer_seq_ae.close()
|
| writer_class.close()
|
| writer_bd_class.close()
|
|
|
|
|
| def main(_):
|
| tf.logging.set_verbosity(tf.logging.INFO)
|
| tf.logging.info('Assigning vocabulary ids...')
|
| vocab_ids = make_vocab_ids(
|
| FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt'))
|
|
|
| with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all:
|
| with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all:
|
|
|
| tf.logging.info('Generating training data...')
|
| generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all)
|
|
|
| tf.logging.info('Generating test data...')
|
| generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.app.run()
|
|
|