prathamj31 commited on
Commit
456ffeb
·
1 Parent(s): 16b4531

Add logging statements throughout model lifecycle

Browse files

- Add logger initialization using __name__
- Log model loading process including device and config type
- Log CrossEncoder initialization stages
- Log prediction batching and processing
- Replace print statements with logger for OOM handling
- Log device changes in to_device function

Files changed (1) hide show
  1. modeling_zeranker.py +21 -3
modeling_zeranker.py CHANGED
@@ -1,6 +1,7 @@
1
  from sentence_transformers import CrossEncoder as _CE
2
 
3
  import math
 
4
  from typing import cast, Any
5
  import types
6
 
@@ -23,6 +24,8 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
23
  # pyright: reportUnknownMemberType=false
24
  # pyright: reportUnknownVariableType=false
25
 
 
 
26
  MODEL_PATH = "zeroentropy/zerank-2"
27
  PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
28
  global_device = (
@@ -74,9 +77,12 @@ def load_model(
74
  if device is None:
75
  device = global_device
76
 
 
 
77
  config = AutoConfig.from_pretrained(MODEL_PATH)
78
  assert isinstance(config, PretrainedConfig)
79
 
 
80
  model = AutoModelForCausalLM.from_pretrained(
81
  MODEL_PATH,
82
  torch_dtype="auto",
@@ -93,6 +99,7 @@ def load_model(
93
  | Qwen3ForCausalLM,
94
  )
95
 
 
96
  tokenizer = cast(
97
  AutoTokenizer,
98
  AutoTokenizer.from_pretrained(
@@ -105,6 +112,7 @@ def load_model(
105
  if tokenizer.pad_token is None:
106
  tokenizer.pad_token = tokenizer.eos_token
107
 
 
108
  return tokenizer, model
109
 
110
 
@@ -113,16 +121,19 @@ _original_init = _CE.__init__
113
 
114
 
115
  def __init__(self, *args: Any, **kwargs: Any) -> None:
 
116
  # Call the original CrossEncoder __init__ first
117
  _original_init(self, *args, **kwargs)
118
 
119
  # Load the model immediately on instantiation
 
120
  self.inner_tokenizer, self.inner_model = load_model(global_device)
121
  self.inner_model.eval()
122
  self.inner_model.gradient_checkpointing_disable()
123
  self.inner_yes_token_id = self.inner_tokenizer.encode(
124
  "Yes", add_special_tokens=False
125
  )[0]
 
126
 
127
 
128
  def predict(
@@ -142,6 +153,8 @@ def predict(
142
  raise ValueError("query_documents or sentences must be provided")
143
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
144
 
 
 
145
  model = self.inner_model
146
  tokenizer = self.inner_tokenizer
147
 
@@ -170,9 +183,12 @@ def predict(
170
  batches[-1].append((query, document))
171
  max_length = max(max_length, 20 + len(query) + len(document))
172
 
 
 
173
  # Inference all of the document batches
174
  all_logits: list[float] = []
175
- for batch in batches:
 
176
  batch_inputs = format_pointwise_datapoints(
177
  tokenizer,
178
  batch,
@@ -184,9 +200,9 @@ def predict(
184
  with torch.inference_mode():
185
  outputs = model(**batch_inputs, use_cache=False)
186
  except torch.OutOfMemoryError:
187
- print(f"GPU OOM! {torch.cuda.memory_reserved()}")
188
  torch.cuda.empty_cache()
189
- print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}")
190
  outputs = model(**batch_inputs, use_cache=False)
191
 
192
  # Extract the logits
@@ -209,11 +225,13 @@ def predict(
209
  # Unsort by indices
210
  scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
211
 
 
212
  return scores
213
 
214
 
215
  def to_device(self: _CE, new_device: torch.device) -> None:
216
  global global_device
 
217
  global_device = new_device
218
 
219
 
 
1
  from sentence_transformers import CrossEncoder as _CE
2
 
3
  import math
4
+ import logging
5
  from typing import cast, Any
6
  import types
7
 
 
24
  # pyright: reportUnknownMemberType=false
25
  # pyright: reportUnknownVariableType=false
26
 
27
+ logger = logging.getLogger(__name__)
28
+
29
  MODEL_PATH = "zeroentropy/zerank-2"
30
  PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
31
  global_device = (
 
77
  if device is None:
78
  device = global_device
79
 
80
+ logger.info(f"Loading model from {MODEL_PATH} on device: {device}")
81
+
82
  config = AutoConfig.from_pretrained(MODEL_PATH)
83
  assert isinstance(config, PretrainedConfig)
84
 
85
+ logger.info(f"Loading model with config type: {config.model_type}")
86
  model = AutoModelForCausalLM.from_pretrained(
87
  MODEL_PATH,
88
  torch_dtype="auto",
 
99
  | Qwen3ForCausalLM,
100
  )
101
 
102
+ logger.info("Loading tokenizer")
103
  tokenizer = cast(
104
  AutoTokenizer,
105
  AutoTokenizer.from_pretrained(
 
112
  if tokenizer.pad_token is None:
113
  tokenizer.pad_token = tokenizer.eos_token
114
 
115
+ logger.info("Model and tokenizer loaded successfully")
116
  return tokenizer, model
117
 
118
 
 
121
 
122
 
123
  def __init__(self, *args: Any, **kwargs: Any) -> None:
124
+ logger.info("Initializing CrossEncoder with eager model loading")
125
  # Call the original CrossEncoder __init__ first
126
  _original_init(self, *args, **kwargs)
127
 
128
  # Load the model immediately on instantiation
129
+ logger.info("Loading model on instantiation (no lazy loading)")
130
  self.inner_tokenizer, self.inner_model = load_model(global_device)
131
  self.inner_model.eval()
132
  self.inner_model.gradient_checkpointing_disable()
133
  self.inner_yes_token_id = self.inner_tokenizer.encode(
134
  "Yes", add_special_tokens=False
135
  )[0]
136
+ logger.info(f"CrossEncoder initialization complete. Yes token ID: {self.inner_yes_token_id}")
137
 
138
 
139
  def predict(
 
153
  raise ValueError("query_documents or sentences must be provided")
154
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
155
 
156
+ logger.info(f"Starting prediction for {len(query_documents)} query-document pairs")
157
+
158
  model = self.inner_model
159
  tokenizer = self.inner_tokenizer
160
 
 
183
  batches[-1].append((query, document))
184
  max_length = max(max_length, 20 + len(query) + len(document))
185
 
186
+ logger.info(f"Created {len(batches)} batches for inference")
187
+
188
  # Inference all of the document batches
189
  all_logits: list[float] = []
190
+ for batch_idx, batch in enumerate(batches):
191
+ logger.debug(f"Processing batch {batch_idx + 1}/{len(batches)} with {len(batch)} pairs")
192
  batch_inputs = format_pointwise_datapoints(
193
  tokenizer,
194
  batch,
 
200
  with torch.inference_mode():
201
  outputs = model(**batch_inputs, use_cache=False)
202
  except torch.OutOfMemoryError:
203
+ logger.warning(f"GPU OOM! Memory reserved: {torch.cuda.memory_reserved()}")
204
  torch.cuda.empty_cache()
205
+ logger.info(f"GPU cache cleared. Memory reserved: {torch.cuda.memory_reserved()}")
206
  outputs = model(**batch_inputs, use_cache=False)
207
 
208
  # Extract the logits
 
225
  # Unsort by indices
226
  scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
227
 
228
+ logger.info(f"Prediction complete. Generated {len(scores)} scores")
229
  return scores
230
 
231
 
232
  def to_device(self: _CE, new_device: torch.device) -> None:
233
  global global_device
234
+ logger.info(f"Changing device from {global_device} to {new_device}")
235
  global_device = new_device
236
 
237