doc_knowledge_base / knowledge_store.py
cryogenic22's picture
Create knowledge_store.py
5f8e1d1 verified
"""
Knowledge Store implementation for Pharmaceutical R&D Knowledge Ecosystem.
Includes TinyDB for structured data and ChromaDB for vector embeddings.
"""
import os
import json
from typing import Dict, List, Any, Optional, Union
from tinydb import TinyDB, Query
from tinydb.middlewares import CachingMiddleware
from tinydb.storages import JSONStorage
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
class KnowledgeStore:
"""
Knowledge store combining structured database (TinyDB) and vector store (ChromaDB).
"""
def __init__(self, data_dir="./data"):
"""Initialize knowledge stores with the specified data directory."""
# Ensure directories exist
os.makedirs(os.path.join(data_dir, "nosql_db"), exist_ok=True)
os.makedirs(os.path.join(data_dir, "vector_db"), exist_ok=True)
# Initialize TinyDB with caching for better performance
self.db_path = os.path.join(data_dir, "nosql_db", "protocol_knowledge.json")
self.db = TinyDB(
self.db_path,
storage=CachingMiddleware(JSONStorage)
)
# Create tables for different entity types
self.documents_table = self.db.table('documents')
self.studies_table = self.db.table('studies')
self.compounds_table = self.db.table('compounds')
self.objectives_table = self.db.table('objectives')
self.endpoints_table = self.db.table('endpoints')
self.population_table = self.db.table('population_criteria')
self.arms_table = self.db.table('study_arms')
self.assessments_table = self.db.table('assessments')
self.analytes_table = self.db.table('analytes')
# Initialize vector store with sentence-transformers embedding
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Initialize vector store directory
self.vector_db_path = os.path.join(data_dir, "vector_db")
try:
self.vector_db = Chroma(
persist_directory=self.vector_db_path,
embedding_function=self.embeddings
)
print(f"Loaded existing vector store from {self.vector_db_path}")
except Exception as e:
print(f"Creating new vector store: {e}")
self.vector_db = Chroma(
embedding_function=self.embeddings,
persist_directory=self.vector_db_path
)
# Query constructor
self.Query = Query()
# =========================================================================
# Structured Knowledge Store Methods (TinyDB)
# =========================================================================
def store_document_metadata(self, metadata: Dict) -> int:
"""Store basic document metadata and return the document ID."""
# Check if document already exists
doc_id = metadata.get('id') or metadata.get('document_id')
protocol_id = metadata.get('protocol_id')
existing = None
if doc_id:
existing = self.documents_table.get(self.Query.document_id == doc_id)
elif protocol_id:
existing = self.documents_table.get(self.Query.protocol_id == protocol_id)
if existing:
self.documents_table.update(metadata, doc_ids=[existing.doc_id])
return existing.doc_id
return self.documents_table.insert(metadata)
def store_study_info(self, study_info: Dict) -> int:
"""Store study information extracted from a protocol."""
# Check if study already exists by protocol ID
protocol_id = study_info.get('protocol_id')
existing = self.studies_table.get(self.Query.protocol_id == protocol_id)
if existing:
self.studies_table.update(study_info, doc_ids=[existing.doc_id])
return existing.doc_id
return self.studies_table.insert(study_info)
def store_compound_info(self, compound_info: Dict) -> int:
"""Store compound information."""
compound_id = compound_info.get('compound_id')
existing = self.compounds_table.get(self.Query.compound_id == compound_id)
if existing:
self.compounds_table.update(compound_info, doc_ids=[existing.doc_id])
return existing.doc_id
return self.compounds_table.insert(compound_info)
def store_objectives(self, protocol_id: str, objectives: List[Dict]) -> List[int]:
"""Store objectives for a protocol."""
# First remove any existing objectives for this protocol
self.objectives_table.remove(self.Query.protocol_id == protocol_id)
# Then insert the new objectives
doc_ids = []
for objective in objectives:
objective['protocol_id'] = protocol_id # Link back to protocol
doc_ids.append(self.objectives_table.insert(objective))
return doc_ids
def store_endpoints(self, protocol_id: str, endpoints: List[Dict]) -> List[int]:
"""Store endpoints for a protocol."""
self.endpoints_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for endpoint in endpoints:
endpoint['protocol_id'] = protocol_id
doc_ids.append(self.endpoints_table.insert(endpoint))
return doc_ids
def store_population_criteria(self, protocol_id: str, criteria: List[Dict]) -> List[int]:
"""Store inclusion/exclusion criteria."""
self.population_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for criterion in criteria:
criterion['protocol_id'] = protocol_id
doc_ids.append(self.population_table.insert(criterion))
return doc_ids
def store_study_arms(self, protocol_id: str, arms: List[Dict]) -> List[int]:
"""Store study arms/cohorts."""
self.arms_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for arm in arms:
arm['protocol_id'] = protocol_id
doc_ids.append(self.arms_table.insert(arm))
return doc_ids
def store_assessments(self, protocol_id: str, assessments: List[Dict]) -> List[int]:
"""Store assessments/procedures."""
self.assessments_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for assessment in assessments:
assessment['protocol_id'] = protocol_id
doc_ids.append(self.assessments_table.insert(assessment))
return doc_ids
# =========================================================================
# Query Methods for Structured Knowledge
# =========================================================================
def get_study_by_protocol_id(self, protocol_id: str) -> Optional[Dict]:
"""Retrieve study information by protocol ID."""
return self.studies_table.get(self.Query.protocol_id == protocol_id)
def get_all_studies(self) -> List[Dict]:
"""Retrieve all studies."""
return self.studies_table.all()
def get_objectives_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all objectives for a protocol."""
return self.objectives_table.search(self.Query.protocol_id == protocol_id)
def get_endpoints_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all endpoints for a protocol."""
return self.endpoints_table.search(self.Query.protocol_id == protocol_id)
def get_population_criteria_by_protocol_id(self, protocol_id: str, criterion_type: Optional[str] = None) -> List[Dict]:
"""Retrieve population criteria for a protocol, optionally filtered by type (Inclusion/Exclusion)."""
if criterion_type:
return self.population_table.search(
(self.Query.protocol_id == protocol_id) &
(self.Query.criterion_type == criterion_type)
)
return self.population_table.search(self.Query.protocol_id == protocol_id)
def search_criteria_by_keyword(self, keyword: str) -> List[Dict]:
"""Search inclusion/exclusion criteria containing a keyword."""
return self.population_table.search(self.Query.text.search(keyword, flags='i'))
def get_all_documents(self) -> List[Dict]:
"""Retrieve metadata for all stored documents."""
return self.documents_table.all()
def get_document_by_id(self, document_id: str) -> Optional[Dict]:
"""Retrieve document by ID."""
return self.documents_table.get(self.Query.document_id == document_id)
def get_documents_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all documents associated with a protocol ID."""
return self.documents_table.search(self.Query.protocol_id == protocol_id)
def get_related_documents(self, protocol_id: str) -> List[Dict]:
"""Find documents related to a protocol (e.g., protocol and its SAP)."""
return self.documents_table.search(
(self.Query.protocol_id == protocol_id) |
(self.Query.related_protocols.any([protocol_id]))
)
def get_assessments_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all assessments for a protocol."""
return self.assessments_table.search(self.Query.protocol_id == protocol_id)
# Example of a more complex query that combines data
def get_protocol_summary(self, protocol_id: str) -> Dict:
"""Create a comprehensive summary of a protocol."""
study = self.get_study_by_protocol_id(protocol_id)
if not study:
return {}
objectives = self.get_objectives_by_protocol_id(protocol_id)
endpoints = self.get_endpoints_by_protocol_id(protocol_id)
primary_objectives = [obj for obj in objectives if obj.get('type') == 'Primary']
secondary_objectives = [obj for obj in objectives if obj.get('type') == 'Secondary']
inclusion = self.population_table.search(
(self.Query.protocol_id == protocol_id) &
(self.Query.criterion_type == 'Inclusion')
)
exclusion = self.population_table.search(
(self.Query.protocol_id == protocol_id) &
(self.Query.criterion_type == 'Exclusion')
)
return {
"protocol_id": protocol_id,
"title": study.get('title', ''),
"phase": study.get('phase', ''),
"design": study.get('design_type', ''),
"primary_objectives": primary_objectives,
"secondary_objectives": secondary_objectives,
"primary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Primary'],
"secondary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Secondary'],
"inclusion_criteria": inclusion,
"exclusion_criteria": exclusion,
"planned_enrollment": study.get('planned_enrollment', '')
}
def find_document_entity_links(self, entity_type: str, protocol_id: str = None) -> Dict:
"""
Find links between documents and specific entity types.
Useful for traceability analysis.
"""
entity_table = None
if entity_type == "objectives":
entity_table = self.objectives_table
elif entity_type == "endpoints":
entity_table = self.endpoints_table
elif entity_type == "population":
entity_table = self.population_table
elif entity_type == "assessments":
entity_table = self.assessments_table
if not entity_table:
return {"error": f"Unknown entity type: {entity_type}"}
# Get all documents
documents = self.get_all_documents() if not protocol_id else self.get_documents_by_protocol_id(protocol_id)
result = {}
for doc in documents:
doc_id = doc.get('document_id')
doc_protocol_id = doc.get('protocol_id')
# Find all entities for this protocol
if entity_table == self.objectives_table:
entities = self.get_objectives_by_protocol_id(doc_protocol_id)
elif entity_table == self.endpoints_table:
entities = self.get_endpoints_by_protocol_id(doc_protocol_id)
elif entity_table == self.population_table:
entities = self.get_population_criteria_by_protocol_id(doc_protocol_id)
elif entity_table == self.assessments_table:
entities = self.get_assessments_by_protocol_id(doc_protocol_id)
result[doc_id] = {
"document_title": doc.get('title', ''),
"document_type": doc.get('type', ''),
"protocol_id": doc_protocol_id,
"entities": entities
}
return result
# =========================================================================
# Vector Store Methods
# =========================================================================
def add_documents(self, documents: List[Dict]):
"""
Add documents to the vector store.
Each document should have 'page_content' and 'metadata' fields.
"""
texts = [doc['page_content'] for doc in documents]
metadatas = [doc['metadata'] for doc in documents]
# Add to vector store
try:
ids = self.vector_db.add_texts(texts=texts, metadatas=metadatas)
self.vector_db.persist() # Save to disk
return {"status": "success", "added": len(texts), "ids": ids}
except Exception as e:
return {"status": "error", "message": str(e)}
def similarity_search(self, query: str, k: int = 5, filter_dict: Dict = None):
"""
Search for documents similar to the query.
Optionally filter by metadata.
"""
try:
results = self.vector_db.similarity_search(
query=query,
k=k,
filter=filter_dict
)
return results
except Exception as e:
print(f"Error in similarity search: {e}")
return []
def similarity_search_with_score(self, query: str, k: int = 5, filter_dict: Dict = None):
"""
Search for documents similar to the query, returning relevance scores.
"""
try:
results = self.vector_db.similarity_search_with_score(
query=query,
k=k,
filter=filter_dict
)
return results
except Exception as e:
print(f"Error in similarity search with score: {e}")
return []
def get_vector_store_stats(self):
"""Get statistics about the vector store."""
try:
collection = self.vector_db._collection
count = collection.count()
return {
"document_count": count,
"embedding_dimension": self.embeddings.embedding_size,
"model": self.embeddings.model_name
}
except Exception as e:
return {"error": str(e)}