Spaces:
Running
Running
File size: 7,367 Bytes
1367957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# processing/chroma_manager.py
"""
ChromaDB vector database implementation
"""
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
import numpy as np
from embeddings.embedding_models import EmbeddingManager
from embeddings.text_chunking import ResearchPaperChunker
class ChromaManager:
"""ChromaDB vector database manager"""
def __init__(self,
persist_directory: str = "./data/vector_db/chromadb",
embedding_model: str = "all-MiniLM-L6-v2",
chunk_strategy: str = "semantic"):
self.persist_directory = persist_directory
self.embedding_manager = EmbeddingManager(embedding_model)
self.chunker = ResearchPaperChunker(chunk_strategy)
# Initialize ChromaDB client
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name="medical_research_papers",
metadata={"description": "Medical research papers vector store"}
)
print(f"β
ChromaDB initialized at: {persist_directory}")
def add_papers(self, papers: List[Dict[str, Any]], batch_size: int = 100) -> bool:
"""Add papers to ChromaDB vector store"""
try:
# Chunk all papers
all_chunks = self.chunker.batch_chunk_papers(papers)
if not all_chunks:
print("β οΈ No chunks generated from papers")
return False
# Prepare data for ChromaDB
documents = []
metadatas = []
ids = []
for i, chunk in enumerate(all_chunks):
documents.append(chunk['text'])
metadatas.append({
'paper_id': chunk['paper_id'],
'paper_title': chunk['paper_title'],
'source': chunk['source'],
'domain': chunk['domain'],
'publication_date': chunk.get('publication_date', ''),
'chunk_index': i,
'chunk_strategy': chunk.get('chunk_strategy', 'semantic'),
'start_char': chunk.get('start_char', 0),
'end_char': chunk.get('end_char', 0)
})
ids.append(f"{chunk['paper_id']}_chunk_{i}")
# Add to ChromaDB in batches
total_chunks = len(documents)
for i in range(0, total_chunks, batch_size):
batch_end = min(i + batch_size, total_chunks)
self.collection.add(
documents=documents[i:batch_end],
metadatas=metadatas[i:batch_end],
ids=ids[i:batch_end]
)
print(f"π¦ Added batch {i // batch_size + 1}: {i}-{batch_end - 1} chunks")
print(f"β
Successfully added {total_chunks} chunks from {len(papers)} papers")
return True
except Exception as e:
print(f"β Error adding papers to ChromaDB: {e}")
return False
def search(self,
query: str,
domain: str = None,
n_results: int = 10,
where_filter: Dict = None) -> List[Dict[str, Any]]:
"""Search for similar paper chunks"""
try:
# Build filters
filters = {}
if domain:
filters['domain'] = domain
if where_filter:
filters.update(where_filter)
# Perform search
results = self.collection.query(
query_texts=[query],
n_results=n_results,
where=filters if filters else None
)
# Format results
formatted_results = []
if results['documents']:
for i in range(len(results['documents'][0])):
formatted_results.append({
'text': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if results['distances'] else None,
'id': results['ids'][0][i]
})
return formatted_results
except Exception as e:
print(f"β ChromaDB search error: {e}")
return []
def get_collection_stats(self) -> Dict[str, Any]:
"""Get statistics about the collection"""
try:
count = self.collection.count()
return {
"total_chunks": count,
"persist_directory": self.persist_directory,
"embedding_model": self.embedding_manager.model_name,
"chunk_strategy": self.chunker.strategy
}
except Exception as e:
print(f"β Error getting collection stats: {e}")
return {}
def delete_paper(self, paper_id: str) -> bool:
"""Delete all chunks for a specific paper"""
try:
# Get all chunks for this paper
results = self.collection.get(where={'paper_id': paper_id})
if results['ids']:
self.collection.delete(ids=results['ids'])
print(f"β
Deleted {len(results['ids'])} chunks for paper {paper_id}")
return True
else:
print(f"β οΈ No chunks found for paper {paper_id}")
return False
except Exception as e:
print(f"β Error deleting paper {paper_id}: {e}")
return False
# Quick test
def test_chroma_manager():
"""Test ChromaDB manager"""
test_papers = [
{
'id': 'test_001',
'title': 'AI in Medical Imaging',
'abstract': 'Deep learning transforms medical image analysis with improved accuracy.',
'source': 'test',
'domain': 'medical_imaging'
},
{
'id': 'test_002',
'title': 'Genomics and Machine Learning',
'abstract': 'Machine learning methods advance genomic sequence analysis and prediction.',
'source': 'test',
'domain': 'genomics'
}
]
print("π§ͺ Testing ChromaDB Manager")
print("=" * 50)
try:
manager = ChromaManager(persist_directory="./data/test_chromadb")
# Add test papers
success = manager.add_papers(test_papers)
if success:
print("β
Papers added successfully")
# Test search
results = manager.search("medical image analysis", n_results=5)
print(f"π Search results: {len(results)} chunks found")
# Get stats
stats = manager.get_collection_stats()
print(f"π Collection stats: {stats}")
else:
print("β Failed to add papers")
except Exception as e:
print(f"β ChromaDB test failed: {e}")
if __name__ == "__main__":
test_chroma_manager() |