Spaces:
Runtime error
Runtime error
| """ | |
| Knowledge Store implementation for Pharmaceutical R&D Knowledge Ecosystem. | |
| Includes TinyDB for structured data and ChromaDB for vector embeddings. | |
| """ | |
| import os | |
| import json | |
| from typing import Dict, List, Any, Optional, Union | |
| from tinydb import TinyDB, Query | |
| from tinydb.middlewares import CachingMiddleware | |
| from tinydb.storages import JSONStorage | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| class KnowledgeStore: | |
| """ | |
| Knowledge store combining structured database (TinyDB) and vector store (ChromaDB). | |
| """ | |
| def __init__(self, data_dir="./data"): | |
| """Initialize knowledge stores with the specified data directory.""" | |
| # Ensure directories exist | |
| os.makedirs(os.path.join(data_dir, "nosql_db"), exist_ok=True) | |
| os.makedirs(os.path.join(data_dir, "vector_db"), exist_ok=True) | |
| # Initialize TinyDB with caching for better performance | |
| self.db_path = os.path.join(data_dir, "nosql_db", "protocol_knowledge.json") | |
| self.db = TinyDB( | |
| self.db_path, | |
| storage=CachingMiddleware(JSONStorage) | |
| ) | |
| # Create tables for different entity types | |
| self.documents_table = self.db.table('documents') | |
| self.studies_table = self.db.table('studies') | |
| self.compounds_table = self.db.table('compounds') | |
| self.objectives_table = self.db.table('objectives') | |
| self.endpoints_table = self.db.table('endpoints') | |
| self.population_table = self.db.table('population_criteria') | |
| self.arms_table = self.db.table('study_arms') | |
| self.assessments_table = self.db.table('assessments') | |
| self.analytes_table = self.db.table('analytes') | |
| # Initialize vector store with sentence-transformers embedding | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # Initialize vector store directory | |
| self.vector_db_path = os.path.join(data_dir, "vector_db") | |
| try: | |
| self.vector_db = Chroma( | |
| persist_directory=self.vector_db_path, | |
| embedding_function=self.embeddings | |
| ) | |
| print(f"Loaded existing vector store from {self.vector_db_path}") | |
| except Exception as e: | |
| print(f"Creating new vector store: {e}") | |
| self.vector_db = Chroma( | |
| embedding_function=self.embeddings, | |
| persist_directory=self.vector_db_path | |
| ) | |
| # Query constructor | |
| self.Query = Query() | |
| # ========================================================================= | |
| # Structured Knowledge Store Methods (TinyDB) | |
| # ========================================================================= | |
| def store_document_metadata(self, metadata: Dict) -> int: | |
| """Store basic document metadata and return the document ID.""" | |
| # Check if document already exists | |
| doc_id = metadata.get('id') or metadata.get('document_id') | |
| protocol_id = metadata.get('protocol_id') | |
| existing = None | |
| if doc_id: | |
| existing = self.documents_table.get(self.Query.document_id == doc_id) | |
| elif protocol_id: | |
| existing = self.documents_table.get(self.Query.protocol_id == protocol_id) | |
| if existing: | |
| self.documents_table.update(metadata, doc_ids=[existing.doc_id]) | |
| return existing.doc_id | |
| return self.documents_table.insert(metadata) | |
| def store_study_info(self, study_info: Dict) -> int: | |
| """Store study information extracted from a protocol.""" | |
| # Check if study already exists by protocol ID | |
| protocol_id = study_info.get('protocol_id') | |
| existing = self.studies_table.get(self.Query.protocol_id == protocol_id) | |
| if existing: | |
| self.studies_table.update(study_info, doc_ids=[existing.doc_id]) | |
| return existing.doc_id | |
| return self.studies_table.insert(study_info) | |
| def store_compound_info(self, compound_info: Dict) -> int: | |
| """Store compound information.""" | |
| compound_id = compound_info.get('compound_id') | |
| existing = self.compounds_table.get(self.Query.compound_id == compound_id) | |
| if existing: | |
| self.compounds_table.update(compound_info, doc_ids=[existing.doc_id]) | |
| return existing.doc_id | |
| return self.compounds_table.insert(compound_info) | |
| def store_objectives(self, protocol_id: str, objectives: List[Dict]) -> List[int]: | |
| """Store objectives for a protocol.""" | |
| # First remove any existing objectives for this protocol | |
| self.objectives_table.remove(self.Query.protocol_id == protocol_id) | |
| # Then insert the new objectives | |
| doc_ids = [] | |
| for objective in objectives: | |
| objective['protocol_id'] = protocol_id # Link back to protocol | |
| doc_ids.append(self.objectives_table.insert(objective)) | |
| return doc_ids | |
| def store_endpoints(self, protocol_id: str, endpoints: List[Dict]) -> List[int]: | |
| """Store endpoints for a protocol.""" | |
| self.endpoints_table.remove(self.Query.protocol_id == protocol_id) | |
| doc_ids = [] | |
| for endpoint in endpoints: | |
| endpoint['protocol_id'] = protocol_id | |
| doc_ids.append(self.endpoints_table.insert(endpoint)) | |
| return doc_ids | |
| def store_population_criteria(self, protocol_id: str, criteria: List[Dict]) -> List[int]: | |
| """Store inclusion/exclusion criteria.""" | |
| self.population_table.remove(self.Query.protocol_id == protocol_id) | |
| doc_ids = [] | |
| for criterion in criteria: | |
| criterion['protocol_id'] = protocol_id | |
| doc_ids.append(self.population_table.insert(criterion)) | |
| return doc_ids | |
| def store_study_arms(self, protocol_id: str, arms: List[Dict]) -> List[int]: | |
| """Store study arms/cohorts.""" | |
| self.arms_table.remove(self.Query.protocol_id == protocol_id) | |
| doc_ids = [] | |
| for arm in arms: | |
| arm['protocol_id'] = protocol_id | |
| doc_ids.append(self.arms_table.insert(arm)) | |
| return doc_ids | |
| def store_assessments(self, protocol_id: str, assessments: List[Dict]) -> List[int]: | |
| """Store assessments/procedures.""" | |
| self.assessments_table.remove(self.Query.protocol_id == protocol_id) | |
| doc_ids = [] | |
| for assessment in assessments: | |
| assessment['protocol_id'] = protocol_id | |
| doc_ids.append(self.assessments_table.insert(assessment)) | |
| return doc_ids | |
| # ========================================================================= | |
| # Query Methods for Structured Knowledge | |
| # ========================================================================= | |
| def get_study_by_protocol_id(self, protocol_id: str) -> Optional[Dict]: | |
| """Retrieve study information by protocol ID.""" | |
| return self.studies_table.get(self.Query.protocol_id == protocol_id) | |
| def get_all_studies(self) -> List[Dict]: | |
| """Retrieve all studies.""" | |
| return self.studies_table.all() | |
| def get_objectives_by_protocol_id(self, protocol_id: str) -> List[Dict]: | |
| """Retrieve all objectives for a protocol.""" | |
| return self.objectives_table.search(self.Query.protocol_id == protocol_id) | |
| def get_endpoints_by_protocol_id(self, protocol_id: str) -> List[Dict]: | |
| """Retrieve all endpoints for a protocol.""" | |
| return self.endpoints_table.search(self.Query.protocol_id == protocol_id) | |
| def get_population_criteria_by_protocol_id(self, protocol_id: str, criterion_type: Optional[str] = None) -> List[Dict]: | |
| """Retrieve population criteria for a protocol, optionally filtered by type (Inclusion/Exclusion).""" | |
| if criterion_type: | |
| return self.population_table.search( | |
| (self.Query.protocol_id == protocol_id) & | |
| (self.Query.criterion_type == criterion_type) | |
| ) | |
| return self.population_table.search(self.Query.protocol_id == protocol_id) | |
| def search_criteria_by_keyword(self, keyword: str) -> List[Dict]: | |
| """Search inclusion/exclusion criteria containing a keyword.""" | |
| return self.population_table.search(self.Query.text.search(keyword, flags='i')) | |
| def get_all_documents(self) -> List[Dict]: | |
| """Retrieve metadata for all stored documents.""" | |
| return self.documents_table.all() | |
| def get_document_by_id(self, document_id: str) -> Optional[Dict]: | |
| """Retrieve document by ID.""" | |
| return self.documents_table.get(self.Query.document_id == document_id) | |
| def get_documents_by_protocol_id(self, protocol_id: str) -> List[Dict]: | |
| """Retrieve all documents associated with a protocol ID.""" | |
| return self.documents_table.search(self.Query.protocol_id == protocol_id) | |
| def get_related_documents(self, protocol_id: str) -> List[Dict]: | |
| """Find documents related to a protocol (e.g., protocol and its SAP).""" | |
| return self.documents_table.search( | |
| (self.Query.protocol_id == protocol_id) | | |
| (self.Query.related_protocols.any([protocol_id])) | |
| ) | |
| def get_assessments_by_protocol_id(self, protocol_id: str) -> List[Dict]: | |
| """Retrieve all assessments for a protocol.""" | |
| return self.assessments_table.search(self.Query.protocol_id == protocol_id) | |
| # Example of a more complex query that combines data | |
| def get_protocol_summary(self, protocol_id: str) -> Dict: | |
| """Create a comprehensive summary of a protocol.""" | |
| study = self.get_study_by_protocol_id(protocol_id) | |
| if not study: | |
| return {} | |
| objectives = self.get_objectives_by_protocol_id(protocol_id) | |
| endpoints = self.get_endpoints_by_protocol_id(protocol_id) | |
| primary_objectives = [obj for obj in objectives if obj.get('type') == 'Primary'] | |
| secondary_objectives = [obj for obj in objectives if obj.get('type') == 'Secondary'] | |
| inclusion = self.population_table.search( | |
| (self.Query.protocol_id == protocol_id) & | |
| (self.Query.criterion_type == 'Inclusion') | |
| ) | |
| exclusion = self.population_table.search( | |
| (self.Query.protocol_id == protocol_id) & | |
| (self.Query.criterion_type == 'Exclusion') | |
| ) | |
| return { | |
| "protocol_id": protocol_id, | |
| "title": study.get('title', ''), | |
| "phase": study.get('phase', ''), | |
| "design": study.get('design_type', ''), | |
| "primary_objectives": primary_objectives, | |
| "secondary_objectives": secondary_objectives, | |
| "primary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Primary'], | |
| "secondary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Secondary'], | |
| "inclusion_criteria": inclusion, | |
| "exclusion_criteria": exclusion, | |
| "planned_enrollment": study.get('planned_enrollment', '') | |
| } | |
| def find_document_entity_links(self, entity_type: str, protocol_id: str = None) -> Dict: | |
| """ | |
| Find links between documents and specific entity types. | |
| Useful for traceability analysis. | |
| """ | |
| entity_table = None | |
| if entity_type == "objectives": | |
| entity_table = self.objectives_table | |
| elif entity_type == "endpoints": | |
| entity_table = self.endpoints_table | |
| elif entity_type == "population": | |
| entity_table = self.population_table | |
| elif entity_type == "assessments": | |
| entity_table = self.assessments_table | |
| if not entity_table: | |
| return {"error": f"Unknown entity type: {entity_type}"} | |
| # Get all documents | |
| documents = self.get_all_documents() if not protocol_id else self.get_documents_by_protocol_id(protocol_id) | |
| result = {} | |
| for doc in documents: | |
| doc_id = doc.get('document_id') | |
| doc_protocol_id = doc.get('protocol_id') | |
| # Find all entities for this protocol | |
| if entity_table == self.objectives_table: | |
| entities = self.get_objectives_by_protocol_id(doc_protocol_id) | |
| elif entity_table == self.endpoints_table: | |
| entities = self.get_endpoints_by_protocol_id(doc_protocol_id) | |
| elif entity_table == self.population_table: | |
| entities = self.get_population_criteria_by_protocol_id(doc_protocol_id) | |
| elif entity_table == self.assessments_table: | |
| entities = self.get_assessments_by_protocol_id(doc_protocol_id) | |
| result[doc_id] = { | |
| "document_title": doc.get('title', ''), | |
| "document_type": doc.get('type', ''), | |
| "protocol_id": doc_protocol_id, | |
| "entities": entities | |
| } | |
| return result | |
| # ========================================================================= | |
| # Vector Store Methods | |
| # ========================================================================= | |
| def add_documents(self, documents: List[Dict]): | |
| """ | |
| Add documents to the vector store. | |
| Each document should have 'page_content' and 'metadata' fields. | |
| """ | |
| texts = [doc['page_content'] for doc in documents] | |
| metadatas = [doc['metadata'] for doc in documents] | |
| # Add to vector store | |
| try: | |
| ids = self.vector_db.add_texts(texts=texts, metadatas=metadatas) | |
| self.vector_db.persist() # Save to disk | |
| return {"status": "success", "added": len(texts), "ids": ids} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| def similarity_search(self, query: str, k: int = 5, filter_dict: Dict = None): | |
| """ | |
| Search for documents similar to the query. | |
| Optionally filter by metadata. | |
| """ | |
| try: | |
| results = self.vector_db.similarity_search( | |
| query=query, | |
| k=k, | |
| filter=filter_dict | |
| ) | |
| return results | |
| except Exception as e: | |
| print(f"Error in similarity search: {e}") | |
| return [] | |
| def similarity_search_with_score(self, query: str, k: int = 5, filter_dict: Dict = None): | |
| """ | |
| Search for documents similar to the query, returning relevance scores. | |
| """ | |
| try: | |
| results = self.vector_db.similarity_search_with_score( | |
| query=query, | |
| k=k, | |
| filter=filter_dict | |
| ) | |
| return results | |
| except Exception as e: | |
| print(f"Error in similarity search with score: {e}") | |
| return [] | |
| def get_vector_store_stats(self): | |
| """Get statistics about the vector store.""" | |
| try: | |
| collection = self.vector_db._collection | |
| count = collection.count() | |
| return { | |
| "document_count": count, | |
| "embedding_dimension": self.embeddings.embedding_size, | |
| "model": self.embeddings.model_name | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} |