Spaces:
Build error
Build error
| import numpy as np | |
| import logging, torch | |
| from sklearn.preprocessing import MinMaxScaler | |
| from sentence_transformers import CrossEncoder | |
| # from FlagEmbedding import FlagReranker | |
| class Hybrid_search: | |
| def __init__(self, bm25_search, faiss_search, reranker_model_name="BAAI/bge-reranker-v2-gemma", initial_bm25_weight=0.5): | |
| self.bm25_search = bm25_search | |
| self.faiss_search = faiss_search | |
| self.bm25_weight = initial_bm25_weight | |
| # self.reranker = FlagReranker(reranker_model_name, use_fp16=True) | |
| self.logger = logging.getLogger(__name__) | |
| async def advanced_search(self, query, keywords, top_n=5, threshold=0.53, prefixes=None): | |
| # Dynamic BM25 weighting | |
| self._dynamic_weighting(len(query.split())) | |
| keywords = f"{' '.join(keywords)}" | |
| self.logger.info(f"Query: {query}") | |
| self.logger.info(f"Keywords: {keywords}") | |
| # Get BM25 scores and doc_ids | |
| bm25_scores, bm25_doc_ids = self._get_bm25_results(keywords, top_n = top_n) | |
| # self.logger.info(f"BM25 Scores: {bm25_scores}, BM25 doc_ids: {bm25_doc_ids}") | |
| # Get FAISS distances, indices, and doc_ids | |
| faiss_distances, faiss_indices, faiss_doc_ids = self._get_faiss_results(query) | |
| try: | |
| faiss_distances, indices, faiss_doc_ids = self._get_faiss_results(query, top_n = top_n) | |
| # for dist, idx, doc_id in zip(faiss_distances, indices, faiss_doc_ids): | |
| # print(f"Distance: {dist:.4f}, Index: {idx}, Doc ID: {doc_id}") | |
| except Exception as e: | |
| self.logger.error(f"Search failed: {str(e)}") | |
| # Map doc_ids to scores | |
| bm25_scores_dict, faiss_scores_dict = self._map_scores_to_doc_ids( | |
| bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_distances | |
| ) | |
| # Create a unified set of doc IDs | |
| all_doc_ids = sorted(set(bm25_doc_ids).union(faiss_doc_ids)) | |
| # print(f"All doc_ids: {all_doc_ids}, BM25 doc_ids: {bm25_doc_ids}, FAISS doc_ids: {faiss_doc_ids}") | |
| # Filter doc_ids based on prefixes | |
| filtered_doc_ids = self._filter_doc_ids_by_prefixes(all_doc_ids, prefixes) | |
| # self.logger.info(f"Filtered doc_ids: {filtered_doc_ids}") | |
| if not filtered_doc_ids: | |
| self.logger.info("No documents match the prefixes.") | |
| return [] | |
| # Prepare score lists | |
| filtered_bm25_scores, filtered_faiss_scores = self._get_filtered_scores( | |
| filtered_doc_ids, bm25_scores_dict, faiss_scores_dict | |
| ) | |
| # self.logger.info(f"Filtered BM25 scores: {filtered_bm25_scores}") | |
| # self.logger.info(f"Filtered FAISS scores: {filtered_faiss_scores}") | |
| # Normalize scores | |
| bm25_scores_normalized, faiss_scores_normalized = self._normalize_scores( | |
| filtered_bm25_scores, filtered_faiss_scores | |
| ) | |
| # Calculate hybrid scores | |
| hybrid_scores = self._calculate_hybrid_scores(bm25_scores_normalized, faiss_scores_normalized) | |
| # Display hybrid scores | |
| for idx, doc_id in enumerate(filtered_doc_ids): | |
| print(f"Hybrid Score: {hybrid_scores[idx]:.4f}, Doc ID: {doc_id}") | |
| # Apply threshold and get top_n results | |
| results = self._get_top_n_results(filtered_doc_ids, hybrid_scores, top_n, threshold) | |
| self.logger.info(f"Results before reranking: {results}") | |
| # If results exist, apply re-ranking | |
| # if results: | |
| # re_ranked_results = self._rerank_results(query, results) | |
| # self.logger.info(f"Results after reranking: {re_ranked_results}") | |
| # return re_ranked_results | |
| return results | |
| def _dynamic_weighting(self, query_length): | |
| if query_length <= 5: | |
| self.bm25_weight = 0.7 | |
| else: | |
| self.bm25_weight = 0.5 | |
| self.logger.info(f"Dynamic BM25 weight set to: {self.bm25_weight}") | |
| def _get_bm25_results(self, keywords, top_n:int = None): | |
| # Get BM25 scores | |
| bm25_scores = np.array(self.bm25_search.get_scores(keywords)) | |
| bm25_doc_ids = np.array(self.bm25_search.doc_ids) # Assuming doc_ids is a list of document IDs | |
| # Log the scores and IDs before filtering | |
| # self.logger.info(f"BM25 scores: {bm25_scores}") | |
| # self.logger.info(f"BM25 doc_ids: {bm25_doc_ids}") | |
| # Get the top k indices based on BM25 scores | |
| top_k_indices = np.argsort(bm25_scores)[-top_n:][::-1] | |
| # Retrieve top k scores and corresponding document IDs | |
| top_k_scores = bm25_scores[top_k_indices] | |
| top_k_doc_ids = bm25_doc_ids[top_k_indices] | |
| # Return top k scores and document IDs | |
| return top_k_scores, top_k_doc_ids | |
| def _get_faiss_results(self, query, top_n: int = None) -> tuple[np.ndarray, np.ndarray, list[str]]: | |
| try: | |
| # If top_k is not specified, use all documents | |
| if top_n is None: | |
| top_n = len(self.faiss_search.doc_ids) | |
| # Use the search's search method which handles the embedding | |
| distances, indices = self.faiss_search.search(query, k=top_n) | |
| if len(distances) == 0 or len(indices) == 0: | |
| # Handle case where FAISS returns empty results | |
| self.logger.info("FAISS search returned no results.") | |
| return np.array([]), np.array([]), [] | |
| # Filter out invalid indices (-1) | |
| valid_mask = indices != -1 | |
| filtered_distances = distances[valid_mask] | |
| filtered_indices = indices[valid_mask] | |
| # Map indices to doc_ids | |
| doc_ids = [self.faiss_search.doc_ids[idx] for idx in filtered_indices | |
| if 0 <= idx < len(self.faiss_search.doc_ids)] | |
| # self.logger.info(f"FAISS distances: {filtered_distances}") | |
| # self.logger.info(f"FAISS indices: {filtered_indices}") | |
| # self.logger.info(f"FAISS doc_ids: {doc_ids}") | |
| return filtered_distances, filtered_indices, doc_ids | |
| except Exception as e: | |
| self.logger.error(f"Error in FAISS search: {str(e)}") | |
| raise | |
| def _map_scores_to_doc_ids(self, bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_scores): | |
| bm25_scores_dict = dict(zip(bm25_doc_ids, bm25_scores)) | |
| faiss_scores_dict = dict(zip(faiss_doc_ids, faiss_scores)) | |
| # self.logger.info(f"BM25 scores dict: {bm25_scores_dict}") | |
| # self.logger.info(f"FAISS scores dict: {faiss_scores_dict}") | |
| return bm25_scores_dict, faiss_scores_dict | |
| def _filter_doc_ids_by_prefixes(self, all_doc_ids, prefixes): | |
| if prefixes: | |
| filtered_doc_ids = [ | |
| doc_id | |
| for doc_id in all_doc_ids | |
| if any(doc_id.startswith(prefix) for prefix in prefixes) | |
| ] | |
| else: | |
| filtered_doc_ids = list(all_doc_ids) | |
| return filtered_doc_ids | |
| def _get_filtered_scores(self, filtered_doc_ids, bm25_scores_dict, faiss_scores_dict): | |
| # Initialize lists to hold scores in the unified doc ID order | |
| bm25_aligned_scores = [] | |
| faiss_aligned_scores = [] | |
| # Populate aligned score lists, filling missing scores with neutral values | |
| for doc_id in filtered_doc_ids: | |
| bm25_aligned_scores.append(bm25_scores_dict.get(doc_id, 0)) # Use 0 if not found in BM25 | |
| faiss_aligned_scores.append(faiss_scores_dict.get(doc_id, max(faiss_scores_dict.values()) + 1)) # Use a high distance if not found in FAISS | |
| # Invert the FAISS scores | |
| faiss_aligned_scores = [1 / score if score != 0 else 0 for score in faiss_aligned_scores] | |
| return bm25_aligned_scores, faiss_aligned_scores | |
| def _normalize_scores(self, filtered_bm25_scores, filtered_faiss_scores): | |
| scaler_bm25 = MinMaxScaler() | |
| bm25_scores_normalized = self._normalize_array(filtered_bm25_scores, scaler_bm25) | |
| scaler_faiss = MinMaxScaler() | |
| faiss_scores_normalized = self._normalize_array(filtered_faiss_scores, scaler_faiss) | |
| # self.logger.info(f"Normalized BM25 scores: {bm25_scores_normalized}") | |
| # self.logger.info(f"Normalized FAISS scores: {faiss_scores_normalized}") | |
| return bm25_scores_normalized, faiss_scores_normalized | |
| def _normalize_array(self, scores, scaler): | |
| scores_array = np.array(scores) | |
| if np.ptp(scores_array) > 0: | |
| normalized_scores = scaler.fit_transform(scores_array.reshape(-1, 1)).flatten() | |
| else: | |
| # Handle identical scores with a fallback to uniform 0.5 | |
| normalized_scores = np.full_like(scores_array, 0.5, dtype=float) | |
| return normalized_scores | |
| def _calculate_hybrid_scores(self, bm25_scores_normalized, faiss_scores_normalized): | |
| hybrid_scores = self.bm25_weight * bm25_scores_normalized + (1 - self.bm25_weight) * faiss_scores_normalized | |
| # self.logger.info(f"Hybrid scores: {hybrid_scores}") | |
| return hybrid_scores | |
| def _get_top_n_results(self, filtered_doc_ids, hybrid_scores, top_n, threshold): | |
| hybrid_scores = np.array(hybrid_scores) | |
| threshold_indices = np.where(hybrid_scores >= threshold)[0] | |
| if len(threshold_indices) == 0: | |
| self.logger.info("No documents meet the threshold.") | |
| return [] | |
| sorted_indices = threshold_indices[np.argsort(hybrid_scores[threshold_indices])[::-1]] | |
| top_indices = sorted_indices[:top_n] | |
| results = [(filtered_doc_ids[idx], hybrid_scores[idx]) for idx in top_indices] | |
| self.logger.info(f"Top {top_n} results: {results}") | |
| return results | |
| def _rerank_results(self, query, results): | |
| """ | |
| Re-rank the retrieved documents using FlagReranker with normalized scores. | |
| Parameters: | |
| - query (str): The search query. | |
| - results (List[Tuple[str, float]]): A list of (doc_id, score) tuples. | |
| Returns: | |
| - List[Tuple[str, float]]: Re-ranked list of (doc_id, score) tuples with normalized scores. | |
| """ | |
| # Prepare input for the re-ranker | |
| document_texts = [self.bm25_search.get_document(doc_id) for doc_id, _ in results] | |
| doc_ids = [doc_id for doc_id, _ in results] | |
| # Generate pairwise scores using the FlagReranker | |
| rerank_inputs = [[query, doc] for doc in document_texts] | |
| with torch.no_grad(): | |
| rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True) | |
| # rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True) | |
| # Combine doc_ids with normalized re-rank scores and sort by scores | |
| reranked_results = sorted( | |
| zip(doc_ids, rerank_scores), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| # Log and return results | |
| # self.logger.info(f"Re-ranked results with normalized scores: {reranked_results}") | |
| return reranked_results | |