File size: 15,521 Bytes
5f8e1d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
Knowledge Store implementation for Pharmaceutical R&D Knowledge Ecosystem.
Includes TinyDB for structured data and ChromaDB for vector embeddings.
"""

import os
import json
from typing import Dict, List, Any, Optional, Union
from tinydb import TinyDB, Query
from tinydb.middlewares import CachingMiddleware
from tinydb.storages import JSONStorage
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings

class KnowledgeStore:
    """
    Knowledge store combining structured database (TinyDB) and vector store (ChromaDB).
    """
    def __init__(self, data_dir="./data"):
        """Initialize knowledge stores with the specified data directory."""
        # Ensure directories exist
        os.makedirs(os.path.join(data_dir, "nosql_db"), exist_ok=True)
        os.makedirs(os.path.join(data_dir, "vector_db"), exist_ok=True)
        
        # Initialize TinyDB with caching for better performance
        self.db_path = os.path.join(data_dir, "nosql_db", "protocol_knowledge.json")
        self.db = TinyDB(
            self.db_path,
            storage=CachingMiddleware(JSONStorage)
        )
        
        # Create tables for different entity types
        self.documents_table = self.db.table('documents')
        self.studies_table = self.db.table('studies')
        self.compounds_table = self.db.table('compounds')
        self.objectives_table = self.db.table('objectives')
        self.endpoints_table = self.db.table('endpoints')
        self.population_table = self.db.table('population_criteria')
        self.arms_table = self.db.table('study_arms')
        self.assessments_table = self.db.table('assessments')
        self.analytes_table = self.db.table('analytes')
        
        # Initialize vector store with sentence-transformers embedding
        self.embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        )
        
        # Initialize vector store directory
        self.vector_db_path = os.path.join(data_dir, "vector_db")
        try:
            self.vector_db = Chroma(
                persist_directory=self.vector_db_path,
                embedding_function=self.embeddings
            )
            print(f"Loaded existing vector store from {self.vector_db_path}")
        except Exception as e:
            print(f"Creating new vector store: {e}")
            self.vector_db = Chroma(
                embedding_function=self.embeddings,
                persist_directory=self.vector_db_path
            )
        
        # Query constructor
        self.Query = Query()
    
    # =========================================================================
    # Structured Knowledge Store Methods (TinyDB)
    # =========================================================================
    
    def store_document_metadata(self, metadata: Dict) -> int:
        """Store basic document metadata and return the document ID."""
        # Check if document already exists
        doc_id = metadata.get('id') or metadata.get('document_id')
        protocol_id = metadata.get('protocol_id')
        existing = None
        
        if doc_id:
            existing = self.documents_table.get(self.Query.document_id == doc_id)
        elif protocol_id:
            existing = self.documents_table.get(self.Query.protocol_id == protocol_id)
            
        if existing:
            self.documents_table.update(metadata, doc_ids=[existing.doc_id])
            return existing.doc_id
            
        return self.documents_table.insert(metadata)
    
    def store_study_info(self, study_info: Dict) -> int:
        """Store study information extracted from a protocol."""
        # Check if study already exists by protocol ID
        protocol_id = study_info.get('protocol_id')
        existing = self.studies_table.get(self.Query.protocol_id == protocol_id)
        if existing:
            self.studies_table.update(study_info, doc_ids=[existing.doc_id])
            return existing.doc_id
        return self.studies_table.insert(study_info)
    
    def store_compound_info(self, compound_info: Dict) -> int:
        """Store compound information."""
        compound_id = compound_info.get('compound_id')
        existing = self.compounds_table.get(self.Query.compound_id == compound_id)
        if existing:
            self.compounds_table.update(compound_info, doc_ids=[existing.doc_id])
            return existing.doc_id
        return self.compounds_table.insert(compound_info)
    
    def store_objectives(self, protocol_id: str, objectives: List[Dict]) -> List[int]:
        """Store objectives for a protocol."""
        # First remove any existing objectives for this protocol
        self.objectives_table.remove(self.Query.protocol_id == protocol_id)
        
        # Then insert the new objectives
        doc_ids = []
        for objective in objectives:
            objective['protocol_id'] = protocol_id  # Link back to protocol
            doc_ids.append(self.objectives_table.insert(objective))
        return doc_ids
    
    def store_endpoints(self, protocol_id: str, endpoints: List[Dict]) -> List[int]:
        """Store endpoints for a protocol."""
        self.endpoints_table.remove(self.Query.protocol_id == protocol_id)
        doc_ids = []
        for endpoint in endpoints:
            endpoint['protocol_id'] = protocol_id
            doc_ids.append(self.endpoints_table.insert(endpoint))
        return doc_ids
    
    def store_population_criteria(self, protocol_id: str, criteria: List[Dict]) -> List[int]:
        """Store inclusion/exclusion criteria."""
        self.population_table.remove(self.Query.protocol_id == protocol_id)
        doc_ids = []
        for criterion in criteria:
            criterion['protocol_id'] = protocol_id
            doc_ids.append(self.population_table.insert(criterion))
        return doc_ids
    
    def store_study_arms(self, protocol_id: str, arms: List[Dict]) -> List[int]:
        """Store study arms/cohorts."""
        self.arms_table.remove(self.Query.protocol_id == protocol_id)
        doc_ids = []
        for arm in arms:
            arm['protocol_id'] = protocol_id
            doc_ids.append(self.arms_table.insert(arm))
        return doc_ids
    
    def store_assessments(self, protocol_id: str, assessments: List[Dict]) -> List[int]:
        """Store assessments/procedures."""
        self.assessments_table.remove(self.Query.protocol_id == protocol_id)
        doc_ids = []
        for assessment in assessments:
            assessment['protocol_id'] = protocol_id
            doc_ids.append(self.assessments_table.insert(assessment))
        return doc_ids
    
    # =========================================================================
    # Query Methods for Structured Knowledge
    # =========================================================================
    
    def get_study_by_protocol_id(self, protocol_id: str) -> Optional[Dict]:
        """Retrieve study information by protocol ID."""
        return self.studies_table.get(self.Query.protocol_id == protocol_id)
    
    def get_all_studies(self) -> List[Dict]:
        """Retrieve all studies."""
        return self.studies_table.all()
    
    def get_objectives_by_protocol_id(self, protocol_id: str) -> List[Dict]:
        """Retrieve all objectives for a protocol."""
        return self.objectives_table.search(self.Query.protocol_id == protocol_id)
    
    def get_endpoints_by_protocol_id(self, protocol_id: str) -> List[Dict]:
        """Retrieve all endpoints for a protocol."""
        return self.endpoints_table.search(self.Query.protocol_id == protocol_id)
    
    def get_population_criteria_by_protocol_id(self, protocol_id: str, criterion_type: Optional[str] = None) -> List[Dict]:
        """Retrieve population criteria for a protocol, optionally filtered by type (Inclusion/Exclusion)."""
        if criterion_type:
            return self.population_table.search(
                (self.Query.protocol_id == protocol_id) & 
                (self.Query.criterion_type == criterion_type)
            )
        return self.population_table.search(self.Query.protocol_id == protocol_id)
    
    def search_criteria_by_keyword(self, keyword: str) -> List[Dict]:
        """Search inclusion/exclusion criteria containing a keyword."""
        return self.population_table.search(self.Query.text.search(keyword, flags='i'))
    
    def get_all_documents(self) -> List[Dict]:
        """Retrieve metadata for all stored documents."""
        return self.documents_table.all()
    
    def get_document_by_id(self, document_id: str) -> Optional[Dict]:
        """Retrieve document by ID."""
        return self.documents_table.get(self.Query.document_id == document_id)
    
    def get_documents_by_protocol_id(self, protocol_id: str) -> List[Dict]:
        """Retrieve all documents associated with a protocol ID."""
        return self.documents_table.search(self.Query.protocol_id == protocol_id)
    
    def get_related_documents(self, protocol_id: str) -> List[Dict]:
        """Find documents related to a protocol (e.g., protocol and its SAP)."""
        return self.documents_table.search(
            (self.Query.protocol_id == protocol_id) | 
            (self.Query.related_protocols.any([protocol_id]))
        )
    
    def get_assessments_by_protocol_id(self, protocol_id: str) -> List[Dict]:
        """Retrieve all assessments for a protocol."""
        return self.assessments_table.search(self.Query.protocol_id == protocol_id)
    
    # Example of a more complex query that combines data
    def get_protocol_summary(self, protocol_id: str) -> Dict:
        """Create a comprehensive summary of a protocol."""
        study = self.get_study_by_protocol_id(protocol_id)
        if not study:
            return {}
            
        objectives = self.get_objectives_by_protocol_id(protocol_id)
        endpoints = self.get_endpoints_by_protocol_id(protocol_id)
        
        primary_objectives = [obj for obj in objectives if obj.get('type') == 'Primary']
        secondary_objectives = [obj for obj in objectives if obj.get('type') == 'Secondary']
        
        inclusion = self.population_table.search(
            (self.Query.protocol_id == protocol_id) & 
            (self.Query.criterion_type == 'Inclusion')
        )
        exclusion = self.population_table.search(
            (self.Query.protocol_id == protocol_id) & 
            (self.Query.criterion_type == 'Exclusion')
        )
        
        return {
            "protocol_id": protocol_id,
            "title": study.get('title', ''),
            "phase": study.get('phase', ''),
            "design": study.get('design_type', ''),
            "primary_objectives": primary_objectives,
            "secondary_objectives": secondary_objectives,
            "primary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Primary'],
            "secondary_endpoints": [ep for ep in endpoints if ep.get('type') == 'Secondary'],
            "inclusion_criteria": inclusion,
            "exclusion_criteria": exclusion,
            "planned_enrollment": study.get('planned_enrollment', '')
        }
    
    def find_document_entity_links(self, entity_type: str, protocol_id: str = None) -> Dict:
        """
        Find links between documents and specific entity types.
        Useful for traceability analysis.
        """
        entity_table = None
        if entity_type == "objectives":
            entity_table = self.objectives_table
        elif entity_type == "endpoints":
            entity_table = self.endpoints_table
        elif entity_type == "population":
            entity_table = self.population_table
        elif entity_type == "assessments":
            entity_table = self.assessments_table
        
        if not entity_table:
            return {"error": f"Unknown entity type: {entity_type}"}
        
        # Get all documents
        documents = self.get_all_documents() if not protocol_id else self.get_documents_by_protocol_id(protocol_id)
        
        result = {}
        for doc in documents:
            doc_id = doc.get('document_id')
            doc_protocol_id = doc.get('protocol_id')
            
            # Find all entities for this protocol
            if entity_table == self.objectives_table:
                entities = self.get_objectives_by_protocol_id(doc_protocol_id)
            elif entity_table == self.endpoints_table:
                entities = self.get_endpoints_by_protocol_id(doc_protocol_id)
            elif entity_table == self.population_table:
                entities = self.get_population_criteria_by_protocol_id(doc_protocol_id)
            elif entity_table == self.assessments_table:
                entities = self.get_assessments_by_protocol_id(doc_protocol_id)
            
            result[doc_id] = {
                "document_title": doc.get('title', ''),
                "document_type": doc.get('type', ''),
                "protocol_id": doc_protocol_id,
                "entities": entities
            }
        
        return result
    
    # =========================================================================
    # Vector Store Methods
    # =========================================================================
    
    def add_documents(self, documents: List[Dict]):
        """
        Add documents to the vector store.
        Each document should have 'page_content' and 'metadata' fields.
        """
        texts = [doc['page_content'] for doc in documents]
        metadatas = [doc['metadata'] for doc in documents]
        
        # Add to vector store
        try:
            ids = self.vector_db.add_texts(texts=texts, metadatas=metadatas)
            self.vector_db.persist()  # Save to disk
            return {"status": "success", "added": len(texts), "ids": ids}
        except Exception as e:
            return {"status": "error", "message": str(e)}
    
    def similarity_search(self, query: str, k: int = 5, filter_dict: Dict = None):
        """
        Search for documents similar to the query.
        Optionally filter by metadata.
        """
        try:
            results = self.vector_db.similarity_search(
                query=query,
                k=k,
                filter=filter_dict
            )
            return results
        except Exception as e:
            print(f"Error in similarity search: {e}")
            return []
    
    def similarity_search_with_score(self, query: str, k: int = 5, filter_dict: Dict = None):
        """
        Search for documents similar to the query, returning relevance scores.
        """
        try:
            results = self.vector_db.similarity_search_with_score(
                query=query,
                k=k,
                filter=filter_dict
            )
            return results
        except Exception as e:
            print(f"Error in similarity search with score: {e}")
            return []
    
    def get_vector_store_stats(self):
        """Get statistics about the vector store."""
        try:
            collection = self.vector_db._collection
            count = collection.count()
            return {
                "document_count": count,
                "embedding_dimension": self.embeddings.embedding_size,
                "model": self.embeddings.model_name
            }
        except Exception as e:
            return {"error": str(e)}