PeteBleackley commited on
Commit
dd9c3ed
·
1 Parent(s): a8c528d

Completed script for testing consistency

Browse files
Files changed (1) hide show
  1. scripts.py +69 -10
scripts.py CHANGED
@@ -22,13 +22,11 @@ import nltk.corpus
22
  import difflib
23
  import scipy.stats
24
  import scipy.spatial
 
 
25
 
26
 
27
 
28
- def decoder_loss(y_true,y_pred):
29
- return keras.losses.sparse_categorical_crossentropy(y_true,
30
- y_pred,
31
- from_logits=True)
32
 
33
  def capitalise(token,i):
34
  return token.text_with_ws.title() if i==0 or token.tag_.startswith('NNP') else token.text_with_ws.lower()
@@ -123,9 +121,9 @@ def train_models(path):
123
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel(encoder_base,
124
  decoder_base,
125
  tokenizer)
126
- losses={'encode_decode':decoder_loss,
127
  'question_answering':keras.losses.mean_squared_error,
128
- 'reasoning':decoder_loss,
129
  'consistency':keras.losses.mean_squared_error}
130
  optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99))
131
  trainer.compile(optimizer=optimizer,
@@ -395,11 +393,70 @@ def test_consistency(path):
395
  maxlen = max((len(sentence for sentence in s0)))
396
  for sentence in s0:
397
  sentence.pad(maxlen,pad_id=pad_token)
398
- s0_in = tensorflow
399
-
400
-
401
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
 
403
  if __name__ == '__main__':
404
  parser = argparse.ArgumentParser(prog='QARAC',
405
  description='Experimental NLP system, aimed at improving factual accuracy')
@@ -422,4 +479,6 @@ if __name__ == '__main__':
422
  test_question_answering(args.filename)
423
  elif args.task=="test_reasoning":
424
  test_reasoning(args.filename)
 
 
425
 
 
22
  import difflib
23
  import scipy.stats
24
  import scipy.spatial
25
+ import seaborn
26
+
27
 
28
 
29
 
 
 
 
 
30
 
31
  def capitalise(token,i):
32
  return token.text_with_ws.title() if i==0 or token.tag_.startswith('NNP') else token.text_with_ws.lower()
 
121
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel(encoder_base,
122
  decoder_base,
123
  tokenizer)
124
+ losses={'encode_decode':keras.losses.SparseCategoricalCrossentropy(from_logits=True),
125
  'question_answering':keras.losses.mean_squared_error,
126
+ 'reasoning':keras.losses.SparseCategoricalCrossentropy(from_logits=True),
127
  'consistency':keras.losses.mean_squared_error}
128
  optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99))
129
  trainer.compile(optimizer=optimizer,
 
393
  maxlen = max((len(sentence for sentence in s0)))
394
  for sentence in s0:
395
  sentence.pad(maxlen,pad_id=pad_token)
396
+ s0_in = tensorflow.constant([sentence.ids for sentence in s0])
397
+ s0_attn = tensorflow.constant(numpy.not_equal(s0_in.numpy(),
398
+ pad_token).astype(int))
399
+ maxlen = max((len(sentence for sentence in s1)))
400
+ for sentence in s1:
401
+ sentence.pad(maxlen,pad_id=pad_token)
402
+ s1_in = tensorflow.constant([sentence.ids for sentence in s1])
403
+ s1_attn = tensorflow.constant(numpy.not_equal(s1_in.numpy(),
404
+ pad_token).astype(int))
405
+ s0_vec = tensorflow.l2_norm(encoder(s0_in,attention_mask=s0_attn),
406
+ axis=1)
407
+ s1_vec = tensorflow.l2_norm(encoder(s1_in,attention_mask=s1_attn),
408
+ axis=1)
409
+ @tensorflow.function
410
+ def dotprod(vecs):
411
+ (x,y)=vecs
412
+ return tensorflow.tensordot(x,y,axes=1)
413
+ consistency = tensorflow.vectorized_map(dotprod, (s0_vec,s1_vec)).numpy()
414
+ results = pandas.DataFrame({'label':data['gold_label'],
415
+ 'score':consistency})
416
+ third = 1.0/3.0
417
+ def predicted_labels(x):
418
+ return 'entailment' if x>third else 'contradiction' if x<-third else 'neutral'
419
+ results['prediction'] = results['score'].apply(predicted_labels)
420
+ confusion=results.groupby('label')['prediction'].value_counts().fillna(0)
421
+ seaborn.heatmap(confusion).save('consistency_confusion_matrix.svg')
422
+ correct = pandas.Series({label:confusion[label,label]
423
+ for label in confusion.index})
424
+ print("Accuracy: {}".format(correct.sum()/data.shape[0]))
425
+ print("Precision")
426
+ print(correct/confusion.sum(axis='columns'))
427
+ print("Recall")
428
+ print(correct/confusion.sum(axis='rows'))
429
+ def stats(group):
430
+ (alpha,beta,loc,scale) = scipy.stats.beta.fit(group)
431
+ mean = group.mean()
432
+ sd = group.std()
433
+ return pandas.Series({'mean':mean,
434
+ 'sd':sd,
435
+ 'min':loc,
436
+ 'max':loc+scale,
437
+ 'alpha':alpha,
438
+ 'beta':beta})
439
+ print(results.groupby('label')['score'].apply(stats))
440
+ quartiles = numpy.quantile(consistency,[0.0,0.25,0.5,0.75,1.0])
441
+ IQR = quartiles[3]-quartiles[1]
442
+ bin_width = 2.0*IQR/(data.shape[0]**1.5)
443
+ n_bins = int((quartiles[4]-quartiles[0])/bin_width)
444
+ bins = numpy.linspace(quartiles[0],quartiles[4],n_bins)
445
+ def hist(col):
446
+ (result,_) = numpy.histogram(col,bins)
447
+ return result
448
+ histograms = results.groupby('label')['score'].apply(hist)
449
+ histograms.coluumns = (bins[1:]+bins[:-1])/2
450
+ with pandas.option_context('plotting.backend','matploblib.backends.backend_svg') as options:
451
+ axes=histograms.T.plot.bar(stacked=True)
452
+ axes.get_figure().savefig('consistency_histograms.svg')
453
+ percent = numpy.linspace(0.0,1.0,101)
454
+ percentiles = results.groupby('label')['score'].apply(lambda x: numpy.percentile(x,percent))
455
+ with pandas.option_context('plotting.backend','matploblib.backends.backend_svg') as options:
456
+ axes=percentiles.T.plot.line()
457
+ axes.get_figure().savefig('consistency_percentiles.svg')
458
 
459
+
460
  if __name__ == '__main__':
461
  parser = argparse.ArgumentParser(prog='QARAC',
462
  description='Experimental NLP system, aimed at improving factual accuracy')
 
479
  test_question_answering(args.filename)
480
  elif args.task=="test_reasoning":
481
  test_reasoning(args.filename)
482
+ elif args.task=='test_consistency':
483
+ test_consistency(args.filename)
484