Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
63b2c6a
1
Parent(s):
4f504fa
Training script
Browse files- qarac/models/QaracDecoderModel.py +1 -1
- scripts.py +48 -2
qarac/models/QaracDecoderModel.py
CHANGED
|
@@ -98,7 +98,7 @@ class QaracDecoderModel(transformers.TFPretrainedModel,transformers.TFGeneration
|
|
| 98 |
Parameters
|
| 99 |
----------
|
| 100 |
inputs : tuple of Tensorflow.Tensors OR tensorflow.Tensor
|
| 101 |
-
Vector to be converted to text and seed text
|
| 102 |
kwargs : optional keyword arguments
|
| 103 |
vector : tensorflow.Tensor vector to be decoded. May be supplied
|
| 104 |
via a keyword argument when this is invoked by .generate
|
|
|
|
| 98 |
Parameters
|
| 99 |
----------
|
| 100 |
inputs : tuple of Tensorflow.Tensors OR tensorflow.Tensor
|
| 101 |
+
Vector to be converted to text and seed text OR tokenized seed text
|
| 102 |
kwargs : optional keyword arguments
|
| 103 |
vector : tensorflow.Tensor vector to be decoded. May be supplied
|
| 104 |
via a keyword argument when this is invoked by .generate
|
scripts.py
CHANGED
|
@@ -4,9 +4,12 @@ import re
|
|
| 4 |
import argparse
|
| 5 |
import pickle
|
| 6 |
import tokenizers
|
|
|
|
| 7 |
import qarac.corpora.BNCorpus
|
| 8 |
import qarac.corpora.Batcher
|
| 9 |
import qarac.models.qarac_base_model
|
|
|
|
|
|
|
| 10 |
import keras
|
| 11 |
import tensorflow
|
| 12 |
import spacy
|
|
@@ -102,7 +105,48 @@ def prepare_training_datasets():
|
|
| 102 |
question_answering.to_csv('corpora/question_answering.csv')
|
| 103 |
reasoning.to_csv('corpora/reasoning_train.csv')
|
| 104 |
consistency.to_csv('corpora/consistency.csv')
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
|
|
@@ -115,10 +159,12 @@ if __name__ == '__main__':
|
|
| 115 |
parser.add_argument('-t','--training-task')
|
| 116 |
parser.add_argument('-o','--outputfile')
|
| 117 |
args = parser.parse_args()
|
| 118 |
-
if args.task == 'train_base_model':
|
| 119 |
train_base_model(args.training_task,args.filename)
|
| 120 |
elif args.task == 'prepare_wiki_qa':
|
| 121 |
prepare_wiki_qa(args.filename,args.outputfile)
|
| 122 |
elif args.task == 'prepare_training_datasets':
|
| 123 |
prepare_training_datasets()
|
|
|
|
|
|
|
| 124 |
|
|
|
|
| 4 |
import argparse
|
| 5 |
import pickle
|
| 6 |
import tokenizers
|
| 7 |
+
import transformers
|
| 8 |
import qarac.corpora.BNCorpus
|
| 9 |
import qarac.corpora.Batcher
|
| 10 |
import qarac.models.qarac_base_model
|
| 11 |
+
import qarac.models.QaracTrainerModel
|
| 12 |
+
import qarac.corpora.CombinedCorpus
|
| 13 |
import keras
|
| 14 |
import tensorflow
|
| 15 |
import spacy
|
|
|
|
| 105 |
question_answering.to_csv('corpora/question_answering.csv')
|
| 106 |
reasoning.to_csv('corpora/reasoning_train.csv')
|
| 107 |
consistency.to_csv('corpora/consistency.csv')
|
| 108 |
+
|
| 109 |
+
def train_models(path):
|
| 110 |
+
encoder_base = transformers.TFRobertaModel.from_pretrained('roberta_base')
|
| 111 |
+
config = encoder_base.config
|
| 112 |
+
config.is_decoder = True
|
| 113 |
+
decoder_base = transformers.TFRobertaModel.from_pretrained('roberta_base',
|
| 114 |
+
config=config)
|
| 115 |
+
tokenizer = tokenizers.from_pretrained('roberta_base')
|
| 116 |
+
trainer = qarac.models.QaracTrainerModel.QuaracTrainerModel(encoder_base,
|
| 117 |
+
decoder_base,
|
| 118 |
+
tokenizer)
|
| 119 |
+
losses={'encode_decode':decoder_loss,
|
| 120 |
+
'question_answering':keras.losses.mean_squared_error,
|
| 121 |
+
'reasoning':decoder_loss,
|
| 122 |
+
'consistency':keras.losses.mean_squared_error}
|
| 123 |
+
optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99))
|
| 124 |
+
trainer.compile(optimizer=optimizer,
|
| 125 |
+
loss=losses)
|
| 126 |
+
training_data = qarac.corpora.CombinedCorpus(tokenizer,
|
| 127 |
+
all_text='corpora/all_text.csv',
|
| 128 |
+
question_answering='corpora/question_answering.csv',
|
| 129 |
+
reasoning='corpora/reasoning_train.csv',
|
| 130 |
+
consistency='corpora/consistency.csv')
|
| 131 |
+
trainer.fit(training_data,
|
| 132 |
+
epochs=10,
|
| 133 |
+
workers=16,
|
| 134 |
+
use_multiprocessing=True)
|
| 135 |
+
trainer.question_encoder.push_to_hub('{}/qarac-roberta-question-encoder'.format(path))
|
| 136 |
+
trainer.answer_encoder.push_to_hub('{}/qarac-roberta-answer-encoder'.format(path))
|
| 137 |
+
trainer.decoder.push_to_hub('{}/qarac-roberta-decoder'.format(path))
|
| 138 |
+
with open('model_summaries.txt') as summaries:
|
| 139 |
+
summaries.write('TRAINER MODEL\n')
|
| 140 |
+
summaries.write(trainer.summary())
|
| 141 |
+
summaries.write('QUESTION ENCODER\n')
|
| 142 |
+
summaries.write(trainer.question_encoder.summary())
|
| 143 |
+
summaries.write('ANSWER ENCODER\n')
|
| 144 |
+
summaries.write(trainer.answer_encoder.summary())
|
| 145 |
+
summaries.write('DECODER\n')
|
| 146 |
+
summaries.write(trainer.decoder.summary())
|
| 147 |
+
keras.utils.plot_model(trainer,'trainer_model.png')
|
| 148 |
+
keras.utils.plot_model(trainer.answer_encoder,'encoder_model.png')
|
| 149 |
+
keras.utils.plot_model(trainer.decoder,'decoder_model.png')
|
| 150 |
|
| 151 |
|
| 152 |
|
|
|
|
| 159 |
parser.add_argument('-t','--training-task')
|
| 160 |
parser.add_argument('-o','--outputfile')
|
| 161 |
args = parser.parse_args()
|
| 162 |
+
if args.task == 'train_base_model':
|
| 163 |
train_base_model(args.training_task,args.filename)
|
| 164 |
elif args.task == 'prepare_wiki_qa':
|
| 165 |
prepare_wiki_qa(args.filename,args.outputfile)
|
| 166 |
elif args.task == 'prepare_training_datasets':
|
| 167 |
prepare_training_datasets()
|
| 168 |
+
elif args.task == 'train_models':
|
| 169 |
+
train_models(args.filename)
|
| 170 |
|