rag / create_database.py
poemsforaphrodite's picture
Update create_database.py
4beaccd verified
import os
import shutil
import logging
import warnings
import time
from dotenv import load_dotenv
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
CHROMA_PATH = "./chroma_db"
PROCESSED_PATH = "data/books"
API_TIMEOUT = 60 # 1 minute timeout for API calls
CHUNK_SIZE = 100 # Number of documents to process in each chunk
DELAY_BETWEEN_FILES = 60 # 1 minute delay between processing files
def main():
logger.info("Starting document processing")
generate_data_store()
logger.info("Document processing completed")
def generate_data_store():
files = get_files_to_process()
all_chunks = []
for file in files:
chunks = process_single_file(file)
all_chunks.extend(chunks)
time.sleep(DELAY_BETWEEN_FILES)
save_to_chroma(all_chunks)
def get_files_to_process():
return [f for f in os.listdir(PROCESSED_PATH) if f.endswith('.txt')]
def process_single_file(file):
logger.info(f"Processing file: {file}")
file_path = os.path.join(PROCESSED_PATH, file)
document = load_document(file_path)
chunks = split_text([document])
return chunks
def load_document(file_path):
loader = TextLoader(file_path)
return loader.load()[0]
def calculate_chunk_size(total_lines):
if total_lines < 1000:
return 1000
elif total_lines < 5000:
return 2000
elif total_lines < 10000:
return 4000
elif total_lines < 50000:
return 8000
else:
return 16000 # For very large documents
def split_text(documents: list[Document]):
logger.info("Starting text splitting process")
chunks = []
for doc in documents:
with open(doc.metadata['source'], 'r', encoding='utf-8') as file:
lines = file.readlines()
total_lines = len(lines)
chunk_size = calculate_chunk_size(total_lines)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_size // 5, # 20% overlap
length_function=len,
separators=['\n\n', '\n', ' ', ''],
add_start_index=True,
)
doc_chunks = text_splitter.create_documents(['\n'.join(lines)], [doc.metadata])
chunks.extend(doc_chunks)
logger.info(f"Split document {doc.metadata['source']} into {len(doc_chunks)} chunks with chunk size {chunk_size} lines")
if len(doc_chunks) > 1000:
logger.warning(f"Document {doc.metadata['source']} has been split into a large number of chunks ({len(doc_chunks)}). Consider further preprocessing or summarization.")
if chunks:
sample_chunk = chunks[0]
logger.info(f"Sample chunk: {sample_chunk.page_content[:100]}...")
logger.info(f"Sample chunk metadata: {sample_chunk.metadata}")
return chunks
def save_to_chroma(chunks: list[Document]):
logger.info(f"Preparing to save chunks to Chroma at {CHROMA_PATH}")
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, timeout=API_TIMEOUT)
# Use a single collection for all documents
collection_name = "all_documents"
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings, collection_name=collection_name)
for i in range(0, len(chunks), CHUNK_SIZE):
chunk_batch = chunks[i:i+CHUNK_SIZE]
try:
start_time = time.time()
db.add_documents(chunk_batch)
end_time = time.time()
logger.info(f"Processed and saved chunk {i//CHUNK_SIZE + 1} of {len(chunks)//CHUNK_SIZE + 1} in {end_time - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Error processing chunk {i//CHUNK_SIZE + 1}: {str(e)}")
db.persist()
logger.info(f"Successfully saved {len(chunks)} chunks to {CHROMA_PATH} in collection {collection_name}")
if __name__ == "__main__":
main()