| | """ |
| | ChromaDBManager: A utility class for managing ChromaDB configurations, embeddings, and logging. |
| | This module provides an interface to configure and interact with a ChromaDB document store |
| | using Hugging Face embedding models. It supports persistent storage configuration, model selection, |
| | and logging setup. Designed for integration in local development environments or cloud deployments. |
| | Design Assumptions: |
| | - Configuration values (e.g., DB path, collection name) are loaded from a `.env` file. |
| | - Hugging Face API keys are stored securely in the macOS keychain. |
| | - Logging is configured per class and can output to both the console and file. |
| | - Embedding models are specified via a `models.txt` file, which is automatically created if missing. |
| | - The default models support a range of needs: small, medium, multilingual, and e5 variants. |
| | Core Logic: |
| | - Loads config values using `dotenv_values`. |
| | - Ensures the persistence path exists or is created. |
| | - Initializes a ChromaDB instance with the chosen embedding model. |
| | - Logs system and model configuration for traceability. |
| | - Supports both local and remote Hugging Face embeddings. |
| | Instructions for Use: |
| | 1. Create a `.env` file in the repo root with keys like: |
| | CHROMA_DB_PATH=data/chroma_db |
| | CHROMA_DB_COLLECTION=documents |
| | 2. Ensure your Hugging Face API key is stored in your keyring under the appropriate label. |
| | 3. Run the script using the CLI, or import `ChromaDBManager` into your application. |
| | 4. Optionally, configure the logging level via CLI argument: |
| | python -m chroma_db_manager --log-level DEBUG |
| | Important: |
| | To test the module, run it from the root directory using the `-m` flag: |
| | python -m chroma_db_manager |
| | Do not run the script directly from an IDE or its file path, or relative paths and module imports may break. |
| | Attributes: |
| | db_path (str): Path to the local ChromaDB storage. |
| | collection_name (str): Default ChromaDB collection name. |
| | model_mapping (dict): Maps user-friendly model names to Hugging Face model IDs. |
| | """ |
| | from dotenv import dotenv_values |
| | import os |
| | import sys |
| | import platform |
| | import logging |
| | import warnings |
| | import json |
| | import re |
| | from pathlib import Path |
| | import yaml |
| |
|
| | |
| | import keyring |
| | import chromadb |
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| | from huggingface_hub import login |
| |
|
| | |
| | from utils.metadata_utils import enhance_metadata |
| | from utils.logging_utils import setup_logging |
| |
|
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| |
|
| |
|
| | |
| | logger = setup_logging( |
| | logger_name=__name__, |
| | log_filename=f"{Path(__file__).stem}.log" |
| | ) |
| |
|
| | class ChromaDBManager: |
| | _instance = None |
| | |
| | def _load_repo_configuration(self): |
| | """ |
| | Load configuration from the .env file and initialize db_path. |
| | """ |
| | |
| | config = dotenv_values(".env") |
| |
|
| | |
| | relative_path = config.get("CHROMA_DB_PATH", "data/chroma_db") |
| | |
| | |
| | project_dir = Path(__file__).resolve().parent.parent |
| | self.db_path = project_dir / relative_path |
| |
|
| | |
| | self.collection_name = config.get("CHROMA_DB_COLLECTION", "documents") |
| |
|
| | |
| | self.logger.info(f"Using Chroma DB path from .env: {self.db_path}") |
| | self.logger.info(f"Using default collection: {self.collection_name}") |
| | print(f"Using Chroma DB path from .env: {self.db_path}") |
| | print(f"Using default collection: {self.collection_name}") |
| |
|
| | |
| | try: |
| | os.makedirs(self.db_path, exist_ok=True) |
| | except Exception as e: |
| | self.logger.warning(f"Failed to ensure DB path exists: {e}") |
| |
|
| | |
| | return { |
| | "db_path": str(self.db_path), |
| | "custom_settings": { |
| | "default_collection": self.collection_name |
| | } |
| | } |
| | |
| | def _load_and_initialize_model(self, model_size="medium", models_file="models.txt"): |
| | """ |
| | Load model mapping and initialize the embedding model. |
| | Args: |
| | model_size (str): Size of the model to use. Defaults to "medium". |
| | models_file (str): Path to the models mapping file. Defaults to "models.txt". |
| | """ |
| | try: |
| | |
| | model_mapping = self._load_model_mapping(models_file) |
| |
|
| | |
| | if model_size not in model_mapping: |
| | self.logger.warning(f"Model size '{model_size}' not found. Falling back to 'medium'.") |
| | model_size = "medium" |
| |
|
| | |
| | model_path = model_mapping[model_size] |
| | self.model_name = model_path |
| |
|
| | |
| | self.logger.info(f"Initializing embedding model: {model_path}") |
| | self.model = SentenceTransformer(model_path) |
| |
|
| | |
| | if hasattr(self.model, 'get_sentence_embedding_dimension'): |
| | embedding_dim = self.model.get_sentence_embedding_dimension() |
| | self.logger.info(f"Model embedding dimension: {embedding_dim}") |
| |
|
| | except Exception as e: |
| | self.logger.error(f"Error initializing embedding model: {e}") |
| | |
| | self.logger.warning("Falling back to default small model") |
| | default_model = "sentence-transformers/all-MiniLM-L6-v2" |
| | self.model = SentenceTransformer(default_model) |
| | self.model_name = default_model |
| |
|
| | def _ensure_db_directory_exists(self): |
| | """ |
| | Ensure that the database directory exists. If it doesn't, create it. |
| | """ |
| | if not os.path.exists(self.db_path): |
| | try: |
| | os.makedirs(self.db_path) |
| | self.logger.info(f"Created database directory at: {self.db_path}") |
| | except Exception as e: |
| | self.logger.error(f"Error creating database directory: {e}") |
| | raise |
| | else: |
| | self.logger.info(f"Database directory already exists at: {self.db_path}") |
| |
|
| | def __init__(self, model_size="medium", keys_file="keys.txt", models_file="models.txt", |
| | dataset_repo=None, db_path=None): |
| | """ |
| | Initialize the ChromaDBManager with optional overrides. |
| | """ |
| | if hasattr(self, '_initialized') and self._initialized: |
| | return |
| |
|
| | self.logger = setup_logging( |
| | logger_name="ChromaDBManager", |
| | log_filename="ChromaDBManager.log", |
| | ) |
| | self.logger.info("Initializing ChromaDBManager") |
| |
|
| | |
| | env = dotenv_values(".env") |
| |
|
| | self.db_path = db_path or env.get("CHROMA_DB_PATH", "data/chroma_db") |
| | self.collection_name = env.get("CHROMA_DB_COLLECTION", "documents") |
| | self.dataset_repo = dataset_repo or env.get("HF_DATASET_REPO") |
| |
|
| | self.logger.info(f"Using Chroma DB path: {self.db_path}") |
| | self.logger.info(f"Default collection: {self.collection_name}") |
| | if self.dataset_repo: |
| | self.logger.info(f"Using dataset repository: {self.dataset_repo}") |
| |
|
| | self._authenticate_huggingface(keys_file) |
| | self._ensure_db_directory_exists() |
| | self._load_and_initialize_model(model_size, models_file) |
| |
|
| | self.client = chromadb.PersistentClient(path=self.db_path) |
| | self.collection = self.client.get_or_create_collection(name=self.collection_name) |
| |
|
| | self.logger.info(f"ChromaDBManager initialized with model: {self.model_name}") |
| | self._initialized = True |
| | |
| | def _authenticate_huggingface(self, keys_file=None): |
| | """ |
| | Authenticate with Hugging Face using (in order of priority): |
| | 1. Environment variable (HF_API_KEY, HF_TOKEN, HUGGINGFACE_TOKEN) |
| | 2. macOS keyring (under "HF_API_KEY" and username "rressler") |
| | 3. Local keys file (default: config/keys.txt) |
| | """ |
| | try: |
| | token = ( |
| | os.environ.get("HF_API_KEY") |
| | or os.environ.get("HF_TOKEN") |
| | or os.environ.get("HUGGINGFACE_TOKEN") |
| | ) |
| |
|
| | |
| | if not token and platform.system() == 'Darwin': |
| | try: |
| | token = keyring.get_password("HF_API_KEY", "rressler") |
| | if token: |
| | self.logger.info("Using Hugging Face API key from macOS keyring") |
| | except Exception as e: |
| | self.logger.warning(f"Keyring access failed: {e}") |
| |
|
| | |
| | if not token and keys_file: |
| | try: |
| | keys_path = Path(keys_file) |
| | if not keys_path.is_absolute(): |
| | keys_path = Path(__file__).resolve().parent.parent / "config" / keys_path.name |
| |
|
| | if keys_path.exists(): |
| | with open(keys_path, "r") as f: |
| | for line in f: |
| | if line.strip().startswith("HF_API_KEY="): |
| | _, token = line.strip().split("=", 1) |
| | token = token.strip() |
| | if token and token != "your_api_key": |
| | self.logger.info("Using Hugging Face API key from keys file") |
| | break |
| | except Exception as e: |
| | self.logger.warning(f"Error reading keys file: {e}") |
| |
|
| | |
| | if token: |
| | try: |
| | login(token=token) |
| | self.hf_token = token |
| | self.logger.info("Hugging Face authentication successful") |
| | return True |
| | except Exception as e: |
| | self.logger.error(f"Failed to authenticate with Hugging Face: {e}") |
| | else: |
| | self.logger.warning("No Hugging Face API token available from any source") |
| |
|
| | except Exception as e: |
| | self.logger.error(f"Unexpected error during authentication: {e}") |
| |
|
| | self.hf_token = None |
| | return False |
| | |
| | def _load_model_mapping(self, models_file="models.txt"): |
| | """ |
| | Load embedding model mapping from models.txt JSON file or create it if missing. |
| | Supports both local and Hugging Face deployment. |
| | """ |
| | default_mapping = { |
| | "small": "sentence-transformers/all-MiniLM-L6-v2", |
| | "medium": "sentence-transformers/all-mpnet-base-v2", |
| | "large": "sentence-transformers/all-roberta-large-v1", |
| | "multilingual": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
| | "e5": "intfloat/e5-large-v2" |
| | } |
| |
|
| | try: |
| | |
| | project_root = Path(__file__).resolve().parent.parent |
| | config_dir = project_root / "config" |
| | config_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | models_path = config_dir / models_file |
| |
|
| | if not models_path.exists(): |
| | with models_path.open("w") as f: |
| | json.dump(default_mapping, f, indent=2) |
| | self.logger.info(f"Template models file created at {models_path}") |
| | return default_mapping |
| |
|
| | with models_path.open("r") as f: |
| | model_mapping = json.load(f) |
| |
|
| | if isinstance(model_mapping, dict) and model_mapping: |
| | self.logger.info(f"Loaded {len(model_mapping)} models from {models_path}") |
| | return model_mapping |
| | else: |
| | self.logger.warning(f"Invalid or empty model mapping in {models_path}. Using defaults.") |
| | return default_mapping |
| |
|
| | except Exception as e: |
| | self.logger.error(f"Failed to load model mapping: {e}. Using defaults.") |
| | return default_mapping |
| | Print(f"Load model mapping: {e}. Using defaults.") |
| | |
| | def generate_valid_id(self, text): |
| | """Sanitize the ID by removing special characters and limiting length.""" |
| | if text is None: |
| | text = "untitled" |
| | |
| | |
| | sanitized_text = re.sub(r"[^\w\s]", "", str(text)) |
| | |
| | |
| | sanitized_text = sanitized_text.replace(" ", "_")[:20] |
| | |
| | return sanitized_text |
| |
|
| | def get_collection(self, name="documents"): |
| | """Get or create a collection by name.""" |
| | try: |
| | collection = self.client.get_or_create_collection(name=name) |
| | print(f"✅ Collection '{name}' successfully loaded.") |
| | print(f"📄 Number of docs: {len(collection.get()['documents'])}") |
| | return collection |
| | except Exception as e: |
| | self.logger.error(f"Error getting/creating collection {name}: {e}") |
| | raise |
| | |
| | def embed_text(self, text): |
| | """Generate embeddings for the given text.""" |
| | try: |
| | |
| | if text is None: |
| | text = "" |
| | |
| | |
| | embeddings = self.model.encode(str(text)).tolist() |
| | return embeddings |
| | except Exception as e: |
| | self.logger.error(f"Error generating embeddings: {e}") |
| | raise |
| | |
| | def add_document(self, text, metadata, doc_id=None, collection_name="documents"): |
| | """ |
| | Add a document to the specified collection with enhanced metadata. |
| | Args: |
| | text (str): Document text content to embed and store |
| | metadata (dict): Metadata associated with the document |
| | doc_id (str, optional): Document ID, generated if not provided |
| | collection_name (str, optional): Target collection name |
| | Returns: |
| | str: Document ID of the added document |
| | """ |
| | if not text.strip(): |
| | raise ValueError("Cannot add an empty document.") |
| |
|
| | collection = self.get_collection(collection_name) |
| | embedding = self.embed_text(text) |
| |
|
| | |
| | title = metadata.get("title", "untitled") |
| | base_id = self.generate_valid_id(title) |
| |
|
| | if doc_id is None: |
| | doc_id = f"{base_id}_{hash(text) % 10000}" |
| |
|
| | |
| | enhanced_metadata = enhance_metadata(metadata) |
| |
|
| | |
| | self.logger.debug(f"Enhanced metadata for document '{base_id}': {enhanced_metadata}") |
| |
|
| | |
| | collection.upsert( |
| | documents=[text], |
| | embeddings=[embedding], |
| | metadatas=[enhanced_metadata], |
| | ids=[doc_id] |
| | ) |
| |
|
| | self.logger.info(f"Added document to '{collection_name}' with ID: {doc_id} | Title: {title}") |
| | return doc_id |
| | |
| | def query(self, query_text, n_results=5, collection_name="documents"): |
| | """Query the collection and return results.""" |
| | collection = self.get_collection(collection_name) |
| | query_embedding = self.embed_text(query_text) |
| | |
| | results = collection.query( |
| | query_embeddings=[query_embedding], |
| | n_results=n_results |
| | ) |
| | |
| | return results |
| | |
| | |
| | chroma_manager = ChromaDBManager() |
| |
|
| | |
| | def get_chroma_manager(model_size="medium", keys_file="keys.txt", models_file="models.txt", db_path=None): |
| | """ |
| | Get the ChromaDBManager singleton instance with specified configuration. |
| | If the instance already exists, returns it without reinitializing. |
| | |
| | Args: |
| | model_size (str, optional): Size of the embedding model. |
| | keys_file (str, optional): Path to the keys file. |
| | models_file (str, optional): Path to the models mapping file. |
| | db_path (str, optional): Path to the ChromaDB database. |
| | """ |
| | |
| | if hasattr(get_chroma_manager, '_instance') and get_chroma_manager._instance is not None: |
| | return get_chroma_manager._instance |
| | |
| | |
| | instance = ChromaDBManager( |
| | model_size=model_size, |
| | keys_file=keys_file, |
| | models_file=models_file, |
| | db_path=db_path |
| | ) |
| | |
| | |
| | get_chroma_manager._instance = instance |
| | |
| | return instance |
| |
|
| | def query_documents(query_text, n_results=3): |
| | """Query documents in the default collection.""" |
| | return chroma_manager.query(query_text, n_results) |
| |
|
| | def add_document(text, metadata, doc_id=None): |
| | """Add a document to the default collection.""" |
| | return chroma_manager.add_document(text, metadata, doc_id) |
| |
|
| | def setup_logger(level=logging.INFO): |
| | root_logger = logging.getLogger() |
| | if not root_logger.handlers: |
| | logging.basicConfig( |
| | level=level, |
| | format="%(asctime)s [%(levelname)s] %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | ) |
| | else: |
| | |
| | root_logger.setLevel(level) |
| |
|
| | def init_chroma_db_manager(config: dict) -> ChromaDBManager: |
| | """ |
| | Convenience function to initialize ChromaDBManager using a config dict. |
| | Intended for external scripts using YAML configuration. |
| | """ |
| | return ChromaDBManager( |
| | model_size=config.get("model", "medium"), |
| | keys_file=config.get("keys_file", "keys.txt"), |
| | models_file=config.get("models_file", "models.txt"), |
| | dataset_repo=config.get("repo"), |
| | db_path=config.get("db_path") |
| | ) |
| |
|
| | def main(): |
| | |
| | config_path = "config/chroma_config.yml" |
| | try: |
| | with open(config_path, 'r') as file: |
| | config = yaml.safe_load(file) or {} |
| | log_level = config.get('log_level', 'INFO').upper() |
| | except Exception as e: |
| | log_level = 'INFO' |
| | print(f"Could not load {config_path}, using default log level {log_level}: {e}") |
| |
|
| | |
| | setup_logger(level=log_level) |
| |
|
| | |
| | chroma = ChromaDBManager() |
| |
|
| | |
| | print("ChromaDBManager initialized successfully.") |
| |
|
| | if __name__ == "__main__": |
| | main() |