""" Knowledge Store implementation for Pharmaceutical R&D Knowledge Ecosystem. Includes TinyDB for structured data and ChromaDB for vector embeddings. """ import os import json from typing import Dict, List, Any, Optional, Union from tinydb import TinyDB, Query from tinydb.middlewares import CachingMiddleware from tinydb.storages import JSONStorage from langchain_community.vectorstores import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings class KnowledgeStore: """ Knowledge store combining structured database (TinyDB) and vector store (ChromaDB). """ def __init__(self, data_dir="./data"): """Initialize knowledge stores with the specified data directory.""" # Ensure directories exist os.makedirs(os.path.join(data_dir, "nosql_db"), exist_ok=True) os.makedirs(os.path.join(data_dir, "vector_db"), exist_ok=True) # Initialize TinyDB with caching for better performance self.db_path = os.path.join(data_dir, "nosql_db", "protocol_knowledge.json") self.db = TinyDB( self.db_path, storage=CachingMiddleware(JSONStorage) ) # Create tables for different entity types self.documents_table = self.db.table('documents') self.studies_table = self.db.table('studies') self.compounds_table = self.db.table('compounds') self.objectives_table = self.db.table('objectives') self.endpoints_table = self.db.table('endpoints') self.population_table = self.db.table('population_criteria') self.arms_table = self.db.table('study_arms') self.assessments_table = self.db.table('assessments') self.analytes_table = self.db.table('analytes') # Initialize vector store with sentence-transformers embedding self.embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) # Initialize vector store directory self.vector_db_path = os.path.join(data_dir, "vector_db") try: self.vector_db = Chroma( persist_directory=self.vector_db_path, embedding_function=self.embeddings ) print(f"Loaded existing vector store from {self.vector_db_path}") except Exception as e: print(f"Creating new vector store: {e}") self.vector_db = Chroma( embedding_function=self.embeddings, persist_directory=self.vector_db_path ) # Query constructor self.Query = Query() # ========================================================================= # Structured Knowledge Store Methods (TinyDB) # ========================================================================= def store_document_metadata(self, metadata: Dict) -> int: """Store basic document metadata and return the document ID.""" # Check if document already exists doc_id = metadata.get('id') or metadata.get('document_id') protocol_id = metadata.get('protocol_id') existing = None if doc_id: existing = self.documents_table.get(self.Query.document_id == doc_id) elif protocol_id: existing = self.documents_table.get(self.Query.protocol_id == protocol_id) if existing: self.documents_table.update(metadata, doc_ids=[existing.doc_id]) return existing.doc_id return self.documents_table.insert(metadata) def store_study_info(self, study_info: Dict) -> int: """Store study information extracted from a protocol.""" # Check if study already exists by protocol ID protocol_id = study_info.get('protocol_id') existing = self.studies_table.get(self.Query.protocol_id == protocol_id) if existing: self.studies_table.update(study_info, doc_ids=[existing.doc_id]) return existing.doc_id return self.studies_table.insert(study_info) def store_compound_info(self, compound_info: Dict) -> int: """Store compound information.""" compound_id = compound_info.get('compound_id') existing = self.compounds_table.get(self.Query.compound_id == compound_id) if existing: self.compounds_table.update(compound_info, doc_ids=[existing.doc_id]) return existing.doc_id return self.compounds_table.insert(compound_info) def store_objectives(self, protocol_id: str, objectives: List[Dict]) -> List[int]: """Store objectives for a protocol.""" # First remove any existing objectives for this protocol self.objectives_table.remove(self.Query.protocol_id == protocol_id) # Then insert the new objectives doc_ids = [] for objective in objectives: objective['protocol_id'] = protocol_id # Link back to protocol doc_ids.append(self.objectives_table.insert(objective)) return doc_ids def store_endpoints(self, protocol_id: str, endpoints: List[Dict]) -> List[int]: """Store endpoints for a protocol.""" self.endpoints_table.remove(self.Query.protocol_id == protocol_id) doc_ids = [] for endpoint in endpoints: endpoint['protocol_id'] = protocol_id doc_ids.append(self.endpoints_table.insert(endpoint)) return doc_ids def store_population_criteria(self, protocol_id: str, criteria: List[Dict]) -> List[int]: """Store inclusion/exclusion criteria.""" self.population_table.remove(self.Query.protocol_id == protocol_id) doc_ids = [] for criterion in criteria: criterion['protocol_id'] = protocol_id doc_ids.append(self.population_table.insert(criterion)) return doc_ids def store_study_arms(self, protocol_id: str, arms: List[Dict]) -> List[int]: """Store study arms/cohorts.""" self.arms_table.remove(self.Query.protocol_id == protocol_id) doc_ids = [] for arm in arms: arm['protocol_id'] = protocol_id doc_ids.append(self.arms_table.insert(arm)) return doc_ids def store_assessments(self, protocol_id: str, assessments: List[Dict]) -> List[int]: """Store assessments/procedures.""" self.assessments_table.remove(self.Query.protocol_id == protocol_id) doc_ids = [] for assessment in assessments: assessment['protocol_id'] = protocol_id doc_ids.append(self.assessments_table.insert(assessment)) return doc_ids # ========================================================================= # Query Methods for Structured Knowledge # ========================================================================= def get_study_by_protocol_id(self, protocol_id: str) -> Optional[Dict]: """Retrieve study information by protocol ID.""" return self.studies_table.get(self.Query.protocol_id == protocol_id) def get_all_studies(self) -> List[Dict]: """Retrieve all studies.""" return self.studies_table.all() def get_objectives_by_protocol_id(self, protocol_id: str) -> List[Dict]: """Retrieve all objectives for a protocol.""" return self.objectives_table.search(self.Query.protocol_id == protocol_id) def get_endpoints_by_protocol_id(self, protocol_id: str) -> List[Dict]: """Retrieve all endpoints for a protocol.""" return self.endpoints_table.search(self.Query.protocol_id == protocol_id) def get_population_criteria_by_protocol_id(self, protocol_id: str, criterion_type: Optional[str] = None) -> List[Dict]: """Retrieve population criteria for a protocol, optionally filtered by type (Inclusion/Exclusion).""" if criterion_type: return self.population_table.search( (self.Query.protocol_id == protocol_id) & (self.Query.criterion_type == criterion_type) ) return self.population_table.search(self.Query.protocol_id == protocol_id) def search_criteria_by_keyword(self, keyword: str) -> List[Dict]: """Search inclusion/exclusion criteria containing a keyword.""" return self.population_table.search(self.Query.text.search(keyword, flags='i')) def get_all_documents(self) -> List[Dict]: """Retrieve metadata for all stored documents.""" return self.documents_table.all() def get_document_by_id(self, document_id: str) -> Optional[Dict]: """Retrieve document by ID.""" return self.documents_table.get(self.Query.document_id == document_id) def get_documents_by_protocol_id(self, protocol_id: str) -> List[Dict]: """Retrieve all documents associated with a protocol ID.""" return self.documents_table.search(self.Query.protocol_id == protocol_id) def get_related_documents(self, protocol_id: str) -> List[Dict]: """Find documents related to a protocol (e.g., protocol and its SAP).""" return self.documents_table.search( (self.Query.protocol_id == protocol_id) | (self.Query.related_protocols.any([protocol_id])) ) def get_assessments_by_protocol_id(self, protocol_id: str) -> List[Dict]: """Retrieve all assessments for a protocol.""" return self.assessments_table.search(self.Query.protocol_id == protocol_id) # Example of a more complex query that combines data def get_protocol_summary(self, protocol_id: str) -> Dict: """Create a comprehensive summary of a protocol.""" study = self.get_study_by_protocol_id(protocol_id) if not study: return {} objectives = self.get_objectives_by_protocol_id(protocol_id) endpoints = self.get_endpoints_by_protocol_id(protocol_id) primary_objectives = [obj for obj in objectives if obj.get('type') == 'Primary'] secondary_objectives = [obj for obj in objectives if obj.get('type') == 'Secondary'] inclusion = self.population_table.search( (self.Query.protocol_id == protocol_id) & (self.Query.criterion_type == 'Inclusion') ) exclusion = self.population_table.search( (self.Query.protocol_id == protocol_id) & (self.Query.criterion_type == 'Exclusion') ) return { "protocol_id": protocol_id, "title": study.get('title', ''), "phase": study.get('phase', ''), "design": study.get('design_type', ''), "primary_objectives": primary_objectives, "secondary_objectives": secondary_objectives, "primary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Primary'], "secondary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Secondary'], "inclusion_criteria": inclusion, "exclusion_criteria": exclusion, "planned_enrollment": study.get('planned_enrollment', '') } def find_document_entity_links(self, entity_type: str, protocol_id: str = None) -> Dict: """ Find links between documents and specific entity types. Useful for traceability analysis. """ entity_table = None if entity_type == "objectives": entity_table = self.objectives_table elif entity_type == "endpoints": entity_table = self.endpoints_table elif entity_type == "population": entity_table = self.population_table elif entity_type == "assessments": entity_table = self.assessments_table if not entity_table: return {"error": f"Unknown entity type: {entity_type}"} # Get all documents documents = self.get_all_documents() if not protocol_id else self.get_documents_by_protocol_id(protocol_id) result = {} for doc in documents: doc_id = doc.get('document_id') doc_protocol_id = doc.get('protocol_id') # Find all entities for this protocol if entity_table == self.objectives_table: entities = self.get_objectives_by_protocol_id(doc_protocol_id) elif entity_table == self.endpoints_table: entities = self.get_endpoints_by_protocol_id(doc_protocol_id) elif entity_table == self.population_table: entities = self.get_population_criteria_by_protocol_id(doc_protocol_id) elif entity_table == self.assessments_table: entities = self.get_assessments_by_protocol_id(doc_protocol_id) result[doc_id] = { "document_title": doc.get('title', ''), "document_type": doc.get('type', ''), "protocol_id": doc_protocol_id, "entities": entities } return result # ========================================================================= # Vector Store Methods # ========================================================================= def add_documents(self, documents: List[Dict]): """ Add documents to the vector store. Each document should have 'page_content' and 'metadata' fields. """ texts = [doc['page_content'] for doc in documents] metadatas = [doc['metadata'] for doc in documents] # Add to vector store try: ids = self.vector_db.add_texts(texts=texts, metadatas=metadatas) self.vector_db.persist() # Save to disk return {"status": "success", "added": len(texts), "ids": ids} except Exception as e: return {"status": "error", "message": str(e)} def similarity_search(self, query: str, k: int = 5, filter_dict: Dict = None): """ Search for documents similar to the query. Optionally filter by metadata. """ try: results = self.vector_db.similarity_search( query=query, k=k, filter=filter_dict ) return results except Exception as e: print(f"Error in similarity search: {e}") return [] def similarity_search_with_score(self, query: str, k: int = 5, filter_dict: Dict = None): """ Search for documents similar to the query, returning relevance scores. """ try: results = self.vector_db.similarity_search_with_score( query=query, k=k, filter=filter_dict ) return results except Exception as e: print(f"Error in similarity search with score: {e}") return [] def get_vector_store_stats(self): """Get statistics about the vector store.""" try: collection = self.vector_db._collection count = collection.count() return { "document_count": count, "embedding_dimension": self.embeddings.embedding_size, "model": self.embeddings.model_name } except Exception as e: return {"error": str(e)}