Spaces:
Build error
Build error
| import sys | |
| import os | |
| import boto3 | |
| import hashlib | |
| import json | |
| import threading | |
| # Add the project root directory to Python path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from typing import List | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from langchain_openai import OpenAIEmbeddings | |
| import pinecone | |
| from tqdm.auto import tqdm | |
| from langchain.schema import Document | |
| from config import get_settings | |
| from dotenv import load_dotenv | |
| from io import BytesIO | |
| from PyPDF2 import PdfReader | |
| load_dotenv() | |
| class RAGPrep: | |
| def __init__(self, processed_hashes_file="processed_hashes.json"): | |
| self.settings = get_settings() | |
| self.index_name = self.settings.INDEX_NAME | |
| self.pc = self.init_pinecone() | |
| self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.OPENAI_API_KEY) | |
| self.processed_hashes_file = processed_hashes_file | |
| self.processed_hashes = self.load_processed_hashes() | |
| def init_pinecone(self): | |
| """Initialize Pinecone client""" | |
| pc = pinecone.Pinecone(self.settings.PINECONE_API_KEY) | |
| return pc | |
| # Define function to create or connect to an existing index | |
| def create_or_connect_index(self,index_name, dimension): | |
| """Create or connect to existing Pinecone index""" | |
| spec = pinecone.ServerlessSpec( | |
| cloud=self.settings.CLOUD, | |
| region=self.settings.REGION | |
| ) | |
| print(f'all indexes: {self.pc.list_indexes()}') | |
| if index_name not in self.pc.list_indexes().names(): | |
| self.pc.create_index( | |
| name=index_name, | |
| dimension=dimension, | |
| metric='cosine', # You can use 'dotproduct' or other metrics if needed | |
| spec=spec | |
| ) | |
| return self.pc.Index(index_name) | |
| def load_processed_hashes(self): | |
| """Load previously processed hashes from a file.""" | |
| if os.path.exists(self.processed_hashes_file): | |
| with open(self.processed_hashes_file, "r") as f: | |
| return set(json.load(f)) | |
| return set() | |
| def save_processed_hashes(self): | |
| """Save processed hashes to a file.""" | |
| with open(self.processed_hashes_file, "w") as f: | |
| json.dump(list(self.processed_hashes), f) | |
| def generate_pdf_hash(self, pdf_content: bytes): | |
| """Generate a hash for the given PDF content.""" | |
| hasher = hashlib.md5() | |
| hasher.update(pdf_content) | |
| return hasher.hexdigest() | |
| def load_and_split_pdfs(self, chunk_from = 50, chunk_to = 100) -> List[Document]: | |
| """Load PDFs from S3, extract text, and split into chunks.""" | |
| print("***********") | |
| # Initialize S3 client | |
| s3_client = boto3.client( | |
| 's3', | |
| aws_access_key_id=self.settings.AWS_ACCESS_KEY, | |
| aws_secret_access_key=self.settings.AWS_SECRET_KEY, | |
| region_name=self.settings.AWS_REGION | |
| ) | |
| # List all PDF files in the S3 bucket and prefix | |
| print(f"Listing files in S3 bucket: {self.settings.AWS_BUCKET_NAME}") | |
| response = s3_client.list_objects_v2(Bucket=self.settings.AWS_BUCKET_NAME, Prefix="") | |
| s3_keys = [obj['Key'] for obj in response.get('Contents', [])] | |
| print(f"Found {len(s3_keys)} PDF files in S3") | |
| documents = [] | |
| # Process each PDF file | |
| for s3_key in s3_keys[chunk_from:chunk_to]: | |
| print(f"Processing file: {s3_key}") | |
| if not s3_key.lower().endswith(".pdf"): | |
| print("Not a PDF file, skipping.") | |
| continue | |
| try: | |
| # Read file from S3 | |
| obj = s3_client.get_object(Bucket=self.settings.AWS_BUCKET_NAME, Key=s3_key) | |
| pdf_content = obj['Body'].read() | |
| # Generate hash and check for duplicates | |
| pdf_hash = self.generate_pdf_hash(pdf_content) | |
| if pdf_hash in self.processed_hashes: | |
| print(f"Duplicate PDF detected: {s3_key}, skipping.") | |
| continue | |
| # Extract text from PDF | |
| pdf_file = BytesIO(pdf_content) | |
| pdf_reader = PdfReader(pdf_file) | |
| text = "".join(page.extract_text() for page in pdf_reader.pages) | |
| # Add document with metadata | |
| documents.append(Document(page_content=text, metadata={"source": s3_key})) | |
| self.processed_hashes.add(pdf_hash) | |
| except Exception as e: | |
| print(f"Error processing {s3_key}: {e}") | |
| print(f"Extracted text from {len(documents)} documents") | |
| # Split documents into chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.settings.CHUNK_SIZE, | |
| chunk_overlap=self.settings.CHUNK_OVERLAP | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| print(f"Created {len(chunks)} chunks") | |
| # Save updated hashes | |
| self.save_processed_hashes() | |
| return chunks | |
| def process_and_upload(self, total_batch=200): | |
| """Process PDFs and upload to Pinecone""" | |
| # Create or connect to index | |
| index = self.create_or_connect_index(self.index_name, self.settings.DIMENSIONS) | |
| # Load and split documents | |
| print(f'//////// chunking: ////////') | |
| current_batch = 0 | |
| for i in range(0, total_batch, 50): | |
| batch_size = 50 # Adjust based on your needs | |
| chunks = self.load_and_split_pdfs(current_batch, current_batch+batch_size) | |
| current_batch = current_batch + batch_size | |
| # Prepare for batch processing | |
| max_threads = 4 # Adjust based on your hardware | |
| def process_batch(batch, batch_index): | |
| """Process a single batch of chunks""" | |
| print(f"Processing batch {batch_index} on thread: {threading.current_thread().name}") | |
| print(f"Active threads: {threading.active_count()}") | |
| # Create ids for batch | |
| ids = [f"chunk_{batch_index}_{j}" for j in range(len(batch))] | |
| # Get texts and generate embeddings | |
| texts = [doc.page_content for doc in batch] | |
| embeddings = self.embeddings.embed_documents(texts) | |
| # Create metadata | |
| metadata = [ | |
| { | |
| "text": doc.page_content, | |
| "source": doc.metadata.get("source", "unknown"), | |
| "page": doc.metadata.get("page", 0) | |
| } | |
| for doc in batch | |
| ] | |
| # Create upsert batch | |
| return list(zip(ids, embeddings, metadata)) | |
| with ThreadPoolExecutor(max_threads) as executor: | |
| futures = [] | |
| print(f"Batch size being used: {batch_size}") | |
| for i in range(0, len(chunks), batch_size): | |
| batch = chunks[i:i + batch_size] | |
| futures.append(executor.submit(process_batch, batch, i)) | |
| # Gather results and upsert to Pinecone | |
| for future in tqdm(as_completed(futures), total=len(futures), desc="Uploading batches"): | |
| try: | |
| to_upsert = future.result() | |
| index.upsert(vectors=to_upsert) | |
| except Exception as e: | |
| print(f"Error processing batch: {e}") | |
| print(f"Successfully processed and uploaded {len(chunks)} chunks to Pinecone") | |
| def cleanup_index(self) -> bool: | |
| """ | |
| Delete all vectors from the Pinecone index. | |
| Returns: | |
| bool: True if cleanup was successful, False otherwise | |
| Raises: | |
| Exception: Logs any unexpected errors during cleanup | |
| """ | |
| try: | |
| # Try to get the index | |
| if self.index_name in self.pc.list_indexes().names(): | |
| print(f'index name found in {self.pc.list_indexes().names()}') | |
| # Attempt to delete all vectors | |
| index = self.pc.Index(self.index_name) | |
| index.delete(delete_all=True) | |
| print(f"Successfully cleaned up index: {self.index_name}") | |
| return True | |
| print(f'Index doesn\'t exist.') | |
| return True | |
| except Exception as e: | |
| print(f"Unexpected error during index cleanup: {str(e)}") | |
| # You might want to log this error as well | |
| import logging | |
| logging.error(f"Failed to cleanup index {self.index_name}. Error: {str(e)}") | |
| return False | |
| finally: | |
| # Any cleanup code that should run regardless of success/failure | |
| print("Cleanup operation completed.") | |
| # Example usage: | |
| if __name__ == "__main__": | |
| # Example .env file content: | |
| rag_prep = RAGPrep() | |
| rag_prep.process_and_upload() | |
| # rag_prep.cleanup_index() |