cryogenic22 commited on
Commit
5f8e1d1
·
verified ·
1 Parent(s): a31da12

Create knowledge_store.py

Browse files
Files changed (1) hide show
  1. knowledge_store.py +354 -0
knowledge_store.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge Store implementation for Pharmaceutical R&D Knowledge Ecosystem.
3
+ Includes TinyDB for structured data and ChromaDB for vector embeddings.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ from typing import Dict, List, Any, Optional, Union
9
+ from tinydb import TinyDB, Query
10
+ from tinydb.middlewares import CachingMiddleware
11
+ from tinydb.storages import JSONStorage
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+
15
+ class KnowledgeStore:
16
+ """
17
+ Knowledge store combining structured database (TinyDB) and vector store (ChromaDB).
18
+ """
19
+ def __init__(self, data_dir="./data"):
20
+ """Initialize knowledge stores with the specified data directory."""
21
+ # Ensure directories exist
22
+ os.makedirs(os.path.join(data_dir, "nosql_db"), exist_ok=True)
23
+ os.makedirs(os.path.join(data_dir, "vector_db"), exist_ok=True)
24
+
25
+ # Initialize TinyDB with caching for better performance
26
+ self.db_path = os.path.join(data_dir, "nosql_db", "protocol_knowledge.json")
27
+ self.db = TinyDB(
28
+ self.db_path,
29
+ storage=CachingMiddleware(JSONStorage)
30
+ )
31
+
32
+ # Create tables for different entity types
33
+ self.documents_table = self.db.table('documents')
34
+ self.studies_table = self.db.table('studies')
35
+ self.compounds_table = self.db.table('compounds')
36
+ self.objectives_table = self.db.table('objectives')
37
+ self.endpoints_table = self.db.table('endpoints')
38
+ self.population_table = self.db.table('population_criteria')
39
+ self.arms_table = self.db.table('study_arms')
40
+ self.assessments_table = self.db.table('assessments')
41
+ self.analytes_table = self.db.table('analytes')
42
+
43
+ # Initialize vector store with sentence-transformers embedding
44
+ self.embeddings = HuggingFaceEmbeddings(
45
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
46
+ )
47
+
48
+ # Initialize vector store directory
49
+ self.vector_db_path = os.path.join(data_dir, "vector_db")
50
+ try:
51
+ self.vector_db = Chroma(
52
+ persist_directory=self.vector_db_path,
53
+ embedding_function=self.embeddings
54
+ )
55
+ print(f"Loaded existing vector store from {self.vector_db_path}")
56
+ except Exception as e:
57
+ print(f"Creating new vector store: {e}")
58
+ self.vector_db = Chroma(
59
+ embedding_function=self.embeddings,
60
+ persist_directory=self.vector_db_path
61
+ )
62
+
63
+ # Query constructor
64
+ self.Query = Query()
65
+
66
+ # =========================================================================
67
+ # Structured Knowledge Store Methods (TinyDB)
68
+ # =========================================================================
69
+
70
+ def store_document_metadata(self, metadata: Dict) -> int:
71
+ """Store basic document metadata and return the document ID."""
72
+ # Check if document already exists
73
+ doc_id = metadata.get('id') or metadata.get('document_id')
74
+ protocol_id = metadata.get('protocol_id')
75
+ existing = None
76
+
77
+ if doc_id:
78
+ existing = self.documents_table.get(self.Query.document_id == doc_id)
79
+ elif protocol_id:
80
+ existing = self.documents_table.get(self.Query.protocol_id == protocol_id)
81
+
82
+ if existing:
83
+ self.documents_table.update(metadata, doc_ids=[existing.doc_id])
84
+ return existing.doc_id
85
+
86
+ return self.documents_table.insert(metadata)
87
+
88
+ def store_study_info(self, study_info: Dict) -> int:
89
+ """Store study information extracted from a protocol."""
90
+ # Check if study already exists by protocol ID
91
+ protocol_id = study_info.get('protocol_id')
92
+ existing = self.studies_table.get(self.Query.protocol_id == protocol_id)
93
+ if existing:
94
+ self.studies_table.update(study_info, doc_ids=[existing.doc_id])
95
+ return existing.doc_id
96
+ return self.studies_table.insert(study_info)
97
+
98
+ def store_compound_info(self, compound_info: Dict) -> int:
99
+ """Store compound information."""
100
+ compound_id = compound_info.get('compound_id')
101
+ existing = self.compounds_table.get(self.Query.compound_id == compound_id)
102
+ if existing:
103
+ self.compounds_table.update(compound_info, doc_ids=[existing.doc_id])
104
+ return existing.doc_id
105
+ return self.compounds_table.insert(compound_info)
106
+
107
+ def store_objectives(self, protocol_id: str, objectives: List[Dict]) -> List[int]:
108
+ """Store objectives for a protocol."""
109
+ # First remove any existing objectives for this protocol
110
+ self.objectives_table.remove(self.Query.protocol_id == protocol_id)
111
+
112
+ # Then insert the new objectives
113
+ doc_ids = []
114
+ for objective in objectives:
115
+ objective['protocol_id'] = protocol_id # Link back to protocol
116
+ doc_ids.append(self.objectives_table.insert(objective))
117
+ return doc_ids
118
+
119
+ def store_endpoints(self, protocol_id: str, endpoints: List[Dict]) -> List[int]:
120
+ """Store endpoints for a protocol."""
121
+ self.endpoints_table.remove(self.Query.protocol_id == protocol_id)
122
+ doc_ids = []
123
+ for endpoint in endpoints:
124
+ endpoint['protocol_id'] = protocol_id
125
+ doc_ids.append(self.endpoints_table.insert(endpoint))
126
+ return doc_ids
127
+
128
+ def store_population_criteria(self, protocol_id: str, criteria: List[Dict]) -> List[int]:
129
+ """Store inclusion/exclusion criteria."""
130
+ self.population_table.remove(self.Query.protocol_id == protocol_id)
131
+ doc_ids = []
132
+ for criterion in criteria:
133
+ criterion['protocol_id'] = protocol_id
134
+ doc_ids.append(self.population_table.insert(criterion))
135
+ return doc_ids
136
+
137
+ def store_study_arms(self, protocol_id: str, arms: List[Dict]) -> List[int]:
138
+ """Store study arms/cohorts."""
139
+ self.arms_table.remove(self.Query.protocol_id == protocol_id)
140
+ doc_ids = []
141
+ for arm in arms:
142
+ arm['protocol_id'] = protocol_id
143
+ doc_ids.append(self.arms_table.insert(arm))
144
+ return doc_ids
145
+
146
+ def store_assessments(self, protocol_id: str, assessments: List[Dict]) -> List[int]:
147
+ """Store assessments/procedures."""
148
+ self.assessments_table.remove(self.Query.protocol_id == protocol_id)
149
+ doc_ids = []
150
+ for assessment in assessments:
151
+ assessment['protocol_id'] = protocol_id
152
+ doc_ids.append(self.assessments_table.insert(assessment))
153
+ return doc_ids
154
+
155
+ # =========================================================================
156
+ # Query Methods for Structured Knowledge
157
+ # =========================================================================
158
+
159
+ def get_study_by_protocol_id(self, protocol_id: str) -> Optional[Dict]:
160
+ """Retrieve study information by protocol ID."""
161
+ return self.studies_table.get(self.Query.protocol_id == protocol_id)
162
+
163
+ def get_all_studies(self) -> List[Dict]:
164
+ """Retrieve all studies."""
165
+ return self.studies_table.all()
166
+
167
+ def get_objectives_by_protocol_id(self, protocol_id: str) -> List[Dict]:
168
+ """Retrieve all objectives for a protocol."""
169
+ return self.objectives_table.search(self.Query.protocol_id == protocol_id)
170
+
171
+ def get_endpoints_by_protocol_id(self, protocol_id: str) -> List[Dict]:
172
+ """Retrieve all endpoints for a protocol."""
173
+ return self.endpoints_table.search(self.Query.protocol_id == protocol_id)
174
+
175
+ def get_population_criteria_by_protocol_id(self, protocol_id: str, criterion_type: Optional[str] = None) -> List[Dict]:
176
+ """Retrieve population criteria for a protocol, optionally filtered by type (Inclusion/Exclusion)."""
177
+ if criterion_type:
178
+ return self.population_table.search(
179
+ (self.Query.protocol_id == protocol_id) &
180
+ (self.Query.criterion_type == criterion_type)
181
+ )
182
+ return self.population_table.search(self.Query.protocol_id == protocol_id)
183
+
184
+ def search_criteria_by_keyword(self, keyword: str) -> List[Dict]:
185
+ """Search inclusion/exclusion criteria containing a keyword."""
186
+ return self.population_table.search(self.Query.text.search(keyword, flags='i'))
187
+
188
+ def get_all_documents(self) -> List[Dict]:
189
+ """Retrieve metadata for all stored documents."""
190
+ return self.documents_table.all()
191
+
192
+ def get_document_by_id(self, document_id: str) -> Optional[Dict]:
193
+ """Retrieve document by ID."""
194
+ return self.documents_table.get(self.Query.document_id == document_id)
195
+
196
+ def get_documents_by_protocol_id(self, protocol_id: str) -> List[Dict]:
197
+ """Retrieve all documents associated with a protocol ID."""
198
+ return self.documents_table.search(self.Query.protocol_id == protocol_id)
199
+
200
+ def get_related_documents(self, protocol_id: str) -> List[Dict]:
201
+ """Find documents related to a protocol (e.g., protocol and its SAP)."""
202
+ return self.documents_table.search(
203
+ (self.Query.protocol_id == protocol_id) |
204
+ (self.Query.related_protocols.any([protocol_id]))
205
+ )
206
+
207
+ def get_assessments_by_protocol_id(self, protocol_id: str) -> List[Dict]:
208
+ """Retrieve all assessments for a protocol."""
209
+ return self.assessments_table.search(self.Query.protocol_id == protocol_id)
210
+
211
+ # Example of a more complex query that combines data
212
+ def get_protocol_summary(self, protocol_id: str) -> Dict:
213
+ """Create a comprehensive summary of a protocol."""
214
+ study = self.get_study_by_protocol_id(protocol_id)
215
+ if not study:
216
+ return {}
217
+
218
+ objectives = self.get_objectives_by_protocol_id(protocol_id)
219
+ endpoints = self.get_endpoints_by_protocol_id(protocol_id)
220
+
221
+ primary_objectives = [obj for obj in objectives if obj.get('type') == 'Primary']
222
+ secondary_objectives = [obj for obj in objectives if obj.get('type') == 'Secondary']
223
+
224
+ inclusion = self.population_table.search(
225
+ (self.Query.protocol_id == protocol_id) &
226
+ (self.Query.criterion_type == 'Inclusion')
227
+ )
228
+ exclusion = self.population_table.search(
229
+ (self.Query.protocol_id == protocol_id) &
230
+ (self.Query.criterion_type == 'Exclusion')
231
+ )
232
+
233
+ return {
234
+ "protocol_id": protocol_id,
235
+ "title": study.get('title', ''),
236
+ "phase": study.get('phase', ''),
237
+ "design": study.get('design_type', ''),
238
+ "primary_objectives": primary_objectives,
239
+ "secondary_objectives": secondary_objectives,
240
+ "primary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Primary'],
241
+ "secondary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Secondary'],
242
+ "inclusion_criteria": inclusion,
243
+ "exclusion_criteria": exclusion,
244
+ "planned_enrollment": study.get('planned_enrollment', '')
245
+ }
246
+
247
+ def find_document_entity_links(self, entity_type: str, protocol_id: str = None) -> Dict:
248
+ """
249
+ Find links between documents and specific entity types.
250
+ Useful for traceability analysis.
251
+ """
252
+ entity_table = None
253
+ if entity_type == "objectives":
254
+ entity_table = self.objectives_table
255
+ elif entity_type == "endpoints":
256
+ entity_table = self.endpoints_table
257
+ elif entity_type == "population":
258
+ entity_table = self.population_table
259
+ elif entity_type == "assessments":
260
+ entity_table = self.assessments_table
261
+
262
+ if not entity_table:
263
+ return {"error": f"Unknown entity type: {entity_type}"}
264
+
265
+ # Get all documents
266
+ documents = self.get_all_documents() if not protocol_id else self.get_documents_by_protocol_id(protocol_id)
267
+
268
+ result = {}
269
+ for doc in documents:
270
+ doc_id = doc.get('document_id')
271
+ doc_protocol_id = doc.get('protocol_id')
272
+
273
+ # Find all entities for this protocol
274
+ if entity_table == self.objectives_table:
275
+ entities = self.get_objectives_by_protocol_id(doc_protocol_id)
276
+ elif entity_table == self.endpoints_table:
277
+ entities = self.get_endpoints_by_protocol_id(doc_protocol_id)
278
+ elif entity_table == self.population_table:
279
+ entities = self.get_population_criteria_by_protocol_id(doc_protocol_id)
280
+ elif entity_table == self.assessments_table:
281
+ entities = self.get_assessments_by_protocol_id(doc_protocol_id)
282
+
283
+ result[doc_id] = {
284
+ "document_title": doc.get('title', ''),
285
+ "document_type": doc.get('type', ''),
286
+ "protocol_id": doc_protocol_id,
287
+ "entities": entities
288
+ }
289
+
290
+ return result
291
+
292
+ # =========================================================================
293
+ # Vector Store Methods
294
+ # =========================================================================
295
+
296
+ def add_documents(self, documents: List[Dict]):
297
+ """
298
+ Add documents to the vector store.
299
+ Each document should have 'page_content' and 'metadata' fields.
300
+ """
301
+ texts = [doc['page_content'] for doc in documents]
302
+ metadatas = [doc['metadata'] for doc in documents]
303
+
304
+ # Add to vector store
305
+ try:
306
+ ids = self.vector_db.add_texts(texts=texts, metadatas=metadatas)
307
+ self.vector_db.persist() # Save to disk
308
+ return {"status": "success", "added": len(texts), "ids": ids}
309
+ except Exception as e:
310
+ return {"status": "error", "message": str(e)}
311
+
312
+ def similarity_search(self, query: str, k: int = 5, filter_dict: Dict = None):
313
+ """
314
+ Search for documents similar to the query.
315
+ Optionally filter by metadata.
316
+ """
317
+ try:
318
+ results = self.vector_db.similarity_search(
319
+ query=query,
320
+ k=k,
321
+ filter=filter_dict
322
+ )
323
+ return results
324
+ except Exception as e:
325
+ print(f"Error in similarity search: {e}")
326
+ return []
327
+
328
+ def similarity_search_with_score(self, query: str, k: int = 5, filter_dict: Dict = None):
329
+ """
330
+ Search for documents similar to the query, returning relevance scores.
331
+ """
332
+ try:
333
+ results = self.vector_db.similarity_search_with_score(
334
+ query=query,
335
+ k=k,
336
+ filter=filter_dict
337
+ )
338
+ return results
339
+ except Exception as e:
340
+ print(f"Error in similarity search with score: {e}")
341
+ return []
342
+
343
+ def get_vector_store_stats(self):
344
+ """Get statistics about the vector store."""
345
+ try:
346
+ collection = self.vector_db._collection
347
+ count = collection.count()
348
+ return {
349
+ "document_count": count,
350
+ "embedding_dimension": self.embeddings.embedding_size,
351
+ "model": self.embeddings.model_name
352
+ }
353
+ except Exception as e:
354
+ return {"error": str(e)}