Instructions to use zeroentropy/zerank-2-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use zeroentropy/zerank-2-reranker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("zeroentropy/zerank-2-reranker") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
wenjin_lee commited on
Commit ·
d9aaf5e
1
Parent(s): 9ae8623
Use torch.inference_mode() and disable gradient checkpointing
Browse files- config.json +4 -1
- modeling_zeranker.py +11 -3
config.json
CHANGED
|
@@ -64,5 +64,8 @@
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
-
"vocab_size": 151936
|
|
|
|
|
|
|
|
|
|
| 68 |
}
|
|
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
+
"vocab_size": 151936,
|
| 68 |
+
"auto_map": {
|
| 69 |
+
"AutoConfig": "modeling_zeranker.ZEConfig"
|
| 70 |
+
}
|
| 71 |
}
|
modeling_zeranker.py
CHANGED
|
@@ -20,11 +20,16 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
|
| 20 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 21 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# pyright: reportUnknownMemberType=false
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
-
PER_DEVICE_BATCH_SIZE_TOKENS =
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
@@ -126,9 +131,11 @@ def predict(
|
|
| 126 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
|
| 128 |
if not hasattr(self, "inner_model"):
|
|
|
|
| 129 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
|
| 131 |
self.inner_model.eval()
|
|
|
|
| 132 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
"Yes", add_special_tokens=False
|
| 134 |
)[0]
|
|
@@ -172,7 +179,8 @@ def predict(
|
|
| 172 |
batch_inputs = batch_inputs.to(global_device)
|
| 173 |
|
| 174 |
try:
|
| 175 |
-
|
|
|
|
| 176 |
except torch.OutOfMemoryError:
|
| 177 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 178 |
torch.cuda.empty_cache()
|
|
|
|
| 20 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 21 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
|
| 23 |
+
import logging
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
print("Running code of HF Model")
|
| 27 |
+
|
| 28 |
# pyright: reportUnknownMemberType=false
|
| 29 |
# pyright: reportUnknownVariableType=false
|
| 30 |
|
| 31 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 32 |
+
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 33 |
global_device = (
|
| 34 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 35 |
)
|
|
|
|
| 131 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 132 |
|
| 133 |
if not hasattr(self, "inner_model"):
|
| 134 |
+
logger.info(f"Memory reserved [Within Model File] Before Loading Model: {torch.cuda.memory_reserved()}")
|
| 135 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 136 |
+
logger.info(f"Memory reserved [Within Model File] After Loading Model: {torch.cuda.memory_reserved()}")
|
| 137 |
self.inner_model.eval()
|
| 138 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 139 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 140 |
"Yes", add_special_tokens=False
|
| 141 |
)[0]
|
|
|
|
| 179 |
batch_inputs = batch_inputs.to(global_device)
|
| 180 |
|
| 181 |
try:
|
| 182 |
+
with torch.inference_mode():
|
| 183 |
+
outputs = model(**batch_inputs, use_cache=False)
|
| 184 |
except torch.OutOfMemoryError:
|
| 185 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 186 |
torch.cuda.empty_cache()
|