prathamj31 commited on
Commit
0ac03c0
·
1 Parent(s): cd705e3

Remove lazy loading and load model on instantiation

Browse files

Load the model immediately when the CrossEncoder class is instantiated
instead of waiting for the first predict() call.

Files changed (1) hide show
  1. modeling_zeranker.py +12 -10
modeling_zeranker.py CHANGED
@@ -108,6 +108,16 @@ def load_model(
108
  return tokenizer, model
109
 
110
 
 
 
 
 
 
 
 
 
 
 
111
  def predict(
112
  self,
113
  query_documents: list[tuple[str, str]] | None = None,
@@ -125,14 +135,6 @@ def predict(
125
  raise ValueError("query_documents or sentences must be provided")
126
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
127
 
128
- if not hasattr(self, "inner_model"):
129
- self.inner_tokenizer, self.inner_model = load_model(global_device)
130
- self.inner_model.eval()
131
- self.inner_model.gradient_checkpointing_disable()
132
- self.inner_yes_token_id = self.inner_tokenizer.encode(
133
- "Yes", add_special_tokens=False
134
- )[0]
135
-
136
  model = self.inner_model
137
  tokenizer = self.inner_tokenizer
138
 
@@ -208,10 +210,10 @@ def to_device(self: _CE, new_device: torch.device) -> None:
208
  global_device = new_device
209
 
210
 
 
211
  _CE.predict = predict
 
212
 
213
  from transformers import Qwen3Config
214
 
215
  ZEConfig = Qwen3Config
216
-
217
- _CE.to = to_device
 
108
  return tokenizer, model
109
 
110
 
111
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
112
+ # Load the model immediately on instantiation
113
+ self.inner_tokenizer, self.inner_model = load_model(global_device)
114
+ self.inner_model.eval()
115
+ self.inner_model.gradient_checkpointing_disable()
116
+ self.inner_yes_token_id = self.inner_tokenizer.encode(
117
+ "Yes", add_special_tokens=False
118
+ )[0]
119
+
120
+
121
  def predict(
122
  self,
123
  query_documents: list[tuple[str, str]] | None = None,
 
135
  raise ValueError("query_documents or sentences must be provided")
136
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
137
 
 
 
 
 
 
 
 
 
138
  model = self.inner_model
139
  tokenizer = self.inner_tokenizer
140
 
 
210
  global_device = new_device
211
 
212
 
213
+ _CE.__init__ = __init__
214
  _CE.predict = predict
215
+ _CE.to = to_device
216
 
217
  from transformers import Qwen3Config
218
 
219
  ZEConfig = Qwen3Config