Spaces:
Runtime error
Runtime error
| """ | |
| Streamlit application for the Pharmaceutical R&D Knowledge Ecosystem. | |
| """ | |
| import streamlit as st | |
| import os | |
| import pandas as pd | |
| import json | |
| import tempfile | |
| import time | |
| from datetime import datetime | |
| from pdf_processor import PDFProcessor | |
| from knowledge_store import KnowledgeStore | |
| from llm_interface import LLMInterface | |
| from graph_builder import ( | |
| init_handlers, | |
| build_document_extraction_graph, | |
| build_protocol_coach_graph, | |
| build_content_authoring_graph, | |
| build_traceability_graph | |
| ) | |
| # ========================================================================= | |
| # App Setup and Configuration | |
| # ========================================================================= | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Pharma R&D Knowledge Ecosystem", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Initialize session state variables if they don't exist | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "documents" not in st.session_state: | |
| st.session_state.documents = [] | |
| if "knowledge_base_stats" not in st.session_state: | |
| st.session_state.knowledge_base_stats = { | |
| "documents": 0, | |
| "studies": 0, | |
| "endpoints": 0, | |
| "objectives": 0, | |
| "vectors": 0 | |
| } | |
| # Initialize our handlers and graphs | |
| def initialize_app(): | |
| """Initialize app resources and LangGraph workflows.""" | |
| # Get API key from environment or secrets | |
| api_key = os.environ.get("ANTHROPIC_API_KEY") | |
| if not api_key and hasattr(st, "secrets") and "ANTHROPIC_API_KEY" in st.secrets: | |
| api_key = st.secrets["ANTHROPIC_API_KEY"] | |
| # Initialize handlers | |
| pdf_processor, knowledge_store, llm_interface = init_handlers(api_key) | |
| # Build LangGraph workflows | |
| extraction_graph = build_document_extraction_graph() | |
| coach_graph = build_protocol_coach_graph() | |
| authoring_graph = build_content_authoring_graph() | |
| traceability_graph = build_traceability_graph() | |
| return { | |
| "pdf_processor": pdf_processor, | |
| "knowledge_store": knowledge_store, | |
| "llm_interface": llm_interface, | |
| "extraction_graph": extraction_graph, | |
| "coach_graph": coach_graph, | |
| "authoring_graph": authoring_graph, | |
| "traceability_graph": traceability_graph | |
| } | |
| # Initialize app resources | |
| app_resources = initialize_app() | |
| pdf_processor = app_resources["pdf_processor"] | |
| knowledge_store = app_resources["knowledge_store"] | |
| llm_interface = app_resources["llm_interface"] | |
| extraction_graph = app_resources["extraction_graph"] | |
| coach_graph = app_resources["coach_graph"] | |
| authoring_graph = app_resources["authoring_graph"] | |
| traceability_graph = app_resources["traceability_graph"] | |
| # ========================================================================= | |
| # Helper Functions | |
| # ========================================================================= | |
| def update_knowledge_base_stats(): | |
| """Update the knowledge base statistics in session state.""" | |
| try: | |
| # Get counts of different entity types | |
| documents = knowledge_store.get_all_documents() | |
| document_count = len(documents) | |
| # Get unique protocol IDs | |
| protocol_ids = set() | |
| for doc in documents: | |
| if "protocol_id" in doc and doc["protocol_id"]: | |
| protocol_ids.add(doc["protocol_id"]) | |
| # Get vector store stats | |
| vector_stats = knowledge_store.get_vector_store_stats() | |
| vector_count = vector_stats.get("document_count", 0) | |
| # Count objectives and endpoints across all protocols | |
| objective_count = 0 | |
| endpoint_count = 0 | |
| for protocol_id in protocol_ids: | |
| objectives = knowledge_store.get_objectives_by_protocol_id(protocol_id) | |
| endpoints = knowledge_store.get_endpoints_by_protocol_id(protocol_id) | |
| objective_count += len(objectives) | |
| endpoint_count += len(endpoints) | |
| # Update session state | |
| st.session_state.knowledge_base_stats = { | |
| "documents": document_count, | |
| "studies": len(protocol_ids), | |
| "objectives": objective_count, | |
| "endpoints": endpoint_count, | |
| "vectors": vector_count | |
| } | |
| except Exception as e: | |
| st.error(f"Error updating knowledge base stats: {e}") | |
| def process_document(uploaded_file): | |
| """Process an uploaded document and store in knowledge base.""" | |
| try: | |
| # Create a progress bar | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Step 1: Save the uploaded file | |
| status_text.text("Saving uploaded file...") | |
| progress_bar.progress(10) | |
| # Save uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
| tmp_file.write(uploaded_file.getbuffer()) | |
| file_path = tmp_file.name | |
| # Step 2: Process through LangGraph extraction workflow | |
| status_text.text("Parsing document...") | |
| progress_bar.progress(20) | |
| # Initialize state for extraction | |
| initial_state = { | |
| "document_path": file_path, | |
| "status": "initialized" | |
| } | |
| # Run extraction workflow | |
| result_state = extraction_graph.invoke(initial_state) | |
| # Update progress based on status | |
| if result_state.get("status") == "error": | |
| status_text.text(f"Error: {result_state.get('error', 'Unknown error')}") | |
| progress_bar.progress(100) | |
| return { | |
| "status": "error", | |
| "error": result_state.get("error", "Unknown error"), | |
| "filename": uploaded_file.name | |
| } | |
| # Update progress | |
| status_text.text("Processing completed successfully!") | |
| progress_bar.progress(100) | |
| # Update knowledge base stats | |
| update_knowledge_base_stats() | |
| # Return result | |
| return { | |
| "status": "success", | |
| "filename": uploaded_file.name, | |
| "document_id": result_state.get("document_id"), | |
| "protocol_id": result_state.get("extracted_study", {}).get("protocol_id") | |
| } | |
| except Exception as e: | |
| st.error(f"Error processing document: {e}") | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| "filename": uploaded_file.name | |
| } | |
| finally: | |
| # Clean up temporary file | |
| if 'file_path' in locals(): | |
| try: | |
| os.unlink(file_path) | |
| except: | |
| pass | |
| def chat_with_protocol_coach(query): | |
| """Process a query through the Protocol Coach.""" | |
| try: | |
| # Initialize state for Protocol Coach | |
| initial_state = { | |
| "query": query, | |
| "chat_history": st.session_state.chat_history | |
| } | |
| # Run Protocol Coach workflow | |
| result_state = coach_graph.invoke(initial_state) | |
| return { | |
| "status": "success", | |
| "response": result_state.get("response", "I couldn't generate a response."), | |
| "context": result_state.get("retrieved_context", []) | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e) | |
| } | |
| def generate_document_section(section_type, protocol_id=None, style_guide=None): | |
| """Generate a document section using the content authoring workflow.""" | |
| try: | |
| # Initialize state for Content Authoring | |
| initial_state = { | |
| "section_type": section_type, | |
| "target_protocol_id": protocol_id, | |
| "style_guide": style_guide | |
| } | |
| # Run Content Authoring workflow | |
| result_state = authoring_graph.invoke(initial_state) | |
| return { | |
| "status": "success", | |
| "content": result_state.get("generated_content", "I couldn't generate the content."), | |
| "context": result_state.get("retrieved_context", []) | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e) | |
| } | |
| def analyze_document_traceability(source_id, target_id, entity_type): | |
| """Analyze traceability between two documents.""" | |
| try: | |
| # Initialize state for Traceability Analysis | |
| initial_state = { | |
| "source_document_id": source_id, | |
| "target_document_id": target_id, | |
| "entity_type": entity_type | |
| } | |
| # Run Traceability Analysis workflow | |
| result_state = traceability_graph.invoke(initial_state) | |
| return { | |
| "status": "success", | |
| "analysis": result_state.get("analysis", "I couldn't perform the analysis."), | |
| "matched_pairs": result_state.get("matched_pairs", []) | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e) | |
| } | |
| # ========================================================================= | |
| # Sidebar: Document Upload and Management | |
| # ========================================================================= | |
| def render_sidebar(): | |
| """Render the sidebar for document management.""" | |
| st.sidebar.title("Document Management") | |
| # Knowledge Base Stats | |
| st.sidebar.subheader("Knowledge Base Stats") | |
| stats = st.session_state.knowledge_base_stats | |
| col1, col2 = st.sidebar.columns(2) | |
| col1.metric("Documents", stats["documents"]) | |
| col2.metric("Studies", stats["studies"]) | |
| col1.metric("Objectives", stats["objectives"]) | |
| col2.metric("Endpoints", stats["endpoints"]) | |
| st.sidebar.metric("Vector Chunks", stats["vectors"]) | |
| # Document Upload | |
| st.sidebar.subheader("Upload Documents") | |
| uploaded_files = st.sidebar.file_uploader( | |
| "Upload Protocol/SAP PDFs", | |
| type="pdf", | |
| accept_multiple_files=True, | |
| help="Upload clinical documents (Protocol, SAP, etc.) to add to the knowledge base." | |
| ) | |
| # Process uploaded files if any | |
| if uploaded_files: | |
| if st.sidebar.button("Process Documents"): | |
| with st.sidebar.expander("Processing Results", expanded=True): | |
| for uploaded_file in uploaded_files: | |
| st.write(f"Processing: {uploaded_file.name}") | |
| result = process_document(uploaded_file) | |
| if result["status"] == "success": | |
| st.success(f"Successfully processed {result['filename']}") | |
| # Add to documents list if not already there | |
| doc_exists = False | |
| for doc in st.session_state.documents: | |
| if doc.get("filename") == result["filename"]: | |
| doc_exists = True | |
| break | |
| if not doc_exists: | |
| st.session_state.documents.append({ | |
| "filename": result["filename"], | |
| "document_id": result.get("document_id"), | |
| "protocol_id": result.get("protocol_id"), | |
| "processed_date": datetime.now().strftime("%Y-%m-%d %H:%M") | |
| }) | |
| else: | |
| st.error(f"Error processing {result['filename']}: {result.get('error', 'Unknown error')}") | |
| # Document list | |
| st.sidebar.subheader("Processed Documents") | |
| if not st.session_state.documents: | |
| st.sidebar.info("No documents processed yet.") | |
| else: | |
| for i, doc in enumerate(st.session_state.documents): | |
| with st.sidebar.expander(f"{doc['filename']}"): | |
| st.write(f"**Protocol ID:** {doc.get('protocol_id', 'Unknown')}") | |
| st.write(f"**Processed:** {doc.get('processed_date', 'Unknown')}") | |
| # Refresh Stats Button | |
| if st.sidebar.button("Refresh Stats"): | |
| update_knowledge_base_stats() | |
| st.sidebar.success("Stats refreshed!") | |
| # ========================================================================= | |
| # Main Content Tabs | |
| # ========================================================================= | |
| def render_protocol_coach_tab(): | |
| """Render the Protocol Coach chatbot tab.""" | |
| st.header("Protocol Coach Chatbot") | |
| st.info("Ask questions about the protocol documents in the knowledge base. The Protocol Coach will retrieve relevant information to answer your questions.") | |
| # Initialize or display chat history | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Chat input | |
| if query := st.chat_input("Ask about protocols..."): | |
| # Add user message to chat history and display | |
| st.session_state.chat_history.append({"role": "user", "content": query}) | |
| with st.chat_message("user"): | |
| st.markdown(query) | |
| # Process query | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| result = chat_with_protocol_coach(query) | |
| if result["status"] == "success": | |
| st.markdown(result["response"]) | |
| # Show context sources if debug mode enabled | |
| if st.session_state.get("debug_mode", False): | |
| with st.expander("Context Sources"): | |
| for i, ctx in enumerate(result.get("context", [])): | |
| st.write(f"**Source {i+1}:** {ctx.get('metadata', {}).get('source', 'Unknown')}") | |
| st.write(f"**Section:** {ctx.get('metadata', {}).get('section', 'Unknown')}") | |
| st.write("---") | |
| # Add assistant response to chat history | |
| st.session_state.chat_history.append({"role": "assistant", "content": result["response"]}) | |
| else: | |
| st.error(f"Error: {result.get('error', 'Unknown error')}") | |
| st.session_state.chat_history.append({"role": "assistant", "content": f"Error: {result.get('error', 'Unknown error')}"}) | |
| def render_content_authoring_tab(): | |
| """Render the Content Authoring tab.""" | |
| st.header("Content Authoring Assistant") | |
| st.info("Generate document sections based on knowledge extracted from similar documents.") | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("Content Generation Settings") | |
| # Section Type Selection | |
| section_types = [ | |
| "Introduction", | |
| "Objectives and Endpoints", | |
| "Study Design", | |
| "Study Population", | |
| "Statistical Considerations", | |
| "Inclusion Criteria", | |
| "Exclusion Criteria", | |
| "Safety Assessments", | |
| "Pharmacokinetic Assessments" | |
| ] | |
| section_type = st.selectbox("Select Section Type", section_types) | |
| # Protocol Selection for Context (Optional) | |
| protocol_options = ["--None--"] | |
| for doc in st.session_state.documents: | |
| if doc.get("protocol_id"): | |
| protocol_options.append(doc.get("protocol_id")) | |
| target_protocol = st.selectbox( | |
| "Target Protocol ID (Optional)", | |
| protocol_options | |
| ) | |
| target_protocol = None if target_protocol == "--None--" else target_protocol | |
| # Style Guide (Optional) | |
| style_guide = st.text_area( | |
| "Style Guide (Optional)", | |
| placeholder="Enter any specific style guidelines or content requirements..." | |
| ) | |
| # Generate Button | |
| generate_button = st.button("Generate Content") | |
| # Debug toggle | |
| st.session_state.debug_mode = st.checkbox("Show Context Sources", value=st.session_state.get("debug_mode", False)) | |
| with col2: | |
| st.subheader("Generated Content") | |
| if generate_button: | |
| with st.spinner("Generating content..."): | |
| result = generate_document_section( | |
| section_type=section_type, | |
| protocol_id=target_protocol, | |
| style_guide=style_guide if style_guide else None | |
| ) | |
| if result["status"] == "success": | |
| st.markdown(result["content"]) | |
| # Show context sources if debug mode enabled | |
| if st.session_state.get("debug_mode", False): | |
| with st.expander("Context Sources"): | |
| for i, ctx in enumerate(result.get("context", [])): | |
| st.write(f"**Source {i+1}:** {ctx.get('metadata', {}).get('source', 'Unknown')}") | |
| st.write(f"**Section:** {ctx.get('metadata', {}).get('section', 'Unknown')}") | |
| st.write("---") | |
| else: | |
| st.error(f"Error: {result.get('error', 'Unknown error')}") | |
| def render_traceability_tab(): | |
| """Render the Document Traceability tab.""" | |
| st.header("Cross-Document Traceability") | |
| st.info("Analyze relationships between related documents (e.g., Protocol and SAP).") | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("Traceability Analysis Settings") | |
| # Document Selection | |
| document_options = [] | |
| for doc in st.session_state.documents: | |
| document_options.append({ | |
| "id": doc.get("document_id", ""), | |
| "label": f"{doc['filename']} ({doc.get('protocol_id', 'Unknown')})" | |
| }) | |
| # Source Document | |
| source_options = [{"id": "", "label": "--Select Source Document--"}] + document_options | |
| source_doc = st.selectbox( | |
| "Source Document", | |
| options=source_options, | |
| format_func=lambda x: x["label"] | |
| ) | |
| # Target Document | |
| target_options = [{"id": "", "label": "--Select Target Document--"}] + document_options | |
| target_doc = st.selectbox( | |
| "Target Document", | |
| options=target_options, | |
| format_func=lambda x: x["label"] | |
| ) | |
| # Entity Type | |
| entity_types = [ | |
| {"id": "objectives", "label": "Study Objectives"}, | |
| {"id": "endpoints", "label": "Endpoints"}, | |
| {"id": "population", "label": "Population Criteria"} | |
| ] | |
| entity_type = st.selectbox( | |
| "Entity Type to Compare", | |
| options=entity_types, | |
| format_func=lambda x: x["label"] | |
| ) | |
| # Analyze Button | |
| analyze_button = st.button("Analyze Traceability") | |
| with col2: | |
| st.subheader("Analysis Results") | |
| if analyze_button: | |
| if not source_doc["id"] or not target_doc["id"]: | |
| st.error("Please select both source and target documents.") | |
| else: | |
| with st.spinner("Analyzing traceability..."): | |
| result = analyze_document_traceability( | |
| source_id=source_doc["id"], | |
| target_id=target_doc["id"], | |
| entity_type=entity_type["id"] | |
| ) | |
| if result["status"] == "success": | |
| st.markdown(result["analysis"]) | |
| # Show matched pairs if debug mode enabled | |
| if st.session_state.get("debug_mode", False) and result.get("matched_pairs"): | |
| with st.expander("Matched Entity Pairs"): | |
| for i, pair in enumerate(result["matched_pairs"]): | |
| st.write(f"**Pair {i+1}**") | |
| st.write(f"**Source:** {pair.get('source_text', 'Unknown')}") | |
| st.write(f"**Target:** {pair.get('target_text', 'Unknown')}") | |
| st.write("---") | |
| else: | |
| st.error(f"Error: {result.get('error', 'Unknown error')}") | |
| def render_knowledge_explorer_tab(): | |
| """Render the Knowledge Base Explorer tab.""" | |
| st.header("Knowledge Base Explorer") | |
| st.info("Explore the structured data extracted from documents in the knowledge base.") | |
| # Entity Type Selection | |
| entity_types = [ | |
| {"id": "studies", "label": "Studies"}, | |
| {"id": "objectives", "label": "Study Objectives"}, | |
| {"id": "endpoints", "label": "Endpoints"}, | |
| {"id": "population", "label": "Population Criteria"}, | |
| {"id": "documents", "label": "Documents"} | |
| ] | |
| entity_type = st.selectbox( | |
| "Select Entity Type", | |
| options=entity_types, | |
| format_func=lambda x: x["label"] | |
| ) | |
| # Filter by Protocol ID (Optional) | |
| protocol_options = ["--All Protocols--"] | |
| for doc in st.session_state.documents: | |
| if doc.get("protocol_id") and doc.get("protocol_id") not in protocol_options: | |
| protocol_options.append(doc.get("protocol_id")) | |
| filter_protocol = st.selectbox( | |
| "Filter by Protocol ID", | |
| protocol_options | |
| ) | |
| filter_protocol = None if filter_protocol == "--All Protocols--" else filter_protocol | |
| # Search Query (Optional) | |
| search_query = st.text_input( | |
| "Search Query (Optional)", | |
| placeholder="Enter text to search for..." | |
| ) | |
| # Display Results | |
| st.subheader("Results") | |
| try: | |
| # Retrieve data based on entity type | |
| if entity_type["id"] == "studies": | |
| if filter_protocol: | |
| data = [knowledge_store.get_study_by_protocol_id(filter_protocol)] | |
| else: | |
| data = knowledge_store.get_all_studies() | |
| elif entity_type["id"] == "objectives": | |
| if filter_protocol: | |
| data = knowledge_store.get_objectives_by_protocol_id(filter_protocol) | |
| else: | |
| # Get all objectives across protocols | |
| data = [] | |
| documents = knowledge_store.get_all_documents() | |
| protocol_ids = set() | |
| for doc in documents: | |
| if "protocol_id" in doc and doc["protocol_id"]: | |
| protocol_ids.add(doc["protocol_id"]) | |
| for pid in protocol_ids: | |
| data.extend(knowledge_store.get_objectives_by_protocol_id(pid)) | |
| elif entity_type["id"] == "endpoints": | |
| if filter_protocol: | |
| data = knowledge_store.get_endpoints_by_protocol_id(filter_protocol) | |
| else: | |
| # Get all endpoints across protocols | |
| data = [] | |
| documents = knowledge_store.get_all_documents() | |
| protocol_ids = set() | |
| for doc in documents: | |
| if "protocol_id" in doc and doc["protocol_id"]: | |
| protocol_ids.add(doc["protocol_id"]) | |
| for pid in protocol_ids: | |
| data.extend(knowledge_store.get_endpoints_by_protocol_id(pid)) | |
| elif entity_type["id"] == "population": | |
| if filter_protocol: | |
| data = knowledge_store.get_population_criteria_by_protocol_id(filter_protocol) | |
| else: | |
| # Get all population criteria across protocols | |
| data = [] | |
| documents = knowledge_store.get_all_documents() | |
| protocol_ids = set() | |
| for doc in documents: | |
| if "protocol_id" in doc and doc["protocol_id"]: | |
| protocol_ids.add(doc["protocol_id"]) | |
| for pid in protocol_ids: | |
| data.extend(knowledge_store.get_population_criteria_by_protocol_id(pid)) | |
| elif entity_type["id"] == "documents": | |
| if filter_protocol: | |
| data = knowledge_store.get_documents_by_protocol_id(filter_protocol) | |
| else: | |
| data = knowledge_store.get_all_documents() | |
| else: | |
| data = [] | |
| # Filter by search query if provided | |
| if search_query: | |
| filtered_data = [] | |
| search_lower = search_query.lower() | |
| for item in data: | |
| # Convert item to string for searching | |
| item_str = json.dumps(item).lower() | |
| if search_lower in item_str: | |
| filtered_data.append(item) | |
| data = filtered_data | |
| # Display results | |
| if not data: | |
| st.info("No data found.") | |
| else: | |
| st.write(f"{len(data)} items found") | |
| # Display as table if possible, otherwise as JSON | |
| try: | |
| df = pd.DataFrame(data) | |
| st.dataframe(df, use_container_width=True) | |
| except Exception as e: | |
| st.json(data) | |
| except Exception as e: | |
| st.error(f"Error retrieving data: {e}") | |
| # ========================================================================= | |
| # Main App | |
| # ========================================================================= | |
| def main(): | |
| """Main application function.""" | |
| st.title("π§ Pharmaceutical R&D Knowledge Ecosystem") | |
| # Render the sidebar for document management | |
| render_sidebar() | |
| # Initialize knowledge base stats on first load | |
| if st.session_state.knowledge_base_stats["documents"] == 0: | |
| update_knowledge_base_stats() | |
| # Main content tabs | |
| tab1, tab2, tab3, tab4 = st.tabs([ | |
| "π Content Authoring", | |
| "π€ Protocol Coach", | |
| "π Knowledge Explorer", | |
| "π Cross-Document Traceability" | |
| ]) | |
| with tab1: | |
| render_content_authoring_tab() | |
| with tab2: | |
| render_protocol_coach_tab() | |
| with tab3: | |
| render_knowledge_explorer_tab() | |
| with tab4: | |
| render_traceability_tab() | |
| # Footer | |
| st.markdown("---") | |
| st.caption("Pharmaceutical R&D Knowledge Ecosystem | A demonstration of AI-assisted document processing and knowledge management") | |
| if __name__ == "__main__": | |
| main() |