RAG-Insurance / vector_store.py
mokhles's picture
Initial commit: Insurance RAG API
af37875
import logging
from typing import Optional, Dict, Any, List
import threading
import re
import numpy as np
import chromadb
from rank_bm25 import BM25Okapi
logger = logging.getLogger(__name__)
class VectorStoreManager:
_instance = None
_lock = threading.Lock()
_initialized = False
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
with self._lock:
if not self._initialized:
self._initialize()
VectorStoreManager._initialized = True
def _initialize(self):
"""Initialize vector store with single collection + BM25 index"""
try:
logger.info("Initializing vector store components...")
self.client = None
self.collection = None
db_path = "output/chromadb" # Match your pipeline path
self.client = chromadb.PersistentClient(path=db_path)
logger.info(f"ChromaDB client initialized at path: {db_path}")
available_collections = [col.name for col in self.client.list_collections()]
logger.info(f"Available collections: {available_collections}")
try:
self.collection = self.client.get_collection("rag_documents")
collection_count = self.collection.count()
logger.info(
f"Collection 'rag_documents' loaded with {collection_count} documents"
)
except Exception as e:
logger.error(f"Collection 'rag_documents' not found: {str(e)}")
raise ValueError(
"Required collection 'rag_documents' not found. "
f"Available: {available_collections}"
)
# ---- Build BM25 index from all stored docs ----
logger.info("Building BM25 index from Chroma collection...")
data = self.collection.get(include=["documents", "metadatas"])
self.all_ids: List[str] = data["ids"]
self.all_docs: List[str] = data["documents"]
self.all_metas: List[Dict[str, Any]] = data["metadatas"]
self.tokenized_corpus = [self._tokenize(d) for d in self.all_docs]
self.bm25 = BM25Okapi(self.tokenized_corpus)
logger.info(f"BM25 index ready with {len(self.all_docs)} chunks")
logger.info("Vector store initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize vector store: {str(e)}")
VectorStoreManager._initialized = False
raise
# ----------------- Helpers -----------------
def _tokenize(self, text: str) -> List[str]:
return re.findall(r"\w+", (text or "").lower())
def _matches_filters(
self,
meta: Dict[str, Any],
doc_text: str,
where_filters: Optional[Dict[str, Any]],
where_document: Optional[Dict[str, Any]],
) -> bool:
if where_filters:
for k, v in where_filters.items():
if meta.get(k) != v:
return False
if where_document:
# you only use {"$contains": "..."}
contains = where_document.get("$contains")
if contains and contains.lower() not in (doc_text or "").lower():
return False
return True
def _rrf_fuse(
self,
dense_ranked: List[Dict[str, Any]],
sparse_ranked: List[Dict[str, Any]],
k: int = 60,
w_dense: float = 0.6,
w_sparse: float = 0.4,
) -> List[Dict[str, Any]]:
"""
Reciprocal Rank Fusion
score = w_dense/(k+rank_dense) + w_sparse/(k+rank_sparse)
"""
scores: Dict[str, Dict[str, Any]] = {}
for rank, item in enumerate(dense_ranked):
doc_id = item["id"]
scores.setdefault(doc_id, {"score": 0.0, "item": item})
scores[doc_id]["score"] += w_dense / (k + rank + 1)
for rank, item in enumerate(sparse_ranked):
doc_id = item["id"]
scores.setdefault(doc_id, {"score": 0.0, "item": item})
scores[doc_id]["score"] += w_sparse / (k + rank + 1)
fused = sorted(scores.values(), key=lambda x: x["score"], reverse=True)
return [x["item"] for x in fused]
# ----------------- Main retrieval -----------------
def retrieve_documents(
self,
question: str,
n_results: int = 5,
where_filters: Optional[Dict[str, Any]] = None,
where_document: Optional[Dict[str, Any]] = None,
enable_bm25: bool = False,
bm25_k: Optional[int] = None,
alpha: float = 0.6, # dense weight in hybrid fusion
) -> List[Dict[str, Any]]:
"""
Retrieve documents using:
- semantic-only (Chroma)
- or hybrid semantic + BM25 (RRF fusion)
Returns a list of dicts:
{id, text, metadata, distance, bm25_score(optional)}
"""
if not self._initialized or self.collection is None:
raise RuntimeError("VectorStoreManager not properly initialized")
logger.info(f"Retrieving documents for query: {question[:50]}...")
dense_k = n_results
bm25_k = bm25_k or n_results
# ----- Dense retrieval (semantic via Chroma) -----
try:
dense_res = self.collection.query(
query_texts=[question],
n_results=dense_k,
include=["documents", "metadatas", "distances"],
where=where_filters if where_filters else None,
where_document=where_document if where_document else None,
)
except Exception as e:
logger.error(f"Dense retrieval failed: {str(e)}")
raise
dense_ranked: List[Dict[str, Any]] = []
if dense_res and dense_res.get("documents") and dense_res["documents"][0]:
for i in range(len(dense_res["documents"][0])):
meta = dense_res["metadatas"][0][i]
dense_ranked.append({
"id": dense_res["ids"][0][i],
"text": dense_res["documents"][0][i],
"metadata": meta,
"distance": float(dense_res["distances"][0][i]),
"source": meta.get("source", "Unknown"),
})
if not enable_bm25:
logger.info(f"Semantic-only retrieved {len(dense_ranked)} docs")
return dense_ranked
# ----- Sparse retrieval (BM25) -----
q_tokens = self._tokenize(question)
scores = self.bm25.get_scores(q_tokens)
# Apply same filters to BM25 corpus
valid_indices = []
for idx, (doc, meta) in enumerate(zip(self.all_docs, self.all_metas)):
if self._matches_filters(meta, doc, where_filters, where_document):
valid_indices.append(idx)
# take top bm25_k from valid indices
valid_scores = [(idx, scores[idx]) for idx in valid_indices]
valid_scores.sort(key=lambda x: x[1], reverse=True)
top_sparse = valid_scores[:bm25_k]
sparse_ranked: List[Dict[str, Any]] = []
for idx, s in top_sparse:
meta = self.all_metas[idx]
sparse_ranked.append({
"id": self.all_ids[idx],
"text": self.all_docs[idx],
"metadata": meta,
"bm25_score": float(s),
"distance": None, # may be absent if not in dense top-k
"source": meta.get("source", "Unknown"),
})
# ----- Fuse dense + sparse -----
fused = self._rrf_fuse(
dense_ranked,
sparse_ranked,
w_dense=alpha,
w_sparse=1.0 - alpha,
)
logger.info(
f"Hybrid retrieved dense={len(dense_ranked)} sparse={len(sparse_ranked)} "
f"fused={len(fused)}"
)
return fused
def get_vector_store() -> VectorStoreManager:
"""FastAPI dependency for injecting VectorStoreManager"""
return VectorStoreManager()