doc_knowledge_base / graph_builder.py
cryogenic22's picture
Update graph_builder.py
6a13f3d verified
"""
LangGraph agent orchestration for document processing, content authoring, and protocol coach.
"""
from langgraph.graph import StateGraph, END
from typing import TypedDict, Dict, List, Any, Optional, Literal, Annotated, cast
import operator
import uuid
from schemas import DocumentExtractionState, ProtocolCoachState, ContentAuthoringState, TraceabilityState
from pdf_processor import PDFProcessor
from knowledge_store import KnowledgeStore
from llm_interface import LLMInterface
# Initialize handlers
pdf_processor = None
knowledge_store = None
llm_interface = None
def init_handlers(api_key=None):
"""Initialize handlers for PDF processing, knowledge store, and LLM."""
global pdf_processor, knowledge_store, llm_interface
pdf_processor = PDFProcessor()
knowledge_store = KnowledgeStore()
llm_interface = LLMInterface(api_key=api_key)
return pdf_processor, knowledge_store, llm_interface
# =========================================================================
# Document Extraction Workflow Nodes
# =========================================================================
def parse_document(state: DocumentExtractionState) -> DocumentExtractionState:
"""Parse PDF document and extract text."""
try:
document_path = state["document_path"]
# Process document with PDFProcessor
result = pdf_processor.process_complete_document(document_path)
if result["status"] == "error":
return {
**state,
"status": "error",
"error": f"Failed to parse document: {result.get('error', 'Unknown error')}"
}
return {
**state,
"document_text": result.get("full_text", ""),
"document_metadata": result.get("metadata", {}),
"sections": result.get("sections", {}),
"vector_chunks": result.get("chunks", []),
"status": "parsed"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in parse_document: {str(e)}"
}
def extract_study_info(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract study information using LLM."""
if state.get("status") == "error":
return state
try:
# Use synopsis or first few sections for study info extraction
text_for_extraction = ""
sections = state.get("sections", {})
# Check if sections is a list (section names only) or a dict (section name -> content)
if isinstance(sections, list):
# Just use the document text since we don't have section content
if "document_text" in state:
text_for_extraction = state["document_text"][:20000] # Use first 20k chars
else:
# Try to find synopsis or summary section first
for section_name in ["synopsis", "summary", "overview"]:
if section_name.lower() in [s.lower() for s in sections.keys()]:
section_key = next(k for k in sections.keys() if k.lower() == section_name.lower())
text_for_extraction = sections[section_key]
break
# If no synopsis found, use the beginning of the document
if not text_for_extraction and "document_text" in state:
text_for_extraction = state["document_text"][:20000] # Use first 20k chars
if not text_for_extraction:
return {
**state,
"status": "error",
"error": "No text available for study info extraction"
}
# Extract study info using LLM
study_info = llm_interface.extract_study_info(text_for_extraction)
if not study_info:
return {
**state,
"status": "error",
"error": "Failed to extract study information"
}
# Ensure protocol_id is in study_info
if "protocol_id" not in study_info and "document_metadata" in state:
study_info["protocol_id"] = state["document_metadata"].get("protocol_id")
return {
**state,
"extracted_study": study_info,
"status": "study_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_study_info: {str(e)}"
}
def extract_objectives_endpoints(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract objectives and endpoints using LLM."""
if state.get("status") == "error":
return state
try:
sections = state.get("sections", {})
protocol_id = state.get("extracted_study", {}).get("protocol_id")
if not protocol_id:
protocol_id = state.get("document_metadata", {}).get("protocol_id")
if not protocol_id:
return {
**state,
"status": "error",
"error": "No protocol ID available for extraction"
}
# Find objectives/endpoints section
text_for_extraction = ""
for section_name in ["objectives", "objective", "endpoint", "endpoints"]:
for key in sections.keys():
if section_name.lower() in key.lower():
text_for_extraction = sections[key]
break
if text_for_extraction:
break
if not text_for_extraction:
return {
**state,
"status": "warning",
"error": "No objectives/endpoints section found"
}
# Extract objectives and endpoints
result = llm_interface.extract_objectives_and_endpoints(text_for_extraction, protocol_id)
if not result:
return {
**state,
"status": "warning",
"error": "Failed to extract objectives and endpoints"
}
return {
**state,
"extracted_objectives": result.get("objectives", []),
"extracted_endpoints": result.get("endpoints", []),
"status": "objectives_endpoints_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_objectives_endpoints: {str(e)}"
}
def extract_population_criteria(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract inclusion and exclusion criteria using LLM."""
if state.get("status") == "error":
return state
try:
sections = state.get("sections", {})
protocol_id = state.get("extracted_study", {}).get("protocol_id")
if not protocol_id:
protocol_id = state.get("document_metadata", {}).get("protocol_id")
# Find criteria section
text_for_extraction = ""
for section_name in ["eligibility", "inclusion", "exclusion", "criteria", "population"]:
for key in sections.keys():
if section_name.lower() in key.lower():
text_for_extraction = sections[key]
break
if text_for_extraction:
break
if not text_for_extraction:
return {
**state,
"status": "warning",
"error": "No population criteria section found"
}
# Extract criteria
result = llm_interface.extract_population_criteria(text_for_extraction, protocol_id)
if not result:
return {
**state,
"status": "warning",
"error": "Failed to extract population criteria"
}
return {
**state,
"extracted_population": result,
"status": "population_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_population_criteria: {str(e)}"
}
def extract_study_design(state: DocumentExtractionState) -> DocumentExtractionState:
"""Extract study design information using LLM."""
if state.get("status") == "error":
return state
try:
sections = state.get("sections", {})
protocol_id = state.get("extracted_study", {}).get("protocol_id")
if not protocol_id:
protocol_id = state.get("document_metadata", {}).get("protocol_id")
# Find study design section
text_for_extraction = ""
for section_name in ["study design", "design", "methodology"]:
for key in sections.keys():
if section_name.lower() in key.lower():
text_for_extraction = sections[key]
break
if text_for_extraction:
break
if not text_for_extraction:
return {
**state,
"status": "warning",
"error": "No study design section found"
}
# Extract study design
result = llm_interface.extract_study_design(text_for_extraction, protocol_id)
if not result:
return {
**state,
"status": "warning",
"error": "Failed to extract study design"
}
return {
**state,
"extracted_design": result,
"status": "design_extracted"
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in extract_study_design: {str(e)}"
}
def store_in_knowledge_base(state: DocumentExtractionState) -> DocumentExtractionState:
"""Store extracted information in the knowledge base."""
try:
# Skip if there was a critical error
if state.get("status") == "error":
return state
# Extract data from state
document_metadata = state.get("document_metadata", {})
study_info = state.get("extracted_study", {})
objectives = state.get("extracted_objectives", [])
endpoints = state.get("extracted_endpoints", [])
population = state.get("extracted_population", {})
design = state.get("extracted_design", {})
vector_chunks = state.get("vector_chunks", [])
# Ensure we have a protocol ID
protocol_id = study_info.get("protocol_id")
if not protocol_id:
protocol_id = document_metadata.get("protocol_id")
if not protocol_id:
return {
**state,
"status": "error",
"error": "No protocol ID available for knowledge base storage"
}
# Add protocol_id to document_metadata
document_metadata["protocol_id"] = protocol_id
# Store in NoSQL DB
doc_id = knowledge_store.store_document_metadata(document_metadata)
# Store study info if available
if study_info:
study_id = knowledge_store.store_study_info(study_info)
# Store objectives if available
if objectives:
knowledge_store.store_objectives(protocol_id, objectives)
# Store endpoints if available
if endpoints:
knowledge_store.store_endpoints(protocol_id, endpoints)
# Store population criteria if available
if population and "inclusion_criteria" in population:
inclusion = population.get("inclusion_criteria", [])
exclusion = population.get("exclusion_criteria", [])
# Add criterion_type to each criterion
for criterion in inclusion:
criterion["criterion_type"] = "Inclusion"
criterion["protocol_id"] = protocol_id
for criterion in exclusion:
criterion["criterion_type"] = "Exclusion"
criterion["protocol_id"] = protocol_id
# Store all criteria
all_criteria = inclusion + exclusion
knowledge_store.store_population_criteria(protocol_id, all_criteria)
# Store in vector store if chunks available
if vector_chunks:
result = knowledge_store.add_documents(vector_chunks)
if result.get("status") == "error":
return {
**state,
"status": "warning",
"error": f"Warning: Failed to add to vector store: {result.get('message')}"
}
return {
**state,
"status": "completed",
"document_id": doc_id,
}
except Exception as e:
return {
**state,
"status": "error",
"error": f"Exception in store_in_knowledge_base: {str(e)}"
}
# =========================================================================
# Protocol Coach Workflow Nodes
# =========================================================================
def retrieve_context_for_query(state: ProtocolCoachState) -> ProtocolCoachState:
"""Retrieve relevant context for a user query."""
try:
query = state["query"]
# Query vector store for context
relevant_docs = knowledge_store.similarity_search(
query=query,
k=5 # Get top 5 most relevant chunks
)
if not relevant_docs:
return {
**state,
"retrieved_context": [],
"error": "No relevant context found"
}
# Format results for easy use
context = [
{
"page_content": doc.page_content,
"metadata": doc.metadata
}
for doc in relevant_docs
]
return {
**state,
"retrieved_context": context
}
except Exception as e:
return {
**state,
"error": f"Exception in retrieve_context_for_query: {str(e)}"
}
def answer_query(state: ProtocolCoachState) -> ProtocolCoachState:
"""Generate answer to user query using retrieved context."""
try:
query = state["query"]
context = state.get("retrieved_context", [])
chat_history = state.get("chat_history", [])
if not context:
return {
**state,
"response": "I don't have enough context to answer that question about the protocol. Please try asking something else or upload relevant documents."
}
# Generate response using LLM
response = llm_interface.answer_protocol_question(
question=query,
context=context,
chat_history=chat_history
)
if not response:
return {
**state,
"response": "I encountered an issue while generating a response. Please try again."
}
return {
**state,
"response": response
}
except Exception as e:
return {
**state,
"response": f"Error: {str(e)}",
"error": f"Exception in answer_query: {str(e)}"
}
# =========================================================================
# Content Authoring Workflow Nodes
# =========================================================================
def retrieve_content_examples(state: ContentAuthoringState) -> ContentAuthoringState:
"""Retrieve examples of similar content for authoring."""
try:
section_type = state["section_type"]
target_protocol_id = state.get("target_protocol_id")
# Create a search query based on section type
search_query = f"{section_type} section for clinical study protocol"
# Set up potential filters
filter_dict = None
if target_protocol_id:
# Exclude the target protocol from examples if specified
filter_dict = {"protocol_id": {"$ne": target_protocol_id}}
# Query vector store for examples
relevant_docs = knowledge_store.similarity_search(
query=search_query,
k=3,
filter_dict=filter_dict
)
if not relevant_docs:
return {
**state,
"retrieved_context": [],
"error": "No relevant examples found"
}
# Format results for easy use
context = [
{
"page_content": doc.page_content,
"metadata": doc.metadata
}
for doc in relevant_docs
]
return {
**state,
"retrieved_context": context
}
except Exception as e:
return {
**state,
"error": f"Exception in retrieve_content_examples: {str(e)}"
}
def generate_content(state: ContentAuthoringState) -> ContentAuthoringState:
"""Generate content for authoring."""
try:
section_type = state["section_type"]
context = state.get("retrieved_context", [])
target_protocol_id = state.get("target_protocol_id")
style_guide = state.get("style_guide")
if not context:
return {
**state,
"generated_content": "I don't have enough examples to generate a good section. Please upload more documents or try a different section type.",
"error": "No context available for generation"
}
# Generate content using LLM
content = llm_interface.generate_content_from_knowledge(
section_type=section_type,
context=context,
protocol_id=target_protocol_id,
style_guide=style_guide
)
if not content:
return {
**state,
"generated_content": "I encountered an issue while generating content. Please try again.",
"error": "Failed to generate content"
}
return {
**state,
"generated_content": content
}
except Exception as e:
return {
**state,
"generated_content": f"Error: {str(e)}",
"error": f"Exception in generate_content: {str(e)}"
}
def critique_content(state: ContentAuthoringState) -> ContentAuthoringState:
"""Critique generated content for quality and consistency."""
# This would normally use an LLM to critique content
# For simplicity, we're returning the content unchanged
return state
# =========================================================================
# Traceability Workflow Nodes
# =========================================================================
def retrieve_document_entities(state: TraceabilityState) -> TraceabilityState:
"""Retrieve entities from source and target documents."""
try:
source_doc_id = state["source_document_id"]
target_doc_id = state["target_document_id"]
entity_type = state["entity_type"]
# Get document metadata
source_doc = knowledge_store.get_document_by_id(source_doc_id)
target_doc = knowledge_store.get_document_by_id(target_doc_id)
if not source_doc or not target_doc:
return {
**state,
"error": "One or both documents not found"
}
# Get protocol IDs
source_protocol_id = source_doc.get("protocol_id")
target_protocol_id = target_doc.get("protocol_id")
if not source_protocol_id or not target_protocol_id:
return {
**state,
"error": "Protocol ID missing from one or both documents"
}
# Retrieve entities based on entity type
source_entities = []
target_entities = []
if entity_type == "objectives":
source_entities = knowledge_store.get_objectives_by_protocol_id(source_protocol_id)
target_entities = knowledge_store.get_objectives_by_protocol_id(target_protocol_id)
elif entity_type == "endpoints":
source_entities = knowledge_store.get_endpoints_by_protocol_id(source_protocol_id)
target_entities = knowledge_store.get_endpoints_by_protocol_id(target_protocol_id)
elif entity_type == "population":
source_entities = knowledge_store.get_population_criteria_by_protocol_id(source_protocol_id)
target_entities = knowledge_store.get_population_criteria_by_protocol_id(target_protocol_id)
if not source_entities or not target_entities:
return {
**state,
"error": f"No {entity_type} found in one or both documents"
}
return {
**state,
"source_entities": source_entities,
"target_entities": target_entities
}
except Exception as e:
return {
**state,
"error": f"Exception in retrieve_document_entities: {str(e)}"
}
def match_entities(state: TraceabilityState) -> TraceabilityState:
"""Match entities between documents based on similarity."""
try:
if "error" in state:
return state
source_entities = state.get("source_entities", [])
target_entities = state.get("target_entities", [])
# Simple matching - in a real system this would use more sophisticated comparison
matched_pairs = []
for source_entity in source_entities:
matches = []
for target_entity in target_entities:
# Compare based on description/text
source_text = source_entity.get("description", source_entity.get("text", ""))
target_text = target_entity.get("description", target_entity.get("text", ""))
if not source_text or not target_text:
continue
# Simple text comparison - LLM would do better comparison in real system
if len(source_text) > 0 and len(target_text) > 0:
matches.append({
"source_entity": source_entity,
"target_entity": target_entity,
"source_text": source_text,
"target_text": target_text,
"entity_type": state["entity_type"]
})
# If matches found, take the top one
if matches:
matched_pairs.append(matches[0])
return {
**state,
"matched_pairs": matched_pairs
}
except Exception as e:
return {
**state,
"error": f"Exception in match_entities: {str(e)}"
}
def analyze_matches(state: TraceabilityState) -> TraceabilityState:
"""Analyze matches between documents to identify consistency issues."""
try:
if "error" in state:
return state
matched_pairs = state.get("matched_pairs", [])
source_doc_id = state["source_document_id"]
target_doc_id = state["target_document_id"]
if not matched_pairs:
return {
**state,
"analysis": "No matching entities found between the documents."
}
# Get document metadata
source_doc = knowledge_store.get_document_by_id(source_doc_id)
target_doc = knowledge_store.get_document_by_id(target_doc_id)
# Use LLM to analyze matches
analysis = llm_interface.find_document_connections(
source_doc_info=source_doc,
target_doc_info=target_doc,
entity_pairs=matched_pairs
)
return {
**state,
"analysis": analysis
}
except Exception as e:
return {
**state,
"error": f"Exception in analyze_matches: {str(e)}",
"analysis": f"Error analyzing matches: {str(e)}"
}
# =========================================================================
# Graph Building Functions
# =========================================================================
def build_document_extraction_graph():
"""Build and return document extraction workflow graph."""
workflow = StateGraph(DocumentExtractionState)
# Add nodes
workflow.add_node("parse_document", parse_document)
workflow.add_node("extract_study_info", extract_study_info)
workflow.add_node("extract_objectives_endpoints", extract_objectives_endpoints)
workflow.add_node("extract_population_criteria", extract_population_criteria)
workflow.add_node("extract_study_design", extract_study_design)
workflow.add_node("store_in_knowledge_base", store_in_knowledge_base)
# Add edges - sequential process
workflow.add_edge("parse_document", "extract_study_info")
workflow.add_edge("extract_study_info", "extract_objectives_endpoints")
workflow.add_edge("extract_objectives_endpoints", "extract_population_criteria")
workflow.add_edge("extract_population_criteria", "extract_study_design")
workflow.add_edge("extract_study_design", "store_in_knowledge_base")
workflow.add_edge("store_in_knowledge_base", END)
# Instead of using conditional edges for all nodes,
# let each function handle its own error status
# This simplifies the graph structure and avoids the conditional edge issue
workflow.set_entry_point("parse_document")
return workflow.compile()
def build_protocol_coach_graph():
"""Build and return protocol coach workflow graph."""
workflow = StateGraph(ProtocolCoachState)
# Add nodes
workflow.add_node("retrieve_context", retrieve_context_for_query)
workflow.add_node("answer_query", answer_query)
# Add edges
workflow.add_edge("retrieve_context", "answer_query")
workflow.add_edge("answer_query", END)
workflow.set_entry_point("retrieve_context")
return workflow.compile()
def build_content_authoring_graph():
"""Build and return content authoring workflow graph."""
workflow = StateGraph(ContentAuthoringState)
# Add nodes
workflow.add_node("retrieve_examples", retrieve_content_examples)
workflow.add_node("generate_content", generate_content)
workflow.add_node("critique_content", critique_content)
# Add edges
workflow.add_edge("retrieve_examples", "generate_content")
workflow.add_edge("generate_content", "critique_content")
workflow.add_edge("critique_content", END)
workflow.set_entry_point("retrieve_examples")
return workflow.compile()
def build_traceability_graph():
"""Build and return traceability analysis workflow graph."""
workflow = StateGraph(TraceabilityState)
# Add nodes
workflow.add_node("retrieve_entities", retrieve_document_entities)
workflow.add_node("match_entities", match_entities)
workflow.add_node("analyze_matches", analyze_matches)
# Add edges
workflow.add_edge("retrieve_entities", "match_entities")
workflow.add_edge("match_entities", "analyze_matches")
workflow.add_edge("analyze_matches", END)
workflow.set_entry_point("retrieve_entities")
return workflow.compile()