import argparse import os import shutil from langchain_community.document_loaders import DirectoryLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema.document import Document # Import with fallback for older versions try: from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma except ImportError: # Fallback to older imports from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma CHROMA_PATH = "chroma_db" DATA_PATH = "data" # Embedding configuration model_name = "sentence-transformers/all-mpnet-base-v2" model_kwargs = {'device': 'cpu'} encode_kwargs = {'normalize_embeddings': False} def get_embedding_function(): embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, ) return embeddings def main(): parser = argparse.ArgumentParser(description="Populate the Chroma database with documents.") parser.add_argument("--data_path", type=str, default=DATA_PATH, help="Path to the directory containing documents.") parser.add_argument("--reset", action="store_true", help="Reset the Chroma database before adding documents.") args = parser.parse_args() # Clear existing Chroma database if args.reset: print("Resetting Chroma database...") clear_database() # Load documents from the specified directory documents = load_documents() chunks = split_documents(documents) add_to_chroma(chunks) def load_documents(): """Load documents from the specified directory.""" loader = DirectoryLoader(DATA_PATH, show_progress=True) documents = loader.load() print(f"Loaded {len(documents)} documents.") return documents def split_documents(documents: list[Document]): """Split documents into smaller chunks.""" text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len, is_separator_regex=False) chunks = text_splitter.split_documents(documents) print(f"Split into {len(chunks)} chunks.") return chunks def add_to_chroma(chunks: list[Document]): # Simply recreate the database to avoid API compatibility issues if os.path.exists(CHROMA_PATH): print("Removing existing database...") shutil.rmtree(CHROMA_PATH) print("Creating new database...") db = Chroma( persist_directory=CHROMA_PATH, embedding_function=get_embedding_function(), ) chunks_with_ids = calculate_chunks_ids(chunks) print(f"Adding {len(chunks_with_ids)} chunks to Chroma.") chunk_ids = [chunk.metadata["id"] for chunk in chunks_with_ids] db.add_documents(chunks_with_ids, ids=chunk_ids) # Try to persist, but don't fail if method doesn't exist try: db.persist() print("Database persisted successfully.") except AttributeError: print("Database auto-persisted (newer ChromaDB version).") print("✅ Database populated successfully!") def calculate_chunks_ids(chunks: list[Document]): last_page_id = None current_chunk_index = 0 for chunk in chunks: source = chunk.metadata.get("source") page = chunk.metadata.get("page", 0) current_page_id = f"{source}:{page}" if current_page_id == last_page_id: current_chunk_index += 1 else: current_chunk_index = 0 chunk_id = f"{current_page_id}:{current_chunk_index}" last_page_id = current_page_id chunk.metadata["id"] = chunk_id return chunks def clear_database(): if os.path.exists(CHROMA_PATH): shutil.rmtree(CHROMA_PATH) print(f"Cleared Chroma database at {CHROMA_PATH}.") if __name__ == "__main__": main()