PeteBleackley commited on
Commit
63b2c6a
·
1 Parent(s): 4f504fa

Training script

Browse files
Files changed (2) hide show
  1. qarac/models/QaracDecoderModel.py +1 -1
  2. scripts.py +48 -2
qarac/models/QaracDecoderModel.py CHANGED
@@ -98,7 +98,7 @@ class QaracDecoderModel(transformers.TFPretrainedModel,transformers.TFGeneration
98
  Parameters
99
  ----------
100
  inputs : tuple of Tensorflow.Tensors OR tensorflow.Tensor
101
- Vector to be converted to text and seed text ORtokenized seed text
102
  kwargs : optional keyword arguments
103
  vector : tensorflow.Tensor vector to be decoded. May be supplied
104
  via a keyword argument when this is invoked by .generate
 
98
  Parameters
99
  ----------
100
  inputs : tuple of Tensorflow.Tensors OR tensorflow.Tensor
101
+ Vector to be converted to text and seed text OR tokenized seed text
102
  kwargs : optional keyword arguments
103
  vector : tensorflow.Tensor vector to be decoded. May be supplied
104
  via a keyword argument when this is invoked by .generate
scripts.py CHANGED
@@ -4,9 +4,12 @@ import re
4
  import argparse
5
  import pickle
6
  import tokenizers
 
7
  import qarac.corpora.BNCorpus
8
  import qarac.corpora.Batcher
9
  import qarac.models.qarac_base_model
 
 
10
  import keras
11
  import tensorflow
12
  import spacy
@@ -102,7 +105,48 @@ def prepare_training_datasets():
102
  question_answering.to_csv('corpora/question_answering.csv')
103
  reasoning.to_csv('corpora/reasoning_train.csv')
104
  consistency.to_csv('corpora/consistency.csv')
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
 
@@ -115,10 +159,12 @@ if __name__ == '__main__':
115
  parser.add_argument('-t','--training-task')
116
  parser.add_argument('-o','--outputfile')
117
  args = parser.parse_args()
118
- if args.task == 'train_base_model':
119
  train_base_model(args.training_task,args.filename)
120
  elif args.task == 'prepare_wiki_qa':
121
  prepare_wiki_qa(args.filename,args.outputfile)
122
  elif args.task == 'prepare_training_datasets':
123
  prepare_training_datasets()
 
 
124
 
 
4
  import argparse
5
  import pickle
6
  import tokenizers
7
+ import transformers
8
  import qarac.corpora.BNCorpus
9
  import qarac.corpora.Batcher
10
  import qarac.models.qarac_base_model
11
+ import qarac.models.QaracTrainerModel
12
+ import qarac.corpora.CombinedCorpus
13
  import keras
14
  import tensorflow
15
  import spacy
 
105
  question_answering.to_csv('corpora/question_answering.csv')
106
  reasoning.to_csv('corpora/reasoning_train.csv')
107
  consistency.to_csv('corpora/consistency.csv')
108
+
109
+ def train_models(path):
110
+ encoder_base = transformers.TFRobertaModel.from_pretrained('roberta_base')
111
+ config = encoder_base.config
112
+ config.is_decoder = True
113
+ decoder_base = transformers.TFRobertaModel.from_pretrained('roberta_base',
114
+ config=config)
115
+ tokenizer = tokenizers.from_pretrained('roberta_base')
116
+ trainer = qarac.models.QaracTrainerModel.QuaracTrainerModel(encoder_base,
117
+ decoder_base,
118
+ tokenizer)
119
+ losses={'encode_decode':decoder_loss,
120
+ 'question_answering':keras.losses.mean_squared_error,
121
+ 'reasoning':decoder_loss,
122
+ 'consistency':keras.losses.mean_squared_error}
123
+ optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99))
124
+ trainer.compile(optimizer=optimizer,
125
+ loss=losses)
126
+ training_data = qarac.corpora.CombinedCorpus(tokenizer,
127
+ all_text='corpora/all_text.csv',
128
+ question_answering='corpora/question_answering.csv',
129
+ reasoning='corpora/reasoning_train.csv',
130
+ consistency='corpora/consistency.csv')
131
+ trainer.fit(training_data,
132
+ epochs=10,
133
+ workers=16,
134
+ use_multiprocessing=True)
135
+ trainer.question_encoder.push_to_hub('{}/qarac-roberta-question-encoder'.format(path))
136
+ trainer.answer_encoder.push_to_hub('{}/qarac-roberta-answer-encoder'.format(path))
137
+ trainer.decoder.push_to_hub('{}/qarac-roberta-decoder'.format(path))
138
+ with open('model_summaries.txt') as summaries:
139
+ summaries.write('TRAINER MODEL\n')
140
+ summaries.write(trainer.summary())
141
+ summaries.write('QUESTION ENCODER\n')
142
+ summaries.write(trainer.question_encoder.summary())
143
+ summaries.write('ANSWER ENCODER\n')
144
+ summaries.write(trainer.answer_encoder.summary())
145
+ summaries.write('DECODER\n')
146
+ summaries.write(trainer.decoder.summary())
147
+ keras.utils.plot_model(trainer,'trainer_model.png')
148
+ keras.utils.plot_model(trainer.answer_encoder,'encoder_model.png')
149
+ keras.utils.plot_model(trainer.decoder,'decoder_model.png')
150
 
151
 
152
 
 
159
  parser.add_argument('-t','--training-task')
160
  parser.add_argument('-o','--outputfile')
161
  args = parser.parse_args()
162
+ if args.task == 'train_base_model':
163
  train_base_model(args.training_task,args.filename)
164
  elif args.task == 'prepare_wiki_qa':
165
  prepare_wiki_qa(args.filename,args.outputfile)
166
  elif args.task == 'prepare_training_datasets':
167
  prepare_training_datasets()
168
+ elif args.task == 'train_models':
169
+ train_models(args.filename)
170