agentic-defensor / src /models /retriever.py
vichudo's picture
fix
254ca68
Raw
History Blame Contribute Delete
10.5 kB
import faiss
import pickle
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
import sys
import os
from src.utils.config import TOP_K, FAISS_INDEX_PATH, DOC_CHUNKS_PATH
# Try to import from the proper location, otherwise use the local copy
try:
from src.embeddings.embedder import TextEmbedder
except ImportError:
try:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from embedder import TextEmbedder
print("Using local copy of embedder.py")
except ImportError as e:
print(f"Error importing TextEmbedder: {e}")
# Simple resource manager to avoid circular imports
class SimpleResourceManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(SimpleResourceManager, cls).__new__(cls)
cls._instance.faiss_index = None
cls._instance.doc_chunks = None
cls._instance.initialized = False
return cls._instance
def get_faiss_index(self):
return self.faiss_index
def get_doc_chunks(self):
return self.doc_chunks
# Create a local resource manager
resource_manager = SimpleResourceManager()
class Retriever:
"""
Handles retrieval of relevant document chunks using FAISS vector search.
"""
def __init__(self,
index_path: str = FAISS_INDEX_PATH,
chunks_path: str = DOC_CHUNKS_PATH,
top_k: int = TOP_K):
"""
Initialize the retriever with paths to the FAISS index and document chunks.
Args:
index_path: Path to the FAISS index file
chunks_path: Path to the pickled document chunks
top_k: Number of chunks to retrieve for a query
"""
self.index_path = index_path
self.chunks_path = chunks_path
self.top_k = top_k
self.index = None
self.doc_chunks = None
self.embedder = TextEmbedder()
# Try to get resources from the resource manager first
self.index = resource_manager.get_faiss_index()
self.doc_chunks = resource_manager.get_doc_chunks()
# If not available in the resource manager, load directly
if self.index is None or self.doc_chunks is None:
self._load_resources()
def _load_resources(self) -> None:
"""Load the FAISS index and document chunks from disk."""
try:
print(f"Loading FAISS index from {self.index_path}...")
self.index = faiss.read_index(self.index_path)
print(f"Loading document chunks from {self.chunks_path}...")
with open(self.chunks_path, "rb") as f:
self.doc_chunks = pickle.load(f)
print(f"Resources loaded: {len(self.doc_chunks)} document chunks available.")
# Update the resource manager with our loaded resources
resource_manager.faiss_index = self.index
resource_manager.doc_chunks = self.doc_chunks
resource_manager.initialized = True
# Ensure embedder dimension matches FAISS index
self._ensure_embedder_compatibility()
except Exception as e:
print(f"Error loading resources: {e}")
import traceback
traceback.print_exc()
raise
def _ensure_embedder_compatibility(self) -> None:
"""Ensure the embedder's dimension matches the FAISS index dimension."""
if self.index is not None and hasattr(self.embedder, 'set_dimension'):
faiss_dim = self.index.d
embedder_dim = self.embedder.embedding_dim
if faiss_dim != embedder_dim:
print(f"Warning: Dimension mismatch between FAISS index ({faiss_dim}) and embedder ({embedder_dim})")
print(f"Adjusting embedder dimension to match FAISS index")
self.embedder.set_dimension(faiss_dim)
def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Retrieve the most relevant document chunks for a query.
Args:
query: The search query
top_k: Number of chunks to retrieve (overrides instance default if provided)
Returns:
List of the most relevant document chunks with their metadata
"""
if top_k is None:
top_k = self.top_k
# Adjust top_k if we have fewer chunks than requested
if self.doc_chunks and len(self.doc_chunks) < top_k:
top_k = len(self.doc_chunks)
print(f"Adjusted top_k to {top_k} based on available chunks")
# Get the query embedding
query_embedding = self.embedder.get_query_embedding(query)
# Search the FAISS index
try:
print(f"FAISS index info - ntotal: {self.index.ntotal}, dimension: {self.index.d}")
print(f"Query embedding shape: {query_embedding.shape}")
distances, indices = self.index.search(query_embedding, top_k)
# Log first few results for debugging
top_indices = indices[0][:min(3, len(indices[0]))]
top_distances = distances[0][:min(3, len(distances[0]))]
print(f"Top 3 results - indices: {top_indices}, distances: {top_distances}")
except Exception as e:
print(f"Error during FAISS search: {e}")
import traceback
traceback.print_exc()
# Provide diagnostic information
try:
# Check if embeddings and index are compatible
if self.index is None:
print("FAISS index is None - index was not loaded properly")
else:
print(f"FAISS index dimension: {self.index.d}, total vectors: {self.index.ntotal}")
if query_embedding is None:
print("Query embedding is None")
else:
print(f"Query embedding shape: {query_embedding.shape}, dtype: {query_embedding.dtype}")
if query_embedding.shape[1] != self.index.d:
print(f"Dimension mismatch: query embedding ({query_embedding.shape[1]}) vs. FAISS index ({self.index.d})")
except Exception as diagnostic_e:
print(f"Error during diagnostics: {diagnostic_e}")
# Return all available chunks as fallback
return self._get_all_chunks_with_placeholder_scores()
# Collect the retrieved chunks
retrieved_chunks = []
for i, idx in enumerate(indices[0]):
if idx < len(self.doc_chunks):
# Make a copy to avoid modifying the original
try:
chunk_info = self.doc_chunks[idx].copy() if isinstance(self.doc_chunks[idx], dict) else {"text": self.doc_chunks[idx]}
chunk_info['score'] = float(distances[0][i]) # Add the similarity score
# Ensure basic required fields exist in fallback data
if 'text' not in chunk_info and 'chunk' in chunk_info:
chunk_info['text'] = chunk_info['chunk']
if 'source' not in chunk_info:
chunk_info['source'] = f"source_{idx}"
if 'chunk_id' not in chunk_info:
chunk_info['chunk_id'] = idx
retrieved_chunks.append(chunk_info)
except Exception as e:
print(f"Error processing chunk at index {idx}: {e}")
# If we couldn't retrieve any chunks, return fallback chunks
if not retrieved_chunks:
print("No chunks could be retrieved, using fallback")
return self._get_all_chunks_with_placeholder_scores()
return retrieved_chunks
def _get_all_chunks_with_placeholder_scores(self) -> List[Dict[str, Any]]:
"""Return all available chunks with placeholder scores as fallback."""
fallback_chunks = []
for idx, chunk in enumerate(self.doc_chunks):
try:
if isinstance(chunk, dict):
chunk_info = chunk.copy()
else:
chunk_info = {"text": chunk}
chunk_info['score'] = 1.0 - (idx * 0.1) # Placeholder decreasing scores
# Ensure basic required fields exist
if 'text' not in chunk_info and 'chunk' in chunk_info:
chunk_info['text'] = chunk_info['chunk']
if 'source' not in chunk_info:
chunk_info['source'] = f"source_{idx}"
if 'chunk_id' not in chunk_info:
chunk_info['chunk_id'] = idx
fallback_chunks.append(chunk_info)
except Exception as e:
print(f"Error creating fallback chunk at index {idx}: {e}")
return fallback_chunks
def get_formatted_context(self, retrieved_chunks: List[Dict[str, Any]]) -> str:
"""
Format the retrieved chunks into a context string for the LLM.
Args:
retrieved_chunks: List of retrieved document chunks
Returns:
Formatted context string
"""
formatted_chunks = []
for chunk in retrieved_chunks:
try:
# Get the chunk text (might be in 'text' or 'chunk' field)
chunk_text = chunk.get('text', chunk.get('chunk', "No text available"))
# Create a header with available metadata
source = chunk.get('source', 'unknown_source')
chunk_id = chunk.get('chunk_id', 'unknown_id')
header = f"[{source} - chunk {chunk_id}]"
formatted_chunks.append(f"{header}:\n{chunk_text}")
except Exception as e:
print(f"Error formatting chunk: {e}")
return "\n\n".join(formatted_chunks)