updating model peptriever_2023-06-23T16:07:24.508460
Browse files- bi_encoder.py +64 -0
- config.json +4 -0
bi_encoder.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel
|
| 2 |
+
from transformers.models.bert.modeling_bert import BertOnlyMLMHead
|
| 3 |
+
|
| 4 |
+
from peptriever.model.bert_embedding import BertEmbeddingConfig, BertForEmbedding
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BiEncoderConfig(BertEmbeddingConfig):
|
| 8 |
+
max_length1: int
|
| 9 |
+
max_length2: int
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BiEncoder(PreTrainedModel):
|
| 13 |
+
config_class = BiEncoderConfig
|
| 14 |
+
|
| 15 |
+
def __init__(self, config: BiEncoderConfig):
|
| 16 |
+
super().__init__(config)
|
| 17 |
+
config1 = _replace_max_length(config, "max_length1")
|
| 18 |
+
self.bert1 = BertForEmbedding(config1)
|
| 19 |
+
config2 = _replace_max_length(config, "max_length2")
|
| 20 |
+
self.bert2 = BertForEmbedding(config2)
|
| 21 |
+
self.post_init()
|
| 22 |
+
|
| 23 |
+
def forward(self, x1, x2):
|
| 24 |
+
y1 = self.forward1(x1)
|
| 25 |
+
y2 = self.forward2(x2)
|
| 26 |
+
return {"y1": y1, "y2": y2}
|
| 27 |
+
|
| 28 |
+
def forward2(self, x2):
|
| 29 |
+
y2 = self.bert2(input_ids=x2["input_ids"])
|
| 30 |
+
return y2
|
| 31 |
+
|
| 32 |
+
def forward1(self, x1):
|
| 33 |
+
y1 = self.bert1(input_ids=x1["input_ids"])
|
| 34 |
+
return y1
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BiEncoderWithMaskedLM(PreTrainedModel):
|
| 38 |
+
config_class = BiEncoderConfig
|
| 39 |
+
|
| 40 |
+
def __init__(self, config: BiEncoderConfig):
|
| 41 |
+
super().__init__(config=config)
|
| 42 |
+
config1 = _replace_max_length(config, "max_length1")
|
| 43 |
+
self.bert1 = BertForEmbedding(config1)
|
| 44 |
+
self.lm_head1 = BertOnlyMLMHead(config=config1)
|
| 45 |
+
|
| 46 |
+
config2 = _replace_max_length(config, "max_length2")
|
| 47 |
+
self.bert2 = BertForEmbedding(config2)
|
| 48 |
+
self.lm_head2 = BertOnlyMLMHead(config=config2)
|
| 49 |
+
self.post_init()
|
| 50 |
+
|
| 51 |
+
def forward(self, x1, x2):
|
| 52 |
+
y1, state1 = self.bert1.forward_with_state(input_ids=x1["input_ids"])
|
| 53 |
+
y2, state2 = self.bert2.forward_with_state(input_ids=x2["input_ids"])
|
| 54 |
+
scores1 = self.lm_head1(state1)
|
| 55 |
+
scores2 = self.lm_head2(state2)
|
| 56 |
+
outputs = {"y1": y1, "y2": y2, "scores1": scores1, "scores2": scores2}
|
| 57 |
+
return outputs
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _replace_max_length(config, length_key):
|
| 61 |
+
c1 = config.__dict__.copy()
|
| 62 |
+
c1["max_position_embeddings"] = c1.pop(length_key)
|
| 63 |
+
config1 = BertEmbeddingConfig(**c1)
|
| 64 |
+
return config1
|
config.json
CHANGED
|
@@ -4,6 +4,10 @@
|
|
| 4 |
"BiEncoder"
|
| 5 |
],
|
| 6 |
"attention_probs_dropout_prob": 0.1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"classifier_dropout": null,
|
| 8 |
"distance_func": "euclidean",
|
| 9 |
"hidden_act": "gelu",
|
|
|
|
| 4 |
"BiEncoder"
|
| 5 |
],
|
| 6 |
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "bi_encoder.BiEncoderConfig",
|
| 9 |
+
"AutoModel": "bi_encoder.BiEncoder"
|
| 10 |
+
},
|
| 11 |
"classifier_dropout": null,
|
| 12 |
"distance_func": "euclidean",
|
| 13 |
"hidden_act": "gelu",
|