PeteBleackley commited on
Commit
519dfd1
·
1 Parent(s): 4cda7b6

PyTorch implementation of HugggingFace PreTrainedModel class does not allow direct setting of base_model. Rejig constructors accordingly

Browse files
qarac/models/QaracDecoderModel.py CHANGED
@@ -11,7 +11,7 @@ import transformers
11
 
12
  class QaracDecoderHead(torch.nn.Module):
13
 
14
- def __init__(self,config,input_embeddings):
15
  """
16
  Creates the Decoder head
17
 
@@ -25,7 +25,7 @@ class QaracDecoderHead(torch.nn.Module):
25
  None.
26
 
27
  """
28
- super(QaracDecoderHead,self).__init__()
29
  self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
30
  self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
31
  self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config,
@@ -77,7 +77,7 @@ class QaracDecoderHead(torch.nn.Module):
77
 
78
  class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_utils.GenerationMixin):
79
 
80
- def __init__(self,base_model,tokenizer):
81
  """
82
  Creates decoder model from base model
83
 
@@ -91,7 +91,7 @@ class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_uti
91
  None.
92
 
93
  """
94
- super(QaracDecoderModel,self).__init__(base_model.config)
95
  self.base_model = base_model
96
  self.decoder_head = QaracDecoderHead(self.base_model.config,
97
  self.base_model.roberta.get_input_embeddings())
 
11
 
12
  class QaracDecoderHead(torch.nn.Module):
13
 
14
+ def __init__(self,base_model,config,input_embeddings):
15
  """
16
  Creates the Decoder head
17
 
 
25
  None.
26
 
27
  """
28
+ super(QaracDecoderHead,self).from_pretrained(base_model,config)
29
  self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
30
  self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
31
  self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config,
 
77
 
78
  class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_utils.GenerationMixin):
79
 
80
+ def __init__(self,config,tokenizer):
81
  """
82
  Creates decoder model from base model
83
 
 
91
  None.
92
 
93
  """
94
+ super(QaracDecoderModel,self).__init__(config)
95
  self.base_model = base_model
96
  self.decoder_head = QaracDecoderHead(self.base_model.config,
97
  self.base_model.roberta.get_input_embeddings())
qarac/models/QaracEncoderModel.py CHANGED
@@ -25,9 +25,8 @@ class QaracEncoderModel(transformers.PreTrainedModel):
25
  None.
26
 
27
  """
28
- super(QaracEncoderModel,self).__init__(base_model.config)
29
- self.base_model = base_model
30
- self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(base_model.config)
31
 
32
 
33
  def forward(self,input_ids,
@@ -50,6 +49,10 @@ class QaracEncoderModel(transformers.PreTrainedModel):
50
  return self.head(self.base_model(input_ids,
51
  attention_mask).last_hidden_state,
52
  attention_mask)
 
 
 
 
53
 
54
 
55
 
 
25
  None.
26
 
27
  """
28
+ super(QaracEncoderModel,self).from_pretrained(base_model)
29
+ self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(self.base_model.config)
 
30
 
31
 
32
  def forward(self,input_ids,
 
49
  return self.head(self.base_model(input_ids,
50
  attention_mask).last_hidden_state,
51
  attention_mask)
52
+
53
+ @property
54
+ def config(self):
55
+ return self.base_model.config
56
 
57
 
58
 
qarac/models/QaracTrainerModel.py CHANGED
@@ -14,7 +14,7 @@ EPSILON=1.0e-12
14
 
15
  class QaracTrainerModel(torch.nn.Module):
16
 
17
- def __init__(self,base_encoder_model,base_decoder_model,tokenizer):
18
  """
19
  Sets up the Trainer model
20
 
@@ -32,9 +32,13 @@ class QaracTrainerModel(torch.nn.Module):
32
 
33
  """
34
  super(QaracTrainerModel,self).__init__()
35
- self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_encoder_model)
36
- self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_encoder_model)
37
- self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_decoder_model,tokenizer)
 
 
 
 
38
 
39
  def forward(self,
40
  all_text,
 
14
 
15
  class QaracTrainerModel(torch.nn.Module):
16
 
17
+ def __init__(self,base_model_path,tokenizer):
18
  """
19
  Sets up the Trainer model
20
 
 
32
 
33
  """
34
  super(QaracTrainerModel,self).__init__()
35
+ self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path)
36
+ self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path)
37
+ config = self.answer_encoder.config
38
+ config.is_decoder = True
39
+ self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_model_path,
40
+ config,
41
+ tokenizer)
42
 
43
  def forward(self,
44
  all_text,
scripts.py CHANGED
@@ -123,14 +123,8 @@ def prepare_training_datasets():
123
  consistency.to_csv('corpora/consistency.csv')
124
 
125
  def train_models(path):
126
- encoder_base = transformers.TFRobertaModel.from_pretrained('roberta-base')
127
- config = encoder_base.config
128
- config.is_decoder = True
129
- decoder_base = transformers.TFRobertaModel.from_pretrained('roberta-base',
130
- config=config)
131
- tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
132
- trainer = qarac.models.QaracTrainerModel.QaracTrainerModel(encoder_base,
133
- decoder_base,
134
  tokenizer)
135
  loss_fn = CombinedLoss()
136
  optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)
 
123
  consistency.to_csv('corpora/consistency.csv')
124
 
125
  def train_models(path):
126
+ tokenizer = tokenizers.from_pretrained('roberta-base')
127
+ trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta_base',
 
 
 
 
 
 
128
  tokenizer)
129
  loss_fn = CombinedLoss()
130
  optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)