Agent_CS / populate_db.py
daniel-was-taken's picture
Upload main files
35dae13 verified
import argparse
import os
import shutil
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.document import Document
# Import with fallback for older versions
try:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
except ImportError:
# Fallback to older imports
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
CHROMA_PATH = "chroma_db"
DATA_PATH = "data"
# Embedding configuration
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
def get_embedding_function():
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embeddings
def main():
parser = argparse.ArgumentParser(description="Populate the Chroma database with documents.")
parser.add_argument("--data_path", type=str, default=DATA_PATH, help="Path to the directory containing documents.")
parser.add_argument("--reset", action="store_true", help="Reset the Chroma database before adding documents.")
args = parser.parse_args()
# Clear existing Chroma database
if args.reset:
print("Resetting Chroma database...")
clear_database()
# Load documents from the specified directory
documents = load_documents()
chunks = split_documents(documents)
add_to_chroma(chunks)
def load_documents():
"""Load documents from the specified directory."""
loader = DirectoryLoader(DATA_PATH, show_progress=True)
documents = loader.load()
print(f"Loaded {len(documents)} documents.")
return documents
def split_documents(documents: list[Document]):
"""Split documents into smaller chunks."""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len, is_separator_regex=False)
chunks = text_splitter.split_documents(documents)
print(f"Split into {len(chunks)} chunks.")
return chunks
def add_to_chroma(chunks: list[Document]):
# Simply recreate the database to avoid API compatibility issues
if os.path.exists(CHROMA_PATH):
print("Removing existing database...")
shutil.rmtree(CHROMA_PATH)
print("Creating new database...")
db = Chroma(
persist_directory=CHROMA_PATH,
embedding_function=get_embedding_function(),
)
chunks_with_ids = calculate_chunks_ids(chunks)
print(f"Adding {len(chunks_with_ids)} chunks to Chroma.")
chunk_ids = [chunk.metadata["id"] for chunk in chunks_with_ids]
db.add_documents(chunks_with_ids, ids=chunk_ids)
# Try to persist, but don't fail if method doesn't exist
try:
db.persist()
print("Database persisted successfully.")
except AttributeError:
print("Database auto-persisted (newer ChromaDB version).")
print("✅ Database populated successfully!")
def calculate_chunks_ids(chunks: list[Document]):
last_page_id = None
current_chunk_index = 0
for chunk in chunks:
source = chunk.metadata.get("source")
page = chunk.metadata.get("page", 0)
current_page_id = f"{source}:{page}"
if current_page_id == last_page_id:
current_chunk_index += 1
else:
current_chunk_index = 0
chunk_id = f"{current_page_id}:{current_chunk_index}"
last_page_id = current_page_id
chunk.metadata["id"] = chunk_id
return chunks
def clear_database():
if os.path.exists(CHROMA_PATH):
shutil.rmtree(CHROMA_PATH)
print(f"Cleared Chroma database at {CHROMA_PATH}.")
if __name__ == "__main__":
main()