| import os, pickle | |
| import numpy as np | |
| import tensorflow as tf | |
| from tqdm import tqdm | |
| from model import RNN | |
| from utils import CONFIG, decode_word, load_datasets | |
| class Tester: | |
| def __init__(self): | |
| self.config = CONFIG | |
| self.config['graph_part_configs']['lemm']['use_cls_placeholder'] = True | |
| self.rnn = RNN(True) | |
| self.chars = {c: index for index, c in enumerate(self.config['chars'])} | |
| self.batch_size = 65536 | |
| self.show_bad_items = False | |
| def test(self): | |
| config = tf.ConfigProto(allow_soft_placement=True) | |
| results = [] | |
| with tf.Session(config=config, graph=self.rnn.graph) as sess: | |
| sess.run(tf.global_variables_initializer()) | |
| sess.run(tf.local_variables_initializer()) | |
| self.rnn.restore(sess) | |
| for gram in self.rnn.gram_keys: | |
| full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, gram, self.rnn.gram_graph_parts[gram], 'test') | |
| result = f"{gram}. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}" | |
| results.append(result) | |
| tqdm.write(result) | |
| full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, 'main', self.rnn.main_graph_part, 'test') | |
| result = f"main. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}" | |
| results.append(result) | |
| tqdm.write(result) | |
| lemm_acc, _ = self.__test_lemmas__(sess, 'test') | |
| result = f"lemma_acc: {lemm_acc}" | |
| tqdm.write(result) | |
| results.append(result) | |
| inflect_acc, _ = self.__test_inflect__(sess, 'test') | |
| result = f"inflect_acc: {inflect_acc}" | |
| tqdm.write(result) | |
| results.append(result) | |
| tqdm.write(result) | |
| return "\n".join(results) | |
| def __get_classification_items__(self, sess, items, graph_part): | |
| wi = 0 | |
| pbar = tqdm(total=len(items), desc='Getting classification info') | |
| results = [] | |
| etalon = [] | |
| while wi < len(items): | |
| bi = 0 | |
| xs = [] | |
| indexes = [] | |
| seq_lens = [] | |
| max_len = 0 | |
| while bi < self.batch_size and wi < len(items): | |
| word = items[wi]['src'] | |
| etalon.append(items[wi]['y']) | |
| for c_index, char in enumerate(word): | |
| xs.append(self.chars[char] if char in self.chars else self.chars['UNDEFINED']) | |
| indexes.append([bi, c_index]) | |
| cur_len = len(word) | |
| if cur_len > max_len: | |
| max_len = cur_len | |
| seq_lens.append(cur_len) | |
| bi += 1 | |
| wi += 1 | |
| pbar.update(1) | |
| lnch = [graph_part.probs[0]] | |
| nn_results = sess.run( | |
| lnch, | |
| { | |
| self.rnn.batch_size: bi, | |
| self.rnn.x_seq_lens[0]: np.asarray(seq_lens), | |
| self.rnn.x_vals[0]: np.asarray(xs), | |
| self.rnn.x_inds[0]: np.asarray(indexes), | |
| self.rnn.x_shape[0]: np.asarray([bi, max_len]) | |
| } | |
| ) | |
| results.extend(nn_results[0]) | |
| return results, etalon | |
| def __get_lemma_items__(self, sess, items): | |
| wi = 0 | |
| pbar = tqdm(total=len(items)) | |
| while wi < len(items): | |
| bi = 0 | |
| xs = [] | |
| clss = [] | |
| indexes = [] | |
| seq_lens = [] | |
| max_len = 0 | |
| while bi < self.batch_size and wi < len(items): | |
| item = items[wi] | |
| word = item['x_src'] | |
| x_cls = item['main_cls'] | |
| for c_index, char in enumerate(word): | |
| xs.append(self.chars[char]) | |
| indexes.append([bi, c_index]) | |
| cur_len = len(word) | |
| clss.append(x_cls) | |
| if cur_len > max_len: | |
| max_len = cur_len | |
| seq_lens.append(cur_len) | |
| bi += 1 | |
| wi += 1 | |
| pbar.update(1) | |
| lnch = [self.rnn.lem_result] | |
| results = sess.run( | |
| lnch, | |
| { | |
| self.rnn.batch_size: bi, | |
| self.rnn.x_seq_lens[0]: np.asarray(seq_lens), | |
| self.rnn.x_vals[0]: np.asarray(xs), | |
| self.rnn.x_inds[0]: np.asarray(indexes), | |
| self.rnn.lem_class_pl: np.asarray(clss), | |
| self.rnn.x_shape[0]: np.asarray([bi, max_len]) | |
| } | |
| ) | |
| for word_src in results[0]: | |
| yield decode_word(word_src[0]) | |
| def __get_inflect_items__(self, sess, items): | |
| wi = 0 | |
| pbar = tqdm(total=len(items)) | |
| while wi < len(items): | |
| bi = 0 | |
| xs = [] | |
| x_clss = [] | |
| y_clss = [] | |
| indexes = [] | |
| seq_lens = [] | |
| max_len = 0 | |
| while bi < self.batch_size and wi < len(items): | |
| item = items[wi] | |
| word = item['x_src'] | |
| x_cls = item['x_cls'] | |
| y_cls = item['y_cls'] | |
| for c_index, char in enumerate(word): | |
| xs.append(self.chars[char]) | |
| indexes.append([bi, c_index]) | |
| cur_len = len(word) | |
| x_clss.append(x_cls) | |
| y_clss.append(y_cls) | |
| if cur_len > max_len: | |
| max_len = cur_len | |
| seq_lens.append(cur_len) | |
| bi += 1 | |
| wi += 1 | |
| pbar.update(1) | |
| lnch = [self.rnn.inflect_graph_part.results[0]] | |
| results = sess.run( | |
| lnch, | |
| { | |
| self.rnn.batch_size: bi, | |
| self.rnn.x_seq_lens[0]: np.asarray(seq_lens), | |
| self.rnn.x_vals[0]: np.asarray(xs), | |
| self.rnn.x_inds[0]: np.asarray(indexes), | |
| self.rnn.inflect_graph_part.x_cls[0]: np.asarray(x_clss), | |
| self.rnn.inflect_graph_part.y_cls[0]: np.asarray(y_clss), | |
| self.rnn.x_shape[0]: np.asarray([bi, max_len]) | |
| } | |
| ) | |
| for word_src in results[0]: | |
| yield decode_word(word_src) | |
| def __test_classification__(self, sess, key, graph_part, *ds_types): | |
| et_items = load_datasets(key, *ds_types) | |
| results, etalon = self.__get_classification_items__(sess, et_items, graph_part) | |
| total = len(etalon) | |
| total_classes = 0 | |
| full_correct = 0 | |
| part_correct = 0 | |
| bad_items = [] | |
| for index, et in enumerate(etalon): | |
| classes_count = et.sum() | |
| good_classes = np.argwhere(et == 1).ravel() | |
| rez_classes = np.argsort(results[index])[-classes_count:] | |
| total_classes += classes_count | |
| correct = True | |
| for cls in rez_classes: | |
| if cls in good_classes: | |
| part_correct += 1 | |
| else: | |
| correct = False | |
| if correct: | |
| full_correct += 1 | |
| else: | |
| bad_items.append((et_items[index], rez_classes)) | |
| full_acc = full_correct / total | |
| cls_correct = part_correct / total_classes | |
| return full_acc, cls_correct, bad_items | |
| def __test_lemmas__(self, sess, *ds_types): | |
| good_items = load_datasets("lemma", *ds_types) | |
| good_items = [ | |
| word | |
| for word in good_items | |
| if all([c in self.config['chars'] for c in word['x_src']]) | |
| ] | |
| results = list(self.__get_lemma_items__(sess, good_items)) | |
| bad_words = [] | |
| total = len(good_items) | |
| wrong = 0 | |
| for index, rez in enumerate(results): | |
| et_word = good_items[index] | |
| if rez != et_word['y_src']: | |
| wrong += 1 | |
| bad_words.append((et_word, rez)) | |
| correct = total - wrong | |
| acc = correct / total | |
| return acc, bad_words | |
| def __test_inflect__(self, sess, *ds_types): | |
| good_items = load_datasets("inflect", *ds_types) | |
| good_items = [ | |
| word | |
| for word in good_items | |
| if all([c in self.config['chars'] for c in word['x_src']]) | |
| ] | |
| bad_items = [] | |
| rez_words = list(self.__get_inflect_items__(sess, good_items)) | |
| total = len(good_items) | |
| wrong = 0 | |
| for index, rez in enumerate(rez_words): | |
| et_word = good_items[index] | |
| if rez != et_word['y_src']: | |
| wrong += 1 | |
| bad_items.append((et_word, rez)) | |
| correct = total - wrong | |
| acc = correct / total | |
| return acc, bad_items | |
| if __name__ == "__main__": | |
| tester = Tester() | |
| tester.test() | |