Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Vector Database Builder for Scikit-learn Documentation | |
| This module creates a vector database from chunked Scikit-learn documentation | |
| using ChromaDB and Sentence-Transformers for efficient semantic search. | |
| Author: AI Assistant | |
| Date: September 2025 | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Any, Optional | |
| from uuid import uuid4 | |
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class VectorDatabaseBuilder: | |
| """ | |
| A class for building a vector database from chunked documentation. | |
| This class handles the creation of embeddings from text chunks and | |
| their storage in a ChromaDB vector database for efficient retrieval. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = 'all-MiniLM-L6-v2', | |
| db_path: str = './chroma_db', | |
| collection_name: str = 'sklearn_docs' | |
| ): | |
| """ | |
| Initialize the VectorDatabaseBuilder. | |
| Args: | |
| model_name (str): Name of the sentence transformer model | |
| db_path (str): Path to store the ChromaDB database | |
| collection_name (str): Name of the collection in ChromaDB | |
| """ | |
| self.model_name = model_name | |
| self.db_path = Path(db_path) | |
| self.collection_name = collection_name | |
| # Initialize components | |
| self.embedding_model = None | |
| self.chroma_client = None | |
| self.collection = None | |
| logger.info(f"Initialized VectorDatabaseBuilder:") | |
| logger.info(f" - Model: {model_name}") | |
| logger.info(f" - Database path: {db_path}") | |
| logger.info(f" - Collection: {collection_name}") | |
| def load_embedding_model(self) -> None: | |
| """ | |
| Load the sentence transformer model for creating embeddings. | |
| """ | |
| logger.info(f"Loading embedding model: {self.model_name}") | |
| try: | |
| self.embedding_model = SentenceTransformer(self.model_name) | |
| # Test the model with a sample text | |
| test_embedding = self.embedding_model.encode("test sentence") | |
| embedding_dim = len(test_embedding) | |
| logger.info(f"Model loaded successfully!") | |
| logger.info(f" - Embedding dimension: {embedding_dim}") | |
| logger.info(f" - Model device: {self.embedding_model.device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load embedding model: {e}") | |
| raise | |
| def initialize_chroma_client(self) -> None: | |
| """ | |
| Initialize ChromaDB client and create/get collection. | |
| """ | |
| logger.info("Initializing ChromaDB client...") | |
| try: | |
| # Create database directory if it doesn't exist | |
| self.db_path.mkdir(parents=True, exist_ok=True) | |
| # Initialize ChromaDB client with persistent storage | |
| self.chroma_client = chromadb.PersistentClient( | |
| path=str(self.db_path), | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True | |
| ) | |
| ) | |
| logger.info(f"ChromaDB client initialized at: {self.db_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize ChromaDB client: {e}") | |
| raise | |
| def create_collection(self, reset: bool = False) -> None: | |
| """ | |
| Create or get the ChromaDB collection. | |
| Args: | |
| reset (bool): Whether to reset/recreate the collection if it exists | |
| """ | |
| logger.info(f"Creating/getting collection: {self.collection_name}") | |
| try: | |
| if reset: | |
| # Delete existing collection if it exists | |
| try: | |
| self.chroma_client.delete_collection(name=self.collection_name) | |
| logger.info(f"Deleted existing collection: {self.collection_name}") | |
| except Exception: | |
| # Collection doesn't exist, which is fine | |
| pass | |
| # Create or get collection | |
| self.collection = self.chroma_client.get_or_create_collection( | |
| name=self.collection_name, | |
| metadata={"description": "Scikit-learn documentation embeddings"} | |
| ) | |
| # Get collection info | |
| collection_count = self.collection.count() | |
| logger.info(f"Collection '{self.collection_name}' ready") | |
| logger.info(f" - Current document count: {collection_count}") | |
| except Exception as e: | |
| logger.error(f"Failed to create collection: {e}") | |
| raise | |
| def load_chunks(self, chunks_file: str) -> List[Dict[str, Any]]: | |
| """ | |
| Load text chunks from JSON file. | |
| Args: | |
| chunks_file (str): Path to the chunks JSON file | |
| Returns: | |
| List[Dict[str, Any]]: List of chunks with content and metadata | |
| """ | |
| chunks_path = Path(chunks_file) | |
| if not chunks_path.exists(): | |
| raise FileNotFoundError(f"Chunks file not found: {chunks_path}") | |
| logger.info(f"Loading chunks from: {chunks_path}") | |
| try: | |
| with open(chunks_path, 'r', encoding='utf-8') as f: | |
| chunks = json.load(f) | |
| logger.info(f"Loaded {len(chunks)} chunks") | |
| # Validate chunk structure | |
| if chunks and isinstance(chunks[0], dict): | |
| required_keys = {'page_content', 'metadata'} | |
| if not required_keys.issubset(chunks[0].keys()): | |
| raise ValueError("Invalid chunk structure. Missing required keys.") | |
| return chunks | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Invalid JSON in chunks file: {e}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading chunks: {e}") | |
| raise | |
| def create_embeddings_batch( | |
| self, | |
| texts: List[str], | |
| batch_size: int = 32 | |
| ) -> List[List[float]]: | |
| """ | |
| Create embeddings for a batch of texts. | |
| Args: | |
| texts (List[str]): List of texts to embed | |
| batch_size (int): Batch size for processing | |
| Returns: | |
| List[List[float]]: List of embeddings | |
| """ | |
| logger.info(f"Creating embeddings for {len(texts)} texts...") | |
| try: | |
| # Process in batches to manage memory | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| batch_embeddings = self.embedding_model.encode( | |
| batch_texts, | |
| show_progress_bar=False, | |
| convert_to_numpy=True | |
| ) | |
| # Convert to list of lists | |
| all_embeddings.extend([emb.tolist() for emb in batch_embeddings]) | |
| if (i + batch_size) % 100 == 0 or (i + batch_size) >= len(texts): | |
| logger.info(f" - Processed {min(i + batch_size, len(texts))}/{len(texts)} embeddings") | |
| logger.info("Embedding creation completed!") | |
| return all_embeddings | |
| except Exception as e: | |
| logger.error(f"Error creating embeddings: {e}") | |
| raise | |
| def add_documents_to_collection( | |
| self, | |
| chunks: List[Dict[str, Any]], | |
| batch_size: int = 100 | |
| ) -> None: | |
| """ | |
| Add documents to the ChromaDB collection. | |
| Args: | |
| chunks (List[Dict[str, Any]]): List of document chunks | |
| batch_size (int): Batch size for adding to database | |
| """ | |
| logger.info(f"Adding {len(chunks)} documents to collection...") | |
| try: | |
| # Extract texts and metadata | |
| texts = [chunk['page_content'] for chunk in chunks] | |
| metadatas = [] | |
| for chunk in chunks: | |
| # Prepare metadata - ChromaDB requires string values | |
| metadata = { | |
| 'url': chunk['metadata']['url'], | |
| 'chunk_index': str(chunk['metadata']['chunk_index']), | |
| 'source': chunk['metadata'].get('source', 'scikit-learn-docs'), | |
| 'content_length': str(len(chunk['page_content'])) | |
| } | |
| metadatas.append(metadata) | |
| # Create embeddings | |
| embeddings = self.create_embeddings_batch(texts, batch_size=32) | |
| # Generate unique IDs | |
| ids = [str(uuid4()) for _ in range(len(chunks))] | |
| # Add to collection in batches | |
| for i in range(0, len(chunks), batch_size): | |
| end_idx = min(i + batch_size, len(chunks)) | |
| batch_ids = ids[i:end_idx] | |
| batch_documents = texts[i:end_idx] | |
| batch_metadatas = metadatas[i:end_idx] | |
| batch_embeddings = embeddings[i:end_idx] | |
| self.collection.add( | |
| ids=batch_ids, | |
| documents=batch_documents, | |
| metadatas=batch_metadatas, | |
| embeddings=batch_embeddings | |
| ) | |
| logger.info(f" - Added batch {i//batch_size + 1}: documents {i+1}-{end_idx}") | |
| # Verify the addition | |
| final_count = self.collection.count() | |
| logger.info(f"Successfully added documents to collection!") | |
| logger.info(f" - Total documents in collection: {final_count}") | |
| except Exception as e: | |
| logger.error(f"Error adding documents to collection: {e}") | |
| raise | |
| def get_database_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get statistics about the vector database. | |
| Returns: | |
| Dict[str, Any]: Database statistics | |
| """ | |
| try: | |
| stats = { | |
| 'collection_name': self.collection_name, | |
| 'total_documents': self.collection.count(), | |
| 'database_path': str(self.db_path), | |
| 'embedding_model': self.model_name, | |
| 'database_size_mb': self._get_directory_size(self.db_path) / (1024 * 1024) | |
| } | |
| return stats | |
| except Exception as e: | |
| logger.error(f"Error getting database stats: {e}") | |
| return {} | |
| def _get_directory_size(self, path: Path) -> int: | |
| """ | |
| Calculate the total size of a directory. | |
| Args: | |
| path (Path): Directory path | |
| Returns: | |
| int: Size in bytes | |
| """ | |
| total_size = 0 | |
| try: | |
| for item in path.rglob('*'): | |
| if item.is_file(): | |
| total_size += item.stat().st_size | |
| except (OSError, PermissionError): | |
| pass | |
| return total_size | |
| def build_database( | |
| self, | |
| chunks_file: str = 'chunks.json', | |
| reset_collection: bool = True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Complete pipeline to build the vector database. | |
| Args: | |
| chunks_file (str): Path to chunks JSON file | |
| reset_collection (bool): Whether to reset existing collection | |
| Returns: | |
| Dict[str, Any]: Build statistics | |
| """ | |
| logger.info("Starting vector database build pipeline...") | |
| start_time = time.time() | |
| try: | |
| # Load embedding model | |
| self.load_embedding_model() | |
| # Initialize ChromaDB | |
| self.initialize_chroma_client() | |
| # Create collection | |
| self.create_collection(reset=reset_collection) | |
| # Load chunks | |
| chunks = self.load_chunks(chunks_file) | |
| # Add documents to collection | |
| self.add_documents_to_collection(chunks) | |
| # Get final statistics | |
| build_time = time.time() - start_time | |
| stats = self.get_database_stats() | |
| stats['build_time_seconds'] = round(build_time, 2) | |
| stats['documents_per_second'] = round(len(chunks) / build_time, 2) | |
| logger.info("Vector database build completed successfully!") | |
| return stats | |
| except Exception as e: | |
| logger.error(f"Database build failed: {e}") | |
| raise | |
| def test_search(self, query: str = "linear regression", n_results: int = 3) -> None: | |
| """ | |
| Test the vector database with a sample search. | |
| Args: | |
| query (str): Test query | |
| n_results (int): Number of results to return | |
| """ | |
| logger.info(f"Testing database with query: '{query}'") | |
| try: | |
| results = self.collection.query( | |
| query_texts=[query], | |
| n_results=n_results | |
| ) | |
| logger.info(f"Search test successful! Found {len(results['documents'][0])} results:") | |
| for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])): | |
| logger.info(f" Result {i+1}:") | |
| logger.info(f" - URL: {metadata['url'].split('/')[-1]}") | |
| logger.info(f" - Preview: {doc[:100]}...") | |
| except Exception as e: | |
| logger.error(f"Search test failed: {e}") | |
| def main(): | |
| """ | |
| Main function to build the vector database. | |
| """ | |
| print("Scikit-learn Documentation Vector Database Builder") | |
| print("=" * 60) | |
| # Configuration | |
| chunks_file = "chunks.json" | |
| # Initialize builder | |
| builder = VectorDatabaseBuilder( | |
| model_name='all-MiniLM-L6-v2', | |
| db_path='./chroma_db', | |
| collection_name='sklearn_docs' | |
| ) | |
| try: | |
| # Build database | |
| stats = builder.build_database(chunks_file, reset_collection=True) | |
| # Display results | |
| print(f"\nπ Vector Database Build Completed!") | |
| print(f" π Total documents: {stats['total_documents']:,}") | |
| print(f" π§ Embedding model: {stats['embedding_model']}") | |
| print(f" πΎ Database size: {stats['database_size_mb']:.2f} MB") | |
| print(f" β±οΈ Build time: {stats['build_time_seconds']:.1f} seconds") | |
| print(f" π Processing speed: {stats['documents_per_second']:.1f} docs/sec") | |
| print(f" π Database location: {stats['database_path']}") | |
| # Test the database | |
| print(f"\nπ Testing database with sample search...") | |
| builder.test_search("What is cross validation?", n_results=2) | |
| print(f"\nβ Vector database is ready for use!") | |
| except Exception as e: | |
| logger.error(f"Build failed: {e}") | |
| print(f"\nβ Error: {e}") | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) |