PeteBleackley commited on
Commit
858f75e
·
1 Parent(s): f0ad7f1

There's a simpler way of doing this, I hope

Browse files
qarac/models/QaracDecoderModel.py CHANGED
@@ -92,9 +92,11 @@ class QaracDecoderModel(transformers.RobertaModel,
92
  None.
93
 
94
  """
95
- super(QaracDecoderModel,self).from_pretrained(model_path,config=config)
96
- self.decoder_head = QaracDecoderHead(self.base_model.config,
97
- self.base_model.roberta.get_input_embeddings())
 
 
98
  self.tokenizer = tokenizer
99
 
100
 
@@ -119,7 +121,7 @@ class QaracDecoderModel(transformers.RobertaModel,
119
  (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
120
 
121
  return self.decoder_head(torch.unsqueeze(v,1),
122
- self.base_model(s).last_hidden_state,
123
  training = kwargs.get('training',False))
124
 
125
  def prepare_inputs_for_generation(self,
 
92
  None.
93
 
94
  """
95
+ super(QaracDecoderModel,self).__init__(config)
96
+ self.decoder_base = transformers.RobertaModel.from_pretrained(model_path,
97
+ config=config)
98
+ self.decoder_head = QaracDecoderHead(self.config,
99
+ self.decoder_base.roberta.get_input_embeddings())
100
  self.tokenizer = tokenizer
101
 
102
 
 
121
  (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
122
 
123
  return self.decoder_head(torch.unsqueeze(v,1),
124
+ self.decoder_base(s).last_hidden_state,
125
  training = kwargs.get('training',False))
126
 
127
  def prepare_inputs_for_generation(self,
qarac/models/QaracEncoderModel.py CHANGED
@@ -9,9 +9,9 @@ Created on Tue Sep 5 10:01:39 2023
9
  import transformers
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
- class QaracEncoderModel(transformers.RobertaModel):
13
 
14
- def __init__(self,base_model):
15
  """
16
  Creates the endocer model
17
 
@@ -25,8 +25,9 @@ class QaracEncoderModel(transformers.RobertaModel):
25
  None.
26
 
27
  """
28
- config = transformers.PretrainedConfig.from_pretrained(base_model)
29
- super(QaracEncoderModel,self).from_pretrained(base_model,config=config)
 
30
  self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(config)
31
 
32
 
@@ -47,8 +48,8 @@ class QaracEncoderModel(transformers.RobertaModel):
47
 
48
  """
49
 
50
- return self.head(self.base_model(input_ids,
51
- attention_mask).last_hidden_state,
52
  attention_mask)
53
 
54
  @property
 
9
  import transformers
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
+ class QaracEncoderModel(transformers.PreTrainedModel):
13
 
14
+ def __init__(self,path):
15
  """
16
  Creates the endocer model
17
 
 
25
  None.
26
 
27
  """
28
+ config = transformers.PretrainedConfig.from_pretrained(path)
29
+ super(QaracEncoderModel,self).__init__(config)
30
+ self.encoder = transformers.RobertaModel(path)
31
  self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(config)
32
 
33
 
 
48
 
49
  """
50
 
51
+ return self.head(self.encoder(input_ids,
52
+ attention_mask).last_hidden_state,
53
  attention_mask)
54
 
55
  @property