Spaces:
Sleeping
Sleeping
Initial commit: Insurance RAG API
Browse files- Dockerfile +16 -0
- README.md +5 -4
- app.py +23 -0
- chroma.py +302 -0
- requirements.txt +19 -0
- retrieval.py +208 -0
- vector_store.py +228 -0
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
EXPOSE 7860
|
| 15 |
+
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Insurance Rag Api
|
| 3 |
+
emoji: 🌖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
short_description: Production‑ready FastAPI Retrieval‑Augmented Generation (RAG
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
from fastapi import FastAPI
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
|
| 5 |
+
from retrieval import router as retrieval_router
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
app = FastAPI(title="Insurance RAG API", version="1.0.0")
|
| 9 |
+
|
| 10 |
+
app.add_middleware(
|
| 11 |
+
CORSMiddleware,
|
| 12 |
+
allow_origins=["*"],
|
| 13 |
+
allow_credentials=True,
|
| 14 |
+
allow_methods=["*"],
|
| 15 |
+
allow_headers=["*"],
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
app.include_router(retrieval_router)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@app.get("/")
|
| 22 |
+
async def root():
|
| 23 |
+
return {"message": "Insurance RAG API is running", "docs": "/docs"}
|
chroma.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# chroma.py (minimal, no visualization, WITH sentence-transformers, with .env)
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
|
| 8 |
+
import pandas as pd # (currently unused but kept if you need it later)
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
from llama_parse import LlamaParse
|
| 12 |
+
from llama_index.core.node_parser import SentenceSplitter
|
| 13 |
+
|
| 14 |
+
import chromadb
|
| 15 |
+
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
|
| 16 |
+
from openai import OpenAI
|
| 17 |
+
|
| 18 |
+
import nest_asyncio
|
| 19 |
+
nest_asyncio.apply()
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings("ignore")
|
| 22 |
+
|
| 23 |
+
# ---------- LOAD .env ----------
|
| 24 |
+
load_dotenv()
|
| 25 |
+
|
| 26 |
+
# ---------- CONFIG ----------
|
| 27 |
+
CONFIG = {
|
| 28 |
+
"pdf_directory": r"C:\Users\Legion\Documents\Ominimo Job\Pdfs for RAG",
|
| 29 |
+
"output_directory": "./output/",
|
| 30 |
+
"llm_model": "gpt-4.1-mini",
|
| 31 |
+
"chunk_size": 512,
|
| 32 |
+
"chunk_overlap": 50,
|
| 33 |
+
"top_k_retrieval": 3,
|
| 34 |
+
|
| 35 |
+
# ✅ SentenceTransformer embedding model (384-D for MiniLM)
|
| 36 |
+
# Must match your retrieval embedding model.
|
| 37 |
+
"embedding_model": "all-MiniLM-L6-v2",
|
| 38 |
+
|
| 39 |
+
# Optional: force device ("cpu" or "cuda")
|
| 40 |
+
"embedding_device": os.getenv("EMB_DEVICE", "cpu"),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
Path(CONFIG["output_directory"]).mkdir(parents=True, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# ---------- OPENAI CLIENT (for summaries only) ----------
|
| 46 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
| 47 |
+
if not OPENAI_API_KEY:
|
| 48 |
+
raise RuntimeError("OPENAI_API_KEY is not set in the environment or .env file.")
|
| 49 |
+
|
| 50 |
+
client = OpenAI(api_key=OPENAI_API_KEY)
|
| 51 |
+
document_summaries: Dict[str, str] = {}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def summarize_document(text: str, client: OpenAI, model: str) -> str:
|
| 55 |
+
"""Generate a summary of the document using OpenAI (used only for summaries)."""
|
| 56 |
+
response = client.chat.completions.create(
|
| 57 |
+
model=model,
|
| 58 |
+
messages=[
|
| 59 |
+
{
|
| 60 |
+
"role": "system",
|
| 61 |
+
"content": (
|
| 62 |
+
"You are a helpful assistant that creates concise "
|
| 63 |
+
"summaries of documents."
|
| 64 |
+
),
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"role": "user",
|
| 68 |
+
"content": (
|
| 69 |
+
"Please provide a comprehensive summary of the "
|
| 70 |
+
"following document:\n\n"
|
| 71 |
+
f"{text[:4000]}"
|
| 72 |
+
),
|
| 73 |
+
},
|
| 74 |
+
],
|
| 75 |
+
temperature=0.3,
|
| 76 |
+
max_tokens=500,
|
| 77 |
+
)
|
| 78 |
+
return response.choices[0].message.content
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------- PDF PARSING ----------
|
| 82 |
+
def parse_pdfs_with_llamaparse(pdf_directory: str) -> List[Dict]:
|
| 83 |
+
"""Parse PDFs using LlamaParse with batch processing."""
|
| 84 |
+
pdf_files = list(Path(pdf_directory).glob("*.pdf"))
|
| 85 |
+
print(f"Found {len(pdf_files)} PDF files")
|
| 86 |
+
|
| 87 |
+
llama_key = os.environ.get("LLAMA_CLOUD_API_KEY")
|
| 88 |
+
if not llama_key:
|
| 89 |
+
raise RuntimeError("LLAMA_CLOUD_API_KEY is not set in the environment or .env.")
|
| 90 |
+
|
| 91 |
+
parser = LlamaParse(
|
| 92 |
+
api_key=llama_key,
|
| 93 |
+
result_type="markdown",
|
| 94 |
+
verbose=True,
|
| 95 |
+
language="en",
|
| 96 |
+
num_workers=4,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
all_documents: List[Dict] = []
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
print("\nParsing all PDFs in batch...")
|
| 103 |
+
pdf_paths = [str(pdf) for pdf in pdf_files]
|
| 104 |
+
documents_batch = parser.load_data(pdf_paths)
|
| 105 |
+
print(f"✓ Successfully parsed {len(documents_batch)} document sections")
|
| 106 |
+
|
| 107 |
+
doc_index = 0
|
| 108 |
+
for pdf_path in pdf_files:
|
| 109 |
+
print(f"\nProcessing: {pdf_path.name}")
|
| 110 |
+
pdf_docs = []
|
| 111 |
+
|
| 112 |
+
while doc_index < len(documents_batch):
|
| 113 |
+
doc = documents_batch[doc_index]
|
| 114 |
+
|
| 115 |
+
if hasattr(doc, "metadata") and doc.metadata.get("file_path"):
|
| 116 |
+
if pdf_path.name in doc.metadata.get("file_path", ""):
|
| 117 |
+
pdf_docs.append(doc)
|
| 118 |
+
doc_index += 1
|
| 119 |
+
else:
|
| 120 |
+
break
|
| 121 |
+
else:
|
| 122 |
+
pdf_docs.append(doc)
|
| 123 |
+
doc_index += 1
|
| 124 |
+
if doc_index >= len(documents_batch):
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
if pdf_docs:
|
| 128 |
+
full_text = " ".join([d.text for d in pdf_docs])
|
| 129 |
+
summary = summarize_document(full_text, client, CONFIG["llm_model"])
|
| 130 |
+
document_summaries[pdf_path.name] = summary
|
| 131 |
+
|
| 132 |
+
print(f"Summary for {pdf_path.name}:")
|
| 133 |
+
print(summary[:200] + "...\n")
|
| 134 |
+
|
| 135 |
+
for d in pdf_docs:
|
| 136 |
+
all_documents.append(
|
| 137 |
+
{
|
| 138 |
+
"text": d.text,
|
| 139 |
+
"source": pdf_path.name,
|
| 140 |
+
"metadata": d.metadata if hasattr(d, "metadata") else {},
|
| 141 |
+
}
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
print(f"Warning: No content extracted from {pdf_path.name}")
|
| 145 |
+
document_summaries[pdf_path.name] = "No content extracted"
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"Batch processing failed: {str(e)}")
|
| 149 |
+
print("\nFalling back to individual file processing with sleep delays...")
|
| 150 |
+
|
| 151 |
+
import time
|
| 152 |
+
|
| 153 |
+
for pdf_path in pdf_files:
|
| 154 |
+
print(f"\nParsing: {pdf_path.name}")
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
time.sleep(2)
|
| 158 |
+
documents = parser.load_data(str(pdf_path))
|
| 159 |
+
|
| 160 |
+
if documents:
|
| 161 |
+
full_text = " ".join([d.text for d in documents])
|
| 162 |
+
summary = summarize_document(full_text, client, CONFIG["llm_model"])
|
| 163 |
+
document_summaries[pdf_path.name] = summary
|
| 164 |
+
|
| 165 |
+
print(f"Summary for {pdf_path.name}:")
|
| 166 |
+
print(summary[:200] + "...\n")
|
| 167 |
+
|
| 168 |
+
for d in documents:
|
| 169 |
+
all_documents.append(
|
| 170 |
+
{
|
| 171 |
+
"text": d.text,
|
| 172 |
+
"source": pdf_path.name,
|
| 173 |
+
"metadata": d.metadata if hasattr(d, "metadata") else {},
|
| 174 |
+
}
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
print(f"Warning: No content extracted from {pdf_path.name}")
|
| 178 |
+
document_summaries[pdf_path.name] = "No content extracted"
|
| 179 |
+
|
| 180 |
+
except Exception as e2:
|
| 181 |
+
print(f"Error parsing {pdf_path.name}: {str(e2)}")
|
| 182 |
+
document_summaries[pdf_path.name] = f"Failed to parse: {str(e2)}"
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
return all_documents
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ---------- CHUNKING ----------
|
| 189 |
+
def chunk_documents(
|
| 190 |
+
documents: List[Dict],
|
| 191 |
+
chunk_size: int = 512,
|
| 192 |
+
chunk_overlap: int = 50,
|
| 193 |
+
) -> List[Dict]:
|
| 194 |
+
"""Chunk documents using semantic splitting."""
|
| 195 |
+
text_splitter = SentenceSplitter(
|
| 196 |
+
chunk_size=chunk_size,
|
| 197 |
+
chunk_overlap=chunk_overlap,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
all_chunks: List[Dict] = []
|
| 201 |
+
chunk_id = 0
|
| 202 |
+
|
| 203 |
+
for doc in documents:
|
| 204 |
+
chunks = text_splitter.split_text(doc["text"])
|
| 205 |
+
|
| 206 |
+
for chunk in chunks:
|
| 207 |
+
all_chunks.append(
|
| 208 |
+
{
|
| 209 |
+
"chunk_id": f"chunk_{chunk_id}",
|
| 210 |
+
"text": chunk,
|
| 211 |
+
"source": doc["source"],
|
| 212 |
+
"metadata": doc["metadata"],
|
| 213 |
+
}
|
| 214 |
+
)
|
| 215 |
+
chunk_id += 1
|
| 216 |
+
|
| 217 |
+
return all_chunks
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ---------- CHROMA (SBERT EMBEDDINGS, 384-D) ----------
|
| 221 |
+
def create_chromadb_collection(
|
| 222 |
+
chunks: List[Dict],
|
| 223 |
+
collection_name: str = "rag_documents",
|
| 224 |
+
) -> chromadb.Collection:
|
| 225 |
+
"""Create and populate ChromaDB collection using SentenceTransformer embeddings."""
|
| 226 |
+
|
| 227 |
+
sbert_ef = SentenceTransformerEmbeddingFunction(
|
| 228 |
+
model_name=CONFIG["embedding_model"],
|
| 229 |
+
device=CONFIG["embedding_device"],
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
client_db = chromadb.PersistentClient(
|
| 233 |
+
path=os.path.join(CONFIG["output_directory"], "chromadb")
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# ✅ Delete existing collection to avoid old 1536-D vectors
|
| 237 |
+
try:
|
| 238 |
+
client_db.delete_collection(collection_name)
|
| 239 |
+
print(f"Deleted existing collection: {collection_name}")
|
| 240 |
+
except Exception:
|
| 241 |
+
pass
|
| 242 |
+
|
| 243 |
+
collection = client_db.create_collection(
|
| 244 |
+
name=collection_name,
|
| 245 |
+
metadata={
|
| 246 |
+
"description": "RAG document chunks",
|
| 247 |
+
"embedding_model": CONFIG["embedding_model"],
|
| 248 |
+
"embedding_dim": 384, # MiniLM dim
|
| 249 |
+
},
|
| 250 |
+
embedding_function=sbert_ef,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
ids = [chunk["chunk_id"] for chunk in chunks]
|
| 254 |
+
documents = [chunk["text"] for chunk in chunks]
|
| 255 |
+
metadatas = [
|
| 256 |
+
{"source": chunk["source"], **(chunk["metadata"] or {})}
|
| 257 |
+
for chunk in chunks
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
batch_size = 100
|
| 261 |
+
for i in range(0, len(ids), batch_size):
|
| 262 |
+
batch_end = min(i + batch_size, len(ids))
|
| 263 |
+
|
| 264 |
+
collection.add(
|
| 265 |
+
ids=ids[i:batch_end],
|
| 266 |
+
documents=documents[i:batch_end],
|
| 267 |
+
metadatas=metadatas[i:batch_end],
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
print(
|
| 271 |
+
f"Added batch {i // batch_size + 1}/"
|
| 272 |
+
f"{(len(ids) - 1) // batch_size + 1}"
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
print(f"✓ ChromaDB collection created with {len(ids)} chunks")
|
| 276 |
+
return collection
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# ---------- MAIN ----------
|
| 280 |
+
def main():
|
| 281 |
+
print("✓ Starting pipeline with .env configuration (SentenceTransformer embeddings)")
|
| 282 |
+
|
| 283 |
+
print("Starting PDF parsing...")
|
| 284 |
+
parsed_documents = parse_pdfs_with_llamaparse(CONFIG["pdf_directory"])
|
| 285 |
+
print(f"\n✓ Parsed {len(parsed_documents)} document sections from PDFs")
|
| 286 |
+
|
| 287 |
+
chunks = chunk_documents(
|
| 288 |
+
parsed_documents,
|
| 289 |
+
CONFIG["chunk_size"],
|
| 290 |
+
CONFIG["chunk_overlap"],
|
| 291 |
+
)
|
| 292 |
+
print(f"✓ Created {len(chunks)} chunks")
|
| 293 |
+
if chunks:
|
| 294 |
+
print("\nSample chunk:")
|
| 295 |
+
print(chunks[0])
|
| 296 |
+
|
| 297 |
+
chroma_collection = create_chromadb_collection(chunks)
|
| 298 |
+
print("ChromaDB collection ready for querying.")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
|
| 4 |
+
pydantic
|
| 5 |
+
python-dotenv
|
| 6 |
+
|
| 7 |
+
pandas
|
| 8 |
+
|
| 9 |
+
llama-index-core
|
| 10 |
+
llama-parse
|
| 11 |
+
|
| 12 |
+
chromadb
|
| 13 |
+
sentence-transformers
|
| 14 |
+
rank-bm25
|
| 15 |
+
|
| 16 |
+
openai
|
| 17 |
+
nest-asyncio
|
| 18 |
+
|
| 19 |
+
numpy
|
retrieval.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
| 2 |
+
from pydantic import BaseModel, Field, computed_field
|
| 3 |
+
from typing import List, Optional, Dict, Any
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sentence_transformers import CrossEncoder
|
| 7 |
+
|
| 8 |
+
from vector_store import get_vector_store, VectorStoreManager
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
router = APIRouter(prefix="/retrieval", tags=["retrieval"])
|
| 12 |
+
|
| 13 |
+
_reranker = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_reranker():
|
| 17 |
+
global _reranker
|
| 18 |
+
if _reranker is None:
|
| 19 |
+
logger.info("Loading cross-encoder reranker...")
|
| 20 |
+
_reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
| 21 |
+
return _reranker
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class RetrievalRequest(BaseModel):
|
| 25 |
+
question: str = Field(..., min_length=1, max_length=500)
|
| 26 |
+
top_k: int = Field(default=5, ge=1, le=20)
|
| 27 |
+
|
| 28 |
+
filter_by_cluster: Optional[str] = None
|
| 29 |
+
filter_by_source: Optional[str] = None
|
| 30 |
+
filter_by_topic: Optional[str] = None
|
| 31 |
+
contains_text: Optional[str] = None
|
| 32 |
+
|
| 33 |
+
similarity_threshold: float = Field(default=1.0, ge=0.0, le=2.0)
|
| 34 |
+
|
| 35 |
+
# ✅ Hybrid retrieval toggles
|
| 36 |
+
enable_bm25: bool = Field(
|
| 37 |
+
default=False,
|
| 38 |
+
description="Enable BM25 + semantic hybrid retrieval",
|
| 39 |
+
)
|
| 40 |
+
bm25_k: int = Field(
|
| 41 |
+
default=20,
|
| 42 |
+
ge=5,
|
| 43 |
+
le=100,
|
| 44 |
+
description="How many BM25 candidates to consider",
|
| 45 |
+
)
|
| 46 |
+
hybrid_alpha: float = Field(
|
| 47 |
+
default=0.4,
|
| 48 |
+
ge=0.0,
|
| 49 |
+
le=1.0,
|
| 50 |
+
description="Dense weight in hybrid fusion (alpha=1 => semantic only)",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Reranking
|
| 54 |
+
enable_rerank: bool = Field(default=False)
|
| 55 |
+
rerank_top_k: int = Field(default=3, ge=1, le=10)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DocumentResult(BaseModel):
|
| 59 |
+
chunk_id: str
|
| 60 |
+
text: str
|
| 61 |
+
source: str
|
| 62 |
+
topic: Optional[str]
|
| 63 |
+
cluster: Optional[str]
|
| 64 |
+
distance: float
|
| 65 |
+
rerank_score: Optional[float] = None
|
| 66 |
+
|
| 67 |
+
@computed_field
|
| 68 |
+
@property
|
| 69 |
+
def relevance_label(self) -> str:
|
| 70 |
+
if self.distance < 0.8:
|
| 71 |
+
return "Highly Relevant"
|
| 72 |
+
elif self.distance < 1.0:
|
| 73 |
+
return "Relevant"
|
| 74 |
+
elif self.distance < 1.5:
|
| 75 |
+
return "Somewhat Relevant"
|
| 76 |
+
return "Low Relevance"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class RetrievalResponse(BaseModel):
|
| 80 |
+
documents: List[DocumentResult]
|
| 81 |
+
count: int
|
| 82 |
+
query: str
|
| 83 |
+
filters_applied: Dict[str, Any]
|
| 84 |
+
retrieval_stats: Dict[str, Any]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def rerank_documents(query: str, documents: List[DocumentResult], top_k: int = 3):
|
| 88 |
+
if not documents or len(documents) <= 1:
|
| 89 |
+
return documents
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
reranker = get_reranker()
|
| 93 |
+
pairs = [[query, doc.text[:1500]] for doc in documents]
|
| 94 |
+
|
| 95 |
+
scores = reranker.predict(pairs)
|
| 96 |
+
|
| 97 |
+
for doc, score in zip(documents, scores):
|
| 98 |
+
doc.rerank_score = float(score)
|
| 99 |
+
|
| 100 |
+
reranked = sorted(documents, key=lambda x: x.rerank_score or 0.0, reverse=True)
|
| 101 |
+
return reranked[:top_k]
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"Reranking failed: {str(e)}, returning original results")
|
| 105 |
+
return documents[:top_k]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@router.post("/search", response_model=RetrievalResponse)
|
| 109 |
+
async def retrieve_documents_endpoint(
|
| 110 |
+
request: RetrievalRequest,
|
| 111 |
+
vector_store: VectorStoreManager = Depends(get_vector_store),
|
| 112 |
+
):
|
| 113 |
+
try:
|
| 114 |
+
logger.info(f"Processing query: '{request.question}' top_k={request.top_k}")
|
| 115 |
+
|
| 116 |
+
where_filters: Dict[str, Any] = {}
|
| 117 |
+
if request.filter_by_cluster:
|
| 118 |
+
where_filters["cluster"] = request.filter_by_cluster
|
| 119 |
+
if request.filter_by_source:
|
| 120 |
+
where_filters["source"] = request.filter_by_source
|
| 121 |
+
if request.filter_by_topic:
|
| 122 |
+
where_filters["topic"] = request.filter_by_topic
|
| 123 |
+
|
| 124 |
+
where_document = {"$contains": request.contains_text} if request.contains_text else None
|
| 125 |
+
|
| 126 |
+
# If reranking or hybrid, fetch more candidates
|
| 127 |
+
n_candidates = request.top_k * 3 if (request.enable_rerank or request.enable_bm25) else request.top_k
|
| 128 |
+
|
| 129 |
+
candidates = vector_store.retrieve_documents(
|
| 130 |
+
question=request.question,
|
| 131 |
+
n_results=n_candidates,
|
| 132 |
+
where_filters=where_filters if where_filters else None,
|
| 133 |
+
where_document=where_document,
|
| 134 |
+
enable_bm25=request.enable_bm25,
|
| 135 |
+
bm25_k=request.bm25_k,
|
| 136 |
+
alpha=request.hybrid_alpha,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
documents: List[DocumentResult] = []
|
| 140 |
+
filtered_count = 0
|
| 141 |
+
|
| 142 |
+
for c in candidates:
|
| 143 |
+
distance = c.get("distance")
|
| 144 |
+
# if candidate came only from BM25, distance may be None
|
| 145 |
+
if distance is None:
|
| 146 |
+
distance = 1.5 # treat as weak semantic match
|
| 147 |
+
|
| 148 |
+
if distance <= request.similarity_threshold:
|
| 149 |
+
meta = c.get("metadata") or {}
|
| 150 |
+
documents.append(
|
| 151 |
+
DocumentResult(
|
| 152 |
+
chunk_id=c["id"],
|
| 153 |
+
text=c["text"],
|
| 154 |
+
source=meta.get("source", "Unknown"),
|
| 155 |
+
topic=meta.get("topic"),
|
| 156 |
+
cluster=meta.get("cluster"),
|
| 157 |
+
distance=float(distance),
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
filtered_count += 1
|
| 162 |
+
|
| 163 |
+
total_retrieved = len(candidates)
|
| 164 |
+
|
| 165 |
+
# Rerank if enabled
|
| 166 |
+
if request.enable_rerank and len(documents) > 1:
|
| 167 |
+
documents = rerank_documents(request.question, documents, request.rerank_top_k)
|
| 168 |
+
retrieval_method = "hybrid_with_rerank" if request.enable_bm25 else "semantic_with_rerank"
|
| 169 |
+
else:
|
| 170 |
+
documents = documents[:request.top_k]
|
| 171 |
+
retrieval_method = "hybrid" if request.enable_bm25 else "semantic"
|
| 172 |
+
|
| 173 |
+
distances = [d.distance for d in documents]
|
| 174 |
+
avg_distance = float(np.mean(distances)) if distances else None
|
| 175 |
+
best_distance = min(distances) if distances else None
|
| 176 |
+
|
| 177 |
+
return RetrievalResponse(
|
| 178 |
+
documents=documents,
|
| 179 |
+
count=len(documents),
|
| 180 |
+
query=request.question,
|
| 181 |
+
filters_applied={
|
| 182 |
+
"cluster": request.filter_by_cluster,
|
| 183 |
+
"source": request.filter_by_source,
|
| 184 |
+
"topic": request.filter_by_topic,
|
| 185 |
+
"contains_text": request.contains_text,
|
| 186 |
+
"similarity_threshold": request.similarity_threshold,
|
| 187 |
+
"enable_bm25": request.enable_bm25,
|
| 188 |
+
"bm25_k": request.bm25_k,
|
| 189 |
+
"hybrid_alpha": request.hybrid_alpha,
|
| 190 |
+
},
|
| 191 |
+
retrieval_stats={
|
| 192 |
+
"method": retrieval_method,
|
| 193 |
+
"total_retrieved": total_retrieved,
|
| 194 |
+
"filtered_by_threshold": filtered_count,
|
| 195 |
+
"returned": len(documents),
|
| 196 |
+
"best_distance": best_distance,
|
| 197 |
+
"avg_distance": avg_distance,
|
| 198 |
+
"reranking_applied": request.enable_rerank,
|
| 199 |
+
"bm25_applied": request.enable_bm25,
|
| 200 |
+
},
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error(f"Retrieval failed: {str(e)}", exc_info=True)
|
| 205 |
+
raise HTTPException(
|
| 206 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 207 |
+
detail=f"Retrieval failed: {str(e)}",
|
| 208 |
+
)
|
vector_store.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional, Dict, Any, List
|
| 3 |
+
import threading
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import chromadb
|
| 8 |
+
from rank_bm25 import BM25Okapi
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VectorStoreManager:
|
| 14 |
+
_instance = None
|
| 15 |
+
_lock = threading.Lock()
|
| 16 |
+
_initialized = False
|
| 17 |
+
|
| 18 |
+
def __new__(cls):
|
| 19 |
+
with cls._lock:
|
| 20 |
+
if cls._instance is None:
|
| 21 |
+
cls._instance = super().__new__(cls)
|
| 22 |
+
return cls._instance
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
with self._lock:
|
| 26 |
+
if not self._initialized:
|
| 27 |
+
self._initialize()
|
| 28 |
+
VectorStoreManager._initialized = True
|
| 29 |
+
|
| 30 |
+
def _initialize(self):
|
| 31 |
+
"""Initialize vector store with single collection + BM25 index"""
|
| 32 |
+
try:
|
| 33 |
+
logger.info("Initializing vector store components...")
|
| 34 |
+
|
| 35 |
+
self.client = None
|
| 36 |
+
self.collection = None
|
| 37 |
+
|
| 38 |
+
db_path = "output/chromadb" # Match your pipeline path
|
| 39 |
+
self.client = chromadb.PersistentClient(path=db_path)
|
| 40 |
+
logger.info(f"ChromaDB client initialized at path: {db_path}")
|
| 41 |
+
|
| 42 |
+
available_collections = [col.name for col in self.client.list_collections()]
|
| 43 |
+
logger.info(f"Available collections: {available_collections}")
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
self.collection = self.client.get_collection("rag_documents")
|
| 47 |
+
collection_count = self.collection.count()
|
| 48 |
+
logger.info(
|
| 49 |
+
f"Collection 'rag_documents' loaded with {collection_count} documents"
|
| 50 |
+
)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"Collection 'rag_documents' not found: {str(e)}")
|
| 53 |
+
raise ValueError(
|
| 54 |
+
"Required collection 'rag_documents' not found. "
|
| 55 |
+
f"Available: {available_collections}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# ---- Build BM25 index from all stored docs ----
|
| 59 |
+
logger.info("Building BM25 index from Chroma collection...")
|
| 60 |
+
data = self.collection.get(include=["documents", "metadatas"])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
self.all_ids: List[str] = data["ids"]
|
| 64 |
+
self.all_docs: List[str] = data["documents"]
|
| 65 |
+
self.all_metas: List[Dict[str, Any]] = data["metadatas"]
|
| 66 |
+
|
| 67 |
+
self.tokenized_corpus = [self._tokenize(d) for d in self.all_docs]
|
| 68 |
+
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 69 |
+
|
| 70 |
+
logger.info(f"BM25 index ready with {len(self.all_docs)} chunks")
|
| 71 |
+
logger.info("Vector store initialized successfully")
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Failed to initialize vector store: {str(e)}")
|
| 75 |
+
VectorStoreManager._initialized = False
|
| 76 |
+
raise
|
| 77 |
+
|
| 78 |
+
# ----------------- Helpers -----------------
|
| 79 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 80 |
+
return re.findall(r"\w+", (text or "").lower())
|
| 81 |
+
|
| 82 |
+
def _matches_filters(
|
| 83 |
+
self,
|
| 84 |
+
meta: Dict[str, Any],
|
| 85 |
+
doc_text: str,
|
| 86 |
+
where_filters: Optional[Dict[str, Any]],
|
| 87 |
+
where_document: Optional[Dict[str, Any]],
|
| 88 |
+
) -> bool:
|
| 89 |
+
if where_filters:
|
| 90 |
+
for k, v in where_filters.items():
|
| 91 |
+
if meta.get(k) != v:
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
if where_document:
|
| 95 |
+
# you only use {"$contains": "..."}
|
| 96 |
+
contains = where_document.get("$contains")
|
| 97 |
+
if contains and contains.lower() not in (doc_text or "").lower():
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
def _rrf_fuse(
|
| 103 |
+
self,
|
| 104 |
+
dense_ranked: List[Dict[str, Any]],
|
| 105 |
+
sparse_ranked: List[Dict[str, Any]],
|
| 106 |
+
k: int = 60,
|
| 107 |
+
w_dense: float = 0.6,
|
| 108 |
+
w_sparse: float = 0.4,
|
| 109 |
+
) -> List[Dict[str, Any]]:
|
| 110 |
+
"""
|
| 111 |
+
Reciprocal Rank Fusion
|
| 112 |
+
score = w_dense/(k+rank_dense) + w_sparse/(k+rank_sparse)
|
| 113 |
+
"""
|
| 114 |
+
scores: Dict[str, Dict[str, Any]] = {}
|
| 115 |
+
|
| 116 |
+
for rank, item in enumerate(dense_ranked):
|
| 117 |
+
doc_id = item["id"]
|
| 118 |
+
scores.setdefault(doc_id, {"score": 0.0, "item": item})
|
| 119 |
+
scores[doc_id]["score"] += w_dense / (k + rank + 1)
|
| 120 |
+
|
| 121 |
+
for rank, item in enumerate(sparse_ranked):
|
| 122 |
+
doc_id = item["id"]
|
| 123 |
+
scores.setdefault(doc_id, {"score": 0.0, "item": item})
|
| 124 |
+
scores[doc_id]["score"] += w_sparse / (k + rank + 1)
|
| 125 |
+
|
| 126 |
+
fused = sorted(scores.values(), key=lambda x: x["score"], reverse=True)
|
| 127 |
+
return [x["item"] for x in fused]
|
| 128 |
+
|
| 129 |
+
# ----------------- Main retrieval -----------------
|
| 130 |
+
def retrieve_documents(
|
| 131 |
+
self,
|
| 132 |
+
question: str,
|
| 133 |
+
n_results: int = 5,
|
| 134 |
+
where_filters: Optional[Dict[str, Any]] = None,
|
| 135 |
+
where_document: Optional[Dict[str, Any]] = None,
|
| 136 |
+
enable_bm25: bool = False,
|
| 137 |
+
bm25_k: Optional[int] = None,
|
| 138 |
+
alpha: float = 0.6, # dense weight in hybrid fusion
|
| 139 |
+
) -> List[Dict[str, Any]]:
|
| 140 |
+
"""
|
| 141 |
+
Retrieve documents using:
|
| 142 |
+
- semantic-only (Chroma)
|
| 143 |
+
- or hybrid semantic + BM25 (RRF fusion)
|
| 144 |
+
|
| 145 |
+
Returns a list of dicts:
|
| 146 |
+
{id, text, metadata, distance, bm25_score(optional)}
|
| 147 |
+
"""
|
| 148 |
+
if not self._initialized or self.collection is None:
|
| 149 |
+
raise RuntimeError("VectorStoreManager not properly initialized")
|
| 150 |
+
|
| 151 |
+
logger.info(f"Retrieving documents for query: {question[:50]}...")
|
| 152 |
+
dense_k = n_results
|
| 153 |
+
bm25_k = bm25_k or n_results
|
| 154 |
+
|
| 155 |
+
# ----- Dense retrieval (semantic via Chroma) -----
|
| 156 |
+
try:
|
| 157 |
+
dense_res = self.collection.query(
|
| 158 |
+
query_texts=[question],
|
| 159 |
+
n_results=dense_k,
|
| 160 |
+
include=["documents", "metadatas", "distances"],
|
| 161 |
+
where=where_filters if where_filters else None,
|
| 162 |
+
where_document=where_document if where_document else None,
|
| 163 |
+
)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"Dense retrieval failed: {str(e)}")
|
| 166 |
+
raise
|
| 167 |
+
|
| 168 |
+
dense_ranked: List[Dict[str, Any]] = []
|
| 169 |
+
if dense_res and dense_res.get("documents") and dense_res["documents"][0]:
|
| 170 |
+
for i in range(len(dense_res["documents"][0])):
|
| 171 |
+
meta = dense_res["metadatas"][0][i]
|
| 172 |
+
dense_ranked.append({
|
| 173 |
+
"id": dense_res["ids"][0][i],
|
| 174 |
+
"text": dense_res["documents"][0][i],
|
| 175 |
+
"metadata": meta,
|
| 176 |
+
"distance": float(dense_res["distances"][0][i]),
|
| 177 |
+
"source": meta.get("source", "Unknown"),
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
if not enable_bm25:
|
| 181 |
+
logger.info(f"Semantic-only retrieved {len(dense_ranked)} docs")
|
| 182 |
+
return dense_ranked
|
| 183 |
+
|
| 184 |
+
# ----- Sparse retrieval (BM25) -----
|
| 185 |
+
q_tokens = self._tokenize(question)
|
| 186 |
+
scores = self.bm25.get_scores(q_tokens)
|
| 187 |
+
|
| 188 |
+
# Apply same filters to BM25 corpus
|
| 189 |
+
valid_indices = []
|
| 190 |
+
for idx, (doc, meta) in enumerate(zip(self.all_docs, self.all_metas)):
|
| 191 |
+
if self._matches_filters(meta, doc, where_filters, where_document):
|
| 192 |
+
valid_indices.append(idx)
|
| 193 |
+
|
| 194 |
+
# take top bm25_k from valid indices
|
| 195 |
+
valid_scores = [(idx, scores[idx]) for idx in valid_indices]
|
| 196 |
+
valid_scores.sort(key=lambda x: x[1], reverse=True)
|
| 197 |
+
top_sparse = valid_scores[:bm25_k]
|
| 198 |
+
|
| 199 |
+
sparse_ranked: List[Dict[str, Any]] = []
|
| 200 |
+
for idx, s in top_sparse:
|
| 201 |
+
meta = self.all_metas[idx]
|
| 202 |
+
sparse_ranked.append({
|
| 203 |
+
"id": self.all_ids[idx],
|
| 204 |
+
"text": self.all_docs[idx],
|
| 205 |
+
"metadata": meta,
|
| 206 |
+
"bm25_score": float(s),
|
| 207 |
+
"distance": None, # may be absent if not in dense top-k
|
| 208 |
+
"source": meta.get("source", "Unknown"),
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
# ----- Fuse dense + sparse -----
|
| 212 |
+
fused = self._rrf_fuse(
|
| 213 |
+
dense_ranked,
|
| 214 |
+
sparse_ranked,
|
| 215 |
+
w_dense=alpha,
|
| 216 |
+
w_sparse=1.0 - alpha,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
logger.info(
|
| 220 |
+
f"Hybrid retrieved dense={len(dense_ranked)} sparse={len(sparse_ranked)} "
|
| 221 |
+
f"fused={len(fused)}"
|
| 222 |
+
)
|
| 223 |
+
return fused
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_vector_store() -> VectorStoreManager:
|
| 227 |
+
"""FastAPI dependency for injecting VectorStoreManager"""
|
| 228 |
+
return VectorStoreManager()
|