File size: 6,207 Bytes
3998131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Vector database module using ChromaDB
Stores and retrieves document chunks with embeddings
"""

import logging
from pathlib import Path
from typing import List, Dict, Any, Optional

try:
    import chromadb
    from chromadb.config import Settings
    CHROMADB_AVAILABLE = True
except ImportError:
    CHROMADB_AVAILABLE = False

from .config import VECTOR_DB_DIR, DEFAULT_RETRIEVAL_K

logger = logging.getLogger(__name__)


class LegalVectorDB:
    """ChromaDB vector database for legal documents"""
    
    def __init__(self, persist_directory: Path = VECTOR_DB_DIR):
        """
        Initialize ChromaDB with persistent storage
        
        Args:
            persist_directory: Directory to store the database
        """
        if not CHROMADB_AVAILABLE:
            raise ImportError(
                "chromadb not installed. "
                "Install with: pip install chromadb"
            )
        
        self.persist_directory = Path(persist_directory)
        self.persist_directory.mkdir(parents=True, exist_ok=True)
        
        logger.info(f"Initializing ChromaDB at {self.persist_directory}")
        
        # Initialize ChromaDB client with persistent storage
        self.client = chromadb.PersistentClient(
            path=str(self.persist_directory)
        )
        
        # Create or get collection
        self.collection_name = "nepal_legal_docs"
        self.collection = self.client.get_or_create_collection(
            name=self.collection_name,
            metadata={"description": "Nepal legal documents for RAG-based law explanation"}
        )
        
        current_count = self.collection.count()
        logger.info(f"Collection '{self.collection_name}' ready. Current document count: {current_count}")
    
    def add_chunks(
        self,
        chunks: List[Dict[str, Any]],
        embeddings: List[List[float]]
    ) -> None:
        """
        Add chunks with embeddings to the database
        
        Args:
            chunks: List of chunk dictionaries with 'chunk_id', 'text', and 'metadata'
            embeddings: List of embedding vectors (as lists)
        """
        if len(chunks) != len(embeddings):
            raise ValueError(f"Number of chunks ({len(chunks)}) must match number of embeddings ({len(embeddings)})")
        
        # Extract data from chunks
        ids = [chunk['chunk_id'] for chunk in chunks]
        documents = [chunk['text'] for chunk in chunks]
        
        # Clean metadata: ChromaDB only accepts str, int, float, bool
        # Remove None values and convert other types to strings
        metadatas = []
        for chunk in chunks:
            cleaned_metadata = {}
            for key, value in chunk['metadata'].items():
                if value is None:
                    # Skip None values
                    continue
                elif isinstance(value, (str, int, float, bool)):
                    # Keep valid types as-is
                    cleaned_metadata[key] = value
                elif isinstance(value, list):
                    # Convert lists to comma-separated strings
                    if value:  # Only include non-empty lists
                        cleaned_metadata[key] = ', '.join(str(item) for item in value)
                else:
                    # Convert other types to strings
                    cleaned_metadata[key] = str(value)
            metadatas.append(cleaned_metadata)
        
        logger.info(f"Adding {len(chunks)} chunks to vector database")
        
        # Add to ChromaDB
        self.collection.add(
            ids=ids,
            documents=documents,
            embeddings=embeddings,
            metadatas=metadatas
        )
        
        total_count = self.collection.count()
        logger.info(f"Successfully added chunks. Total documents in database: {total_count}")
    
    def query(
        self,
        query_text: str,
        n_results: int = DEFAULT_RETRIEVAL_K,
        where: Optional[Dict] = None
    ) -> Dict[str, Any]:
        """
        Query the database with a text query
        
        Args:
            query_text: Query string
            n_results: Number of results to return
            where: Optional metadata filter
            
        Returns:
            Dictionary with 'ids', 'documents', 'metadatas', and 'distances'
        """
        logger.info(f"Querying database with: '{query_text[:50]}...' (n_results={n_results})")
        
        results = self.collection.query(
            query_texts=[query_text],
            n_results=n_results,
            where=where
        )
        
        return results
    
    def query_with_embedding(
        self,
        query_embedding: List[float],
        n_results: int = DEFAULT_RETRIEVAL_K,
        where: Optional[Dict] = None
    ) -> Dict[str, Any]:
        """
        Query with pre-computed embedding
        
        Args:
            query_embedding: Query embedding vector
            n_results: Number of results to return
            where: Optional metadata filter
            
        Returns:
            Dictionary with 'ids', 'documents', 'metadatas', and 'distances'
        """
        logger.info(f"Querying database with embedding (n_results={n_results})")
        
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            where=where
        )
        
        return results
    
    def get_count(self) -> int:
        """Get the number of documents in the database"""
        return self.collection.count()
    
    def delete_collection(self) -> None:
        """Delete the entire collection (use with caution!)"""
        logger.warning(f"Deleting collection '{self.collection_name}'")
        self.client.delete_collection(name=self.collection_name)
        logger.info("Collection deleted")
    
    def peek(self, limit: int = 5) -> Dict[str, Any]:
        """
        Peek at some documents in the database
        
        Args:
            limit: Number of documents to return
            
        Returns:
            Dictionary with sample documents
        """
        return self.collection.peek(limit=limit)