prathamj31 commited on
Commit
f899c80
·
1 Parent(s): 456ffeb

Load model during to_device call for eager loading

Browse files

Move model loading to to_device() since __init__ patching doesn't work
due to timing (CrossEncoder instance is created before this module
is loaded from HuggingFace). to_device() is called during CrossEncoder
initialization, making this effectively eager loading.

Files changed (1) hide show
  1. modeling_zeranker.py +12 -0
modeling_zeranker.py CHANGED
@@ -234,6 +234,18 @@ def to_device(self: _CE, new_device: torch.device) -> None:
234
  logger.info(f"Changing device from {global_device} to {new_device}")
235
  global_device = new_device
236
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  _CE.__init__ = __init__
239
  _CE.predict = predict
 
234
  logger.info(f"Changing device from {global_device} to {new_device}")
235
  global_device = new_device
236
 
237
+ # Load the model now since __init__ patching doesn't work due to timing
238
+ # (CrossEncoder instance is created before this module is loaded)
239
+ if not hasattr(self, "inner_model"):
240
+ logger.info("Loading model during device setup (eager loading)")
241
+ self.inner_tokenizer, self.inner_model = load_model(global_device)
242
+ self.inner_model.eval()
243
+ self.inner_model.gradient_checkpointing_disable()
244
+ self.inner_yes_token_id = self.inner_tokenizer.encode(
245
+ "Yes", add_special_tokens=False
246
+ )[0]
247
+ logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
248
+
249
 
250
  _CE.__init__ = __init__
251
  _CE.predict = predict