Text Ranking
sentence-transformers
ONNX
Safetensors
OpenVINO
Transformers
English
electra
text-classification
custom_code
Instructions to use cross-encoder/monoelectra-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use cross-encoder/monoelectra-base with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("cross-encoder/monoelectra-base", trust_remote_code=True) query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Transformers
How to use cross-encoder/monoelectra-base with Transformers:
# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("cross-encoder/monoelectra-base", trust_remote_code=True) model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/monoelectra-base", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
Tom Aarsen commited on
Commit ·
23a9cf3
1
Parent(s): be3c10d
Integrate with (Sentence) Transformers
Browse files- README.md +58 -3
- config.json +16 -3
- model.safetensors +2 -2
- modeling.py +88 -0
README.md
CHANGED
|
@@ -1,11 +1,66 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
pipeline_tag: text-ranking
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
base_model:
|
| 6 |
- google/electra-base-discriminator
|
|
|
|
|
|
|
| 7 |
---
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
pipeline_tag: text-ranking
|
| 4 |
+
language:
|
| 5 |
+
- en
|
| 6 |
+
library_name: sentence-transformers
|
| 7 |
base_model:
|
| 8 |
- google/electra-base-discriminator
|
| 9 |
+
tags:
|
| 10 |
+
- transformers
|
| 11 |
---
|
| 12 |
|
| 13 |
+
## Cross-Encoder for Text Ranking
|
| 14 |
|
| 15 |
+
This model is a port of the [webis/monoelectra-base](https://huggingface.co/webis/monoelectra-base) model from [lightning-ir](https://github.com/webis-de/lightning-ir) to [Sentence Transformers](https://sbert.net/) and [Transformers](https://huggingface.co/docs/transformers).
|
| 16 |
+
|
| 17 |
+
The original model was introduced in the paper [A Systematic Investigation of Distilling Large Language Models into Cross-Encoders for Passage Re-ranking](https://arxiv.org/abs/2405.07920). See https://github.com/webis-de/rank-distillm for code used to train the original model.
|
| 18 |
+
|
| 19 |
+
The model can be used as a reranker in a 2-stage "retrieve-rerank" pipeline, where it reorders passages returned by a retriever model (e.g. an embedding model or BM25) given some query. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details.
|
| 20 |
+
|
| 21 |
+
## Usage with Sentence Transformers
|
| 22 |
+
|
| 23 |
+
The usage is easy when you have [SentenceTransformers](https://www.sbert.net/) installed.
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
pip install sentence-transformers
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
Then you can use the pre-trained model like this:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
from sentence_transformers import CrossEncoder
|
| 33 |
+
|
| 34 |
+
model = CrossEncoder("cross-encoder/monoelectra-base", trust_remote_code=True)
|
| 35 |
+
scores = model.predict([
|
| 36 |
+
("How many people live in Berlin?", "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."),
|
| 37 |
+
("How many people live in Berlin?", "Berlin is well known for its museums."),
|
| 38 |
+
])
|
| 39 |
+
print(scores)
|
| 40 |
+
# [ 8.607138 -4.320078]
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Usage with Transformers
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 47 |
+
import torch
|
| 48 |
+
|
| 49 |
+
model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/monoelectra-base", trust_remote_code=True)
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/monoelectra-base")
|
| 51 |
+
|
| 52 |
+
features = tokenizer(
|
| 53 |
+
[
|
| 54 |
+
["How many people live in Berlin?", "How many people live in Berlin?"],
|
| 55 |
+
["Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.", "New York City is famous for the Metropolitan Museum of Art."],
|
| 56 |
+
],
|
| 57 |
+
padding=True,
|
| 58 |
+
truncation=True,
|
| 59 |
+
return_tensors="pt",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
model.eval()
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
scores = model(**features).logits
|
| 65 |
+
print(scores)
|
| 66 |
+
```
|
config.json
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
"attention_probs_dropout_prob": 0.1,
|
|
|
|
|
|
|
|
|
|
| 6 |
"backbone_model_type": "electra",
|
| 7 |
"classifier_dropout": null,
|
| 8 |
"doc_length": 256,
|
|
@@ -10,23 +13,33 @@
|
|
| 10 |
"hidden_act": "gelu",
|
| 11 |
"hidden_dropout_prob": 0.1,
|
| 12 |
"hidden_size": 768,
|
|
|
|
|
|
|
|
|
|
| 13 |
"initializer_range": 0.02,
|
| 14 |
"intermediate_size": 3072,
|
|
|
|
|
|
|
|
|
|
| 15 |
"layer_norm_eps": 1e-12,
|
| 16 |
"max_position_embeddings": 512,
|
| 17 |
-
"model_type": "
|
| 18 |
"num_attention_heads": 12,
|
| 19 |
"num_hidden_layers": 12,
|
| 20 |
"pad_token_id": 0,
|
| 21 |
"pooling_strategy": "first",
|
| 22 |
"position_embedding_type": "absolute",
|
| 23 |
"query_length": 32,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"summary_activation": "gelu",
|
| 25 |
"summary_last_dropout": 0.1,
|
| 26 |
"summary_type": "first",
|
| 27 |
"summary_use_proj": true,
|
| 28 |
"torch_dtype": "float32",
|
| 29 |
-
"transformers_version": "4.
|
| 30 |
"type_vocab_size": 2,
|
| 31 |
"use_cache": true,
|
| 32 |
"vocab_size": 30522
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"WebisCrossEncoderForSequenceClassification"
|
| 4 |
],
|
| 5 |
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoModelForSequenceClassification": "modeling.WebisCrossEncoderForSequenceClassification"
|
| 8 |
+
},
|
| 9 |
"backbone_model_type": "electra",
|
| 10 |
"classifier_dropout": null,
|
| 11 |
"doc_length": 256,
|
|
|
|
| 13 |
"hidden_act": "gelu",
|
| 14 |
"hidden_dropout_prob": 0.1,
|
| 15 |
"hidden_size": 768,
|
| 16 |
+
"id2label": {
|
| 17 |
+
"0": "LABEL_0"
|
| 18 |
+
},
|
| 19 |
"initializer_range": 0.02,
|
| 20 |
"intermediate_size": 3072,
|
| 21 |
+
"label2id": {
|
| 22 |
+
"LABEL_0": 0
|
| 23 |
+
},
|
| 24 |
"layer_norm_eps": 1e-12,
|
| 25 |
"max_position_embeddings": 512,
|
| 26 |
+
"model_type": "electra",
|
| 27 |
"num_attention_heads": 12,
|
| 28 |
"num_hidden_layers": 12,
|
| 29 |
"pad_token_id": 0,
|
| 30 |
"pooling_strategy": "first",
|
| 31 |
"position_embedding_type": "absolute",
|
| 32 |
"query_length": 32,
|
| 33 |
+
"sentence_transformers": {
|
| 34 |
+
"activation_fn": "torch.nn.modules.linear.Identity",
|
| 35 |
+
"version": "4.0.1"
|
| 36 |
+
},
|
| 37 |
"summary_activation": "gelu",
|
| 38 |
"summary_last_dropout": 0.1,
|
| 39 |
"summary_type": "first",
|
| 40 |
"summary_use_proj": true,
|
| 41 |
"torch_dtype": "float32",
|
| 42 |
+
"transformers_version": "4.49.0",
|
| 43 |
"type_vocab_size": 2,
|
| 44 |
"use_cache": true,
|
| 45 |
"vocab_size": 30522
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5516a5f0510d6b44fc7415d3b283118f935c6438391e44e0850d079c0e644796
|
| 3 |
+
size 435593564
|
modeling.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 5 |
+
from transformers import ElectraPreTrainedModel, ElectraModel, ElectraConfig
|
| 6 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class WebisCrossEncoderForSequenceClassification(ElectraPreTrainedModel):
|
| 10 |
+
def __init__(self, config: ElectraConfig):
|
| 11 |
+
super().__init__(config)
|
| 12 |
+
self.num_labels = config.num_labels
|
| 13 |
+
self.config = config
|
| 14 |
+
self.electra = ElectraModel(config)
|
| 15 |
+
self.linear = torch.nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
| 16 |
+
|
| 17 |
+
# Initialize weights and apply final processing
|
| 18 |
+
self.post_init()
|
| 19 |
+
|
| 20 |
+
def forward(
|
| 21 |
+
self,
|
| 22 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 23 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 24 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 25 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 26 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 27 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 28 |
+
labels: Optional[torch.Tensor] = None,
|
| 29 |
+
output_attentions: Optional[bool] = None,
|
| 30 |
+
output_hidden_states: Optional[bool] = None,
|
| 31 |
+
return_dict: Optional[bool] = None,
|
| 32 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 33 |
+
r"""
|
| 34 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 35 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 36 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 37 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 38 |
+
"""
|
| 39 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 40 |
+
|
| 41 |
+
discriminator_hidden_states = self.electra(
|
| 42 |
+
input_ids,
|
| 43 |
+
attention_mask=attention_mask,
|
| 44 |
+
token_type_ids=token_type_ids,
|
| 45 |
+
position_ids=position_ids,
|
| 46 |
+
head_mask=head_mask,
|
| 47 |
+
inputs_embeds=inputs_embeds,
|
| 48 |
+
output_attentions=output_attentions,
|
| 49 |
+
output_hidden_states=output_hidden_states,
|
| 50 |
+
return_dict=return_dict,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
sequence_output = discriminator_hidden_states[0]
|
| 54 |
+
logits = self.linear(sequence_output[:, 0, :]) # Take [CLS] token representation for classification
|
| 55 |
+
|
| 56 |
+
loss = None
|
| 57 |
+
if labels is not None:
|
| 58 |
+
if self.config.problem_type is None:
|
| 59 |
+
if self.num_labels == 1:
|
| 60 |
+
self.config.problem_type = "regression"
|
| 61 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 62 |
+
self.config.problem_type = "single_label_classification"
|
| 63 |
+
else:
|
| 64 |
+
self.config.problem_type = "multi_label_classification"
|
| 65 |
+
|
| 66 |
+
if self.config.problem_type == "regression":
|
| 67 |
+
loss_fct = MSELoss()
|
| 68 |
+
if self.num_labels == 1:
|
| 69 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 70 |
+
else:
|
| 71 |
+
loss = loss_fct(logits, labels)
|
| 72 |
+
elif self.config.problem_type == "single_label_classification":
|
| 73 |
+
loss_fct = CrossEntropyLoss()
|
| 74 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 75 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 76 |
+
loss_fct = BCEWithLogitsLoss()
|
| 77 |
+
loss = loss_fct(logits, labels)
|
| 78 |
+
|
| 79 |
+
if not return_dict:
|
| 80 |
+
output = (logits,) + discriminator_hidden_states[1:]
|
| 81 |
+
return ((loss,) + output) if loss is not None else output
|
| 82 |
+
|
| 83 |
+
return SequenceClassifierOutput(
|
| 84 |
+
loss=loss,
|
| 85 |
+
logits=logits,
|
| 86 |
+
hidden_states=discriminator_hidden_states.hidden_states,
|
| 87 |
+
attentions=discriminator_hidden_states.attentions,
|
| 88 |
+
)
|