Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
519dfd1
1
Parent(s):
4cda7b6
PyTorch implementation of HugggingFace PreTrainedModel class does not allow direct setting of base_model. Rejig constructors accordingly
Browse files
qarac/models/QaracDecoderModel.py
CHANGED
|
@@ -11,7 +11,7 @@ import transformers
|
|
| 11 |
|
| 12 |
class QaracDecoderHead(torch.nn.Module):
|
| 13 |
|
| 14 |
-
def __init__(self,config,input_embeddings):
|
| 15 |
"""
|
| 16 |
Creates the Decoder head
|
| 17 |
|
|
@@ -25,7 +25,7 @@ class QaracDecoderHead(torch.nn.Module):
|
|
| 25 |
None.
|
| 26 |
|
| 27 |
"""
|
| 28 |
-
super(QaracDecoderHead,self).
|
| 29 |
self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
|
| 30 |
self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
|
| 31 |
self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config,
|
|
@@ -77,7 +77,7 @@ class QaracDecoderHead(torch.nn.Module):
|
|
| 77 |
|
| 78 |
class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_utils.GenerationMixin):
|
| 79 |
|
| 80 |
-
def __init__(self,
|
| 81 |
"""
|
| 82 |
Creates decoder model from base model
|
| 83 |
|
|
@@ -91,7 +91,7 @@ class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_uti
|
|
| 91 |
None.
|
| 92 |
|
| 93 |
"""
|
| 94 |
-
super(QaracDecoderModel,self).__init__(
|
| 95 |
self.base_model = base_model
|
| 96 |
self.decoder_head = QaracDecoderHead(self.base_model.config,
|
| 97 |
self.base_model.roberta.get_input_embeddings())
|
|
|
|
| 11 |
|
| 12 |
class QaracDecoderHead(torch.nn.Module):
|
| 13 |
|
| 14 |
+
def __init__(self,base_model,config,input_embeddings):
|
| 15 |
"""
|
| 16 |
Creates the Decoder head
|
| 17 |
|
|
|
|
| 25 |
None.
|
| 26 |
|
| 27 |
"""
|
| 28 |
+
super(QaracDecoderHead,self).from_pretrained(base_model,config)
|
| 29 |
self.layer_0 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
|
| 30 |
self.layer_1 = transformers.models.roberta.modeling_roberta.RobertaLayer(config)
|
| 31 |
self.head = transformers.models.roberta.modeling_roberta.RobertaLMHead(config,
|
|
|
|
| 77 |
|
| 78 |
class QaracDecoderModel(transformers.PreTrainedModel,transformers.generation_utils.GenerationMixin):
|
| 79 |
|
| 80 |
+
def __init__(self,config,tokenizer):
|
| 81 |
"""
|
| 82 |
Creates decoder model from base model
|
| 83 |
|
|
|
|
| 91 |
None.
|
| 92 |
|
| 93 |
"""
|
| 94 |
+
super(QaracDecoderModel,self).__init__(config)
|
| 95 |
self.base_model = base_model
|
| 96 |
self.decoder_head = QaracDecoderHead(self.base_model.config,
|
| 97 |
self.base_model.roberta.get_input_embeddings())
|
qarac/models/QaracEncoderModel.py
CHANGED
|
@@ -25,9 +25,8 @@ class QaracEncoderModel(transformers.PreTrainedModel):
|
|
| 25 |
None.
|
| 26 |
|
| 27 |
"""
|
| 28 |
-
super(QaracEncoderModel,self).
|
| 29 |
-
self.
|
| 30 |
-
self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(base_model.config)
|
| 31 |
|
| 32 |
|
| 33 |
def forward(self,input_ids,
|
|
@@ -50,6 +49,10 @@ class QaracEncoderModel(transformers.PreTrainedModel):
|
|
| 50 |
return self.head(self.base_model(input_ids,
|
| 51 |
attention_mask).last_hidden_state,
|
| 52 |
attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
|
|
|
|
| 25 |
None.
|
| 26 |
|
| 27 |
"""
|
| 28 |
+
super(QaracEncoderModel,self).from_pretrained(base_model)
|
| 29 |
+
self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(self.base_model.config)
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def forward(self,input_ids,
|
|
|
|
| 49 |
return self.head(self.base_model(input_ids,
|
| 50 |
attention_mask).last_hidden_state,
|
| 51 |
attention_mask)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def config(self):
|
| 55 |
+
return self.base_model.config
|
| 56 |
|
| 57 |
|
| 58 |
|
qarac/models/QaracTrainerModel.py
CHANGED
|
@@ -14,7 +14,7 @@ EPSILON=1.0e-12
|
|
| 14 |
|
| 15 |
class QaracTrainerModel(torch.nn.Module):
|
| 16 |
|
| 17 |
-
def __init__(self,
|
| 18 |
"""
|
| 19 |
Sets up the Trainer model
|
| 20 |
|
|
@@ -32,9 +32,13 @@ class QaracTrainerModel(torch.nn.Module):
|
|
| 32 |
|
| 33 |
"""
|
| 34 |
super(QaracTrainerModel,self).__init__()
|
| 35 |
-
self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(
|
| 36 |
-
self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def forward(self,
|
| 40 |
all_text,
|
|
|
|
| 14 |
|
| 15 |
class QaracTrainerModel(torch.nn.Module):
|
| 16 |
|
| 17 |
+
def __init__(self,base_model_path,tokenizer):
|
| 18 |
"""
|
| 19 |
Sets up the Trainer model
|
| 20 |
|
|
|
|
| 32 |
|
| 33 |
"""
|
| 34 |
super(QaracTrainerModel,self).__init__()
|
| 35 |
+
self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path)
|
| 36 |
+
self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path)
|
| 37 |
+
config = self.answer_encoder.config
|
| 38 |
+
config.is_decoder = True
|
| 39 |
+
self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_model_path,
|
| 40 |
+
config,
|
| 41 |
+
tokenizer)
|
| 42 |
|
| 43 |
def forward(self,
|
| 44 |
all_text,
|
scripts.py
CHANGED
|
@@ -123,14 +123,8 @@ def prepare_training_datasets():
|
|
| 123 |
consistency.to_csv('corpora/consistency.csv')
|
| 124 |
|
| 125 |
def train_models(path):
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
config.is_decoder = True
|
| 129 |
-
decoder_base = transformers.TFRobertaModel.from_pretrained('roberta-base',
|
| 130 |
-
config=config)
|
| 131 |
-
tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
|
| 132 |
-
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel(encoder_base,
|
| 133 |
-
decoder_base,
|
| 134 |
tokenizer)
|
| 135 |
loss_fn = CombinedLoss()
|
| 136 |
optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)
|
|
|
|
| 123 |
consistency.to_csv('corpora/consistency.csv')
|
| 124 |
|
| 125 |
def train_models(path):
|
| 126 |
+
tokenizer = tokenizers.from_pretrained('roberta-base')
|
| 127 |
+
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta_base',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
tokenizer)
|
| 129 |
loss_fn = CombinedLoss()
|
| 130 |
optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)
|