| import os |
|
|
| import torch |
| import numpy as np |
| |
| from transformers import ( |
| AutoTokenizer, |
| AutoModel, |
| ) |
|
|
|
|
| class BGERetriever: |
| def __init__(self, model_name=None, device=None, sentence_pooling_method="cls"): |
| """ |
| Initializes the BGE retriever using the multilingual BGE-m3 base model. |
| """ |
| |
| if model_name is None: |
| model_suffix = "bge-m3" |
| model_suffix = "test_encoder_only_m3_bge-m3_sd" |
| model_suffix = "test_encoder_only_base_bge-large-en-v1.5_sd" |
| model_suffix = "test_encoder_only_base_bge_m3_new" |
| model_suffix = "test_encoder_only_base_bge_m3_new1" |
| local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix) |
| if os.path.isdir(local_model_path): |
| model_name = local_model_path |
| print(f"Using local BGE model from: {model_name}") |
|
|
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| self.return_dense: bool = True |
| self.return_sparse: bool = False |
| self.return_colbert_vecs: bool = False |
| self.return_sparse_embedding: bool = False |
|
|
| print(f"Loading BGE multilingual model on device: {self.device}") |
| self.sentence_pooling_method = sentence_pooling_method |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, device_map=self.device) |
| self.vocab_size = self.model.config.vocab_size |
| self.temperature = 1.0 |
| self.model.eval() |
|
|
| self.corpus_ids = [] |
| self.corpus_embeddings = None |
|
|
| def _dense_embedding(self, last_hidden_state, attention_mask): |
| """Use the pooling method to get the dense embedding. |
| |
| Args: |
| last_hidden_state (torch.Tensor): The model output's last hidden state. |
| attention_mask (torch.Tensor): Mask out padding tokens during pooling. |
| |
| Raises: |
| NotImplementedError: Specified pooling method not implemented. |
| |
| Returns: |
| torch.Tensor: The dense embeddings. |
| """ |
| if self.sentence_pooling_method == "cls": |
| return last_hidden_state[:, 0] |
| elif self.sentence_pooling_method == "mean": |
| s = torch.sum( |
| last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1 |
| ) |
| d = attention_mask.sum(dim=1, keepdim=True).float() |
| return s / d |
| elif self.sentence_pooling_method == "last_token": |
| left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] |
| if left_padding: |
| return last_hidden_state[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = last_hidden_state.shape[0] |
| return last_hidden_state[ |
| torch.arange(batch_size, device=last_hidden_state.device), |
| sequence_lengths, |
| ] |
| else: |
| raise NotImplementedError(f"pooling method {self.sentence_pooling_method} not implemented") |
|
|
| def _compute_similarity(self, q_reps, p_reps): |
| """Computes the similarity between query and passage representations using inner product. |
| |
| Args: |
| q_reps (torch.Tensor): Query representations. |
| p_reps (torch.Tensor): Passage representations. |
| |
| Returns: |
| torch.Tensor: The computed similarity matrix. |
| """ |
| if len(p_reps.size()) == 2: |
| return torch.matmul(q_reps, p_reps.transpose(0, 1)) |
| return torch.matmul(q_reps, p_reps.transpose(-2, -1)) |
|
|
| def compute_dense_score(self, q_reps, p_reps): |
| """Compute the dense score. |
| |
| Args: |
| q_reps (torch.Tensor): Query representations. |
| p_reps (torch.Tensor): Passage representations. |
| |
| Returns: |
| torch.Tensor: The computed dense scores, adjusted by temperature. |
| """ |
| cos_scores = q_reps @ p_reps.T |
| return cos_scores |
| scores = self._compute_similarity(q_reps, p_reps) / self.temperature |
| scores = scores.view(q_reps.size(0), -1) |
| return scores |
|
|
| @torch.inference_mode() |
| def embed_texts( |
| self, |
| texts, |
| is_query=False, |
| batch_size=64, |
| ): |
| """ |
| Generates embeddings for texts using BGE model with proper prefixes. |
| BGE requires specific prefixes for queries vs passages. |
| """ |
|
|
| prefixed_texts = [text.strip() for text in texts] |
|
|
| all_dense_embeddings = [] |
| total_batches = (len(prefixed_texts) + batch_size - 1) // batch_size |
|
|
| for i in range(0, len(prefixed_texts), batch_size): |
| batch_num = i // batch_size + 1 |
| if not is_query and batch_num % 50 == 0: |
| print(f"Processing batch {batch_num}/{total_batches} ({(batch_num/total_batches)*100:.1f}%)") |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| batch_texts = prefixed_texts[i:i + batch_size] |
|
|
| encoded = self.tokenizer( |
| batch_texts, |
| padding=True, |
| truncation=True, |
| max_length=512, |
| return_tensors='pt', |
| ).to(self.device) |
|
|
| model_output = self.model(**encoded) |
|
|
| last_hidden_state = model_output.last_hidden_state |
| dense_vecs = self._dense_embedding(last_hidden_state, encoded['attention_mask']) |
| dense_vecs = torch.nn.functional.normalize(dense_vecs, p=2, dim=1) |
| all_dense_embeddings.append(dense_vecs.cpu()) |
|
|
| all_dense_embeddings = torch.cat(all_dense_embeddings, dim=0) |
|
|
| return all_dense_embeddings |
|
|
|
|
| class BGEReranker: |
| def __init__(self, model_name=None, device=None): |
| """ |
| Initializes the BGE reranker for fine-grained relevance scoring. |
| """ |
| |
| if model_name is None: |
| model_suffix = 'bge-reranker-v2-m3' |
| model_suffix = 'test_encoder_only_base_bge_reranker_v2_m3' |
| model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new" |
| model_suffix = "test_encoder_only_base_bge_reranker_v2_m3_new1" |
| local_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', model_suffix) |
| if os.path.isdir(local_model_path): |
| model_name = local_model_path |
| print(f"Using local BGE model from: {model_name}") |
|
|
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print(f"Loading BGE reranker on device: {self.device}") |
| |
| from transformers import AutoModelForSequenceClassification |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForSequenceClassification.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| device_map=self.device, |
| ) |
| self.model.eval() |
|
|
| @torch.inference_mode() |
| def rerank(self, query_text, passages, passage_ids, top_k=20, batch_size=32): |
| """ |
| Rerank the passages using BGE reranker - CORRECTED VERSION. |
| """ |
| if not passages: |
| return [] |
|
|
| pairs = list(zip(passage_ids, passages)) |
| pairs.sort(key=lambda x: len(x[1])) |
| passage_ids, passages = zip(*pairs) |
|
|
| scores = [] |
| for i in range(0, len(passages), batch_size): |
| batch_passages = passages[i:i + batch_size] |
|
|
| try: |
| |
| |
| batch_queries = [query_text] * len(batch_passages) |
|
|
| |
| inputs = self.tokenizer( |
| batch_queries, |
| batch_passages, |
| padding=True, |
| truncation=True, |
| max_length=512, |
| return_tensors='pt' |
| ).to(self.device) |
|
|
| |
| outputs = self.model(**inputs) |
|
|
| |
| logits = outputs.logits |
|
|
| |
| if len(logits.shape) == 1: |
| |
| batch_scores = logits.cpu().numpy() |
| elif logits.shape[1] == 1: |
| |
| batch_scores = logits.squeeze(-1).cpu().numpy() |
| else: |
| |
| batch_scores = logits[:, 1].cpu().numpy() |
|
|
| scores.extend(batch_scores.tolist()) |
|
|
| except Exception as e: |
| print(f"Error in reranking batch {i//batch_size + 1}: {e}") |
| |
| fallback_scores = [0.5] * len(batch_passages) |
| scores.extend(fallback_scores) |
|
|
| |
| results = list(zip(passage_ids, scores)) |
| results.sort(key=lambda x: x[1], reverse=True) |
|
|
| return results[:top_k] |
|
|
|
|
| |
| retriever = None |
| reranker = None |
| corpus_texts = {} |
|
|
|
|
| def preprocess(corpus_dict): |
| """ |
| Preprocessing function using BGE multilingual model + BGE reranker. |
| |
| Input: corpus_dict - dict mapping document IDs to document objects with 'passage'/'text' field |
| Output: dict containing initialized models, embeddings, and corpus data |
| |
| Note: Uses global variables (retriever, reranker, corpus_texts) for efficiency, |
| but also returns all required data via preprocessed_data for function interface. |
| """ |
| global retriever, reranker, corpus_texts |
| print("=" * 60) |
| print("PREPROCESSING: Initializing BGE Reranker Pipeline...") |
| print("=" * 60) |
|
|
| |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
|
|
| |
| print("Loading BGE retriever...") |
| retriever = BGERetriever() |
|
|
| |
| print("Loading BGE reranker...") |
| reranker = BGEReranker() |
|
|
| print(f"Preparing corpus with {len(corpus_dict)} documents...") |
|
|
| |
| |
| corpus_ids = list(corpus_dict.keys()) |
| passages = [doc.get('passage', doc.get('text', '')) for doc in corpus_dict.values()] |
| retriever.corpus_ids, passages = zip(*sorted(zip(corpus_ids, passages), key=lambda x: len(x[1]))) |
|
|
| |
| corpus_texts = {doc_id: passages[i] for i, doc_id in enumerate(retriever.corpus_ids)} |
|
|
| |
| print("Computing BGE embeddings...") |
| retriever.corpus_embeddings = retriever.embed_texts(passages, is_query=False, batch_size=64) |
|
|
| print("✓ Corpus preprocessing complete!") |
| print(f"✓ Generated embeddings for {len(retriever.corpus_ids)} documents") |
|
|
| print(f"✓ Embedding matrix shape: {retriever.corpus_embeddings.shape}") |
|
|
| return { |
| 'retriever': retriever, |
| 'reranker': reranker, |
| 'corpus_ids': retriever.corpus_ids, |
| 'corpus_embeddings': retriever.corpus_embeddings, |
| 'corpus_texts': corpus_texts, |
| 'num_documents': len(corpus_dict) |
| } |
|
|
|
|
| def predict(query, preprocessed_data): |
| """ |
| Two-stage prediction: BGE retrieval + BGE reranking. |
| |
| Input: |
| - query: dict with 'query' field containing query text |
| - preprocessed_data: dict from preprocess() containing models and corpus data |
| |
| Output: list of dicts with 'paragraph_uuid' and 'score' fields, ranked by relevance |
| |
| Note: Uses global variables for efficiency but can also extract required data |
| from preprocessed_data parameter for proper function interface. |
| """ |
| global retriever, reranker, corpus_texts |
|
|
| |
| query_text = query.get('query', '') |
| if not query_text: |
| return [] |
|
|
| |
| if retriever is None: |
| retriever = preprocessed_data.get('retriever') |
| reranker = preprocessed_data.get('reranker') |
| corpus_texts = preprocessed_data.get('corpus_texts', {}) |
|
|
| if retriever is None or reranker is None: |
| print("Error: Missing retriever or reranker in preprocessed data") |
| return [] |
|
|
| try: |
| |
| |
| print("Stage 1: BGE retrieval...") |
| query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1) |
|
|
| |
| |
| dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings) |
| e5_scores = dense_scores.squeeze(0).numpy() |
|
|
| |
| top_100_indices = np.argsort(e5_scores)[::-1][:100] |
|
|
| |
| candidate_ids = [retriever.corpus_ids[idx] for idx in top_100_indices] |
| candidate_passages = [corpus_texts.get(doc_id, '') for doc_id in candidate_ids] |
|
|
| |
| print("Stage 2: BGE reranking...") |
| reranked_results = reranker.rerank( |
| query_text, |
| candidate_passages, |
| candidate_ids, |
| top_k=20, |
| batch_size=16, |
| ) |
|
|
| |
| results = [] |
| for rank, (passage_id, rerank_score) in enumerate(reranked_results): |
| results.append({ |
| 'paragraph_uuid': passage_id, |
| 'score': float(rerank_score) |
| }) |
|
|
| print(f"✓ Returned {len(results)} results with reranker scores") |
| return results |
|
|
| except Exception as e: |
| print(f"Error in prediction: {e}") |
| |
| query_embedding = retriever.embed_texts([query_text], is_query=True, batch_size=1) |
| |
| dense_scores = retriever.compute_dense_score(query_embedding, retriever.corpus_embeddings) |
|
|
| e5_scores = dense_scores.squeeze(0).numpy() |
|
|
| top_indices = np.argsort(e5_scores)[::-1][:20] |
|
|
| results = [] |
| for idx in top_indices: |
| results.append({ |
| 'paragraph_uuid': retriever.corpus_ids[idx], |
| 'score': float(e5_scores[idx]) |
| }) |
|
|
| return results |
|
|