project / src /embedder.py
dnj0's picture
Upload 4 files
8099442 verified
# ============================================================================
# STEP 2: EMBEDDER MODULE
# Generate embeddings using CLIP and store in ChromaDB
# ============================================================================
import os
import json
from typing import List, Dict, Optional
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from sentence_transformers import SentenceTransformer
import numpy as np
class CLIPEmbeddingFunction(EmbeddingFunction):
"""Custom embedding function using CLIP model."""
def __init__(self, model_name: str = "sentence-transformers/clip-ViT-B-32"):
"""Initialize CLIP embedder."""
self.model = SentenceTransformer(model_name)
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents."""
# Handle both text and list inputs
if isinstance(input, str):
embeddings = self.model.encode([input]).tolist()
else:
embeddings = self.model.encode(list(input)).tolist()
return embeddings
class ChromaDBManager:
"""Manage ChromaDB vector storage with persistent data."""
def __init__(self, db_dir: str = "./chroma_db"):
"""Initialize ChromaDB with persistent storage."""
self.db_dir = db_dir
os.makedirs(db_dir, exist_ok=True)
# Initialize persistent client
self.client = chromadb.PersistentClient(path=db_dir)
# Initialize embedding function with CLIP
self.embedding_function = CLIPEmbeddingFunction(
model_name="sentence-transformers/clip-ViT-B-32"
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name="pdf_documents",
embedding_function=self.embedding_function,
metadata={"hnsw:space": "cosine"}
)
print(f"ChromaDB initialized. Database location: {db_dir}")
def add_documents(self, documents: List[Dict]) -> None:
"""Add documents to ChromaDB."""
if not documents:
print("No documents to add")
return
doc_ids = []
doc_texts = []
doc_metadatas = []
for idx, doc in enumerate(documents):
doc_id = f"doc_{doc.get('filename', 'unknown')}_{idx}"
doc_text = doc.get('text', '') + " " + " ".join([table[1] for table in doc.get('tables', [])])
doc_ids.append(doc_id)
doc_texts.append(doc_text)
doc_metadatas.append({
"filename": doc.get('filename', ''),
"page": str(doc.get('page', 0)),
"source": "pdf"
})
# Add to collection
self.collection.add(
ids=doc_ids,
documents=doc_texts,
metadatas=doc_metadatas
)
print(f"Added {len(documents)} documents to ChromaDB")
def search(self, query: str, n_results: int = 5) -> List[Dict]:
"""Search for documents similar to query."""
results = self.collection.query(
query_texts=[query],
n_results=n_results
)
retrieved_docs = []
if results['documents']:
for doc, distance, metadata in zip(
results['documents'][0],
results['distances'][0],
results['metadatas'][0]
):
retrieved_docs.append({
'document': doc,
'distance': distance,
'metadata': metadata,
'relevance_score': 1 - distance # Convert distance to similarity score
})
return retrieved_docs
def get_all_documents_count(self) -> int:
"""Get total number of documents in collection."""
return self.collection.count()
def clear_collection(self) -> None:
"""Clear all documents from collection (for reset)."""
self.collection.delete(where={})
print("Collection cleared")
def get_collection_info(self) -> Dict:
"""Get information about the collection."""
return {
"name": self.collection.name,
"document_count": self.collection.count(),
"metadata": self.collection.metadata
}