project2 / src /vector_store.py
dnj0's picture
Upload 6 files
555c75a verified
import chromadb
from chromadb.config import Settings
import os
from typing import List, Dict, Optional
class VectorStore:
def __init__(self, persist_dir: str = "./chroma_db", embedding_function=None):
self.persist_dir = persist_dir
os.makedirs(persist_dir, exist_ok=True)
# Initialize ChromaDB persistent client
self.client = chromadb.PersistentClient(
path=persist_dir,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
self.embedding_function = embedding_function
self.collection = None
def get_or_create_collection(self, collection_name: str = "pdf_documents"):
"""Get or create ChromaDB collection"""
try:
# Try to get existing collection
self.collection = self.client.get_collection(
name=collection_name,
embedding_function=self.embedding_function
)
print(f"βœ“ Loaded existing collection: {collection_name}")
except:
# Create new collection
self.collection = self.client.create_collection(
name=collection_name,
embedding_function=self.embedding_function,
metadata={"hnsw:space": "cosine"}
)
print(f"βœ“ Created new collection: {collection_name}")
return self.collection
def add_documents(self, documents: List[str], metadatas: List[Dict], ids: Optional[List[str]] = None):
"""Add documents to vector store"""
if not self.collection:
self.get_or_create_collection()
if ids is None:
ids = [f"doc_{i}" for i in range(len(documents))]
# Get existing IDs to avoid duplicates
try:
existing_ids = self.collection.get()["ids"]
except:
existing_ids = []
# Filter out documents that already exist
docs_to_add = []
meta_to_add = []
ids_to_add = []
for doc, meta, doc_id in zip(documents, metadatas, ids):
if doc_id not in existing_ids:
docs_to_add.append(doc)
meta_to_add.append(meta)
ids_to_add.append(doc_id)
if docs_to_add:
self.collection.add(
documents=docs_to_add,
metadatas=meta_to_add,
ids=ids_to_add
)
print(f"βœ“ Added {len(docs_to_add)} new documents to vector store")
else:
print("βœ“ All documents already in vector store")
def search(self, query: str, n_results: int = 5) -> Dict:
"""Search documents in vector store"""
if not self.collection:
return {"documents": [], "metadatas": [], "distances": []}
results = self.collection.query(
query_texts=[query],
n_results=n_results
)
return results
def get_collection_info(self) -> Dict:
"""Get collection statistics"""
if not self.collection:
return {}
count = self.collection.count()
return {
"collection_name": self.collection.name,
"document_count": count
}