Commit
·
456ffeb
1
Parent(s):
16b4531
Add logging statements throughout model lifecycle
Browse files- Add logger initialization using __name__
- Log model loading process including device and config type
- Log CrossEncoder initialization stages
- Log prediction batching and processing
- Replace print statements with logger for OOM handling
- Log device changes in to_device function
- modeling_zeranker.py +21 -3
modeling_zeranker.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
|
|
|
| 4 |
from typing import cast, Any
|
| 5 |
import types
|
| 6 |
|
|
@@ -23,6 +24,8 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
| 23 |
# pyright: reportUnknownMemberType=false
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
|
|
|
|
|
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 28 |
global_device = (
|
|
@@ -74,9 +77,12 @@ def load_model(
|
|
| 74 |
if device is None:
|
| 75 |
device = global_device
|
| 76 |
|
|
|
|
|
|
|
| 77 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 78 |
assert isinstance(config, PretrainedConfig)
|
| 79 |
|
|
|
|
| 80 |
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
MODEL_PATH,
|
| 82 |
torch_dtype="auto",
|
|
@@ -93,6 +99,7 @@ def load_model(
|
|
| 93 |
| Qwen3ForCausalLM,
|
| 94 |
)
|
| 95 |
|
|
|
|
| 96 |
tokenizer = cast(
|
| 97 |
AutoTokenizer,
|
| 98 |
AutoTokenizer.from_pretrained(
|
|
@@ -105,6 +112,7 @@ def load_model(
|
|
| 105 |
if tokenizer.pad_token is None:
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
|
|
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
|
@@ -113,16 +121,19 @@ _original_init = _CE.__init__
|
|
| 113 |
|
| 114 |
|
| 115 |
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
|
|
| 116 |
# Call the original CrossEncoder __init__ first
|
| 117 |
_original_init(self, *args, **kwargs)
|
| 118 |
|
| 119 |
# Load the model immediately on instantiation
|
|
|
|
| 120 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 121 |
self.inner_model.eval()
|
| 122 |
self.inner_model.gradient_checkpointing_disable()
|
| 123 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 124 |
"Yes", add_special_tokens=False
|
| 125 |
)[0]
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
def predict(
|
|
@@ -142,6 +153,8 @@ def predict(
|
|
| 142 |
raise ValueError("query_documents or sentences must be provided")
|
| 143 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 144 |
|
|
|
|
|
|
|
| 145 |
model = self.inner_model
|
| 146 |
tokenizer = self.inner_tokenizer
|
| 147 |
|
|
@@ -170,9 +183,12 @@ def predict(
|
|
| 170 |
batches[-1].append((query, document))
|
| 171 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 172 |
|
|
|
|
|
|
|
| 173 |
# Inference all of the document batches
|
| 174 |
all_logits: list[float] = []
|
| 175 |
-
for batch in batches:
|
|
|
|
| 176 |
batch_inputs = format_pointwise_datapoints(
|
| 177 |
tokenizer,
|
| 178 |
batch,
|
|
@@ -184,9 +200,9 @@ def predict(
|
|
| 184 |
with torch.inference_mode():
|
| 185 |
outputs = model(**batch_inputs, use_cache=False)
|
| 186 |
except torch.OutOfMemoryError:
|
| 187 |
-
|
| 188 |
torch.cuda.empty_cache()
|
| 189 |
-
|
| 190 |
outputs = model(**batch_inputs, use_cache=False)
|
| 191 |
|
| 192 |
# Extract the logits
|
|
@@ -209,11 +225,13 @@ def predict(
|
|
| 209 |
# Unsort by indices
|
| 210 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 211 |
|
|
|
|
| 212 |
return scores
|
| 213 |
|
| 214 |
|
| 215 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 216 |
global global_device
|
|
|
|
| 217 |
global_device = new_device
|
| 218 |
|
| 219 |
|
|
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
| 4 |
+
import logging
|
| 5 |
from typing import cast, Any
|
| 6 |
import types
|
| 7 |
|
|
|
|
| 24 |
# pyright: reportUnknownMemberType=false
|
| 25 |
# pyright: reportUnknownVariableType=false
|
| 26 |
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 30 |
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 31 |
global_device = (
|
|
|
|
| 77 |
if device is None:
|
| 78 |
device = global_device
|
| 79 |
|
| 80 |
+
logger.info(f"Loading model from {MODEL_PATH} on device: {device}")
|
| 81 |
+
|
| 82 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 83 |
assert isinstance(config, PretrainedConfig)
|
| 84 |
|
| 85 |
+
logger.info(f"Loading model with config type: {config.model_type}")
|
| 86 |
model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
MODEL_PATH,
|
| 88 |
torch_dtype="auto",
|
|
|
|
| 99 |
| Qwen3ForCausalLM,
|
| 100 |
)
|
| 101 |
|
| 102 |
+
logger.info("Loading tokenizer")
|
| 103 |
tokenizer = cast(
|
| 104 |
AutoTokenizer,
|
| 105 |
AutoTokenizer.from_pretrained(
|
|
|
|
| 112 |
if tokenizer.pad_token is None:
|
| 113 |
tokenizer.pad_token = tokenizer.eos_token
|
| 114 |
|
| 115 |
+
logger.info("Model and tokenizer loaded successfully")
|
| 116 |
return tokenizer, model
|
| 117 |
|
| 118 |
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 124 |
+
logger.info("Initializing CrossEncoder with eager model loading")
|
| 125 |
# Call the original CrossEncoder __init__ first
|
| 126 |
_original_init(self, *args, **kwargs)
|
| 127 |
|
| 128 |
# Load the model immediately on instantiation
|
| 129 |
+
logger.info("Loading model on instantiation (no lazy loading)")
|
| 130 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 131 |
self.inner_model.eval()
|
| 132 |
self.inner_model.gradient_checkpointing_disable()
|
| 133 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 134 |
"Yes", add_special_tokens=False
|
| 135 |
)[0]
|
| 136 |
+
logger.info(f"CrossEncoder initialization complete. Yes token ID: {self.inner_yes_token_id}")
|
| 137 |
|
| 138 |
|
| 139 |
def predict(
|
|
|
|
| 153 |
raise ValueError("query_documents or sentences must be provided")
|
| 154 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 155 |
|
| 156 |
+
logger.info(f"Starting prediction for {len(query_documents)} query-document pairs")
|
| 157 |
+
|
| 158 |
model = self.inner_model
|
| 159 |
tokenizer = self.inner_tokenizer
|
| 160 |
|
|
|
|
| 183 |
batches[-1].append((query, document))
|
| 184 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 185 |
|
| 186 |
+
logger.info(f"Created {len(batches)} batches for inference")
|
| 187 |
+
|
| 188 |
# Inference all of the document batches
|
| 189 |
all_logits: list[float] = []
|
| 190 |
+
for batch_idx, batch in enumerate(batches):
|
| 191 |
+
logger.debug(f"Processing batch {batch_idx + 1}/{len(batches)} with {len(batch)} pairs")
|
| 192 |
batch_inputs = format_pointwise_datapoints(
|
| 193 |
tokenizer,
|
| 194 |
batch,
|
|
|
|
| 200 |
with torch.inference_mode():
|
| 201 |
outputs = model(**batch_inputs, use_cache=False)
|
| 202 |
except torch.OutOfMemoryError:
|
| 203 |
+
logger.warning(f"GPU OOM! Memory reserved: {torch.cuda.memory_reserved()}")
|
| 204 |
torch.cuda.empty_cache()
|
| 205 |
+
logger.info(f"GPU cache cleared. Memory reserved: {torch.cuda.memory_reserved()}")
|
| 206 |
outputs = model(**batch_inputs, use_cache=False)
|
| 207 |
|
| 208 |
# Extract the logits
|
|
|
|
| 225 |
# Unsort by indices
|
| 226 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 227 |
|
| 228 |
+
logger.info(f"Prediction complete. Generated {len(scores)} scores")
|
| 229 |
return scores
|
| 230 |
|
| 231 |
|
| 232 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 233 |
global global_device
|
| 234 |
+
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 235 |
global_device = new_device
|
| 236 |
|
| 237 |
|