import faiss import pickle import numpy as np from typing import List, Dict, Any, Tuple, Optional import sys import os from src.utils.config import TOP_K, FAISS_INDEX_PATH, DOC_CHUNKS_PATH # Try to import from the proper location, otherwise use the local copy try: from src.embeddings.embedder import TextEmbedder except ImportError: try: import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from embedder import TextEmbedder print("Using local copy of embedder.py") except ImportError as e: print(f"Error importing TextEmbedder: {e}") # Simple resource manager to avoid circular imports class SimpleResourceManager: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(SimpleResourceManager, cls).__new__(cls) cls._instance.faiss_index = None cls._instance.doc_chunks = None cls._instance.initialized = False return cls._instance def get_faiss_index(self): return self.faiss_index def get_doc_chunks(self): return self.doc_chunks # Create a local resource manager resource_manager = SimpleResourceManager() class Retriever: """ Handles retrieval of relevant document chunks using FAISS vector search. """ def __init__(self, index_path: str = FAISS_INDEX_PATH, chunks_path: str = DOC_CHUNKS_PATH, top_k: int = TOP_K): """ Initialize the retriever with paths to the FAISS index and document chunks. Args: index_path: Path to the FAISS index file chunks_path: Path to the pickled document chunks top_k: Number of chunks to retrieve for a query """ self.index_path = index_path self.chunks_path = chunks_path self.top_k = top_k self.index = None self.doc_chunks = None self.embedder = TextEmbedder() # Try to get resources from the resource manager first self.index = resource_manager.get_faiss_index() self.doc_chunks = resource_manager.get_doc_chunks() # If not available in the resource manager, load directly if self.index is None or self.doc_chunks is None: self._load_resources() def _load_resources(self) -> None: """Load the FAISS index and document chunks from disk.""" try: print(f"Loading FAISS index from {self.index_path}...") self.index = faiss.read_index(self.index_path) print(f"Loading document chunks from {self.chunks_path}...") with open(self.chunks_path, "rb") as f: self.doc_chunks = pickle.load(f) print(f"Resources loaded: {len(self.doc_chunks)} document chunks available.") # Update the resource manager with our loaded resources resource_manager.faiss_index = self.index resource_manager.doc_chunks = self.doc_chunks resource_manager.initialized = True # Ensure embedder dimension matches FAISS index self._ensure_embedder_compatibility() except Exception as e: print(f"Error loading resources: {e}") import traceback traceback.print_exc() raise def _ensure_embedder_compatibility(self) -> None: """Ensure the embedder's dimension matches the FAISS index dimension.""" if self.index is not None and hasattr(self.embedder, 'set_dimension'): faiss_dim = self.index.d embedder_dim = self.embedder.embedding_dim if faiss_dim != embedder_dim: print(f"Warning: Dimension mismatch between FAISS index ({faiss_dim}) and embedder ({embedder_dim})") print(f"Adjusting embedder dimension to match FAISS index") self.embedder.set_dimension(faiss_dim) def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]: """ Retrieve the most relevant document chunks for a query. Args: query: The search query top_k: Number of chunks to retrieve (overrides instance default if provided) Returns: List of the most relevant document chunks with their metadata """ if top_k is None: top_k = self.top_k # Adjust top_k if we have fewer chunks than requested if self.doc_chunks and len(self.doc_chunks) < top_k: top_k = len(self.doc_chunks) print(f"Adjusted top_k to {top_k} based on available chunks") # Get the query embedding query_embedding = self.embedder.get_query_embedding(query) # Search the FAISS index try: print(f"FAISS index info - ntotal: {self.index.ntotal}, dimension: {self.index.d}") print(f"Query embedding shape: {query_embedding.shape}") distances, indices = self.index.search(query_embedding, top_k) # Log first few results for debugging top_indices = indices[0][:min(3, len(indices[0]))] top_distances = distances[0][:min(3, len(distances[0]))] print(f"Top 3 results - indices: {top_indices}, distances: {top_distances}") except Exception as e: print(f"Error during FAISS search: {e}") import traceback traceback.print_exc() # Provide diagnostic information try: # Check if embeddings and index are compatible if self.index is None: print("FAISS index is None - index was not loaded properly") else: print(f"FAISS index dimension: {self.index.d}, total vectors: {self.index.ntotal}") if query_embedding is None: print("Query embedding is None") else: print(f"Query embedding shape: {query_embedding.shape}, dtype: {query_embedding.dtype}") if query_embedding.shape[1] != self.index.d: print(f"Dimension mismatch: query embedding ({query_embedding.shape[1]}) vs. FAISS index ({self.index.d})") except Exception as diagnostic_e: print(f"Error during diagnostics: {diagnostic_e}") # Return all available chunks as fallback return self._get_all_chunks_with_placeholder_scores() # Collect the retrieved chunks retrieved_chunks = [] for i, idx in enumerate(indices[0]): if idx < len(self.doc_chunks): # Make a copy to avoid modifying the original try: chunk_info = self.doc_chunks[idx].copy() if isinstance(self.doc_chunks[idx], dict) else {"text": self.doc_chunks[idx]} chunk_info['score'] = float(distances[0][i]) # Add the similarity score # Ensure basic required fields exist in fallback data if 'text' not in chunk_info and 'chunk' in chunk_info: chunk_info['text'] = chunk_info['chunk'] if 'source' not in chunk_info: chunk_info['source'] = f"source_{idx}" if 'chunk_id' not in chunk_info: chunk_info['chunk_id'] = idx retrieved_chunks.append(chunk_info) except Exception as e: print(f"Error processing chunk at index {idx}: {e}") # If we couldn't retrieve any chunks, return fallback chunks if not retrieved_chunks: print("No chunks could be retrieved, using fallback") return self._get_all_chunks_with_placeholder_scores() return retrieved_chunks def _get_all_chunks_with_placeholder_scores(self) -> List[Dict[str, Any]]: """Return all available chunks with placeholder scores as fallback.""" fallback_chunks = [] for idx, chunk in enumerate(self.doc_chunks): try: if isinstance(chunk, dict): chunk_info = chunk.copy() else: chunk_info = {"text": chunk} chunk_info['score'] = 1.0 - (idx * 0.1) # Placeholder decreasing scores # Ensure basic required fields exist if 'text' not in chunk_info and 'chunk' in chunk_info: chunk_info['text'] = chunk_info['chunk'] if 'source' not in chunk_info: chunk_info['source'] = f"source_{idx}" if 'chunk_id' not in chunk_info: chunk_info['chunk_id'] = idx fallback_chunks.append(chunk_info) except Exception as e: print(f"Error creating fallback chunk at index {idx}: {e}") return fallback_chunks def get_formatted_context(self, retrieved_chunks: List[Dict[str, Any]]) -> str: """ Format the retrieved chunks into a context string for the LLM. Args: retrieved_chunks: List of retrieved document chunks Returns: Formatted context string """ formatted_chunks = [] for chunk in retrieved_chunks: try: # Get the chunk text (might be in 'text' or 'chunk' field) chunk_text = chunk.get('text', chunk.get('chunk', "No text available")) # Create a header with available metadata source = chunk.get('source', 'unknown_source') chunk_id = chunk.get('chunk_id', 'unknown_id') header = f"[{source} - chunk {chunk_id}]" formatted_chunks.append(f"{header}:\n{chunk_text}") except Exception as e: print(f"Error formatting chunk: {e}") return "\n\n".join(formatted_chunks)