PeteBleackley commited on
Commit
37a581e
·
1 Parent(s): 32df2f1

Converted QaracEncoderModel to use PyTorch

Browse files
Files changed (1) hide show
  1. qarac/models/QaracEncoderModel.py +6 -24
qarac/models/QaracEncoderModel.py CHANGED
@@ -9,7 +9,7 @@ Created on Tue Sep 5 10:01:39 2023
9
  import transformers
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
- class QaracEncoderModel(transformers.TFPreTrainedModel):
13
 
14
  def __init__(self,base_model):
15
  """
@@ -27,27 +27,11 @@ class QaracEncoderModel(transformers.TFPreTrainedModel):
27
  """
28
  super(QaracEncoderModel,self).__init__(base_model.config)
29
  self.base_model = base_model
30
- self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead()
31
 
32
- def build(self,input_shape):
33
- """
34
-
35
-
36
- Parameters
37
- ----------
38
- input_shape : tuple
39
- shape of input data.
40
-
41
- Returns
42
- -------
43
- None.
44
-
45
- """
46
- self.built=True
47
 
48
- def call(self,input_ids,
49
- attention_mask=None,
50
- training=False):
51
  """
52
  Vectorizes a tokenised text
53
 
@@ -64,10 +48,8 @@ class QaracEncoderModel(transformers.TFPreTrainedModel):
64
  """
65
 
66
  return self.head(self.base_model(input_ids,
67
- attention_mask,
68
- training=training).last_hidden_state,
69
- attention_mask,
70
- training)
71
 
72
 
73
 
 
9
  import transformers
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
+ class QaracEncoderModel(transformers.PreTrainedModel):
13
 
14
  def __init__(self,base_model):
15
  """
 
27
  """
28
  super(QaracEncoderModel,self).__init__(base_model.config)
29
  self.base_model = base_model
30
+ self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(base_model.config)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def forward(self,input_ids,
34
+ attention_mask=None):
 
35
  """
36
  Vectorizes a tokenised text
37
 
 
48
  """
49
 
50
  return self.head(self.base_model(input_ids,
51
+ attention_mask).last_hidden_state,
52
+ attention_mask)
 
 
53
 
54
 
55