File size: 2,990 Bytes
0240c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import tensorflow as tf
import tf_utils as tfu
from graph.base import GraphPartBase


class Inflect(GraphPartBase):
    def __init__(self, for_usage, global_settings, current_settings, optimiser, reset_optimiser):
        super().__init__(for_usage, global_settings, current_settings, optimiser, reset_optimiser, 'inflect', ["Loss", "AccuracyByChar", "Accuracy"])
        self.chars_count = self.chars_count + 1
        self.start_char_index = global_settings['start_token']
        self.end_char_index = global_settings['end_token']
        self.results = []
        self.x_cls = []
        self.ys = []
        self.y_seq_lens = []
        self.y_cls = []
        self.keep_drops = []
        self.decoder_keep_drops = []

    def __build_graph_for_device__(self, x, x_seq_len, batch_size, cls=None):
        self.xs.append(x)
        self.x_seq_lens.append(x_seq_len)

        x_cls = tf.placeholder(dtype=tf.int32, shape=(None,), name='XClass')
        self.x_cls.append(x_cls)

        if batch_size is None:
            batch_size = self.settings['batch_size']

        y = tf.placeholder(dtype=tf.int32, shape=(None, None), name='Y')
        self.ys.append(y)

        y_seq_len = tf.placeholder(dtype=tf.int32, shape=(None,), name='YSeqLen')
        self.y_seq_lens.append(y_seq_len)

        y_cls = tf.placeholder(dtype=tf.int32, shape=(None,), name='YClass')
        self.y_cls.append(y_cls)

        tfu.seq2seq(self,
                    batch_size,
                    x,
                    y_cls,
                    x_seq_len,
                    y,
                    x_cls,
                    y_seq_len)

    def __update_feed_dict__(self, op_name, feed_dict, batch, dev_num):
        feed_dict[self.x_cls[dev_num]] = batch['x_cls']
        feed_dict[self.y_cls[dev_num]] = batch['y_cls']
        feed_dict[self.ys[dev_num]] = batch['y']
        feed_dict[self.y_seq_lens[dev_num]] = batch['y_seq_len']
        feed_dict[self.keep_drops[dev_num]] = self.settings['keep_drop']
        feed_dict[self.decoder_keep_drops[dev_num]] = self.settings['decoder']['keep_drop']

    def __load_dataset__(self, operation_name):
        items = list(tfu.load_inflect_dataset(
            self.dataset_path,
            self.devices_count,
            operation_name,
            self.settings['batch_size']
        ))
        return items

    def transfer_learning_init(self, sess):
        my_prefix = f"{self.main_scope_name}/"
        vars = {
            var.name[len(my_prefix):]: var
            for var in tf.global_variables(my_prefix)
            if "Adam" not in var.name
        }

        lem_prefix = f"Lemm/"
        lem_vars = {
            var.name[len(lem_prefix):]: var
            for var in tf.global_variables(lem_prefix)
            if "Adam" not in var.name
        }

        for key in vars:
            my_var = vars[key]
            lem_var = lem_vars[key]
            value = sess.run(lem_var)
            sess.run(my_var.assign(value))
        print()