Spaces:
Sleeping
Sleeping
| import chromadb | |
| import PyPDF2 | |
| import time | |
| import chromadb.utils.embedding_functions as embedding_functions | |
| import os | |
| import io | |
| class ChromaUploader: | |
| def __init__(self, collection_name, db_path, api_key=None): | |
| # Initialize Chroma persistent client and collection name | |
| self.chroma_client = chromadb.PersistentClient(path=db_path) | |
| self.collection_name = collection_name | |
| self.collection = None | |
| # Use provided API key or fall back to environment variable | |
| self.openai_key = api_key or os.getenv("OPENAI_API_KEY") | |
| if not self.openai_key: | |
| raise ValueError("OpenAI API key is required") | |
| self.openai_ef = embedding_functions.OpenAIEmbeddingFunction( | |
| api_key=self.openai_key, | |
| model_name="text-embedding-ada-002" | |
| ) | |
| self._initialize_collection() | |
| def _initialize_collection(self): | |
| """ | |
| Initializes the collection if it doesn't exist. | |
| """ | |
| try: | |
| self.collection = self.chroma_client.get_collection( | |
| name=self.collection_name, | |
| embedding_function=self.openai_ef | |
| ) | |
| print(f"Collection '{self.collection_name}' already exists.") | |
| except Exception as e: | |
| # If collection doesn't exist, create a new one | |
| self.collection = self.chroma_client.create_collection( | |
| name=self.collection_name, | |
| embedding_function=self.openai_ef | |
| ) | |
| print(f"Created new collection '{self.collection_name}'.") | |
| def add_documents(self, documents, progress_callback=None): | |
| """ | |
| Adds documents to the collection with retry mechanism and progress tracking. | |
| :param documents: List of document strings to be added | |
| :param progress_callback: Optional callback function for progress updates | |
| """ | |
| if documents is None or len(documents) == 0: | |
| print("No data collected from the document to add.") | |
| return False | |
| try: | |
| # Create unique IDs for each document chunk | |
| timestamp = int(time.time() * 1000000) # microseconds for uniqueness | |
| ids = [f"doc_{timestamp}_{i}" for i in range(len(documents))] | |
| # Filter out empty documents | |
| valid_documents = [] | |
| valid_ids = [] | |
| for i, doc in enumerate(documents): | |
| if doc and doc.strip() and len(doc.strip()) > 10: # Only add non-empty docs with some content | |
| valid_documents.append(doc.strip()) | |
| valid_ids.append(ids[i]) | |
| if not valid_documents: | |
| print("No valid documents to add after filtering.") | |
| return False | |
| print(f"Attempting to add {len(valid_documents)} documents to collection...") | |
| # Add documents to collection in smaller batches with retry | |
| batch_size = 20 # Reduced batch size to avoid connection issues | |
| total_added = 0 | |
| total_batches = (len(valid_documents) + batch_size - 1) // batch_size | |
| for i in range(0, len(valid_documents), batch_size): | |
| batch_docs = valid_documents[i:i + batch_size] | |
| batch_ids = valid_ids[i:i + batch_size] | |
| batch_num = i // batch_size + 1 | |
| # Update progress if callback provided | |
| if progress_callback: | |
| progress = 0.6 + (0.2 * batch_num / total_batches) # Progress from 60% to 80% | |
| progress_callback(progress, f"Adding batch {batch_num}/{total_batches} to ChromaDB...") | |
| success = self._add_batch_with_retry(batch_docs, batch_ids, max_retries=3) | |
| if success: | |
| total_added += len(batch_docs) | |
| print(f"Successfully added batch {batch_num}, total: {total_added}/{len(valid_documents)}") | |
| else: | |
| print(f"Failed to add batch {batch_num} after retries") | |
| # Continue with next batch instead of failing completely | |
| if total_added > 0: | |
| print(f"Successfully added {total_added} out of {len(valid_documents)} documents to collection '{self.collection_name}'.") | |
| return True | |
| else: | |
| print("Failed to add any documents to the collection.") | |
| return False | |
| except Exception as e: | |
| print(f"Error in add_documents: {e}") | |
| return False | |
| def _add_batch_with_retry(self, batch_docs, batch_ids, max_retries=3): | |
| """ | |
| Add a batch of documents with retry mechanism | |
| """ | |
| import time | |
| for attempt in range(max_retries): | |
| try: | |
| print(f"Attempt {attempt + 1}/{max_retries} for batch of {len(batch_docs)} documents...") | |
| self.collection.add( | |
| documents=batch_docs, | |
| ids=batch_ids | |
| ) | |
| return True | |
| except Exception as e: | |
| error_msg = str(e).lower() | |
| print(f"Attempt {attempt + 1} failed: {e}") | |
| if "connection" in error_msg or "timeout" in error_msg or "rate" in error_msg: | |
| # Network or rate limit issue - wait before retry | |
| wait_time = (attempt + 1) * 2 # Exponential backoff | |
| print(f"Connection/rate limit issue detected. Waiting {wait_time} seconds before retry...") | |
| time.sleep(wait_time) | |
| elif "api" in error_msg and "key" in error_msg: | |
| # API key issue - no point in retrying | |
| print("API key issue detected. Cannot retry.") | |
| return False | |
| else: | |
| # Other error - short wait before retry | |
| time.sleep(1) | |
| if attempt == max_retries - 1: | |
| print(f"All {max_retries} attempts failed for this batch.") | |
| return False | |
| return False | |
| def extract_text_from_pdf_bytes(self, pdf_bytes): | |
| """ | |
| Extracts text from a PDF file from bytes (for Gradio uploaded files). | |
| :param pdf_bytes: PDF file as bytes | |
| :return: Extracted text from the PDF and the lines as a list | |
| """ | |
| try: | |
| # Create a file-like object from bytes | |
| pdf_file = io.BytesIO(pdf_bytes) | |
| # Create a PDF reader object | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| # Initialize an empty string to store extracted text | |
| text = "" | |
| # Extract text from each page | |
| for page_num, page in enumerate(pdf_reader.pages): | |
| try: | |
| # Extract text from the page | |
| page_text = page.extract_text() | |
| # Clean up the extracted text | |
| cleaned_text = self._clean_extracted_text(page_text) | |
| if cleaned_text.strip(): # Only add non-empty pages | |
| # Append to the total text with page marker | |
| text += f"\n--- Page {page_num + 1} ---\n{cleaned_text}\n" | |
| except Exception as e: | |
| print(f"Error extracting text from page {page_num + 1}: {e}") | |
| continue | |
| if not text.strip(): | |
| return "", [] | |
| # Split text into meaningful chunks | |
| chunks = self._split_text_into_chunks(text, max_chunk_size=1000, overlap=100) | |
| return text.strip(), chunks | |
| except Exception as e: | |
| print(f"Error extracting text from PDF: {e}") | |
| return "", [] | |
| def extract_text_from_pdf(self, pdf_path): | |
| """ | |
| Extracts text from a PDF file using PyPDF2 with improved text extraction. | |
| :param pdf_path: Path to the PDF file | |
| :return: Extracted text from the PDF and the lines as a list | |
| """ | |
| try: | |
| # Open the PDF file | |
| with open(pdf_path, 'rb') as file: | |
| pdf_bytes = file.read() | |
| return self.extract_text_from_pdf_bytes(pdf_bytes) | |
| except Exception as e: | |
| print(f"Error extracting text from PDF: {e}") | |
| return "", [] | |
| def _clean_extracted_text(self, text): | |
| """ | |
| Clean up extracted text to improve readability and remove unnecessary whitespace. | |
| :param text: Raw extracted text | |
| :return: Cleaned text | |
| """ | |
| if not text: | |
| return "" | |
| # Remove excessive whitespace and clean up | |
| lines = [] | |
| for line in text.split('\n'): | |
| cleaned_line = line.strip() | |
| if cleaned_line and len(cleaned_line) > 2: # Filter out very short lines | |
| lines.append(cleaned_line) | |
| # Join lines with proper spacing | |
| cleaned_text = ' '.join(lines) | |
| # Remove multiple spaces | |
| while ' ' in cleaned_text: | |
| cleaned_text = cleaned_text.replace(' ', ' ') | |
| return cleaned_text | |
| def _split_text_into_chunks(self, text, max_chunk_size=1000, overlap=100): | |
| """ | |
| Split text into overlapping chunks for better context preservation. | |
| :param text: Text to split | |
| :param max_chunk_size: Maximum size of each chunk | |
| :param overlap: Number of characters to overlap between chunks | |
| :return: List of text chunks | |
| """ | |
| if not text: | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| # Calculate end position | |
| end = start + max_chunk_size | |
| # If we're not at the end of the text, try to end at a sentence boundary | |
| if end < len(text): | |
| # Look for sentence endings within the last 200 characters | |
| search_start = max(end - 200, start) | |
| sentence_endings = ['. ', '! ', '? ', '\n\n'] | |
| best_end = end | |
| for ending in sentence_endings: | |
| pos = text.rfind(ending, search_start, end) | |
| if pos > start: | |
| best_end = pos + len(ending) | |
| break | |
| end = best_end | |
| # Extract chunk | |
| chunk = text[start:end].strip() | |
| if chunk and len(chunk) > 50: # Only add substantial chunks | |
| chunks.append(chunk) | |
| # Move start position with overlap | |
| start = max(start + 1, end - overlap) | |
| # Safety check to prevent infinite loops | |
| if start >= len(text): | |
| break | |
| return chunks | |
| def get_collection_count(self): | |
| """ | |
| Get the number of documents in the collection. | |
| """ | |
| try: | |
| return self.collection.count() | |
| except Exception as e: | |
| print(f"Error getting collection count: {e}") | |
| return 0 |