Spaces:
Sleeping
Sleeping
| import faiss | |
| import pickle | |
| import numpy as np | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import sys | |
| import os | |
| from src.utils.config import TOP_K, FAISS_INDEX_PATH, DOC_CHUNKS_PATH | |
| # Try to import from the proper location, otherwise use the local copy | |
| try: | |
| from src.embeddings.embedder import TextEmbedder | |
| except ImportError: | |
| try: | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| from embedder import TextEmbedder | |
| print("Using local copy of embedder.py") | |
| except ImportError as e: | |
| print(f"Error importing TextEmbedder: {e}") | |
| # Simple resource manager to avoid circular imports | |
| class SimpleResourceManager: | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(SimpleResourceManager, cls).__new__(cls) | |
| cls._instance.faiss_index = None | |
| cls._instance.doc_chunks = None | |
| cls._instance.initialized = False | |
| return cls._instance | |
| def get_faiss_index(self): | |
| return self.faiss_index | |
| def get_doc_chunks(self): | |
| return self.doc_chunks | |
| # Create a local resource manager | |
| resource_manager = SimpleResourceManager() | |
| class Retriever: | |
| """ | |
| Handles retrieval of relevant document chunks using FAISS vector search. | |
| """ | |
| def __init__(self, | |
| index_path: str = FAISS_INDEX_PATH, | |
| chunks_path: str = DOC_CHUNKS_PATH, | |
| top_k: int = TOP_K): | |
| """ | |
| Initialize the retriever with paths to the FAISS index and document chunks. | |
| Args: | |
| index_path: Path to the FAISS index file | |
| chunks_path: Path to the pickled document chunks | |
| top_k: Number of chunks to retrieve for a query | |
| """ | |
| self.index_path = index_path | |
| self.chunks_path = chunks_path | |
| self.top_k = top_k | |
| self.index = None | |
| self.doc_chunks = None | |
| self.embedder = TextEmbedder() | |
| # Try to get resources from the resource manager first | |
| self.index = resource_manager.get_faiss_index() | |
| self.doc_chunks = resource_manager.get_doc_chunks() | |
| # If not available in the resource manager, load directly | |
| if self.index is None or self.doc_chunks is None: | |
| self._load_resources() | |
| def _load_resources(self) -> None: | |
| """Load the FAISS index and document chunks from disk.""" | |
| try: | |
| print(f"Loading FAISS index from {self.index_path}...") | |
| self.index = faiss.read_index(self.index_path) | |
| print(f"Loading document chunks from {self.chunks_path}...") | |
| with open(self.chunks_path, "rb") as f: | |
| self.doc_chunks = pickle.load(f) | |
| print(f"Resources loaded: {len(self.doc_chunks)} document chunks available.") | |
| # Update the resource manager with our loaded resources | |
| resource_manager.faiss_index = self.index | |
| resource_manager.doc_chunks = self.doc_chunks | |
| resource_manager.initialized = True | |
| # Ensure embedder dimension matches FAISS index | |
| self._ensure_embedder_compatibility() | |
| except Exception as e: | |
| print(f"Error loading resources: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def _ensure_embedder_compatibility(self) -> None: | |
| """Ensure the embedder's dimension matches the FAISS index dimension.""" | |
| if self.index is not None and hasattr(self.embedder, 'set_dimension'): | |
| faiss_dim = self.index.d | |
| embedder_dim = self.embedder.embedding_dim | |
| if faiss_dim != embedder_dim: | |
| print(f"Warning: Dimension mismatch between FAISS index ({faiss_dim}) and embedder ({embedder_dim})") | |
| print(f"Adjusting embedder dimension to match FAISS index") | |
| self.embedder.set_dimension(faiss_dim) | |
| def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve the most relevant document chunks for a query. | |
| Args: | |
| query: The search query | |
| top_k: Number of chunks to retrieve (overrides instance default if provided) | |
| Returns: | |
| List of the most relevant document chunks with their metadata | |
| """ | |
| if top_k is None: | |
| top_k = self.top_k | |
| # Adjust top_k if we have fewer chunks than requested | |
| if self.doc_chunks and len(self.doc_chunks) < top_k: | |
| top_k = len(self.doc_chunks) | |
| print(f"Adjusted top_k to {top_k} based on available chunks") | |
| # Get the query embedding | |
| query_embedding = self.embedder.get_query_embedding(query) | |
| # Search the FAISS index | |
| try: | |
| print(f"FAISS index info - ntotal: {self.index.ntotal}, dimension: {self.index.d}") | |
| print(f"Query embedding shape: {query_embedding.shape}") | |
| distances, indices = self.index.search(query_embedding, top_k) | |
| # Log first few results for debugging | |
| top_indices = indices[0][:min(3, len(indices[0]))] | |
| top_distances = distances[0][:min(3, len(distances[0]))] | |
| print(f"Top 3 results - indices: {top_indices}, distances: {top_distances}") | |
| except Exception as e: | |
| print(f"Error during FAISS search: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Provide diagnostic information | |
| try: | |
| # Check if embeddings and index are compatible | |
| if self.index is None: | |
| print("FAISS index is None - index was not loaded properly") | |
| else: | |
| print(f"FAISS index dimension: {self.index.d}, total vectors: {self.index.ntotal}") | |
| if query_embedding is None: | |
| print("Query embedding is None") | |
| else: | |
| print(f"Query embedding shape: {query_embedding.shape}, dtype: {query_embedding.dtype}") | |
| if query_embedding.shape[1] != self.index.d: | |
| print(f"Dimension mismatch: query embedding ({query_embedding.shape[1]}) vs. FAISS index ({self.index.d})") | |
| except Exception as diagnostic_e: | |
| print(f"Error during diagnostics: {diagnostic_e}") | |
| # Return all available chunks as fallback | |
| return self._get_all_chunks_with_placeholder_scores() | |
| # Collect the retrieved chunks | |
| retrieved_chunks = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx < len(self.doc_chunks): | |
| # Make a copy to avoid modifying the original | |
| try: | |
| chunk_info = self.doc_chunks[idx].copy() if isinstance(self.doc_chunks[idx], dict) else {"text": self.doc_chunks[idx]} | |
| chunk_info['score'] = float(distances[0][i]) # Add the similarity score | |
| # Ensure basic required fields exist in fallback data | |
| if 'text' not in chunk_info and 'chunk' in chunk_info: | |
| chunk_info['text'] = chunk_info['chunk'] | |
| if 'source' not in chunk_info: | |
| chunk_info['source'] = f"source_{idx}" | |
| if 'chunk_id' not in chunk_info: | |
| chunk_info['chunk_id'] = idx | |
| retrieved_chunks.append(chunk_info) | |
| except Exception as e: | |
| print(f"Error processing chunk at index {idx}: {e}") | |
| # If we couldn't retrieve any chunks, return fallback chunks | |
| if not retrieved_chunks: | |
| print("No chunks could be retrieved, using fallback") | |
| return self._get_all_chunks_with_placeholder_scores() | |
| return retrieved_chunks | |
| def _get_all_chunks_with_placeholder_scores(self) -> List[Dict[str, Any]]: | |
| """Return all available chunks with placeholder scores as fallback.""" | |
| fallback_chunks = [] | |
| for idx, chunk in enumerate(self.doc_chunks): | |
| try: | |
| if isinstance(chunk, dict): | |
| chunk_info = chunk.copy() | |
| else: | |
| chunk_info = {"text": chunk} | |
| chunk_info['score'] = 1.0 - (idx * 0.1) # Placeholder decreasing scores | |
| # Ensure basic required fields exist | |
| if 'text' not in chunk_info and 'chunk' in chunk_info: | |
| chunk_info['text'] = chunk_info['chunk'] | |
| if 'source' not in chunk_info: | |
| chunk_info['source'] = f"source_{idx}" | |
| if 'chunk_id' not in chunk_info: | |
| chunk_info['chunk_id'] = idx | |
| fallback_chunks.append(chunk_info) | |
| except Exception as e: | |
| print(f"Error creating fallback chunk at index {idx}: {e}") | |
| return fallback_chunks | |
| def get_formatted_context(self, retrieved_chunks: List[Dict[str, Any]]) -> str: | |
| """ | |
| Format the retrieved chunks into a context string for the LLM. | |
| Args: | |
| retrieved_chunks: List of retrieved document chunks | |
| Returns: | |
| Formatted context string | |
| """ | |
| formatted_chunks = [] | |
| for chunk in retrieved_chunks: | |
| try: | |
| # Get the chunk text (might be in 'text' or 'chunk' field) | |
| chunk_text = chunk.get('text', chunk.get('chunk', "No text available")) | |
| # Create a header with available metadata | |
| source = chunk.get('source', 'unknown_source') | |
| chunk_id = chunk.get('chunk_id', 'unknown_id') | |
| header = f"[{source} - chunk {chunk_id}]" | |
| formatted_chunks.append(f"{header}:\n{chunk_text}") | |
| except Exception as e: | |
| print(f"Error formatting chunk: {e}") | |
| return "\n\n".join(formatted_chunks) |