import os import torch import numpy as np #from sklearn.metrics.pairwise import cosine_similarity from transformers import ( AutoTokenizer, AutoModel, ) class BGERetriever: def __init__(self, model_name=None, device=None, sentence_pooling_method="cls"): """ Initializes the BGE retriever using the multilingual BGE-m3 base model. """ # Use local model if model_name is None: model_suffix = "bge-m3" model_suffix = "test_encoder_only_m3_bge-m3_sd" model_suffix = "test_encoder_only_base_bge-large-en-v1.5_sd" model_suffix = "test_encoder_only_base_bge_m3_new" model_suffix = "test_encoder_only_base_bge_m3_new1" local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix) if os.path.isdir(local_model_path): model_name = local_model_path print(f"Using local BGE model from: {model_name}") self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.return_dense: bool = True self.return_sparse: bool = False self.return_colbert_vecs: bool = False self.return_sparse_embedding: bool = False print(f"Loading BGE multilingual model on device: {self.device}") self.sentence_pooling_method = sentence_pooling_method self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, device_map=self.device) self.vocab_size = self.model.config.vocab_size self.temperature = 1.0 self.model.eval() self.corpus_ids = [] self.corpus_embeddings = None def _dense_embedding(self, last_hidden_state, attention_mask): """Use the pooling method to get the dense embedding. Args: last_hidden_state (torch.Tensor): The model output's last hidden state. attention_mask (torch.Tensor): Mask out padding tokens during pooling. Raises: NotImplementedError: Specified pooling method not implemented. Returns: torch.Tensor: The dense embeddings. """ if self.sentence_pooling_method == "cls": return last_hidden_state[:, 0] elif self.sentence_pooling_method == "mean": s = torch.sum( last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1 ) d = attention_mask.sum(dim=1, keepdim=True).float() return s / d elif self.sentence_pooling_method == "last_token": left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] if left_padding: return last_hidden_state[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_state.shape[0] return last_hidden_state[ torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths, ] else: raise NotImplementedError(f"pooling method {self.sentence_pooling_method} not implemented") def _compute_similarity(self, q_reps, p_reps): """Computes the similarity between query and passage representations using inner product. Args: q_reps (torch.Tensor): Query representations. p_reps (torch.Tensor): Passage representations. Returns: torch.Tensor: The computed similarity matrix. """ if len(p_reps.size()) == 2: return torch.matmul(q_reps, p_reps.transpose(0, 1)) return torch.matmul(q_reps, p_reps.transpose(-2, -1)) def compute_dense_score(self, q_reps, p_reps): """Compute the dense score. Args: q_reps (torch.Tensor): Query representations. p_reps (torch.Tensor): Passage representations. Returns: torch.Tensor: The computed dense scores, adjusted by temperature. """ cos_scores = q_reps @ p_reps.T return cos_scores scores = self._compute_similarity(q_reps, p_reps) / self.temperature scores = scores.view(q_reps.size(0), -1) return scores @torch.inference_mode() def embed_texts( self, texts, is_query=False, batch_size=64, ): """ Generates embeddings for texts using BGE model with proper prefixes. BGE requires specific prefixes for queries vs passages. """ prefixed_texts = [text.strip() for text in texts] all_dense_embeddings = [] total_batches = (len(prefixed_texts) + batch_size - 1) // batch_size for i in range(0, len(prefixed_texts), batch_size): batch_num = i // batch_size + 1 if not is_query and batch_num % 50 == 0: print(f"Processing batch {batch_num}/{total_batches} ({(batch_num/total_batches)*100:.1f}%)") if torch.cuda.is_available(): torch.cuda.empty_cache() batch_texts = prefixed_texts[i:i + batch_size] encoded = self.tokenizer( batch_texts, padding=True, truncation=True, max_length=512, return_tensors='pt', ).to(self.device) model_output = self.model(**encoded) last_hidden_state = model_output.last_hidden_state dense_vecs = self._dense_embedding(last_hidden_state, encoded['attention_mask']) dense_vecs = torch.nn.functional.normalize(dense_vecs, p=2, dim=1) all_dense_embeddings.append(dense_vecs.cpu()) all_dense_embeddings = torch.cat(all_dense_embeddings, dim=0) return all_dense_embeddings class BGEReranker: def __init__(self, model_name=None, device=None): """ Initializes the BGE reranker for fine-grained relevance scoring. """ # Use local model if model_name is None: model_suffix = 'bge-reranker-v2-m3' model_suffix = 'test_encoder_only_base_bge_reranker_v2_m3' model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new" model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new1" local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix) if os.path.isdir(local_model_path): model_name = local_model_path print(f"Using local BGE model from: {model_name}") self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") print(f"Loading BGE reranker on device: {self.device}") # BGE reranker is actually a special model type from transformers import AutoModelForSequenceClassification self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSequenceClassification.from_pretrained( model_name, torch_dtype=torch.float16, trust_remote_code=True, device_map=self.device, ) self.model.eval() @torch.inference_mode() def rerank(self, query_text, passages, passage_ids, top_k=20, batch_size=32): """ Rerank the passages using BGE reranker - CORRECTED VERSION. """ if not passages: return [] pairs = list(zip(passage_ids, passages)) pairs.sort(key=lambda x: len(x[1])) passage_ids, passages = zip(*pairs) scores = [] for i in range(0, len(passages), batch_size): batch_passages = passages[i:i + batch_size] try: # BGE reranker expects SEPARATE query and passage inputs # NOT concatenated strings batch_queries = [query_text] * len(batch_passages) # Tokenize query-passage pairs properly inputs = self.tokenizer( batch_queries, batch_passages, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(self.device) # Get relevance scores from sequence classification model outputs = self.model(**inputs) # BGE reranker outputs logits for relevance classification logits = outputs.logits # Handle different output shapes if len(logits.shape) == 1: # Single score per pair batch_scores = logits.cpu().numpy() elif logits.shape[1] == 1: # Single column output batch_scores = logits.squeeze(-1).cpu().numpy() else: # Binary classification - take positive class (index 1) batch_scores = logits[:, 1].cpu().numpy() scores.extend(batch_scores.tolist()) except Exception as e: print(f"Error in reranking batch {i//batch_size + 1}: {e}") # Fallback: Use neutral scores for this batch fallback_scores = [0.5] * len(batch_passages) scores.extend(fallback_scores) # Combine results and sort by reranking score results = list(zip(passage_ids, scores)) results.sort(key=lambda x: x[1], reverse=True) return results[:top_k] # Global instances retriever = None reranker = None corpus_texts = {} # Store original passage texts for reranking def preprocess(corpus_dict): """ Preprocessing function using BGE multilingual model + BGE reranker. Input: corpus_dict - dict mapping document IDs to document objects with 'passage'/'text' field Output: dict containing initialized models, embeddings, and corpus data Note: Uses global variables (retriever, reranker, corpus_texts) for efficiency, but also returns all required data via preprocessed_data for function interface. """ global retriever, reranker, corpus_texts print("=" * 60) print("PREPROCESSING: Initializing BGE Reranker Pipeline...") print("=" * 60) # Set GPU memory optimization os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # Initialize BGE retriever print("Loading BGE retriever...") retriever = BGERetriever() # Initialize BGE reranker print("Loading BGE reranker...") reranker = BGEReranker() print(f"Preparing corpus with {len(corpus_dict)} documents...") # Store corpus IDs, passages, and original texts #retriever.corpus_ids = list(corpus_dict.keys()) corpus_ids = list(corpus_dict.keys()) passages = [doc.get('passage', doc.get('text', '')) for doc in corpus_dict.values()] retriever.corpus_ids, passages = zip(*sorted(zip(corpus_ids, passages), key=lambda x: len(x[1]))) # Store original texts for reranking corpus_texts = {doc_id: passages[i] for i, doc_id in enumerate(retriever.corpus_ids)} # Compute embeddings with conservative batch size for retrieval print("Computing BGE embeddings...") retriever.corpus_embeddings = retriever.embed_texts(passages, is_query=False, batch_size=64) print("✓ Corpus preprocessing complete!") print(f"✓ Generated embeddings for {len(retriever.corpus_ids)} documents") print(f"✓ Embedding matrix shape: {retriever.corpus_embeddings.shape}") return { 'retriever': retriever, 'reranker': reranker, 'corpus_ids': retriever.corpus_ids, 'corpus_embeddings': retriever.corpus_embeddings, 'corpus_texts': corpus_texts, 'num_documents': len(corpus_dict) } def predict(query, preprocessed_data): """ Two-stage prediction: BGE retrieval + BGE reranking. Input: - query: dict with 'query' field containing query text - preprocessed_data: dict from preprocess() containing models and corpus data Output: list of dicts with 'paragraph_uuid' and 'score' fields, ranked by relevance Note: Uses global variables for efficiency but can also extract required data from preprocessed_data parameter for proper function interface. """ global retriever, reranker, corpus_texts # Extract query text query_text = query.get('query', '') if not query_text: return [] # Use global instances or get from preprocessed_data if retriever is None: retriever = preprocessed_data.get('retriever') reranker = preprocessed_data.get('reranker') corpus_texts = preprocessed_data.get('corpus_texts', {}) if retriever is None or reranker is None: print("Error: Missing retriever or reranker in preprocessed data") return [] try: #raise # STAGE 1: BGE Retrieval (get top 100 candidates) print("Stage 1: BGE retrieval...") query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1) # Compute cosine similarity with precomputed corpus embeddings #e5_scores = cosine_similarity(query_embedding, retriever.corpus_embeddings)[0] dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings) e5_scores = dense_scores.squeeze(0).numpy() # Get top 100 candidates for reranking top_100_indices = np.argsort(e5_scores)[::-1][:100] # Get passages and IDs for reranking candidate_ids = [retriever.corpus_ids[idx] for idx in top_100_indices] candidate_passages = [corpus_texts.get(doc_id, '') for doc_id in candidate_ids] # STAGE 2: BGE Reranking (rerank top 100 -> top 20) print("Stage 2: BGE reranking...") reranked_results = reranker.rerank( query_text, candidate_passages, candidate_ids, top_k=20, batch_size=16, ) # Build final results with ACTUAL reranking scores results = [] for rank, (passage_id, rerank_score) in enumerate(reranked_results): results.append({ 'paragraph_uuid': passage_id, 'score': float(rerank_score) # Use actual BGE reranker score! }) print(f"✓ Returned {len(results)} results with reranker scores") return results except Exception as e: print(f"Error in prediction: {e}") # Fallback to BGE-only retrieval with BGE scores query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1) #e5_scores = cosine_similarity(query_embedding, retriever.corpus_embeddings)[0] dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings) e5_scores = dense_scores.squeeze(0).numpy() top_indices = np.argsort(e5_scores)[::-1][:20] results = [] for idx in top_indices: results.append({ 'paragraph_uuid': retriever.corpus_ids[idx], 'score': float(e5_scores[idx]) # Use actual BGE cosine similarity score }) return results