advance-multidoc-rag / src /context_retriever.py
Fnu Mahnoor
Fix Inference
c07baa6
import os
import faiss
import torch
import numpy as np
from pathlib import Path
import logging
import json
from functools import lru_cache
from sentence_transformers import SentenceTransformer
from .embeddings_utils import load_faiss_index, load_metadata
from .graph_index import HierarchicalGraphManager
# Import your new graph manager
# 1. Global Cache for the Embedding Model
# This prevents the 2-5 second delay of reloading the model on every question.
_MODEL_CACHE = {}
def get_embedding_model(model_name: str):
if model_name not in _MODEL_CACHE:
# Use 'auto' to ensure it fits on whatever hardware is available
device = "cuda" if torch.cuda.is_available() else "cpu"
_MODEL_CACHE[model_name] = SentenceTransformer(model_name, device=device, trust_remote_code=True)
return _MODEL_CACHE[model_name]
def query_index(index_path: str, meta_path: str, query: str, top_k: int = 10, model_name: str = "nomic-ai/nomic-embed-text-v1"):
# Load Index and Meta
index = load_faiss_index(index_path)
meta = load_metadata(meta_path)
# Get the "Warm" model
model = get_embedding_model(model_name)
# 2. Optimized Inference
with torch.no_grad():
q_emb = model.encode([query], convert_to_numpy=True).astype('float32')
# Unit normalization for Cosine Similarity (matches our optimized indexing)
faiss.normalize_L2(q_emb)
# 3. Fast FAISS Search
distances, indices = index.search(q_emb, top_k)
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx < 0 or idx >= len(meta):
continue
item = meta[idx]
results.append({
"score": float(dist),
"text": item.get("text"),
"source": item.get("source"),
"page": item.get("page"),
"chunk_id": item.get("chunk_id"),
})
return results
def get_latest_session_paths():
"""
Scans the data directory to find the most recently updated session folder.
"""
# Define your base search areas
base_dirs = [Path("data/_staging"), Path("data/uploads")]
all_sessions = []
for base in base_dirs:
if base.exists():
# Get all subdirectories
all_sessions.extend([d for d in base.iterdir() if d.is_dir()])
if not all_sessions:
# Fallback to the root data directory if no session folders exist
return Path("data/faiss.index"), Path("data/meta.pkl"), Path("data/graph_data.pkl")
# Find the directory with the latest modification time
latest_dir = max(all_sessions, key=lambda d: d.stat().st_mtime)
logging.info(f"๐Ÿš€ Context Retriever: Pointing to latest session -> {latest_dir.name}")
return (
latest_dir / "faiss.index",
latest_dir / "meta.pkl",
latest_dir / "graph_data.pkl"
)
def retrieve_contexts(query: str, idx_path, meta_path, k: int = 10, return_with_sources: bool = True):
"""
Finds the latest index automatically and performs Vector Search.
"""
if not idx_path.exists() or not meta_path.exists():
logging.warning("โš ๏ธ No vector index found in the latest session.")
return []
try:
results = query_index(str(idx_path), str(meta_path), query, top_k=k)
return results if return_with_sources else [r["text"] for r in results]
except Exception as e:
logging.error(f"Vector Retrieval Error: {e}")
return []
def retrieve_hybrid_context(query: str, k: int = 5):
"""
The Master Coordinator: Fetches data and logs progress for debugging.
"""
# 1. Track Session Discovery
idx_path, meta_path, graph_path = get_latest_session_paths()
logging.info(f"๐Ÿ“‚ Session Target: {idx_path.parent.name}")
# 2. Get Vector Context
vector_results = retrieve_contexts(query, idx_path, meta_path, k=k, return_with_sources=True)
logging.info(f"๐Ÿ” Vector Search: Found {len(vector_results)} chunks")
# 3. Get Graph Context
gm = HierarchicalGraphManager(storage_path=str(graph_path))
graph_summaries = gm.get_relevant_community_summaries(query)
logging.info(f"๐Ÿ•ธ๏ธ Graph Search: Found {len(graph_summaries)} community summaries")
# 4. Prepare for Serialization
data = {
"vector_context": vector_results,
"graph_context": graph_summaries,
"metadata": vector_results
}
# 5. Log the final state before string conversion
if not data["vector_context"] and not data["graph_context"]:
logging.warning("โš ๏ธ Hybrid retrieval returned ZERO context.")
else:
logging.info("โœ… Hybrid context successfully compiled.")
# Return as string with default=str to catch non-serializable metadata
return json.dumps(data, indent=4, ensure_ascii=False, default=str)