prathamj31 commited on
Commit
3aefef2
·
1 Parent(s): f899c80

Remove unused __init__ method and related code

Browse files

The __init__ patching never worked due to timing (CrossEncoder instance
is created before module is loaded). Model loading happens in to_device()
instead. Also removed unused 'types' import.

Files changed (1) hide show
  1. modeling_zeranker.py +0 -23
modeling_zeranker.py CHANGED
@@ -3,8 +3,6 @@ from sentence_transformers import CrossEncoder as _CE
3
  import math
4
  import logging
5
  from typing import cast, Any
6
- import types
7
-
8
 
9
  import torch
10
  from transformers.configuration_utils import PretrainedConfig
@@ -116,26 +114,6 @@ def load_model(
116
  return tokenizer, model
117
 
118
 
119
- # Store the original __init__ method
120
- _original_init = _CE.__init__
121
-
122
-
123
- def __init__(self, *args: Any, **kwargs: Any) -> None:
124
- logger.info("Initializing CrossEncoder with eager model loading")
125
- # Call the original CrossEncoder __init__ first
126
- _original_init(self, *args, **kwargs)
127
-
128
- # Load the model immediately on instantiation
129
- logger.info("Loading model on instantiation (no lazy loading)")
130
- self.inner_tokenizer, self.inner_model = load_model(global_device)
131
- self.inner_model.eval()
132
- self.inner_model.gradient_checkpointing_disable()
133
- self.inner_yes_token_id = self.inner_tokenizer.encode(
134
- "Yes", add_special_tokens=False
135
- )[0]
136
- logger.info(f"CrossEncoder initialization complete. Yes token ID: {self.inner_yes_token_id}")
137
-
138
-
139
  def predict(
140
  self,
141
  query_documents: list[tuple[str, str]] | None = None,
@@ -247,7 +225,6 @@ def to_device(self: _CE, new_device: torch.device) -> None:
247
  logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
248
 
249
 
250
- _CE.__init__ = __init__
251
  _CE.predict = predict
252
  _CE.to = to_device
253
 
 
3
  import math
4
  import logging
5
  from typing import cast, Any
 
 
6
 
7
  import torch
8
  from transformers.configuration_utils import PretrainedConfig
 
114
  return tokenizer, model
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def predict(
118
  self,
119
  query_documents: list[tuple[str, str]] | None = None,
 
225
  logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
226
 
227
 
 
228
  _CE.predict = predict
229
  _CE.to = to_device
230