mokhles commited on
Commit
af37875
·
1 Parent(s): b891d55

Initial commit: Insurance RAG API

Browse files
Files changed (7) hide show
  1. Dockerfile +16 -0
  2. README.md +5 -4
  3. app.py +23 -0
  4. chroma.py +302 -0
  5. requirements.txt +19 -0
  6. retrieval.py +208 -0
  7. vector_store.py +228 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY . .
13
+
14
+ EXPOSE 7860
15
+
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: RAG Insurance
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Insurance Rag Api
3
+ emoji: 🌖
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
+ short_description: Production‑ready FastAPI Retrieval‑Augmented Generation (RAG
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+
5
+ from retrieval import router as retrieval_router
6
+
7
+
8
+ app = FastAPI(title="Insurance RAG API", version="1.0.0")
9
+
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"],
13
+ allow_credentials=True,
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ app.include_router(retrieval_router)
19
+
20
+
21
+ @app.get("/")
22
+ async def root():
23
+ return {"message": "Insurance RAG API is running", "docs": "/docs"}
chroma.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chroma.py (minimal, no visualization, WITH sentence-transformers, with .env)
2
+
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+ from typing import List, Dict
7
+
8
+ import pandas as pd # (currently unused but kept if you need it later)
9
+ from dotenv import load_dotenv
10
+
11
+ from llama_parse import LlamaParse
12
+ from llama_index.core.node_parser import SentenceSplitter
13
+
14
+ import chromadb
15
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
16
+ from openai import OpenAI
17
+
18
+ import nest_asyncio
19
+ nest_asyncio.apply()
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+ # ---------- LOAD .env ----------
24
+ load_dotenv()
25
+
26
+ # ---------- CONFIG ----------
27
+ CONFIG = {
28
+ "pdf_directory": r"C:\Users\Legion\Documents\Ominimo Job\Pdfs for RAG",
29
+ "output_directory": "./output/",
30
+ "llm_model": "gpt-4.1-mini",
31
+ "chunk_size": 512,
32
+ "chunk_overlap": 50,
33
+ "top_k_retrieval": 3,
34
+
35
+ # ✅ SentenceTransformer embedding model (384-D for MiniLM)
36
+ # Must match your retrieval embedding model.
37
+ "embedding_model": "all-MiniLM-L6-v2",
38
+
39
+ # Optional: force device ("cpu" or "cuda")
40
+ "embedding_device": os.getenv("EMB_DEVICE", "cpu"),
41
+ }
42
+
43
+ Path(CONFIG["output_directory"]).mkdir(parents=True, exist_ok=True)
44
+
45
+ # ---------- OPENAI CLIENT (for summaries only) ----------
46
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
47
+ if not OPENAI_API_KEY:
48
+ raise RuntimeError("OPENAI_API_KEY is not set in the environment or .env file.")
49
+
50
+ client = OpenAI(api_key=OPENAI_API_KEY)
51
+ document_summaries: Dict[str, str] = {}
52
+
53
+
54
+ def summarize_document(text: str, client: OpenAI, model: str) -> str:
55
+ """Generate a summary of the document using OpenAI (used only for summaries)."""
56
+ response = client.chat.completions.create(
57
+ model=model,
58
+ messages=[
59
+ {
60
+ "role": "system",
61
+ "content": (
62
+ "You are a helpful assistant that creates concise "
63
+ "summaries of documents."
64
+ ),
65
+ },
66
+ {
67
+ "role": "user",
68
+ "content": (
69
+ "Please provide a comprehensive summary of the "
70
+ "following document:\n\n"
71
+ f"{text[:4000]}"
72
+ ),
73
+ },
74
+ ],
75
+ temperature=0.3,
76
+ max_tokens=500,
77
+ )
78
+ return response.choices[0].message.content
79
+
80
+
81
+ # ---------- PDF PARSING ----------
82
+ def parse_pdfs_with_llamaparse(pdf_directory: str) -> List[Dict]:
83
+ """Parse PDFs using LlamaParse with batch processing."""
84
+ pdf_files = list(Path(pdf_directory).glob("*.pdf"))
85
+ print(f"Found {len(pdf_files)} PDF files")
86
+
87
+ llama_key = os.environ.get("LLAMA_CLOUD_API_KEY")
88
+ if not llama_key:
89
+ raise RuntimeError("LLAMA_CLOUD_API_KEY is not set in the environment or .env.")
90
+
91
+ parser = LlamaParse(
92
+ api_key=llama_key,
93
+ result_type="markdown",
94
+ verbose=True,
95
+ language="en",
96
+ num_workers=4,
97
+ )
98
+
99
+ all_documents: List[Dict] = []
100
+
101
+ try:
102
+ print("\nParsing all PDFs in batch...")
103
+ pdf_paths = [str(pdf) for pdf in pdf_files]
104
+ documents_batch = parser.load_data(pdf_paths)
105
+ print(f"✓ Successfully parsed {len(documents_batch)} document sections")
106
+
107
+ doc_index = 0
108
+ for pdf_path in pdf_files:
109
+ print(f"\nProcessing: {pdf_path.name}")
110
+ pdf_docs = []
111
+
112
+ while doc_index < len(documents_batch):
113
+ doc = documents_batch[doc_index]
114
+
115
+ if hasattr(doc, "metadata") and doc.metadata.get("file_path"):
116
+ if pdf_path.name in doc.metadata.get("file_path", ""):
117
+ pdf_docs.append(doc)
118
+ doc_index += 1
119
+ else:
120
+ break
121
+ else:
122
+ pdf_docs.append(doc)
123
+ doc_index += 1
124
+ if doc_index >= len(documents_batch):
125
+ break
126
+
127
+ if pdf_docs:
128
+ full_text = " ".join([d.text for d in pdf_docs])
129
+ summary = summarize_document(full_text, client, CONFIG["llm_model"])
130
+ document_summaries[pdf_path.name] = summary
131
+
132
+ print(f"Summary for {pdf_path.name}:")
133
+ print(summary[:200] + "...\n")
134
+
135
+ for d in pdf_docs:
136
+ all_documents.append(
137
+ {
138
+ "text": d.text,
139
+ "source": pdf_path.name,
140
+ "metadata": d.metadata if hasattr(d, "metadata") else {},
141
+ }
142
+ )
143
+ else:
144
+ print(f"Warning: No content extracted from {pdf_path.name}")
145
+ document_summaries[pdf_path.name] = "No content extracted"
146
+
147
+ except Exception as e:
148
+ print(f"Batch processing failed: {str(e)}")
149
+ print("\nFalling back to individual file processing with sleep delays...")
150
+
151
+ import time
152
+
153
+ for pdf_path in pdf_files:
154
+ print(f"\nParsing: {pdf_path.name}")
155
+
156
+ try:
157
+ time.sleep(2)
158
+ documents = parser.load_data(str(pdf_path))
159
+
160
+ if documents:
161
+ full_text = " ".join([d.text for d in documents])
162
+ summary = summarize_document(full_text, client, CONFIG["llm_model"])
163
+ document_summaries[pdf_path.name] = summary
164
+
165
+ print(f"Summary for {pdf_path.name}:")
166
+ print(summary[:200] + "...\n")
167
+
168
+ for d in documents:
169
+ all_documents.append(
170
+ {
171
+ "text": d.text,
172
+ "source": pdf_path.name,
173
+ "metadata": d.metadata if hasattr(d, "metadata") else {},
174
+ }
175
+ )
176
+ else:
177
+ print(f"Warning: No content extracted from {pdf_path.name}")
178
+ document_summaries[pdf_path.name] = "No content extracted"
179
+
180
+ except Exception as e2:
181
+ print(f"Error parsing {pdf_path.name}: {str(e2)}")
182
+ document_summaries[pdf_path.name] = f"Failed to parse: {str(e2)}"
183
+ continue
184
+
185
+ return all_documents
186
+
187
+
188
+ # ---------- CHUNKING ----------
189
+ def chunk_documents(
190
+ documents: List[Dict],
191
+ chunk_size: int = 512,
192
+ chunk_overlap: int = 50,
193
+ ) -> List[Dict]:
194
+ """Chunk documents using semantic splitting."""
195
+ text_splitter = SentenceSplitter(
196
+ chunk_size=chunk_size,
197
+ chunk_overlap=chunk_overlap,
198
+ )
199
+
200
+ all_chunks: List[Dict] = []
201
+ chunk_id = 0
202
+
203
+ for doc in documents:
204
+ chunks = text_splitter.split_text(doc["text"])
205
+
206
+ for chunk in chunks:
207
+ all_chunks.append(
208
+ {
209
+ "chunk_id": f"chunk_{chunk_id}",
210
+ "text": chunk,
211
+ "source": doc["source"],
212
+ "metadata": doc["metadata"],
213
+ }
214
+ )
215
+ chunk_id += 1
216
+
217
+ return all_chunks
218
+
219
+
220
+ # ---------- CHROMA (SBERT EMBEDDINGS, 384-D) ----------
221
+ def create_chromadb_collection(
222
+ chunks: List[Dict],
223
+ collection_name: str = "rag_documents",
224
+ ) -> chromadb.Collection:
225
+ """Create and populate ChromaDB collection using SentenceTransformer embeddings."""
226
+
227
+ sbert_ef = SentenceTransformerEmbeddingFunction(
228
+ model_name=CONFIG["embedding_model"],
229
+ device=CONFIG["embedding_device"],
230
+ )
231
+
232
+ client_db = chromadb.PersistentClient(
233
+ path=os.path.join(CONFIG["output_directory"], "chromadb")
234
+ )
235
+
236
+ # ✅ Delete existing collection to avoid old 1536-D vectors
237
+ try:
238
+ client_db.delete_collection(collection_name)
239
+ print(f"Deleted existing collection: {collection_name}")
240
+ except Exception:
241
+ pass
242
+
243
+ collection = client_db.create_collection(
244
+ name=collection_name,
245
+ metadata={
246
+ "description": "RAG document chunks",
247
+ "embedding_model": CONFIG["embedding_model"],
248
+ "embedding_dim": 384, # MiniLM dim
249
+ },
250
+ embedding_function=sbert_ef,
251
+ )
252
+
253
+ ids = [chunk["chunk_id"] for chunk in chunks]
254
+ documents = [chunk["text"] for chunk in chunks]
255
+ metadatas = [
256
+ {"source": chunk["source"], **(chunk["metadata"] or {})}
257
+ for chunk in chunks
258
+ ]
259
+
260
+ batch_size = 100
261
+ for i in range(0, len(ids), batch_size):
262
+ batch_end = min(i + batch_size, len(ids))
263
+
264
+ collection.add(
265
+ ids=ids[i:batch_end],
266
+ documents=documents[i:batch_end],
267
+ metadatas=metadatas[i:batch_end],
268
+ )
269
+
270
+ print(
271
+ f"Added batch {i // batch_size + 1}/"
272
+ f"{(len(ids) - 1) // batch_size + 1}"
273
+ )
274
+
275
+ print(f"✓ ChromaDB collection created with {len(ids)} chunks")
276
+ return collection
277
+
278
+
279
+ # ---------- MAIN ----------
280
+ def main():
281
+ print("✓ Starting pipeline with .env configuration (SentenceTransformer embeddings)")
282
+
283
+ print("Starting PDF parsing...")
284
+ parsed_documents = parse_pdfs_with_llamaparse(CONFIG["pdf_directory"])
285
+ print(f"\n✓ Parsed {len(parsed_documents)} document sections from PDFs")
286
+
287
+ chunks = chunk_documents(
288
+ parsed_documents,
289
+ CONFIG["chunk_size"],
290
+ CONFIG["chunk_overlap"],
291
+ )
292
+ print(f"✓ Created {len(chunks)} chunks")
293
+ if chunks:
294
+ print("\nSample chunk:")
295
+ print(chunks[0])
296
+
297
+ chroma_collection = create_chromadb_collection(chunks)
298
+ print("ChromaDB collection ready for querying.")
299
+
300
+
301
+ if __name__ == "__main__":
302
+ main()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+
4
+ pydantic
5
+ python-dotenv
6
+
7
+ pandas
8
+
9
+ llama-index-core
10
+ llama-parse
11
+
12
+ chromadb
13
+ sentence-transformers
14
+ rank-bm25
15
+
16
+ openai
17
+ nest-asyncio
18
+
19
+ numpy
retrieval.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, status, Query
2
+ from pydantic import BaseModel, Field, computed_field
3
+ from typing import List, Optional, Dict, Any
4
+ import logging
5
+ import numpy as np
6
+ from sentence_transformers import CrossEncoder
7
+
8
+ from vector_store import get_vector_store, VectorStoreManager
9
+
10
+ logger = logging.getLogger(__name__)
11
+ router = APIRouter(prefix="/retrieval", tags=["retrieval"])
12
+
13
+ _reranker = None
14
+
15
+
16
+ def get_reranker():
17
+ global _reranker
18
+ if _reranker is None:
19
+ logger.info("Loading cross-encoder reranker...")
20
+ _reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
21
+ return _reranker
22
+
23
+
24
+ class RetrievalRequest(BaseModel):
25
+ question: str = Field(..., min_length=1, max_length=500)
26
+ top_k: int = Field(default=5, ge=1, le=20)
27
+
28
+ filter_by_cluster: Optional[str] = None
29
+ filter_by_source: Optional[str] = None
30
+ filter_by_topic: Optional[str] = None
31
+ contains_text: Optional[str] = None
32
+
33
+ similarity_threshold: float = Field(default=1.0, ge=0.0, le=2.0)
34
+
35
+ # ✅ Hybrid retrieval toggles
36
+ enable_bm25: bool = Field(
37
+ default=False,
38
+ description="Enable BM25 + semantic hybrid retrieval",
39
+ )
40
+ bm25_k: int = Field(
41
+ default=20,
42
+ ge=5,
43
+ le=100,
44
+ description="How many BM25 candidates to consider",
45
+ )
46
+ hybrid_alpha: float = Field(
47
+ default=0.4,
48
+ ge=0.0,
49
+ le=1.0,
50
+ description="Dense weight in hybrid fusion (alpha=1 => semantic only)",
51
+ )
52
+
53
+ # Reranking
54
+ enable_rerank: bool = Field(default=False)
55
+ rerank_top_k: int = Field(default=3, ge=1, le=10)
56
+
57
+
58
+ class DocumentResult(BaseModel):
59
+ chunk_id: str
60
+ text: str
61
+ source: str
62
+ topic: Optional[str]
63
+ cluster: Optional[str]
64
+ distance: float
65
+ rerank_score: Optional[float] = None
66
+
67
+ @computed_field
68
+ @property
69
+ def relevance_label(self) -> str:
70
+ if self.distance < 0.8:
71
+ return "Highly Relevant"
72
+ elif self.distance < 1.0:
73
+ return "Relevant"
74
+ elif self.distance < 1.5:
75
+ return "Somewhat Relevant"
76
+ return "Low Relevance"
77
+
78
+
79
+ class RetrievalResponse(BaseModel):
80
+ documents: List[DocumentResult]
81
+ count: int
82
+ query: str
83
+ filters_applied: Dict[str, Any]
84
+ retrieval_stats: Dict[str, Any]
85
+
86
+
87
+ def rerank_documents(query: str, documents: List[DocumentResult], top_k: int = 3):
88
+ if not documents or len(documents) <= 1:
89
+ return documents
90
+
91
+ try:
92
+ reranker = get_reranker()
93
+ pairs = [[query, doc.text[:1500]] for doc in documents]
94
+
95
+ scores = reranker.predict(pairs)
96
+
97
+ for doc, score in zip(documents, scores):
98
+ doc.rerank_score = float(score)
99
+
100
+ reranked = sorted(documents, key=lambda x: x.rerank_score or 0.0, reverse=True)
101
+ return reranked[:top_k]
102
+
103
+ except Exception as e:
104
+ logger.error(f"Reranking failed: {str(e)}, returning original results")
105
+ return documents[:top_k]
106
+
107
+
108
+ @router.post("/search", response_model=RetrievalResponse)
109
+ async def retrieve_documents_endpoint(
110
+ request: RetrievalRequest,
111
+ vector_store: VectorStoreManager = Depends(get_vector_store),
112
+ ):
113
+ try:
114
+ logger.info(f"Processing query: '{request.question}' top_k={request.top_k}")
115
+
116
+ where_filters: Dict[str, Any] = {}
117
+ if request.filter_by_cluster:
118
+ where_filters["cluster"] = request.filter_by_cluster
119
+ if request.filter_by_source:
120
+ where_filters["source"] = request.filter_by_source
121
+ if request.filter_by_topic:
122
+ where_filters["topic"] = request.filter_by_topic
123
+
124
+ where_document = {"$contains": request.contains_text} if request.contains_text else None
125
+
126
+ # If reranking or hybrid, fetch more candidates
127
+ n_candidates = request.top_k * 3 if (request.enable_rerank or request.enable_bm25) else request.top_k
128
+
129
+ candidates = vector_store.retrieve_documents(
130
+ question=request.question,
131
+ n_results=n_candidates,
132
+ where_filters=where_filters if where_filters else None,
133
+ where_document=where_document,
134
+ enable_bm25=request.enable_bm25,
135
+ bm25_k=request.bm25_k,
136
+ alpha=request.hybrid_alpha,
137
+ )
138
+
139
+ documents: List[DocumentResult] = []
140
+ filtered_count = 0
141
+
142
+ for c in candidates:
143
+ distance = c.get("distance")
144
+ # if candidate came only from BM25, distance may be None
145
+ if distance is None:
146
+ distance = 1.5 # treat as weak semantic match
147
+
148
+ if distance <= request.similarity_threshold:
149
+ meta = c.get("metadata") or {}
150
+ documents.append(
151
+ DocumentResult(
152
+ chunk_id=c["id"],
153
+ text=c["text"],
154
+ source=meta.get("source", "Unknown"),
155
+ topic=meta.get("topic"),
156
+ cluster=meta.get("cluster"),
157
+ distance=float(distance),
158
+ )
159
+ )
160
+ else:
161
+ filtered_count += 1
162
+
163
+ total_retrieved = len(candidates)
164
+
165
+ # Rerank if enabled
166
+ if request.enable_rerank and len(documents) > 1:
167
+ documents = rerank_documents(request.question, documents, request.rerank_top_k)
168
+ retrieval_method = "hybrid_with_rerank" if request.enable_bm25 else "semantic_with_rerank"
169
+ else:
170
+ documents = documents[:request.top_k]
171
+ retrieval_method = "hybrid" if request.enable_bm25 else "semantic"
172
+
173
+ distances = [d.distance for d in documents]
174
+ avg_distance = float(np.mean(distances)) if distances else None
175
+ best_distance = min(distances) if distances else None
176
+
177
+ return RetrievalResponse(
178
+ documents=documents,
179
+ count=len(documents),
180
+ query=request.question,
181
+ filters_applied={
182
+ "cluster": request.filter_by_cluster,
183
+ "source": request.filter_by_source,
184
+ "topic": request.filter_by_topic,
185
+ "contains_text": request.contains_text,
186
+ "similarity_threshold": request.similarity_threshold,
187
+ "enable_bm25": request.enable_bm25,
188
+ "bm25_k": request.bm25_k,
189
+ "hybrid_alpha": request.hybrid_alpha,
190
+ },
191
+ retrieval_stats={
192
+ "method": retrieval_method,
193
+ "total_retrieved": total_retrieved,
194
+ "filtered_by_threshold": filtered_count,
195
+ "returned": len(documents),
196
+ "best_distance": best_distance,
197
+ "avg_distance": avg_distance,
198
+ "reranking_applied": request.enable_rerank,
199
+ "bm25_applied": request.enable_bm25,
200
+ },
201
+ )
202
+
203
+ except Exception as e:
204
+ logger.error(f"Retrieval failed: {str(e)}", exc_info=True)
205
+ raise HTTPException(
206
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
207
+ detail=f"Retrieval failed: {str(e)}",
208
+ )
vector_store.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Dict, Any, List
3
+ import threading
4
+ import re
5
+
6
+ import numpy as np
7
+ import chromadb
8
+ from rank_bm25 import BM25Okapi
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class VectorStoreManager:
14
+ _instance = None
15
+ _lock = threading.Lock()
16
+ _initialized = False
17
+
18
+ def __new__(cls):
19
+ with cls._lock:
20
+ if cls._instance is None:
21
+ cls._instance = super().__new__(cls)
22
+ return cls._instance
23
+
24
+ def __init__(self):
25
+ with self._lock:
26
+ if not self._initialized:
27
+ self._initialize()
28
+ VectorStoreManager._initialized = True
29
+
30
+ def _initialize(self):
31
+ """Initialize vector store with single collection + BM25 index"""
32
+ try:
33
+ logger.info("Initializing vector store components...")
34
+
35
+ self.client = None
36
+ self.collection = None
37
+
38
+ db_path = "output/chromadb" # Match your pipeline path
39
+ self.client = chromadb.PersistentClient(path=db_path)
40
+ logger.info(f"ChromaDB client initialized at path: {db_path}")
41
+
42
+ available_collections = [col.name for col in self.client.list_collections()]
43
+ logger.info(f"Available collections: {available_collections}")
44
+
45
+ try:
46
+ self.collection = self.client.get_collection("rag_documents")
47
+ collection_count = self.collection.count()
48
+ logger.info(
49
+ f"Collection 'rag_documents' loaded with {collection_count} documents"
50
+ )
51
+ except Exception as e:
52
+ logger.error(f"Collection 'rag_documents' not found: {str(e)}")
53
+ raise ValueError(
54
+ "Required collection 'rag_documents' not found. "
55
+ f"Available: {available_collections}"
56
+ )
57
+
58
+ # ---- Build BM25 index from all stored docs ----
59
+ logger.info("Building BM25 index from Chroma collection...")
60
+ data = self.collection.get(include=["documents", "metadatas"])
61
+
62
+
63
+ self.all_ids: List[str] = data["ids"]
64
+ self.all_docs: List[str] = data["documents"]
65
+ self.all_metas: List[Dict[str, Any]] = data["metadatas"]
66
+
67
+ self.tokenized_corpus = [self._tokenize(d) for d in self.all_docs]
68
+ self.bm25 = BM25Okapi(self.tokenized_corpus)
69
+
70
+ logger.info(f"BM25 index ready with {len(self.all_docs)} chunks")
71
+ logger.info("Vector store initialized successfully")
72
+
73
+ except Exception as e:
74
+ logger.error(f"Failed to initialize vector store: {str(e)}")
75
+ VectorStoreManager._initialized = False
76
+ raise
77
+
78
+ # ----------------- Helpers -----------------
79
+ def _tokenize(self, text: str) -> List[str]:
80
+ return re.findall(r"\w+", (text or "").lower())
81
+
82
+ def _matches_filters(
83
+ self,
84
+ meta: Dict[str, Any],
85
+ doc_text: str,
86
+ where_filters: Optional[Dict[str, Any]],
87
+ where_document: Optional[Dict[str, Any]],
88
+ ) -> bool:
89
+ if where_filters:
90
+ for k, v in where_filters.items():
91
+ if meta.get(k) != v:
92
+ return False
93
+
94
+ if where_document:
95
+ # you only use {"$contains": "..."}
96
+ contains = where_document.get("$contains")
97
+ if contains and contains.lower() not in (doc_text or "").lower():
98
+ return False
99
+
100
+ return True
101
+
102
+ def _rrf_fuse(
103
+ self,
104
+ dense_ranked: List[Dict[str, Any]],
105
+ sparse_ranked: List[Dict[str, Any]],
106
+ k: int = 60,
107
+ w_dense: float = 0.6,
108
+ w_sparse: float = 0.4,
109
+ ) -> List[Dict[str, Any]]:
110
+ """
111
+ Reciprocal Rank Fusion
112
+ score = w_dense/(k+rank_dense) + w_sparse/(k+rank_sparse)
113
+ """
114
+ scores: Dict[str, Dict[str, Any]] = {}
115
+
116
+ for rank, item in enumerate(dense_ranked):
117
+ doc_id = item["id"]
118
+ scores.setdefault(doc_id, {"score": 0.0, "item": item})
119
+ scores[doc_id]["score"] += w_dense / (k + rank + 1)
120
+
121
+ for rank, item in enumerate(sparse_ranked):
122
+ doc_id = item["id"]
123
+ scores.setdefault(doc_id, {"score": 0.0, "item": item})
124
+ scores[doc_id]["score"] += w_sparse / (k + rank + 1)
125
+
126
+ fused = sorted(scores.values(), key=lambda x: x["score"], reverse=True)
127
+ return [x["item"] for x in fused]
128
+
129
+ # ----------------- Main retrieval -----------------
130
+ def retrieve_documents(
131
+ self,
132
+ question: str,
133
+ n_results: int = 5,
134
+ where_filters: Optional[Dict[str, Any]] = None,
135
+ where_document: Optional[Dict[str, Any]] = None,
136
+ enable_bm25: bool = False,
137
+ bm25_k: Optional[int] = None,
138
+ alpha: float = 0.6, # dense weight in hybrid fusion
139
+ ) -> List[Dict[str, Any]]:
140
+ """
141
+ Retrieve documents using:
142
+ - semantic-only (Chroma)
143
+ - or hybrid semantic + BM25 (RRF fusion)
144
+
145
+ Returns a list of dicts:
146
+ {id, text, metadata, distance, bm25_score(optional)}
147
+ """
148
+ if not self._initialized or self.collection is None:
149
+ raise RuntimeError("VectorStoreManager not properly initialized")
150
+
151
+ logger.info(f"Retrieving documents for query: {question[:50]}...")
152
+ dense_k = n_results
153
+ bm25_k = bm25_k or n_results
154
+
155
+ # ----- Dense retrieval (semantic via Chroma) -----
156
+ try:
157
+ dense_res = self.collection.query(
158
+ query_texts=[question],
159
+ n_results=dense_k,
160
+ include=["documents", "metadatas", "distances"],
161
+ where=where_filters if where_filters else None,
162
+ where_document=where_document if where_document else None,
163
+ )
164
+ except Exception as e:
165
+ logger.error(f"Dense retrieval failed: {str(e)}")
166
+ raise
167
+
168
+ dense_ranked: List[Dict[str, Any]] = []
169
+ if dense_res and dense_res.get("documents") and dense_res["documents"][0]:
170
+ for i in range(len(dense_res["documents"][0])):
171
+ meta = dense_res["metadatas"][0][i]
172
+ dense_ranked.append({
173
+ "id": dense_res["ids"][0][i],
174
+ "text": dense_res["documents"][0][i],
175
+ "metadata": meta,
176
+ "distance": float(dense_res["distances"][0][i]),
177
+ "source": meta.get("source", "Unknown"),
178
+ })
179
+
180
+ if not enable_bm25:
181
+ logger.info(f"Semantic-only retrieved {len(dense_ranked)} docs")
182
+ return dense_ranked
183
+
184
+ # ----- Sparse retrieval (BM25) -----
185
+ q_tokens = self._tokenize(question)
186
+ scores = self.bm25.get_scores(q_tokens)
187
+
188
+ # Apply same filters to BM25 corpus
189
+ valid_indices = []
190
+ for idx, (doc, meta) in enumerate(zip(self.all_docs, self.all_metas)):
191
+ if self._matches_filters(meta, doc, where_filters, where_document):
192
+ valid_indices.append(idx)
193
+
194
+ # take top bm25_k from valid indices
195
+ valid_scores = [(idx, scores[idx]) for idx in valid_indices]
196
+ valid_scores.sort(key=lambda x: x[1], reverse=True)
197
+ top_sparse = valid_scores[:bm25_k]
198
+
199
+ sparse_ranked: List[Dict[str, Any]] = []
200
+ for idx, s in top_sparse:
201
+ meta = self.all_metas[idx]
202
+ sparse_ranked.append({
203
+ "id": self.all_ids[idx],
204
+ "text": self.all_docs[idx],
205
+ "metadata": meta,
206
+ "bm25_score": float(s),
207
+ "distance": None, # may be absent if not in dense top-k
208
+ "source": meta.get("source", "Unknown"),
209
+ })
210
+
211
+ # ----- Fuse dense + sparse -----
212
+ fused = self._rrf_fuse(
213
+ dense_ranked,
214
+ sparse_ranked,
215
+ w_dense=alpha,
216
+ w_sparse=1.0 - alpha,
217
+ )
218
+
219
+ logger.info(
220
+ f"Hybrid retrieved dense={len(dense_ranked)} sparse={len(sparse_ranked)} "
221
+ f"fused={len(fused)}"
222
+ )
223
+ return fused
224
+
225
+
226
+ def get_vector_store() -> VectorStoreManager:
227
+ """FastAPI dependency for injecting VectorStoreManager"""
228
+ return VectorStoreManager()