PeteBleackley commited on
Commit
5b7a8ed
·
1 Parent(s): fb4c0b0

Further changes for compatibility with HuggingFace Pytorch implementation

Browse files
qarac/models/QaracBaseModel.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Mon Oct 9 09:52:39 2023
5
+
6
+ @author: peter
7
+ """
8
+
9
+ import transformers
10
+
11
+ class QaracBaseModel(transformers.PreTraineModel):
12
+ """Base class for Qarac Models. Provided config_class"""
13
+ config_class = transformers.RobertaConfig
14
+
15
+ def __init__(self,config,*inputs,**kwargs):
16
+ super(QaracBaseModel,self).__init__(config,*inputs,**kwargs)
qarac/models/QaracDecoderModel.py CHANGED
@@ -6,8 +6,9 @@ Created on Tue Sep 5 10:29:03 2023
6
  @author: peter
7
  """
8
 
9
- import torch
10
  import transformers
 
 
11
 
12
  class QaracDecoderHead(torch.nn.Module):
13
 
@@ -75,7 +76,8 @@ class QaracDecoderHead(torch.nn.Module):
75
  False,
76
  training)[0])
77
 
78
- class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_utils.GenerationMixin):
 
79
 
80
  def __init__(self,config,tokenizer):
81
  """
 
6
  @author: peter
7
  """
8
 
 
9
  import transformers
10
+ import torch
11
+ import qarac.models.QaracBaseModel.QaracBaseModel
12
 
13
  class QaracDecoderHead(torch.nn.Module):
14
 
 
76
  False,
77
  training)[0])
78
 
79
+ class QaracDecoderModel(qarac.models.QaracBaseModel.QaracBaseModel,
80
+ transformers.generation_utils.GenerationMixin):
81
 
82
  def __init__(self,config,tokenizer):
83
  """
qarac/models/QaracEncoderModel.py CHANGED
@@ -6,10 +6,10 @@ Created on Tue Sep 5 10:01:39 2023
6
  @author: peter
7
  """
8
 
9
- import transformers
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
- class QaracEncoderModel(transformers.PreTrainedModel):
13
 
14
  def __init__(self,base_model):
15
  """
 
6
  @author: peter
7
  """
8
 
9
+ import qarac.models.QaracBaseModel
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
+ class QaracEncoderModel(qarac.models.QaracBaseModel.QaracBaseModel):
13
 
14
  def __init__(self,base_model):
15
  """