cryogenic22's picture
Create app.py
69b6b11 verified
"""
Streamlit application for the Pharmaceutical R&D Knowledge Ecosystem.
"""
import streamlit as st
import os
import pandas as pd
import json
import tempfile
import time
from datetime import datetime
from pdf_processor import PDFProcessor
from knowledge_store import KnowledgeStore
from llm_interface import LLMInterface
from graph_builder import (
init_handlers,
build_document_extraction_graph,
build_protocol_coach_graph,
build_content_authoring_graph,
build_traceability_graph
)
# =========================================================================
# App Setup and Configuration
# =========================================================================
# Page configuration
st.set_page_config(
page_title="Pharma R&D Knowledge Ecosystem",
page_icon="πŸ’Š",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialize session state variables if they don't exist
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "documents" not in st.session_state:
st.session_state.documents = []
if "knowledge_base_stats" not in st.session_state:
st.session_state.knowledge_base_stats = {
"documents": 0,
"studies": 0,
"endpoints": 0,
"objectives": 0,
"vectors": 0
}
# Initialize our handlers and graphs
@st.cache_resource
def initialize_app():
"""Initialize app resources and LangGraph workflows."""
# Get API key from environment or secrets
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key and hasattr(st, "secrets") and "ANTHROPIC_API_KEY" in st.secrets:
api_key = st.secrets["ANTHROPIC_API_KEY"]
# Initialize handlers
pdf_processor, knowledge_store, llm_interface = init_handlers(api_key)
# Build LangGraph workflows
extraction_graph = build_document_extraction_graph()
coach_graph = build_protocol_coach_graph()
authoring_graph = build_content_authoring_graph()
traceability_graph = build_traceability_graph()
return {
"pdf_processor": pdf_processor,
"knowledge_store": knowledge_store,
"llm_interface": llm_interface,
"extraction_graph": extraction_graph,
"coach_graph": coach_graph,
"authoring_graph": authoring_graph,
"traceability_graph": traceability_graph
}
# Initialize app resources
app_resources = initialize_app()
pdf_processor = app_resources["pdf_processor"]
knowledge_store = app_resources["knowledge_store"]
llm_interface = app_resources["llm_interface"]
extraction_graph = app_resources["extraction_graph"]
coach_graph = app_resources["coach_graph"]
authoring_graph = app_resources["authoring_graph"]
traceability_graph = app_resources["traceability_graph"]
# =========================================================================
# Helper Functions
# =========================================================================
def update_knowledge_base_stats():
"""Update the knowledge base statistics in session state."""
try:
# Get counts of different entity types
documents = knowledge_store.get_all_documents()
document_count = len(documents)
# Get unique protocol IDs
protocol_ids = set()
for doc in documents:
if "protocol_id" in doc and doc["protocol_id"]:
protocol_ids.add(doc["protocol_id"])
# Get vector store stats
vector_stats = knowledge_store.get_vector_store_stats()
vector_count = vector_stats.get("document_count", 0)
# Count objectives and endpoints across all protocols
objective_count = 0
endpoint_count = 0
for protocol_id in protocol_ids:
objectives = knowledge_store.get_objectives_by_protocol_id(protocol_id)
endpoints = knowledge_store.get_endpoints_by_protocol_id(protocol_id)
objective_count += len(objectives)
endpoint_count += len(endpoints)
# Update session state
st.session_state.knowledge_base_stats = {
"documents": document_count,
"studies": len(protocol_ids),
"objectives": objective_count,
"endpoints": endpoint_count,
"vectors": vector_count
}
except Exception as e:
st.error(f"Error updating knowledge base stats: {e}")
def process_document(uploaded_file):
"""Process an uploaded document and store in knowledge base."""
try:
# Create a progress bar
progress_bar = st.progress(0)
status_text = st.empty()
# Step 1: Save the uploaded file
status_text.text("Saving uploaded file...")
progress_bar.progress(10)
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(uploaded_file.getbuffer())
file_path = tmp_file.name
# Step 2: Process through LangGraph extraction workflow
status_text.text("Parsing document...")
progress_bar.progress(20)
# Initialize state for extraction
initial_state = {
"document_path": file_path,
"status": "initialized"
}
# Run extraction workflow
result_state = extraction_graph.invoke(initial_state)
# Update progress based on status
if result_state.get("status") == "error":
status_text.text(f"Error: {result_state.get('error', 'Unknown error')}")
progress_bar.progress(100)
return {
"status": "error",
"error": result_state.get("error", "Unknown error"),
"filename": uploaded_file.name
}
# Update progress
status_text.text("Processing completed successfully!")
progress_bar.progress(100)
# Update knowledge base stats
update_knowledge_base_stats()
# Return result
return {
"status": "success",
"filename": uploaded_file.name,
"document_id": result_state.get("document_id"),
"protocol_id": result_state.get("extracted_study", {}).get("protocol_id")
}
except Exception as e:
st.error(f"Error processing document: {e}")
return {
"status": "error",
"error": str(e),
"filename": uploaded_file.name
}
finally:
# Clean up temporary file
if 'file_path' in locals():
try:
os.unlink(file_path)
except:
pass
def chat_with_protocol_coach(query):
"""Process a query through the Protocol Coach."""
try:
# Initialize state for Protocol Coach
initial_state = {
"query": query,
"chat_history": st.session_state.chat_history
}
# Run Protocol Coach workflow
result_state = coach_graph.invoke(initial_state)
return {
"status": "success",
"response": result_state.get("response", "I couldn't generate a response."),
"context": result_state.get("retrieved_context", [])
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}
def generate_document_section(section_type, protocol_id=None, style_guide=None):
"""Generate a document section using the content authoring workflow."""
try:
# Initialize state for Content Authoring
initial_state = {
"section_type": section_type,
"target_protocol_id": protocol_id,
"style_guide": style_guide
}
# Run Content Authoring workflow
result_state = authoring_graph.invoke(initial_state)
return {
"status": "success",
"content": result_state.get("generated_content", "I couldn't generate the content."),
"context": result_state.get("retrieved_context", [])
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}
def analyze_document_traceability(source_id, target_id, entity_type):
"""Analyze traceability between two documents."""
try:
# Initialize state for Traceability Analysis
initial_state = {
"source_document_id": source_id,
"target_document_id": target_id,
"entity_type": entity_type
}
# Run Traceability Analysis workflow
result_state = traceability_graph.invoke(initial_state)
return {
"status": "success",
"analysis": result_state.get("analysis", "I couldn't perform the analysis."),
"matched_pairs": result_state.get("matched_pairs", [])
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}
# =========================================================================
# Sidebar: Document Upload and Management
# =========================================================================
def render_sidebar():
"""Render the sidebar for document management."""
st.sidebar.title("Document Management")
# Knowledge Base Stats
st.sidebar.subheader("Knowledge Base Stats")
stats = st.session_state.knowledge_base_stats
col1, col2 = st.sidebar.columns(2)
col1.metric("Documents", stats["documents"])
col2.metric("Studies", stats["studies"])
col1.metric("Objectives", stats["objectives"])
col2.metric("Endpoints", stats["endpoints"])
st.sidebar.metric("Vector Chunks", stats["vectors"])
# Document Upload
st.sidebar.subheader("Upload Documents")
uploaded_files = st.sidebar.file_uploader(
"Upload Protocol/SAP PDFs",
type="pdf",
accept_multiple_files=True,
help="Upload clinical documents (Protocol, SAP, etc.) to add to the knowledge base."
)
# Process uploaded files if any
if uploaded_files:
if st.sidebar.button("Process Documents"):
with st.sidebar.expander("Processing Results", expanded=True):
for uploaded_file in uploaded_files:
st.write(f"Processing: {uploaded_file.name}")
result = process_document(uploaded_file)
if result["status"] == "success":
st.success(f"Successfully processed {result['filename']}")
# Add to documents list if not already there
doc_exists = False
for doc in st.session_state.documents:
if doc.get("filename") == result["filename"]:
doc_exists = True
break
if not doc_exists:
st.session_state.documents.append({
"filename": result["filename"],
"document_id": result.get("document_id"),
"protocol_id": result.get("protocol_id"),
"processed_date": datetime.now().strftime("%Y-%m-%d %H:%M")
})
else:
st.error(f"Error processing {result['filename']}: {result.get('error', 'Unknown error')}")
# Document list
st.sidebar.subheader("Processed Documents")
if not st.session_state.documents:
st.sidebar.info("No documents processed yet.")
else:
for i, doc in enumerate(st.session_state.documents):
with st.sidebar.expander(f"{doc['filename']}"):
st.write(f"**Protocol ID:** {doc.get('protocol_id', 'Unknown')}")
st.write(f"**Processed:** {doc.get('processed_date', 'Unknown')}")
# Refresh Stats Button
if st.sidebar.button("Refresh Stats"):
update_knowledge_base_stats()
st.sidebar.success("Stats refreshed!")
# =========================================================================
# Main Content Tabs
# =========================================================================
def render_protocol_coach_tab():
"""Render the Protocol Coach chatbot tab."""
st.header("Protocol Coach Chatbot")
st.info("Ask questions about the protocol documents in the knowledge base. The Protocol Coach will retrieve relevant information to answer your questions.")
# Initialize or display chat history
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if query := st.chat_input("Ask about protocols..."):
# Add user message to chat history and display
st.session_state.chat_history.append({"role": "user", "content": query})
with st.chat_message("user"):
st.markdown(query)
# Process query
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
result = chat_with_protocol_coach(query)
if result["status"] == "success":
st.markdown(result["response"])
# Show context sources if debug mode enabled
if st.session_state.get("debug_mode", False):
with st.expander("Context Sources"):
for i, ctx in enumerate(result.get("context", [])):
st.write(f"**Source {i+1}:** {ctx.get('metadata', {}).get('source', 'Unknown')}")
st.write(f"**Section:** {ctx.get('metadata', {}).get('section', 'Unknown')}")
st.write("---")
# Add assistant response to chat history
st.session_state.chat_history.append({"role": "assistant", "content": result["response"]})
else:
st.error(f"Error: {result.get('error', 'Unknown error')}")
st.session_state.chat_history.append({"role": "assistant", "content": f"Error: {result.get('error', 'Unknown error')}"})
def render_content_authoring_tab():
"""Render the Content Authoring tab."""
st.header("Content Authoring Assistant")
st.info("Generate document sections based on knowledge extracted from similar documents.")
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Content Generation Settings")
# Section Type Selection
section_types = [
"Introduction",
"Objectives and Endpoints",
"Study Design",
"Study Population",
"Statistical Considerations",
"Inclusion Criteria",
"Exclusion Criteria",
"Safety Assessments",
"Pharmacokinetic Assessments"
]
section_type = st.selectbox("Select Section Type", section_types)
# Protocol Selection for Context (Optional)
protocol_options = ["--None--"]
for doc in st.session_state.documents:
if doc.get("protocol_id"):
protocol_options.append(doc.get("protocol_id"))
target_protocol = st.selectbox(
"Target Protocol ID (Optional)",
protocol_options
)
target_protocol = None if target_protocol == "--None--" else target_protocol
# Style Guide (Optional)
style_guide = st.text_area(
"Style Guide (Optional)",
placeholder="Enter any specific style guidelines or content requirements..."
)
# Generate Button
generate_button = st.button("Generate Content")
# Debug toggle
st.session_state.debug_mode = st.checkbox("Show Context Sources", value=st.session_state.get("debug_mode", False))
with col2:
st.subheader("Generated Content")
if generate_button:
with st.spinner("Generating content..."):
result = generate_document_section(
section_type=section_type,
protocol_id=target_protocol,
style_guide=style_guide if style_guide else None
)
if result["status"] == "success":
st.markdown(result["content"])
# Show context sources if debug mode enabled
if st.session_state.get("debug_mode", False):
with st.expander("Context Sources"):
for i, ctx in enumerate(result.get("context", [])):
st.write(f"**Source {i+1}:** {ctx.get('metadata', {}).get('source', 'Unknown')}")
st.write(f"**Section:** {ctx.get('metadata', {}).get('section', 'Unknown')}")
st.write("---")
else:
st.error(f"Error: {result.get('error', 'Unknown error')}")
def render_traceability_tab():
"""Render the Document Traceability tab."""
st.header("Cross-Document Traceability")
st.info("Analyze relationships between related documents (e.g., Protocol and SAP).")
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Traceability Analysis Settings")
# Document Selection
document_options = []
for doc in st.session_state.documents:
document_options.append({
"id": doc.get("document_id", ""),
"label": f"{doc['filename']} ({doc.get('protocol_id', 'Unknown')})"
})
# Source Document
source_options = [{"id": "", "label": "--Select Source Document--"}] + document_options
source_doc = st.selectbox(
"Source Document",
options=source_options,
format_func=lambda x: x["label"]
)
# Target Document
target_options = [{"id": "", "label": "--Select Target Document--"}] + document_options
target_doc = st.selectbox(
"Target Document",
options=target_options,
format_func=lambda x: x["label"]
)
# Entity Type
entity_types = [
{"id": "objectives", "label": "Study Objectives"},
{"id": "endpoints", "label": "Endpoints"},
{"id": "population", "label": "Population Criteria"}
]
entity_type = st.selectbox(
"Entity Type to Compare",
options=entity_types,
format_func=lambda x: x["label"]
)
# Analyze Button
analyze_button = st.button("Analyze Traceability")
with col2:
st.subheader("Analysis Results")
if analyze_button:
if not source_doc["id"] or not target_doc["id"]:
st.error("Please select both source and target documents.")
else:
with st.spinner("Analyzing traceability..."):
result = analyze_document_traceability(
source_id=source_doc["id"],
target_id=target_doc["id"],
entity_type=entity_type["id"]
)
if result["status"] == "success":
st.markdown(result["analysis"])
# Show matched pairs if debug mode enabled
if st.session_state.get("debug_mode", False) and result.get("matched_pairs"):
with st.expander("Matched Entity Pairs"):
for i, pair in enumerate(result["matched_pairs"]):
st.write(f"**Pair {i+1}**")
st.write(f"**Source:** {pair.get('source_text', 'Unknown')}")
st.write(f"**Target:** {pair.get('target_text', 'Unknown')}")
st.write("---")
else:
st.error(f"Error: {result.get('error', 'Unknown error')}")
def render_knowledge_explorer_tab():
"""Render the Knowledge Base Explorer tab."""
st.header("Knowledge Base Explorer")
st.info("Explore the structured data extracted from documents in the knowledge base.")
# Entity Type Selection
entity_types = [
{"id": "studies", "label": "Studies"},
{"id": "objectives", "label": "Study Objectives"},
{"id": "endpoints", "label": "Endpoints"},
{"id": "population", "label": "Population Criteria"},
{"id": "documents", "label": "Documents"}
]
entity_type = st.selectbox(
"Select Entity Type",
options=entity_types,
format_func=lambda x: x["label"]
)
# Filter by Protocol ID (Optional)
protocol_options = ["--All Protocols--"]
for doc in st.session_state.documents:
if doc.get("protocol_id") and doc.get("protocol_id") not in protocol_options:
protocol_options.append(doc.get("protocol_id"))
filter_protocol = st.selectbox(
"Filter by Protocol ID",
protocol_options
)
filter_protocol = None if filter_protocol == "--All Protocols--" else filter_protocol
# Search Query (Optional)
search_query = st.text_input(
"Search Query (Optional)",
placeholder="Enter text to search for..."
)
# Display Results
st.subheader("Results")
try:
# Retrieve data based on entity type
if entity_type["id"] == "studies":
if filter_protocol:
data = [knowledge_store.get_study_by_protocol_id(filter_protocol)]
else:
data = knowledge_store.get_all_studies()
elif entity_type["id"] == "objectives":
if filter_protocol:
data = knowledge_store.get_objectives_by_protocol_id(filter_protocol)
else:
# Get all objectives across protocols
data = []
documents = knowledge_store.get_all_documents()
protocol_ids = set()
for doc in documents:
if "protocol_id" in doc and doc["protocol_id"]:
protocol_ids.add(doc["protocol_id"])
for pid in protocol_ids:
data.extend(knowledge_store.get_objectives_by_protocol_id(pid))
elif entity_type["id"] == "endpoints":
if filter_protocol:
data = knowledge_store.get_endpoints_by_protocol_id(filter_protocol)
else:
# Get all endpoints across protocols
data = []
documents = knowledge_store.get_all_documents()
protocol_ids = set()
for doc in documents:
if "protocol_id" in doc and doc["protocol_id"]:
protocol_ids.add(doc["protocol_id"])
for pid in protocol_ids:
data.extend(knowledge_store.get_endpoints_by_protocol_id(pid))
elif entity_type["id"] == "population":
if filter_protocol:
data = knowledge_store.get_population_criteria_by_protocol_id(filter_protocol)
else:
# Get all population criteria across protocols
data = []
documents = knowledge_store.get_all_documents()
protocol_ids = set()
for doc in documents:
if "protocol_id" in doc and doc["protocol_id"]:
protocol_ids.add(doc["protocol_id"])
for pid in protocol_ids:
data.extend(knowledge_store.get_population_criteria_by_protocol_id(pid))
elif entity_type["id"] == "documents":
if filter_protocol:
data = knowledge_store.get_documents_by_protocol_id(filter_protocol)
else:
data = knowledge_store.get_all_documents()
else:
data = []
# Filter by search query if provided
if search_query:
filtered_data = []
search_lower = search_query.lower()
for item in data:
# Convert item to string for searching
item_str = json.dumps(item).lower()
if search_lower in item_str:
filtered_data.append(item)
data = filtered_data
# Display results
if not data:
st.info("No data found.")
else:
st.write(f"{len(data)} items found")
# Display as table if possible, otherwise as JSON
try:
df = pd.DataFrame(data)
st.dataframe(df, use_container_width=True)
except Exception as e:
st.json(data)
except Exception as e:
st.error(f"Error retrieving data: {e}")
# =========================================================================
# Main App
# =========================================================================
def main():
"""Main application function."""
st.title("🧠 Pharmaceutical R&D Knowledge Ecosystem")
# Render the sidebar for document management
render_sidebar()
# Initialize knowledge base stats on first load
if st.session_state.knowledge_base_stats["documents"] == 0:
update_knowledge_base_stats()
# Main content tabs
tab1, tab2, tab3, tab4 = st.tabs([
"πŸ“ Content Authoring",
"πŸ€– Protocol Coach",
"πŸ” Knowledge Explorer",
"πŸ”„ Cross-Document Traceability"
])
with tab1:
render_content_authoring_tab()
with tab2:
render_protocol_coach_tab()
with tab3:
render_knowledge_explorer_tab()
with tab4:
render_traceability_tab()
# Footer
st.markdown("---")
st.caption("Pharmaceutical R&D Knowledge Ecosystem | A demonstration of AI-assisted document processing and knowledge management")
if __name__ == "__main__":
main()