MedSearchPro / processing /chroma_manager.py
paulhemb's picture
Initial Backend Deployment
1367957
# processing/chroma_manager.py
"""
ChromaDB vector database implementation
"""
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
import numpy as np
from embeddings.embedding_models import EmbeddingManager
from embeddings.text_chunking import ResearchPaperChunker
class ChromaManager:
"""ChromaDB vector database manager"""
def __init__(self,
persist_directory: str = "./data/vector_db/chromadb",
embedding_model: str = "all-MiniLM-L6-v2",
chunk_strategy: str = "semantic"):
self.persist_directory = persist_directory
self.embedding_manager = EmbeddingManager(embedding_model)
self.chunker = ResearchPaperChunker(chunk_strategy)
# Initialize ChromaDB client
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name="medical_research_papers",
metadata={"description": "Medical research papers vector store"}
)
print(f"βœ… ChromaDB initialized at: {persist_directory}")
def add_papers(self, papers: List[Dict[str, Any]], batch_size: int = 100) -> bool:
"""Add papers to ChromaDB vector store"""
try:
# Chunk all papers
all_chunks = self.chunker.batch_chunk_papers(papers)
if not all_chunks:
print("⚠️ No chunks generated from papers")
return False
# Prepare data for ChromaDB
documents = []
metadatas = []
ids = []
for i, chunk in enumerate(all_chunks):
documents.append(chunk['text'])
metadatas.append({
'paper_id': chunk['paper_id'],
'paper_title': chunk['paper_title'],
'source': chunk['source'],
'domain': chunk['domain'],
'publication_date': chunk.get('publication_date', ''),
'chunk_index': i,
'chunk_strategy': chunk.get('chunk_strategy', 'semantic'),
'start_char': chunk.get('start_char', 0),
'end_char': chunk.get('end_char', 0)
})
ids.append(f"{chunk['paper_id']}_chunk_{i}")
# Add to ChromaDB in batches
total_chunks = len(documents)
for i in range(0, total_chunks, batch_size):
batch_end = min(i + batch_size, total_chunks)
self.collection.add(
documents=documents[i:batch_end],
metadatas=metadatas[i:batch_end],
ids=ids[i:batch_end]
)
print(f"πŸ“¦ Added batch {i // batch_size + 1}: {i}-{batch_end - 1} chunks")
print(f"βœ… Successfully added {total_chunks} chunks from {len(papers)} papers")
return True
except Exception as e:
print(f"❌ Error adding papers to ChromaDB: {e}")
return False
def search(self,
query: str,
domain: str = None,
n_results: int = 10,
where_filter: Dict = None) -> List[Dict[str, Any]]:
"""Search for similar paper chunks"""
try:
# Build filters
filters = {}
if domain:
filters['domain'] = domain
if where_filter:
filters.update(where_filter)
# Perform search
results = self.collection.query(
query_texts=[query],
n_results=n_results,
where=filters if filters else None
)
# Format results
formatted_results = []
if results['documents']:
for i in range(len(results['documents'][0])):
formatted_results.append({
'text': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if results['distances'] else None,
'id': results['ids'][0][i]
})
return formatted_results
except Exception as e:
print(f"❌ ChromaDB search error: {e}")
return []
def get_collection_stats(self) -> Dict[str, Any]:
"""Get statistics about the collection"""
try:
count = self.collection.count()
return {
"total_chunks": count,
"persist_directory": self.persist_directory,
"embedding_model": self.embedding_manager.model_name,
"chunk_strategy": self.chunker.strategy
}
except Exception as e:
print(f"❌ Error getting collection stats: {e}")
return {}
def delete_paper(self, paper_id: str) -> bool:
"""Delete all chunks for a specific paper"""
try:
# Get all chunks for this paper
results = self.collection.get(where={'paper_id': paper_id})
if results['ids']:
self.collection.delete(ids=results['ids'])
print(f"βœ… Deleted {len(results['ids'])} chunks for paper {paper_id}")
return True
else:
print(f"⚠️ No chunks found for paper {paper_id}")
return False
except Exception as e:
print(f"❌ Error deleting paper {paper_id}: {e}")
return False
# Quick test
def test_chroma_manager():
"""Test ChromaDB manager"""
test_papers = [
{
'id': 'test_001',
'title': 'AI in Medical Imaging',
'abstract': 'Deep learning transforms medical image analysis with improved accuracy.',
'source': 'test',
'domain': 'medical_imaging'
},
{
'id': 'test_002',
'title': 'Genomics and Machine Learning',
'abstract': 'Machine learning methods advance genomic sequence analysis and prediction.',
'source': 'test',
'domain': 'genomics'
}
]
print("πŸ§ͺ Testing ChromaDB Manager")
print("=" * 50)
try:
manager = ChromaManager(persist_directory="./data/test_chromadb")
# Add test papers
success = manager.add_papers(test_papers)
if success:
print("βœ… Papers added successfully")
# Test search
results = manager.search("medical image analysis", n_results=5)
print(f"πŸ” Search results: {len(results)} chunks found")
# Get stats
stats = manager.get_collection_stats()
print(f"πŸ“Š Collection stats: {stats}")
else:
print("❌ Failed to add papers")
except Exception as e:
print(f"❌ ChromaDB test failed: {e}")
if __name__ == "__main__":
test_chroma_manager()