Commit
·
f899c80
1
Parent(s):
456ffeb
Load model during to_device call for eager loading
Browse filesMove model loading to to_device() since __init__ patching doesn't work
due to timing (CrossEncoder instance is created before this module
is loaded from HuggingFace). to_device() is called during CrossEncoder
initialization, making this effectively eager loading.
- modeling_zeranker.py +12 -0
modeling_zeranker.py
CHANGED
|
@@ -234,6 +234,18 @@ def to_device(self: _CE, new_device: torch.device) -> None:
|
|
| 234 |
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 235 |
global_device = new_device
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
_CE.__init__ = __init__
|
| 239 |
_CE.predict = predict
|
|
|
|
| 234 |
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 235 |
global_device = new_device
|
| 236 |
|
| 237 |
+
# Load the model now since __init__ patching doesn't work due to timing
|
| 238 |
+
# (CrossEncoder instance is created before this module is loaded)
|
| 239 |
+
if not hasattr(self, "inner_model"):
|
| 240 |
+
logger.info("Loading model during device setup (eager loading)")
|
| 241 |
+
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 242 |
+
self.inner_model.eval()
|
| 243 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 244 |
+
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 245 |
+
"Yes", add_special_tokens=False
|
| 246 |
+
)[0]
|
| 247 |
+
logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
|
| 248 |
+
|
| 249 |
|
| 250 |
_CE.__init__ = __init__
|
| 251 |
_CE.predict = predict
|