Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
37a581e
1
Parent(s):
32df2f1
Converted QaracEncoderModel to use PyTorch
Browse files
qarac/models/QaracEncoderModel.py
CHANGED
|
@@ -9,7 +9,7 @@ Created on Tue Sep 5 10:01:39 2023
|
|
| 9 |
import transformers
|
| 10 |
import qarac.models.layers.GlobalAttentionPoolingHead
|
| 11 |
|
| 12 |
-
class QaracEncoderModel(transformers.
|
| 13 |
|
| 14 |
def __init__(self,base_model):
|
| 15 |
"""
|
|
@@ -27,27 +27,11 @@ class QaracEncoderModel(transformers.TFPreTrainedModel):
|
|
| 27 |
"""
|
| 28 |
super(QaracEncoderModel,self).__init__(base_model.config)
|
| 29 |
self.base_model = base_model
|
| 30 |
-
self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead()
|
| 31 |
|
| 32 |
-
def build(self,input_shape):
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
Parameters
|
| 37 |
-
----------
|
| 38 |
-
input_shape : tuple
|
| 39 |
-
shape of input data.
|
| 40 |
-
|
| 41 |
-
Returns
|
| 42 |
-
-------
|
| 43 |
-
None.
|
| 44 |
-
|
| 45 |
-
"""
|
| 46 |
-
self.built=True
|
| 47 |
|
| 48 |
-
def
|
| 49 |
-
attention_mask=None
|
| 50 |
-
training=False):
|
| 51 |
"""
|
| 52 |
Vectorizes a tokenised text
|
| 53 |
|
|
@@ -64,10 +48,8 @@ class QaracEncoderModel(transformers.TFPreTrainedModel):
|
|
| 64 |
"""
|
| 65 |
|
| 66 |
return self.head(self.base_model(input_ids,
|
| 67 |
-
attention_mask,
|
| 68 |
-
|
| 69 |
-
attention_mask,
|
| 70 |
-
training)
|
| 71 |
|
| 72 |
|
| 73 |
|
|
|
|
| 9 |
import transformers
|
| 10 |
import qarac.models.layers.GlobalAttentionPoolingHead
|
| 11 |
|
| 12 |
+
class QaracEncoderModel(transformers.PreTrainedModel):
|
| 13 |
|
| 14 |
def __init__(self,base_model):
|
| 15 |
"""
|
|
|
|
| 27 |
"""
|
| 28 |
super(QaracEncoderModel,self).__init__(base_model.config)
|
| 29 |
self.base_model = base_model
|
| 30 |
+
self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(base_model.config)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
def forward(self,input_ids,
|
| 34 |
+
attention_mask=None):
|
|
|
|
| 35 |
"""
|
| 36 |
Vectorizes a tokenised text
|
| 37 |
|
|
|
|
| 48 |
"""
|
| 49 |
|
| 50 |
return self.head(self.base_model(input_ids,
|
| 51 |
+
attention_mask).last_hidden_state,
|
| 52 |
+
attention_mask)
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
|