Instructions to use zeroentropy/zerank-2-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use zeroentropy/zerank-2-reranker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("zeroentropy/zerank-2-reranker") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Tom Aarsen commited on
Commit ·
d713204
1
Parent(s): 9ae8623
Integrate with transformers, sentence transformers
Browse files- README.md +1 -1
- config.json +7 -1
- modeling_zeranker.py +128 -206
- tokenizer_config.json +4 -1
README.md
CHANGED
|
@@ -41,8 +41,8 @@ query_documents = [
|
|
| 41 |
]
|
| 42 |
|
| 43 |
scores = model.predict(query_documents)
|
| 44 |
-
|
| 45 |
print(scores)
|
|
|
|
| 46 |
```
|
| 47 |
|
| 48 |
The model can also be inferenced using ZeroEntropy's [/models/rerank](https://docs.zeroentropy.dev/api-reference/models/rerank) endpoint, and on [AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-o7avk66msiukc).
|
|
|
|
| 41 |
]
|
| 42 |
|
| 43 |
scores = model.predict(query_documents)
|
|
|
|
| 44 |
print(scores)
|
| 45 |
+
# [0.7531883 0.28894895]
|
| 46 |
```
|
| 47 |
|
| 48 |
The model can also be inferenced using ZeroEntropy's [/models/rerank](https://docs.zeroentropy.dev/api-reference/models/rerank) endpoint, and on [AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-o7avk66msiukc).
|
config.json
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"bos_token_id": 151643,
|
| 8 |
"dtype": "bfloat16",
|
| 9 |
"eos_token_id": 151645,
|
|
@@ -56,6 +60,8 @@
|
|
| 56 |
"num_attention_heads": 32,
|
| 57 |
"num_hidden_layers": 36,
|
| 58 |
"num_key_value_heads": 8,
|
|
|
|
|
|
|
| 59 |
"rms_norm_eps": 1e-06,
|
| 60 |
"rope_scaling": null,
|
| 61 |
"rope_theta": 1000000,
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"ZeroEntropyForSequenceClassification"
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "modeling_zeranker.ZeroEntropyConfig",
|
| 9 |
+
"AutoModelForSequenceClassification": "modeling_zeranker.ZeroEntropyForSequenceClassification"
|
| 10 |
+
},
|
| 11 |
"bos_token_id": 151643,
|
| 12 |
"dtype": "bfloat16",
|
| 13 |
"eos_token_id": 151645,
|
|
|
|
| 60 |
"num_attention_heads": 32,
|
| 61 |
"num_hidden_layers": 36,
|
| 62 |
"num_key_value_heads": 8,
|
| 63 |
+
"num_labels": 1,
|
| 64 |
+
"pad_token_id": 151643,
|
| 65 |
"rms_norm_eps": 1e-06,
|
| 66 |
"rope_scaling": null,
|
| 67 |
"rope_theta": 1000000,
|
modeling_zeranker.py
CHANGED
|
@@ -1,216 +1,138 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
| 7 |
|
|
|
|
| 8 |
import torch
|
| 9 |
-
from transformers
|
| 10 |
-
|
| 11 |
-
from transformers.models.auto.configuration_auto import AutoConfig
|
| 12 |
-
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
| 13 |
-
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 14 |
-
from transformers.models.gemma3.modeling_gemma3 import (
|
| 15 |
-
Gemma3ForCausalLM,
|
| 16 |
-
Gemma3ForConditionalGeneration,
|
| 17 |
-
)
|
| 18 |
-
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
| 19 |
-
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
| 20 |
-
from transformers.tokenization_utils_base import BatchEncoding
|
| 21 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
AutoTokenizer,
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
apply_softmax: Any = None,
|
| 120 |
-
convert_to_numpy: Any = None,
|
| 121 |
-
convert_to_tensor: Any = None,
|
| 122 |
-
) -> list[float]:
|
| 123 |
-
if query_documents is None:
|
| 124 |
-
if sentences is None:
|
| 125 |
-
raise ValueError("query_documents or sentences must be provided")
|
| 126 |
-
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
-
|
| 128 |
-
if not hasattr(self, "inner_model"):
|
| 129 |
-
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
self.inner_model.gradient_checkpointing_enable()
|
| 131 |
-
self.inner_model.eval()
|
| 132 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
-
"Yes", add_special_tokens=False
|
| 134 |
-
)[0]
|
| 135 |
-
|
| 136 |
-
model = self.inner_model
|
| 137 |
-
tokenizer = self.inner_tokenizer
|
| 138 |
-
|
| 139 |
-
query_documents = [
|
| 140 |
-
(query[:2_000], document[:10_000]) for query, document in query_documents
|
| 141 |
-
]
|
| 142 |
-
# Sort
|
| 143 |
-
permutation = list(range(len(query_documents)))
|
| 144 |
-
permutation.sort(
|
| 145 |
-
key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])
|
| 146 |
-
)
|
| 147 |
-
query_documents = [query_documents[i] for i in permutation]
|
| 148 |
-
|
| 149 |
-
# Extract document batches from this line of datapoints
|
| 150 |
-
max_length = 0
|
| 151 |
-
batches: list[list[tuple[str, str]]] = []
|
| 152 |
-
for query, document in query_documents:
|
| 153 |
-
if (
|
| 154 |
-
len(batches) == 0
|
| 155 |
-
or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document))
|
| 156 |
-
> PER_DEVICE_BATCH_SIZE_TOKENS
|
| 157 |
-
):
|
| 158 |
-
batches.append([])
|
| 159 |
-
max_length = 0
|
| 160 |
-
|
| 161 |
-
batches[-1].append((query, document))
|
| 162 |
-
max_length = max(max_length, 20 + len(query) + len(document))
|
| 163 |
-
|
| 164 |
-
# Inference all of the document batches
|
| 165 |
-
all_logits: list[float] = []
|
| 166 |
-
for batch in batches:
|
| 167 |
-
batch_inputs = format_pointwise_datapoints(
|
| 168 |
-
tokenizer,
|
| 169 |
-
batch,
|
| 170 |
)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
outputs = model(**batch_inputs, use_cache=False)
|
| 181 |
|
| 182 |
-
# Extract the logits
|
| 183 |
-
logits = cast(torch.Tensor, outputs.logits)
|
| 184 |
-
attention_mask = cast(torch.Tensor, batch_inputs.attention_mask)
|
| 185 |
last_positions = attention_mask.sum(dim=1) - 1
|
| 186 |
-
|
| 187 |
batch_size = logits.shape[0]
|
| 188 |
-
batch_indices = torch.arange(batch_size, device=
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
yes_logits =
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 201 |
-
|
| 202 |
-
return scores
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 206 |
-
global global_device
|
| 207 |
-
global_device = new_device
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
_CE.predict = predict
|
| 211 |
-
|
| 212 |
-
from transformers import Qwen3Config
|
| 213 |
-
|
| 214 |
-
ZEConfig = Qwen3Config
|
| 215 |
-
|
| 216 |
-
_CE.to = to_device
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from transformers.modeling_outputs import (
|
| 3 |
+
BaseModelOutputWithPast,
|
| 4 |
+
CausalLMOutputWithPast,
|
| 5 |
+
SequenceClassifierOutputWithPast,
|
| 6 |
+
)
|
| 7 |
+
from transformers.utils import auto_docstring
|
| 8 |
+
from transformers.utils.generic import TransformersKwargs, can_return_tuple
|
| 9 |
|
| 10 |
+
from typing import Optional, Union
|
| 11 |
|
| 12 |
+
from transformers.processing_utils import Unpack
|
| 13 |
import torch
|
| 14 |
+
from transformers import Cache, Qwen3Config
|
| 15 |
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3PreTrainedModel, Qwen3Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 17 |
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
logger = logging.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ZeroEntropyTokenizer(PreTrainedTokenizerFast):
|
| 24 |
+
def __init__(self, **kwargs):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
|
| 27 |
+
def __call__(self, pairs, *args, **kwargs):
|
| 28 |
+
input_texts: list[str] = []
|
| 29 |
+
for query, document in pairs:
|
| 30 |
+
messages = [
|
| 31 |
+
{"role": "system", "content": query.strip()},
|
| 32 |
+
{"role": "user", "content": document.strip()},
|
| 33 |
+
]
|
| 34 |
+
input_text = self.apply_chat_template(
|
| 35 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 36 |
+
)
|
| 37 |
+
assert isinstance(input_text, str)
|
| 38 |
+
input_texts.append(input_text)
|
| 39 |
+
|
| 40 |
+
batch_inputs = super().__call__(input_texts, *args, **kwargs)
|
| 41 |
+
return batch_inputs
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ZeroEntropyConfig(Qwen3Config):
|
| 45 |
+
model_type = "zeroentropy"
|
| 46 |
+
|
| 47 |
+
def __init__(self, yes_token_id: int = 9454, **kwargs):
|
| 48 |
+
super().__init__(**kwargs)
|
| 49 |
+
self.yes_token_id = yes_token_id
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ZeroEntropyForSequenceClassification(Qwen3PreTrainedModel):
|
| 53 |
+
config: ZeroEntropyConfig
|
| 54 |
+
|
| 55 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 56 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 57 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 58 |
+
|
| 59 |
+
def __init__(self, config):
|
| 60 |
+
super().__init__(config)
|
| 61 |
+
self.model = Qwen3Model(config)
|
| 62 |
+
self.vocab_size = config.vocab_size
|
| 63 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 64 |
+
|
| 65 |
+
# Initialize weights and apply final processing
|
| 66 |
+
self.post_init()
|
| 67 |
+
|
| 68 |
+
@can_return_tuple
|
| 69 |
+
@auto_docstring
|
| 70 |
+
def forward(
|
| 71 |
+
self,
|
| 72 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 73 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 74 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 75 |
+
past_key_values: Optional[Cache] = None,
|
| 76 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 77 |
+
labels: Optional[torch.LongTensor] = None,
|
| 78 |
+
use_cache: Optional[bool] = None,
|
| 79 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 80 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 81 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 82 |
+
) -> CausalLMOutputWithPast:
|
| 83 |
+
r"""
|
| 84 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 85 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 86 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 87 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
|
| 93 |
+
|
| 94 |
+
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
|
| 95 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
| 96 |
+
|
| 97 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 98 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 99 |
+
|
| 100 |
+
>>> # Generate
|
| 101 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 102 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 103 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 104 |
+
```"""
|
| 105 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 106 |
+
input_ids=input_ids,
|
| 107 |
+
attention_mask=attention_mask,
|
| 108 |
+
position_ids=position_ids,
|
| 109 |
+
past_key_values=past_key_values,
|
| 110 |
+
inputs_embeds=inputs_embeds,
|
| 111 |
+
use_cache=use_cache,
|
| 112 |
+
cache_position=cache_position,
|
| 113 |
+
**kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
|
| 116 |
+
hidden_states = outputs.last_hidden_state
|
| 117 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 118 |
+
slice_indices = (
|
| 119 |
+
slice(-logits_to_keep, None)
|
| 120 |
+
if isinstance(logits_to_keep, int)
|
| 121 |
+
else logits_to_keep
|
| 122 |
+
)
|
| 123 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
| 124 |
|
|
|
|
|
|
|
|
|
|
| 125 |
last_positions = attention_mask.sum(dim=1) - 1
|
|
|
|
| 126 |
batch_size = logits.shape[0]
|
| 127 |
+
batch_indices = torch.arange(batch_size, device=logits.device)
|
| 128 |
+
yes_logits = logits[batch_indices, last_positions, self.config.yes_token_id]
|
| 129 |
+
yes_logits = yes_logits / 5.0
|
| 130 |
+
yes_logits = yes_logits.unsqueeze(-1)
|
| 131 |
+
|
| 132 |
+
return SequenceClassifierOutputWithPast(
|
| 133 |
+
loss=None,
|
| 134 |
+
logits=yes_logits,
|
| 135 |
+
past_key_values=outputs.past_key_values,
|
| 136 |
+
hidden_states=outputs.hidden_states,
|
| 137 |
+
attentions=outputs.attentions,
|
| 138 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer_config.json
CHANGED
|
@@ -226,6 +226,9 @@
|
|
| 226 |
"<|image_pad|>",
|
| 227 |
"<|video_pad|>"
|
| 228 |
],
|
|
|
|
|
|
|
|
|
|
| 229 |
"bos_token": null,
|
| 230 |
"clean_up_tokenization_spaces": false,
|
| 231 |
"eos_token": "<|im_end|>",
|
|
@@ -235,6 +238,6 @@
|
|
| 235 |
"pad_token": "<|endoftext|>",
|
| 236 |
"padding_side": "right",
|
| 237 |
"split_special_tokens": false,
|
| 238 |
-
"tokenizer_class": "
|
| 239 |
"unk_token": null
|
| 240 |
}
|
|
|
|
| 226 |
"<|image_pad|>",
|
| 227 |
"<|video_pad|>"
|
| 228 |
],
|
| 229 |
+
"auto_map": {
|
| 230 |
+
"AutoTokenizer": [null, "modeling_zeranker.ZeroEntropyTokenizer"]
|
| 231 |
+
},
|
| 232 |
"bos_token": null,
|
| 233 |
"clean_up_tokenization_spaces": false,
|
| 234 |
"eos_token": "<|im_end|>",
|
|
|
|
| 238 |
"pad_token": "<|endoftext|>",
|
| 239 |
"padding_side": "right",
|
| 240 |
"split_special_tokens": false,
|
| 241 |
+
"tokenizer_class": "ZeroEntropyTokenizer",
|
| 242 |
"unk_token": null
|
| 243 |
}
|