| import os, pickle |
| import numpy as np |
| import tensorflow as tf |
| from tqdm import tqdm |
| from model import RNN |
| from utils import CONFIG, decode_word, load_datasets |
|
|
|
|
| class Tester: |
| def __init__(self): |
| self.config = CONFIG |
| self.config['graph_part_configs']['lemm']['use_cls_placeholder'] = True |
| self.rnn = RNN(True) |
| self.chars = {c: index for index, c in enumerate(self.config['chars'])} |
| self.batch_size = 65536 |
| self.show_bad_items = False |
|
|
| def test(self): |
| config = tf.ConfigProto(allow_soft_placement=True) |
| results = [] |
| with tf.Session(config=config, graph=self.rnn.graph) as sess: |
| sess.run(tf.global_variables_initializer()) |
| sess.run(tf.local_variables_initializer()) |
| self.rnn.restore(sess) |
| for gram in self.rnn.gram_keys: |
| full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, gram, self.rnn.gram_graph_parts[gram], 'test') |
| result = f"{gram}. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}" |
| results.append(result) |
| tqdm.write(result) |
|
|
| full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, 'main', self.rnn.main_graph_part, 'test') |
| result = f"main. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}" |
| results.append(result) |
| tqdm.write(result) |
| lemm_acc, _ = self.__test_lemmas__(sess, 'test') |
| result = f"lemma_acc: {lemm_acc}" |
| tqdm.write(result) |
| results.append(result) |
| inflect_acc, _ = self.__test_inflect__(sess, 'test') |
| result = f"inflect_acc: {inflect_acc}" |
| tqdm.write(result) |
| results.append(result) |
| tqdm.write(result) |
|
|
| return "\n".join(results) |
|
|
| def __get_classification_items__(self, sess, items, graph_part): |
| wi = 0 |
| pbar = tqdm(total=len(items), desc='Getting classification info') |
| results = [] |
| etalon = [] |
|
|
| while wi < len(items): |
| bi = 0 |
| xs = [] |
| indexes = [] |
| seq_lens = [] |
| max_len = 0 |
|
|
| while bi < self.batch_size and wi < len(items): |
| word = items[wi]['src'] |
| etalon.append(items[wi]['y']) |
| for c_index, char in enumerate(word): |
| xs.append(self.chars[char] if char in self.chars else self.chars['UNDEFINED']) |
| indexes.append([bi, c_index]) |
| cur_len = len(word) |
| if cur_len > max_len: |
| max_len = cur_len |
| seq_lens.append(cur_len) |
| bi += 1 |
| wi += 1 |
| pbar.update(1) |
|
|
| lnch = [graph_part.probs[0]] |
| nn_results = sess.run( |
| lnch, |
| { |
| self.rnn.batch_size: bi, |
| self.rnn.x_seq_lens[0]: np.asarray(seq_lens), |
| self.rnn.x_vals[0]: np.asarray(xs), |
| self.rnn.x_inds[0]: np.asarray(indexes), |
| self.rnn.x_shape[0]: np.asarray([bi, max_len]) |
| } |
| ) |
| results.extend(nn_results[0]) |
|
|
| return results, etalon |
|
|
| def __get_lemma_items__(self, sess, items): |
| wi = 0 |
| pbar = tqdm(total=len(items)) |
| while wi < len(items): |
| bi = 0 |
| xs = [] |
| clss = [] |
| indexes = [] |
| seq_lens = [] |
| max_len = 0 |
|
|
| while bi < self.batch_size and wi < len(items): |
| item = items[wi] |
| word = item['x_src'] |
| x_cls = item['main_cls'] |
| for c_index, char in enumerate(word): |
| xs.append(self.chars[char]) |
| indexes.append([bi, c_index]) |
| cur_len = len(word) |
| clss.append(x_cls) |
| if cur_len > max_len: |
| max_len = cur_len |
| seq_lens.append(cur_len) |
| bi += 1 |
| wi += 1 |
| pbar.update(1) |
|
|
| lnch = [self.rnn.lem_result] |
| results = sess.run( |
| lnch, |
| { |
| self.rnn.batch_size: bi, |
| self.rnn.x_seq_lens[0]: np.asarray(seq_lens), |
| self.rnn.x_vals[0]: np.asarray(xs), |
| self.rnn.x_inds[0]: np.asarray(indexes), |
| self.rnn.lem_class_pl: np.asarray(clss), |
| self.rnn.x_shape[0]: np.asarray([bi, max_len]) |
| } |
| ) |
| for word_src in results[0]: |
| yield decode_word(word_src[0]) |
|
|
| def __get_inflect_items__(self, sess, items): |
| wi = 0 |
| pbar = tqdm(total=len(items)) |
| while wi < len(items): |
| bi = 0 |
| xs = [] |
| x_clss = [] |
| y_clss = [] |
| indexes = [] |
| seq_lens = [] |
| max_len = 0 |
|
|
| while bi < self.batch_size and wi < len(items): |
| item = items[wi] |
| word = item['x_src'] |
| x_cls = item['x_cls'] |
| y_cls = item['y_cls'] |
| for c_index, char in enumerate(word): |
| xs.append(self.chars[char]) |
| indexes.append([bi, c_index]) |
| cur_len = len(word) |
| x_clss.append(x_cls) |
| y_clss.append(y_cls) |
| if cur_len > max_len: |
| max_len = cur_len |
| seq_lens.append(cur_len) |
| bi += 1 |
| wi += 1 |
| pbar.update(1) |
|
|
| lnch = [self.rnn.inflect_graph_part.results[0]] |
| results = sess.run( |
| lnch, |
| { |
| self.rnn.batch_size: bi, |
| self.rnn.x_seq_lens[0]: np.asarray(seq_lens), |
| self.rnn.x_vals[0]: np.asarray(xs), |
| self.rnn.x_inds[0]: np.asarray(indexes), |
| self.rnn.inflect_graph_part.x_cls[0]: np.asarray(x_clss), |
| self.rnn.inflect_graph_part.y_cls[0]: np.asarray(y_clss), |
| self.rnn.x_shape[0]: np.asarray([bi, max_len]) |
| } |
| ) |
|
|
| for word_src in results[0]: |
| yield decode_word(word_src) |
|
|
| def __test_classification__(self, sess, key, graph_part, *ds_types): |
| et_items = load_datasets(key, *ds_types) |
| results, etalon = self.__get_classification_items__(sess, et_items, graph_part) |
| total = len(etalon) |
| total_classes = 0 |
| full_correct = 0 |
| part_correct = 0 |
| bad_items = [] |
|
|
| for index, et in enumerate(etalon): |
| classes_count = et.sum() |
| good_classes = np.argwhere(et == 1).ravel() |
| rez_classes = np.argsort(results[index])[-classes_count:] |
|
|
| total_classes += classes_count |
| correct = True |
| for cls in rez_classes: |
| if cls in good_classes: |
| part_correct += 1 |
| else: |
| correct = False |
|
|
| if correct: |
| full_correct += 1 |
| else: |
| bad_items.append((et_items[index], rez_classes)) |
|
|
| full_acc = full_correct / total |
| cls_correct = part_correct / total_classes |
| return full_acc, cls_correct, bad_items |
|
|
| def __test_lemmas__(self, sess, *ds_types): |
| good_items = load_datasets("lemma", *ds_types) |
| good_items = [ |
| word |
| for word in good_items |
| if all([c in self.config['chars'] for c in word['x_src']]) |
| ] |
| results = list(self.__get_lemma_items__(sess, good_items)) |
|
|
| bad_words = [] |
| total = len(good_items) |
| wrong = 0 |
| for index, rez in enumerate(results): |
| et_word = good_items[index] |
| if rez != et_word['y_src']: |
| wrong += 1 |
| bad_words.append((et_word, rez)) |
|
|
| correct = total - wrong |
| acc = correct / total |
| return acc, bad_words |
|
|
| def __test_inflect__(self, sess, *ds_types): |
| good_items = load_datasets("inflect", *ds_types) |
| good_items = [ |
| word |
| for word in good_items |
| if all([c in self.config['chars'] for c in word['x_src']]) |
| ] |
|
|
| bad_items = [] |
| rez_words = list(self.__get_inflect_items__(sess, good_items)) |
| total = len(good_items) |
| wrong = 0 |
| for index, rez in enumerate(rez_words): |
| et_word = good_items[index] |
| if rez != et_word['y_src']: |
| wrong += 1 |
| bad_items.append((et_word, rez)) |
|
|
| correct = total - wrong |
| acc = correct / total |
| return acc, bad_items |
|
|
|
|
| if __name__ == "__main__": |
| tester = Tester() |
| tester.test() |
|
|