Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
dd9c3ed
1
Parent(s):
a8c528d
Completed script for testing consistency
Browse files- scripts.py +69 -10
scripts.py
CHANGED
|
@@ -22,13 +22,11 @@ import nltk.corpus
|
|
| 22 |
import difflib
|
| 23 |
import scipy.stats
|
| 24 |
import scipy.spatial
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
|
| 28 |
-
def decoder_loss(y_true,y_pred):
|
| 29 |
-
return keras.losses.sparse_categorical_crossentropy(y_true,
|
| 30 |
-
y_pred,
|
| 31 |
-
from_logits=True)
|
| 32 |
|
| 33 |
def capitalise(token,i):
|
| 34 |
return token.text_with_ws.title() if i==0 or token.tag_.startswith('NNP') else token.text_with_ws.lower()
|
|
@@ -123,9 +121,9 @@ def train_models(path):
|
|
| 123 |
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel(encoder_base,
|
| 124 |
decoder_base,
|
| 125 |
tokenizer)
|
| 126 |
-
losses={'encode_decode':
|
| 127 |
'question_answering':keras.losses.mean_squared_error,
|
| 128 |
-
'reasoning':
|
| 129 |
'consistency':keras.losses.mean_squared_error}
|
| 130 |
optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99))
|
| 131 |
trainer.compile(optimizer=optimizer,
|
|
@@ -395,11 +393,70 @@ def test_consistency(path):
|
|
| 395 |
maxlen = max((len(sentence for sentence in s0)))
|
| 396 |
for sentence in s0:
|
| 397 |
sentence.pad(maxlen,pad_id=pad_token)
|
| 398 |
-
s0_in = tensorflow
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
|
|
|
| 403 |
if __name__ == '__main__':
|
| 404 |
parser = argparse.ArgumentParser(prog='QARAC',
|
| 405 |
description='Experimental NLP system, aimed at improving factual accuracy')
|
|
@@ -422,4 +479,6 @@ if __name__ == '__main__':
|
|
| 422 |
test_question_answering(args.filename)
|
| 423 |
elif args.task=="test_reasoning":
|
| 424 |
test_reasoning(args.filename)
|
|
|
|
|
|
|
| 425 |
|
|
|
|
| 22 |
import difflib
|
| 23 |
import scipy.stats
|
| 24 |
import scipy.spatial
|
| 25 |
+
import seaborn
|
| 26 |
+
|
| 27 |
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def capitalise(token,i):
|
| 32 |
return token.text_with_ws.title() if i==0 or token.tag_.startswith('NNP') else token.text_with_ws.lower()
|
|
|
|
| 121 |
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel(encoder_base,
|
| 122 |
decoder_base,
|
| 123 |
tokenizer)
|
| 124 |
+
losses={'encode_decode':keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
| 125 |
'question_answering':keras.losses.mean_squared_error,
|
| 126 |
+
'reasoning':keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
| 127 |
'consistency':keras.losses.mean_squared_error}
|
| 128 |
optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99))
|
| 129 |
trainer.compile(optimizer=optimizer,
|
|
|
|
| 393 |
maxlen = max((len(sentence for sentence in s0)))
|
| 394 |
for sentence in s0:
|
| 395 |
sentence.pad(maxlen,pad_id=pad_token)
|
| 396 |
+
s0_in = tensorflow.constant([sentence.ids for sentence in s0])
|
| 397 |
+
s0_attn = tensorflow.constant(numpy.not_equal(s0_in.numpy(),
|
| 398 |
+
pad_token).astype(int))
|
| 399 |
+
maxlen = max((len(sentence for sentence in s1)))
|
| 400 |
+
for sentence in s1:
|
| 401 |
+
sentence.pad(maxlen,pad_id=pad_token)
|
| 402 |
+
s1_in = tensorflow.constant([sentence.ids for sentence in s1])
|
| 403 |
+
s1_attn = tensorflow.constant(numpy.not_equal(s1_in.numpy(),
|
| 404 |
+
pad_token).astype(int))
|
| 405 |
+
s0_vec = tensorflow.l2_norm(encoder(s0_in,attention_mask=s0_attn),
|
| 406 |
+
axis=1)
|
| 407 |
+
s1_vec = tensorflow.l2_norm(encoder(s1_in,attention_mask=s1_attn),
|
| 408 |
+
axis=1)
|
| 409 |
+
@tensorflow.function
|
| 410 |
+
def dotprod(vecs):
|
| 411 |
+
(x,y)=vecs
|
| 412 |
+
return tensorflow.tensordot(x,y,axes=1)
|
| 413 |
+
consistency = tensorflow.vectorized_map(dotprod, (s0_vec,s1_vec)).numpy()
|
| 414 |
+
results = pandas.DataFrame({'label':data['gold_label'],
|
| 415 |
+
'score':consistency})
|
| 416 |
+
third = 1.0/3.0
|
| 417 |
+
def predicted_labels(x):
|
| 418 |
+
return 'entailment' if x>third else 'contradiction' if x<-third else 'neutral'
|
| 419 |
+
results['prediction'] = results['score'].apply(predicted_labels)
|
| 420 |
+
confusion=results.groupby('label')['prediction'].value_counts().fillna(0)
|
| 421 |
+
seaborn.heatmap(confusion).save('consistency_confusion_matrix.svg')
|
| 422 |
+
correct = pandas.Series({label:confusion[label,label]
|
| 423 |
+
for label in confusion.index})
|
| 424 |
+
print("Accuracy: {}".format(correct.sum()/data.shape[0]))
|
| 425 |
+
print("Precision")
|
| 426 |
+
print(correct/confusion.sum(axis='columns'))
|
| 427 |
+
print("Recall")
|
| 428 |
+
print(correct/confusion.sum(axis='rows'))
|
| 429 |
+
def stats(group):
|
| 430 |
+
(alpha,beta,loc,scale) = scipy.stats.beta.fit(group)
|
| 431 |
+
mean = group.mean()
|
| 432 |
+
sd = group.std()
|
| 433 |
+
return pandas.Series({'mean':mean,
|
| 434 |
+
'sd':sd,
|
| 435 |
+
'min':loc,
|
| 436 |
+
'max':loc+scale,
|
| 437 |
+
'alpha':alpha,
|
| 438 |
+
'beta':beta})
|
| 439 |
+
print(results.groupby('label')['score'].apply(stats))
|
| 440 |
+
quartiles = numpy.quantile(consistency,[0.0,0.25,0.5,0.75,1.0])
|
| 441 |
+
IQR = quartiles[3]-quartiles[1]
|
| 442 |
+
bin_width = 2.0*IQR/(data.shape[0]**1.5)
|
| 443 |
+
n_bins = int((quartiles[4]-quartiles[0])/bin_width)
|
| 444 |
+
bins = numpy.linspace(quartiles[0],quartiles[4],n_bins)
|
| 445 |
+
def hist(col):
|
| 446 |
+
(result,_) = numpy.histogram(col,bins)
|
| 447 |
+
return result
|
| 448 |
+
histograms = results.groupby('label')['score'].apply(hist)
|
| 449 |
+
histograms.coluumns = (bins[1:]+bins[:-1])/2
|
| 450 |
+
with pandas.option_context('plotting.backend','matploblib.backends.backend_svg') as options:
|
| 451 |
+
axes=histograms.T.plot.bar(stacked=True)
|
| 452 |
+
axes.get_figure().savefig('consistency_histograms.svg')
|
| 453 |
+
percent = numpy.linspace(0.0,1.0,101)
|
| 454 |
+
percentiles = results.groupby('label')['score'].apply(lambda x: numpy.percentile(x,percent))
|
| 455 |
+
with pandas.option_context('plotting.backend','matploblib.backends.backend_svg') as options:
|
| 456 |
+
axes=percentiles.T.plot.line()
|
| 457 |
+
axes.get_figure().savefig('consistency_percentiles.svg')
|
| 458 |
|
| 459 |
+
|
| 460 |
if __name__ == '__main__':
|
| 461 |
parser = argparse.ArgumentParser(prog='QARAC',
|
| 462 |
description='Experimental NLP system, aimed at improving factual accuracy')
|
|
|
|
| 479 |
test_question_answering(args.filename)
|
| 480 |
elif args.task=="test_reasoning":
|
| 481 |
test_reasoning(args.filename)
|
| 482 |
+
elif args.task=='test_consistency':
|
| 483 |
+
test_consistency(args.filename)
|
| 484 |
|