""" Streamlit application for the Pharmaceutical R&D Knowledge Ecosystem. """ import streamlit as st import os import pandas as pd import json import tempfile import time from datetime import datetime from pdf_processor import PDFProcessor from knowledge_store import KnowledgeStore from llm_interface import LLMInterface from graph_builder import ( init_handlers, build_document_extraction_graph, build_protocol_coach_graph, build_content_authoring_graph, build_traceability_graph ) # ========================================================================= # App Setup and Configuration # ========================================================================= # Page configuration st.set_page_config( page_title="Pharma R&D Knowledge Ecosystem", page_icon="💊", layout="wide", initial_sidebar_state="expanded" ) # Initialize session state variables if they don't exist if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "documents" not in st.session_state: st.session_state.documents = [] if "knowledge_base_stats" not in st.session_state: st.session_state.knowledge_base_stats = { "documents": 0, "studies": 0, "endpoints": 0, "objectives": 0, "vectors": 0 } # Initialize our handlers and graphs @st.cache_resource def initialize_app(): """Initialize app resources and LangGraph workflows.""" # Get API key from environment or secrets api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key and hasattr(st, "secrets") and "ANTHROPIC_API_KEY" in st.secrets: api_key = st.secrets["ANTHROPIC_API_KEY"] # Initialize handlers pdf_processor, knowledge_store, llm_interface = init_handlers(api_key) # Build LangGraph workflows extraction_graph = build_document_extraction_graph() coach_graph = build_protocol_coach_graph() authoring_graph = build_content_authoring_graph() traceability_graph = build_traceability_graph() return { "pdf_processor": pdf_processor, "knowledge_store": knowledge_store, "llm_interface": llm_interface, "extraction_graph": extraction_graph, "coach_graph": coach_graph, "authoring_graph": authoring_graph, "traceability_graph": traceability_graph } # Initialize app resources app_resources = initialize_app() pdf_processor = app_resources["pdf_processor"] knowledge_store = app_resources["knowledge_store"] llm_interface = app_resources["llm_interface"] extraction_graph = app_resources["extraction_graph"] coach_graph = app_resources["coach_graph"] authoring_graph = app_resources["authoring_graph"] traceability_graph = app_resources["traceability_graph"] # ========================================================================= # Helper Functions # ========================================================================= def update_knowledge_base_stats(): """Update the knowledge base statistics in session state.""" try: # Get counts of different entity types documents = knowledge_store.get_all_documents() document_count = len(documents) # Get unique protocol IDs protocol_ids = set() for doc in documents: if "protocol_id" in doc and doc["protocol_id"]: protocol_ids.add(doc["protocol_id"]) # Get vector store stats vector_stats = knowledge_store.get_vector_store_stats() vector_count = vector_stats.get("document_count", 0) # Count objectives and endpoints across all protocols objective_count = 0 endpoint_count = 0 for protocol_id in protocol_ids: objectives = knowledge_store.get_objectives_by_protocol_id(protocol_id) endpoints = knowledge_store.get_endpoints_by_protocol_id(protocol_id) objective_count += len(objectives) endpoint_count += len(endpoints) # Update session state st.session_state.knowledge_base_stats = { "documents": document_count, "studies": len(protocol_ids), "objectives": objective_count, "endpoints": endpoint_count, "vectors": vector_count } except Exception as e: st.error(f"Error updating knowledge base stats: {e}") def process_document(uploaded_file): """Process an uploaded document and store in knowledge base.""" try: # Create a progress bar progress_bar = st.progress(0) status_text = st.empty() # Step 1: Save the uploaded file status_text.text("Saving uploaded file...") progress_bar.progress(10) # Save uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: tmp_file.write(uploaded_file.getbuffer()) file_path = tmp_file.name # Step 2: Process through LangGraph extraction workflow status_text.text("Parsing document...") progress_bar.progress(20) # Initialize state for extraction initial_state = { "document_path": file_path, "status": "initialized" } # Run extraction workflow result_state = extraction_graph.invoke(initial_state) # Update progress based on status if result_state.get("status") == "error": status_text.text(f"Error: {result_state.get('error', 'Unknown error')}") progress_bar.progress(100) return { "status": "error", "error": result_state.get("error", "Unknown error"), "filename": uploaded_file.name } # Update progress status_text.text("Processing completed successfully!") progress_bar.progress(100) # Update knowledge base stats update_knowledge_base_stats() # Return result return { "status": "success", "filename": uploaded_file.name, "document_id": result_state.get("document_id"), "protocol_id": result_state.get("extracted_study", {}).get("protocol_id") } except Exception as e: st.error(f"Error processing document: {e}") return { "status": "error", "error": str(e), "filename": uploaded_file.name } finally: # Clean up temporary file if 'file_path' in locals(): try: os.unlink(file_path) except: pass def chat_with_protocol_coach(query): """Process a query through the Protocol Coach.""" try: # Initialize state for Protocol Coach initial_state = { "query": query, "chat_history": st.session_state.chat_history } # Run Protocol Coach workflow result_state = coach_graph.invoke(initial_state) return { "status": "success", "response": result_state.get("response", "I couldn't generate a response."), "context": result_state.get("retrieved_context", []) } except Exception as e: return { "status": "error", "error": str(e) } def generate_document_section(section_type, protocol_id=None, style_guide=None): """Generate a document section using the content authoring workflow.""" try: # Initialize state for Content Authoring initial_state = { "section_type": section_type, "target_protocol_id": protocol_id, "style_guide": style_guide } # Run Content Authoring workflow result_state = authoring_graph.invoke(initial_state) return { "status": "success", "content": result_state.get("generated_content", "I couldn't generate the content."), "context": result_state.get("retrieved_context", []) } except Exception as e: return { "status": "error", "error": str(e) } def analyze_document_traceability(source_id, target_id, entity_type): """Analyze traceability between two documents.""" try: # Initialize state for Traceability Analysis initial_state = { "source_document_id": source_id, "target_document_id": target_id, "entity_type": entity_type } # Run Traceability Analysis workflow result_state = traceability_graph.invoke(initial_state) return { "status": "success", "analysis": result_state.get("analysis", "I couldn't perform the analysis."), "matched_pairs": result_state.get("matched_pairs", []) } except Exception as e: return { "status": "error", "error": str(e) } # ========================================================================= # Sidebar: Document Upload and Management # ========================================================================= def render_sidebar(): """Render the sidebar for document management.""" st.sidebar.title("Document Management") # Knowledge Base Stats st.sidebar.subheader("Knowledge Base Stats") stats = st.session_state.knowledge_base_stats col1, col2 = st.sidebar.columns(2) col1.metric("Documents", stats["documents"]) col2.metric("Studies", stats["studies"]) col1.metric("Objectives", stats["objectives"]) col2.metric("Endpoints", stats["endpoints"]) st.sidebar.metric("Vector Chunks", stats["vectors"]) # Document Upload st.sidebar.subheader("Upload Documents") uploaded_files = st.sidebar.file_uploader( "Upload Protocol/SAP PDFs", type="pdf", accept_multiple_files=True, help="Upload clinical documents (Protocol, SAP, etc.) to add to the knowledge base." ) # Process uploaded files if any if uploaded_files: if st.sidebar.button("Process Documents"): with st.sidebar.expander("Processing Results", expanded=True): for uploaded_file in uploaded_files: st.write(f"Processing: {uploaded_file.name}") result = process_document(uploaded_file) if result["status"] == "success": st.success(f"Successfully processed {result['filename']}") # Add to documents list if not already there doc_exists = False for doc in st.session_state.documents: if doc.get("filename") == result["filename"]: doc_exists = True break if not doc_exists: st.session_state.documents.append({ "filename": result["filename"], "document_id": result.get("document_id"), "protocol_id": result.get("protocol_id"), "processed_date": datetime.now().strftime("%Y-%m-%d %H:%M") }) else: st.error(f"Error processing {result['filename']}: {result.get('error', 'Unknown error')}") # Document list st.sidebar.subheader("Processed Documents") if not st.session_state.documents: st.sidebar.info("No documents processed yet.") else: for i, doc in enumerate(st.session_state.documents): with st.sidebar.expander(f"{doc['filename']}"): st.write(f"**Protocol ID:** {doc.get('protocol_id', 'Unknown')}") st.write(f"**Processed:** {doc.get('processed_date', 'Unknown')}") # Refresh Stats Button if st.sidebar.button("Refresh Stats"): update_knowledge_base_stats() st.sidebar.success("Stats refreshed!") # ========================================================================= # Main Content Tabs # ========================================================================= def render_protocol_coach_tab(): """Render the Protocol Coach chatbot tab.""" st.header("Protocol Coach Chatbot") st.info("Ask questions about the protocol documents in the knowledge base. The Protocol Coach will retrieve relevant information to answer your questions.") # Initialize or display chat history for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.markdown(message["content"]) # Chat input if query := st.chat_input("Ask about protocols..."): # Add user message to chat history and display st.session_state.chat_history.append({"role": "user", "content": query}) with st.chat_message("user"): st.markdown(query) # Process query with st.chat_message("assistant"): with st.spinner("Thinking..."): result = chat_with_protocol_coach(query) if result["status"] == "success": st.markdown(result["response"]) # Show context sources if debug mode enabled if st.session_state.get("debug_mode", False): with st.expander("Context Sources"): for i, ctx in enumerate(result.get("context", [])): st.write(f"**Source {i+1}:** {ctx.get('metadata', {}).get('source', 'Unknown')}") st.write(f"**Section:** {ctx.get('metadata', {}).get('section', 'Unknown')}") st.write("---") # Add assistant response to chat history st.session_state.chat_history.append({"role": "assistant", "content": result["response"]}) else: st.error(f"Error: {result.get('error', 'Unknown error')}") st.session_state.chat_history.append({"role": "assistant", "content": f"Error: {result.get('error', 'Unknown error')}"}) def render_content_authoring_tab(): """Render the Content Authoring tab.""" st.header("Content Authoring Assistant") st.info("Generate document sections based on knowledge extracted from similar documents.") col1, col2 = st.columns([1, 1]) with col1: st.subheader("Content Generation Settings") # Section Type Selection section_types = [ "Introduction", "Objectives and Endpoints", "Study Design", "Study Population", "Statistical Considerations", "Inclusion Criteria", "Exclusion Criteria", "Safety Assessments", "Pharmacokinetic Assessments" ] section_type = st.selectbox("Select Section Type", section_types) # Protocol Selection for Context (Optional) protocol_options = ["--None--"] for doc in st.session_state.documents: if doc.get("protocol_id"): protocol_options.append(doc.get("protocol_id")) target_protocol = st.selectbox( "Target Protocol ID (Optional)", protocol_options ) target_protocol = None if target_protocol == "--None--" else target_protocol # Style Guide (Optional) style_guide = st.text_area( "Style Guide (Optional)", placeholder="Enter any specific style guidelines or content requirements..." ) # Generate Button generate_button = st.button("Generate Content") # Debug toggle st.session_state.debug_mode = st.checkbox("Show Context Sources", value=st.session_state.get("debug_mode", False)) with col2: st.subheader("Generated Content") if generate_button: with st.spinner("Generating content..."): result = generate_document_section( section_type=section_type, protocol_id=target_protocol, style_guide=style_guide if style_guide else None ) if result["status"] == "success": st.markdown(result["content"]) # Show context sources if debug mode enabled if st.session_state.get("debug_mode", False): with st.expander("Context Sources"): for i, ctx in enumerate(result.get("context", [])): st.write(f"**Source {i+1}:** {ctx.get('metadata', {}).get('source', 'Unknown')}") st.write(f"**Section:** {ctx.get('metadata', {}).get('section', 'Unknown')}") st.write("---") else: st.error(f"Error: {result.get('error', 'Unknown error')}") def render_traceability_tab(): """Render the Document Traceability tab.""" st.header("Cross-Document Traceability") st.info("Analyze relationships between related documents (e.g., Protocol and SAP).") col1, col2 = st.columns([1, 1]) with col1: st.subheader("Traceability Analysis Settings") # Document Selection document_options = [] for doc in st.session_state.documents: document_options.append({ "id": doc.get("document_id", ""), "label": f"{doc['filename']} ({doc.get('protocol_id', 'Unknown')})" }) # Source Document source_options = [{"id": "", "label": "--Select Source Document--"}] + document_options source_doc = st.selectbox( "Source Document", options=source_options, format_func=lambda x: x["label"] ) # Target Document target_options = [{"id": "", "label": "--Select Target Document--"}] + document_options target_doc = st.selectbox( "Target Document", options=target_options, format_func=lambda x: x["label"] ) # Entity Type entity_types = [ {"id": "objectives", "label": "Study Objectives"}, {"id": "endpoints", "label": "Endpoints"}, {"id": "population", "label": "Population Criteria"} ] entity_type = st.selectbox( "Entity Type to Compare", options=entity_types, format_func=lambda x: x["label"] ) # Analyze Button analyze_button = st.button("Analyze Traceability") with col2: st.subheader("Analysis Results") if analyze_button: if not source_doc["id"] or not target_doc["id"]: st.error("Please select both source and target documents.") else: with st.spinner("Analyzing traceability..."): result = analyze_document_traceability( source_id=source_doc["id"], target_id=target_doc["id"], entity_type=entity_type["id"] ) if result["status"] == "success": st.markdown(result["analysis"]) # Show matched pairs if debug mode enabled if st.session_state.get("debug_mode", False) and result.get("matched_pairs"): with st.expander("Matched Entity Pairs"): for i, pair in enumerate(result["matched_pairs"]): st.write(f"**Pair {i+1}**") st.write(f"**Source:** {pair.get('source_text', 'Unknown')}") st.write(f"**Target:** {pair.get('target_text', 'Unknown')}") st.write("---") else: st.error(f"Error: {result.get('error', 'Unknown error')}") def render_knowledge_explorer_tab(): """Render the Knowledge Base Explorer tab.""" st.header("Knowledge Base Explorer") st.info("Explore the structured data extracted from documents in the knowledge base.") # Entity Type Selection entity_types = [ {"id": "studies", "label": "Studies"}, {"id": "objectives", "label": "Study Objectives"}, {"id": "endpoints", "label": "Endpoints"}, {"id": "population", "label": "Population Criteria"}, {"id": "documents", "label": "Documents"} ] entity_type = st.selectbox( "Select Entity Type", options=entity_types, format_func=lambda x: x["label"] ) # Filter by Protocol ID (Optional) protocol_options = ["--All Protocols--"] for doc in st.session_state.documents: if doc.get("protocol_id") and doc.get("protocol_id") not in protocol_options: protocol_options.append(doc.get("protocol_id")) filter_protocol = st.selectbox( "Filter by Protocol ID", protocol_options ) filter_protocol = None if filter_protocol == "--All Protocols--" else filter_protocol # Search Query (Optional) search_query = st.text_input( "Search Query (Optional)", placeholder="Enter text to search for..." ) # Display Results st.subheader("Results") try: # Retrieve data based on entity type if entity_type["id"] == "studies": if filter_protocol: data = [knowledge_store.get_study_by_protocol_id(filter_protocol)] else: data = knowledge_store.get_all_studies() elif entity_type["id"] == "objectives": if filter_protocol: data = knowledge_store.get_objectives_by_protocol_id(filter_protocol) else: # Get all objectives across protocols data = [] documents = knowledge_store.get_all_documents() protocol_ids = set() for doc in documents: if "protocol_id" in doc and doc["protocol_id"]: protocol_ids.add(doc["protocol_id"]) for pid in protocol_ids: data.extend(knowledge_store.get_objectives_by_protocol_id(pid)) elif entity_type["id"] == "endpoints": if filter_protocol: data = knowledge_store.get_endpoints_by_protocol_id(filter_protocol) else: # Get all endpoints across protocols data = [] documents = knowledge_store.get_all_documents() protocol_ids = set() for doc in documents: if "protocol_id" in doc and doc["protocol_id"]: protocol_ids.add(doc["protocol_id"]) for pid in protocol_ids: data.extend(knowledge_store.get_endpoints_by_protocol_id(pid)) elif entity_type["id"] == "population": if filter_protocol: data = knowledge_store.get_population_criteria_by_protocol_id(filter_protocol) else: # Get all population criteria across protocols data = [] documents = knowledge_store.get_all_documents() protocol_ids = set() for doc in documents: if "protocol_id" in doc and doc["protocol_id"]: protocol_ids.add(doc["protocol_id"]) for pid in protocol_ids: data.extend(knowledge_store.get_population_criteria_by_protocol_id(pid)) elif entity_type["id"] == "documents": if filter_protocol: data = knowledge_store.get_documents_by_protocol_id(filter_protocol) else: data = knowledge_store.get_all_documents() else: data = [] # Filter by search query if provided if search_query: filtered_data = [] search_lower = search_query.lower() for item in data: # Convert item to string for searching item_str = json.dumps(item).lower() if search_lower in item_str: filtered_data.append(item) data = filtered_data # Display results if not data: st.info("No data found.") else: st.write(f"{len(data)} items found") # Display as table if possible, otherwise as JSON try: df = pd.DataFrame(data) st.dataframe(df, use_container_width=True) except Exception as e: st.json(data) except Exception as e: st.error(f"Error retrieving data: {e}") # ========================================================================= # Main App # ========================================================================= def main(): """Main application function.""" st.title("🧠 Pharmaceutical R&D Knowledge Ecosystem") # Render the sidebar for document management render_sidebar() # Initialize knowledge base stats on first load if st.session_state.knowledge_base_stats["documents"] == 0: update_knowledge_base_stats() # Main content tabs tab1, tab2, tab3, tab4 = st.tabs([ "📝 Content Authoring", "🤖 Protocol Coach", "🔍 Knowledge Explorer", "🔄 Cross-Document Traceability" ]) with tab1: render_content_authoring_tab() with tab2: render_protocol_coach_tab() with tab3: render_knowledge_explorer_tab() with tab4: render_traceability_tab() # Footer st.markdown("---") st.caption("Pharmaceutical R&D Knowledge Ecosystem | A demonstration of AI-assisted document processing and knowledge management") if __name__ == "__main__": main()