Commit
·
0ac03c0
1
Parent(s):
cd705e3
Remove lazy loading and load model on instantiation
Browse filesLoad the model immediately when the CrossEncoder class is instantiated
instead of waiting for the first predict() call.
- modeling_zeranker.py +12 -10
modeling_zeranker.py
CHANGED
|
@@ -108,6 +108,16 @@ def load_model(
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def predict(
|
| 112 |
self,
|
| 113 |
query_documents: list[tuple[str, str]] | None = None,
|
|
@@ -125,14 +135,6 @@ def predict(
|
|
| 125 |
raise ValueError("query_documents or sentences must be provided")
|
| 126 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
|
| 128 |
-
if not hasattr(self, "inner_model"):
|
| 129 |
-
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
self.inner_model.eval()
|
| 131 |
-
self.inner_model.gradient_checkpointing_disable()
|
| 132 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
-
"Yes", add_special_tokens=False
|
| 134 |
-
)[0]
|
| 135 |
-
|
| 136 |
model = self.inner_model
|
| 137 |
tokenizer = self.inner_tokenizer
|
| 138 |
|
|
@@ -208,10 +210,10 @@ def to_device(self: _CE, new_device: torch.device) -> None:
|
|
| 208 |
global_device = new_device
|
| 209 |
|
| 210 |
|
|
|
|
| 211 |
_CE.predict = predict
|
|
|
|
| 212 |
|
| 213 |
from transformers import Qwen3Config
|
| 214 |
|
| 215 |
ZEConfig = Qwen3Config
|
| 216 |
-
|
| 217 |
-
_CE.to = to_device
|
|
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
| 111 |
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 112 |
+
# Load the model immediately on instantiation
|
| 113 |
+
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 114 |
+
self.inner_model.eval()
|
| 115 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 116 |
+
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 117 |
+
"Yes", add_special_tokens=False
|
| 118 |
+
)[0]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
def predict(
|
| 122 |
self,
|
| 123 |
query_documents: list[tuple[str, str]] | None = None,
|
|
|
|
| 135 |
raise ValueError("query_documents or sentences must be provided")
|
| 136 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
model = self.inner_model
|
| 139 |
tokenizer = self.inner_tokenizer
|
| 140 |
|
|
|
|
| 210 |
global_device = new_device
|
| 211 |
|
| 212 |
|
| 213 |
+
_CE.__init__ = __init__
|
| 214 |
_CE.predict = predict
|
| 215 |
+
_CE.to = to_device
|
| 216 |
|
| 217 |
from transformers import Qwen3Config
|
| 218 |
|
| 219 |
ZEConfig = Qwen3Config
|
|
|
|
|
|