Spaces:
Sleeping
Sleeping
| """ | |
| Search Module for Advanced RAG | |
| Handles hybrid search combining BM25 and semantic search with score fusion. | |
| """ | |
| import re | |
| import time | |
| import numpy as np | |
| from typing import List, Dict, Any | |
| from pathlib import Path | |
| from rank_bm25 import BM25Okapi | |
| from qdrant_client import QdrantClient | |
| from config.config import ( | |
| OUTPUT_DIR, TOP_K, SCORE_THRESHOLD, ENABLE_HYBRID_SEARCH, | |
| BM25_WEIGHT, SEMANTIC_WEIGHT, USE_TOTAL_BUDGET_APPROACH | |
| ) | |
| class SearchManager: | |
| """Manages hybrid search operations combining BM25 and semantic search.""" | |
| def __init__(self, embedding_manager): | |
| """Initialize the search manager.""" | |
| self.embedding_manager = embedding_manager | |
| self.base_db_path = Path(OUTPUT_DIR) | |
| self.qdrant_clients = {} | |
| self.bm25_indexes = {} # Cache BM25 indexes per document | |
| self.document_chunks = {} # Cache chunks for BM25 | |
| print("β Search Manager initialized") | |
| def get_qdrant_client(self, doc_id: str) -> QdrantClient: | |
| """Get or create Qdrant client for a specific document.""" | |
| if doc_id not in self.qdrant_clients: | |
| db_path = self.base_db_path / f"{doc_id}_collection.db" | |
| if not db_path.exists(): | |
| raise FileNotFoundError(f"Database not found for document {doc_id}") | |
| self.qdrant_clients[doc_id] = QdrantClient(path=str(db_path)) | |
| return self.qdrant_clients[doc_id] | |
| def _load_bm25_index(self, doc_id: str): | |
| """Load or create BM25 index for a document.""" | |
| if doc_id not in self.bm25_indexes: | |
| print(f"π Loading BM25 index for {doc_id}") | |
| # Get all chunks from Qdrant | |
| client = self.get_qdrant_client(doc_id) | |
| collection_name = f"{doc_id}_collection" | |
| try: | |
| # Get all points from collection | |
| result = client.scroll( | |
| collection_name=collection_name, | |
| limit=10000, # Adjust based on your chunk count | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| chunks = [] | |
| chunk_ids = [] | |
| for point in result[0]: | |
| chunk_text = point.payload.get('text', '') | |
| chunks.append(chunk_text) | |
| chunk_ids.append(point.id) | |
| # Tokenize chunks for BM25 | |
| tokenized_chunks = [self._tokenize_text(chunk) for chunk in chunks] | |
| # Create BM25 index | |
| self.bm25_indexes[doc_id] = BM25Okapi(tokenized_chunks) | |
| self.document_chunks[doc_id] = { | |
| 'chunks': chunks, | |
| 'chunk_ids': chunk_ids, | |
| 'tokenized_chunks': tokenized_chunks | |
| } | |
| print(f"β BM25 index loaded for {doc_id} with {len(chunks)} chunks") | |
| except Exception as e: | |
| print(f"β Error loading BM25 index for {doc_id}: {e}") | |
| # Fallback: empty index | |
| self.bm25_indexes[doc_id] = BM25Okapi([[]]) | |
| self.document_chunks[doc_id] = {'chunks': [], 'chunk_ids': [], 'tokenized_chunks': []} | |
| def _tokenize_text(self, text: str) -> List[str]: | |
| """Simple tokenization for BM25.""" | |
| # Remove special characters and convert to lowercase | |
| text = re.sub(r'[^\w\s]', ' ', text.lower()) | |
| # Split and filter empty tokens | |
| tokens = [token for token in text.split() if len(token) > 2] | |
| return tokens | |
| async def hybrid_search(self, queries: List[str], doc_id: str, top_k: int = TOP_K) -> List[Dict]: | |
| """ | |
| Perform hybrid search combining BM25 and semantic search. | |
| Optimized for focused sub-queries from query breakdown. | |
| Uses total budget approach to distribute retrieval across queries. | |
| """ | |
| collection_name = f"{doc_id}_collection" | |
| client = self.get_qdrant_client(doc_id) | |
| # Ensure BM25 index is loaded | |
| if doc_id not in self.bm25_indexes: | |
| self._load_bm25_index(doc_id) | |
| # Calculate per-query budget based on approach | |
| if USE_TOTAL_BUDGET_APPROACH and len(queries) > 1: | |
| per_query_budget = max(1, top_k // len(queries)) | |
| extra_budget = top_k % len(queries) # Distribute remaining budget | |
| print(f"π― Total Budget Approach: Distributing {top_k} candidates across {len(queries)} queries") | |
| print(f" π Base budget per query: {per_query_budget}") | |
| if extra_budget > 0: | |
| print(f" β Extra budget for first {extra_budget} queries: +1 each") | |
| else: | |
| per_query_budget = top_k | |
| extra_budget = 0 | |
| print(f"π Per-Query Approach: Each query gets {per_query_budget} candidates") | |
| all_candidates = {} # point_id -> {'score': float, 'payload': dict, 'source': str} | |
| query_performance = {} # Track performance of each sub-query | |
| print(f"π Running hybrid search with {len(queries)} focused queries...") | |
| for query_idx, query in enumerate(queries): | |
| query_candidates = 0 | |
| query_start = time.time() | |
| # Calculate this query's budget | |
| if USE_TOTAL_BUDGET_APPROACH and len(queries) > 1: | |
| query_budget = per_query_budget + (1 if query_idx < extra_budget else 0) | |
| search_limit = query_budget * 2 # Get extra for better selection | |
| else: | |
| query_budget = per_query_budget | |
| search_limit = query_budget * 2 | |
| print(f" Q{query_idx+1} Budget: {query_budget} candidates (searching {search_limit})") | |
| # 1. Semantic Search | |
| if ENABLE_HYBRID_SEARCH or not ENABLE_HYBRID_SEARCH: # Always do semantic | |
| try: | |
| query_vector = await self.embedding_manager.encode_query(query) | |
| semantic_results = client.search( | |
| collection_name=collection_name, | |
| query_vector=query_vector, | |
| limit=search_limit, # Use query-specific limit | |
| score_threshold=SCORE_THRESHOLD | |
| ) | |
| # Process semantic results with budget limit | |
| semantic_count = 0 | |
| for result in semantic_results: | |
| if USE_TOTAL_BUDGET_APPROACH and semantic_count >= query_budget: | |
| break # Respect budget limit | |
| point_id = str(result.id) | |
| semantic_score = float(result.score) | |
| if point_id not in all_candidates: | |
| all_candidates[point_id] = { | |
| 'semantic_score': 0, | |
| 'bm25_score': 0, | |
| 'payload': result.payload, | |
| 'fusion_score': 0, | |
| 'contributing_queries': [] | |
| } | |
| # Use max score across queries for semantic, but track which queries contributed | |
| if semantic_score > all_candidates[point_id]['semantic_score']: | |
| all_candidates[point_id]['semantic_score'] = semantic_score | |
| all_candidates[point_id]['contributing_queries'].append({ | |
| 'query_idx': query_idx, | |
| 'query_text': query[:50] + '...' if len(query) > 50 else query, | |
| 'semantic_score': semantic_score, | |
| 'type': 'semantic' | |
| }) | |
| query_candidates += 1 | |
| semantic_count += 1 | |
| except Exception as e: | |
| print(f"β οΈ Semantic search failed for query '{query[:50]}...': {e}") | |
| # 2. BM25 Search (if enabled) | |
| if ENABLE_HYBRID_SEARCH and doc_id in self.bm25_indexes: | |
| try: | |
| tokenized_query = self._tokenize_text(query) | |
| bm25_scores = self.bm25_indexes[doc_id].get_scores(tokenized_query) | |
| # Get top BM25 results with budget consideration | |
| chunk_data = self.document_chunks[doc_id] | |
| bm25_top_indices = np.argsort(bm25_scores)[::-1][:search_limit] | |
| # Process BM25 results with budget limit | |
| bm25_count = 0 | |
| for idx in bm25_top_indices: | |
| if USE_TOTAL_BUDGET_APPROACH and bm25_count >= query_budget: | |
| break # Respect budget limit | |
| if idx < len(chunk_data['chunk_ids']) and bm25_scores[idx] > 0: | |
| point_id = str(chunk_data['chunk_ids'][idx]) | |
| bm25_score = float(bm25_scores[idx]) | |
| if point_id not in all_candidates: | |
| all_candidates[point_id] = { | |
| 'semantic_score': 0, | |
| 'bm25_score': 0, | |
| 'payload': {'text': chunk_data['chunks'][idx]}, | |
| 'fusion_score': 0, | |
| 'contributing_queries': [] | |
| } | |
| # Use max score across queries for BM25, but track which queries contributed | |
| if bm25_score > all_candidates[point_id]['bm25_score']: | |
| all_candidates[point_id]['bm25_score'] = bm25_score | |
| all_candidates[point_id]['contributing_queries'].append({ | |
| 'query_idx': query_idx, | |
| 'query_text': query[:50] + '...' if len(query) > 50 else query, | |
| 'bm25_score': bm25_score, | |
| 'type': 'bm25' | |
| }) | |
| query_candidates += 1 | |
| bm25_count += 1 | |
| except Exception as e: | |
| print(f"β οΈ BM25 search failed for query '{query[:50]}...': {e}") | |
| # Track query performance with budget info | |
| query_time = time.time() - query_start | |
| query_performance[query_idx] = { | |
| 'query': query[:80] + '...' if len(query) > 80 else query, | |
| 'candidates_found': query_candidates, | |
| 'budget_allocated': query_budget if USE_TOTAL_BUDGET_APPROACH else 'unlimited', | |
| 'time': query_time | |
| } | |
| # 3. Score Fusion (Reciprocal Rank Fusion + Weighted Combination) | |
| self._apply_score_fusion(all_candidates) | |
| # 4. Sort by fusion score and return top results | |
| sorted_candidates = sorted( | |
| all_candidates.items(), | |
| key=lambda x: x[1]['fusion_score'], | |
| reverse=True | |
| ) | |
| # Convert to result format with enhanced metadata | |
| hybrid_results = [] | |
| for point_id, data in sorted_candidates[:top_k]: | |
| hybrid_results.append({ | |
| 'id': point_id, | |
| 'score': data['fusion_score'], | |
| 'payload': data['payload'], | |
| 'semantic_score': data['semantic_score'], | |
| 'bm25_score': data['bm25_score'], | |
| 'contributing_queries': data['contributing_queries'] | |
| }) | |
| # Log performance summary | |
| approach_name = "Total Budget" if USE_TOTAL_BUDGET_APPROACH else "Per-Query" | |
| print(f"π Hybrid search completed ({approach_name} Approach):") | |
| print(f" π {len(all_candidates)} total candidates from {len(queries)} focused queries") | |
| print(f" π― Top {len(hybrid_results)} results selected") | |
| # Log per-query performance with budget info | |
| total_budget_used = 0 | |
| for idx, perf in query_performance.items(): | |
| budget_info = f" (budget: {perf['budget_allocated']})" if USE_TOTAL_BUDGET_APPROACH else "" | |
| print(f" Q{idx+1}: {perf['candidates_found']} candidates{budget_info} in {perf['time']:.3f}s") | |
| print(f" Query: {perf['query']}") | |
| if USE_TOTAL_BUDGET_APPROACH and isinstance(perf['budget_allocated'], int): | |
| total_budget_used += perf['candidates_found'] | |
| if USE_TOTAL_BUDGET_APPROACH: | |
| print(f" π° Total budget efficiency: {total_budget_used}/{top_k} candidates used") | |
| return hybrid_results | |
| def _apply_score_fusion(self, candidates: Dict): | |
| """Apply advanced score fusion techniques.""" | |
| if not candidates: | |
| return | |
| # Normalize scores | |
| semantic_scores = [data['semantic_score'] for data in candidates.values() if data['semantic_score'] > 0] | |
| bm25_scores = [data['bm25_score'] for data in candidates.values() if data['bm25_score'] > 0] | |
| # Min-Max normalization | |
| if semantic_scores: | |
| sem_min, sem_max = min(semantic_scores), max(semantic_scores) | |
| sem_range = sem_max - sem_min if sem_max > sem_min else 1 | |
| else: | |
| sem_min, sem_range = 0, 1 | |
| if bm25_scores: | |
| bm25_min, bm25_max = min(bm25_scores), max(bm25_scores) | |
| bm25_range = bm25_max - bm25_min if bm25_max > bm25_min else 1 | |
| else: | |
| bm25_min, bm25_range = 0, 1 | |
| # Calculate fusion scores | |
| for point_id, data in candidates.items(): | |
| # Normalize scores | |
| norm_semantic = (data['semantic_score'] - sem_min) / sem_range if data['semantic_score'] > 0 else 0 | |
| norm_bm25 = (data['bm25_score'] - bm25_min) / bm25_range if data['bm25_score'] > 0 else 0 | |
| # Weighted combination | |
| if ENABLE_HYBRID_SEARCH: | |
| fusion_score = (SEMANTIC_WEIGHT * norm_semantic) + (BM25_WEIGHT * norm_bm25) | |
| else: | |
| fusion_score = norm_semantic | |
| # Add reciprocal rank fusion bonus (helps with ranking diversity) | |
| rank_bonus = 1.0 / (1.0 + max(norm_semantic, norm_bm25) * 10) | |
| fusion_score += rank_bonus * 0.1 | |
| data['fusion_score'] = fusion_score | |
| def cleanup(self): | |
| """Cleanup search manager resources.""" | |
| print("π§Ή Cleaning up Search Manager resources...") | |
| # Close all Qdrant clients | |
| for client in self.qdrant_clients.values(): | |
| try: | |
| client.close() | |
| except Exception: | |
| pass | |
| self.qdrant_clients.clear() | |
| self.bm25_indexes.clear() | |
| self.document_chunks.clear() | |
| print("β Search Manager cleanup completed") | |