| import os.path | |
| import tensorflow as tf | |
| class VisualizeCallback(tf.keras.callbacks.Callback): | |
| def __init__(self, epoch_interval=1, func=lambda model, epoch: None): | |
| super(VisualizeCallback, self).__init__() | |
| self.func = func | |
| self.epoch_interval = epoch_interval | |
| def on_epoch_end(self, epoch, logs=None): | |
| if epoch % self.epoch_interval == 0 and epoch > 0: | |
| self.func(self.model, epoch) | |
| class CheckpointCallback(tf.keras.callbacks.Callback): | |
| def __init__(self, vae, path, epoch_interval=1, restore_training=False, restore_path=None): | |
| super(CheckpointCallback, self).__init__() | |
| self.epoch_interval = epoch_interval | |
| self.path = path | |
| self.vae = vae | |
| self.ckpt = tf.train.Checkpoint(vae=vae, | |
| vae_optimizer=vae.optimizer) | |
| self.ckpt_manager = tf.train.CheckpointManager(checkpoint=self.ckpt, | |
| directory=self.path, | |
| max_to_keep=None) | |
| self.restore_training = restore_training | |
| self.restore_path = restore_path | |
| self._saved = False | |
| def on_epoch_end(self, epoch, logs=None): | |
| if epoch % self.epoch_interval == 0 and epoch > 0: | |
| self.ckpt_manager.save(checkpoint_number=epoch) | |
| def on_train_begin(self, logs=None): | |
| if self.restore_training: | |
| if self.restore_path is None: | |
| self.ckpt.restore(self.ckpt_manager.latest_checkpoint).except_partial() | |
| print("Resume training from checkpoint ", self.ckpt_manager.latest_checkpoint, "\n") | |
| else: | |
| self.ckpt.restore(self.restore_path) | |
| print("resume training from checkpoint ", self.restore_path, "\n") | |
| def on_train_end(self, logs=None): | |
| weights_path = os.path.join(self.path, "trained-vae") | |
| self.ckpt.save(file_prefix=weights_path) | |