DeepMorphy / tester.py
niobures's picture
DeepMorphy
0240c6e verified
import os, pickle
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from model import RNN
from utils import CONFIG, decode_word, load_datasets
class Tester:
def __init__(self):
self.config = CONFIG
self.config['graph_part_configs']['lemm']['use_cls_placeholder'] = True
self.rnn = RNN(True)
self.chars = {c: index for index, c in enumerate(self.config['chars'])}
self.batch_size = 65536
self.show_bad_items = False
def test(self):
config = tf.ConfigProto(allow_soft_placement=True)
results = []
with tf.Session(config=config, graph=self.rnn.graph) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
self.rnn.restore(sess)
for gram in self.rnn.gram_keys:
full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, gram, self.rnn.gram_graph_parts[gram], 'test')
result = f"{gram}. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}"
results.append(result)
tqdm.write(result)
full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, 'main', self.rnn.main_graph_part, 'test')
result = f"main. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}"
results.append(result)
tqdm.write(result)
lemm_acc, _ = self.__test_lemmas__(sess, 'test')
result = f"lemma_acc: {lemm_acc}"
tqdm.write(result)
results.append(result)
inflect_acc, _ = self.__test_inflect__(sess, 'test')
result = f"inflect_acc: {inflect_acc}"
tqdm.write(result)
results.append(result)
tqdm.write(result)
return "\n".join(results)
def __get_classification_items__(self, sess, items, graph_part):
wi = 0
pbar = tqdm(total=len(items), desc='Getting classification info')
results = []
etalon = []
while wi < len(items):
bi = 0
xs = []
indexes = []
seq_lens = []
max_len = 0
while bi < self.batch_size and wi < len(items):
word = items[wi]['src']
etalon.append(items[wi]['y'])
for c_index, char in enumerate(word):
xs.append(self.chars[char] if char in self.chars else self.chars['UNDEFINED'])
indexes.append([bi, c_index])
cur_len = len(word)
if cur_len > max_len:
max_len = cur_len
seq_lens.append(cur_len)
bi += 1
wi += 1
pbar.update(1)
lnch = [graph_part.probs[0]]
nn_results = sess.run(
lnch,
{
self.rnn.batch_size: bi,
self.rnn.x_seq_lens[0]: np.asarray(seq_lens),
self.rnn.x_vals[0]: np.asarray(xs),
self.rnn.x_inds[0]: np.asarray(indexes),
self.rnn.x_shape[0]: np.asarray([bi, max_len])
}
)
results.extend(nn_results[0])
return results, etalon
def __get_lemma_items__(self, sess, items):
wi = 0
pbar = tqdm(total=len(items))
while wi < len(items):
bi = 0
xs = []
clss = []
indexes = []
seq_lens = []
max_len = 0
while bi < self.batch_size and wi < len(items):
item = items[wi]
word = item['x_src']
x_cls = item['main_cls']
for c_index, char in enumerate(word):
xs.append(self.chars[char])
indexes.append([bi, c_index])
cur_len = len(word)
clss.append(x_cls)
if cur_len > max_len:
max_len = cur_len
seq_lens.append(cur_len)
bi += 1
wi += 1
pbar.update(1)
lnch = [self.rnn.lem_result]
results = sess.run(
lnch,
{
self.rnn.batch_size: bi,
self.rnn.x_seq_lens[0]: np.asarray(seq_lens),
self.rnn.x_vals[0]: np.asarray(xs),
self.rnn.x_inds[0]: np.asarray(indexes),
self.rnn.lem_class_pl: np.asarray(clss),
self.rnn.x_shape[0]: np.asarray([bi, max_len])
}
)
for word_src in results[0]:
yield decode_word(word_src[0])
def __get_inflect_items__(self, sess, items):
wi = 0
pbar = tqdm(total=len(items))
while wi < len(items):
bi = 0
xs = []
x_clss = []
y_clss = []
indexes = []
seq_lens = []
max_len = 0
while bi < self.batch_size and wi < len(items):
item = items[wi]
word = item['x_src']
x_cls = item['x_cls']
y_cls = item['y_cls']
for c_index, char in enumerate(word):
xs.append(self.chars[char])
indexes.append([bi, c_index])
cur_len = len(word)
x_clss.append(x_cls)
y_clss.append(y_cls)
if cur_len > max_len:
max_len = cur_len
seq_lens.append(cur_len)
bi += 1
wi += 1
pbar.update(1)
lnch = [self.rnn.inflect_graph_part.results[0]]
results = sess.run(
lnch,
{
self.rnn.batch_size: bi,
self.rnn.x_seq_lens[0]: np.asarray(seq_lens),
self.rnn.x_vals[0]: np.asarray(xs),
self.rnn.x_inds[0]: np.asarray(indexes),
self.rnn.inflect_graph_part.x_cls[0]: np.asarray(x_clss),
self.rnn.inflect_graph_part.y_cls[0]: np.asarray(y_clss),
self.rnn.x_shape[0]: np.asarray([bi, max_len])
}
)
for word_src in results[0]:
yield decode_word(word_src)
def __test_classification__(self, sess, key, graph_part, *ds_types):
et_items = load_datasets(key, *ds_types)
results, etalon = self.__get_classification_items__(sess, et_items, graph_part)
total = len(etalon)
total_classes = 0
full_correct = 0
part_correct = 0
bad_items = []
for index, et in enumerate(etalon):
classes_count = et.sum()
good_classes = np.argwhere(et == 1).ravel()
rez_classes = np.argsort(results[index])[-classes_count:]
total_classes += classes_count
correct = True
for cls in rez_classes:
if cls in good_classes:
part_correct += 1
else:
correct = False
if correct:
full_correct += 1
else:
bad_items.append((et_items[index], rez_classes))
full_acc = full_correct / total
cls_correct = part_correct / total_classes
return full_acc, cls_correct, bad_items
def __test_lemmas__(self, sess, *ds_types):
good_items = load_datasets("lemma", *ds_types)
good_items = [
word
for word in good_items
if all([c in self.config['chars'] for c in word['x_src']])
]
results = list(self.__get_lemma_items__(sess, good_items))
bad_words = []
total = len(good_items)
wrong = 0
for index, rez in enumerate(results):
et_word = good_items[index]
if rez != et_word['y_src']:
wrong += 1
bad_words.append((et_word, rez))
correct = total - wrong
acc = correct / total
return acc, bad_words
def __test_inflect__(self, sess, *ds_types):
good_items = load_datasets("inflect", *ds_types)
good_items = [
word
for word in good_items
if all([c in self.config['chars'] for c in word['x_src']])
]
bad_items = []
rez_words = list(self.__get_inflect_items__(sess, good_items))
total = len(good_items)
wrong = 0
for index, rez in enumerate(rez_words):
et_word = good_items[index]
if rez != et_word['y_src']:
wrong += 1
bad_items.append((et_word, rez))
correct = total - wrong
acc = correct / total
return acc, bad_items
if __name__ == "__main__":
tester = Tester()
tester.test()