|
|
import chromadb
|
|
|
from chromadb.config import Settings
|
|
|
import os
|
|
|
from typing import List, Dict, Optional
|
|
|
|
|
|
class VectorStore:
|
|
|
def __init__(self, persist_dir: str = "./chroma_db", embedding_function=None):
|
|
|
self.persist_dir = persist_dir
|
|
|
os.makedirs(persist_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
self.client = chromadb.PersistentClient(
|
|
|
path=persist_dir,
|
|
|
settings=Settings(
|
|
|
anonymized_telemetry=False,
|
|
|
allow_reset=True
|
|
|
)
|
|
|
)
|
|
|
|
|
|
self.embedding_function = embedding_function
|
|
|
self.collection = None
|
|
|
|
|
|
def get_or_create_collection(self, collection_name: str = "pdf_documents"):
|
|
|
"""Get or create ChromaDB collection"""
|
|
|
try:
|
|
|
|
|
|
self.collection = self.client.get_collection(
|
|
|
name=collection_name,
|
|
|
embedding_function=self.embedding_function
|
|
|
)
|
|
|
print(f"β Loaded existing collection: {collection_name}")
|
|
|
except:
|
|
|
|
|
|
self.collection = self.client.create_collection(
|
|
|
name=collection_name,
|
|
|
embedding_function=self.embedding_function,
|
|
|
metadata={"hnsw:space": "cosine"}
|
|
|
)
|
|
|
print(f"β Created new collection: {collection_name}")
|
|
|
|
|
|
return self.collection
|
|
|
|
|
|
def add_documents(self, documents: List[str], metadatas: List[Dict], ids: Optional[List[str]] = None):
|
|
|
"""Add documents to vector store"""
|
|
|
if not self.collection:
|
|
|
self.get_or_create_collection()
|
|
|
|
|
|
if ids is None:
|
|
|
ids = [f"doc_{i}" for i in range(len(documents))]
|
|
|
|
|
|
|
|
|
try:
|
|
|
existing_ids = self.collection.get()["ids"]
|
|
|
except:
|
|
|
existing_ids = []
|
|
|
|
|
|
|
|
|
docs_to_add = []
|
|
|
meta_to_add = []
|
|
|
ids_to_add = []
|
|
|
|
|
|
for doc, meta, doc_id in zip(documents, metadatas, ids):
|
|
|
if doc_id not in existing_ids:
|
|
|
docs_to_add.append(doc)
|
|
|
meta_to_add.append(meta)
|
|
|
ids_to_add.append(doc_id)
|
|
|
|
|
|
if docs_to_add:
|
|
|
self.collection.add(
|
|
|
documents=docs_to_add,
|
|
|
metadatas=meta_to_add,
|
|
|
ids=ids_to_add
|
|
|
)
|
|
|
print(f"β Added {len(docs_to_add)} new documents to vector store")
|
|
|
else:
|
|
|
print("β All documents already in vector store")
|
|
|
|
|
|
def search(self, query: str, n_results: int = 5) -> Dict:
|
|
|
"""Search documents in vector store"""
|
|
|
if not self.collection:
|
|
|
return {"documents": [], "metadatas": [], "distances": []}
|
|
|
|
|
|
results = self.collection.query(
|
|
|
query_texts=[query],
|
|
|
n_results=n_results
|
|
|
)
|
|
|
|
|
|
return results
|
|
|
|
|
|
def get_collection_info(self) -> Dict:
|
|
|
"""Get collection statistics"""
|
|
|
if not self.collection:
|
|
|
return {}
|
|
|
|
|
|
count = self.collection.count()
|
|
|
return {
|
|
|
"collection_name": self.collection.name,
|
|
|
"document_count": count
|
|
|
} |