Spaces:
Sleeping
Sleeping
Add save_summary and get_summaries endpoints to FastAPI app; refactor create_chroma_db to handle single document input
51a3d33 | import os | |
| from typing import Optional, List | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from dotenv import load_dotenv | |
| import google.generativeai as genai | |
| load_dotenv() | |
| # Configure paths | |
| CORPUS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "corpus") | |
| DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "vectordb") | |
| # Ensure directories exist | |
| os.makedirs(CORPUS_DIR, exist_ok=True) | |
| os.makedirs(DB_DIR, exist_ok=True) | |
| def load_documents(corpus_dir: str = CORPUS_DIR) -> List: | |
| """Load documents from the corpus directory.""" | |
| if not os.path.exists(corpus_dir): | |
| raise FileNotFoundError(f"Corpus directory not found: {corpus_dir}") | |
| print(f"Loading documents from {corpus_dir}...") | |
| # Initialize loaders for different file types | |
| loaders = { | |
| # "txt": DirectoryLoader(corpus_dir, glob="**/*.txt", loader_cls=TextLoader), | |
| "pdf": DirectoryLoader(corpus_dir, glob="**/*.pdf", loader_cls=PyPDFLoader), | |
| # "docx": DirectoryLoader(corpus_dir, glob="**/*.docx", loader_cls=Docx2txtLoader), | |
| } | |
| documents = [] | |
| for file_type, loader in loaders.items(): | |
| try: | |
| docs = loader.load() | |
| print(f"Loaded {len(docs)} {file_type} documents") | |
| documents.extend(docs) | |
| except Exception as e: | |
| print(f"Error loading {file_type} documents: {e}") | |
| return documents | |
| def split_documents(documents, chunk_size=1000, chunk_overlap=200): | |
| """Split documents into chunks.""" | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| length_function=len, | |
| ) | |
| splits = text_splitter.split_documents(documents) | |
| print(f"Split {len(documents)} documents into {len(splits)} chunks") | |
| return splits | |
| def create_chroma_db_and_document(document, collection_name="corpus_collection", db_dir=DB_DIR): | |
| """Create a Chroma vector database from documents.""" | |
| # Initialize the Gemini embedding function | |
| gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction( | |
| api_key=os.getenv("GOOGLE_API_KEY"), | |
| model_name="models/embedding-001" | |
| ) | |
| # Initialize Chroma client | |
| client = chromadb.PersistentClient(path=db_dir) | |
| # Create or get collection | |
| try: | |
| collection = client.get_collection(name=collection_name) | |
| print(f"Using existing collection: {collection_name}") | |
| except: | |
| collection = client.create_collection( | |
| name=collection_name, | |
| embedding_function=gemini_ef | |
| ) | |
| print(f"Created new collection: {collection_name}") | |
| try: | |
| collection.add( | |
| documents = [document.page_content], | |
| ids = [document.id] | |
| ) | |
| print("Document added to collection successfully.") | |
| return True | |
| except Exception as e: | |
| print(f"Error adding document to collection: {e}") | |
| return False | |
| def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR): | |
| """Query the Chroma vector database.""" | |
| # Initialize the Gemini embedding function | |
| gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction( | |
| api_key=os.getenv("GOOGLE_API_KEY"), | |
| model_name="models/embedding-001" | |
| ) | |
| # Initialize Chroma client | |
| client = chromadb.PersistentClient(path=db_dir) | |
| # Get collection | |
| collection = client.get_collection(name=collection_name, embedding_function=gemini_ef) | |
| # Query collection | |
| results = collection.query( | |
| query_texts=[query], | |
| n_results=n_results | |
| ) | |
| return results | |
| def main(): | |
| """Main function to create and test the vector database.""" | |
| print("Starting vector database creation...") | |
| # Load documents | |
| documents = load_documents() | |
| if not documents: | |
| print("No documents found in corpus directory. Please add documents to proceed.") | |
| return | |
| # Split documents | |
| splits = split_documents(documents) | |
| # Create vector database | |
| collection = create_chroma_db(splits) | |
| # Test query | |
| test_query = "What is this corpus about?" | |
| print(f"\nTesting query: '{test_query}'") | |
| results = query_chroma_db(test_query) | |
| print(f"Found {len(results['documents'][0])} matching documents") | |
| for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])): | |
| print(f"\nResult {i+1}:") | |
| print(f"Document: {doc[:150]}...") | |
| print(f"Source: {metadata.get('source', 'Unknown')}") | |
| print("\nVector database creation and testing complete!") | |
| if __name__ == "__main__": | |
| main() | |