Spaces:
Running
Running
| # processing/chroma_manager.py | |
| """ | |
| ChromaDB vector database implementation | |
| """ | |
| import chromadb | |
| from chromadb.config import Settings | |
| from typing import List, Dict, Any, Optional | |
| import numpy as np | |
| from embeddings.embedding_models import EmbeddingManager | |
| from embeddings.text_chunking import ResearchPaperChunker | |
| class ChromaManager: | |
| """ChromaDB vector database manager""" | |
| def __init__(self, | |
| persist_directory: str = "./data/vector_db/chromadb", | |
| embedding_model: str = "all-MiniLM-L6-v2", | |
| chunk_strategy: str = "semantic"): | |
| self.persist_directory = persist_directory | |
| self.embedding_manager = EmbeddingManager(embedding_model) | |
| self.chunker = ResearchPaperChunker(chunk_strategy) | |
| # Initialize ChromaDB client | |
| self.client = chromadb.PersistentClient( | |
| path=persist_directory, | |
| settings=Settings(anonymized_telemetry=False) | |
| ) | |
| # Get or create collection | |
| self.collection = self.client.get_or_create_collection( | |
| name="medical_research_papers", | |
| metadata={"description": "Medical research papers vector store"} | |
| ) | |
| print(f"β ChromaDB initialized at: {persist_directory}") | |
| def add_papers(self, papers: List[Dict[str, Any]], batch_size: int = 100) -> bool: | |
| """Add papers to ChromaDB vector store""" | |
| try: | |
| # Chunk all papers | |
| all_chunks = self.chunker.batch_chunk_papers(papers) | |
| if not all_chunks: | |
| print("β οΈ No chunks generated from papers") | |
| return False | |
| # Prepare data for ChromaDB | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| for i, chunk in enumerate(all_chunks): | |
| documents.append(chunk['text']) | |
| metadatas.append({ | |
| 'paper_id': chunk['paper_id'], | |
| 'paper_title': chunk['paper_title'], | |
| 'source': chunk['source'], | |
| 'domain': chunk['domain'], | |
| 'publication_date': chunk.get('publication_date', ''), | |
| 'chunk_index': i, | |
| 'chunk_strategy': chunk.get('chunk_strategy', 'semantic'), | |
| 'start_char': chunk.get('start_char', 0), | |
| 'end_char': chunk.get('end_char', 0) | |
| }) | |
| ids.append(f"{chunk['paper_id']}_chunk_{i}") | |
| # Add to ChromaDB in batches | |
| total_chunks = len(documents) | |
| for i in range(0, total_chunks, batch_size): | |
| batch_end = min(i + batch_size, total_chunks) | |
| self.collection.add( | |
| documents=documents[i:batch_end], | |
| metadatas=metadatas[i:batch_end], | |
| ids=ids[i:batch_end] | |
| ) | |
| print(f"π¦ Added batch {i // batch_size + 1}: {i}-{batch_end - 1} chunks") | |
| print(f"β Successfully added {total_chunks} chunks from {len(papers)} papers") | |
| return True | |
| except Exception as e: | |
| print(f"β Error adding papers to ChromaDB: {e}") | |
| return False | |
| def search(self, | |
| query: str, | |
| domain: str = None, | |
| n_results: int = 10, | |
| where_filter: Dict = None) -> List[Dict[str, Any]]: | |
| """Search for similar paper chunks""" | |
| try: | |
| # Build filters | |
| filters = {} | |
| if domain: | |
| filters['domain'] = domain | |
| if where_filter: | |
| filters.update(where_filter) | |
| # Perform search | |
| results = self.collection.query( | |
| query_texts=[query], | |
| n_results=n_results, | |
| where=filters if filters else None | |
| ) | |
| # Format results | |
| formatted_results = [] | |
| if results['documents']: | |
| for i in range(len(results['documents'][0])): | |
| formatted_results.append({ | |
| 'text': results['documents'][0][i], | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i] if results['distances'] else None, | |
| 'id': results['ids'][0][i] | |
| }) | |
| return formatted_results | |
| except Exception as e: | |
| print(f"β ChromaDB search error: {e}") | |
| return [] | |
| def get_collection_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the collection""" | |
| try: | |
| count = self.collection.count() | |
| return { | |
| "total_chunks": count, | |
| "persist_directory": self.persist_directory, | |
| "embedding_model": self.embedding_manager.model_name, | |
| "chunk_strategy": self.chunker.strategy | |
| } | |
| except Exception as e: | |
| print(f"β Error getting collection stats: {e}") | |
| return {} | |
| def delete_paper(self, paper_id: str) -> bool: | |
| """Delete all chunks for a specific paper""" | |
| try: | |
| # Get all chunks for this paper | |
| results = self.collection.get(where={'paper_id': paper_id}) | |
| if results['ids']: | |
| self.collection.delete(ids=results['ids']) | |
| print(f"β Deleted {len(results['ids'])} chunks for paper {paper_id}") | |
| return True | |
| else: | |
| print(f"β οΈ No chunks found for paper {paper_id}") | |
| return False | |
| except Exception as e: | |
| print(f"β Error deleting paper {paper_id}: {e}") | |
| return False | |
| # Quick test | |
| def test_chroma_manager(): | |
| """Test ChromaDB manager""" | |
| test_papers = [ | |
| { | |
| 'id': 'test_001', | |
| 'title': 'AI in Medical Imaging', | |
| 'abstract': 'Deep learning transforms medical image analysis with improved accuracy.', | |
| 'source': 'test', | |
| 'domain': 'medical_imaging' | |
| }, | |
| { | |
| 'id': 'test_002', | |
| 'title': 'Genomics and Machine Learning', | |
| 'abstract': 'Machine learning methods advance genomic sequence analysis and prediction.', | |
| 'source': 'test', | |
| 'domain': 'genomics' | |
| } | |
| ] | |
| print("π§ͺ Testing ChromaDB Manager") | |
| print("=" * 50) | |
| try: | |
| manager = ChromaManager(persist_directory="./data/test_chromadb") | |
| # Add test papers | |
| success = manager.add_papers(test_papers) | |
| if success: | |
| print("β Papers added successfully") | |
| # Test search | |
| results = manager.search("medical image analysis", n_results=5) | |
| print(f"π Search results: {len(results)} chunks found") | |
| # Get stats | |
| stats = manager.get_collection_stats() | |
| print(f"π Collection stats: {stats}") | |
| else: | |
| print("β Failed to add papers") | |
| except Exception as e: | |
| print(f"β ChromaDB test failed: {e}") | |
| if __name__ == "__main__": | |
| test_chroma_manager() |