|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import json |
|
|
from typing import List, Dict, Optional |
|
|
import chromadb |
|
|
from chromadb import Documents, EmbeddingFunction, Embeddings |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class CLIPEmbeddingFunction(EmbeddingFunction): |
|
|
"""Custom embedding function using CLIP model.""" |
|
|
|
|
|
def __init__(self, model_name: str = "sentence-transformers/clip-ViT-B-32"): |
|
|
"""Initialize CLIP embedder.""" |
|
|
self.model = SentenceTransformer(model_name) |
|
|
|
|
|
def __call__(self, input: Documents) -> Embeddings: |
|
|
"""Generate embeddings for input documents.""" |
|
|
|
|
|
if isinstance(input, str): |
|
|
embeddings = self.model.encode([input]).tolist() |
|
|
else: |
|
|
embeddings = self.model.encode(list(input)).tolist() |
|
|
return embeddings |
|
|
|
|
|
|
|
|
class ChromaDBManager: |
|
|
"""Manage ChromaDB vector storage with persistent data.""" |
|
|
|
|
|
def __init__(self, db_dir: str = "./chroma_db"): |
|
|
"""Initialize ChromaDB with persistent storage.""" |
|
|
self.db_dir = db_dir |
|
|
os.makedirs(db_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.client = chromadb.PersistentClient(path=db_dir) |
|
|
|
|
|
|
|
|
self.embedding_function = CLIPEmbeddingFunction( |
|
|
model_name="sentence-transformers/clip-ViT-B-32" |
|
|
) |
|
|
|
|
|
|
|
|
self.collection = self.client.get_or_create_collection( |
|
|
name="pdf_documents", |
|
|
embedding_function=self.embedding_function, |
|
|
metadata={"hnsw:space": "cosine"} |
|
|
) |
|
|
|
|
|
print(f"ChromaDB initialized. Database location: {db_dir}") |
|
|
|
|
|
def add_documents(self, documents: List[Dict]) -> None: |
|
|
"""Add documents to ChromaDB.""" |
|
|
if not documents: |
|
|
print("No documents to add") |
|
|
return |
|
|
|
|
|
doc_ids = [] |
|
|
doc_texts = [] |
|
|
doc_metadatas = [] |
|
|
|
|
|
for idx, doc in enumerate(documents): |
|
|
doc_id = f"doc_{doc.get('filename', 'unknown')}_{idx}" |
|
|
doc_text = doc.get('text', '') + " " + " ".join([table[1] for table in doc.get('tables', [])]) |
|
|
|
|
|
doc_ids.append(doc_id) |
|
|
doc_texts.append(doc_text) |
|
|
doc_metadatas.append({ |
|
|
"filename": doc.get('filename', ''), |
|
|
"page": str(doc.get('page', 0)), |
|
|
"source": "pdf" |
|
|
}) |
|
|
|
|
|
|
|
|
self.collection.add( |
|
|
ids=doc_ids, |
|
|
documents=doc_texts, |
|
|
metadatas=doc_metadatas |
|
|
) |
|
|
|
|
|
print(f"Added {len(documents)} documents to ChromaDB") |
|
|
|
|
|
def search(self, query: str, n_results: int = 5) -> List[Dict]: |
|
|
"""Search for documents similar to query.""" |
|
|
results = self.collection.query( |
|
|
query_texts=[query], |
|
|
n_results=n_results |
|
|
) |
|
|
|
|
|
retrieved_docs = [] |
|
|
if results['documents']: |
|
|
for doc, distance, metadata in zip( |
|
|
results['documents'][0], |
|
|
results['distances'][0], |
|
|
results['metadatas'][0] |
|
|
): |
|
|
retrieved_docs.append({ |
|
|
'document': doc, |
|
|
'distance': distance, |
|
|
'metadata': metadata, |
|
|
'relevance_score': 1 - distance |
|
|
}) |
|
|
|
|
|
return retrieved_docs |
|
|
|
|
|
def get_all_documents_count(self) -> int: |
|
|
"""Get total number of documents in collection.""" |
|
|
return self.collection.count() |
|
|
|
|
|
def clear_collection(self) -> None: |
|
|
"""Clear all documents from collection (for reset).""" |
|
|
self.collection.delete(where={}) |
|
|
print("Collection cleared") |
|
|
|
|
|
def get_collection_info(self) -> Dict: |
|
|
"""Get information about the collection.""" |
|
|
return { |
|
|
"name": self.collection.name, |
|
|
"document_count": self.collection.count(), |
|
|
"metadata": self.collection.metadata |
|
|
} |
|
|
|