|
|
import os |
|
|
import shutil |
|
|
import tensorflow as tf |
|
|
from graph.gram_cls import GramCls |
|
|
from graph.main_cls import MainCls |
|
|
from graph.lemm import Lemm |
|
|
from graph.inflect import Inflect |
|
|
from graph.base import TfContext |
|
|
from utils import MyDefaultDict, CONFIG |
|
|
from tensorflow.python.tools import freeze_graph |
|
|
|
|
|
|
|
|
class RNN: |
|
|
def __init__(self, for_usage): |
|
|
self.config = CONFIG |
|
|
self.filler = self.config['filler'] |
|
|
self.checkpoints_keep = 200000 |
|
|
self.for_usage = for_usage |
|
|
self.default_config = self.config['graph_part_configs']['default'] |
|
|
self.key_configs = MyDefaultDict( |
|
|
lambda key: self.default_config, |
|
|
{ |
|
|
key: MyDefaultDict(lambda prop_key: self.default_config[prop_key], self.config['graph_part_configs'][key]) |
|
|
for key in self.config['graph_part_configs'] |
|
|
if key != 'default' |
|
|
} |
|
|
) |
|
|
self.export_path = self.config['export_path'] |
|
|
self.save_path = self.config['save_path'] |
|
|
self.model_key = self.config['model_key'] |
|
|
self.miss_steps = self.config['miss_steps'] if 'miss_steps' in self.config else [] |
|
|
self.start_char = self.config['start_token'] |
|
|
self.end_char = self.config['end_token'] |
|
|
self.gram_keys = [ |
|
|
key |
|
|
for key in sorted(self.config['grammemes_types'], key=lambda x: self.config['grammemes_types'][x]['index']) |
|
|
] |
|
|
self.main_class_k = self.config['main_class_k'] |
|
|
self.train_steps = self.config['train_steps'] |
|
|
|
|
|
if for_usage: |
|
|
self.devices = ['/cpu:0'] |
|
|
else: |
|
|
self.devices = self.config['train_devices'] |
|
|
|
|
|
if not os.path.exists(self.save_path): |
|
|
os.mkdir(self.save_path) |
|
|
|
|
|
self._build_graph() |
|
|
|
|
|
def _build_graph(self): |
|
|
self.graph = tf.Graph() |
|
|
self.checks = [] |
|
|
self.xs = [] |
|
|
self.x_seq_lens = [] |
|
|
|
|
|
self.x_inds = [] |
|
|
self.x_vals = [] |
|
|
self.x_shape = [] |
|
|
|
|
|
self.prints = [] |
|
|
|
|
|
with self.graph.as_default(), tf.device('/cpu:0'): |
|
|
self.is_training = tf.placeholder(tf.bool, name="IsTraining") |
|
|
self.learn_rate = tf.placeholder(tf.float32, name="LearningRate") |
|
|
self.batch_size = tf.placeholder(tf.int32, [], name="BatchSize") if self.for_usage else None |
|
|
self.optimiser = tf.train.AdamOptimizer(self.learn_rate) |
|
|
self.reset_optimizer = tf.variables_initializer(self.optimiser.variables()) |
|
|
self.gram_graph_parts = { |
|
|
gram: GramCls(gram, self.for_usage, self.config, self.key_configs[gram], self.optimiser, self.reset_optimizer) |
|
|
for gram in self.gram_keys |
|
|
} |
|
|
self.lem_graph_part = Lemm(self.for_usage, self.config, self.key_configs["lemm"], self.optimiser, self.reset_optimizer) |
|
|
self.main_graph_part = MainCls(self.for_usage, self.config, self.key_configs["main"], self.optimiser, self.reset_optimizer) |
|
|
self.inflect_graph_part = Inflect(self.for_usage, self.config, self.key_configs['inflect'], self.optimiser, self.reset_optimizer) |
|
|
|
|
|
for device_index, device_name in enumerate(self.devices): |
|
|
with tf.device(device_name): |
|
|
if self.for_usage: |
|
|
x_ind_pl = tf.placeholder(dtype=tf.int32, shape=(None, None), name='XIndexes') |
|
|
x_val_pl = tf.placeholder(dtype=tf.int32, shape=(None,), name='XValues') |
|
|
x_shape_pl = tf.placeholder(dtype=tf.int32, shape=(2,), name='XShape') |
|
|
x_ind = tf.dtypes.cast(x_ind_pl, dtype=tf.int64) |
|
|
x_val = tf.dtypes.cast(x_val_pl, dtype=tf.int64) |
|
|
x_shape = tf.dtypes.cast(x_shape_pl, dtype=tf.int64) |
|
|
|
|
|
x_sparse = tf.sparse.SparseTensor(x_ind, x_val, x_shape) |
|
|
x = tf.sparse.to_dense(x_sparse, default_value=self.end_char) |
|
|
self.x_inds.append(x_ind_pl) |
|
|
self.x_vals.append(x_val_pl) |
|
|
self.x_shape.append(x_shape_pl) |
|
|
else: |
|
|
x = tf.placeholder(dtype=tf.int32, shape=(None, None), name='X') |
|
|
self.xs.append(x) |
|
|
|
|
|
x_seq_len = tf.placeholder(dtype=tf.int32, shape=(None,), name='SeqLen') |
|
|
self.x_seq_lens.append(x_seq_len) |
|
|
|
|
|
for gram in self.gram_keys: |
|
|
self.gram_graph_parts[gram].build_graph_for_device(x, x_seq_len) |
|
|
|
|
|
gram_probs = [self.gram_graph_parts[gram].probs[-1] for gram in self.gram_keys] |
|
|
gram_keep_drops = [self.gram_graph_parts[gram].keep_drops[-1] for gram in self.gram_keys] |
|
|
self.main_graph_part.build_graph_for_device(x, x_seq_len, gram_probs, gram_keep_drops) |
|
|
self.prints.append(tf.print("main_result", self.main_graph_part.results[0].indices)) |
|
|
if self.for_usage: |
|
|
x_tiled = tf.contrib.seq2seq.tile_batch(x, multiplier=self.main_class_k) |
|
|
seq_len_tiled = tf.contrib.seq2seq.tile_batch(x_seq_len, multiplier=self.main_class_k) |
|
|
cls = tf.reshape(self.main_graph_part.results[0].indices, (-1,)) |
|
|
batch_size_tiled = self.batch_size * self.main_class_k |
|
|
self.lem_graph_part.build_graph_for_device(x_tiled, |
|
|
seq_len_tiled, |
|
|
batch_size_tiled, |
|
|
cls) |
|
|
self.lem_cls_result = tf.reshape(self.lem_graph_part.results[0], |
|
|
(self.batch_size, self.main_class_k, -1)) |
|
|
self.lem_graph_part.build_graph_for_device(x, |
|
|
x_seq_len, |
|
|
self.batch_size) |
|
|
self.lem_result = tf.expand_dims(self.lem_graph_part.results[1], 1) |
|
|
self.lem_class_pl = self.lem_graph_part.cls[0] |
|
|
else: |
|
|
self.lem_graph_part.build_graph_for_device(x, |
|
|
x_seq_len, |
|
|
self.batch_size) |
|
|
|
|
|
self.inflect_graph_part.build_graph_for_device(x, x_seq_len, self.batch_size) |
|
|
if self.for_usage: |
|
|
self.inflect_result = self.inflect_graph_part.results[0] |
|
|
self.inflect_x_class_pl = self.inflect_graph_part.x_cls[0] |
|
|
self.inflect_y_class_pl = self.inflect_graph_part.y_cls[0] |
|
|
|
|
|
for gram in self.gram_keys: |
|
|
self.gram_graph_parts[gram].build_graph_end() |
|
|
self.main_graph_part.build_graph_end() |
|
|
self.lem_graph_part.build_graph_end() |
|
|
self.inflect_graph_part.build_graph_end() |
|
|
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=self.checkpoints_keep) |
|
|
|
|
|
def restore(self, sess): |
|
|
latest_checkpoint = tf.train.latest_checkpoint(self.save_path) |
|
|
for gram in self.gram_keys: |
|
|
if gram not in self.config['ignore_restore']: |
|
|
self.gram_graph_parts[gram].restore(sess, latest_checkpoint) |
|
|
if self.main_graph_part.key not in self.config['ignore_restore']: |
|
|
self.main_graph_part.restore(sess, latest_checkpoint) |
|
|
if self.lem_graph_part.key not in self.config['ignore_restore']: |
|
|
self.lem_graph_part.restore(sess, latest_checkpoint) |
|
|
if self.inflect_graph_part not in self.config['ignore_restore']: |
|
|
self.inflect_graph_part.restore(sess, latest_checkpoint) |
|
|
|
|
|
if self.inflect_graph_part.settings['transfer_init']: |
|
|
self.inflect_graph_part.transfer_learning_init(sess) |
|
|
|
|
|
def train(self): |
|
|
config = tf.ConfigProto(allow_soft_placement=True) |
|
|
if not os.path.isdir(self.save_path): |
|
|
os.mkdir(self.save_path) |
|
|
|
|
|
with tf.Session(config = config, graph=self.graph) as sess: |
|
|
sess.run(tf.global_variables_initializer()) |
|
|
sess.run(tf.local_variables_initializer()) |
|
|
|
|
|
tc = TfContext(sess, self.saver, self.learn_rate) |
|
|
self.restore(sess) |
|
|
sess.run(self.reset_optimizer) |
|
|
|
|
|
for gram in self.gram_keys: |
|
|
if gram in self.train_steps: |
|
|
self.gram_graph_parts[gram].train(tc) |
|
|
|
|
|
if self.main_graph_part.key in self.train_steps: |
|
|
self.main_graph_part.train(tc) |
|
|
|
|
|
if self.lem_graph_part.key in self.train_steps: |
|
|
self.lem_graph_part.train(tc) |
|
|
|
|
|
if self.inflect_graph_part.key in self.train_steps: |
|
|
self.inflect_graph_part.train(tc) |
|
|
|
|
|
def release(self): |
|
|
with tf.Session(graph=self.graph) as sess: |
|
|
sess.run(tf.global_variables_initializer()) |
|
|
sess.run(tf.local_variables_initializer()) |
|
|
|
|
|
|
|
|
latest_checkpoint = tf.train.latest_checkpoint(self.save_path) |
|
|
if latest_checkpoint: |
|
|
self.restore(sess) |
|
|
|
|
|
if os.path.isdir(self.export_path): |
|
|
shutil.rmtree(self.export_path) |
|
|
|
|
|
output_dic = {} |
|
|
gram_op_dic = {} |
|
|
for gram in self.gram_keys: |
|
|
res = self.gram_graph_parts[gram].results[0] |
|
|
prob = self.gram_graph_parts[gram].probs[0] |
|
|
output_dic[f'res_{gram}'] = res |
|
|
output_dic[f'prob_{gram}'] = prob |
|
|
gram_op_dic[gram] = { |
|
|
'res': res.op.name, |
|
|
'prob': prob.op.name |
|
|
} |
|
|
|
|
|
output_dic['res_values'] = self.main_graph_part.results[0].values |
|
|
output_dic['res_indexes'] = self.main_graph_part.results[0].indices |
|
|
|
|
|
output_dic['lem_cls_result'] = self.lem_cls_result |
|
|
output_dic['lem_result'] = self.lem_result |
|
|
|
|
|
output_dic['inflect_result'] = self.inflect_graph_part.results[0] |
|
|
|
|
|
|
|
|
tf.saved_model.simple_save(sess, |
|
|
self.export_path, |
|
|
inputs={ |
|
|
'x_ind': self.x_inds[0], |
|
|
'x_val': self.x_vals[0], |
|
|
'x_shape': self.x_shape[0], |
|
|
'seq_len': self.x_seq_lens[0], |
|
|
'batch_size': self.batch_size, |
|
|
'lem_x_cls': self.lem_class_pl, |
|
|
'inflect_x_cls': self.inflect_x_class_pl, |
|
|
'inflect_y_cls': self.inflect_y_class_pl |
|
|
}, |
|
|
outputs=output_dic) |
|
|
|
|
|
|
|
|
input_graph = 'graph.pbtxt' |
|
|
tf.train.write_graph(sess.graph.as_graph_def(), self.export_path, input_graph, as_text=True) |
|
|
input_graph = os.path.join(self.export_path, input_graph) |
|
|
frozen_path = os.path.join(self.export_path, 'frozen_model.pb') |
|
|
output_ops = [output_dic[key].op.name for key in output_dic] |
|
|
output_ops = ",".join(output_ops) |
|
|
freeze_graph.freeze_graph(input_graph, |
|
|
"", |
|
|
False, |
|
|
latest_checkpoint, |
|
|
output_ops, |
|
|
"", |
|
|
"", |
|
|
frozen_path, |
|
|
True, |
|
|
"", |
|
|
input_saved_model_dir=self.export_path) |
|
|
|
|
|
op_dic = { |
|
|
key: output_dic[key].op.name |
|
|
for key in output_dic |
|
|
} |
|
|
|
|
|
op_dic['x_ind'] = self.x_inds[0].op.name |
|
|
op_dic['x_val'] = self.x_vals[0].op.name |
|
|
op_dic['x_shape'] = self.x_shape[0].op.name |
|
|
op_dic['seq_len'] = self.main_graph_part.x_seq_lens[0].op.name |
|
|
op_dic['batch_size'] = self.batch_size.op.name |
|
|
op_dic['lem_x_cls'] = self.lem_class_pl.op.name |
|
|
op_dic['inflect_x_cls'] = self.inflect_x_class_pl.op.name |
|
|
op_dic['inflect_y_cls'] = self.inflect_y_class_pl.op.name |
|
|
|
|
|
return frozen_path, \ |
|
|
gram_op_dic , \ |
|
|
op_dic |
|
|
|