File size: 4,208 Bytes
0240c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import tensorflow as tf
import tf_utils as tfu
from graph.base import GraphPartBase


class MainCls(GraphPartBase):
    def __init__(self, for_usage, global_settings, current_settings, optimiser, reset_optimiser):
        super().__init__(for_usage, global_settings, current_settings, optimiser, reset_optimiser, 'main', ['Loss', 'Accuracy'])
        self.checks = []
        self.weights = []
        self.keep_drops = []
        self.losses = []
        self.probs = []
        self.results = []
        self.ys = []
        self.drops = []
        self.top_k = global_settings['main_class_k']

    def __build_graph_for_device__(self, x, seq_len, gram_probs, gram_drop):
        self.xs.append(x)
        self.x_seq_lens.append(seq_len)
        self.drops.append(gram_drop)

        y = tf.placeholder(dtype=tf.int32, shape=(None, self.main_classes_count), name='Y')
        weights = tf.placeholder(dtype=tf.float32, shape=(None,), name='Weight')
        x_emd_init = tf.random_normal((self.chars_count, self.settings['char_vector_size']))
        x_emb = tf.get_variable("Embeddings", initializer=x_emd_init)
        rnn_input = tf.nn.embedding_lookup(x_emb, x)

        if self.for_usage:
            cls_keep_drop = tf.constant(1, dtype=tf.float32, name='KeepDrop')
        else:
            cls_keep_drop = tf.placeholder(dtype=tf.float32, name='KeepDrop')
        self.keep_drops.append(cls_keep_drop)

        init_state = tf.concat(gram_probs, 1)
        with tf.variable_scope("InitRnnState", reuse=tf.AUTO_REUSE) as scope:
            rez_size = self.settings['rnn_state_size']
            w_softmax = tf.get_variable("W", (init_state.shape[1], rez_size))
            b_softmax = tf.get_variable("b", [rez_size])
            init_state = tf.matmul(init_state, w_softmax) + b_softmax

        rnn_logits = tfu.build_rnn(rnn_input,
                                   cls_keep_drop,
                                   seq_len,
                                   self.settings,
                                   init_state,
                                   init_state,
                                   self.for_usage)

        logits = tfu.rnn_top('RnnTop',
                             rnn_logits,
                             self.settings,
                             self.main_classes_count)

        float_y = tf.cast(y, tf.float32)
        errors = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                         labels=float_y)
        errors = tf.reduce_sum(errors, axis=1)
        errors = errors * weights
        if not self.for_usage:
            self.checks.append(tf.check_numerics(errors, "ErrorNullCheck"))

        probs = tf.nn.softmax(logits)
        result = tf.math.top_k(probs, self.top_k, name="Results")
        loss = tf.reduce_sum(errors)

        if not self.for_usage:
            self.checks.append(tf.check_numerics(errors, "LossNullCheck"))

        vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.main_scope_name)
        grads = self.optimiser.compute_gradients(loss, var_list=vars)

        self.losses.append(loss)
        self.probs.append(probs)
        self.results.append(result)
        self.ys.append(y)
        self.dev_grads.append(grads)
        self.weights.append(weights)

        # metrics
        self.create_mean_metric(0, loss)
        labels = tf.math.argmax(y, axis=1)
        predictions = tf.math.argmax(probs, axis=1)
        self.create_accuracy_metric(1, labels, predictions)

    def __update_feed_dict__(self, op_name, feed_dict, batch, dev_num):
        for gram_drop in self.drops[dev_num]:
            feed_dict[gram_drop] = 1
        feed_dict[self.keep_drops[dev_num]] = 1 if op_name == 'test' else self.settings['keep_drop']
        feed_dict[self.ys[dev_num]] = batch['y']
        feed_dict[self.weights[dev_num]] = batch['weight']

    def __load_dataset__(self, operation_name):
        return list(
            tfu.load_cls_dataset(
                self.dataset_path,
                self.devices_count,
                operation_name,
                self.settings['train_batch_size'],
                self.settings['use_weights']
            )
        )