cryogenic22 commited on
Commit
e30ad71
·
verified ·
1 Parent(s): 7e30640

Create graph_builder.py

Browse files
Files changed (1) hide show
  1. graph_builder.py +776 -0
graph_builder.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph agent orchestration for document processing, content authoring, and protocol coach.
3
+ """
4
+
5
+ from langgraph.graph import StateGraph, END
6
+ from typing import TypedDict, Dict, List, Any, Optional, Literal, Annotated, cast
7
+ import operator
8
+ import uuid
9
+
10
+ from schemas import DocumentExtractionState, ProtocolCoachState, ContentAuthoringState, TraceabilityState
11
+ from pdf_processor import PDFProcessor
12
+ from knowledge_store import KnowledgeStore
13
+ from llm_interface import LLMInterface
14
+
15
+ # Initialize handlers
16
+ pdf_processor = None
17
+ knowledge_store = None
18
+ llm_interface = None
19
+
20
+ def init_handlers(api_key=None):
21
+ """Initialize handlers for PDF processing, knowledge store, and LLM."""
22
+ global pdf_processor, knowledge_store, llm_interface
23
+
24
+ pdf_processor = PDFProcessor()
25
+ knowledge_store = KnowledgeStore()
26
+ llm_interface = LLMInterface(api_key=api_key)
27
+
28
+ return pdf_processor, knowledge_store, llm_interface
29
+
30
+ # =========================================================================
31
+ # Document Extraction Workflow Nodes
32
+ # =========================================================================
33
+
34
+ def parse_document(state: DocumentExtractionState) -> DocumentExtractionState:
35
+ """Parse PDF document and extract text."""
36
+ try:
37
+ document_path = state["document_path"]
38
+
39
+ # Process document with PDFProcessor
40
+ result = pdf_processor.process_complete_document(document_path)
41
+
42
+ if result["status"] == "error":
43
+ return {
44
+ **state,
45
+ "status": "error",
46
+ "error": f"Failed to parse document: {result.get('error', 'Unknown error')}"
47
+ }
48
+
49
+ return {
50
+ **state,
51
+ "document_text": result.get("full_text", ""),
52
+ "document_metadata": result.get("metadata", {}),
53
+ "sections": result.get("sections", {}),
54
+ "vector_chunks": result.get("chunks", []),
55
+ "status": "parsed"
56
+ }
57
+ except Exception as e:
58
+ return {
59
+ **state,
60
+ "status": "error",
61
+ "error": f"Exception in parse_document: {str(e)}"
62
+ }
63
+
64
+ def extract_study_info(state: DocumentExtractionState) -> DocumentExtractionState:
65
+ """Extract study information using LLM."""
66
+ if state.get("status") == "error":
67
+ return state
68
+
69
+ try:
70
+ # Use synopsis or first few sections for study info extraction
71
+ text_for_extraction = ""
72
+ sections = state.get("sections", {})
73
+
74
+ # Try to find synopsis or summary section first
75
+ for section_name in ["synopsis", "summary", "overview"]:
76
+ if section_name.lower() in [s.lower() for s in sections.keys()]:
77
+ section_key = next(k for k in sections.keys() if k.lower() == section_name.lower())
78
+ text_for_extraction = sections[section_key]
79
+ break
80
+
81
+ # If no synopsis found, use the beginning of the document
82
+ if not text_for_extraction and "document_text" in state:
83
+ text_for_extraction = state["document_text"][:20000] # Use first 20k chars
84
+
85
+ if not text_for_extraction:
86
+ return {
87
+ **state,
88
+ "status": "error",
89
+ "error": "No text available for study info extraction"
90
+ }
91
+
92
+ # Extract study info using LLM
93
+ study_info = llm_interface.extract_study_info(text_for_extraction)
94
+
95
+ if not study_info:
96
+ return {
97
+ **state,
98
+ "status": "error",
99
+ "error": "Failed to extract study information"
100
+ }
101
+
102
+ # Ensure protocol_id is in study_info
103
+ if "protocol_id" not in study_info and "document_metadata" in state:
104
+ study_info["protocol_id"] = state["document_metadata"].get("protocol_id")
105
+
106
+ return {
107
+ **state,
108
+ "extracted_study": study_info,
109
+ "status": "study_extracted"
110
+ }
111
+ except Exception as e:
112
+ return {
113
+ **state,
114
+ "status": "error",
115
+ "error": f"Exception in extract_study_info: {str(e)}"
116
+ }
117
+
118
+ def extract_objectives_endpoints(state: DocumentExtractionState) -> DocumentExtractionState:
119
+ """Extract objectives and endpoints using LLM."""
120
+ if state.get("status") == "error":
121
+ return state
122
+
123
+ try:
124
+ sections = state.get("sections", {})
125
+ protocol_id = state.get("extracted_study", {}).get("protocol_id")
126
+
127
+ if not protocol_id:
128
+ protocol_id = state.get("document_metadata", {}).get("protocol_id")
129
+
130
+ if not protocol_id:
131
+ return {
132
+ **state,
133
+ "status": "error",
134
+ "error": "No protocol ID available for extraction"
135
+ }
136
+
137
+ # Find objectives/endpoints section
138
+ text_for_extraction = ""
139
+ for section_name in ["objectives", "objective", "endpoint", "endpoints"]:
140
+ for key in sections.keys():
141
+ if section_name.lower() in key.lower():
142
+ text_for_extraction = sections[key]
143
+ break
144
+ if text_for_extraction:
145
+ break
146
+
147
+ if not text_for_extraction:
148
+ return {
149
+ **state,
150
+ "status": "warning",
151
+ "error": "No objectives/endpoints section found"
152
+ }
153
+
154
+ # Extract objectives and endpoints
155
+ result = llm_interface.extract_objectives_and_endpoints(text_for_extraction, protocol_id)
156
+
157
+ if not result:
158
+ return {
159
+ **state,
160
+ "status": "warning",
161
+ "error": "Failed to extract objectives and endpoints"
162
+ }
163
+
164
+ return {
165
+ **state,
166
+ "extracted_objectives": result.get("objectives", []),
167
+ "extracted_endpoints": result.get("endpoints", []),
168
+ "status": "objectives_endpoints_extracted"
169
+ }
170
+ except Exception as e:
171
+ return {
172
+ **state,
173
+ "status": "error",
174
+ "error": f"Exception in extract_objectives_endpoints: {str(e)}"
175
+ }
176
+
177
+ def extract_population_criteria(state: DocumentExtractionState) -> DocumentExtractionState:
178
+ """Extract inclusion and exclusion criteria using LLM."""
179
+ if state.get("status") == "error":
180
+ return state
181
+
182
+ try:
183
+ sections = state.get("sections", {})
184
+ protocol_id = state.get("extracted_study", {}).get("protocol_id")
185
+
186
+ if not protocol_id:
187
+ protocol_id = state.get("document_metadata", {}).get("protocol_id")
188
+
189
+ # Find criteria section
190
+ text_for_extraction = ""
191
+ for section_name in ["eligibility", "inclusion", "exclusion", "criteria", "population"]:
192
+ for key in sections.keys():
193
+ if section_name.lower() in key.lower():
194
+ text_for_extraction = sections[key]
195
+ break
196
+ if text_for_extraction:
197
+ break
198
+
199
+ if not text_for_extraction:
200
+ return {
201
+ **state,
202
+ "status": "warning",
203
+ "error": "No population criteria section found"
204
+ }
205
+
206
+ # Extract criteria
207
+ result = llm_interface.extract_population_criteria(text_for_extraction, protocol_id)
208
+
209
+ if not result:
210
+ return {
211
+ **state,
212
+ "status": "warning",
213
+ "error": "Failed to extract population criteria"
214
+ }
215
+
216
+ return {
217
+ **state,
218
+ "extracted_population": result,
219
+ "status": "population_extracted"
220
+ }
221
+ except Exception as e:
222
+ return {
223
+ **state,
224
+ "status": "error",
225
+ "error": f"Exception in extract_population_criteria: {str(e)}"
226
+ }
227
+
228
+ def extract_study_design(state: DocumentExtractionState) -> DocumentExtractionState:
229
+ """Extract study design information using LLM."""
230
+ if state.get("status") == "error":
231
+ return state
232
+
233
+ try:
234
+ sections = state.get("sections", {})
235
+ protocol_id = state.get("extracted_study", {}).get("protocol_id")
236
+
237
+ if not protocol_id:
238
+ protocol_id = state.get("document_metadata", {}).get("protocol_id")
239
+
240
+ # Find study design section
241
+ text_for_extraction = ""
242
+ for section_name in ["study design", "design", "methodology"]:
243
+ for key in sections.keys():
244
+ if section_name.lower() in key.lower():
245
+ text_for_extraction = sections[key]
246
+ break
247
+ if text_for_extraction:
248
+ break
249
+
250
+ if not text_for_extraction:
251
+ return {
252
+ **state,
253
+ "status": "warning",
254
+ "error": "No study design section found"
255
+ }
256
+
257
+ # Extract study design
258
+ result = llm_interface.extract_study_design(text_for_extraction, protocol_id)
259
+
260
+ if not result:
261
+ return {
262
+ **state,
263
+ "status": "warning",
264
+ "error": "Failed to extract study design"
265
+ }
266
+
267
+ return {
268
+ **state,
269
+ "extracted_design": result,
270
+ "status": "design_extracted"
271
+ }
272
+ except Exception as e:
273
+ return {
274
+ **state,
275
+ "status": "error",
276
+ "error": f"Exception in extract_study_design: {str(e)}"
277
+ }
278
+
279
+ def store_in_knowledge_base(state: DocumentExtractionState) -> DocumentExtractionState:
280
+ """Store extracted information in the knowledge base."""
281
+ try:
282
+ # Skip if there was a critical error
283
+ if state.get("status") == "error":
284
+ return state
285
+
286
+ # Extract data from state
287
+ document_metadata = state.get("document_metadata", {})
288
+ study_info = state.get("extracted_study", {})
289
+ objectives = state.get("extracted_objectives", [])
290
+ endpoints = state.get("extracted_endpoints", [])
291
+ population = state.get("extracted_population", {})
292
+ design = state.get("extracted_design", {})
293
+ vector_chunks = state.get("vector_chunks", [])
294
+
295
+ # Ensure we have a protocol ID
296
+ protocol_id = study_info.get("protocol_id")
297
+ if not protocol_id:
298
+ protocol_id = document_metadata.get("protocol_id")
299
+
300
+ if not protocol_id:
301
+ return {
302
+ **state,
303
+ "status": "error",
304
+ "error": "No protocol ID available for knowledge base storage"
305
+ }
306
+
307
+ # Add protocol_id to document_metadata
308
+ document_metadata["protocol_id"] = protocol_id
309
+
310
+ # Store in NoSQL DB
311
+ doc_id = knowledge_store.store_document_metadata(document_metadata)
312
+
313
+ # Store study info if available
314
+ if study_info:
315
+ study_id = knowledge_store.store_study_info(study_info)
316
+
317
+ # Store objectives if available
318
+ if objectives:
319
+ knowledge_store.store_objectives(protocol_id, objectives)
320
+
321
+ # Store endpoints if available
322
+ if endpoints:
323
+ knowledge_store.store_endpoints(protocol_id, endpoints)
324
+
325
+ # Store population criteria if available
326
+ if population and "inclusion_criteria" in population:
327
+ inclusion = population.get("inclusion_criteria", [])
328
+ exclusion = population.get("exclusion_criteria", [])
329
+
330
+ # Add criterion_type to each criterion
331
+ for criterion in inclusion:
332
+ criterion["criterion_type"] = "Inclusion"
333
+ criterion["protocol_id"] = protocol_id
334
+
335
+ for criterion in exclusion:
336
+ criterion["criterion_type"] = "Exclusion"
337
+ criterion["protocol_id"] = protocol_id
338
+
339
+ # Store all criteria
340
+ all_criteria = inclusion + exclusion
341
+ knowledge_store.store_population_criteria(protocol_id, all_criteria)
342
+
343
+ # Store in vector store if chunks available
344
+ if vector_chunks:
345
+ result = knowledge_store.add_documents(vector_chunks)
346
+
347
+ if result.get("status") == "error":
348
+ return {
349
+ **state,
350
+ "status": "warning",
351
+ "error": f"Warning: Failed to add to vector store: {result.get('message')}"
352
+ }
353
+
354
+ return {
355
+ **state,
356
+ "status": "completed",
357
+ "document_id": doc_id,
358
+ }
359
+ except Exception as e:
360
+ return {
361
+ **state,
362
+ "status": "error",
363
+ "error": f"Exception in store_in_knowledge_base: {str(e)}"
364
+ }
365
+
366
+ # =========================================================================
367
+ # Protocol Coach Workflow Nodes
368
+ # =========================================================================
369
+
370
+ def retrieve_context_for_query(state: ProtocolCoachState) -> ProtocolCoachState:
371
+ """Retrieve relevant context for a user query."""
372
+ try:
373
+ query = state["query"]
374
+
375
+ # Query vector store for context
376
+ relevant_docs = knowledge_store.similarity_search(
377
+ query=query,
378
+ k=5 # Get top 5 most relevant chunks
379
+ )
380
+
381
+ if not relevant_docs:
382
+ return {
383
+ **state,
384
+ "retrieved_context": [],
385
+ "error": "No relevant context found"
386
+ }
387
+
388
+ # Format results for easy use
389
+ context = [
390
+ {
391
+ "page_content": doc.page_content,
392
+ "metadata": doc.metadata
393
+ }
394
+ for doc in relevant_docs
395
+ ]
396
+
397
+ return {
398
+ **state,
399
+ "retrieved_context": context
400
+ }
401
+ except Exception as e:
402
+ return {
403
+ **state,
404
+ "error": f"Exception in retrieve_context_for_query: {str(e)}"
405
+ }
406
+
407
+ def answer_query(state: ProtocolCoachState) -> ProtocolCoachState:
408
+ """Generate answer to user query using retrieved context."""
409
+ try:
410
+ query = state["query"]
411
+ context = state.get("retrieved_context", [])
412
+ chat_history = state.get("chat_history", [])
413
+
414
+ if not context:
415
+ return {
416
+ **state,
417
+ "response": "I don't have enough context to answer that question about the protocol. Please try asking something else or upload relevant documents."
418
+ }
419
+
420
+ # Generate response using LLM
421
+ response = llm_interface.answer_protocol_question(
422
+ question=query,
423
+ context=context,
424
+ chat_history=chat_history
425
+ )
426
+
427
+ if not response:
428
+ return {
429
+ **state,
430
+ "response": "I encountered an issue while generating a response. Please try again."
431
+ }
432
+
433
+ return {
434
+ **state,
435
+ "response": response
436
+ }
437
+ except Exception as e:
438
+ return {
439
+ **state,
440
+ "response": f"Error: {str(e)}",
441
+ "error": f"Exception in answer_query: {str(e)}"
442
+ }
443
+
444
+ # =========================================================================
445
+ # Content Authoring Workflow Nodes
446
+ # =========================================================================
447
+
448
+ def retrieve_content_examples(state: ContentAuthoringState) -> ContentAuthoringState:
449
+ """Retrieve examples of similar content for authoring."""
450
+ try:
451
+ section_type = state["section_type"]
452
+ target_protocol_id = state.get("target_protocol_id")
453
+
454
+ # Create a search query based on section type
455
+ search_query = f"{section_type} section for clinical study protocol"
456
+
457
+ # Set up potential filters
458
+ filter_dict = None
459
+ if target_protocol_id:
460
+ # Exclude the target protocol from examples if specified
461
+ filter_dict = {"protocol_id": {"$ne": target_protocol_id}}
462
+
463
+ # Query vector store for examples
464
+ relevant_docs = knowledge_store.similarity_search(
465
+ query=search_query,
466
+ k=3,
467
+ filter_dict=filter_dict
468
+ )
469
+
470
+ if not relevant_docs:
471
+ return {
472
+ **state,
473
+ "retrieved_context": [],
474
+ "error": "No relevant examples found"
475
+ }
476
+
477
+ # Format results for easy use
478
+ context = [
479
+ {
480
+ "page_content": doc.page_content,
481
+ "metadata": doc.metadata
482
+ }
483
+ for doc in relevant_docs
484
+ ]
485
+
486
+ return {
487
+ **state,
488
+ "retrieved_context": context
489
+ }
490
+ except Exception as e:
491
+ return {
492
+ **state,
493
+ "error": f"Exception in retrieve_content_examples: {str(e)}"
494
+ }
495
+
496
+ def generate_content(state: ContentAuthoringState) -> ContentAuthoringState:
497
+ """Generate content for authoring."""
498
+ try:
499
+ section_type = state["section_type"]
500
+ context = state.get("retrieved_context", [])
501
+ target_protocol_id = state.get("target_protocol_id")
502
+ style_guide = state.get("style_guide")
503
+
504
+ if not context:
505
+ return {
506
+ **state,
507
+ "generated_content": "I don't have enough examples to generate a good section. Please upload more documents or try a different section type.",
508
+ "error": "No context available for generation"
509
+ }
510
+
511
+ # Generate content using LLM
512
+ content = llm_interface.generate_content_from_knowledge(
513
+ section_type=section_type,
514
+ context=context,
515
+ protocol_id=target_protocol_id,
516
+ style_guide=style_guide
517
+ )
518
+
519
+ if not content:
520
+ return {
521
+ **state,
522
+ "generated_content": "I encountered an issue while generating content. Please try again.",
523
+ "error": "Failed to generate content"
524
+ }
525
+
526
+ return {
527
+ **state,
528
+ "generated_content": content
529
+ }
530
+ except Exception as e:
531
+ return {
532
+ **state,
533
+ "generated_content": f"Error: {str(e)}",
534
+ "error": f"Exception in generate_content: {str(e)}"
535
+ }
536
+
537
+ def critique_content(state: ContentAuthoringState) -> ContentAuthoringState:
538
+ """Critique generated content for quality and consistency."""
539
+ # This would normally use an LLM to critique content
540
+ # For simplicity, we're returning the content unchanged
541
+ return state
542
+
543
+ # =========================================================================
544
+ # Traceability Workflow Nodes
545
+ # =========================================================================
546
+
547
+ def retrieve_document_entities(state: TraceabilityState) -> TraceabilityState:
548
+ """Retrieve entities from source and target documents."""
549
+ try:
550
+ source_doc_id = state["source_document_id"]
551
+ target_doc_id = state["target_document_id"]
552
+ entity_type = state["entity_type"]
553
+
554
+ # Get document metadata
555
+ source_doc = knowledge_store.get_document_by_id(source_doc_id)
556
+ target_doc = knowledge_store.get_document_by_id(target_doc_id)
557
+
558
+ if not source_doc or not target_doc:
559
+ return {
560
+ **state,
561
+ "error": "One or both documents not found"
562
+ }
563
+
564
+ # Get protocol IDs
565
+ source_protocol_id = source_doc.get("protocol_id")
566
+ target_protocol_id = target_doc.get("protocol_id")
567
+
568
+ if not source_protocol_id or not target_protocol_id:
569
+ return {
570
+ **state,
571
+ "error": "Protocol ID missing from one or both documents"
572
+ }
573
+
574
+ # Retrieve entities based on entity type
575
+ source_entities = []
576
+ target_entities = []
577
+
578
+ if entity_type == "objectives":
579
+ source_entities = knowledge_store.get_objectives_by_protocol_id(source_protocol_id)
580
+ target_entities = knowledge_store.get_objectives_by_protocol_id(target_protocol_id)
581
+ elif entity_type == "endpoints":
582
+ source_entities = knowledge_store.get_endpoints_by_protocol_id(source_protocol_id)
583
+ target_entities = knowledge_store.get_endpoints_by_protocol_id(target_protocol_id)
584
+ elif entity_type == "population":
585
+ source_entities = knowledge_store.get_population_criteria_by_protocol_id(source_protocol_id)
586
+ target_entities = knowledge_store.get_population_criteria_by_protocol_id(target_protocol_id)
587
+
588
+ if not source_entities or not target_entities:
589
+ return {
590
+ **state,
591
+ "error": f"No {entity_type} found in one or both documents"
592
+ }
593
+
594
+ return {
595
+ **state,
596
+ "source_entities": source_entities,
597
+ "target_entities": target_entities
598
+ }
599
+ except Exception as e:
600
+ return {
601
+ **state,
602
+ "error": f"Exception in retrieve_document_entities: {str(e)}"
603
+ }
604
+
605
+ def match_entities(state: TraceabilityState) -> TraceabilityState:
606
+ """Match entities between documents based on similarity."""
607
+ try:
608
+ if "error" in state:
609
+ return state
610
+
611
+ source_entities = state.get("source_entities", [])
612
+ target_entities = state.get("target_entities", [])
613
+
614
+ # Simple matching - in a real system this would use more sophisticated comparison
615
+ matched_pairs = []
616
+
617
+ for source_entity in source_entities:
618
+ matches = []
619
+
620
+ for target_entity in target_entities:
621
+ # Compare based on description/text
622
+ source_text = source_entity.get("description", source_entity.get("text", ""))
623
+ target_text = target_entity.get("description", target_entity.get("text", ""))
624
+
625
+ if not source_text or not target_text:
626
+ continue
627
+
628
+ # Simple text comparison - LLM would do better comparison in real system
629
+ if len(source_text) > 0 and len(target_text) > 0:
630
+ matches.append({
631
+ "source_entity": source_entity,
632
+ "target_entity": target_entity,
633
+ "source_text": source_text,
634
+ "target_text": target_text,
635
+ "entity_type": state["entity_type"]
636
+ })
637
+
638
+ # If matches found, take the top one
639
+ if matches:
640
+ matched_pairs.append(matches[0])
641
+
642
+ return {
643
+ **state,
644
+ "matched_pairs": matched_pairs
645
+ }
646
+ except Exception as e:
647
+ return {
648
+ **state,
649
+ "error": f"Exception in match_entities: {str(e)}"
650
+ }
651
+
652
+ def analyze_matches(state: TraceabilityState) -> TraceabilityState:
653
+ """Analyze matches between documents to identify consistency issues."""
654
+ try:
655
+ if "error" in state:
656
+ return state
657
+
658
+ matched_pairs = state.get("matched_pairs", [])
659
+ source_doc_id = state["source_document_id"]
660
+ target_doc_id = state["target_document_id"]
661
+
662
+ if not matched_pairs:
663
+ return {
664
+ **state,
665
+ "analysis": "No matching entities found between the documents."
666
+ }
667
+
668
+ # Get document metadata
669
+ source_doc = knowledge_store.get_document_by_id(source_doc_id)
670
+ target_doc = knowledge_store.get_document_by_id(target_doc_id)
671
+
672
+ # Use LLM to analyze matches
673
+ analysis = llm_interface.find_document_connections(
674
+ source_doc_info=source_doc,
675
+ target_doc_info=target_doc,
676
+ entity_pairs=matched_pairs
677
+ )
678
+
679
+ return {
680
+ **state,
681
+ "analysis": analysis
682
+ }
683
+ except Exception as e:
684
+ return {
685
+ **state,
686
+ "error": f"Exception in analyze_matches: {str(e)}",
687
+ "analysis": f"Error analyzing matches: {str(e)}"
688
+ }
689
+
690
+ # =========================================================================
691
+ # Graph Building Functions
692
+ # =========================================================================
693
+
694
+ def build_document_extraction_graph():
695
+ """Build and return document extraction workflow graph."""
696
+ workflow = StateGraph(DocumentExtractionState)
697
+
698
+ # Add nodes
699
+ workflow.add_node("parse_document", parse_document)
700
+ workflow.add_node("extract_study_info", extract_study_info)
701
+ workflow.add_node("extract_objectives_endpoints", extract_objectives_endpoints)
702
+ workflow.add_node("extract_population_criteria", extract_population_criteria)
703
+ workflow.add_node("extract_study_design", extract_study_design)
704
+ workflow.add_node("store_in_knowledge_base", store_in_knowledge_base)
705
+
706
+ # Add edges - sequential process
707
+ workflow.add_edge("parse_document", "extract_study_info")
708
+ workflow.add_edge("extract_study_info", "extract_objectives_endpoints")
709
+ workflow.add_edge("extract_objectives_endpoints", "extract_population_criteria")
710
+ workflow.add_edge("extract_population_criteria", "extract_study_design")
711
+ workflow.add_edge("extract_study_design", "store_in_knowledge_base")
712
+ workflow.add_edge("store_in_knowledge_base", END)
713
+
714
+ # Handle errors - any node can output an error
715
+ for node in workflow.nodes:
716
+ # Check if status is error, if yes, go to END
717
+ workflow.add_conditional_edges(
718
+ node,
719
+ lambda state: "status" in state and state["status"] == "error",
720
+ {
721
+ True: END,
722
+ False: None
723
+ }
724
+ )
725
+
726
+ workflow.set_entry_point("parse_document")
727
+ return workflow.compile()
728
+
729
+ def build_protocol_coach_graph():
730
+ """Build and return protocol coach workflow graph."""
731
+ workflow = StateGraph(ProtocolCoachState)
732
+
733
+ # Add nodes
734
+ workflow.add_node("retrieve_context", retrieve_context_for_query)
735
+ workflow.add_node("answer_query", answer_query)
736
+
737
+ # Add edges
738
+ workflow.add_edge("retrieve_context", "answer_query")
739
+ workflow.add_edge("answer_query", END)
740
+
741
+ workflow.set_entry_point("retrieve_context")
742
+ return workflow.compile()
743
+
744
+ def build_content_authoring_graph():
745
+ """Build and return content authoring workflow graph."""
746
+ workflow = StateGraph(ContentAuthoringState)
747
+
748
+ # Add nodes
749
+ workflow.add_node("retrieve_examples", retrieve_content_examples)
750
+ workflow.add_node("generate_content", generate_content)
751
+ workflow.add_node("critique_content", critique_content)
752
+
753
+ # Add edges
754
+ workflow.add_edge("retrieve_examples", "generate_content")
755
+ workflow.add_edge("generate_content", "critique_content")
756
+ workflow.add_edge("critique_content", END)
757
+
758
+ workflow.set_entry_point("retrieve_examples")
759
+ return workflow.compile()
760
+
761
+ def build_traceability_graph():
762
+ """Build and return traceability analysis workflow graph."""
763
+ workflow = StateGraph(TraceabilityState)
764
+
765
+ # Add nodes
766
+ workflow.add_node("retrieve_entities", retrieve_document_entities)
767
+ workflow.add_node("match_entities", match_entities)
768
+ workflow.add_node("analyze_matches", analyze_matches)
769
+
770
+ # Add edges
771
+ workflow.add_edge("retrieve_entities", "match_entities")
772
+ workflow.add_edge("match_entities", "analyze_matches")
773
+ workflow.add_edge("analyze_matches", END)
774
+
775
+ workflow.set_entry_point("retrieve_entities")
776
+ return workflow.compile()