# ============================================================================ # STEP 2: EMBEDDER MODULE # Generate embeddings using CLIP and store in ChromaDB # ============================================================================ import os import json from typing import List, Dict, Optional import chromadb from chromadb import Documents, EmbeddingFunction, Embeddings from sentence_transformers import SentenceTransformer import numpy as np class CLIPEmbeddingFunction(EmbeddingFunction): """Custom embedding function using CLIP model.""" def __init__(self, model_name: str = "sentence-transformers/clip-ViT-B-32"): """Initialize CLIP embedder.""" self.model = SentenceTransformer(model_name) def __call__(self, input: Documents) -> Embeddings: """Generate embeddings for input documents.""" # Handle both text and list inputs if isinstance(input, str): embeddings = self.model.encode([input]).tolist() else: embeddings = self.model.encode(list(input)).tolist() return embeddings class ChromaDBManager: """Manage ChromaDB vector storage with persistent data.""" def __init__(self, db_dir: str = "./chroma_db"): """Initialize ChromaDB with persistent storage.""" self.db_dir = db_dir os.makedirs(db_dir, exist_ok=True) # Initialize persistent client self.client = chromadb.PersistentClient(path=db_dir) # Initialize embedding function with CLIP self.embedding_function = CLIPEmbeddingFunction( model_name="sentence-transformers/clip-ViT-B-32" ) # Get or create collection self.collection = self.client.get_or_create_collection( name="pdf_documents", embedding_function=self.embedding_function, metadata={"hnsw:space": "cosine"} ) print(f"ChromaDB initialized. Database location: {db_dir}") def add_documents(self, documents: List[Dict]) -> None: """Add documents to ChromaDB.""" if not documents: print("No documents to add") return doc_ids = [] doc_texts = [] doc_metadatas = [] for idx, doc in enumerate(documents): doc_id = f"doc_{doc.get('filename', 'unknown')}_{idx}" doc_text = doc.get('text', '') + " " + " ".join([table[1] for table in doc.get('tables', [])]) doc_ids.append(doc_id) doc_texts.append(doc_text) doc_metadatas.append({ "filename": doc.get('filename', ''), "page": str(doc.get('page', 0)), "source": "pdf" }) # Add to collection self.collection.add( ids=doc_ids, documents=doc_texts, metadatas=doc_metadatas ) print(f"Added {len(documents)} documents to ChromaDB") def search(self, query: str, n_results: int = 5) -> List[Dict]: """Search for documents similar to query.""" results = self.collection.query( query_texts=[query], n_results=n_results ) retrieved_docs = [] if results['documents']: for doc, distance, metadata in zip( results['documents'][0], results['distances'][0], results['metadatas'][0] ): retrieved_docs.append({ 'document': doc, 'distance': distance, 'metadata': metadata, 'relevance_score': 1 - distance # Convert distance to similarity score }) return retrieved_docs def get_all_documents_count(self) -> int: """Get total number of documents in collection.""" return self.collection.count() def clear_collection(self) -> None: """Clear all documents from collection (for reset).""" self.collection.delete(where={}) print("Collection cleared") def get_collection_info(self) -> Dict: """Get information about the collection.""" return { "name": self.collection.name, "document_count": self.collection.count(), "metadata": self.collection.metadata }