Spaces:
Sleeping
Sleeping
File size: 3,894 Bytes
35dae13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import argparse
import os
import shutil
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.document import Document
# Import with fallback for older versions
try:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
except ImportError:
# Fallback to older imports
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
CHROMA_PATH = "chroma_db"
DATA_PATH = "data"
# Embedding configuration
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
def get_embedding_function():
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embeddings
def main():
parser = argparse.ArgumentParser(description="Populate the Chroma database with documents.")
parser.add_argument("--data_path", type=str, default=DATA_PATH, help="Path to the directory containing documents.")
parser.add_argument("--reset", action="store_true", help="Reset the Chroma database before adding documents.")
args = parser.parse_args()
# Clear existing Chroma database
if args.reset:
print("Resetting Chroma database...")
clear_database()
# Load documents from the specified directory
documents = load_documents()
chunks = split_documents(documents)
add_to_chroma(chunks)
def load_documents():
"""Load documents from the specified directory."""
loader = DirectoryLoader(DATA_PATH, show_progress=True)
documents = loader.load()
print(f"Loaded {len(documents)} documents.")
return documents
def split_documents(documents: list[Document]):
"""Split documents into smaller chunks."""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len, is_separator_regex=False)
chunks = text_splitter.split_documents(documents)
print(f"Split into {len(chunks)} chunks.")
return chunks
def add_to_chroma(chunks: list[Document]):
# Simply recreate the database to avoid API compatibility issues
if os.path.exists(CHROMA_PATH):
print("Removing existing database...")
shutil.rmtree(CHROMA_PATH)
print("Creating new database...")
db = Chroma(
persist_directory=CHROMA_PATH,
embedding_function=get_embedding_function(),
)
chunks_with_ids = calculate_chunks_ids(chunks)
print(f"Adding {len(chunks_with_ids)} chunks to Chroma.")
chunk_ids = [chunk.metadata["id"] for chunk in chunks_with_ids]
db.add_documents(chunks_with_ids, ids=chunk_ids)
# Try to persist, but don't fail if method doesn't exist
try:
db.persist()
print("Database persisted successfully.")
except AttributeError:
print("Database auto-persisted (newer ChromaDB version).")
print("✅ Database populated successfully!")
def calculate_chunks_ids(chunks: list[Document]):
last_page_id = None
current_chunk_index = 0
for chunk in chunks:
source = chunk.metadata.get("source")
page = chunk.metadata.get("page", 0)
current_page_id = f"{source}:{page}"
if current_page_id == last_page_id:
current_chunk_index += 1
else:
current_chunk_index = 0
chunk_id = f"{current_page_id}:{current_chunk_index}"
last_page_id = current_page_id
chunk.metadata["id"] = chunk_id
return chunks
def clear_database():
if os.path.exists(CHROMA_PATH):
shutil.rmtree(CHROMA_PATH)
print(f"Cleared Chroma database at {CHROMA_PATH}.")
if __name__ == "__main__":
main()
|