Spaces:
Sleeping
Sleeping
| from __future__ import print_function | |
| from collections import deque | |
| from datetime import datetime | |
| import logging | |
| import os | |
| import pprint as pp | |
| import time | |
| import numpy as np | |
| import tensorflow.compat.v1 as tf | |
| tf.disable_v2_behavior() | |
| from tf_utils import shape | |
| class TFBaseModel(object): | |
| """Interface containing some boilerplate code for training tensorflow models. | |
| Subclassing models must implement self.calculate_loss(), which returns a tensor for the batch loss. | |
| Code for the training loop, parameter updates, checkpointing, and inference are implemented here and | |
| subclasses are mainly responsible for building the computational graph beginning with the placeholders | |
| and ending with the loss tensor. | |
| Args: | |
| reader: Class with attributes train_batch_generator, val_batch_generator, and test_batch_generator | |
| that yield dictionaries mapping tf.placeholder names (as strings) to batch data (numpy arrays). | |
| batch_size: Minibatch size. | |
| learning_rate: Learning rate. | |
| optimizer: 'rms' for RMSProp, 'adam' for Adam, 'sgd' for SGD | |
| grad_clip: Clip gradients elementwise to have norm at most equal to grad_clip. | |
| regularization_constant: Regularization constant applied to all trainable parameters. | |
| keep_prob: 1 - p, where p is the dropout probability | |
| early_stopping_steps: Number of steps to continue training after validation loss has | |
| stopped decreasing. | |
| warm_start_init_step: If nonzero, model will resume training a restored model beginning | |
| at warm_start_init_step. | |
| num_restarts: After validation loss plateaus, the best checkpoint will be restored and the | |
| learning rate will be halved. This process will repeat num_restarts times. | |
| enable_parameter_averaging: If true, model saves exponential weighted averages of parameters | |
| to separate checkpoint file. | |
| min_steps_to_checkpoint: Model only saves after min_steps_to_checkpoint training steps | |
| have passed. | |
| log_interval: Train and validation accuracies are logged every log_interval training steps. | |
| loss_averaging_window: Train/validation losses are averaged over the last loss_averaging_window | |
| training steps. | |
| num_validation_batches: Number of batches to be used in validation evaluation at each step. | |
| log_dir: Directory where logs are written. | |
| checkpoint_dir: Directory where checkpoints are saved. | |
| prediction_dir: Directory where predictions/outputs are saved. | |
| """ | |
| def __init__( | |
| self, | |
| reader=None, | |
| batch_sizes=[128], | |
| num_training_steps=20000, | |
| learning_rates=[.01], | |
| beta1_decays=[.99], | |
| optimizer='adam', | |
| grad_clip=5, | |
| regularization_constant=0.0, | |
| keep_prob=1.0, | |
| patiences=[3000], | |
| warm_start_init_step=0, | |
| enable_parameter_averaging=False, | |
| min_steps_to_checkpoint=100, | |
| log_interval=20, | |
| logging_level=logging.INFO, | |
| loss_averaging_window=100, | |
| validation_batch_size=64, | |
| log_dir='logs', | |
| checkpoint_dir='checkpoints', | |
| prediction_dir='predictions', | |
| ): | |
| assert len(batch_sizes) == len(learning_rates) == len(patiences) | |
| self.batch_sizes = batch_sizes | |
| self.learning_rates = learning_rates | |
| self.beta1_decays = beta1_decays | |
| self.patiences = patiences | |
| self.num_restarts = len(batch_sizes) - 1 | |
| self.restart_idx = 0 | |
| self.update_train_params() | |
| self.reader = reader | |
| self.num_training_steps = num_training_steps | |
| self.optimizer = optimizer | |
| self.grad_clip = grad_clip | |
| self.regularization_constant = regularization_constant | |
| self.warm_start_init_step = warm_start_init_step | |
| self.keep_prob_scalar = keep_prob | |
| self.enable_parameter_averaging = enable_parameter_averaging | |
| self.min_steps_to_checkpoint = min_steps_to_checkpoint | |
| self.log_interval = log_interval | |
| self.loss_averaging_window = loss_averaging_window | |
| self.validation_batch_size = validation_batch_size | |
| self.log_dir = log_dir | |
| self.logging_level = logging_level | |
| self.prediction_dir = prediction_dir | |
| self.checkpoint_dir = checkpoint_dir | |
| if self.enable_parameter_averaging: | |
| self.checkpoint_dir_averaged = checkpoint_dir + '_avg' | |
| self.init_logging(self.log_dir) | |
| logging.info('\nnew run with parameters:\n{}'.format(pp.pformat(self.__dict__))) | |
| self.graph = self.build_graph() | |
| self.session = tf.Session(graph=self.graph) | |
| logging.info('built graph') | |
| def update_train_params(self): | |
| self.batch_size = self.batch_sizes[self.restart_idx] | |
| self.learning_rate = self.learning_rates[self.restart_idx] | |
| self.beta1_decay = self.beta1_decays[self.restart_idx] | |
| self.early_stopping_steps = self.patiences[self.restart_idx] | |
| def calculate_loss(self): | |
| raise NotImplementedError('subclass must implement this') | |
| def fit(self): | |
| with self.session.as_default(): | |
| if self.warm_start_init_step: | |
| self.restore(self.warm_start_init_step) | |
| step = self.warm_start_init_step | |
| else: | |
| self.session.run(self.init) | |
| step = 0 | |
| train_generator = self.reader.train_batch_generator(self.batch_size) | |
| val_generator = self.reader.val_batch_generator(self.validation_batch_size) | |
| train_loss_history = deque(maxlen=self.loss_averaging_window) | |
| val_loss_history = deque(maxlen=self.loss_averaging_window) | |
| train_time_history = deque(maxlen=self.loss_averaging_window) | |
| val_time_history = deque(maxlen=self.loss_averaging_window) | |
| if not hasattr(self, 'metrics'): | |
| self.metrics = {} | |
| metric_histories = { | |
| metric_name: deque(maxlen=self.loss_averaging_window) for metric_name in self.metrics | |
| } | |
| best_validation_loss, best_validation_tstep = float('inf'), 0 | |
| while step < self.num_training_steps: | |
| # validation evaluation | |
| val_start = time.time() | |
| val_batch_df = next(val_generator) | |
| val_feed_dict = { | |
| getattr(self, placeholder_name, None): data | |
| for placeholder_name, data in val_batch_df.items() if hasattr(self, placeholder_name) | |
| } | |
| val_feed_dict.update({self.learning_rate_var: self.learning_rate, self.beta1_decay_var: self.beta1_decay}) | |
| if hasattr(self, 'keep_prob'): | |
| val_feed_dict.update({self.keep_prob: 1.0}) | |
| if hasattr(self, 'is_training'): | |
| val_feed_dict.update({self.is_training: False}) | |
| results = self.session.run( | |
| fetches=[self.loss] + self.metrics.values(), | |
| feed_dict=val_feed_dict | |
| ) | |
| val_loss = results[0] | |
| val_metrics = results[1:] if len(results) > 1 else [] | |
| val_metrics = dict(zip(self.metrics.keys(), val_metrics)) | |
| val_loss_history.append(val_loss) | |
| val_time_history.append(time.time() - val_start) | |
| for key in val_metrics: | |
| metric_histories[key].append(val_metrics[key]) | |
| if hasattr(self, 'monitor_tensors'): | |
| for name, tensor in self.monitor_tensors.items(): | |
| [np_val] = self.session.run([tensor], feed_dict=val_feed_dict) | |
| print(name) | |
| print('min', np_val.min()) | |
| print('max', np_val.max()) | |
| print('mean', np_val.mean()) | |
| print('std', np_val.std()) | |
| print('nans', np.isnan(np_val).sum()) | |
| print() | |
| print() | |
| print() | |
| # train step | |
| train_start = time.time() | |
| train_batch_df = next(train_generator) | |
| train_feed_dict = { | |
| getattr(self, placeholder_name, None): data | |
| for placeholder_name, data in train_batch_df.items() if hasattr(self, placeholder_name) | |
| } | |
| train_feed_dict.update({self.learning_rate_var: self.learning_rate, self.beta1_decay_var: self.beta1_decay}) | |
| if hasattr(self, 'keep_prob'): | |
| train_feed_dict.update({self.keep_prob: self.keep_prob_scalar}) | |
| if hasattr(self, 'is_training'): | |
| train_feed_dict.update({self.is_training: True}) | |
| train_loss, _ = self.session.run( | |
| fetches=[self.loss, self.step], | |
| feed_dict=train_feed_dict | |
| ) | |
| train_loss_history.append(train_loss) | |
| train_time_history.append(time.time() - train_start) | |
| if step % self.log_interval == 0: | |
| avg_train_loss = sum(train_loss_history) / len(train_loss_history) | |
| avg_val_loss = sum(val_loss_history) / len(val_loss_history) | |
| avg_train_time = sum(train_time_history) / len(train_time_history) | |
| avg_val_time = sum(val_time_history) / len(val_time_history) | |
| metric_log = ( | |
| "[[step {:>8}]] " | |
| "[[train {:>4}s]] loss: {:<12} " | |
| "[[val {:>4}s]] loss: {:<12} " | |
| ).format( | |
| step, | |
| round(avg_train_time, 4), | |
| round(avg_train_loss, 8), | |
| round(avg_val_time, 4), | |
| round(avg_val_loss, 8), | |
| ) | |
| early_stopping_metric = avg_val_loss | |
| for metric_name, metric_history in metric_histories.items(): | |
| metric_val = sum(metric_history) / len(metric_history) | |
| metric_log += '{}: {:<4} '.format(metric_name, round(metric_val, 4)) | |
| if metric_name == self.early_stopping_metric: | |
| early_stopping_metric = metric_val | |
| logging.info(metric_log) | |
| if early_stopping_metric < best_validation_loss: | |
| best_validation_loss = early_stopping_metric | |
| best_validation_tstep = step | |
| if step > self.min_steps_to_checkpoint: | |
| self.save(step) | |
| if self.enable_parameter_averaging: | |
| self.save(step, averaged=True) | |
| if step - best_validation_tstep > self.early_stopping_steps: | |
| if self.num_restarts is None or self.restart_idx >= self.num_restarts: | |
| logging.info('best validation loss of {} at training step {}'.format( | |
| best_validation_loss, best_validation_tstep)) | |
| logging.info('early stopping - ending training.') | |
| return | |
| if self.restart_idx < self.num_restarts: | |
| self.restore(best_validation_tstep) | |
| step = best_validation_tstep | |
| self.restart_idx += 1 | |
| self.update_train_params() | |
| train_generator = self.reader.train_batch_generator(self.batch_size) | |
| step += 1 | |
| if step <= self.min_steps_to_checkpoint: | |
| best_validation_tstep = step | |
| self.save(step) | |
| if self.enable_parameter_averaging: | |
| self.save(step, averaged=True) | |
| logging.info('num_training_steps reached - ending training') | |
| def predict(self, chunk_size=256): | |
| if not os.path.isdir(self.prediction_dir): | |
| os.makedirs(self.prediction_dir) | |
| if hasattr(self, 'prediction_tensors'): | |
| prediction_dict = {tensor_name: [] for tensor_name in self.prediction_tensors} | |
| test_generator = self.reader.test_batch_generator(chunk_size) | |
| for i, test_batch_df in enumerate(test_generator): | |
| if i % 10 == 0: | |
| print(i*len(test_batch_df)) | |
| test_feed_dict = { | |
| getattr(self, placeholder_name, None): data | |
| for placeholder_name, data in test_batch_df.items() if hasattr(self, placeholder_name) | |
| } | |
| if hasattr(self, 'keep_prob'): | |
| test_feed_dict.update({self.keep_prob: 1.0}) | |
| if hasattr(self, 'is_training'): | |
| test_feed_dict.update({self.is_training: False}) | |
| tensor_names, tf_tensors = zip(*self.prediction_tensors.items()) | |
| np_tensors = self.session.run( | |
| fetches=tf_tensors, | |
| feed_dict=test_feed_dict | |
| ) | |
| for tensor_name, tensor in zip(tensor_names, np_tensors): | |
| prediction_dict[tensor_name].append(tensor) | |
| for tensor_name, tensor in prediction_dict.items(): | |
| np_tensor = np.concatenate(tensor, 0) | |
| save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name)) | |
| logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file)) | |
| np.save(save_file, np_tensor) | |
| if hasattr(self, 'parameter_tensors'): | |
| for tensor_name, tensor in self.parameter_tensors.items(): | |
| np_tensor = tensor.eval(self.session) | |
| save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name)) | |
| logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file)) | |
| np.save(save_file, np_tensor) | |
| def save(self, step, averaged=False): | |
| saver = self.saver_averaged if averaged else self.saver | |
| checkpoint_dir = self.checkpoint_dir_averaged if averaged else self.checkpoint_dir | |
| if not os.path.isdir(checkpoint_dir): | |
| logging.info('creating checkpoint directory {}'.format(checkpoint_dir)) | |
| os.mkdir(checkpoint_dir) | |
| model_path = os.path.join(checkpoint_dir, 'model') | |
| logging.info('saving model to {}'.format(model_path)) | |
| saver.save(self.session, model_path, global_step=step) | |
| def restore(self, step=None, averaged=False): | |
| saver = self.saver_averaged if averaged else self.saver | |
| checkpoint_dir = self.checkpoint_dir_averaged if averaged else self.checkpoint_dir | |
| if not step: | |
| model_path = tf.train.latest_checkpoint(checkpoint_dir) | |
| logging.info('restoring model parameters from {}'.format(model_path)) | |
| saver.restore(self.session, model_path) | |
| else: | |
| model_path = os.path.join( | |
| checkpoint_dir, 'model{}-{}'.format('_avg' if averaged else '', step) | |
| ) | |
| logging.info('restoring model from {}'.format(model_path)) | |
| saver.restore(self.session, model_path) | |
| def init_logging(self, log_dir): | |
| if not os.path.isdir(log_dir): | |
| os.makedirs(log_dir) | |
| date_str = datetime.now().strftime('%Y-%m-%d_%H-%M') | |
| log_file = 'log_{}.txt'.format(date_str) | |
| try: # Python 2 | |
| reload(logging) # bad | |
| except NameError: # Python 3 | |
| import logging | |
| logging.basicConfig( | |
| filename=os.path.join(log_dir, log_file), | |
| level=self.logging_level, | |
| format='[[%(asctime)s]] %(message)s', | |
| datefmt='%m/%d/%Y %I:%M:%S %p' | |
| ) | |
| logging.getLogger().addHandler(logging.StreamHandler()) | |
| def update_parameters(self, loss): | |
| if self.regularization_constant != 0: | |
| l2_norm = tf.reduce_sum([tf.sqrt(tf.reduce_sum(tf.square(param))) for param in tf.trainable_variables()]) | |
| loss = loss + self.regularization_constant*l2_norm | |
| optimizer = self.get_optimizer(self.learning_rate_var, self.beta1_decay_var) | |
| grads = optimizer.compute_gradients(loss) | |
| clipped = [(tf.clip_by_value(g, -self.grad_clip, self.grad_clip), v_) for g, v_ in grads] | |
| update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
| with tf.control_dependencies(update_ops): | |
| step = optimizer.apply_gradients(clipped, global_step=self.global_step) | |
| if self.enable_parameter_averaging: | |
| maintain_averages_op = self.ema.apply(tf.trainable_variables()) | |
| with tf.control_dependencies([step]): | |
| self.step = tf.group(maintain_averages_op) | |
| else: | |
| self.step = step | |
| logging.info('all parameters:') | |
| logging.info(pp.pformat([(var.name, shape(var)) for var in tf.global_variables()])) | |
| logging.info('trainable parameters:') | |
| logging.info(pp.pformat([(var.name, shape(var)) for var in tf.trainable_variables()])) | |
| logging.info('trainable parameter count:') | |
| logging.info(str(np.sum(np.prod(shape(var)) for var in tf.trainable_variables()))) | |
| def get_optimizer(self, learning_rate, beta1_decay): | |
| if self.optimizer == 'adam': | |
| return tf.train.AdamOptimizer(learning_rate, beta1=beta1_decay) | |
| elif self.optimizer == 'gd': | |
| return tf.train.GradientDescentOptimizer(learning_rate) | |
| elif self.optimizer == 'rms': | |
| return tf.train.RMSPropOptimizer(learning_rate, decay=beta1_decay, momentum=0.9) | |
| else: | |
| assert False, 'optimizer must be adam, gd, or rms' | |
| def build_graph(self): | |
| with tf.Graph().as_default() as graph: | |
| self.ema = tf.train.ExponentialMovingAverage(decay=0.99) | |
| self.global_step = tf.Variable(0, trainable=False) | |
| self.learning_rate_var = tf.Variable(0.0, trainable=False) | |
| self.beta1_decay_var = tf.Variable(0.0, trainable=False) | |
| self.loss = self.calculate_loss() | |
| self.update_parameters(self.loss) | |
| self.saver = tf.train.Saver(max_to_keep=1) | |
| if self.enable_parameter_averaging: | |
| self.saver_averaged = tf.train.Saver(self.ema.variables_to_restore(), max_to_keep=1) | |
| self.init = tf.global_variables_initializer() | |
| return graph | |