|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Evaluates text classification model."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import math
|
| import time
|
|
|
|
|
|
|
| import tensorflow as tf
|
|
|
| import graphs
|
|
|
| flags = tf.app.flags
|
| FLAGS = flags.FLAGS
|
|
|
| flags.DEFINE_string('master', '',
|
| 'BNS name prefix of the Tensorflow eval master, '
|
| 'or "local".')
|
| flags.DEFINE_string('eval_dir', '/tmp/text_eval',
|
| 'Directory where to write event logs.')
|
| flags.DEFINE_string('eval_data', 'test', 'Specify which dataset is used. '
|
| '("train", "valid", "test") ')
|
|
|
| flags.DEFINE_string('checkpoint_dir', '/tmp/text_train',
|
| 'Directory where to read model checkpoints.')
|
| flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run the eval.')
|
| flags.DEFINE_integer('num_examples', 32, 'Number of examples to run.')
|
| flags.DEFINE_bool('run_once', False, 'Whether to run eval only once.')
|
|
|
|
|
| def restore_from_checkpoint(sess, saver):
|
| """Restore model from checkpoint.
|
|
|
| Args:
|
| sess: Session.
|
| saver: Saver for restoring the checkpoint.
|
|
|
| Returns:
|
| bool: Whether the checkpoint was found and restored
|
| """
|
| ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
| if not ckpt or not ckpt.model_checkpoint_path:
|
| tf.logging.info('No checkpoint found at %s', FLAGS.checkpoint_dir)
|
| return False
|
|
|
| saver.restore(sess, ckpt.model_checkpoint_path)
|
| return True
|
|
|
|
|
| def run_eval(eval_ops, summary_writer, saver):
|
| """Runs evaluation over FLAGS.num_examples examples.
|
|
|
| Args:
|
| eval_ops: dict<metric name, tuple(value, update_op)>
|
| summary_writer: Summary writer.
|
| saver: Saver.
|
|
|
| Returns:
|
| dict<metric name, value>, with value being the average over all examples.
|
| """
|
| sv = tf.train.Supervisor(
|
| logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None)
|
| with sv.managed_session(
|
| master=FLAGS.master, start_standard_services=False) as sess:
|
| if not restore_from_checkpoint(sess, saver):
|
| return
|
| sv.start_queue_runners(sess)
|
|
|
| metric_names, ops = zip(*eval_ops.items())
|
| value_ops, update_ops = zip(*ops)
|
|
|
| value_ops_dict = dict(zip(metric_names, value_ops))
|
|
|
|
|
| num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
|
| tf.logging.info('Running %d batches for evaluation.', num_batches)
|
| for i in range(num_batches):
|
| if (i + 1) % 10 == 0:
|
| tf.logging.info('Running batch %d/%d...', i + 1, num_batches)
|
| if (i + 1) % 50 == 0:
|
| _log_values(sess, value_ops_dict)
|
| sess.run(update_ops)
|
|
|
| _log_values(sess, value_ops_dict, summary_writer=summary_writer)
|
|
|
|
|
| def _log_values(sess, value_ops, summary_writer=None):
|
| """Evaluate, log, and write summaries of the eval metrics in value_ops."""
|
| metric_names, value_ops = zip(*value_ops.items())
|
| values = sess.run(value_ops)
|
|
|
| tf.logging.info('Eval metric values:')
|
| summary = tf.summary.Summary()
|
| for name, val in zip(metric_names, values):
|
| summary.value.add(tag=name, simple_value=val)
|
| tf.logging.info('%s = %.3f', name, val)
|
|
|
| if summary_writer is not None:
|
| global_step_val = sess.run(tf.train.get_global_step())
|
| tf.logging.info('Finished eval for step ' + str(global_step_val))
|
| summary_writer.add_summary(summary, global_step_val)
|
|
|
|
|
| def main(_):
|
| tf.logging.set_verbosity(tf.logging.INFO)
|
| tf.gfile.MakeDirs(FLAGS.eval_dir)
|
| tf.logging.info('Building eval graph...')
|
| output = graphs.get_model().eval_graph(FLAGS.eval_data)
|
| eval_ops, moving_averaged_variables = output
|
|
|
| saver = tf.train.Saver(moving_averaged_variables)
|
| summary_writer = tf.summary.FileWriter(
|
| FLAGS.eval_dir, graph=tf.get_default_graph())
|
|
|
| while True:
|
| run_eval(eval_ops, summary_writer, saver)
|
| if FLAGS.run_once:
|
| break
|
| time.sleep(FLAGS.eval_interval_secs)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.app.run()
|
|
|