File size: 15,623 Bytes
9222df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
#!/usr/bin/env python3
"""
Vector Database Builder for Scikit-learn Documentation

This module creates a vector database from chunked Scikit-learn documentation
using ChromaDB and Sentence-Transformers for efficient semantic search.

Author: AI Assistant
Date: September 2025
"""

import json
import logging
import os
import time
from pathlib import Path
from typing import Dict, List, Any, Optional
from uuid import uuid4

import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import numpy as np


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class VectorDatabaseBuilder:
    """
    A class for building a vector database from chunked documentation.
    
    This class handles the creation of embeddings from text chunks and
    their storage in a ChromaDB vector database for efficient retrieval.
    """
    
    def __init__(
        self,
        model_name: str = 'all-MiniLM-L6-v2',
        db_path: str = './chroma_db',
        collection_name: str = 'sklearn_docs'
    ):
        """
        Initialize the VectorDatabaseBuilder.
        
        Args:
            model_name (str): Name of the sentence transformer model
            db_path (str): Path to store the ChromaDB database
            collection_name (str): Name of the collection in ChromaDB
        """
        self.model_name = model_name
        self.db_path = Path(db_path)
        self.collection_name = collection_name
        
        # Initialize components
        self.embedding_model = None
        self.chroma_client = None
        self.collection = None
        
        logger.info(f"Initialized VectorDatabaseBuilder:")
        logger.info(f"  - Model: {model_name}")
        logger.info(f"  - Database path: {db_path}")
        logger.info(f"  - Collection: {collection_name}")
    
    def load_embedding_model(self) -> None:
        """
        Load the sentence transformer model for creating embeddings.
        """
        logger.info(f"Loading embedding model: {self.model_name}")
        
        try:
            self.embedding_model = SentenceTransformer(self.model_name)
            
            # Test the model with a sample text
            test_embedding = self.embedding_model.encode("test sentence")
            embedding_dim = len(test_embedding)
            
            logger.info(f"Model loaded successfully!")
            logger.info(f"  - Embedding dimension: {embedding_dim}")
            logger.info(f"  - Model device: {self.embedding_model.device}")
            
        except Exception as e:
            logger.error(f"Failed to load embedding model: {e}")
            raise
    
    def initialize_chroma_client(self) -> None:
        """
        Initialize ChromaDB client and create/get collection.
        """
        logger.info("Initializing ChromaDB client...")
        
        try:
            # Create database directory if it doesn't exist
            self.db_path.mkdir(parents=True, exist_ok=True)
            
            # Initialize ChromaDB client with persistent storage
            self.chroma_client = chromadb.PersistentClient(
                path=str(self.db_path),
                settings=Settings(
                    anonymized_telemetry=False,
                    allow_reset=True
                )
            )
            
            logger.info(f"ChromaDB client initialized at: {self.db_path}")
            
        except Exception as e:
            logger.error(f"Failed to initialize ChromaDB client: {e}")
            raise
    
    def create_collection(self, reset: bool = False) -> None:
        """
        Create or get the ChromaDB collection.
        
        Args:
            reset (bool): Whether to reset/recreate the collection if it exists
        """
        logger.info(f"Creating/getting collection: {self.collection_name}")
        
        try:
            if reset:
                # Delete existing collection if it exists
                try:
                    self.chroma_client.delete_collection(name=self.collection_name)
                    logger.info(f"Deleted existing collection: {self.collection_name}")
                except Exception:
                    # Collection doesn't exist, which is fine
                    pass
            
            # Create or get collection
            self.collection = self.chroma_client.get_or_create_collection(
                name=self.collection_name,
                metadata={"description": "Scikit-learn documentation embeddings"}
            )
            
            # Get collection info
            collection_count = self.collection.count()
            logger.info(f"Collection '{self.collection_name}' ready")
            logger.info(f"  - Current document count: {collection_count}")
            
        except Exception as e:
            logger.error(f"Failed to create collection: {e}")
            raise
    
    def load_chunks(self, chunks_file: str) -> List[Dict[str, Any]]:
        """
        Load text chunks from JSON file.
        
        Args:
            chunks_file (str): Path to the chunks JSON file
            
        Returns:
            List[Dict[str, Any]]: List of chunks with content and metadata
        """
        chunks_path = Path(chunks_file)
        
        if not chunks_path.exists():
            raise FileNotFoundError(f"Chunks file not found: {chunks_path}")
        
        logger.info(f"Loading chunks from: {chunks_path}")
        
        try:
            with open(chunks_path, 'r', encoding='utf-8') as f:
                chunks = json.load(f)
            
            logger.info(f"Loaded {len(chunks)} chunks")
            
            # Validate chunk structure
            if chunks and isinstance(chunks[0], dict):
                required_keys = {'page_content', 'metadata'}
                if not required_keys.issubset(chunks[0].keys()):
                    raise ValueError("Invalid chunk structure. Missing required keys.")
            
            return chunks
            
        except json.JSONDecodeError as e:
            logger.error(f"Invalid JSON in chunks file: {e}")
            raise
        except Exception as e:
            logger.error(f"Error loading chunks: {e}")
            raise
    
    def create_embeddings_batch(
        self, 
        texts: List[str], 
        batch_size: int = 32
    ) -> List[List[float]]:
        """
        Create embeddings for a batch of texts.
        
        Args:
            texts (List[str]): List of texts to embed
            batch_size (int): Batch size for processing
            
        Returns:
            List[List[float]]: List of embeddings
        """
        logger.info(f"Creating embeddings for {len(texts)} texts...")
        
        try:
            # Process in batches to manage memory
            all_embeddings = []
            
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i + batch_size]
                batch_embeddings = self.embedding_model.encode(
                    batch_texts,
                    show_progress_bar=False,
                    convert_to_numpy=True
                )
                
                # Convert to list of lists
                all_embeddings.extend([emb.tolist() for emb in batch_embeddings])
                
                if (i + batch_size) % 100 == 0 or (i + batch_size) >= len(texts):
                    logger.info(f"  - Processed {min(i + batch_size, len(texts))}/{len(texts)} embeddings")
            
            logger.info("Embedding creation completed!")
            return all_embeddings
            
        except Exception as e:
            logger.error(f"Error creating embeddings: {e}")
            raise
    
    def add_documents_to_collection(
        self,
        chunks: List[Dict[str, Any]],
        batch_size: int = 100
    ) -> None:
        """
        Add documents to the ChromaDB collection.
        
        Args:
            chunks (List[Dict[str, Any]]): List of document chunks
            batch_size (int): Batch size for adding to database
        """
        logger.info(f"Adding {len(chunks)} documents to collection...")
        
        try:
            # Extract texts and metadata
            texts = [chunk['page_content'] for chunk in chunks]
            metadatas = []
            
            for chunk in chunks:
                # Prepare metadata - ChromaDB requires string values
                metadata = {
                    'url': chunk['metadata']['url'],
                    'chunk_index': str(chunk['metadata']['chunk_index']),
                    'source': chunk['metadata'].get('source', 'scikit-learn-docs'),
                    'content_length': str(len(chunk['page_content']))
                }
                metadatas.append(metadata)
            
            # Create embeddings
            embeddings = self.create_embeddings_batch(texts, batch_size=32)
            
            # Generate unique IDs
            ids = [str(uuid4()) for _ in range(len(chunks))]
            
            # Add to collection in batches
            for i in range(0, len(chunks), batch_size):
                end_idx = min(i + batch_size, len(chunks))
                
                batch_ids = ids[i:end_idx]
                batch_documents = texts[i:end_idx]
                batch_metadatas = metadatas[i:end_idx]
                batch_embeddings = embeddings[i:end_idx]
                
                self.collection.add(
                    ids=batch_ids,
                    documents=batch_documents,
                    metadatas=batch_metadatas,
                    embeddings=batch_embeddings
                )
                
                logger.info(f"  - Added batch {i//batch_size + 1}: documents {i+1}-{end_idx}")
            
            # Verify the addition
            final_count = self.collection.count()
            logger.info(f"Successfully added documents to collection!")
            logger.info(f"  - Total documents in collection: {final_count}")
            
        except Exception as e:
            logger.error(f"Error adding documents to collection: {e}")
            raise
    
    def get_database_stats(self) -> Dict[str, Any]:
        """
        Get statistics about the vector database.
        
        Returns:
            Dict[str, Any]: Database statistics
        """
        try:
            stats = {
                'collection_name': self.collection_name,
                'total_documents': self.collection.count(),
                'database_path': str(self.db_path),
                'embedding_model': self.model_name,
                'database_size_mb': self._get_directory_size(self.db_path) / (1024 * 1024)
            }
            
            return stats
            
        except Exception as e:
            logger.error(f"Error getting database stats: {e}")
            return {}
    
    def _get_directory_size(self, path: Path) -> int:
        """
        Calculate the total size of a directory.
        
        Args:
            path (Path): Directory path
            
        Returns:
            int: Size in bytes
        """
        total_size = 0
        try:
            for item in path.rglob('*'):
                if item.is_file():
                    total_size += item.stat().st_size
        except (OSError, PermissionError):
            pass
        return total_size
    
    def build_database(
        self,
        chunks_file: str = 'chunks.json',
        reset_collection: bool = True
    ) -> Dict[str, Any]:
        """
        Complete pipeline to build the vector database.
        
        Args:
            chunks_file (str): Path to chunks JSON file
            reset_collection (bool): Whether to reset existing collection
            
        Returns:
            Dict[str, Any]: Build statistics
        """
        logger.info("Starting vector database build pipeline...")
        start_time = time.time()
        
        try:
            # Load embedding model
            self.load_embedding_model()
            
            # Initialize ChromaDB
            self.initialize_chroma_client()
            
            # Create collection
            self.create_collection(reset=reset_collection)
            
            # Load chunks
            chunks = self.load_chunks(chunks_file)
            
            # Add documents to collection
            self.add_documents_to_collection(chunks)
            
            # Get final statistics
            build_time = time.time() - start_time
            stats = self.get_database_stats()
            stats['build_time_seconds'] = round(build_time, 2)
            stats['documents_per_second'] = round(len(chunks) / build_time, 2)
            
            logger.info("Vector database build completed successfully!")
            return stats
            
        except Exception as e:
            logger.error(f"Database build failed: {e}")
            raise
    
    def test_search(self, query: str = "linear regression", n_results: int = 3) -> None:
        """
        Test the vector database with a sample search.
        
        Args:
            query (str): Test query
            n_results (int): Number of results to return
        """
        logger.info(f"Testing database with query: '{query}'")
        
        try:
            results = self.collection.query(
                query_texts=[query],
                n_results=n_results
            )
            
            logger.info(f"Search test successful! Found {len(results['documents'][0])} results:")
            
            for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
                logger.info(f"  Result {i+1}:")
                logger.info(f"    - URL: {metadata['url'].split('/')[-1]}")
                logger.info(f"    - Preview: {doc[:100]}...")
                
        except Exception as e:
            logger.error(f"Search test failed: {e}")


def main():
    """
    Main function to build the vector database.
    """
    print("Scikit-learn Documentation Vector Database Builder")
    print("=" * 60)
    
    # Configuration
    chunks_file = "chunks.json"
    
    # Initialize builder
    builder = VectorDatabaseBuilder(
        model_name='all-MiniLM-L6-v2',
        db_path='./chroma_db',
        collection_name='sklearn_docs'
    )
    
    try:
        # Build database
        stats = builder.build_database(chunks_file, reset_collection=True)
        
        # Display results
        print(f"\nπŸŽ‰ Vector Database Build Completed!")
        print(f"  πŸ“Š Total documents: {stats['total_documents']:,}")
        print(f"  🧠 Embedding model: {stats['embedding_model']}")
        print(f"  πŸ’Ύ Database size: {stats['database_size_mb']:.2f} MB")
        print(f"  ⏱️  Build time: {stats['build_time_seconds']:.1f} seconds")
        print(f"  πŸš€ Processing speed: {stats['documents_per_second']:.1f} docs/sec")
        print(f"  πŸ“ Database location: {stats['database_path']}")
        
        # Test the database
        print(f"\nπŸ” Testing database with sample search...")
        builder.test_search("What is cross validation?", n_results=2)
        
        print(f"\nβœ… Vector database is ready for use!")
        
    except Exception as e:
        logger.error(f"Build failed: {e}")
        print(f"\n❌ Error: {e}")
        return 1
    
    return 0


if __name__ == "__main__":
    exit(main())