Instructions to use zeroentropy/zerank-2-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use zeroentropy/zerank-2-reranker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("zeroentropy/zerank-2-reranker") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Commit ·
456ffeb
1
Parent(s): 16b4531
Add logging statements throughout model lifecycle
Browse files- Add logger initialization using __name__
- Log model loading process including device and config type
- Log CrossEncoder initialization stages
- Log prediction batching and processing
- Replace print statements with logger for OOM handling
- Log device changes in to_device function
- modeling_zeranker.py +21 -3
modeling_zeranker.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
|
|
|
| 4 |
from typing import cast, Any
|
| 5 |
import types
|
| 6 |
|
|
@@ -23,6 +24,8 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
| 23 |
# pyright: reportUnknownMemberType=false
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
|
|
|
|
|
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 28 |
global_device = (
|
|
@@ -74,9 +77,12 @@ def load_model(
|
|
| 74 |
if device is None:
|
| 75 |
device = global_device
|
| 76 |
|
|
|
|
|
|
|
| 77 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 78 |
assert isinstance(config, PretrainedConfig)
|
| 79 |
|
|
|
|
| 80 |
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
MODEL_PATH,
|
| 82 |
torch_dtype="auto",
|
|
@@ -93,6 +99,7 @@ def load_model(
|
|
| 93 |
| Qwen3ForCausalLM,
|
| 94 |
)
|
| 95 |
|
|
|
|
| 96 |
tokenizer = cast(
|
| 97 |
AutoTokenizer,
|
| 98 |
AutoTokenizer.from_pretrained(
|
|
@@ -105,6 +112,7 @@ def load_model(
|
|
| 105 |
if tokenizer.pad_token is None:
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
|
|
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
|
@@ -113,16 +121,19 @@ _original_init = _CE.__init__
|
|
| 113 |
|
| 114 |
|
| 115 |
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
|
|
| 116 |
# Call the original CrossEncoder __init__ first
|
| 117 |
_original_init(self, *args, **kwargs)
|
| 118 |
|
| 119 |
# Load the model immediately on instantiation
|
|
|
|
| 120 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 121 |
self.inner_model.eval()
|
| 122 |
self.inner_model.gradient_checkpointing_disable()
|
| 123 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 124 |
"Yes", add_special_tokens=False
|
| 125 |
)[0]
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
def predict(
|
|
@@ -142,6 +153,8 @@ def predict(
|
|
| 142 |
raise ValueError("query_documents or sentences must be provided")
|
| 143 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 144 |
|
|
|
|
|
|
|
| 145 |
model = self.inner_model
|
| 146 |
tokenizer = self.inner_tokenizer
|
| 147 |
|
|
@@ -170,9 +183,12 @@ def predict(
|
|
| 170 |
batches[-1].append((query, document))
|
| 171 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 172 |
|
|
|
|
|
|
|
| 173 |
# Inference all of the document batches
|
| 174 |
all_logits: list[float] = []
|
| 175 |
-
for batch in batches:
|
|
|
|
| 176 |
batch_inputs = format_pointwise_datapoints(
|
| 177 |
tokenizer,
|
| 178 |
batch,
|
|
@@ -184,9 +200,9 @@ def predict(
|
|
| 184 |
with torch.inference_mode():
|
| 185 |
outputs = model(**batch_inputs, use_cache=False)
|
| 186 |
except torch.OutOfMemoryError:
|
| 187 |
-
|
| 188 |
torch.cuda.empty_cache()
|
| 189 |
-
|
| 190 |
outputs = model(**batch_inputs, use_cache=False)
|
| 191 |
|
| 192 |
# Extract the logits
|
|
@@ -209,11 +225,13 @@ def predict(
|
|
| 209 |
# Unsort by indices
|
| 210 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 211 |
|
|
|
|
| 212 |
return scores
|
| 213 |
|
| 214 |
|
| 215 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 216 |
global global_device
|
|
|
|
| 217 |
global_device = new_device
|
| 218 |
|
| 219 |
|
|
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
| 4 |
+
import logging
|
| 5 |
from typing import cast, Any
|
| 6 |
import types
|
| 7 |
|
|
|
|
| 24 |
# pyright: reportUnknownMemberType=false
|
| 25 |
# pyright: reportUnknownVariableType=false
|
| 26 |
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 30 |
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 31 |
global_device = (
|
|
|
|
| 77 |
if device is None:
|
| 78 |
device = global_device
|
| 79 |
|
| 80 |
+
logger.info(f"Loading model from {MODEL_PATH} on device: {device}")
|
| 81 |
+
|
| 82 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 83 |
assert isinstance(config, PretrainedConfig)
|
| 84 |
|
| 85 |
+
logger.info(f"Loading model with config type: {config.model_type}")
|
| 86 |
model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
MODEL_PATH,
|
| 88 |
torch_dtype="auto",
|
|
|
|
| 99 |
| Qwen3ForCausalLM,
|
| 100 |
)
|
| 101 |
|
| 102 |
+
logger.info("Loading tokenizer")
|
| 103 |
tokenizer = cast(
|
| 104 |
AutoTokenizer,
|
| 105 |
AutoTokenizer.from_pretrained(
|
|
|
|
| 112 |
if tokenizer.pad_token is None:
|
| 113 |
tokenizer.pad_token = tokenizer.eos_token
|
| 114 |
|
| 115 |
+
logger.info("Model and tokenizer loaded successfully")
|
| 116 |
return tokenizer, model
|
| 117 |
|
| 118 |
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 124 |
+
logger.info("Initializing CrossEncoder with eager model loading")
|
| 125 |
# Call the original CrossEncoder __init__ first
|
| 126 |
_original_init(self, *args, **kwargs)
|
| 127 |
|
| 128 |
# Load the model immediately on instantiation
|
| 129 |
+
logger.info("Loading model on instantiation (no lazy loading)")
|
| 130 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 131 |
self.inner_model.eval()
|
| 132 |
self.inner_model.gradient_checkpointing_disable()
|
| 133 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 134 |
"Yes", add_special_tokens=False
|
| 135 |
)[0]
|
| 136 |
+
logger.info(f"CrossEncoder initialization complete. Yes token ID: {self.inner_yes_token_id}")
|
| 137 |
|
| 138 |
|
| 139 |
def predict(
|
|
|
|
| 153 |
raise ValueError("query_documents or sentences must be provided")
|
| 154 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 155 |
|
| 156 |
+
logger.info(f"Starting prediction for {len(query_documents)} query-document pairs")
|
| 157 |
+
|
| 158 |
model = self.inner_model
|
| 159 |
tokenizer = self.inner_tokenizer
|
| 160 |
|
|
|
|
| 183 |
batches[-1].append((query, document))
|
| 184 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 185 |
|
| 186 |
+
logger.info(f"Created {len(batches)} batches for inference")
|
| 187 |
+
|
| 188 |
# Inference all of the document batches
|
| 189 |
all_logits: list[float] = []
|
| 190 |
+
for batch_idx, batch in enumerate(batches):
|
| 191 |
+
logger.debug(f"Processing batch {batch_idx + 1}/{len(batches)} with {len(batch)} pairs")
|
| 192 |
batch_inputs = format_pointwise_datapoints(
|
| 193 |
tokenizer,
|
| 194 |
batch,
|
|
|
|
| 200 |
with torch.inference_mode():
|
| 201 |
outputs = model(**batch_inputs, use_cache=False)
|
| 202 |
except torch.OutOfMemoryError:
|
| 203 |
+
logger.warning(f"GPU OOM! Memory reserved: {torch.cuda.memory_reserved()}")
|
| 204 |
torch.cuda.empty_cache()
|
| 205 |
+
logger.info(f"GPU cache cleared. Memory reserved: {torch.cuda.memory_reserved()}")
|
| 206 |
outputs = model(**batch_inputs, use_cache=False)
|
| 207 |
|
| 208 |
# Extract the logits
|
|
|
|
| 225 |
# Unsort by indices
|
| 226 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 227 |
|
| 228 |
+
logger.info(f"Prediction complete. Generated {len(scores)} scores")
|
| 229 |
return scores
|
| 230 |
|
| 231 |
|
| 232 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 233 |
global global_device
|
| 234 |
+
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 235 |
global_device = new_device
|
| 236 |
|
| 237 |
|