PeteBleackley commited on
Commit
87535ff
·
1 Parent(s): 2f6dc26

Removed a base model that was causing a loop in model initialisation

Browse files
qarac/models/QaracBaseModel.py DELETED
@@ -1,16 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Mon Oct 9 09:52:39 2023
5
-
6
- @author: peter
7
- """
8
-
9
- import transformers
10
-
11
- class QaracBaseModel(transformers.PreTrainedModel):
12
- """Base class for Qarac Models. Provided config_class"""
13
- config_class = transformers.PretrainedConfig
14
-
15
- def __init__(self,config,*inputs,**kwargs):
16
- super(QaracBaseModel,self).__init__(config,*inputs,**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qarac/models/QaracDecoderModel.py CHANGED
@@ -8,7 +8,6 @@ Created on Tue Sep 5 10:29:03 2023
8
 
9
  import transformers
10
  import torch
11
- import qarac.models.QaracBaseModel
12
 
13
  class QaracDecoderHead(torch.nn.Module):
14
 
@@ -76,7 +75,7 @@ class QaracDecoderHead(torch.nn.Module):
76
  False,
77
  training)[0])
78
 
79
- class QaracDecoderModel(qarac.models.QaracBaseModel.QaracBaseModel,
80
  transformers.generation_utils.GenerationMixin):
81
 
82
  def __init__(self,config,tokenizer):
 
8
 
9
  import transformers
10
  import torch
 
11
 
12
  class QaracDecoderHead(torch.nn.Module):
13
 
 
75
  False,
76
  training)[0])
77
 
78
+ class QaracDecoderModel(transformers.PreTrainedModel,
79
  transformers.generation_utils.GenerationMixin):
80
 
81
  def __init__(self,config,tokenizer):
qarac/models/QaracEncoderModel.py CHANGED
@@ -6,10 +6,10 @@ Created on Tue Sep 5 10:01:39 2023
6
  @author: peter
7
  """
8
 
9
- import qarac.models.QaracBaseModel
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
- class QaracEncoderModel(qarac.models.QaracBaseModel.QaracBaseModel):
13
 
14
  def __init__(self,base_model):
15
  """
@@ -25,8 +25,9 @@ class QaracEncoderModel(qarac.models.QaracBaseModel.QaracBaseModel):
25
  None.
26
 
27
  """
28
- super(QaracEncoderModel,self).from_pretrained(base_model)
29
- self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(self.base_model.config)
 
30
 
31
 
32
  def forward(self,input_ids,
 
6
  @author: peter
7
  """
8
 
9
+ import transformers
10
  import qarac.models.layers.GlobalAttentionPoolingHead
11
 
12
+ class QaracEncoderModel(transformers.PreTrainedModel):
13
 
14
  def __init__(self,base_model):
15
  """
 
25
  None.
26
 
27
  """
28
+ config = transformers.PretrainedConfig.from_pretrained(base_model)
29
+ super(QaracEncoderModel,self).__init__(base_model,config=config)
30
+ self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(config)
31
 
32
 
33
  def forward(self,input_ids,