chaaim123 commited on
Commit
ec123f6
·
verified ·
1 Parent(s): bb09f1a

Create utils/chroma_utils.py

Browse files
Files changed (1) hide show
  1. utils/chroma_utils.py +472 -0
utils/chroma_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ChromaDBManager: A utility class for managing ChromaDB configurations, embeddings, and logging.
3
+ This module provides an interface to configure and interact with a ChromaDB document store
4
+ using Hugging Face embedding models. It supports persistent storage configuration, model selection,
5
+ and logging setup. Designed for integration in local development environments or cloud deployments.
6
+ Design Assumptions:
7
+ - Configuration values (e.g., DB path, collection name) are loaded from a `.env` file.
8
+ - Hugging Face API keys are stored securely in the macOS keychain.
9
+ - Logging is configured per class and can output to both the console and file.
10
+ - Embedding models are specified via a `models.txt` file, which is automatically created if missing.
11
+ - The default models support a range of needs: small, medium, multilingual, and e5 variants.
12
+ Core Logic:
13
+ - Loads config values using `dotenv_values`.
14
+ - Ensures the persistence path exists or is created.
15
+ - Initializes a ChromaDB instance with the chosen embedding model.
16
+ - Logs system and model configuration for traceability.
17
+ - Supports both local and remote Hugging Face embeddings.
18
+ Instructions for Use:
19
+ 1. Create a `.env` file in the repo root with keys like:
20
+ CHROMA_DB_PATH=data/chroma_db
21
+ CHROMA_DB_COLLECTION=documents
22
+ 2. Ensure your Hugging Face API key is stored in your keyring under the appropriate label.
23
+ 3. Run the script using the CLI, or import `ChromaDBManager` into your application.
24
+ 4. Optionally, configure the logging level via CLI argument:
25
+ python -m chroma_db_manager --log-level DEBUG
26
+ Important:
27
+ To test the module, run it from the root directory using the `-m` flag:
28
+ python -m chroma_db_manager
29
+ Do not run the script directly from an IDE or its file path, or relative paths and module imports may break.
30
+ Attributes:
31
+ db_path (str): Path to the local ChromaDB storage.
32
+ collection_name (str): Default ChromaDB collection name.
33
+ model_mapping (dict): Maps user-friendly model names to Hugging Face model IDs.
34
+ """
35
+ from dotenv import dotenv_values
36
+ import os
37
+ import sys
38
+ import platform
39
+ import logging
40
+ import warnings
41
+ import json
42
+ import re
43
+ from pathlib import Path
44
+ import yaml
45
+
46
+ # 3rd-party libraries
47
+ import keyring
48
+ import chromadb
49
+ import numpy as np
50
+ from sentence_transformers import SentenceTransformer
51
+ from huggingface_hub import login
52
+
53
+ # Project utilities (absolute imports)
54
+ from utils.metadata_utils import enhance_metadata
55
+ from utils.logging_utils import setup_logging
56
+
57
+ warnings.filterwarnings("ignore", category=FutureWarning)
58
+
59
+
60
+ # Configure logging with debug mode from arguments
61
+ logger = setup_logging(
62
+ logger_name=__name__,
63
+ log_filename=f"{Path(__file__).stem}.log"
64
+ )
65
+
66
+ class ChromaDBManager:
67
+ _instance = None
68
+
69
+ def _load_repo_configuration(self):
70
+ """
71
+ Load configuration from the .env file and initialize db_path.
72
+ """
73
+ # Load .env variables
74
+ config = dotenv_values(".env")
75
+
76
+ # Get the relative path from the environment variable
77
+ relative_path = config.get("CHROMA_DB_PATH", "data/chroma_db")
78
+
79
+ # Resolve the relative path to the project directory
80
+ project_dir = Path(__file__).resolve().parent.parent # Assuming this script is inside the au_advisor folder
81
+ self.db_path = project_dir / relative_path # Combine project directory with the relative path
82
+
83
+ # Set the collection name
84
+ self.collection_name = config.get("CHROMA_DB_COLLECTION", "documents")
85
+
86
+ # Log the paths being used
87
+ self.logger.info(f"Using Chroma DB path from .env: {self.db_path}")
88
+ self.logger.info(f"Using default collection: {self.collection_name}")
89
+ print(f"Using Chroma DB path from .env: {self.db_path}")
90
+ print(f"Using default collection: {self.collection_name}")
91
+
92
+ # Optionally ensure the DB path exists
93
+ try:
94
+ os.makedirs(self.db_path, exist_ok=True)
95
+ except Exception as e:
96
+ self.logger.warning(f"Failed to ensure DB path exists: {e}")
97
+
98
+ # Return as a config dict only if needed elsewhere
99
+ return {
100
+ "db_path": str(self.db_path), # Return as string in case the path object needs to be used
101
+ "custom_settings": {
102
+ "default_collection": self.collection_name
103
+ }
104
+ }
105
+
106
+ def _load_and_initialize_model(self, model_size="medium", models_file="models.txt"):
107
+ """
108
+ Load model mapping and initialize the embedding model.
109
+ Args:
110
+ model_size (str): Size of the model to use. Defaults to "medium".
111
+ models_file (str): Path to the models mapping file. Defaults to "models.txt".
112
+ """
113
+ try:
114
+ # Load the model mapping from the file
115
+ model_mapping = self._load_model_mapping(models_file)
116
+
117
+ # Validate the requested model size
118
+ if model_size not in model_mapping:
119
+ self.logger.warning(f"Model size '{model_size}' not found. Falling back to 'medium'.")
120
+ model_size = "medium"
121
+
122
+ # Get the model path
123
+ model_path = model_mapping[model_size]
124
+ self.model_name = model_path # Store the model name for logging
125
+
126
+ # Initialize the SentenceTransformer model
127
+ self.logger.info(f"Initializing embedding model: {model_path}")
128
+ self.model = SentenceTransformer(model_path)
129
+
130
+ # Optional: Log model details
131
+ if hasattr(self.model, 'get_sentence_embedding_dimension'):
132
+ embedding_dim = self.model.get_sentence_embedding_dimension()
133
+ self.logger.info(f"Model embedding dimension: {embedding_dim}")
134
+
135
+ except Exception as e:
136
+ self.logger.error(f"Error initializing embedding model: {e}")
137
+ # Fallback to a default model if initialization fails
138
+ self.logger.warning("Falling back to default small model")
139
+ default_model = "sentence-transformers/all-MiniLM-L6-v2"
140
+ self.model = SentenceTransformer(default_model)
141
+ self.model_name = default_model
142
+
143
+ def _ensure_db_directory_exists(self):
144
+ """
145
+ Ensure that the database directory exists. If it doesn't, create it.
146
+ """
147
+ if not os.path.exists(self.db_path):
148
+ try:
149
+ os.makedirs(self.db_path)
150
+ self.logger.info(f"Created database directory at: {self.db_path}")
151
+ except Exception as e:
152
+ self.logger.error(f"Error creating database directory: {e}")
153
+ raise
154
+ else:
155
+ self.logger.info(f"Database directory already exists at: {self.db_path}")
156
+
157
+ def __init__(self, model_size="medium", keys_file="keys.txt", models_file="models.txt",
158
+ dataset_repo=None, db_path=None):
159
+ """
160
+ Initialize the ChromaDBManager with optional overrides.
161
+ """
162
+ if hasattr(self, '_initialized') and self._initialized:
163
+ return
164
+
165
+ self.logger = setup_logging(
166
+ logger_name="ChromaDBManager",
167
+ log_filename="ChromaDBManager.log",
168
+ )
169
+ self.logger.info("Initializing ChromaDBManager")
170
+
171
+ # Load .env configuration
172
+ env = dotenv_values(".env")
173
+
174
+ self.db_path = db_path or env.get("CHROMA_DB_PATH", "data/chroma_db")
175
+ self.collection_name = env.get("CHROMA_DB_COLLECTION", "documents")
176
+ self.dataset_repo = dataset_repo or env.get("HF_DATASET_REPO")
177
+
178
+ self.logger.info(f"Using Chroma DB path: {self.db_path}")
179
+ self.logger.info(f"Default collection: {self.collection_name}")
180
+ if self.dataset_repo:
181
+ self.logger.info(f"Using dataset repository: {self.dataset_repo}")
182
+
183
+ self._authenticate_huggingface(keys_file)
184
+ self._ensure_db_directory_exists()
185
+ self._load_and_initialize_model(model_size, models_file)
186
+
187
+ self.client = chromadb.PersistentClient(path=self.db_path)
188
+ self.collection = self.client.get_or_create_collection(name=self.collection_name)
189
+
190
+ self.logger.info(f"ChromaDBManager initialized with model: {self.model_name}")
191
+ self._initialized = True
192
+
193
+ def _authenticate_huggingface(self, keys_file=None):
194
+ """
195
+ Authenticate with Hugging Face using (in order of priority):
196
+ 1. Environment variable (HF_API_KEY, HF_TOKEN, HUGGINGFACE_TOKEN)
197
+ 2. macOS keyring (under "HF_API_KEY" and username "rressler")
198
+ 3. Local keys file (default: config/keys.txt)
199
+ """
200
+ try:
201
+ token = (
202
+ os.environ.get("HF_API_KEY")
203
+ or os.environ.get("HF_TOKEN")
204
+ or os.environ.get("HUGGINGFACE_TOKEN")
205
+ )
206
+
207
+ # Try keyring only on macOS
208
+ if not token and platform.system() == 'Darwin':
209
+ try:
210
+ token = keyring.get_password("HF_API_KEY", "rressler")
211
+ if token:
212
+ self.logger.info("Using Hugging Face API key from macOS keyring")
213
+ except Exception as e:
214
+ self.logger.warning(f"Keyring access failed: {e}")
215
+
216
+ # Try keys file (default: config/keys.txt)
217
+ if not token and keys_file:
218
+ try:
219
+ keys_path = Path(keys_file)
220
+ if not keys_path.is_absolute():
221
+ keys_path = Path(__file__).resolve().parent.parent / "config" / keys_path.name
222
+
223
+ if keys_path.exists():
224
+ with open(keys_path, "r") as f:
225
+ for line in f:
226
+ if line.strip().startswith("HF_API_KEY="):
227
+ _, token = line.strip().split("=", 1)
228
+ token = token.strip()
229
+ if token and token != "your_api_key":
230
+ self.logger.info("Using Hugging Face API key from keys file")
231
+ break
232
+ except Exception as e:
233
+ self.logger.warning(f"Error reading keys file: {e}")
234
+
235
+ # Try to login if we have a token
236
+ if token:
237
+ try:
238
+ login(token=token)
239
+ self.hf_token = token
240
+ self.logger.info("Hugging Face authentication successful")
241
+ return True
242
+ except Exception as e:
243
+ self.logger.error(f"Failed to authenticate with Hugging Face: {e}")
244
+ else:
245
+ self.logger.warning("No Hugging Face API token available from any source")
246
+
247
+ except Exception as e:
248
+ self.logger.error(f"Unexpected error during authentication: {e}")
249
+
250
+ self.hf_token = None
251
+ return False
252
+
253
+ def _load_model_mapping(self, models_file="models.txt"):
254
+ """
255
+ Load embedding model mapping from models.txt JSON file or create it if missing.
256
+ Supports both local and Hugging Face deployment.
257
+ """
258
+ default_mapping = {
259
+ "small": "sentence-transformers/all-MiniLM-L6-v2",
260
+ "medium": "sentence-transformers/all-mpnet-base-v2",
261
+ "large": "sentence-transformers/all-roberta-large-v1",
262
+ "multilingual": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
263
+ "e5": "intfloat/e5-large-v2"
264
+ }
265
+
266
+ try:
267
+ # Use a config directory relative to this file
268
+ project_root = Path(__file__).resolve().parent.parent
269
+ config_dir = project_root / "config"
270
+ config_dir.mkdir(parents=True, exist_ok=True)
271
+
272
+ models_path = config_dir / models_file
273
+
274
+ if not models_path.exists():
275
+ with models_path.open("w") as f:
276
+ json.dump(default_mapping, f, indent=2)
277
+ self.logger.info(f"Template models file created at {models_path}")
278
+ return default_mapping
279
+
280
+ with models_path.open("r") as f:
281
+ model_mapping = json.load(f)
282
+
283
+ if isinstance(model_mapping, dict) and model_mapping:
284
+ self.logger.info(f"Loaded {len(model_mapping)} models from {models_path}")
285
+ return model_mapping
286
+ else:
287
+ self.logger.warning(f"Invalid or empty model mapping in {models_path}. Using defaults.")
288
+ return default_mapping
289
+
290
+ except Exception as e:
291
+ self.logger.error(f"Failed to load model mapping: {e}. Using defaults.")
292
+ return default_mapping
293
+ Print(f"Load model mapping: {e}. Using defaults.")
294
+
295
+ def generate_valid_id(self, text):
296
+ """Sanitize the ID by removing special characters and limiting length."""
297
+ if text is None:
298
+ text = "untitled"
299
+
300
+ # Remove non-alphanumeric chars
301
+ sanitized_text = re.sub(r"[^\w\s]", "", str(text))
302
+
303
+ # Replace spaces with underscores and limit length
304
+ sanitized_text = sanitized_text.replace(" ", "_")[:20]
305
+
306
+ return sanitized_text
307
+
308
+ def get_collection(self, name="documents"):
309
+ """Get or create a collection by name."""
310
+ try:
311
+ collection = self.client.get_or_create_collection(name=name)
312
+ print(f"✅ Collection '{name}' successfully loaded.")
313
+ print(f"📄 Number of docs: {len(collection.get()['documents'])}")
314
+ return collection
315
+ except Exception as e:
316
+ self.logger.error(f"Error getting/creating collection {name}: {e}")
317
+ raise
318
+
319
+ def embed_text(self, text):
320
+ """Generate embeddings for the given text."""
321
+ try:
322
+ # Convert text to a string and handle potential None input
323
+ if text is None:
324
+ text = ""
325
+
326
+ # Generate embeddings
327
+ embeddings = self.model.encode(str(text)).tolist()
328
+ return embeddings
329
+ except Exception as e:
330
+ self.logger.error(f"Error generating embeddings: {e}")
331
+ raise
332
+
333
+ def add_document(self, text, metadata, doc_id=None, collection_name="documents"):
334
+ """
335
+ Add a document to the specified collection with enhanced metadata.
336
+ Args:
337
+ text (str): Document text content to embed and store
338
+ metadata (dict): Metadata associated with the document
339
+ doc_id (str, optional): Document ID, generated if not provided
340
+ collection_name (str, optional): Target collection name
341
+ Returns:
342
+ str: Document ID of the added document
343
+ """
344
+ if not text.strip():
345
+ raise ValueError("Cannot add an empty document.")
346
+
347
+ collection = self.get_collection(collection_name)
348
+ embedding = self.embed_text(text)
349
+
350
+ # Generate or normalize doc_id
351
+ title = metadata.get("title", "untitled")
352
+ base_id = self.generate_valid_id(title)
353
+
354
+ if doc_id is None:
355
+ doc_id = f"{base_id}_{hash(text) % 10000}"
356
+
357
+ # Enhance and log metadata
358
+ enhanced_metadata = enhance_metadata(metadata)
359
+
360
+ # Log key additions
361
+ self.logger.debug(f"Enhanced metadata for document '{base_id}': {enhanced_metadata}")
362
+
363
+ # Upsert into ChromaDB
364
+ collection.upsert(
365
+ documents=[text],
366
+ embeddings=[embedding],
367
+ metadatas=[enhanced_metadata],
368
+ ids=[doc_id]
369
+ )
370
+
371
+ self.logger.info(f"Added document to '{collection_name}' with ID: {doc_id} | Title: {title}")
372
+ return doc_id
373
+
374
+ def query(self, query_text, n_results=5, collection_name="documents"):
375
+ """Query the collection and return results."""
376
+ collection = self.get_collection(collection_name)
377
+ query_embedding = self.embed_text(query_text)
378
+
379
+ results = collection.query(
380
+ query_embeddings=[query_embedding],
381
+ n_results=n_results
382
+ )
383
+
384
+ return results
385
+
386
+ # Create a singleton instance
387
+ chroma_manager = ChromaDBManager()
388
+
389
+ # Convenience functions
390
+ def get_chroma_manager(model_size="medium", keys_file="keys.txt", models_file="models.txt", db_path=None):
391
+ """
392
+ Get the ChromaDBManager singleton instance with specified configuration.
393
+ If the instance already exists, returns it without reinitializing.
394
+
395
+ Args:
396
+ model_size (str, optional): Size of the embedding model.
397
+ keys_file (str, optional): Path to the keys file.
398
+ models_file (str, optional): Path to the models mapping file.
399
+ db_path (str, optional): Path to the ChromaDB database.
400
+ """
401
+ # Check if the instance already exists
402
+ if hasattr(get_chroma_manager, '_instance') and get_chroma_manager._instance is not None:
403
+ return get_chroma_manager._instance
404
+
405
+ # Create a new instance with the specified configuration
406
+ instance = ChromaDBManager(
407
+ model_size=model_size,
408
+ keys_file=keys_file,
409
+ models_file=models_file,
410
+ db_path=db_path
411
+ )
412
+
413
+ # Store the instance as a static variable
414
+ get_chroma_manager._instance = instance
415
+
416
+ return instance
417
+
418
+ def query_documents(query_text, n_results=3):
419
+ """Query documents in the default collection."""
420
+ return chroma_manager.query(query_text, n_results)
421
+
422
+ def add_document(text, metadata, doc_id=None):
423
+ """Add a document to the default collection."""
424
+ return chroma_manager.add_document(text, metadata, doc_id)
425
+
426
+ def setup_logger(level=logging.INFO):
427
+ root_logger = logging.getLogger()
428
+ if not root_logger.handlers:
429
+ logging.basicConfig(
430
+ level=level,
431
+ format="%(asctime)s [%(levelname)s] %(message)s",
432
+ datefmt="%Y-%m-%d %H:%M:%S",
433
+ )
434
+ else:
435
+ # Just set the level if already configured elsewhere
436
+ root_logger.setLevel(level)
437
+
438
+ def init_chroma_db_manager(config: dict) -> ChromaDBManager:
439
+ """
440
+ Convenience function to initialize ChromaDBManager using a config dict.
441
+ Intended for external scripts using YAML configuration.
442
+ """
443
+ return ChromaDBManager(
444
+ model_size=config.get("model", "medium"),
445
+ keys_file=config.get("keys_file", "keys.txt"),
446
+ models_file=config.get("models_file", "models.txt"),
447
+ dataset_repo=config.get("repo"), # Optional: for HF datasets
448
+ db_path=config.get("db_path") # Optional: override .env default
449
+ )
450
+
451
+ def main():
452
+ # Load config YAML for logging level if needed
453
+ config_path = "config/chroma_config.yml"
454
+ try:
455
+ with open(config_path, 'r') as file:
456
+ config = yaml.safe_load(file) or {}
457
+ log_level = config.get('log_level', 'INFO').upper() # Default to INFO if not set in YAML
458
+ except Exception as e:
459
+ log_level = 'INFO' # Default to INFO if there is an error loading config
460
+ print(f"Could not load {config_path}, using default log level {log_level}: {e}")
461
+
462
+ # Set up logging
463
+ setup_logger(level=log_level)
464
+
465
+ # Initialize ChromaDBManager (No CLI args, just configuration file)
466
+ chroma = ChromaDBManager()
467
+
468
+ # Example of using ChromaDBManager
469
+ print("ChromaDBManager initialized successfully.")
470
+
471
+ if __name__ == "__main__":
472
+ main()