DeepMorphy / tester.py
niobures's picture
DeepMorphy
0240c6e verified
raw
history blame
8.94 kB
import os, pickle
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from model import RNN
from utils import CONFIG, decode_word, load_datasets
class Tester:
def __init__(self):
self.config = CONFIG
self.config['graph_part_configs']['lemm']['use_cls_placeholder'] = True
self.rnn = RNN(True)
self.chars = {c: index for index, c in enumerate(self.config['chars'])}
self.batch_size = 65536
self.show_bad_items = False
def test(self):
config = tf.ConfigProto(allow_soft_placement=True)
results = []
with tf.Session(config=config, graph=self.rnn.graph) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
self.rnn.restore(sess)
for gram in self.rnn.gram_keys:
full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, gram, self.rnn.gram_graph_parts[gram], 'test')
result = f"{gram}. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}"
results.append(result)
tqdm.write(result)
full_cls_acc, part_cls_acc, _ = self.__test_classification__(sess, 'main', self.rnn.main_graph_part, 'test')
result = f"main. full_cls_acc: {full_cls_acc}; part_cls_acc: {part_cls_acc}"
results.append(result)
tqdm.write(result)
lemm_acc, _ = self.__test_lemmas__(sess, 'test')
result = f"lemma_acc: {lemm_acc}"
tqdm.write(result)
results.append(result)
inflect_acc, _ = self.__test_inflect__(sess, 'test')
result = f"inflect_acc: {inflect_acc}"
tqdm.write(result)
results.append(result)
tqdm.write(result)
return "\n".join(results)
def __get_classification_items__(self, sess, items, graph_part):
wi = 0
pbar = tqdm(total=len(items), desc='Getting classification info')
results = []
etalon = []
while wi < len(items):
bi = 0
xs = []
indexes = []
seq_lens = []
max_len = 0
while bi < self.batch_size and wi < len(items):
word = items[wi]['src']
etalon.append(items[wi]['y'])
for c_index, char in enumerate(word):
xs.append(self.chars[char] if char in self.chars else self.chars['UNDEFINED'])
indexes.append([bi, c_index])
cur_len = len(word)
if cur_len > max_len:
max_len = cur_len
seq_lens.append(cur_len)
bi += 1
wi += 1
pbar.update(1)
lnch = [graph_part.probs[0]]
nn_results = sess.run(
lnch,
{
self.rnn.batch_size: bi,
self.rnn.x_seq_lens[0]: np.asarray(seq_lens),
self.rnn.x_vals[0]: np.asarray(xs),
self.rnn.x_inds[0]: np.asarray(indexes),
self.rnn.x_shape[0]: np.asarray([bi, max_len])
}
)
results.extend(nn_results[0])
return results, etalon
def __get_lemma_items__(self, sess, items):
wi = 0
pbar = tqdm(total=len(items))
while wi < len(items):
bi = 0
xs = []
clss = []
indexes = []
seq_lens = []
max_len = 0
while bi < self.batch_size and wi < len(items):
item = items[wi]
word = item['x_src']
x_cls = item['main_cls']
for c_index, char in enumerate(word):
xs.append(self.chars[char])
indexes.append([bi, c_index])
cur_len = len(word)
clss.append(x_cls)
if cur_len > max_len:
max_len = cur_len
seq_lens.append(cur_len)
bi += 1
wi += 1
pbar.update(1)
lnch = [self.rnn.lem_result]
results = sess.run(
lnch,
{
self.rnn.batch_size: bi,
self.rnn.x_seq_lens[0]: np.asarray(seq_lens),
self.rnn.x_vals[0]: np.asarray(xs),
self.rnn.x_inds[0]: np.asarray(indexes),
self.rnn.lem_class_pl: np.asarray(clss),
self.rnn.x_shape[0]: np.asarray([bi, max_len])
}
)
for word_src in results[0]:
yield decode_word(word_src[0])
def __get_inflect_items__(self, sess, items):
wi = 0
pbar = tqdm(total=len(items))
while wi < len(items):
bi = 0
xs = []
x_clss = []
y_clss = []
indexes = []
seq_lens = []
max_len = 0
while bi < self.batch_size and wi < len(items):
item = items[wi]
word = item['x_src']
x_cls = item['x_cls']
y_cls = item['y_cls']
for c_index, char in enumerate(word):
xs.append(self.chars[char])
indexes.append([bi, c_index])
cur_len = len(word)
x_clss.append(x_cls)
y_clss.append(y_cls)
if cur_len > max_len:
max_len = cur_len
seq_lens.append(cur_len)
bi += 1
wi += 1
pbar.update(1)
lnch = [self.rnn.inflect_graph_part.results[0]]
results = sess.run(
lnch,
{
self.rnn.batch_size: bi,
self.rnn.x_seq_lens[0]: np.asarray(seq_lens),
self.rnn.x_vals[0]: np.asarray(xs),
self.rnn.x_inds[0]: np.asarray(indexes),
self.rnn.inflect_graph_part.x_cls[0]: np.asarray(x_clss),
self.rnn.inflect_graph_part.y_cls[0]: np.asarray(y_clss),
self.rnn.x_shape[0]: np.asarray([bi, max_len])
}
)
for word_src in results[0]:
yield decode_word(word_src)
def __test_classification__(self, sess, key, graph_part, *ds_types):
et_items = load_datasets(key, *ds_types)
results, etalon = self.__get_classification_items__(sess, et_items, graph_part)
total = len(etalon)
total_classes = 0
full_correct = 0
part_correct = 0
bad_items = []
for index, et in enumerate(etalon):
classes_count = et.sum()
good_classes = np.argwhere(et == 1).ravel()
rez_classes = np.argsort(results[index])[-classes_count:]
total_classes += classes_count
correct = True
for cls in rez_classes:
if cls in good_classes:
part_correct += 1
else:
correct = False
if correct:
full_correct += 1
else:
bad_items.append((et_items[index], rez_classes))
full_acc = full_correct / total
cls_correct = part_correct / total_classes
return full_acc, cls_correct, bad_items
def __test_lemmas__(self, sess, *ds_types):
good_items = load_datasets("lemma", *ds_types)
good_items = [
word
for word in good_items
if all([c in self.config['chars'] for c in word['x_src']])
]
results = list(self.__get_lemma_items__(sess, good_items))
bad_words = []
total = len(good_items)
wrong = 0
for index, rez in enumerate(results):
et_word = good_items[index]
if rez != et_word['y_src']:
wrong += 1
bad_words.append((et_word, rez))
correct = total - wrong
acc = correct / total
return acc, bad_words
def __test_inflect__(self, sess, *ds_types):
good_items = load_datasets("inflect", *ds_types)
good_items = [
word
for word in good_items
if all([c in self.config['chars'] for c in word['x_src']])
]
bad_items = []
rez_words = list(self.__get_inflect_items__(sess, good_items))
total = len(good_items)
wrong = 0
for index, rez in enumerate(rez_words):
et_word = good_items[index]
if rez != et_word['y_src']:
wrong += 1
bad_items.append((et_word, rez))
correct = total - wrong
acc = correct / total
return acc, bad_items
if __name__ == "__main__":
tester = Tester()
tester.test()