demo10 / utils /chroma_utils.py
chaaim123's picture
Create utils/chroma_utils.py
c68c9dd verified
"""
ChromaDBManager: A utility class for managing ChromaDB configurations, embeddings, and logging.
This module provides an interface to configure and interact with a ChromaDB document store
using Hugging Face embedding models. It supports persistent storage configuration, model selection,
and logging setup. Designed for integration in local development environments or cloud deployments.
Design Assumptions:
- Configuration values (e.g., DB path, collection name) are loaded from a `.env` file.
- Hugging Face API keys are stored securely in the macOS keychain.
- Logging is configured per class and can output to both the console and file.
- Embedding models are specified via a `models.txt` file, which is automatically created if missing.
- The default models support a range of needs: small, medium, multilingual, and e5 variants.
Core Logic:
- Loads config values using `dotenv_values`.
- Ensures the persistence path exists or is created.
- Initializes a ChromaDB instance with the chosen embedding model.
- Logs system and model configuration for traceability.
- Supports both local and remote Hugging Face embeddings.
Instructions for Use:
1. Create a `.env` file in the repo root with keys like:
CHROMA_DB_PATH=data/chroma_db
CHROMA_DB_COLLECTION=documents
2. Ensure your Hugging Face API key is stored in your keyring under the appropriate label.
3. Run the script using the CLI, or import `ChromaDBManager` into your application.
4. Optionally, configure the logging level via CLI argument:
python -m chroma_db_manager --log-level DEBUG
Important:
To test the module, run it from the root directory using the `-m` flag:
python -m chroma_db_manager
Do not run the script directly from an IDE or its file path, or relative paths and module imports may break.
Attributes:
db_path (str): Path to the local ChromaDB storage.
collection_name (str): Default ChromaDB collection name.
model_mapping (dict): Maps user-friendly model names to Hugging Face model IDs.
"""
from dotenv import dotenv_values
import os
import sys
import platform
import logging
import warnings
import json
import re
from pathlib import Path
import yaml
# 3rd-party libraries
import keyring
import chromadb
import numpy as np
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
# Project utilities (absolute imports)
from utils.metadata_utils import enhance_metadata
from utils.logging_utils import setup_logging
warnings.filterwarnings("ignore", category=FutureWarning)
# Configure logging with debug mode from arguments
logger = setup_logging(
logger_name=__name__,
log_filename=f"{Path(__file__).stem}.log"
)
class ChromaDBManager:
_instance = None
def _load_repo_configuration(self):
"""
Load configuration from the .env file and initialize db_path.
"""
# Load .env variables
config = dotenv_values(".env")
# Get the relative path from the environment variable
relative_path = config.get("CHROMA_DB_PATH", "data/chroma_db")
# Resolve the relative path to the project directory
project_dir = Path(__file__).resolve().parent.parent # Assuming this script is inside the au_advisor folder
self.db_path = project_dir / relative_path # Combine project directory with the relative path
# Set the collection name
self.collection_name = config.get("CHROMA_DB_COLLECTION", "documents")
# Log the paths being used
self.logger.info(f"Using Chroma DB path from .env: {self.db_path}")
self.logger.info(f"Using default collection: {self.collection_name}")
print(f"Using Chroma DB path from .env: {self.db_path}")
print(f"Using default collection: {self.collection_name}")
# Optionally ensure the DB path exists
try:
os.makedirs(self.db_path, exist_ok=True)
except Exception as e:
self.logger.warning(f"Failed to ensure DB path exists: {e}")
# Return as a config dict only if needed elsewhere
return {
"db_path": str(self.db_path), # Return as string in case the path object needs to be used
"custom_settings": {
"default_collection": self.collection_name
}
}
def _load_and_initialize_model(self, model_size="medium", models_file="models.txt"):
"""
Load model mapping and initialize the embedding model.
Args:
model_size (str): Size of the model to use. Defaults to "medium".
models_file (str): Path to the models mapping file. Defaults to "models.txt".
"""
try:
# Load the model mapping from the file
model_mapping = self._load_model_mapping(models_file)
# Validate the requested model size
if model_size not in model_mapping:
self.logger.warning(f"Model size '{model_size}' not found. Falling back to 'medium'.")
model_size = "medium"
# Get the model path
model_path = model_mapping[model_size]
self.model_name = model_path # Store the model name for logging
# Initialize the SentenceTransformer model
self.logger.info(f"Initializing embedding model: {model_path}")
self.model = SentenceTransformer(model_path)
# Optional: Log model details
if hasattr(self.model, 'get_sentence_embedding_dimension'):
embedding_dim = self.model.get_sentence_embedding_dimension()
self.logger.info(f"Model embedding dimension: {embedding_dim}")
except Exception as e:
self.logger.error(f"Error initializing embedding model: {e}")
# Fallback to a default model if initialization fails
self.logger.warning("Falling back to default small model")
default_model = "sentence-transformers/all-MiniLM-L6-v2"
self.model = SentenceTransformer(default_model)
self.model_name = default_model
def _ensure_db_directory_exists(self):
"""
Ensure that the database directory exists. If it doesn't, create it.
"""
if not os.path.exists(self.db_path):
try:
os.makedirs(self.db_path)
self.logger.info(f"Created database directory at: {self.db_path}")
except Exception as e:
self.logger.error(f"Error creating database directory: {e}")
raise
else:
self.logger.info(f"Database directory already exists at: {self.db_path}")
def __init__(self, model_size="medium", keys_file="keys.txt", models_file="models.txt",
dataset_repo=None, db_path=None):
"""
Initialize the ChromaDBManager with optional overrides.
"""
if hasattr(self, '_initialized') and self._initialized:
return
self.logger = setup_logging(
logger_name="ChromaDBManager",
log_filename="ChromaDBManager.log",
)
self.logger.info("Initializing ChromaDBManager")
# Load .env configuration
env = dotenv_values(".env")
self.db_path = db_path or env.get("CHROMA_DB_PATH", "data/chroma_db")
self.collection_name = env.get("CHROMA_DB_COLLECTION", "documents")
self.dataset_repo = dataset_repo or env.get("HF_DATASET_REPO")
self.logger.info(f"Using Chroma DB path: {self.db_path}")
self.logger.info(f"Default collection: {self.collection_name}")
if self.dataset_repo:
self.logger.info(f"Using dataset repository: {self.dataset_repo}")
self._authenticate_huggingface(keys_file)
self._ensure_db_directory_exists()
self._load_and_initialize_model(model_size, models_file)
self.client = chromadb.PersistentClient(path=self.db_path)
self.collection = self.client.get_or_create_collection(name=self.collection_name)
self.logger.info(f"ChromaDBManager initialized with model: {self.model_name}")
self._initialized = True
def _authenticate_huggingface(self, keys_file=None):
"""
Authenticate with Hugging Face using (in order of priority):
1. Environment variable (HF_API_KEY, HF_TOKEN, HUGGINGFACE_TOKEN)
2. macOS keyring (under "HF_API_KEY" and username "rressler")
3. Local keys file (default: config/keys.txt)
"""
try:
token = (
os.environ.get("HF_API_KEY")
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACE_TOKEN")
)
# Try keyring only on macOS
if not token and platform.system() == 'Darwin':
try:
token = keyring.get_password("HF_API_KEY", "rressler")
if token:
self.logger.info("Using Hugging Face API key from macOS keyring")
except Exception as e:
self.logger.warning(f"Keyring access failed: {e}")
# Try keys file (default: config/keys.txt)
if not token and keys_file:
try:
keys_path = Path(keys_file)
if not keys_path.is_absolute():
keys_path = Path(__file__).resolve().parent.parent / "config" / keys_path.name
if keys_path.exists():
with open(keys_path, "r") as f:
for line in f:
if line.strip().startswith("HF_API_KEY="):
_, token = line.strip().split("=", 1)
token = token.strip()
if token and token != "your_api_key":
self.logger.info("Using Hugging Face API key from keys file")
break
except Exception as e:
self.logger.warning(f"Error reading keys file: {e}")
# Try to login if we have a token
if token:
try:
login(token=token)
self.hf_token = token
self.logger.info("Hugging Face authentication successful")
return True
except Exception as e:
self.logger.error(f"Failed to authenticate with Hugging Face: {e}")
else:
self.logger.warning("No Hugging Face API token available from any source")
except Exception as e:
self.logger.error(f"Unexpected error during authentication: {e}")
self.hf_token = None
return False
def _load_model_mapping(self, models_file="models.txt"):
"""
Load embedding model mapping from models.txt JSON file or create it if missing.
Supports both local and Hugging Face deployment.
"""
default_mapping = {
"small": "sentence-transformers/all-MiniLM-L6-v2",
"medium": "sentence-transformers/all-mpnet-base-v2",
"large": "sentence-transformers/all-roberta-large-v1",
"multilingual": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"e5": "intfloat/e5-large-v2"
}
try:
# Use a config directory relative to this file
project_root = Path(__file__).resolve().parent.parent
config_dir = project_root / "config"
config_dir.mkdir(parents=True, exist_ok=True)
models_path = config_dir / models_file
if not models_path.exists():
with models_path.open("w") as f:
json.dump(default_mapping, f, indent=2)
self.logger.info(f"Template models file created at {models_path}")
return default_mapping
with models_path.open("r") as f:
model_mapping = json.load(f)
if isinstance(model_mapping, dict) and model_mapping:
self.logger.info(f"Loaded {len(model_mapping)} models from {models_path}")
return model_mapping
else:
self.logger.warning(f"Invalid or empty model mapping in {models_path}. Using defaults.")
return default_mapping
except Exception as e:
self.logger.error(f"Failed to load model mapping: {e}. Using defaults.")
return default_mapping
Print(f"Load model mapping: {e}. Using defaults.")
def generate_valid_id(self, text):
"""Sanitize the ID by removing special characters and limiting length."""
if text is None:
text = "untitled"
# Remove non-alphanumeric chars
sanitized_text = re.sub(r"[^\w\s]", "", str(text))
# Replace spaces with underscores and limit length
sanitized_text = sanitized_text.replace(" ", "_")[:20]
return sanitized_text
def get_collection(self, name="documents"):
"""Get or create a collection by name."""
try:
collection = self.client.get_or_create_collection(name=name)
print(f"✅ Collection '{name}' successfully loaded.")
print(f"📄 Number of docs: {len(collection.get()['documents'])}")
return collection
except Exception as e:
self.logger.error(f"Error getting/creating collection {name}: {e}")
raise
def embed_text(self, text):
"""Generate embeddings for the given text."""
try:
# Convert text to a string and handle potential None input
if text is None:
text = ""
# Generate embeddings
embeddings = self.model.encode(str(text)).tolist()
return embeddings
except Exception as e:
self.logger.error(f"Error generating embeddings: {e}")
raise
def add_document(self, text, metadata, doc_id=None, collection_name="documents"):
"""
Add a document to the specified collection with enhanced metadata.
Args:
text (str): Document text content to embed and store
metadata (dict): Metadata associated with the document
doc_id (str, optional): Document ID, generated if not provided
collection_name (str, optional): Target collection name
Returns:
str: Document ID of the added document
"""
if not text.strip():
raise ValueError("Cannot add an empty document.")
collection = self.get_collection(collection_name)
embedding = self.embed_text(text)
# Generate or normalize doc_id
title = metadata.get("title", "untitled")
base_id = self.generate_valid_id(title)
if doc_id is None:
doc_id = f"{base_id}_{hash(text) % 10000}"
# Enhance and log metadata
enhanced_metadata = enhance_metadata(metadata)
# Log key additions
self.logger.debug(f"Enhanced metadata for document '{base_id}': {enhanced_metadata}")
# Upsert into ChromaDB
collection.upsert(
documents=[text],
embeddings=[embedding],
metadatas=[enhanced_metadata],
ids=[doc_id]
)
self.logger.info(f"Added document to '{collection_name}' with ID: {doc_id} | Title: {title}")
return doc_id
def query(self, query_text, n_results=5, collection_name="documents"):
"""Query the collection and return results."""
collection = self.get_collection(collection_name)
query_embedding = self.embed_text(query_text)
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
return results
# Create a singleton instance
chroma_manager = ChromaDBManager()
# Convenience functions
def get_chroma_manager(model_size="medium", keys_file="keys.txt", models_file="models.txt", db_path=None):
"""
Get the ChromaDBManager singleton instance with specified configuration.
If the instance already exists, returns it without reinitializing.
Args:
model_size (str, optional): Size of the embedding model.
keys_file (str, optional): Path to the keys file.
models_file (str, optional): Path to the models mapping file.
db_path (str, optional): Path to the ChromaDB database.
"""
# Check if the instance already exists
if hasattr(get_chroma_manager, '_instance') and get_chroma_manager._instance is not None:
return get_chroma_manager._instance
# Create a new instance with the specified configuration
instance = ChromaDBManager(
model_size=model_size,
keys_file=keys_file,
models_file=models_file,
db_path=db_path
)
# Store the instance as a static variable
get_chroma_manager._instance = instance
return instance
def query_documents(query_text, n_results=3):
"""Query documents in the default collection."""
return chroma_manager.query(query_text, n_results)
def add_document(text, metadata, doc_id=None):
"""Add a document to the default collection."""
return chroma_manager.add_document(text, metadata, doc_id)
def setup_logger(level=logging.INFO):
root_logger = logging.getLogger()
if not root_logger.handlers:
logging.basicConfig(
level=level,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
else:
# Just set the level if already configured elsewhere
root_logger.setLevel(level)
def init_chroma_db_manager(config: dict) -> ChromaDBManager:
"""
Convenience function to initialize ChromaDBManager using a config dict.
Intended for external scripts using YAML configuration.
"""
return ChromaDBManager(
model_size=config.get("model", "medium"),
keys_file=config.get("keys_file", "keys.txt"),
models_file=config.get("models_file", "models.txt"),
dataset_repo=config.get("repo"), # Optional: for HF datasets
db_path=config.get("db_path") # Optional: override .env default
)
def main():
# Load config YAML for logging level if needed
config_path = "config/chroma_config.yml"
try:
with open(config_path, 'r') as file:
config = yaml.safe_load(file) or {}
log_level = config.get('log_level', 'INFO').upper() # Default to INFO if not set in YAML
except Exception as e:
log_level = 'INFO' # Default to INFO if there is an error loading config
print(f"Could not load {config_path}, using default log level {log_level}: {e}")
# Set up logging
setup_logger(level=log_level)
# Initialize ChromaDBManager (No CLI args, just configuration file)
chroma = ChromaDBManager()
# Example of using ChromaDBManager
print("ChromaDBManager initialized successfully.")
if __name__ == "__main__":
main()