Spaces:
Runtime error
Runtime error
File size: 28,075 Bytes
e30ad71 6a13f3d e30ad71 131cc45 e30ad71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 |
"""
LangGraph agent orchestration for document processing, content authoring, and protocol coach.
"""
from langgraph.graph import StateGraph, END
from typing import TypedDict, Dict, List, Any, Optional, Literal, Annotated, cast
import operator
import uuid
from schemas import DocumentExtractionState, ProtocolCoachState, ContentAuthoringState, TraceabilityState
from pdf_processor import PDFProcessor
from knowledge_store import KnowledgeStore
from llm_interface import LLMInterface
# Initialize handlers
pdf_processor = None
knowledge_store = None
llm_interface = None
def init_handlers(api_key=None):
"""Initialize handlers for PDF processing, knowledge store, and LLM."""
global pdf_processor, knowledge_store, llm_interface
pdf_processor = PDFProcessor()
knowledge_store = KnowledgeStore()
llm_interface = LLMInterface(api_key=api_key)
return pdf_processor, knowledge_store, llm_interface
# =========================================================================
# Document Extraction Workflow Nodes
# =========================================================================
def parse_document(state: DocumentExtractionState) -> DocumentExtractionState:
"""Parse PDF document and extract text."""
try:
document_path = state["document_path"]
# Process document with PDFProcessor
result = pdf_processor.process_complete_document(document_path)
if result["status"] == "error":
return {
**state,
"status": "error",
"error": f"Failed to parse document: {result.get('error', 'Unknown error')}"
}
return {
**state,
"document_text": result.get("full_text", ""),
"document_metadata": result.get("metadata", {}),
"sections": result.get("sections", {}),
"vector_chunks": result.get("chunks", []),
"status": "parsed"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in parse_document: {str(e)}"
}
def extract_study_info(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract study information using LLM."""
if state.get("status") == "error":
return state
try:
# Use synopsis or first few sections for study info extraction
text_for_extraction = ""
sections = state.get("sections", {})
# Check if sections is a list (section names only) or a dict (section name -> content)
if isinstance(sections, list):
# Just use the document text since we don't have section content
if "document_text" in state:
text_for_extraction = state["document_text"][:20000] # Use first 20k chars
else:
# Try to find synopsis or summary section first
for section_name in ["synopsis", "summary", "overview"]:
if section_name.lower() in [s.lower() for s in sections.keys()]:
section_key = next(k for k in sections.keys() if k.lower() == section_name.lower())
text_for_extraction = sections[section_key]
break
# If no synopsis found, use the beginning of the document
if not text_for_extraction and "document_text" in state:
text_for_extraction = state["document_text"][:20000] # Use first 20k chars
if not text_for_extraction:
return {
**state,
"status": "error",
"error": "No text available for study info extraction"
}
# Extract study info using LLM
study_info = llm_interface.extract_study_info(text_for_extraction)
if not study_info:
return {
**state,
"status": "error",
"error": "Failed to extract study information"
}
# Ensure protocol_id is in study_info
if "protocol_id" not in study_info and "document_metadata" in state:
study_info["protocol_id"] = state["document_metadata"].get("protocol_id")
return {
**state,
"extracted_study": study_info,
"status": "study_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_study_info: {str(e)}"
}
def extract_objectives_endpoints(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract objectives and endpoints using LLM."""
if state.get("status") == "error":
return state
try:
sections = state.get("sections", {})
protocol_id = state.get("extracted_study", {}).get("protocol_id")
if not protocol_id:
protocol_id = state.get("document_metadata", {}).get("protocol_id")
if not protocol_id:
return {
**state,
"status": "error",
"error": "No protocol ID available for extraction"
}
# Find objectives/endpoints section
text_for_extraction = ""
for section_name in ["objectives", "objective", "endpoint", "endpoints"]:
for key in sections.keys():
if section_name.lower() in key.lower():
text_for_extraction = sections[key]
break
if text_for_extraction:
break
if not text_for_extraction:
return {
**state,
"status": "warning",
"error": "No objectives/endpoints section found"
}
# Extract objectives and endpoints
result = llm_interface.extract_objectives_and_endpoints(text_for_extraction, protocol_id)
if not result:
return {
**state,
"status": "warning",
"error": "Failed to extract objectives and endpoints"
}
return {
**state,
"extracted_objectives": result.get("objectives", []),
"extracted_endpoints": result.get("endpoints", []),
"status": "objectives_endpoints_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_objectives_endpoints: {str(e)}"
}
def extract_population_criteria(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract inclusion and exclusion criteria using LLM."""
if state.get("status") == "error":
return state
try:
sections = state.get("sections", {})
protocol_id = state.get("extracted_study", {}).get("protocol_id")
if not protocol_id:
protocol_id = state.get("document_metadata", {}).get("protocol_id")
# Find criteria section
text_for_extraction = ""
for section_name in ["eligibility", "inclusion", "exclusion", "criteria", "population"]:
for key in sections.keys():
if section_name.lower() in key.lower():
text_for_extraction = sections[key]
break
if text_for_extraction:
break
if not text_for_extraction:
return {
**state,
"status": "warning",
"error": "No population criteria section found"
}
# Extract criteria
result = llm_interface.extract_population_criteria(text_for_extraction, protocol_id)
if not result:
return {
**state,
"status": "warning",
"error": "Failed to extract population criteria"
}
return {
**state,
"extracted_population": result,
"status": "population_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_population_criteria: {str(e)}"
}
def extract_study_design(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract study design information using LLM."""
if state.get("status") == "error":
return state
try:
sections = state.get("sections", {})
protocol_id = state.get("extracted_study", {}).get("protocol_id")
if not protocol_id:
protocol_id = state.get("document_metadata", {}).get("protocol_id")
# Find study design section
text_for_extraction = ""
for section_name in ["study design", "design", "methodology"]:
for key in sections.keys():
if section_name.lower() in key.lower():
text_for_extraction = sections[key]
break
if text_for_extraction:
break
if not text_for_extraction:
return {
**state,
"status": "warning",
"error": "No study design section found"
}
# Extract study design
result = llm_interface.extract_study_design(text_for_extraction, protocol_id)
if not result:
return {
**state,
"status": "warning",
"error": "Failed to extract study design"
}
return {
**state,
"extracted_design": result,
"status": "design_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_study_design: {str(e)}"
}
def store_in_knowledge_base(state: DocumentExtractionState) -> DocumentExtractionState:
"""Store extracted information in the knowledge base."""
try:
# Skip if there was a critical error
if state.get("status") == "error":
return state
# Extract data from state
document_metadata = state.get("document_metadata", {})
study_info = state.get("extracted_study", {})
objectives = state.get("extracted_objectives", [])
endpoints = state.get("extracted_endpoints", [])
population = state.get("extracted_population", {})
design = state.get("extracted_design", {})
vector_chunks = state.get("vector_chunks", [])
# Ensure we have a protocol ID
protocol_id = study_info.get("protocol_id")
if not protocol_id:
protocol_id = document_metadata.get("protocol_id")
if not protocol_id:
return {
**state,
"status": "error",
"error": "No protocol ID available for knowledge base storage"
}
# Add protocol_id to document_metadata
document_metadata["protocol_id"] = protocol_id
# Store in NoSQL DB
doc_id = knowledge_store.store_document_metadata(document_metadata)
# Store study info if available
if study_info:
study_id = knowledge_store.store_study_info(study_info)
# Store objectives if available
if objectives:
knowledge_store.store_objectives(protocol_id, objectives)
# Store endpoints if available
if endpoints:
knowledge_store.store_endpoints(protocol_id, endpoints)
# Store population criteria if available
if population and "inclusion_criteria" in population:
inclusion = population.get("inclusion_criteria", [])
exclusion = population.get("exclusion_criteria", [])
# Add criterion_type to each criterion
for criterion in inclusion:
criterion["criterion_type"] = "Inclusion"
criterion["protocol_id"] = protocol_id
for criterion in exclusion:
criterion["criterion_type"] = "Exclusion"
criterion["protocol_id"] = protocol_id
# Store all criteria
all_criteria = inclusion + exclusion
knowledge_store.store_population_criteria(protocol_id, all_criteria)
# Store in vector store if chunks available
if vector_chunks:
result = knowledge_store.add_documents(vector_chunks)
if result.get("status") == "error":
return {
**state,
"status": "warning",
"error": f"Warning: Failed to add to vector store: {result.get('message')}"
}
return {
**state,
"status": "completed",
"document_id": doc_id,
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in store_in_knowledge_base: {str(e)}"
}
# =========================================================================
# Protocol Coach Workflow Nodes
# =========================================================================
def retrieve_context_for_query(state: ProtocolCoachState) -> ProtocolCoachState:
"""Retrieve relevant context for a user query."""
try:
query = state["query"]
# Query vector store for context
relevant_docs = knowledge_store.similarity_search(
query=query,
k=5 # Get top 5 most relevant chunks
)
if not relevant_docs:
return {
**state,
"retrieved_context": [],
"error": "No relevant context found"
}
# Format results for easy use
context = [
{
"page_content": doc.page_content,
"metadata": doc.metadata
}
for doc in relevant_docs
]
return {
**state,
"retrieved_context": context
}
except Exception as e:
return {
**state,
"error": f"Exception in retrieve_context_for_query: {str(e)}"
}
def answer_query(state: ProtocolCoachState) -> ProtocolCoachState:
"""Generate answer to user query using retrieved context."""
try:
query = state["query"]
context = state.get("retrieved_context", [])
chat_history = state.get("chat_history", [])
if not context:
return {
**state,
"response": "I don't have enough context to answer that question about the protocol. Please try asking something else or upload relevant documents."
}
# Generate response using LLM
response = llm_interface.answer_protocol_question(
question=query,
context=context,
chat_history=chat_history
)
if not response:
return {
**state,
"response": "I encountered an issue while generating a response. Please try again."
}
return {
**state,
"response": response
}
except Exception as e:
return {
**state,
"response": f"Error: {str(e)}",
"error": f"Exception in answer_query: {str(e)}"
}
# =========================================================================
# Content Authoring Workflow Nodes
# =========================================================================
def retrieve_content_examples(state: ContentAuthoringState) -> ContentAuthoringState:
"""Retrieve examples of similar content for authoring."""
try:
section_type = state["section_type"]
target_protocol_id = state.get("target_protocol_id")
# Create a search query based on section type
search_query = f"{section_type} section for clinical study protocol"
# Set up potential filters
filter_dict = None
if target_protocol_id:
# Exclude the target protocol from examples if specified
filter_dict = {"protocol_id": {"$ne": target_protocol_id}}
# Query vector store for examples
relevant_docs = knowledge_store.similarity_search(
query=search_query,
k=3,
filter_dict=filter_dict
)
if not relevant_docs:
return {
**state,
"retrieved_context": [],
"error": "No relevant examples found"
}
# Format results for easy use
context = [
{
"page_content": doc.page_content,
"metadata": doc.metadata
}
for doc in relevant_docs
]
return {
**state,
"retrieved_context": context
}
except Exception as e:
return {
**state,
"error": f"Exception in retrieve_content_examples: {str(e)}"
}
def generate_content(state: ContentAuthoringState) -> ContentAuthoringState:
"""Generate content for authoring."""
try:
section_type = state["section_type"]
context = state.get("retrieved_context", [])
target_protocol_id = state.get("target_protocol_id")
style_guide = state.get("style_guide")
if not context:
return {
**state,
"generated_content": "I don't have enough examples to generate a good section. Please upload more documents or try a different section type.",
"error": "No context available for generation"
}
# Generate content using LLM
content = llm_interface.generate_content_from_knowledge(
section_type=section_type,
context=context,
protocol_id=target_protocol_id,
style_guide=style_guide
)
if not content:
return {
**state,
"generated_content": "I encountered an issue while generating content. Please try again.",
"error": "Failed to generate content"
}
return {
**state,
"generated_content": content
}
except Exception as e:
return {
**state,
"generated_content": f"Error: {str(e)}",
"error": f"Exception in generate_content: {str(e)}"
}
def critique_content(state: ContentAuthoringState) -> ContentAuthoringState:
"""Critique generated content for quality and consistency."""
# This would normally use an LLM to critique content
# For simplicity, we're returning the content unchanged
return state
# =========================================================================
# Traceability Workflow Nodes
# =========================================================================
def retrieve_document_entities(state: TraceabilityState) -> TraceabilityState:
"""Retrieve entities from source and target documents."""
try:
source_doc_id = state["source_document_id"]
target_doc_id = state["target_document_id"]
entity_type = state["entity_type"]
# Get document metadata
source_doc = knowledge_store.get_document_by_id(source_doc_id)
target_doc = knowledge_store.get_document_by_id(target_doc_id)
if not source_doc or not target_doc:
return {
**state,
"error": "One or both documents not found"
}
# Get protocol IDs
source_protocol_id = source_doc.get("protocol_id")
target_protocol_id = target_doc.get("protocol_id")
if not source_protocol_id or not target_protocol_id:
return {
**state,
"error": "Protocol ID missing from one or both documents"
}
# Retrieve entities based on entity type
source_entities = []
target_entities = []
if entity_type == "objectives":
source_entities = knowledge_store.get_objectives_by_protocol_id(source_protocol_id)
target_entities = knowledge_store.get_objectives_by_protocol_id(target_protocol_id)
elif entity_type == "endpoints":
source_entities = knowledge_store.get_endpoints_by_protocol_id(source_protocol_id)
target_entities = knowledge_store.get_endpoints_by_protocol_id(target_protocol_id)
elif entity_type == "population":
source_entities = knowledge_store.get_population_criteria_by_protocol_id(source_protocol_id)
target_entities = knowledge_store.get_population_criteria_by_protocol_id(target_protocol_id)
if not source_entities or not target_entities:
return {
**state,
"error": f"No {entity_type} found in one or both documents"
}
return {
**state,
"source_entities": source_entities,
"target_entities": target_entities
}
except Exception as e:
return {
**state,
"error": f"Exception in retrieve_document_entities: {str(e)}"
}
def match_entities(state: TraceabilityState) -> TraceabilityState:
"""Match entities between documents based on similarity."""
try:
if "error" in state:
return state
source_entities = state.get("source_entities", [])
target_entities = state.get("target_entities", [])
# Simple matching - in a real system this would use more sophisticated comparison
matched_pairs = []
for source_entity in source_entities:
matches = []
for target_entity in target_entities:
# Compare based on description/text
source_text = source_entity.get("description", source_entity.get("text", ""))
target_text = target_entity.get("description", target_entity.get("text", ""))
if not source_text or not target_text:
continue
# Simple text comparison - LLM would do better comparison in real system
if len(source_text) > 0 and len(target_text) > 0:
matches.append({
"source_entity": source_entity,
"target_entity": target_entity,
"source_text": source_text,
"target_text": target_text,
"entity_type": state["entity_type"]
})
# If matches found, take the top one
if matches:
matched_pairs.append(matches[0])
return {
**state,
"matched_pairs": matched_pairs
}
except Exception as e:
return {
**state,
"error": f"Exception in match_entities: {str(e)}"
}
def analyze_matches(state: TraceabilityState) -> TraceabilityState:
"""Analyze matches between documents to identify consistency issues."""
try:
if "error" in state:
return state
matched_pairs = state.get("matched_pairs", [])
source_doc_id = state["source_document_id"]
target_doc_id = state["target_document_id"]
if not matched_pairs:
return {
**state,
"analysis": "No matching entities found between the documents."
}
# Get document metadata
source_doc = knowledge_store.get_document_by_id(source_doc_id)
target_doc = knowledge_store.get_document_by_id(target_doc_id)
# Use LLM to analyze matches
analysis = llm_interface.find_document_connections(
source_doc_info=source_doc,
target_doc_info=target_doc,
entity_pairs=matched_pairs
)
return {
**state,
"analysis": analysis
}
except Exception as e:
return {
**state,
"error": f"Exception in analyze_matches: {str(e)}",
"analysis": f"Error analyzing matches: {str(e)}"
}
# =========================================================================
# Graph Building Functions
# =========================================================================
def build_document_extraction_graph():
"""Build and return document extraction workflow graph."""
workflow = StateGraph(DocumentExtractionState)
# Add nodes
workflow.add_node("parse_document", parse_document)
workflow.add_node("extract_study_info", extract_study_info)
workflow.add_node("extract_objectives_endpoints", extract_objectives_endpoints)
workflow.add_node("extract_population_criteria", extract_population_criteria)
workflow.add_node("extract_study_design", extract_study_design)
workflow.add_node("store_in_knowledge_base", store_in_knowledge_base)
# Add edges - sequential process
workflow.add_edge("parse_document", "extract_study_info")
workflow.add_edge("extract_study_info", "extract_objectives_endpoints")
workflow.add_edge("extract_objectives_endpoints", "extract_population_criteria")
workflow.add_edge("extract_population_criteria", "extract_study_design")
workflow.add_edge("extract_study_design", "store_in_knowledge_base")
workflow.add_edge("store_in_knowledge_base", END)
# Instead of using conditional edges for all nodes,
# let each function handle its own error status
# This simplifies the graph structure and avoids the conditional edge issue
workflow.set_entry_point("parse_document")
return workflow.compile()
def build_protocol_coach_graph():
"""Build and return protocol coach workflow graph."""
workflow = StateGraph(ProtocolCoachState)
# Add nodes
workflow.add_node("retrieve_context", retrieve_context_for_query)
workflow.add_node("answer_query", answer_query)
# Add edges
workflow.add_edge("retrieve_context", "answer_query")
workflow.add_edge("answer_query", END)
workflow.set_entry_point("retrieve_context")
return workflow.compile()
def build_content_authoring_graph():
"""Build and return content authoring workflow graph."""
workflow = StateGraph(ContentAuthoringState)
# Add nodes
workflow.add_node("retrieve_examples", retrieve_content_examples)
workflow.add_node("generate_content", generate_content)
workflow.add_node("critique_content", critique_content)
# Add edges
workflow.add_edge("retrieve_examples", "generate_content")
workflow.add_edge("generate_content", "critique_content")
workflow.add_edge("critique_content", END)
workflow.set_entry_point("retrieve_examples")
return workflow.compile()
def build_traceability_graph():
"""Build and return traceability analysis workflow graph."""
workflow = StateGraph(TraceabilityState)
# Add nodes
workflow.add_node("retrieve_entities", retrieve_document_entities)
workflow.add_node("match_entities", match_entities)
workflow.add_node("analyze_matches", analyze_matches)
# Add edges
workflow.add_edge("retrieve_entities", "match_entities")
workflow.add_edge("match_entities", "analyze_matches")
workflow.add_edge("analyze_matches", END)
workflow.set_entry_point("retrieve_entities")
return workflow.compile() |