|
|
import os |
|
|
import tensorflow as tf |
|
|
import tf_utils as tfu |
|
|
from tqdm import tqdm |
|
|
from abc import ABC, abstractmethod |
|
|
from utils import RANDOM |
|
|
|
|
|
|
|
|
class TfContext: |
|
|
def __init__(self, |
|
|
sess, |
|
|
saver, |
|
|
learn_rate_op): |
|
|
self.sess = sess |
|
|
self.saver = saver |
|
|
self.learn_rate_op = learn_rate_op |
|
|
self.epoch = 0 |
|
|
|
|
|
|
|
|
class GraphPartBase(ABC): |
|
|
def __init__(self, for_usage, global_settings, current_settings, optimiser, reset_optimiser, key, metric_names): |
|
|
self.key = key |
|
|
self.global_settings = global_settings |
|
|
self.filler = global_settings['filler'] |
|
|
self.main_metric_name = current_settings['main_metric_type'] |
|
|
self.settings = current_settings |
|
|
self.for_usage = for_usage |
|
|
self.optimiser = optimiser |
|
|
self.reset_optimiser = reset_optimiser |
|
|
self.metric_names = metric_names |
|
|
self.max_word_size = global_settings['max_word_size'] |
|
|
self.checkpoints_keep = 10000 |
|
|
self.chars_count = len(global_settings['chars']) + 1 |
|
|
self.dataset_path = global_settings['dataset_path'] |
|
|
self.grammemes_count = len(global_settings['grammemes_types']) |
|
|
self.main_classes = global_settings['main_classes'] |
|
|
self.main_classes_count = len(self.main_classes) |
|
|
self.metrics_reset = [] |
|
|
self.metrics_update = [] |
|
|
self.devices_metrics = {metr: [] for metr in self.metric_names} |
|
|
self.main_scope_name = key.title() |
|
|
self.save_path = global_settings['save_path'] |
|
|
self.dev_grads = [] |
|
|
self.losses = [] |
|
|
self.devices = global_settings['train_devices'] |
|
|
self.devices_count = len(self.devices) |
|
|
self.dataset_path = global_settings['dataset_path'] |
|
|
self.xs = [] |
|
|
self.x_seq_lens = [] |
|
|
self.prints = [] |
|
|
self.main_cls_dic = self.global_settings['main_classes'] |
|
|
self.learn_rate_val = self.settings['learn_rate'] |
|
|
self.best_model_metric = None |
|
|
self.best_epoch = None |
|
|
self.init_checkpoint = None |
|
|
|
|
|
def train(self, tc): |
|
|
return_step = 0 |
|
|
trains = self.__load_dataset__('train') |
|
|
valids = self.__load_dataset__('valid') |
|
|
self.best_model_metric = self.__valid_loop__(tc, valids) |
|
|
self.best_epoch = -1 |
|
|
while True: |
|
|
tqdm.write(self.filler) |
|
|
tqdm.write(self.filler) |
|
|
tqdm.write(self.main_scope_name) |
|
|
|
|
|
train_main_metric = self.__train_loop__(tc, trains) |
|
|
valid_main_metric = self.__valid_loop__(tc, valids) |
|
|
|
|
|
tqdm.write(f"Epoch {tc.epoch} Train {self.main_metric_name}: {train_main_metric} Validation {self.main_metric_name}: {valid_main_metric}") |
|
|
need_decay = False |
|
|
delta = self.__calc_metric_delta__(self.best_model_metric, valid_main_metric) |
|
|
|
|
|
if delta > 0: |
|
|
if delta < self.settings['stop_main_metric_delta']: |
|
|
tqdm.write(f"{self.main_metric_name} delta is less then min value") |
|
|
need_decay = True |
|
|
else: |
|
|
return_step = 0 |
|
|
self.best_model_metric = valid_main_metric |
|
|
self.best_epoch = tc.epoch |
|
|
tc.saver.save(tc.sess, self.save_path, tc.epoch) |
|
|
tc.epoch += 1 |
|
|
else: |
|
|
tqdm.write("Best epoch is better then current") |
|
|
tc.sess.run(self.reset_optimiser) |
|
|
need_decay = True |
|
|
self.__restore_best_epoch__(tc) |
|
|
|
|
|
if not need_decay: |
|
|
continue |
|
|
|
|
|
if return_step == self.settings['return_step']: |
|
|
self.__decay_params__() |
|
|
if self.learn_rate_val < self.settings['min_learn_rate']: |
|
|
tqdm.write(f"Learning rate {self.learn_rate_val} is less then min learning rate") |
|
|
finish = self.__before_finish__() |
|
|
if finish: |
|
|
break |
|
|
return_step = 0 |
|
|
else: |
|
|
RANDOM.shuffle(trains) |
|
|
tqdm.write(f"Return step increased") |
|
|
return_step += 1 |
|
|
|
|
|
return self.best_epoch, self.best_model_metric |
|
|
|
|
|
def __train_loop__(self, tc, trains): |
|
|
tc.sess.run(self.metrics_reset) |
|
|
for item in tqdm(trains, desc=f"Train, epoch {tc.epoch}"): |
|
|
launch = [self.optimize] |
|
|
launch.extend(self.metrics_update) |
|
|
if len(self.prints): |
|
|
launch.extend(self.prints) |
|
|
|
|
|
feed_dic = self.__create_feed_dict__('train', item) |
|
|
feed_dic[tc.learn_rate_op] = self.learn_rate_val |
|
|
tc.sess.run(launch, feed_dic) |
|
|
|
|
|
train_main_metric = self.__write_metrics_report__(tc.sess, "Train") |
|
|
return train_main_metric |
|
|
|
|
|
def __valid_loop__(self, tc, valids): |
|
|
tc.sess.run(self.metrics_reset) |
|
|
for item in tqdm(valids, desc=f"Validation, epoch {tc.epoch}"): |
|
|
launch = [] |
|
|
launch.extend(self.metrics_update) |
|
|
if len(self.prints): |
|
|
launch.extend(self.prints) |
|
|
feed_dic = self.__create_feed_dict__('valid', item) |
|
|
tc.sess.run(launch, feed_dic) |
|
|
|
|
|
valid_main_metric = self.__write_metrics_report__(tc.sess, "Valid") |
|
|
return valid_main_metric |
|
|
|
|
|
def __calc_metric_delta__(self, best_model_metric, cur_model_metric): |
|
|
delta = best_model_metric - cur_model_metric |
|
|
if self.main_metric_name != "Loss": |
|
|
delta = -delta |
|
|
return delta |
|
|
|
|
|
def __before_finish__(self): |
|
|
return True |
|
|
|
|
|
def __decay_params__(self): |
|
|
self.learn_rate_val = self.learn_rate_val * self.settings['learn_rate_decay_step'] |
|
|
tqdm.write(f"Learning rate decayed. New value: {self.learn_rate_val}") |
|
|
|
|
|
def __restore_best_epoch__(self, tc): |
|
|
if self.best_epoch == -1 and tc.epoch == 0: |
|
|
tqdm.write(f"Restoring from init_checkpoint {self.best_epoch}") |
|
|
self.restore(tc.sess, self.init_checkpoint) |
|
|
elif self.best_epoch == -1: |
|
|
tqdm.write(f"Restoring best epoch {tc.epoch}") |
|
|
self.restore(tc.sess, os.path.join(self.save_path, f"-{tc.epoch}")) |
|
|
else: |
|
|
tqdm.write(f"Restoring best epoch {self.best_epoch}") |
|
|
self.restore(tc.sess, os.path.join(self.save_path, f"-{self.best_epoch}")) |
|
|
|
|
|
def build_graph_end(self): |
|
|
with tf.variable_scope(self.main_scope_name, reuse=tf.AUTO_REUSE) as scope: |
|
|
self.metrics = { |
|
|
metr: tf.reduce_mean(self.devices_metrics[metr], name=metr) |
|
|
for metr in self.devices_metrics |
|
|
} |
|
|
if not self.for_usage: |
|
|
self.grads = tfu.average_gradients(self.dev_grads) |
|
|
if self.settings['clip_grads']: |
|
|
self.grads = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in self.grads] |
|
|
|
|
|
self.optimize = self.optimiser.apply_gradients(self.grads, name='Optimize') |
|
|
self.loss = tf.reduce_sum(self.losses, name='GlobalLoss') |
|
|
|
|
|
def build_graph_for_device(self, *args): |
|
|
with tf.variable_scope(self.main_scope_name, reuse=tf.AUTO_REUSE) as scope: |
|
|
self.__build_graph_for_device__(*args) |
|
|
|
|
|
def restore(self, sess, check_point): |
|
|
try: |
|
|
vars = [ |
|
|
var |
|
|
for var in tf.global_variables(f"{self.main_scope_name}/") |
|
|
if "Adam" not in var.name |
|
|
] |
|
|
saver = tf.train.Saver(var_list=vars) |
|
|
saver.restore(sess, check_point) |
|
|
self.init_checkpoint = check_point |
|
|
tqdm.write(f"Restoration for graph part '{self.key}', scope {self.main_scope_name} success") |
|
|
except Exception as ex: |
|
|
tqdm.write(f"Restoration for graph part '{self.key}', scope {self.main_scope_name} failed. Error: {ex}") |
|
|
|
|
|
def __write_metrics_report__(self, sess, step_name): |
|
|
tqdm.write('') |
|
|
launch_results = sess.run(self.metrics) |
|
|
result = [f"{step_name} metrics: "] |
|
|
|
|
|
for index, metr in enumerate(self.metrics): |
|
|
result.append('{:>8}'.format(self.metric_names[index])) |
|
|
result.append("=") |
|
|
result.append("{0:.7f}".format(launch_results[metr])) |
|
|
result.append(" ") |
|
|
|
|
|
result = "".join(result) |
|
|
tqdm.write(result) |
|
|
return launch_results[self.main_metric_name] |
|
|
|
|
|
def create_mean_metric(self, metric_index, values): |
|
|
metr_epoch_loss, metr_update, metr_reset = tfu.create_reset_metric( |
|
|
tf.metrics.mean, |
|
|
self.metric_names[metric_index], |
|
|
values |
|
|
) |
|
|
self.metrics_reset.append(metr_reset) |
|
|
self.metrics_update.append(metr_update) |
|
|
self.devices_metrics[self.metric_names[metric_index]].append(metr_epoch_loss) |
|
|
|
|
|
def create_accuracy_metric(self, metric_index, labels, predictions): |
|
|
metr_epoch_loss, metr_update, metr_reset = tfu.create_reset_metric( |
|
|
tf.metrics.accuracy, |
|
|
self.metric_names[metric_index], |
|
|
labels=labels, |
|
|
predictions=predictions |
|
|
) |
|
|
self.metrics_reset.append(metr_reset) |
|
|
self.metrics_update.append(metr_update) |
|
|
self.devices_metrics[self.metric_names[metric_index]].append(metr_epoch_loss) |
|
|
|
|
|
def __create_feed_dict__(self, op_name, item): |
|
|
feed_dic = {} |
|
|
for dev_num, batch in enumerate(item): |
|
|
feed_dic[self.xs[dev_num]] = batch['x'] |
|
|
feed_dic[self.x_seq_lens[dev_num]] = batch['x_seq_len'] |
|
|
self.__update_feed_dict__(op_name, feed_dic, batch, dev_num) |
|
|
|
|
|
return feed_dic |
|
|
|
|
|
@abstractmethod |
|
|
def __update_feed_dict__(self, op_name, feed_dict, batch, dev_num): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def __build_graph_for_device__(self, *args): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def __load_dataset__(self, operation_name): |
|
|
return [] |
|
|
|