yarden077's picture
uploading 3rd place model
a795a71 verified
import os
import torch
import numpy as np
#from sklearn.metrics.pairwise import cosine_similarity
from transformers import (
AutoTokenizer,
AutoModel,
)
class BGERetriever:
def __init__(self, model_name=None, device=None, sentence_pooling_method="cls"):
"""
Initializes the BGE retriever using the multilingual BGE-m3 base model.
"""
# Use local model
if model_name is None:
model_suffix = "bge-m3"
model_suffix = "test_encoder_only_m3_bge-m3_sd"
model_suffix = "test_encoder_only_base_bge-large-en-v1.5_sd"
model_suffix = "test_encoder_only_base_bge_m3_new"
model_suffix = "test_encoder_only_base_bge_m3_new1"
local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix)
if os.path.isdir(local_model_path):
model_name = local_model_path
print(f"Using local BGE model from: {model_name}")
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.return_dense: bool = True
self.return_sparse: bool = False
self.return_colbert_vecs: bool = False
self.return_sparse_embedding: bool = False
print(f"Loading BGE multilingual model on device: {self.device}")
self.sentence_pooling_method = sentence_pooling_method
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, device_map=self.device)
self.vocab_size = self.model.config.vocab_size
self.temperature = 1.0
self.model.eval()
self.corpus_ids = []
self.corpus_embeddings = None
def _dense_embedding(self, last_hidden_state, attention_mask):
"""Use the pooling method to get the dense embedding.
Args:
last_hidden_state (torch.Tensor): The model output's last hidden state.
attention_mask (torch.Tensor): Mask out padding tokens during pooling.
Raises:
NotImplementedError: Specified pooling method not implemented.
Returns:
torch.Tensor: The dense embeddings.
"""
if self.sentence_pooling_method == "cls":
return last_hidden_state[:, 0]
elif self.sentence_pooling_method == "mean":
s = torch.sum(
last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1
)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == "last_token":
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
return last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device),
sequence_lengths,
]
else:
raise NotImplementedError(f"pooling method {self.sentence_pooling_method} not implemented")
def _compute_similarity(self, q_reps, p_reps):
"""Computes the similarity between query and passage representations using inner product.
Args:
q_reps (torch.Tensor): Query representations.
p_reps (torch.Tensor): Passage representations.
Returns:
torch.Tensor: The computed similarity matrix.
"""
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def compute_dense_score(self, q_reps, p_reps):
"""Compute the dense score.
Args:
q_reps (torch.Tensor): Query representations.
p_reps (torch.Tensor): Passage representations.
Returns:
torch.Tensor: The computed dense scores, adjusted by temperature.
"""
cos_scores = q_reps @ p_reps.T
return cos_scores
scores = self._compute_similarity(q_reps, p_reps) / self.temperature
scores = scores.view(q_reps.size(0), -1)
return scores
@torch.inference_mode()
def embed_texts(
self,
texts,
is_query=False,
batch_size=64,
):
"""
Generates embeddings for texts using BGE model with proper prefixes.
BGE requires specific prefixes for queries vs passages.
"""
prefixed_texts = [text.strip() for text in texts]
all_dense_embeddings = []
total_batches = (len(prefixed_texts) + batch_size - 1) // batch_size
for i in range(0, len(prefixed_texts), batch_size):
batch_num = i // batch_size + 1
if not is_query and batch_num % 50 == 0:
print(f"Processing batch {batch_num}/{total_batches} ({(batch_num/total_batches)*100:.1f}%)")
if torch.cuda.is_available():
torch.cuda.empty_cache()
batch_texts = prefixed_texts[i:i + batch_size]
encoded = self.tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt',
).to(self.device)
model_output = self.model(**encoded)
last_hidden_state = model_output.last_hidden_state
dense_vecs = self._dense_embedding(last_hidden_state, encoded['attention_mask'])
dense_vecs = torch.nn.functional.normalize(dense_vecs, p=2, dim=1)
all_dense_embeddings.append(dense_vecs.cpu())
all_dense_embeddings = torch.cat(all_dense_embeddings, dim=0)
return all_dense_embeddings
class BGEReranker:
def __init__(self, model_name=None, device=None):
"""
Initializes the BGE reranker for fine-grained relevance scoring.
"""
# Use local model
if model_name is None:
model_suffix = 'bge-reranker-v2-m3'
model_suffix = 'test_encoder_only_base_bge_reranker_v2_m3'
model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new"
model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new1"
local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix)
if os.path.isdir(local_model_path):
model_name = local_model_path
print(f"Using local BGE model from: {model_name}")
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading BGE reranker on device: {self.device}")
# BGE reranker is actually a special model type
from transformers import AutoModelForSequenceClassification
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map=self.device,
)
self.model.eval()
@torch.inference_mode()
def rerank(self, query_text, passages, passage_ids, top_k=20, batch_size=32):
"""
Rerank the passages using BGE reranker - CORRECTED VERSION.
"""
if not passages:
return []
pairs = list(zip(passage_ids, passages))
pairs.sort(key=lambda x: len(x[1]))
passage_ids, passages = zip(*pairs)
scores = []
for i in range(0, len(passages), batch_size):
batch_passages = passages[i:i + batch_size]
try:
# BGE reranker expects SEPARATE query and passage inputs
# NOT concatenated strings
batch_queries = [query_text] * len(batch_passages)
# Tokenize query-passage pairs properly
inputs = self.tokenizer(
batch_queries,
batch_passages,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(self.device)
# Get relevance scores from sequence classification model
outputs = self.model(**inputs)
# BGE reranker outputs logits for relevance classification
logits = outputs.logits
# Handle different output shapes
if len(logits.shape) == 1:
# Single score per pair
batch_scores = logits.cpu().numpy()
elif logits.shape[1] == 1:
# Single column output
batch_scores = logits.squeeze(-1).cpu().numpy()
else:
# Binary classification - take positive class (index 1)
batch_scores = logits[:, 1].cpu().numpy()
scores.extend(batch_scores.tolist())
except Exception as e:
print(f"Error in reranking batch {i//batch_size + 1}: {e}")
# Fallback: Use neutral scores for this batch
fallback_scores = [0.5] * len(batch_passages)
scores.extend(fallback_scores)
# Combine results and sort by reranking score
results = list(zip(passage_ids, scores))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
# Global instances
retriever = None
reranker = None
corpus_texts = {} # Store original passage texts for reranking
def preprocess(corpus_dict):
"""
Preprocessing function using BGE multilingual model + BGE reranker.
Input: corpus_dict - dict mapping document IDs to document objects with 'passage'/'text' field
Output: dict containing initialized models, embeddings, and corpus data
Note: Uses global variables (retriever, reranker, corpus_texts) for efficiency,
but also returns all required data via preprocessed_data for function interface.
"""
global retriever, reranker, corpus_texts
print("=" * 60)
print("PREPROCESSING: Initializing BGE Reranker Pipeline...")
print("=" * 60)
# Set GPU memory optimization
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Initialize BGE retriever
print("Loading BGE retriever...")
retriever = BGERetriever()
# Initialize BGE reranker
print("Loading BGE reranker...")
reranker = BGEReranker()
print(f"Preparing corpus with {len(corpus_dict)} documents...")
# Store corpus IDs, passages, and original texts
#retriever.corpus_ids = list(corpus_dict.keys())
corpus_ids = list(corpus_dict.keys())
passages = [doc.get('passage', doc.get('text', '')) for doc in corpus_dict.values()]
retriever.corpus_ids, passages = zip(*sorted(zip(corpus_ids, passages), key=lambda x: len(x[1])))
# Store original texts for reranking
corpus_texts = {doc_id: passages[i] for i, doc_id in enumerate(retriever.corpus_ids)}
# Compute embeddings with conservative batch size for retrieval
print("Computing BGE embeddings...")
retriever.corpus_embeddings = retriever.embed_texts(passages, is_query=False, batch_size=64)
print("✓ Corpus preprocessing complete!")
print(f"✓ Generated embeddings for {len(retriever.corpus_ids)} documents")
print(f"✓ Embedding matrix shape: {retriever.corpus_embeddings.shape}")
return {
'retriever': retriever,
'reranker': reranker,
'corpus_ids': retriever.corpus_ids,
'corpus_embeddings': retriever.corpus_embeddings,
'corpus_texts': corpus_texts,
'num_documents': len(corpus_dict)
}
def predict(query, preprocessed_data):
"""
Two-stage prediction: BGE retrieval + BGE reranking.
Input:
- query: dict with 'query' field containing query text
- preprocessed_data: dict from preprocess() containing models and corpus data
Output: list of dicts with 'paragraph_uuid' and 'score' fields, ranked by relevance
Note: Uses global variables for efficiency but can also extract required data
from preprocessed_data parameter for proper function interface.
"""
global retriever, reranker, corpus_texts
# Extract query text
query_text = query.get('query', '')
if not query_text:
return []
# Use global instances or get from preprocessed_data
if retriever is None:
retriever = preprocessed_data.get('retriever')
reranker = preprocessed_data.get('reranker')
corpus_texts = preprocessed_data.get('corpus_texts', {})
if retriever is None or reranker is None:
print("Error: Missing retriever or reranker in preprocessed data")
return []
try:
#raise
# STAGE 1: BGE Retrieval (get top 100 candidates)
print("Stage 1: BGE retrieval...")
query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1)
# Compute cosine similarity with precomputed corpus embeddings
#e5_scores = cosine_similarity(query_embedding, retriever.corpus_embeddings)[0]
dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings)
e5_scores = dense_scores.squeeze(0).numpy()
# Get top 100 candidates for reranking
top_100_indices = np.argsort(e5_scores)[::-1][:100]
# Get passages and IDs for reranking
candidate_ids = [retriever.corpus_ids[idx] for idx in top_100_indices]
candidate_passages = [corpus_texts.get(doc_id, '') for doc_id in candidate_ids]
# STAGE 2: BGE Reranking (rerank top 100 -> top 20)
print("Stage 2: BGE reranking...")
reranked_results = reranker.rerank(
query_text,
candidate_passages,
candidate_ids,
top_k=20,
batch_size=16,
)
# Build final results with ACTUAL reranking scores
results = []
for rank, (passage_id, rerank_score) in enumerate(reranked_results):
results.append({
'paragraph_uuid': passage_id,
'score': float(rerank_score) # Use actual BGE reranker score!
})
print(f"✓ Returned {len(results)} results with reranker scores")
return results
except Exception as e:
print(f"Error in prediction: {e}")
# Fallback to BGE-only retrieval with BGE scores
query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1)
#e5_scores = cosine_similarity(query_embedding, retriever.corpus_embeddings)[0]
dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings)
e5_scores = dense_scores.squeeze(0).numpy()
top_indices = np.argsort(e5_scores)[::-1][:20]
results = []
for idx in top_indices:
results.append({
'paragraph_uuid': retriever.corpus_ids[idx],
'score': float(e5_scores[idx]) # Use actual BGE cosine similarity score
})
return results