Spaces:
Runtime error
Runtime error
File size: 15,521 Bytes
5f8e1d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
"""
Knowledge Store implementation for Pharmaceutical R&D Knowledge Ecosystem.
Includes TinyDB for structured data and ChromaDB for vector embeddings.
"""
import os
import json
from typing import Dict, List, Any, Optional, Union
from tinydb import TinyDB, Query
from tinydb.middlewares import CachingMiddleware
from tinydb.storages import JSONStorage
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
class KnowledgeStore:
"""
Knowledge store combining structured database (TinyDB) and vector store (ChromaDB).
"""
def __init__(self, data_dir="./data"):
"""Initialize knowledge stores with the specified data directory."""
# Ensure directories exist
os.makedirs(os.path.join(data_dir, "nosql_db"), exist_ok=True)
os.makedirs(os.path.join(data_dir, "vector_db"), exist_ok=True)
# Initialize TinyDB with caching for better performance
self.db_path = os.path.join(data_dir, "nosql_db", "protocol_knowledge.json")
self.db = TinyDB(
self.db_path,
storage=CachingMiddleware(JSONStorage)
)
# Create tables for different entity types
self.documents_table = self.db.table('documents')
self.studies_table = self.db.table('studies')
self.compounds_table = self.db.table('compounds')
self.objectives_table = self.db.table('objectives')
self.endpoints_table = self.db.table('endpoints')
self.population_table = self.db.table('population_criteria')
self.arms_table = self.db.table('study_arms')
self.assessments_table = self.db.table('assessments')
self.analytes_table = self.db.table('analytes')
# Initialize vector store with sentence-transformers embedding
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Initialize vector store directory
self.vector_db_path = os.path.join(data_dir, "vector_db")
try:
self.vector_db = Chroma(
persist_directory=self.vector_db_path,
embedding_function=self.embeddings
)
print(f"Loaded existing vector store from {self.vector_db_path}")
except Exception as e:
print(f"Creating new vector store: {e}")
self.vector_db = Chroma(
embedding_function=self.embeddings,
persist_directory=self.vector_db_path
)
# Query constructor
self.Query = Query()
# =========================================================================
# Structured Knowledge Store Methods (TinyDB)
# =========================================================================
def store_document_metadata(self, metadata: Dict) -> int:
"""Store basic document metadata and return the document ID."""
# Check if document already exists
doc_id = metadata.get('id') or metadata.get('document_id')
protocol_id = metadata.get('protocol_id')
existing = None
if doc_id:
existing = self.documents_table.get(self.Query.document_id == doc_id)
elif protocol_id:
existing = self.documents_table.get(self.Query.protocol_id == protocol_id)
if existing:
self.documents_table.update(metadata, doc_ids=[existing.doc_id])
return existing.doc_id
return self.documents_table.insert(metadata)
def store_study_info(self, study_info: Dict) -> int:
"""Store study information extracted from a protocol."""
# Check if study already exists by protocol ID
protocol_id = study_info.get('protocol_id')
existing = self.studies_table.get(self.Query.protocol_id == protocol_id)
if existing:
self.studies_table.update(study_info, doc_ids=[existing.doc_id])
return existing.doc_id
return self.studies_table.insert(study_info)
def store_compound_info(self, compound_info: Dict) -> int:
"""Store compound information."""
compound_id = compound_info.get('compound_id')
existing = self.compounds_table.get(self.Query.compound_id == compound_id)
if existing:
self.compounds_table.update(compound_info, doc_ids=[existing.doc_id])
return existing.doc_id
return self.compounds_table.insert(compound_info)
def store_objectives(self, protocol_id: str, objectives: List[Dict]) -> List[int]:
"""Store objectives for a protocol."""
# First remove any existing objectives for this protocol
self.objectives_table.remove(self.Query.protocol_id == protocol_id)
# Then insert the new objectives
doc_ids = []
for objective in objectives:
objective['protocol_id'] = protocol_id # Link back to protocol
doc_ids.append(self.objectives_table.insert(objective))
return doc_ids
def store_endpoints(self, protocol_id: str, endpoints: List[Dict]) -> List[int]:
"""Store endpoints for a protocol."""
self.endpoints_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for endpoint in endpoints:
endpoint['protocol_id'] = protocol_id
doc_ids.append(self.endpoints_table.insert(endpoint))
return doc_ids
def store_population_criteria(self, protocol_id: str, criteria: List[Dict]) -> List[int]:
"""Store inclusion/exclusion criteria."""
self.population_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for criterion in criteria:
criterion['protocol_id'] = protocol_id
doc_ids.append(self.population_table.insert(criterion))
return doc_ids
def store_study_arms(self, protocol_id: str, arms: List[Dict]) -> List[int]:
"""Store study arms/cohorts."""
self.arms_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for arm in arms:
arm['protocol_id'] = protocol_id
doc_ids.append(self.arms_table.insert(arm))
return doc_ids
def store_assessments(self, protocol_id: str, assessments: List[Dict]) -> List[int]:
"""Store assessments/procedures."""
self.assessments_table.remove(self.Query.protocol_id == protocol_id)
doc_ids = []
for assessment in assessments:
assessment['protocol_id'] = protocol_id
doc_ids.append(self.assessments_table.insert(assessment))
return doc_ids
# =========================================================================
# Query Methods for Structured Knowledge
# =========================================================================
def get_study_by_protocol_id(self, protocol_id: str) -> Optional[Dict]:
"""Retrieve study information by protocol ID."""
return self.studies_table.get(self.Query.protocol_id == protocol_id)
def get_all_studies(self) -> List[Dict]:
"""Retrieve all studies."""
return self.studies_table.all()
def get_objectives_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all objectives for a protocol."""
return self.objectives_table.search(self.Query.protocol_id == protocol_id)
def get_endpoints_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all endpoints for a protocol."""
return self.endpoints_table.search(self.Query.protocol_id == protocol_id)
def get_population_criteria_by_protocol_id(self, protocol_id: str, criterion_type: Optional[str] = None) -> List[Dict]:
"""Retrieve population criteria for a protocol, optionally filtered by type (Inclusion/Exclusion)."""
if criterion_type:
return self.population_table.search(
(self.Query.protocol_id == protocol_id) &
(self.Query.criterion_type == criterion_type)
)
return self.population_table.search(self.Query.protocol_id == protocol_id)
def search_criteria_by_keyword(self, keyword: str) -> List[Dict]:
"""Search inclusion/exclusion criteria containing a keyword."""
return self.population_table.search(self.Query.text.search(keyword, flags='i'))
def get_all_documents(self) -> List[Dict]:
"""Retrieve metadata for all stored documents."""
return self.documents_table.all()
def get_document_by_id(self, document_id: str) -> Optional[Dict]:
"""Retrieve document by ID."""
return self.documents_table.get(self.Query.document_id == document_id)
def get_documents_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all documents associated with a protocol ID."""
return self.documents_table.search(self.Query.protocol_id == protocol_id)
def get_related_documents(self, protocol_id: str) -> List[Dict]:
"""Find documents related to a protocol (e.g., protocol and its SAP)."""
return self.documents_table.search(
(self.Query.protocol_id == protocol_id) |
(self.Query.related_protocols.any([protocol_id]))
)
def get_assessments_by_protocol_id(self, protocol_id: str) -> List[Dict]:
"""Retrieve all assessments for a protocol."""
return self.assessments_table.search(self.Query.protocol_id == protocol_id)
# Example of a more complex query that combines data
def get_protocol_summary(self, protocol_id: str) -> Dict:
"""Create a comprehensive summary of a protocol."""
study = self.get_study_by_protocol_id(protocol_id)
if not study:
return {}
objectives = self.get_objectives_by_protocol_id(protocol_id)
endpoints = self.get_endpoints_by_protocol_id(protocol_id)
primary_objectives = [obj for obj in objectives if obj.get('type') == 'Primary']
secondary_objectives = [obj for obj in objectives if obj.get('type') == 'Secondary']
inclusion = self.population_table.search(
(self.Query.protocol_id == protocol_id) &
(self.Query.criterion_type == 'Inclusion')
)
exclusion = self.population_table.search(
(self.Query.protocol_id == protocol_id) &
(self.Query.criterion_type == 'Exclusion')
)
return {
"protocol_id": protocol_id,
"title": study.get('title', ''),
"phase": study.get('phase', ''),
"design": study.get('design_type', ''),
"primary_objectives": primary_objectives,
"secondary_objectives": secondary_objectives,
"primary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Primary'],
"secondary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Secondary'],
"inclusion_criteria": inclusion,
"exclusion_criteria": exclusion,
"planned_enrollment": study.get('planned_enrollment', '')
}
def find_document_entity_links(self, entity_type: str, protocol_id: str = None) -> Dict:
"""
Find links between documents and specific entity types.
Useful for traceability analysis.
"""
entity_table = None
if entity_type == "objectives":
entity_table = self.objectives_table
elif entity_type == "endpoints":
entity_table = self.endpoints_table
elif entity_type == "population":
entity_table = self.population_table
elif entity_type == "assessments":
entity_table = self.assessments_table
if not entity_table:
return {"error": f"Unknown entity type: {entity_type}"}
# Get all documents
documents = self.get_all_documents() if not protocol_id else self.get_documents_by_protocol_id(protocol_id)
result = {}
for doc in documents:
doc_id = doc.get('document_id')
doc_protocol_id = doc.get('protocol_id')
# Find all entities for this protocol
if entity_table == self.objectives_table:
entities = self.get_objectives_by_protocol_id(doc_protocol_id)
elif entity_table == self.endpoints_table:
entities = self.get_endpoints_by_protocol_id(doc_protocol_id)
elif entity_table == self.population_table:
entities = self.get_population_criteria_by_protocol_id(doc_protocol_id)
elif entity_table == self.assessments_table:
entities = self.get_assessments_by_protocol_id(doc_protocol_id)
result[doc_id] = {
"document_title": doc.get('title', ''),
"document_type": doc.get('type', ''),
"protocol_id": doc_protocol_id,
"entities": entities
}
return result
# =========================================================================
# Vector Store Methods
# =========================================================================
def add_documents(self, documents: List[Dict]):
"""
Add documents to the vector store.
Each document should have 'page_content' and 'metadata' fields.
"""
texts = [doc['page_content'] for doc in documents]
metadatas = [doc['metadata'] for doc in documents]
# Add to vector store
try:
ids = self.vector_db.add_texts(texts=texts, metadatas=metadatas)
self.vector_db.persist() # Save to disk
return {"status": "success", "added": len(texts), "ids": ids}
except Exception as e:
return {"status": "error", "message": str(e)}
def similarity_search(self, query: str, k: int = 5, filter_dict: Dict = None):
"""
Search for documents similar to the query.
Optionally filter by metadata.
"""
try:
results = self.vector_db.similarity_search(
query=query,
k=k,
filter=filter_dict
)
return results
except Exception as e:
print(f"Error in similarity search: {e}")
return []
def similarity_search_with_score(self, query: str, k: int = 5, filter_dict: Dict = None):
"""
Search for documents similar to the query, returning relevance scores.
"""
try:
results = self.vector_db.similarity_search_with_score(
query=query,
k=k,
filter=filter_dict
)
return results
except Exception as e:
print(f"Error in similarity search with score: {e}")
return []
def get_vector_store_stats(self):
"""Get statistics about the vector store."""
try:
collection = self.vector_db._collection
count = collection.count()
return {
"document_count": count,
"embedding_dimension": self.embeddings.embedding_size,
"model": self.embeddings.model_name
}
except Exception as e:
return {"error": str(e)} |