File size: 4,429 Bytes
8099442 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# ============================================================================
# STEP 2: EMBEDDER MODULE
# Generate embeddings using CLIP and store in ChromaDB
# ============================================================================
import os
import json
from typing import List, Dict, Optional
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from sentence_transformers import SentenceTransformer
import numpy as np
class CLIPEmbeddingFunction(EmbeddingFunction):
"""Custom embedding function using CLIP model."""
def __init__(self, model_name: str = "sentence-transformers/clip-ViT-B-32"):
"""Initialize CLIP embedder."""
self.model = SentenceTransformer(model_name)
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents."""
# Handle both text and list inputs
if isinstance(input, str):
embeddings = self.model.encode([input]).tolist()
else:
embeddings = self.model.encode(list(input)).tolist()
return embeddings
class ChromaDBManager:
"""Manage ChromaDB vector storage with persistent data."""
def __init__(self, db_dir: str = "./chroma_db"):
"""Initialize ChromaDB with persistent storage."""
self.db_dir = db_dir
os.makedirs(db_dir, exist_ok=True)
# Initialize persistent client
self.client = chromadb.PersistentClient(path=db_dir)
# Initialize embedding function with CLIP
self.embedding_function = CLIPEmbeddingFunction(
model_name="sentence-transformers/clip-ViT-B-32"
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name="pdf_documents",
embedding_function=self.embedding_function,
metadata={"hnsw:space": "cosine"}
)
print(f"ChromaDB initialized. Database location: {db_dir}")
def add_documents(self, documents: List[Dict]) -> None:
"""Add documents to ChromaDB."""
if not documents:
print("No documents to add")
return
doc_ids = []
doc_texts = []
doc_metadatas = []
for idx, doc in enumerate(documents):
doc_id = f"doc_{doc.get('filename', 'unknown')}_{idx}"
doc_text = doc.get('text', '') + " " + " ".join([table[1] for table in doc.get('tables', [])])
doc_ids.append(doc_id)
doc_texts.append(doc_text)
doc_metadatas.append({
"filename": doc.get('filename', ''),
"page": str(doc.get('page', 0)),
"source": "pdf"
})
# Add to collection
self.collection.add(
ids=doc_ids,
documents=doc_texts,
metadatas=doc_metadatas
)
print(f"Added {len(documents)} documents to ChromaDB")
def search(self, query: str, n_results: int = 5) -> List[Dict]:
"""Search for documents similar to query."""
results = self.collection.query(
query_texts=[query],
n_results=n_results
)
retrieved_docs = []
if results['documents']:
for doc, distance, metadata in zip(
results['documents'][0],
results['distances'][0],
results['metadatas'][0]
):
retrieved_docs.append({
'document': doc,
'distance': distance,
'metadata': metadata,
'relevance_score': 1 - distance # Convert distance to similarity score
})
return retrieved_docs
def get_all_documents_count(self) -> int:
"""Get total number of documents in collection."""
return self.collection.count()
def clear_collection(self) -> None:
"""Clear all documents from collection (for reset)."""
self.collection.delete(where={})
print("Collection cleared")
def get_collection_info(self) -> Dict:
"""Get information about the collection."""
return {
"name": self.collection.name,
"document_count": self.collection.count(),
"metadata": self.collection.metadata
}
|