import os import shutil import tensorflow as tf from graph.gram_cls import GramCls from graph.main_cls import MainCls from graph.lemm import Lemm from graph.inflect import Inflect from graph.base import TfContext from utils import MyDefaultDict, CONFIG from tensorflow.python.tools import freeze_graph class RNN: def __init__(self, for_usage): self.config = CONFIG self.filler = self.config['filler'] self.checkpoints_keep = 200000 self.for_usage = for_usage self.default_config = self.config['graph_part_configs']['default'] self.key_configs = MyDefaultDict( lambda key: self.default_config, { key: MyDefaultDict(lambda prop_key: self.default_config[prop_key], self.config['graph_part_configs'][key]) for key in self.config['graph_part_configs'] if key != 'default' } ) self.export_path = self.config['export_path'] self.save_path = self.config['save_path'] self.model_key = self.config['model_key'] self.miss_steps = self.config['miss_steps'] if 'miss_steps' in self.config else [] self.start_char = self.config['start_token'] self.end_char = self.config['end_token'] self.gram_keys = [ key for key in sorted(self.config['grammemes_types'], key=lambda x: self.config['grammemes_types'][x]['index']) ] self.main_class_k = self.config['main_class_k'] self.train_steps = self.config['train_steps'] if for_usage: self.devices = ['/cpu:0'] else: self.devices = self.config['train_devices'] if not os.path.exists(self.save_path): os.mkdir(self.save_path) self._build_graph() def _build_graph(self): self.graph = tf.Graph() self.checks = [] self.xs = [] self.x_seq_lens = [] self.x_inds = [] self.x_vals = [] self.x_shape = [] self.prints = [] with self.graph.as_default(), tf.device('/cpu:0'): self.is_training = tf.placeholder(tf.bool, name="IsTraining") self.learn_rate = tf.placeholder(tf.float32, name="LearningRate") self.batch_size = tf.placeholder(tf.int32, [], name="BatchSize") if self.for_usage else None self.optimiser = tf.train.AdamOptimizer(self.learn_rate) self.reset_optimizer = tf.variables_initializer(self.optimiser.variables()) self.gram_graph_parts = { gram: GramCls(gram, self.for_usage, self.config, self.key_configs[gram], self.optimiser, self.reset_optimizer) for gram in self.gram_keys } self.lem_graph_part = Lemm(self.for_usage, self.config, self.key_configs["lemm"], self.optimiser, self.reset_optimizer) self.main_graph_part = MainCls(self.for_usage, self.config, self.key_configs["main"], self.optimiser, self.reset_optimizer) self.inflect_graph_part = Inflect(self.for_usage, self.config, self.key_configs['inflect'], self.optimiser, self.reset_optimizer) for device_index, device_name in enumerate(self.devices): with tf.device(device_name): if self.for_usage: x_ind_pl = tf.placeholder(dtype=tf.int32, shape=(None, None), name='XIndexes') x_val_pl = tf.placeholder(dtype=tf.int32, shape=(None,), name='XValues') x_shape_pl = tf.placeholder(dtype=tf.int32, shape=(2,), name='XShape') x_ind = tf.dtypes.cast(x_ind_pl, dtype=tf.int64) x_val = tf.dtypes.cast(x_val_pl, dtype=tf.int64) x_shape = tf.dtypes.cast(x_shape_pl, dtype=tf.int64) x_sparse = tf.sparse.SparseTensor(x_ind, x_val, x_shape) x = tf.sparse.to_dense(x_sparse, default_value=self.end_char) self.x_inds.append(x_ind_pl) self.x_vals.append(x_val_pl) self.x_shape.append(x_shape_pl) else: x = tf.placeholder(dtype=tf.int32, shape=(None, None), name='X') self.xs.append(x) x_seq_len = tf.placeholder(dtype=tf.int32, shape=(None,), name='SeqLen') self.x_seq_lens.append(x_seq_len) for gram in self.gram_keys: self.gram_graph_parts[gram].build_graph_for_device(x, x_seq_len) gram_probs = [self.gram_graph_parts[gram].probs[-1] for gram in self.gram_keys] gram_keep_drops = [self.gram_graph_parts[gram].keep_drops[-1] for gram in self.gram_keys] self.main_graph_part.build_graph_for_device(x, x_seq_len, gram_probs, gram_keep_drops) self.prints.append(tf.print("main_result", self.main_graph_part.results[0].indices)) if self.for_usage: x_tiled = tf.contrib.seq2seq.tile_batch(x, multiplier=self.main_class_k) seq_len_tiled = tf.contrib.seq2seq.tile_batch(x_seq_len, multiplier=self.main_class_k) cls = tf.reshape(self.main_graph_part.results[0].indices, (-1,)) batch_size_tiled = self.batch_size * self.main_class_k self.lem_graph_part.build_graph_for_device(x_tiled, seq_len_tiled, batch_size_tiled, cls) self.lem_cls_result = tf.reshape(self.lem_graph_part.results[0], (self.batch_size, self.main_class_k, -1)) self.lem_graph_part.build_graph_for_device(x, x_seq_len, self.batch_size) self.lem_result = tf.expand_dims(self.lem_graph_part.results[1], 1) self.lem_class_pl = self.lem_graph_part.cls[0] else: self.lem_graph_part.build_graph_for_device(x, x_seq_len, self.batch_size) self.inflect_graph_part.build_graph_for_device(x, x_seq_len, self.batch_size) if self.for_usage: self.inflect_result = self.inflect_graph_part.results[0] self.inflect_x_class_pl = self.inflect_graph_part.x_cls[0] self.inflect_y_class_pl = self.inflect_graph_part.y_cls[0] for gram in self.gram_keys: self.gram_graph_parts[gram].build_graph_end() self.main_graph_part.build_graph_end() self.lem_graph_part.build_graph_end() self.inflect_graph_part.build_graph_end() self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=self.checkpoints_keep) def restore(self, sess): latest_checkpoint = tf.train.latest_checkpoint(self.save_path) for gram in self.gram_keys: if gram not in self.config['ignore_restore']: self.gram_graph_parts[gram].restore(sess, latest_checkpoint) if self.main_graph_part.key not in self.config['ignore_restore']: self.main_graph_part.restore(sess, latest_checkpoint) if self.lem_graph_part.key not in self.config['ignore_restore']: self.lem_graph_part.restore(sess, latest_checkpoint) if self.inflect_graph_part not in self.config['ignore_restore']: self.inflect_graph_part.restore(sess, latest_checkpoint) if self.inflect_graph_part.settings['transfer_init']: self.inflect_graph_part.transfer_learning_init(sess) def train(self): config = tf.ConfigProto(allow_soft_placement=True) if not os.path.isdir(self.save_path): os.mkdir(self.save_path) with tf.Session(config = config, graph=self.graph) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) tc = TfContext(sess, self.saver, self.learn_rate) self.restore(sess) sess.run(self.reset_optimizer) for gram in self.gram_keys: if gram in self.train_steps: self.gram_graph_parts[gram].train(tc) if self.main_graph_part.key in self.train_steps: self.main_graph_part.train(tc) if self.lem_graph_part.key in self.train_steps: self.lem_graph_part.train(tc) if self.inflect_graph_part.key in self.train_steps: self.inflect_graph_part.train(tc) def release(self): with tf.Session(graph=self.graph) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Loading checkpoint latest_checkpoint = tf.train.latest_checkpoint(self.save_path) if latest_checkpoint: self.restore(sess) if os.path.isdir(self.export_path): shutil.rmtree(self.export_path) output_dic = {} gram_op_dic = {} for gram in self.gram_keys: res = self.gram_graph_parts[gram].results[0] prob = self.gram_graph_parts[gram].probs[0] output_dic[f'res_{gram}'] = res output_dic[f'prob_{gram}'] = prob gram_op_dic[gram] = { 'res': res.op.name, 'prob': prob.op.name } output_dic['res_values'] = self.main_graph_part.results[0].values output_dic['res_indexes'] = self.main_graph_part.results[0].indices output_dic['lem_cls_result'] = self.lem_cls_result output_dic['lem_result'] = self.lem_result output_dic['inflect_result'] = self.inflect_graph_part.results[0] # Saving model tf.saved_model.simple_save(sess, self.export_path, inputs={ 'x_ind': self.x_inds[0], 'x_val': self.x_vals[0], 'x_shape': self.x_shape[0], 'seq_len': self.x_seq_lens[0], 'batch_size': self.batch_size, 'lem_x_cls': self.lem_class_pl, 'inflect_x_cls': self.inflect_x_class_pl, 'inflect_y_cls': self.inflect_y_class_pl }, outputs=output_dic) # Freezing graph input_graph = 'graph.pbtxt' tf.train.write_graph(sess.graph.as_graph_def(), self.export_path, input_graph, as_text=True) input_graph = os.path.join(self.export_path, input_graph) frozen_path = os.path.join(self.export_path, 'frozen_model.pb') output_ops = [output_dic[key].op.name for key in output_dic] output_ops = ",".join(output_ops) freeze_graph.freeze_graph(input_graph, "", False, latest_checkpoint, output_ops, "", "", frozen_path, True, "", input_saved_model_dir=self.export_path) op_dic = { key: output_dic[key].op.name for key in output_dic } op_dic['x_ind'] = self.x_inds[0].op.name op_dic['x_val'] = self.x_vals[0].op.name op_dic['x_shape'] = self.x_shape[0].op.name op_dic['seq_len'] = self.main_graph_part.x_seq_lens[0].op.name op_dic['batch_size'] = self.batch_size.op.name op_dic['lem_x_cls'] = self.lem_class_pl.op.name op_dic['inflect_x_cls'] = self.inflect_x_class_pl.op.name op_dic['inflect_y_cls'] = self.inflect_y_class_pl.op.name return frozen_path, \ gram_op_dic , \ op_dic