Spaces:
Runtime error
Runtime error
| """ | |
| LangGraph agent orchestration for document processing, content authoring, and protocol coach. | |
| """ | |
| from langgraph.graph import StateGraph, END | |
| from typing import TypedDict, Dict, List, Any, Optional, Literal, Annotated, cast | |
| import operator | |
| import uuid | |
| from schemas import DocumentExtractionState, ProtocolCoachState, ContentAuthoringState, TraceabilityState | |
| from pdf_processor import PDFProcessor | |
| from knowledge_store import KnowledgeStore | |
| from llm_interface import LLMInterface | |
| # Initialize handlers | |
| pdf_processor = None | |
| knowledge_store = None | |
| llm_interface = None | |
| def init_handlers(api_key=None): | |
| """Initialize handlers for PDF processing, knowledge store, and LLM.""" | |
| global pdf_processor, knowledge_store, llm_interface | |
| pdf_processor = PDFProcessor() | |
| knowledge_store = KnowledgeStore() | |
| llm_interface = LLMInterface(api_key=api_key) | |
| return pdf_processor, knowledge_store, llm_interface | |
| # ========================================================================= | |
| # Document Extraction Workflow Nodes | |
| # ========================================================================= | |
| def parse_document(state: DocumentExtractionState) -> DocumentExtractionState: | |
| """Parse PDF document and extract text.""" | |
| try: | |
| document_path = state["document_path"] | |
| # Process document with PDFProcessor | |
| result = pdf_processor.process_complete_document(document_path) | |
| if result["status"] == "error": | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": f"Failed to parse document: {result.get('error', 'Unknown error')}" | |
| } | |
| return { | |
| **state, | |
| "document_text": result.get("full_text", ""), | |
| "document_metadata": result.get("metadata", {}), | |
| "sections": result.get("sections", {}), | |
| "vector_chunks": result.get("chunks", []), | |
| "status": "parsed" | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": f"Exception in parse_document: {str(e)}" | |
| } | |
| def extract_study_info(state: DocumentExtractionState) -> DocumentExtractionState: | |
| """Extract study information using LLM.""" | |
| if state.get("status") == "error": | |
| return state | |
| try: | |
| # Use synopsis or first few sections for study info extraction | |
| text_for_extraction = "" | |
| sections = state.get("sections", {}) | |
| # Check if sections is a list (section names only) or a dict (section name -> content) | |
| if isinstance(sections, list): | |
| # Just use the document text since we don't have section content | |
| if "document_text" in state: | |
| text_for_extraction = state["document_text"][:20000] # Use first 20k chars | |
| else: | |
| # Try to find synopsis or summary section first | |
| for section_name in ["synopsis", "summary", "overview"]: | |
| if section_name.lower() in [s.lower() for s in sections.keys()]: | |
| section_key = next(k for k in sections.keys() if k.lower() == section_name.lower()) | |
| text_for_extraction = sections[section_key] | |
| break | |
| # If no synopsis found, use the beginning of the document | |
| if not text_for_extraction and "document_text" in state: | |
| text_for_extraction = state["document_text"][:20000] # Use first 20k chars | |
| if not text_for_extraction: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": "No text available for study info extraction" | |
| } | |
| # Extract study info using LLM | |
| study_info = llm_interface.extract_study_info(text_for_extraction) | |
| if not study_info: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": "Failed to extract study information" | |
| } | |
| # Ensure protocol_id is in study_info | |
| if "protocol_id" not in study_info and "document_metadata" in state: | |
| study_info["protocol_id"] = state["document_metadata"].get("protocol_id") | |
| return { | |
| **state, | |
| "extracted_study": study_info, | |
| "status": "study_extracted" | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": f"Exception in extract_study_info: {str(e)}" | |
| } | |
| def extract_objectives_endpoints(state: DocumentExtractionState) -> DocumentExtractionState: | |
| """Extract objectives and endpoints using LLM.""" | |
| if state.get("status") == "error": | |
| return state | |
| try: | |
| sections = state.get("sections", {}) | |
| protocol_id = state.get("extracted_study", {}).get("protocol_id") | |
| if not protocol_id: | |
| protocol_id = state.get("document_metadata", {}).get("protocol_id") | |
| if not protocol_id: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": "No protocol ID available for extraction" | |
| } | |
| # Find objectives/endpoints section | |
| text_for_extraction = "" | |
| for section_name in ["objectives", "objective", "endpoint", "endpoints"]: | |
| for key in sections.keys(): | |
| if section_name.lower() in key.lower(): | |
| text_for_extraction = sections[key] | |
| break | |
| if text_for_extraction: | |
| break | |
| if not text_for_extraction: | |
| return { | |
| **state, | |
| "status": "warning", | |
| "error": "No objectives/endpoints section found" | |
| } | |
| # Extract objectives and endpoints | |
| result = llm_interface.extract_objectives_and_endpoints(text_for_extraction, protocol_id) | |
| if not result: | |
| return { | |
| **state, | |
| "status": "warning", | |
| "error": "Failed to extract objectives and endpoints" | |
| } | |
| return { | |
| **state, | |
| "extracted_objectives": result.get("objectives", []), | |
| "extracted_endpoints": result.get("endpoints", []), | |
| "status": "objectives_endpoints_extracted" | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": f"Exception in extract_objectives_endpoints: {str(e)}" | |
| } | |
| def extract_population_criteria(state: DocumentExtractionState) -> DocumentExtractionState: | |
| """Extract inclusion and exclusion criteria using LLM.""" | |
| if state.get("status") == "error": | |
| return state | |
| try: | |
| sections = state.get("sections", {}) | |
| protocol_id = state.get("extracted_study", {}).get("protocol_id") | |
| if not protocol_id: | |
| protocol_id = state.get("document_metadata", {}).get("protocol_id") | |
| # Find criteria section | |
| text_for_extraction = "" | |
| for section_name in ["eligibility", "inclusion", "exclusion", "criteria", "population"]: | |
| for key in sections.keys(): | |
| if section_name.lower() in key.lower(): | |
| text_for_extraction = sections[key] | |
| break | |
| if text_for_extraction: | |
| break | |
| if not text_for_extraction: | |
| return { | |
| **state, | |
| "status": "warning", | |
| "error": "No population criteria section found" | |
| } | |
| # Extract criteria | |
| result = llm_interface.extract_population_criteria(text_for_extraction, protocol_id) | |
| if not result: | |
| return { | |
| **state, | |
| "status": "warning", | |
| "error": "Failed to extract population criteria" | |
| } | |
| return { | |
| **state, | |
| "extracted_population": result, | |
| "status": "population_extracted" | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": f"Exception in extract_population_criteria: {str(e)}" | |
| } | |
| def extract_study_design(state: DocumentExtractionState) -> DocumentExtractionState: | |
| """Extract study design information using LLM.""" | |
| if state.get("status") == "error": | |
| return state | |
| try: | |
| sections = state.get("sections", {}) | |
| protocol_id = state.get("extracted_study", {}).get("protocol_id") | |
| if not protocol_id: | |
| protocol_id = state.get("document_metadata", {}).get("protocol_id") | |
| # Find study design section | |
| text_for_extraction = "" | |
| for section_name in ["study design", "design", "methodology"]: | |
| for key in sections.keys(): | |
| if section_name.lower() in key.lower(): | |
| text_for_extraction = sections[key] | |
| break | |
| if text_for_extraction: | |
| break | |
| if not text_for_extraction: | |
| return { | |
| **state, | |
| "status": "warning", | |
| "error": "No study design section found" | |
| } | |
| # Extract study design | |
| result = llm_interface.extract_study_design(text_for_extraction, protocol_id) | |
| if not result: | |
| return { | |
| **state, | |
| "status": "warning", | |
| "error": "Failed to extract study design" | |
| } | |
| return { | |
| **state, | |
| "extracted_design": result, | |
| "status": "design_extracted" | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": f"Exception in extract_study_design: {str(e)}" | |
| } | |
| def store_in_knowledge_base(state: DocumentExtractionState) -> DocumentExtractionState: | |
| """Store extracted information in the knowledge base.""" | |
| try: | |
| # Skip if there was a critical error | |
| if state.get("status") == "error": | |
| return state | |
| # Extract data from state | |
| document_metadata = state.get("document_metadata", {}) | |
| study_info = state.get("extracted_study", {}) | |
| objectives = state.get("extracted_objectives", []) | |
| endpoints = state.get("extracted_endpoints", []) | |
| population = state.get("extracted_population", {}) | |
| design = state.get("extracted_design", {}) | |
| vector_chunks = state.get("vector_chunks", []) | |
| # Ensure we have a protocol ID | |
| protocol_id = study_info.get("protocol_id") | |
| if not protocol_id: | |
| protocol_id = document_metadata.get("protocol_id") | |
| if not protocol_id: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": "No protocol ID available for knowledge base storage" | |
| } | |
| # Add protocol_id to document_metadata | |
| document_metadata["protocol_id"] = protocol_id | |
| # Store in NoSQL DB | |
| doc_id = knowledge_store.store_document_metadata(document_metadata) | |
| # Store study info if available | |
| if study_info: | |
| study_id = knowledge_store.store_study_info(study_info) | |
| # Store objectives if available | |
| if objectives: | |
| knowledge_store.store_objectives(protocol_id, objectives) | |
| # Store endpoints if available | |
| if endpoints: | |
| knowledge_store.store_endpoints(protocol_id, endpoints) | |
| # Store population criteria if available | |
| if population and "inclusion_criteria" in population: | |
| inclusion = population.get("inclusion_criteria", []) | |
| exclusion = population.get("exclusion_criteria", []) | |
| # Add criterion_type to each criterion | |
| for criterion in inclusion: | |
| criterion["criterion_type"] = "Inclusion" | |
| criterion["protocol_id"] = protocol_id | |
| for criterion in exclusion: | |
| criterion["criterion_type"] = "Exclusion" | |
| criterion["protocol_id"] = protocol_id | |
| # Store all criteria | |
| all_criteria = inclusion + exclusion | |
| knowledge_store.store_population_criteria(protocol_id, all_criteria) | |
| # Store in vector store if chunks available | |
| if vector_chunks: | |
| result = knowledge_store.add_documents(vector_chunks) | |
| if result.get("status") == "error": | |
| return { | |
| **state, | |
| "status": "warning", | |
| "error": f"Warning: Failed to add to vector store: {result.get('message')}" | |
| } | |
| return { | |
| **state, | |
| "status": "completed", | |
| "document_id": doc_id, | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "status": "error", | |
| "error": f"Exception in store_in_knowledge_base: {str(e)}" | |
| } | |
| # ========================================================================= | |
| # Protocol Coach Workflow Nodes | |
| # ========================================================================= | |
| def retrieve_context_for_query(state: ProtocolCoachState) -> ProtocolCoachState: | |
| """Retrieve relevant context for a user query.""" | |
| try: | |
| query = state["query"] | |
| # Query vector store for context | |
| relevant_docs = knowledge_store.similarity_search( | |
| query=query, | |
| k=5 # Get top 5 most relevant chunks | |
| ) | |
| if not relevant_docs: | |
| return { | |
| **state, | |
| "retrieved_context": [], | |
| "error": "No relevant context found" | |
| } | |
| # Format results for easy use | |
| context = [ | |
| { | |
| "page_content": doc.page_content, | |
| "metadata": doc.metadata | |
| } | |
| for doc in relevant_docs | |
| ] | |
| return { | |
| **state, | |
| "retrieved_context": context | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "error": f"Exception in retrieve_context_for_query: {str(e)}" | |
| } | |
| def answer_query(state: ProtocolCoachState) -> ProtocolCoachState: | |
| """Generate answer to user query using retrieved context.""" | |
| try: | |
| query = state["query"] | |
| context = state.get("retrieved_context", []) | |
| chat_history = state.get("chat_history", []) | |
| if not context: | |
| return { | |
| **state, | |
| "response": "I don't have enough context to answer that question about the protocol. Please try asking something else or upload relevant documents." | |
| } | |
| # Generate response using LLM | |
| response = llm_interface.answer_protocol_question( | |
| question=query, | |
| context=context, | |
| chat_history=chat_history | |
| ) | |
| if not response: | |
| return { | |
| **state, | |
| "response": "I encountered an issue while generating a response. Please try again." | |
| } | |
| return { | |
| **state, | |
| "response": response | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "response": f"Error: {str(e)}", | |
| "error": f"Exception in answer_query: {str(e)}" | |
| } | |
| # ========================================================================= | |
| # Content Authoring Workflow Nodes | |
| # ========================================================================= | |
| def retrieve_content_examples(state: ContentAuthoringState) -> ContentAuthoringState: | |
| """Retrieve examples of similar content for authoring.""" | |
| try: | |
| section_type = state["section_type"] | |
| target_protocol_id = state.get("target_protocol_id") | |
| # Create a search query based on section type | |
| search_query = f"{section_type} section for clinical study protocol" | |
| # Set up potential filters | |
| filter_dict = None | |
| if target_protocol_id: | |
| # Exclude the target protocol from examples if specified | |
| filter_dict = {"protocol_id": {"$ne": target_protocol_id}} | |
| # Query vector store for examples | |
| relevant_docs = knowledge_store.similarity_search( | |
| query=search_query, | |
| k=3, | |
| filter_dict=filter_dict | |
| ) | |
| if not relevant_docs: | |
| return { | |
| **state, | |
| "retrieved_context": [], | |
| "error": "No relevant examples found" | |
| } | |
| # Format results for easy use | |
| context = [ | |
| { | |
| "page_content": doc.page_content, | |
| "metadata": doc.metadata | |
| } | |
| for doc in relevant_docs | |
| ] | |
| return { | |
| **state, | |
| "retrieved_context": context | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "error": f"Exception in retrieve_content_examples: {str(e)}" | |
| } | |
| def generate_content(state: ContentAuthoringState) -> ContentAuthoringState: | |
| """Generate content for authoring.""" | |
| try: | |
| section_type = state["section_type"] | |
| context = state.get("retrieved_context", []) | |
| target_protocol_id = state.get("target_protocol_id") | |
| style_guide = state.get("style_guide") | |
| if not context: | |
| return { | |
| **state, | |
| "generated_content": "I don't have enough examples to generate a good section. Please upload more documents or try a different section type.", | |
| "error": "No context available for generation" | |
| } | |
| # Generate content using LLM | |
| content = llm_interface.generate_content_from_knowledge( | |
| section_type=section_type, | |
| context=context, | |
| protocol_id=target_protocol_id, | |
| style_guide=style_guide | |
| ) | |
| if not content: | |
| return { | |
| **state, | |
| "generated_content": "I encountered an issue while generating content. Please try again.", | |
| "error": "Failed to generate content" | |
| } | |
| return { | |
| **state, | |
| "generated_content": content | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "generated_content": f"Error: {str(e)}", | |
| "error": f"Exception in generate_content: {str(e)}" | |
| } | |
| def critique_content(state: ContentAuthoringState) -> ContentAuthoringState: | |
| """Critique generated content for quality and consistency.""" | |
| # This would normally use an LLM to critique content | |
| # For simplicity, we're returning the content unchanged | |
| return state | |
| # ========================================================================= | |
| # Traceability Workflow Nodes | |
| # ========================================================================= | |
| def retrieve_document_entities(state: TraceabilityState) -> TraceabilityState: | |
| """Retrieve entities from source and target documents.""" | |
| try: | |
| source_doc_id = state["source_document_id"] | |
| target_doc_id = state["target_document_id"] | |
| entity_type = state["entity_type"] | |
| # Get document metadata | |
| source_doc = knowledge_store.get_document_by_id(source_doc_id) | |
| target_doc = knowledge_store.get_document_by_id(target_doc_id) | |
| if not source_doc or not target_doc: | |
| return { | |
| **state, | |
| "error": "One or both documents not found" | |
| } | |
| # Get protocol IDs | |
| source_protocol_id = source_doc.get("protocol_id") | |
| target_protocol_id = target_doc.get("protocol_id") | |
| if not source_protocol_id or not target_protocol_id: | |
| return { | |
| **state, | |
| "error": "Protocol ID missing from one or both documents" | |
| } | |
| # Retrieve entities based on entity type | |
| source_entities = [] | |
| target_entities = [] | |
| if entity_type == "objectives": | |
| source_entities = knowledge_store.get_objectives_by_protocol_id(source_protocol_id) | |
| target_entities = knowledge_store.get_objectives_by_protocol_id(target_protocol_id) | |
| elif entity_type == "endpoints": | |
| source_entities = knowledge_store.get_endpoints_by_protocol_id(source_protocol_id) | |
| target_entities = knowledge_store.get_endpoints_by_protocol_id(target_protocol_id) | |
| elif entity_type == "population": | |
| source_entities = knowledge_store.get_population_criteria_by_protocol_id(source_protocol_id) | |
| target_entities = knowledge_store.get_population_criteria_by_protocol_id(target_protocol_id) | |
| if not source_entities or not target_entities: | |
| return { | |
| **state, | |
| "error": f"No {entity_type} found in one or both documents" | |
| } | |
| return { | |
| **state, | |
| "source_entities": source_entities, | |
| "target_entities": target_entities | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "error": f"Exception in retrieve_document_entities: {str(e)}" | |
| } | |
| def match_entities(state: TraceabilityState) -> TraceabilityState: | |
| """Match entities between documents based on similarity.""" | |
| try: | |
| if "error" in state: | |
| return state | |
| source_entities = state.get("source_entities", []) | |
| target_entities = state.get("target_entities", []) | |
| # Simple matching - in a real system this would use more sophisticated comparison | |
| matched_pairs = [] | |
| for source_entity in source_entities: | |
| matches = [] | |
| for target_entity in target_entities: | |
| # Compare based on description/text | |
| source_text = source_entity.get("description", source_entity.get("text", "")) | |
| target_text = target_entity.get("description", target_entity.get("text", "")) | |
| if not source_text or not target_text: | |
| continue | |
| # Simple text comparison - LLM would do better comparison in real system | |
| if len(source_text) > 0 and len(target_text) > 0: | |
| matches.append({ | |
| "source_entity": source_entity, | |
| "target_entity": target_entity, | |
| "source_text": source_text, | |
| "target_text": target_text, | |
| "entity_type": state["entity_type"] | |
| }) | |
| # If matches found, take the top one | |
| if matches: | |
| matched_pairs.append(matches[0]) | |
| return { | |
| **state, | |
| "matched_pairs": matched_pairs | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "error": f"Exception in match_entities: {str(e)}" | |
| } | |
| def analyze_matches(state: TraceabilityState) -> TraceabilityState: | |
| """Analyze matches between documents to identify consistency issues.""" | |
| try: | |
| if "error" in state: | |
| return state | |
| matched_pairs = state.get("matched_pairs", []) | |
| source_doc_id = state["source_document_id"] | |
| target_doc_id = state["target_document_id"] | |
| if not matched_pairs: | |
| return { | |
| **state, | |
| "analysis": "No matching entities found between the documents." | |
| } | |
| # Get document metadata | |
| source_doc = knowledge_store.get_document_by_id(source_doc_id) | |
| target_doc = knowledge_store.get_document_by_id(target_doc_id) | |
| # Use LLM to analyze matches | |
| analysis = llm_interface.find_document_connections( | |
| source_doc_info=source_doc, | |
| target_doc_info=target_doc, | |
| entity_pairs=matched_pairs | |
| ) | |
| return { | |
| **state, | |
| "analysis": analysis | |
| } | |
| except Exception as e: | |
| return { | |
| **state, | |
| "error": f"Exception in analyze_matches: {str(e)}", | |
| "analysis": f"Error analyzing matches: {str(e)}" | |
| } | |
| # ========================================================================= | |
| # Graph Building Functions | |
| # ========================================================================= | |
| def build_document_extraction_graph(): | |
| """Build and return document extraction workflow graph.""" | |
| workflow = StateGraph(DocumentExtractionState) | |
| # Add nodes | |
| workflow.add_node("parse_document", parse_document) | |
| workflow.add_node("extract_study_info", extract_study_info) | |
| workflow.add_node("extract_objectives_endpoints", extract_objectives_endpoints) | |
| workflow.add_node("extract_population_criteria", extract_population_criteria) | |
| workflow.add_node("extract_study_design", extract_study_design) | |
| workflow.add_node("store_in_knowledge_base", store_in_knowledge_base) | |
| # Add edges - sequential process | |
| workflow.add_edge("parse_document", "extract_study_info") | |
| workflow.add_edge("extract_study_info", "extract_objectives_endpoints") | |
| workflow.add_edge("extract_objectives_endpoints", "extract_population_criteria") | |
| workflow.add_edge("extract_population_criteria", "extract_study_design") | |
| workflow.add_edge("extract_study_design", "store_in_knowledge_base") | |
| workflow.add_edge("store_in_knowledge_base", END) | |
| # Instead of using conditional edges for all nodes, | |
| # let each function handle its own error status | |
| # This simplifies the graph structure and avoids the conditional edge issue | |
| workflow.set_entry_point("parse_document") | |
| return workflow.compile() | |
| def build_protocol_coach_graph(): | |
| """Build and return protocol coach workflow graph.""" | |
| workflow = StateGraph(ProtocolCoachState) | |
| # Add nodes | |
| workflow.add_node("retrieve_context", retrieve_context_for_query) | |
| workflow.add_node("answer_query", answer_query) | |
| # Add edges | |
| workflow.add_edge("retrieve_context", "answer_query") | |
| workflow.add_edge("answer_query", END) | |
| workflow.set_entry_point("retrieve_context") | |
| return workflow.compile() | |
| def build_content_authoring_graph(): | |
| """Build and return content authoring workflow graph.""" | |
| workflow = StateGraph(ContentAuthoringState) | |
| # Add nodes | |
| workflow.add_node("retrieve_examples", retrieve_content_examples) | |
| workflow.add_node("generate_content", generate_content) | |
| workflow.add_node("critique_content", critique_content) | |
| # Add edges | |
| workflow.add_edge("retrieve_examples", "generate_content") | |
| workflow.add_edge("generate_content", "critique_content") | |
| workflow.add_edge("critique_content", END) | |
| workflow.set_entry_point("retrieve_examples") | |
| return workflow.compile() | |
| def build_traceability_graph(): | |
| """Build and return traceability analysis workflow graph.""" | |
| workflow = StateGraph(TraceabilityState) | |
| # Add nodes | |
| workflow.add_node("retrieve_entities", retrieve_document_entities) | |
| workflow.add_node("match_entities", match_entities) | |
| workflow.add_node("analyze_matches", analyze_matches) | |
| # Add edges | |
| workflow.add_edge("retrieve_entities", "match_entities") | |
| workflow.add_edge("match_entities", "analyze_matches") | |
| workflow.add_edge("analyze_matches", END) | |
| workflow.set_entry_point("retrieve_entities") | |
| return workflow.compile() |