File size: 9,930 Bytes
0240c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import tensorflow as tf
import tf_utils as tfu
from tqdm import tqdm
from abc import ABC, abstractmethod
from utils import RANDOM


class TfContext:
    def __init__(self,
                 sess,
                 saver,
                 learn_rate_op):
        self.sess = sess
        self.saver = saver
        self.learn_rate_op = learn_rate_op
        self.epoch = 0


class GraphPartBase(ABC):
    def __init__(self, for_usage, global_settings, current_settings, optimiser, reset_optimiser, key, metric_names):
        self.key = key
        self.global_settings = global_settings
        self.filler = global_settings['filler']
        self.main_metric_name = current_settings['main_metric_type']
        self.settings = current_settings
        self.for_usage = for_usage
        self.optimiser = optimiser
        self.reset_optimiser = reset_optimiser
        self.metric_names = metric_names
        self.max_word_size = global_settings['max_word_size']
        self.checkpoints_keep = 10000
        self.chars_count = len(global_settings['chars']) + 1
        self.dataset_path = global_settings['dataset_path']
        self.grammemes_count = len(global_settings['grammemes_types'])
        self.main_classes = global_settings['main_classes']
        self.main_classes_count = len(self.main_classes)
        self.metrics_reset = []
        self.metrics_update = []
        self.devices_metrics = {metr: [] for metr in self.metric_names}
        self.main_scope_name = key.title()
        self.save_path = global_settings['save_path']
        self.dev_grads = []
        self.losses = []
        self.devices = global_settings['train_devices']
        self.devices_count = len(self.devices)
        self.dataset_path = global_settings['dataset_path']
        self.xs = []
        self.x_seq_lens = []
        self.prints = []
        self.main_cls_dic = self.global_settings['main_classes']
        self.learn_rate_val = self.settings['learn_rate']
        self.best_model_metric = None
        self.best_epoch = None
        self.init_checkpoint = None

    def train(self, tc):
        return_step = 0
        trains = self.__load_dataset__('train')
        valids = self.__load_dataset__('valid')
        self.best_model_metric = self.__valid_loop__(tc, valids)
        self.best_epoch = -1
        while True:
            tqdm.write(self.filler)
            tqdm.write(self.filler)
            tqdm.write(self.main_scope_name)

            train_main_metric = self.__train_loop__(tc, trains)
            valid_main_metric = self.__valid_loop__(tc, valids)

            tqdm.write(f"Epoch {tc.epoch} Train {self.main_metric_name}: {train_main_metric} Validation {self.main_metric_name}: {valid_main_metric}")
            need_decay = False
            delta = self.__calc_metric_delta__(self.best_model_metric, valid_main_metric)

            if delta > 0:
                if delta < self.settings['stop_main_metric_delta']:
                    tqdm.write(f"{self.main_metric_name} delta is less then min value")
                    need_decay = True
                else:
                    return_step = 0
                self.best_model_metric = valid_main_metric
                self.best_epoch = tc.epoch
                tc.saver.save(tc.sess, self.save_path, tc.epoch)
                tc.epoch += 1
            else:
                tqdm.write("Best epoch is better then current")
                tc.sess.run(self.reset_optimiser)
                need_decay = True
                self.__restore_best_epoch__(tc)

            if not need_decay:
                continue

            if return_step == self.settings['return_step']:
                self.__decay_params__()
                if self.learn_rate_val < self.settings['min_learn_rate']:
                    tqdm.write(f"Learning rate {self.learn_rate_val} is less then min learning rate")
                    finish = self.__before_finish__()
                    if finish:
                        break
                return_step = 0
            else:
                RANDOM.shuffle(trains)
                tqdm.write(f"Return step increased")
                return_step += 1

        return self.best_epoch, self.best_model_metric

    def __train_loop__(self, tc, trains):
        tc.sess.run(self.metrics_reset)
        for item in tqdm(trains, desc=f"Train, epoch {tc.epoch}"):
            launch = [self.optimize]
            launch.extend(self.metrics_update)
            if len(self.prints):
                launch.extend(self.prints)

            feed_dic = self.__create_feed_dict__('train', item)
            feed_dic[tc.learn_rate_op] = self.learn_rate_val
            tc.sess.run(launch, feed_dic)

        train_main_metric = self.__write_metrics_report__(tc.sess, "Train")
        return train_main_metric

    def __valid_loop__(self, tc, valids):
        tc.sess.run(self.metrics_reset)
        for item in tqdm(valids, desc=f"Validation, epoch {tc.epoch}"):
            launch = []
            launch.extend(self.metrics_update)
            if len(self.prints):
                launch.extend(self.prints)
            feed_dic = self.__create_feed_dict__('valid', item)
            tc.sess.run(launch, feed_dic)

        valid_main_metric = self.__write_metrics_report__(tc.sess, "Valid")
        return valid_main_metric

    def __calc_metric_delta__(self, best_model_metric, cur_model_metric):
        delta = best_model_metric - cur_model_metric
        if self.main_metric_name != "Loss":
            delta = -delta
        return delta

    def __before_finish__(self):
        return True

    def __decay_params__(self):
        self.learn_rate_val = self.learn_rate_val * self.settings['learn_rate_decay_step']
        tqdm.write(f"Learning rate decayed. New value: {self.learn_rate_val}")

    def __restore_best_epoch__(self, tc):
        if self.best_epoch == -1 and tc.epoch == 0:
            tqdm.write(f"Restoring from init_checkpoint {self.best_epoch}")
            self.restore(tc.sess, self.init_checkpoint)
        elif self.best_epoch == -1:
            tqdm.write(f"Restoring best epoch {tc.epoch}")
            self.restore(tc.sess, os.path.join(self.save_path, f"-{tc.epoch}"))
        else:
            tqdm.write(f"Restoring best epoch {self.best_epoch}")
            self.restore(tc.sess, os.path.join(self.save_path, f"-{self.best_epoch}"))

    def build_graph_end(self):
        with tf.variable_scope(self.main_scope_name, reuse=tf.AUTO_REUSE) as scope:
            self.metrics = {
                metr: tf.reduce_mean(self.devices_metrics[metr], name=metr)
                for metr in self.devices_metrics
            }
            if not self.for_usage:
                self.grads = tfu.average_gradients(self.dev_grads)
                if self.settings['clip_grads']:
                    self.grads = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in self.grads]

                self.optimize = self.optimiser.apply_gradients(self.grads, name='Optimize')
                self.loss = tf.reduce_sum(self.losses, name='GlobalLoss')

    def build_graph_for_device(self, *args):
        with tf.variable_scope(self.main_scope_name, reuse=tf.AUTO_REUSE) as scope:
            self.__build_graph_for_device__(*args)

    def restore(self, sess, check_point):
        try:
            vars = [
                var
                for var in tf.global_variables(f"{self.main_scope_name}/")
                if "Adam" not in var.name
            ]
            saver = tf.train.Saver(var_list=vars)
            saver.restore(sess, check_point)
            self.init_checkpoint = check_point
            tqdm.write(f"Restoration for graph part '{self.key}', scope {self.main_scope_name} success")
        except Exception as ex:
            tqdm.write(f"Restoration for graph part '{self.key}', scope {self.main_scope_name} failed. Error: {ex}")

    def __write_metrics_report__(self, sess, step_name):
        tqdm.write('')
        launch_results = sess.run(self.metrics)
        result = [f"{step_name} metrics: "]

        for index, metr in enumerate(self.metrics):
            result.append('{:>8}'.format(self.metric_names[index]))
            result.append("=")
            result.append("{0:.7f}".format(launch_results[metr]))
            result.append(" ")

        result = "".join(result)
        tqdm.write(result)
        return launch_results[self.main_metric_name]

    def create_mean_metric(self, metric_index, values):
        metr_epoch_loss, metr_update, metr_reset = tfu.create_reset_metric(
            tf.metrics.mean,
            self.metric_names[metric_index],
            values
        )
        self.metrics_reset.append(metr_reset)
        self.metrics_update.append(metr_update)
        self.devices_metrics[self.metric_names[metric_index]].append(metr_epoch_loss)

    def create_accuracy_metric(self, metric_index, labels, predictions):
        metr_epoch_loss, metr_update, metr_reset = tfu.create_reset_metric(
            tf.metrics.accuracy,
            self.metric_names[metric_index],
            labels=labels,
            predictions=predictions
        )
        self.metrics_reset.append(metr_reset)
        self.metrics_update.append(metr_update)
        self.devices_metrics[self.metric_names[metric_index]].append(metr_epoch_loss)

    def __create_feed_dict__(self, op_name, item):
        feed_dic = {}
        for dev_num, batch in enumerate(item):
            feed_dic[self.xs[dev_num]] = batch['x']
            feed_dic[self.x_seq_lens[dev_num]] = batch['x_seq_len']
            self.__update_feed_dict__(op_name, feed_dic, batch, dev_num)

        return feed_dic

    @abstractmethod
    def __update_feed_dict__(self, op_name, feed_dict, batch, dev_num):
        pass

    @abstractmethod
    def __build_graph_for_device__(self, *args):
        pass

    @abstractmethod
    def __load_dataset__(self, operation_name):
        return []