Text Ranking
sentence-transformers
Safetensors
English
qwen3
finance
legal
code
stem
medical
custom_code
Instructions to use zeroentropy/zerank-1-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use zeroentropy/zerank-1-reranker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("zeroentropy/zerank-1-reranker", trust_remote_code=True) query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Update modeling_zeranker.py
Browse files- modeling_zeranker.py +39 -13
modeling_zeranker.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
| 4 |
-
from typing import cast
|
| 5 |
import types
|
| 6 |
|
| 7 |
import torch
|
|
@@ -21,8 +21,11 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
| 21 |
# pyright: reportUnknownMemberType=false
|
| 22 |
# pyright: reportUnknownVariableType=false
|
| 23 |
|
| 24 |
-
MODEL_PATH = "zeroentropy/
|
| 25 |
PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def format_pointwise_datapoints(
|
|
@@ -67,7 +70,7 @@ def load_model(
|
|
| 67 |
| Qwen3ForCausalLM,
|
| 68 |
]:
|
| 69 |
if device is None:
|
| 70 |
-
device =
|
| 71 |
|
| 72 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 73 |
assert isinstance(config, PretrainedConfig)
|
|
@@ -80,7 +83,6 @@ def load_model(
|
|
| 80 |
)
|
| 81 |
if config.model_type == "llama":
|
| 82 |
model.config.attn_implementation = "flash_attention_2"
|
| 83 |
-
print(f"Model Type: {config.model_type}")
|
| 84 |
assert isinstance(
|
| 85 |
model,
|
| 86 |
LlamaForCausalLM
|
|
@@ -104,13 +106,30 @@ def load_model(
|
|
| 104 |
return tokenizer, model
|
| 105 |
|
| 106 |
|
| 107 |
-
def predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if not hasattr(self, "inner_model"):
|
| 109 |
-
self.inner_tokenizer, self.inner_model = load_model(
|
| 110 |
self.inner_model.gradient_checkpointing_enable()
|
| 111 |
self.inner_model.eval()
|
| 112 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
model = self.inner_model
|
| 116 |
tokenizer = self.inner_tokenizer
|
|
@@ -120,11 +139,11 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
|
|
| 120 |
]
|
| 121 |
# Sort
|
| 122 |
permutation = list(range(len(query_documents)))
|
| 123 |
-
permutation.sort(
|
|
|
|
|
|
|
| 124 |
query_documents = [query_documents[i] for i in permutation]
|
| 125 |
|
| 126 |
-
device = torch.device("cuda")
|
| 127 |
-
|
| 128 |
# Extract document batches from this line of datapoints
|
| 129 |
max_length = 0
|
| 130 |
batches: list[list[tuple[str, str]]] = []
|
|
@@ -148,7 +167,7 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
|
|
| 148 |
batch,
|
| 149 |
)
|
| 150 |
|
| 151 |
-
batch_inputs = batch_inputs.to(
|
| 152 |
|
| 153 |
try:
|
| 154 |
outputs = model(**batch_inputs, use_cache=False)
|
|
@@ -164,7 +183,7 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
|
|
| 164 |
last_positions = attention_mask.sum(dim=1) - 1
|
| 165 |
|
| 166 |
batch_size = logits.shape[0]
|
| 167 |
-
batch_indices = torch.arange(batch_size, device=
|
| 168 |
last_logits = logits[batch_indices, last_positions]
|
| 169 |
|
| 170 |
yes_logits = last_logits[:, self.inner_yes_token_id]
|
|
@@ -181,8 +200,15 @@ def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
|
|
| 181 |
return scores
|
| 182 |
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
_CE.predict = predict
|
| 185 |
|
| 186 |
from transformers import Qwen3Config
|
| 187 |
|
| 188 |
ZEConfig = Qwen3Config
|
|
|
|
|
|
|
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
| 4 |
+
from typing import cast, Any
|
| 5 |
import types
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 21 |
# pyright: reportUnknownMemberType=false
|
| 22 |
# pyright: reportUnknownVariableType=false
|
| 23 |
|
| 24 |
+
MODEL_PATH = "zeroentropy/zerank-1"
|
| 25 |
PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
|
| 26 |
+
global_device = (
|
| 27 |
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 28 |
+
)
|
| 29 |
|
| 30 |
|
| 31 |
def format_pointwise_datapoints(
|
|
|
|
| 70 |
| Qwen3ForCausalLM,
|
| 71 |
]:
|
| 72 |
if device is None:
|
| 73 |
+
device = global_device
|
| 74 |
|
| 75 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 76 |
assert isinstance(config, PretrainedConfig)
|
|
|
|
| 83 |
)
|
| 84 |
if config.model_type == "llama":
|
| 85 |
model.config.attn_implementation = "flash_attention_2"
|
|
|
|
| 86 |
assert isinstance(
|
| 87 |
model,
|
| 88 |
LlamaForCausalLM
|
|
|
|
| 106 |
return tokenizer, model
|
| 107 |
|
| 108 |
|
| 109 |
+
def predict(
|
| 110 |
+
self,
|
| 111 |
+
query_documents: list[tuple[str, str]] | None = None,
|
| 112 |
+
*,
|
| 113 |
+
sentences: Any = None,
|
| 114 |
+
batch_size: Any = None,
|
| 115 |
+
show_progress_bar: Any = None,
|
| 116 |
+
activation_fn: Any = None,
|
| 117 |
+
apply_softmax: Any = None,
|
| 118 |
+
convert_to_numpy: Any = None,
|
| 119 |
+
convert_to_tensor: Any = None,
|
| 120 |
+
) -> list[float]:
|
| 121 |
+
if query_documents is None:
|
| 122 |
+
if sentences is None:
|
| 123 |
+
raise ValueError("query_documents or sentences must be provided")
|
| 124 |
+
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 125 |
+
|
| 126 |
if not hasattr(self, "inner_model"):
|
| 127 |
+
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 128 |
self.inner_model.gradient_checkpointing_enable()
|
| 129 |
self.inner_model.eval()
|
| 130 |
+
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 131 |
+
"Yes", add_special_tokens=False
|
| 132 |
+
)[0]
|
| 133 |
|
| 134 |
model = self.inner_model
|
| 135 |
tokenizer = self.inner_tokenizer
|
|
|
|
| 139 |
]
|
| 140 |
# Sort
|
| 141 |
permutation = list(range(len(query_documents)))
|
| 142 |
+
permutation.sort(
|
| 143 |
+
key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])
|
| 144 |
+
)
|
| 145 |
query_documents = [query_documents[i] for i in permutation]
|
| 146 |
|
|
|
|
|
|
|
| 147 |
# Extract document batches from this line of datapoints
|
| 148 |
max_length = 0
|
| 149 |
batches: list[list[tuple[str, str]]] = []
|
|
|
|
| 167 |
batch,
|
| 168 |
)
|
| 169 |
|
| 170 |
+
batch_inputs = batch_inputs.to(global_device)
|
| 171 |
|
| 172 |
try:
|
| 173 |
outputs = model(**batch_inputs, use_cache=False)
|
|
|
|
| 183 |
last_positions = attention_mask.sum(dim=1) - 1
|
| 184 |
|
| 185 |
batch_size = logits.shape[0]
|
| 186 |
+
batch_indices = torch.arange(batch_size, device=global_device)
|
| 187 |
last_logits = logits[batch_indices, last_positions]
|
| 188 |
|
| 189 |
yes_logits = last_logits[:, self.inner_yes_token_id]
|
|
|
|
| 200 |
return scores
|
| 201 |
|
| 202 |
|
| 203 |
+
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 204 |
+
global global_device
|
| 205 |
+
global_device = new_device
|
| 206 |
+
|
| 207 |
+
|
| 208 |
_CE.predict = predict
|
| 209 |
|
| 210 |
from transformers import Qwen3Config
|
| 211 |
|
| 212 |
ZEConfig = Qwen3Config
|
| 213 |
+
|
| 214 |
+
_CE.to = to_device
|