File size: 3,926 Bytes
835ecb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import chromadb
from chromadb.config import Settings
from langchain_chroma import Chroma
from langchain_core.documents import Document
from typing import List, Dict, Optional
import os

class VectorStoreManager:
    """Manages ChromaDB vector store for persistent storage."""
    
    def __init__(self, 

                 persist_dir: str = "./chroma_db",

                 collection_name: str = "pdf_documents",

                 embeddings=None):
        """

        Initialize vector store.

        

        Args:

            persist_dir: Directory for persistent storage

            collection_name: Name of the collection

            embeddings: LangChain embeddings instance

        """
        self.persist_dir = persist_dir
        self.collection_name = collection_name
        self.embeddings = embeddings
        
        os.makedirs(persist_dir, exist_ok=True)
        
        # Initialize ChromaDB persistent client
        self.client = chromadb.PersistentClient(path=persist_dir)
        
        # Initialize LangChain Chroma wrapper
        self.vector_store = Chroma(
            client=self.client,
            collection_name=collection_name,
            embedding_function=embeddings,
            persist_directory=persist_dir
        )
        
        print(f"Vector store initialized: {persist_dir}/{collection_name}")
    
    def add_documents(self, documents: List[Document], batch_size: int = 50):
        """

        Add documents to vector store.

        

        Args:

            documents: List of LangChain Document objects

            batch_size: Number of documents per batch

        """
        # Process in batches
        for i in range(0, len(documents), batch_size):
            batch = documents[i:i + batch_size]
            try:
                self.vector_store.add_documents(batch)
                print(f"Added {len(batch)} documents (batch {i//batch_size + 1})")
            except Exception as e:
                print(f"Error adding documents: {e}")
    
    def search(self, query: str, k: int = 5) -> List[Dict]:
        """

        Search for similar documents.

        

        Args:

            query: Search query

            k: Number of results to return

        

        Returns:

            List of documents with scores

        """
        results = self.vector_store.similarity_search_with_score(query, k=k)
        
        search_results = []
        for doc, score in results:
            search_results.append({
                "content": doc.page_content,
                "metadata": doc.metadata,
                "similarity": score
            })
        
        return search_results
    
    def get_retriever(self, search_kwargs: Optional[Dict] = None):
        """Get retriever for RAG chain."""
        if search_kwargs is None:
            search_kwargs = {"k": 5}
        
        return self.vector_store.as_retriever(search_kwargs=search_kwargs)
    
    def collection_count(self) -> int:
        """Get number of documents in collection."""
        try:
            collection = self.client.get_collection(self.collection_name)
            return collection.count()
        except Exception as e:
            print(f"Error getting collection count: {e}")
            return 0
    
    def clear_collection(self):
        """Clear all documents from collection."""
        try:
            self.client.delete_collection(self.collection_name)
            self.vector_store = Chroma(
                client=self.client,
                collection_name=self.collection_name,
                embedding_function=self.embeddings,
                persist_directory=self.persist_dir
            )
            print(f"Collection cleared: {self.collection_name}")
        except Exception as e:
            print(f"Error clearing collection: {e}")