Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files- cross_encoder_reranking_train.py +17 -12
cross_encoder_reranking_train.py
CHANGED
|
@@ -276,24 +276,29 @@ def hybrid_score(cross_encoder_score, semantic_score, weight_cross=0.7, weight_s
|
|
| 276 |
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
|
| 277 |
device = next(model.parameters()).device
|
| 278 |
cross_scores = []
|
| 279 |
-
query_emb = embed_text_list([query_text])[0]
|
| 280 |
|
| 281 |
instructed_query = get_detailed_instruct("", query_text)
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
|
| 286 |
-
|
|
|
|
| 287 |
|
| 288 |
with torch.no_grad():
|
| 289 |
-
batch_dict = tokenizer(
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
# Semantic scores
|
| 299 |
doc_embeddings = embed_text_list(doc_texts)
|
|
|
|
| 276 |
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
|
| 277 |
device = next(model.parameters()).device
|
| 278 |
cross_scores = []
|
| 279 |
+
query_emb = embed_text_list([query_text])[0] # Move embedder to CPU
|
| 280 |
|
| 281 |
instructed_query = get_detailed_instruct("", query_text)
|
| 282 |
|
| 283 |
+
# Pre-create all input pairs (concatenation-based cross-encoder setup)
|
| 284 |
+
input_texts = [f"{instructed_query} {doc}" for doc in doc_texts]
|
| 285 |
|
| 286 |
+
for i in tqdm(range(0, len(input_texts), batch_size), desc="Scoring documents", leave=False):
|
| 287 |
+
batch_input_texts = input_texts[i:i+batch_size]
|
| 288 |
|
| 289 |
with torch.no_grad():
|
| 290 |
+
batch_dict = tokenizer(batch_input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
|
| 291 |
+
|
| 292 |
+
# Mixed precision for faster inference and lower memory
|
| 293 |
+
with torch.cuda.amp.autocast():
|
| 294 |
+
outputs = model(**batch_dict)
|
| 295 |
+
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 296 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 297 |
+
|
| 298 |
+
# Since queries are repeated in each pair, compare to instructed query embedding (first one)
|
| 299 |
+
query_vector = embeddings[0].unsqueeze(0) # Use first as query
|
| 300 |
+
batch_cross_scores = (query_vector @ embeddings.T).squeeze(0).cpu().numpy()[1:] # Exclude self-comparison
|
| 301 |
+
cross_scores.extend(batch_cross_scores)
|
| 302 |
|
| 303 |
# Semantic scores
|
| 304 |
doc_embeddings = embed_text_list(doc_texts)
|