telecom-rag / src /vector_store.py
ajaymauryabbn's picture
feat: Telecom RAG System - Production Ready
eb731f7
"""Telecom RAG - Vector Store Module
ChromaDB-based vector store for document storage and retrieval.
"""
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
from .config import CHROMA_PERSIST_DIR, CHROMA_COLLECTION_NAME, TOP_K_RESULTS
from .data_loader import Document
from .embeddings import EmbeddingModel
class TelecomVectorStore:
"""ChromaDB vector store for telecom documents."""
def __init__(self, collection_name: Optional[str] = None):
self.collection_name = collection_name or CHROMA_COLLECTION_NAME
self.embedding_model = EmbeddingModel()
self._initialize_store()
def _initialize_store(self):
"""Initialize ChromaDB."""
try:
import chromadb
from chromadb.config import Settings
# Ensure persist directory exists
CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
# Initialize client with persistence
self.client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST_DIR),
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"} # Use cosine similarity
)
print(f"✅ ChromaDB initialized: {self.collection_name}")
print(f" Documents in collection: {self.collection.count()}")
except ImportError:
raise ImportError("chromadb not installed. Run: pip install chromadb")
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""Add documents to the vector store."""
if not documents:
print("⚠️ No documents to add")
return
print(f"📥 Adding {len(documents)} documents to vector store...")
# Use global document counter to avoid ID collisions
global_doc_idx = 0
total_batches = (len(documents) - 1) // batch_size + 1
for batch_num, i in enumerate(range(0, len(documents), batch_size), 1):
batch = documents[i:i + batch_size]
# Prepare batch data with unique global IDs
ids = [f"doc_{global_doc_idx + j}" for j in range(len(batch))]
global_doc_idx += len(batch)
contents = [doc.content for doc in batch]
metadatas = [doc.metadata for doc in batch]
# Generate embeddings with retry logic
try:
embeddings = self.embedding_model.embed(contents)
except Exception as e:
print(f"⚠️ Error generating embeddings for batch {batch_num}: {e}")
print(" Retrying with smaller batch...")
# Retry with individual documents
embeddings = []
for content in contents:
try:
emb = self.embedding_model.embed([content])[0]
embeddings.append(emb)
except Exception as inner_e:
print(f" Skipping document due to error: {inner_e}")
continue
if not embeddings:
print(f" Skipping entire batch {batch_num}")
continue
# Add to collection
self.collection.add(
ids=ids[:len(embeddings)],
embeddings=embeddings,
documents=contents[:len(embeddings)],
metadatas=metadatas[:len(embeddings)]
)
print(f" Added batch {batch_num}/{total_batches}")
print(f"✅ Vector store now contains {self.collection.count()} documents")
def search(
self,
query: str,
top_k: int = TOP_K_RESULTS,
filter_dict: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search for similar documents.
Args:
query: Search query
top_k: Number of results to return
filter_dict: Optional metadata filters
Returns:
List of results with content, metadata, and similarity score
"""
# Generate query embedding
query_embedding = self.embedding_model.embed_query(query)
# Build query parameters
query_params = {
"query_embeddings": [query_embedding],
"n_results": top_k,
"include": ["documents", "metadatas", "distances"]
}
if filter_dict:
query_params["where"] = filter_dict
# Execute query
results = self.collection.query(**query_params)
# Format results
formatted_results = []
if results["documents"] and results["documents"][0]:
for idx in range(len(results["documents"][0])):
formatted_results.append({
"content": results["documents"][0][idx],
"metadata": results["metadatas"][0][idx] if results["metadatas"] else {},
"distance": results["distances"][0][idx] if results["distances"] else 0,
"similarity": 1 - results["distances"][0][idx] if results["distances"] else 1
})
return formatted_results
def search_by_category(
self,
query: str,
category: str,
top_k: int = TOP_K_RESULTS
) -> List[Dict[str, Any]]:
"""Search within a specific category."""
return self.search(query, top_k, filter_dict={"category": category})
def get_stats(self) -> Dict[str, Any]:
"""Get vector store statistics."""
count = self.collection.count()
# Get category distribution if possible
try:
sample = self.collection.peek(min(100, count))
categories = {}
if sample["metadatas"]:
for meta in sample["metadatas"]:
cat = meta.get("category", "unknown")
categories[cat] = categories.get(cat, 0) + 1
except:
categories = {}
return {
"total_documents": count,
"collection_name": self.collection_name,
"categories_sample": categories
}
def clear(self):
"""Clear all documents from the collection."""
self.client.delete_collection(self.collection_name)
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"}
)
print(f"🗑️ Cleared collection: {self.collection_name}")
if __name__ == "__main__":
# Test vector store
store = TelecomVectorStore()
# Add sample documents
test_docs = [
Document(
content="HARQ (Hybrid Automatic Repeat Request) is a combination of high-rate forward error correction and ARQ error-control.",
metadata={"source": "test", "category": "5g_terminology"}
),
Document(
content="5G NR supports both FDD and TDD modes of operation for flexible spectrum usage.",
metadata={"source": "test", "category": "5g_specifications"}
)
]
store.add_documents(test_docs)
# Test search
results = store.search("What is HARQ?")
print("\n🔍 Search results for 'What is HARQ?':")
for r in results:
print(f" - {r['content'][:100]}... (similarity: {r['similarity']:.3f})")
print("\n📊 Stats:", store.get_stats())