""" LangGraph agent orchestration for document processing, content authoring, and protocol coach. """ from langgraph.graph import StateGraph, END from typing import TypedDict, Dict, List, Any, Optional, Literal, Annotated, cast import operator import uuid from schemas import DocumentExtractionState, ProtocolCoachState, ContentAuthoringState, TraceabilityState from pdf_processor import PDFProcessor from knowledge_store import KnowledgeStore from llm_interface import LLMInterface # Initialize handlers pdf_processor = None knowledge_store = None llm_interface = None def init_handlers(api_key=None): """Initialize handlers for PDF processing, knowledge store, and LLM.""" global pdf_processor, knowledge_store, llm_interface pdf_processor = PDFProcessor() knowledge_store = KnowledgeStore() llm_interface = LLMInterface(api_key=api_key) return pdf_processor, knowledge_store, llm_interface # ========================================================================= # Document Extraction Workflow Nodes # ========================================================================= def parse_document(state: DocumentExtractionState) -> DocumentExtractionState: """Parse PDF document and extract text.""" try: document_path = state["document_path"] # Process document with PDFProcessor result = pdf_processor.process_complete_document(document_path) if result["status"] == "error": return { **state, "status": "error", "error": f"Failed to parse document: {result.get('error', 'Unknown error')}" } return { **state, "document_text": result.get("full_text", ""), "document_metadata": result.get("metadata", {}), "sections": result.get("sections", {}), "vector_chunks": result.get("chunks", []), "status": "parsed" } except Exception as e: return { **state, "status": "error", "error": f"Exception in parse_document: {str(e)}" } def extract_study_info(state: DocumentExtractionState) -> DocumentExtractionState: """Extract study information using LLM.""" if state.get("status") == "error": return state try: # Use synopsis or first few sections for study info extraction text_for_extraction = "" sections = state.get("sections", {}) # Check if sections is a list (section names only) or a dict (section name -> content) if isinstance(sections, list): # Just use the document text since we don't have section content if "document_text" in state: text_for_extraction = state["document_text"][:20000] # Use first 20k chars else: # Try to find synopsis or summary section first for section_name in ["synopsis", "summary", "overview"]: if section_name.lower() in [s.lower() for s in sections.keys()]: section_key = next(k for k in sections.keys() if k.lower() == section_name.lower()) text_for_extraction = sections[section_key] break # If no synopsis found, use the beginning of the document if not text_for_extraction and "document_text" in state: text_for_extraction = state["document_text"][:20000] # Use first 20k chars if not text_for_extraction: return { **state, "status": "error", "error": "No text available for study info extraction" } # Extract study info using LLM study_info = llm_interface.extract_study_info(text_for_extraction) if not study_info: return { **state, "status": "error", "error": "Failed to extract study information" } # Ensure protocol_id is in study_info if "protocol_id" not in study_info and "document_metadata" in state: study_info["protocol_id"] = state["document_metadata"].get("protocol_id") return { **state, "extracted_study": study_info, "status": "study_extracted" } except Exception as e: return { **state, "status": "error", "error": f"Exception in extract_study_info: {str(e)}" } def extract_objectives_endpoints(state: DocumentExtractionState) -> DocumentExtractionState: """Extract objectives and endpoints using LLM.""" if state.get("status") == "error": return state try: sections = state.get("sections", {}) protocol_id = state.get("extracted_study", {}).get("protocol_id") if not protocol_id: protocol_id = state.get("document_metadata", {}).get("protocol_id") if not protocol_id: return { **state, "status": "error", "error": "No protocol ID available for extraction" } # Find objectives/endpoints section text_for_extraction = "" for section_name in ["objectives", "objective", "endpoint", "endpoints"]: for key in sections.keys(): if section_name.lower() in key.lower(): text_for_extraction = sections[key] break if text_for_extraction: break if not text_for_extraction: return { **state, "status": "warning", "error": "No objectives/endpoints section found" } # Extract objectives and endpoints result = llm_interface.extract_objectives_and_endpoints(text_for_extraction, protocol_id) if not result: return { **state, "status": "warning", "error": "Failed to extract objectives and endpoints" } return { **state, "extracted_objectives": result.get("objectives", []), "extracted_endpoints": result.get("endpoints", []), "status": "objectives_endpoints_extracted" } except Exception as e: return { **state, "status": "error", "error": f"Exception in extract_objectives_endpoints: {str(e)}" } def extract_population_criteria(state: DocumentExtractionState) -> DocumentExtractionState: """Extract inclusion and exclusion criteria using LLM.""" if state.get("status") == "error": return state try: sections = state.get("sections", {}) protocol_id = state.get("extracted_study", {}).get("protocol_id") if not protocol_id: protocol_id = state.get("document_metadata", {}).get("protocol_id") # Find criteria section text_for_extraction = "" for section_name in ["eligibility", "inclusion", "exclusion", "criteria", "population"]: for key in sections.keys(): if section_name.lower() in key.lower(): text_for_extraction = sections[key] break if text_for_extraction: break if not text_for_extraction: return { **state, "status": "warning", "error": "No population criteria section found" } # Extract criteria result = llm_interface.extract_population_criteria(text_for_extraction, protocol_id) if not result: return { **state, "status": "warning", "error": "Failed to extract population criteria" } return { **state, "extracted_population": result, "status": "population_extracted" } except Exception as e: return { **state, "status": "error", "error": f"Exception in extract_population_criteria: {str(e)}" } def extract_study_design(state: DocumentExtractionState) -> DocumentExtractionState: """Extract study design information using LLM.""" if state.get("status") == "error": return state try: sections = state.get("sections", {}) protocol_id = state.get("extracted_study", {}).get("protocol_id") if not protocol_id: protocol_id = state.get("document_metadata", {}).get("protocol_id") # Find study design section text_for_extraction = "" for section_name in ["study design", "design", "methodology"]: for key in sections.keys(): if section_name.lower() in key.lower(): text_for_extraction = sections[key] break if text_for_extraction: break if not text_for_extraction: return { **state, "status": "warning", "error": "No study design section found" } # Extract study design result = llm_interface.extract_study_design(text_for_extraction, protocol_id) if not result: return { **state, "status": "warning", "error": "Failed to extract study design" } return { **state, "extracted_design": result, "status": "design_extracted" } except Exception as e: return { **state, "status": "error", "error": f"Exception in extract_study_design: {str(e)}" } def store_in_knowledge_base(state: DocumentExtractionState) -> DocumentExtractionState: """Store extracted information in the knowledge base.""" try: # Skip if there was a critical error if state.get("status") == "error": return state # Extract data from state document_metadata = state.get("document_metadata", {}) study_info = state.get("extracted_study", {}) objectives = state.get("extracted_objectives", []) endpoints = state.get("extracted_endpoints", []) population = state.get("extracted_population", {}) design = state.get("extracted_design", {}) vector_chunks = state.get("vector_chunks", []) # Ensure we have a protocol ID protocol_id = study_info.get("protocol_id") if not protocol_id: protocol_id = document_metadata.get("protocol_id") if not protocol_id: return { **state, "status": "error", "error": "No protocol ID available for knowledge base storage" } # Add protocol_id to document_metadata document_metadata["protocol_id"] = protocol_id # Store in NoSQL DB doc_id = knowledge_store.store_document_metadata(document_metadata) # Store study info if available if study_info: study_id = knowledge_store.store_study_info(study_info) # Store objectives if available if objectives: knowledge_store.store_objectives(protocol_id, objectives) # Store endpoints if available if endpoints: knowledge_store.store_endpoints(protocol_id, endpoints) # Store population criteria if available if population and "inclusion_criteria" in population: inclusion = population.get("inclusion_criteria", []) exclusion = population.get("exclusion_criteria", []) # Add criterion_type to each criterion for criterion in inclusion: criterion["criterion_type"] = "Inclusion" criterion["protocol_id"] = protocol_id for criterion in exclusion: criterion["criterion_type"] = "Exclusion" criterion["protocol_id"] = protocol_id # Store all criteria all_criteria = inclusion + exclusion knowledge_store.store_population_criteria(protocol_id, all_criteria) # Store in vector store if chunks available if vector_chunks: result = knowledge_store.add_documents(vector_chunks) if result.get("status") == "error": return { **state, "status": "warning", "error": f"Warning: Failed to add to vector store: {result.get('message')}" } return { **state, "status": "completed", "document_id": doc_id, } except Exception as e: return { **state, "status": "error", "error": f"Exception in store_in_knowledge_base: {str(e)}" } # ========================================================================= # Protocol Coach Workflow Nodes # ========================================================================= def retrieve_context_for_query(state: ProtocolCoachState) -> ProtocolCoachState: """Retrieve relevant context for a user query.""" try: query = state["query"] # Query vector store for context relevant_docs = knowledge_store.similarity_search( query=query, k=5 # Get top 5 most relevant chunks ) if not relevant_docs: return { **state, "retrieved_context": [], "error": "No relevant context found" } # Format results for easy use context = [ { "page_content": doc.page_content, "metadata": doc.metadata } for doc in relevant_docs ] return { **state, "retrieved_context": context } except Exception as e: return { **state, "error": f"Exception in retrieve_context_for_query: {str(e)}" } def answer_query(state: ProtocolCoachState) -> ProtocolCoachState: """Generate answer to user query using retrieved context.""" try: query = state["query"] context = state.get("retrieved_context", []) chat_history = state.get("chat_history", []) if not context: return { **state, "response": "I don't have enough context to answer that question about the protocol. Please try asking something else or upload relevant documents." } # Generate response using LLM response = llm_interface.answer_protocol_question( question=query, context=context, chat_history=chat_history ) if not response: return { **state, "response": "I encountered an issue while generating a response. Please try again." } return { **state, "response": response } except Exception as e: return { **state, "response": f"Error: {str(e)}", "error": f"Exception in answer_query: {str(e)}" } # ========================================================================= # Content Authoring Workflow Nodes # ========================================================================= def retrieve_content_examples(state: ContentAuthoringState) -> ContentAuthoringState: """Retrieve examples of similar content for authoring.""" try: section_type = state["section_type"] target_protocol_id = state.get("target_protocol_id") # Create a search query based on section type search_query = f"{section_type} section for clinical study protocol" # Set up potential filters filter_dict = None if target_protocol_id: # Exclude the target protocol from examples if specified filter_dict = {"protocol_id": {"$ne": target_protocol_id}} # Query vector store for examples relevant_docs = knowledge_store.similarity_search( query=search_query, k=3, filter_dict=filter_dict ) if not relevant_docs: return { **state, "retrieved_context": [], "error": "No relevant examples found" } # Format results for easy use context = [ { "page_content": doc.page_content, "metadata": doc.metadata } for doc in relevant_docs ] return { **state, "retrieved_context": context } except Exception as e: return { **state, "error": f"Exception in retrieve_content_examples: {str(e)}" } def generate_content(state: ContentAuthoringState) -> ContentAuthoringState: """Generate content for authoring.""" try: section_type = state["section_type"] context = state.get("retrieved_context", []) target_protocol_id = state.get("target_protocol_id") style_guide = state.get("style_guide") if not context: return { **state, "generated_content": "I don't have enough examples to generate a good section. Please upload more documents or try a different section type.", "error": "No context available for generation" } # Generate content using LLM content = llm_interface.generate_content_from_knowledge( section_type=section_type, context=context, protocol_id=target_protocol_id, style_guide=style_guide ) if not content: return { **state, "generated_content": "I encountered an issue while generating content. Please try again.", "error": "Failed to generate content" } return { **state, "generated_content": content } except Exception as e: return { **state, "generated_content": f"Error: {str(e)}", "error": f"Exception in generate_content: {str(e)}" } def critique_content(state: ContentAuthoringState) -> ContentAuthoringState: """Critique generated content for quality and consistency.""" # This would normally use an LLM to critique content # For simplicity, we're returning the content unchanged return state # ========================================================================= # Traceability Workflow Nodes # ========================================================================= def retrieve_document_entities(state: TraceabilityState) -> TraceabilityState: """Retrieve entities from source and target documents.""" try: source_doc_id = state["source_document_id"] target_doc_id = state["target_document_id"] entity_type = state["entity_type"] # Get document metadata source_doc = knowledge_store.get_document_by_id(source_doc_id) target_doc = knowledge_store.get_document_by_id(target_doc_id) if not source_doc or not target_doc: return { **state, "error": "One or both documents not found" } # Get protocol IDs source_protocol_id = source_doc.get("protocol_id") target_protocol_id = target_doc.get("protocol_id") if not source_protocol_id or not target_protocol_id: return { **state, "error": "Protocol ID missing from one or both documents" } # Retrieve entities based on entity type source_entities = [] target_entities = [] if entity_type == "objectives": source_entities = knowledge_store.get_objectives_by_protocol_id(source_protocol_id) target_entities = knowledge_store.get_objectives_by_protocol_id(target_protocol_id) elif entity_type == "endpoints": source_entities = knowledge_store.get_endpoints_by_protocol_id(source_protocol_id) target_entities = knowledge_store.get_endpoints_by_protocol_id(target_protocol_id) elif entity_type == "population": source_entities = knowledge_store.get_population_criteria_by_protocol_id(source_protocol_id) target_entities = knowledge_store.get_population_criteria_by_protocol_id(target_protocol_id) if not source_entities or not target_entities: return { **state, "error": f"No {entity_type} found in one or both documents" } return { **state, "source_entities": source_entities, "target_entities": target_entities } except Exception as e: return { **state, "error": f"Exception in retrieve_document_entities: {str(e)}" } def match_entities(state: TraceabilityState) -> TraceabilityState: """Match entities between documents based on similarity.""" try: if "error" in state: return state source_entities = state.get("source_entities", []) target_entities = state.get("target_entities", []) # Simple matching - in a real system this would use more sophisticated comparison matched_pairs = [] for source_entity in source_entities: matches = [] for target_entity in target_entities: # Compare based on description/text source_text = source_entity.get("description", source_entity.get("text", "")) target_text = target_entity.get("description", target_entity.get("text", "")) if not source_text or not target_text: continue # Simple text comparison - LLM would do better comparison in real system if len(source_text) > 0 and len(target_text) > 0: matches.append({ "source_entity": source_entity, "target_entity": target_entity, "source_text": source_text, "target_text": target_text, "entity_type": state["entity_type"] }) # If matches found, take the top one if matches: matched_pairs.append(matches[0]) return { **state, "matched_pairs": matched_pairs } except Exception as e: return { **state, "error": f"Exception in match_entities: {str(e)}" } def analyze_matches(state: TraceabilityState) -> TraceabilityState: """Analyze matches between documents to identify consistency issues.""" try: if "error" in state: return state matched_pairs = state.get("matched_pairs", []) source_doc_id = state["source_document_id"] target_doc_id = state["target_document_id"] if not matched_pairs: return { **state, "analysis": "No matching entities found between the documents." } # Get document metadata source_doc = knowledge_store.get_document_by_id(source_doc_id) target_doc = knowledge_store.get_document_by_id(target_doc_id) # Use LLM to analyze matches analysis = llm_interface.find_document_connections( source_doc_info=source_doc, target_doc_info=target_doc, entity_pairs=matched_pairs ) return { **state, "analysis": analysis } except Exception as e: return { **state, "error": f"Exception in analyze_matches: {str(e)}", "analysis": f"Error analyzing matches: {str(e)}" } # ========================================================================= # Graph Building Functions # ========================================================================= def build_document_extraction_graph(): """Build and return document extraction workflow graph.""" workflow = StateGraph(DocumentExtractionState) # Add nodes workflow.add_node("parse_document", parse_document) workflow.add_node("extract_study_info", extract_study_info) workflow.add_node("extract_objectives_endpoints", extract_objectives_endpoints) workflow.add_node("extract_population_criteria", extract_population_criteria) workflow.add_node("extract_study_design", extract_study_design) workflow.add_node("store_in_knowledge_base", store_in_knowledge_base) # Add edges - sequential process workflow.add_edge("parse_document", "extract_study_info") workflow.add_edge("extract_study_info", "extract_objectives_endpoints") workflow.add_edge("extract_objectives_endpoints", "extract_population_criteria") workflow.add_edge("extract_population_criteria", "extract_study_design") workflow.add_edge("extract_study_design", "store_in_knowledge_base") workflow.add_edge("store_in_knowledge_base", END) # Instead of using conditional edges for all nodes, # let each function handle its own error status # This simplifies the graph structure and avoids the conditional edge issue workflow.set_entry_point("parse_document") return workflow.compile() def build_protocol_coach_graph(): """Build and return protocol coach workflow graph.""" workflow = StateGraph(ProtocolCoachState) # Add nodes workflow.add_node("retrieve_context", retrieve_context_for_query) workflow.add_node("answer_query", answer_query) # Add edges workflow.add_edge("retrieve_context", "answer_query") workflow.add_edge("answer_query", END) workflow.set_entry_point("retrieve_context") return workflow.compile() def build_content_authoring_graph(): """Build and return content authoring workflow graph.""" workflow = StateGraph(ContentAuthoringState) # Add nodes workflow.add_node("retrieve_examples", retrieve_content_examples) workflow.add_node("generate_content", generate_content) workflow.add_node("critique_content", critique_content) # Add edges workflow.add_edge("retrieve_examples", "generate_content") workflow.add_edge("generate_content", "critique_content") workflow.add_edge("critique_content", END) workflow.set_entry_point("retrieve_examples") return workflow.compile() def build_traceability_graph(): """Build and return traceability analysis workflow graph.""" workflow = StateGraph(TraceabilityState) # Add nodes workflow.add_node("retrieve_entities", retrieve_document_entities) workflow.add_node("match_entities", match_entities) workflow.add_node("analyze_matches", analyze_matches) # Add edges workflow.add_edge("retrieve_entities", "match_entities") workflow.add_edge("match_entities", "analyze_matches") workflow.add_edge("analyze_matches", END) workflow.set_entry_point("retrieve_entities") return workflow.compile()