import tensorflow as tf import tf_utils as tfu from graph.base import GraphPartBase class Inflect(GraphPartBase): def __init__(self, for_usage, global_settings, current_settings, optimiser, reset_optimiser): super().__init__(for_usage, global_settings, current_settings, optimiser, reset_optimiser, 'inflect', ["Loss", "AccuracyByChar", "Accuracy"]) self.chars_count = self.chars_count + 1 self.start_char_index = global_settings['start_token'] self.end_char_index = global_settings['end_token'] self.results = [] self.x_cls = [] self.ys = [] self.y_seq_lens = [] self.y_cls = [] self.keep_drops = [] self.decoder_keep_drops = [] def __build_graph_for_device__(self, x, x_seq_len, batch_size, cls=None): self.xs.append(x) self.x_seq_lens.append(x_seq_len) x_cls = tf.placeholder(dtype=tf.int32, shape=(None,), name='XClass') self.x_cls.append(x_cls) if batch_size is None: batch_size = self.settings['batch_size'] y = tf.placeholder(dtype=tf.int32, shape=(None, None), name='Y') self.ys.append(y) y_seq_len = tf.placeholder(dtype=tf.int32, shape=(None,), name='YSeqLen') self.y_seq_lens.append(y_seq_len) y_cls = tf.placeholder(dtype=tf.int32, shape=(None,), name='YClass') self.y_cls.append(y_cls) tfu.seq2seq(self, batch_size, x, y_cls, x_seq_len, y, x_cls, y_seq_len) def __update_feed_dict__(self, op_name, feed_dict, batch, dev_num): feed_dict[self.x_cls[dev_num]] = batch['x_cls'] feed_dict[self.y_cls[dev_num]] = batch['y_cls'] feed_dict[self.ys[dev_num]] = batch['y'] feed_dict[self.y_seq_lens[dev_num]] = batch['y_seq_len'] feed_dict[self.keep_drops[dev_num]] = self.settings['keep_drop'] feed_dict[self.decoder_keep_drops[dev_num]] = self.settings['decoder']['keep_drop'] def __load_dataset__(self, operation_name): items = list(tfu.load_inflect_dataset( self.dataset_path, self.devices_count, operation_name, self.settings['batch_size'] )) return items def transfer_learning_init(self, sess): my_prefix = f"{self.main_scope_name}/" vars = { var.name[len(my_prefix):]: var for var in tf.global_variables(my_prefix) if "Adam" not in var.name } lem_prefix = f"Lemm/" lem_vars = { var.name[len(lem_prefix):]: var for var in tf.global_variables(lem_prefix) if "Adam" not in var.name } for key in vars: my_var = vars[key] lem_var = lem_vars[key] value = sess.run(lem_var) sess.run(my_var.assign(value)) print()