PeteBleackley commited on
Commit
56e5680
·
1 Parent(s): 13f1508

Converted QaracTrainerModel to use PyTorch

Browse files
Files changed (1) hide show
  1. qarac/models/QaracTrainerModel.py +57 -45
qarac/models/QaracTrainerModel.py CHANGED
@@ -6,11 +6,13 @@ Created on Tue Sep 5 15:30:06 2023
6
  @author: peter
7
  """
8
 
9
- import keras
10
  import qarac.models.QaracEncoderModel
11
  import qarac.models.QaracDecoderModel
12
 
13
- class QaracTrainerModel(keras.Model):
 
 
14
 
15
  def __init__(self,base_encoder_model,base_decoder_model,tokenizer):
16
  """
@@ -18,9 +20,9 @@ class QaracTrainerModel(keras.Model):
18
 
19
  Parameters
20
  ----------
21
- base_encoder_model : transformers.TFRobertaModel
22
  Base model for encoders.
23
- base_decoder_model : transformers.TFRobertaModel
24
  Base model for decoder
25
  tokenizer : transformers.RobertaTokenizer
26
  Tokeniaer for decoder
@@ -33,54 +35,64 @@ class QaracTrainerModel(keras.Model):
33
  self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_encoder_model)
34
  self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_encoder_model)
35
  self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_decoder_model,tokenizer)
36
- self.consistency = keras.layers.Dot(axes=1,normalize=True)
37
 
38
- def call(self,inputs,training=None):
 
 
 
 
 
 
 
 
 
39
  """
40
- Generates training objective outputs from training data
41
 
42
  Parameters
43
  ----------
44
- inputs : dict[str,tensoflow.tensor]
45
- Fields are
46
- 'all_text': Tokenized text to train answer encoder to produce vectors
47
- and decoder to convert them back to text
48
- 'offset_text': Same text as in 'all_text', but preceded by <s>
49
- 'question': Tokenized text of questions for question answering
50
- objective
51
- 'answer': Tokenized text of answers for question answering objective
52
- 'proposition0': tokenized proposition for reasoning objective
53
- 'proposition1': tokenized proposition for reasoning objective
54
- 'conclusion_offset': tokenized text of conclusions for reasoning
55
- objective, prefixed by '<s>'
56
- 'statement0': tokenized statement for consistency objective
57
- 'statement1: tokenized statement for consistency objective'
58
- training : Bool, optional
59
- Not used. The default is None.
 
 
60
 
61
  Returns
62
  -------
63
- results : dict[str,tensorflow.tensor]
64
- Fields are
65
- 'encode_decode': tokeniaed text from decoding of vectors produced by
66
- answer encoder from 'all_text'
67
- 'question_answering': difference between vector produced by question
68
- encoder for 'question' and answer encoder for
69
- 'answer'
70
- 'reasoning': tokenised text produced by decoder from sum of vectors
71
- produced by answwr endocer for 'proposition0' and
72
- 'proposition1'
73
- 'consistency': cosine similarity of vectors produced by answer encoder
74
- from 'statement0' and 'statement1'
75
 
76
  """
77
- results = {}
78
- results['encode_decode'] = self.decoder((self.answer_encoder(inputs['all_text']),
79
- inputs['offset_text']))
80
- results['question_answering'] = self.question_encoder(inputs['question']) - self.answer_encoder(inputs['answer'])
81
- results['reasoning'] = self.decoder((self.answer_encoder(inputs['proposition0'])
82
- +self.answer_encoder(inputs['proposition1']),
83
- inputs['conclusion_offset']))
84
- results['consistency'] = self.consistency((self.answer_encoder(inputs['statement0']),
85
- self.answer_encoder(inputs['statement1'])))
86
- return results
 
 
 
 
 
6
  @author: peter
7
  """
8
 
9
+ import torch
10
  import qarac.models.QaracEncoderModel
11
  import qarac.models.QaracDecoderModel
12
 
13
+ EPSILON=1.0e-12
14
+
15
+ class QaracTrainerModel(torch.nn.Module()):
16
 
17
  def __init__(self,base_encoder_model,base_decoder_model,tokenizer):
18
  """
 
20
 
21
  Parameters
22
  ----------
23
+ base_encoder_model : transformers.RobertaModel
24
  Base model for encoders.
25
+ base_decoder_model : transformers.RobertaModel
26
  Base model for decoder
27
  tokenizer : transformers.RobertaTokenizer
28
  Tokeniaer for decoder
 
35
  self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_encoder_model)
36
  self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_encoder_model)
37
  self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_decoder_model,tokenizer)
 
38
 
39
+ def forward(self,
40
+ all_text,
41
+ offset_text,
42
+ question,
43
+ answer,
44
+ proposition0,
45
+ proposition1,
46
+ conclusion_offset,
47
+ statement0,
48
+ statement1):
49
  """
50
+ Generates training objectives from data
51
 
52
  Parameters
53
  ----------
54
+ all_text : torch.tensor
55
+ Tokenized text for encode-decode objective
56
+ offset_text : torch.tensor
57
+ As above, prefixed with <s>
58
+ question : torch.tensor
59
+ tokenized question for question ansering objective
60
+ answer : torch.tensor
61
+ tokenized answer for question answering objective
62
+ proposition0 : torch.tensor
63
+ tokenized proposition for reasoning objective.
64
+ proposition1 : otrch.tensor
65
+ tokenized proposition for reasoning objective
66
+ conclusion_offset : torch.tensor
67
+ tokeniaed conclusion for reasoning objective, prefixed with <s>
68
+ statement0 : torch.tensor
69
+ tokenized statement for consistency objective
70
+ statement1 : torch.tensor
71
+ tokenized.statement for consistency ogjective
72
 
73
  Returns
74
  -------
75
+ encode_decode : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
76
+ Predicted text for encode-decode task
77
+ question_answering : torch.tensor
78
+ Difference between encoded question and encoded answeer
79
+ reasoning : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
80
+ Predicted text for reasoning objective
81
+ consistency : torch.tensor
82
+ Cosine similarity of vectorized statements
 
 
 
 
83
 
84
  """
85
+ encode_decode = self.decoder((self.answer_encoder(all_text),
86
+ offset_text))
87
+ question_answering = self.question_encoder(question) - self.answer_encoder(answer)
88
+ reasoning = self.decoder((self.answer_encoder(proposition0)
89
+ +self.answer_encoder(proposition1),
90
+ conclusion_offset))
91
+ s0vec = self.answer_encoder(statement0)
92
+ s0norm = torch.max(torch.linalg.vector_norm(s0vec,dim=1),EPSILON)
93
+ s0 = s0vec/s0norm
94
+ s1vec = self.answer_encoder(statement1)
95
+ s1norm = torch.max(torch.linalg.vector_norm(s1vec,dim=1),EPSILON)
96
+ s1 = s1vec/s1norm
97
+ consistency = torch.einsum('ij,ij->i',s0,s1)
98
+ return (encode_decode,question_answering,reasoning,consistency)