Spaces:
Runtime error
Runtime error
| """ | |
| Streamlit App for Hypertension Cost-Effectiveness Analysis with LangChain | |
| Enhanced with AI Assistant and Personalized Chat using OpenAI | |
| Added: CSV upload and patient ID retrieval feature | |
| """ | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from typing import Dict, List, Any | |
| import uuid | |
| import os | |
| import shutil | |
| from datetime import datetime, timedelta | |
| import hashlib | |
| # LangChain imports | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_core.documents import Document | |
| # Import the hypertension model functions | |
| from hypertension_model_fixed2 import ( | |
| run_analysis, beta_men, beta_women, | |
| hazards_from_beta, Q_matrix, discrete_P, | |
| run_markov, icer | |
| ) | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="Hypertension CEA Tool with AI", | |
| page_icon="❤️", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # ===== 用戶隔離系統初始化 ===== | |
| # 為每個用戶生成唯一 ID | |
| if 'user_id' not in st.session_state: | |
| st.session_state.user_id = str(uuid.uuid4()) | |
| st.session_state.session_start = datetime.now() | |
| # Session 過期檢查(2小時後過期) | |
| if 'session_start' in st.session_state: | |
| session_duration = datetime.now() - st.session_state.session_start | |
| if session_duration > timedelta(hours=2): | |
| st.warning("⏱️ Your session has expired (2 hours). Please refresh the page.") | |
| st.stop() | |
| # Initialize session state for chat histories | |
| if 'assistant_messages' not in st.session_state: | |
| st.session_state.assistant_messages = [] | |
| if 'recommendation_messages' not in st.session_state: | |
| st.session_state.recommendation_messages = [] | |
| if 'summary_generated' not in st.session_state: | |
| st.session_state.summary_generated = False | |
| if 'patients_df' not in st.session_state: | |
| st.session_state.patients_df = None | |
| if 'vectorstore' not in st.session_state: | |
| st.session_state.vectorstore = None | |
| if 'cea_results' not in st.session_state: | |
| st.session_state.cea_results = None | |
| if 'uploaded_files' not in st.session_state: | |
| st.session_state.uploaded_files = [] # ✅ 新增 | |
| if 'api_call_count' not in st.session_state: | |
| st.session_state.api_call_count = 0 # ✅ 新增 | |
| # ✅ 啟動時清理超過 24 小時的舊資料 | |
| def cleanup_old_data(): | |
| """Delete vector stores older than 24 hours""" | |
| data_dir = "./data/chroma" | |
| if not os.path.exists(data_dir): | |
| return | |
| cutoff_time = datetime.now() - timedelta(hours=24) | |
| try: | |
| for user_folder in os.listdir(data_dir): | |
| folder_path = os.path.join(data_dir, user_folder) | |
| if os.path.isdir(folder_path): | |
| mod_time = datetime.fromtimestamp(os.path.getmtime(folder_path)) | |
| if mod_time < cutoff_time: | |
| shutil.rmtree(folder_path) | |
| except Exception as e: | |
| pass # 靜默失敗,不影響用戶 | |
| cleanup_old_data() | |
| # Header with OpenAI API Key input | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.title("❤️ Personalized LLM for Planning Cost Effectiveness Analysis") | |
| st.markdown("*Multistate ML Analysis Showcase Hypertension*") | |
| with col2: | |
| openai_api_key = st.text_input( | |
| "🔑 OpenAI API Key", | |
| type="password", | |
| placeholder="sk-...", | |
| help="Enter your OpenAI API key to enable AI features" | |
| ) | |
| if openai_api_key: | |
| st.success("✓ API Key set") | |
| else: | |
| st.warning("⚠️ Enter API key") | |
| # ali Check if API key is provided | |
| def get_llm(): | |
| """Initialize LangChain LLM with OpenAI - with rate limiting""" | |
| if not openai_api_key: | |
| return None | |
| # ✅ API 調用限制(每 session 最多 100 次) | |
| if st.session_state.api_call_count >= 100: | |
| st.error("⚠️ API call limit reached (100 calls per session). Please start a new session.") | |
| st.stop() | |
| try: | |
| llm = ChatOpenAI( | |
| model="gpt-4o-mini", | |
| temperature=0.7, | |
| openai_api_key=openai_api_key | |
| ) | |
| # ✅ 記錄 API 調用 | |
| st.session_state.api_call_count += 1 | |
| return llm | |
| except Exception as e: | |
| st.error(f"Error initializing OpenAI: {str(e)}") | |
| return None | |
| # Create vector store from patient data | |
| def create_patient_vectorstore(patients_df: pd.DataFrame): | |
| """Create vector store from patient dataframe for RAG retrieval - USER ISOLATED""" | |
| if not openai_api_key: | |
| return None | |
| try: | |
| # ✅ 建立用戶專屬目錄 | |
| user_id = st.session_state.user_id | |
| persist_dir = f"./data/chroma/{user_id}" | |
| os.makedirs(persist_dir, exist_ok=True) | |
| documents = [] | |
| for idx, row in patients_df.iterrows(): | |
| patient_text = f"""Patient ID: {row['patient_id']} | |
| Sex: {row['sex']}, Age: {row['age']}, Education: {row['education']} | |
| BMI: {row['bmi']} kg/m², Waist: {row['waist']} cm | |
| Fasting Glucose: {row['fasting_glucose']} mg/dL | |
| Total Cholesterol: {row['total_cholesterol']} mg/dL | |
| Uric Acid: {row['uric_acid']} mg/dL | |
| Smoking: {row['smoking']}, Alcohol: {row['alcohol']}, Exercise: {row['exercise']} | |
| Betel: {row.get('betel', 'No')}, Family History: {row['family_history']}""" | |
| doc = Document( | |
| page_content=patient_text, | |
| metadata={ | |
| "patient_id": row['patient_id'], | |
| "user_id": user_id # ✅ 加入 user_id 標記 | |
| } | |
| ) | |
| documents.append(doc) | |
| embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
| # ✅ 使用用戶專屬的 collection 和目錄 | |
| vectorstore = Chroma( | |
| collection_name=f"user_{user_id}_patients", | |
| embedding_function=embeddings, | |
| persist_directory=persist_dir | |
| ) | |
| # 清空舊資料(如果有) | |
| try: | |
| vectorstore.delete_collection() | |
| vectorstore = Chroma( | |
| collection_name=f"user_{user_id}_patients", | |
| embedding_function=embeddings, | |
| persist_directory=persist_dir | |
| ) | |
| except: | |
| pass | |
| # 加入新文件 | |
| vectorstore.add_documents(documents) | |
| return vectorstore | |
| except Exception as e: | |
| st.error(f"Error creating vector store: {str(e)}") | |
| return None | |
| # Retrieve patient by ID | |
| def retrieve_patient_by_id(patient_id: str): | |
| """Retrieve patient from dataframe by ID""" | |
| if st.session_state.patients_df is None: | |
| return None | |
| patient_row = st.session_state.patients_df[ | |
| st.session_state.patients_df['patient_id'] == patient_id | |
| ] | |
| if patient_row.empty: | |
| return None | |
| return patient_row.iloc[0].to_dict() | |
| # Sidebar for patient information | |
| st.sidebar.header("👤 Patient Information") | |
| # ali ===== NEW: Patient Data Upload Section ===== | |
| with st.sidebar.expander("📂 Upload Patient Database (Optional)", expanded=False): | |
| uploaded_file = st.file_uploader( | |
| "Upload CSV/Excel with patient data", | |
| type=['csv', 'xlsx'], | |
| help="Upload a file with multiple patients (Max 10MB)" | |
| ) | |
| if uploaded_file is not None: | |
| try: | |
| # ✅ 檔案大小限制 | |
| file_size_mb = uploaded_file.size / (1024 * 1024) | |
| if file_size_mb > 10: | |
| st.error(f"❌ File too large ({file_size_mb:.1f}MB). Max 10MB.") | |
| st.stop() | |
| if uploaded_file.name.endswith('.csv'): | |
| df = pd.read_csv(uploaded_file) | |
| else: | |
| df = pd.read_excel(uploaded_file) | |
| st.session_state.patients_df = df | |
| # ✅ 記錄上傳的檔案 | |
| file_hash = hashlib.md5(uploaded_file.getvalue()).hexdigest() | |
| if file_hash not in st.session_state.uploaded_files: | |
| st.session_state.uploaded_files.append(file_hash) | |
| st.success(f"✅ Loaded {len(df)} patients ({file_size_mb:.1f}MB)") | |
| # Optionally create vector store | |
| if openai_api_key and st.button("🔄 Create Vector Store for Smart Search"): | |
| with st.spinner("Creating isolated vector store..."): | |
| vectorstore = create_patient_vectorstore(df) | |
| if vectorstore: | |
| st.session_state.vectorstore = vectorstore | |
| st.success("✅ Vector store created! (Isolated to your session)") | |
| except Exception as e: | |
| st.error(f"Error loading file: {str(e)}") | |
| # Patient ID retrieval | |
| if st.session_state.patients_df is not None: | |
| st.markdown("---") | |
| patient_id_input = st.text_input("🔍 Enter Patient ID", placeholder="P001") | |
| if st.button("📥 Load Patient Data"): | |
| if patient_id_input: | |
| patient_data = retrieve_patient_by_id(patient_id_input) | |
| if patient_data: | |
| st.session_state.loaded_patient = patient_data | |
| st.success(f"✅ Loaded {patient_id_input}") | |
| st.rerun() | |
| else: | |
| st.error(f"Patient {patient_id_input} not found") | |
| st.sidebar.markdown("---") | |
| # Check if we have loaded patient data | |
| if 'loaded_patient' in st.session_state: | |
| # Use loaded patient data | |
| p = st.session_state.loaded_patient | |
| st.sidebar.info(f"📋 Loaded: {p['patient_id']}") | |
| sex = st.sidebar.radio("Biological Sex", ["Male", "Female"], | |
| index=0 if p['sex'] == 'Male' else 1, key="sex_loaded") | |
| st.sidebar.markdown("### Demographics") | |
| age = st.sidebar.number_input("Age", min_value=18, max_value=100, value=int(p['age']), key="age_loaded") | |
| education_idx = 0 if p['education'] in ['High', 'High (College or above)'] else 1 | |
| education = st.sidebar.selectbox("Education Level", | |
| ["High (College or above)", "Low (Below college)"], | |
| index=education_idx, key="edu_loaded") | |
| st.sidebar.markdown("### Anthropometrics") | |
| bmi = st.sidebar.number_input("BMI (kg/m²)", min_value=15.0, max_value=50.0, | |
| value=float(p['bmi']), format="%.1f", key="bmi_loaded") | |
| waist = st.sidebar.number_input("Waist Circumference (cm)", min_value=50, max_value=150, | |
| value=int(p['waist']), key="waist_loaded") | |
| st.sidebar.markdown("### Laboratory Values") | |
| fasting_glucose = st.sidebar.number_input("Fasting Glucose (mg/dL)", min_value=70, max_value=300, | |
| value=int(p['fasting_glucose']), key="glucose_loaded") | |
| total_cholesterol = st.sidebar.number_input("Total Cholesterol (mg/dL)", min_value=100, max_value=400, | |
| value=int(p['total_cholesterol']), key="chol_loaded") | |
| uric_acid = st.sidebar.number_input("Uric Acid (mg/dL)", min_value=2.0, max_value=15.0, | |
| value=float(p['uric_acid']), format="%.1f", key="ua_loaded") | |
| st.sidebar.markdown("### Lifestyle Factors") | |
| smoking_idx = 0 if p['smoking'] in ['No', 'Non-smoker'] else 1 | |
| smoking = st.sidebar.selectbox("Smoking Status", ["Non-smoker", "Current smoker"], | |
| index=smoking_idx, key="smoke_loaded") | |
| alcohol_idx = 0 if p['alcohol'] in ['No', 'None/Occasional'] else 1 | |
| alcohol = st.sidebar.selectbox("Alcohol Consumption", ["None/Occasional", "Regular drinker"], | |
| index=alcohol_idx, key="alcohol_loaded") | |
| exercise_idx = 0 if p['exercise'] in ['No', 'Infrequent'] else 1 | |
| exercise = st.sidebar.selectbox("Exercise Frequency", ["Infrequent", "Regular (≥3 times/week)"], | |
| index=exercise_idx, key="exercise_loaded") | |
| if sex == "Male": | |
| betel_idx = 0 if p.get('betel', 'No') == 'No' else 1 | |
| betel = st.sidebar.selectbox("Betel Nut Chewing", ["No", "Yes"], | |
| index=betel_idx, key="betel_loaded") | |
| family_idx = 0 if p['family_history'] == 'No' else 1 | |
| family_history = st.sidebar.selectbox("Family History of Hypertension", ["No", "Yes"], | |
| index=family_idx, key="fh_loaded") | |
| else: | |
| # Manual input (original functionality) | |
| sex = st.sidebar.radio("Biological Sex", ["Male", "Female"]) | |
| st.sidebar.markdown("### Demographics") | |
| age = st.sidebar.number_input("Age", min_value=18, max_value=100, value=45) | |
| education = st.sidebar.selectbox("Education Level", ["High (College or above)", "Low (Below college)"]) | |
| st.sidebar.markdown("### Anthropometrics") | |
| bmi = st.sidebar.number_input("BMI (kg/m²)", min_value=15.0, max_value=50.0, value=27.0, format="%.1f") | |
| if sex == "Male": | |
| waist = st.sidebar.number_input("Waist Circumference (cm)", min_value=60, max_value=150, value=88) | |
| else: | |
| waist = st.sidebar.number_input("Waist Circumference (cm)", min_value=50, max_value=150, value=78) | |
| st.sidebar.markdown("### Laboratory Values") | |
| fasting_glucose = st.sidebar.number_input("Fasting Glucose (mg/dL)", min_value=70, max_value=300, value=100) | |
| total_cholesterol = st.sidebar.number_input("Total Cholesterol (mg/dL)", min_value=100, max_value=400, value=190) | |
| if sex == "Male": | |
| uric_acid = st.sidebar.number_input("Uric Acid (mg/dL)", min_value=2.0, max_value=15.0, value=6.5, | |
| format="%.1f") | |
| else: | |
| uric_acid = st.sidebar.number_input("Uric Acid (mg/dL)", min_value=2.0, max_value=15.0, value=5.5, | |
| format="%.1f") | |
| st.sidebar.markdown("### Lifestyle Factors") | |
| smoking = st.sidebar.selectbox("Smoking Status", ["Non-smoker", "Current smoker"]) | |
| alcohol = st.sidebar.selectbox("Alcohol Consumption", ["None/Occasional", "Regular drinker"]) | |
| exercise = st.sidebar.selectbox("Exercise Frequency", ["Infrequent", "Regular (≥3 times/week)"]) | |
| if sex == "Male": | |
| betel = st.sidebar.selectbox("Betel Nut Chewing", ["No", "Yes"]) | |
| family_history = st.sidebar.selectbox("Family History of Hypertension", ["No", "Yes"]) | |
| # Convert inputs to feature dictionary | |
| def create_feature_dict(): | |
| features = {} | |
| if sex == "Male": | |
| features["Education_high"] = 1 if education == "High (College or above)" else 0 | |
| features["BMI_ge25"] = 1 if bmi >= 25 else 0 | |
| features["Waist_ge90"] = 1 if waist >= 90 else 0 | |
| features["Fasting_glu_high"] = 1 if fasting_glucose >= 110 else 0 | |
| features["TC_ge200"] = 1 if total_cholesterol >= 200 else 0 | |
| features["UA_high"] = 1 if uric_acid >= 7 else 0 | |
| features["Smoking_current"] = 1 if smoking == "Current smoker" else 0 | |
| features["Betel_current"] = 1 if betel == "Yes" else 0 | |
| features["Alcohol_current"] = 1 if alcohol == "Regular drinker" else 0 | |
| features["Exercise_freq"] = 1 if exercise == "Regular (≥3 times/week)" else 0 | |
| features["FHx_yes"] = 1 if family_history == "Yes" else 0 | |
| else: # Female | |
| features["Education_high"] = 1 if education == "High (College or above)" else 0 | |
| features["BMI_ge25"] = 1 if bmi >= 25 else 0 | |
| features["Waist_ge80"] = 1 if waist >= 80 else 0 | |
| features["Fasting_glu_high"] = 1 if fasting_glucose >= 110 else 0 | |
| features["TC_ge200"] = 1 if total_cholesterol >= 200 else 0 | |
| features["UA_high"] = 1 if uric_acid >= 6 else 0 | |
| features["Smoking_current"] = 1 if smoking == "Current smoker" else 0 | |
| features["Alcohol_current"] = 1 if alcohol == "Regular drinker" else 0 | |
| features["Exercise_freq"] = 1 if exercise == "Regular (≥3 times/week)" else 0 | |
| features["FHx_yes"] = 1 if family_history == "Yes" else 0 | |
| return features | |
| # Get patient features and info string | |
| patient_features = create_feature_dict() | |
| def get_patient_info_string(): | |
| """Generate patient info string for AI context""" | |
| risk_factors = [] | |
| if patient_features.get("BMI_ge25", 0) == 1: | |
| risk_factors.append(f"BMI {bmi:.1f} kg/m² (≥25)") | |
| if (sex == "Male" and patient_features.get("Waist_ge90", 0) == 1) or \ | |
| (sex == "Female" and patient_features.get("Waist_ge80", 0) == 1): | |
| risk_factors.append(f"Waist circumference {waist} cm (high)") | |
| if patient_features.get("Smoking_current", 0) == 1: | |
| risk_factors.append("Current smoker") | |
| if patient_features.get("Alcohol_current", 0) == 1: | |
| risk_factors.append("Regular alcohol consumption") | |
| if sex == "Male" and patient_features.get("Betel_current", 0) == 1: | |
| risk_factors.append("Betel nut chewing") | |
| if patient_features.get("Exercise_freq", 0) == 0: | |
| risk_factors.append("Insufficient exercise") | |
| if patient_features.get("FHx_yes", 0) == 1: | |
| risk_factors.append("Family history of hypertension") | |
| if patient_features.get("UA_high", 0) == 1: | |
| risk_factors.append(f"High uric acid ({uric_acid:.1f} mg/dL)") | |
| if patient_features.get("Fasting_glu_high", 0) == 1: | |
| risk_factors.append(f"High fasting glucose ({fasting_glucose} mg/dL)") | |
| if patient_features.get("TC_ge200", 0) == 1: | |
| risk_factors.append(f"High total cholesterol ({total_cholesterol} mg/dL)") | |
| patient_id_str = "" | |
| if 'loaded_patient' in st.session_state: | |
| patient_id_str = f"Patient ID: {st.session_state.loaded_patient['patient_id']}\n" | |
| info = f"""{patient_id_str}{sex}, Age {age} | |
| BMI: {bmi:.1f} kg/m² | |
| Waist: {waist} cm | |
| Exercise: {exercise} | |
| Smoking: {smoking} | |
| Alcohol: {alcohol} | |
| Risk Factors Identified: | |
| """ | |
| if risk_factors: | |
| info += "\n".join(f"- {rf}" for rf in risk_factors) | |
| else: | |
| info += "- No major modifiable risk factors detected" | |
| return info | |
| def get_cea_results_string(): | |
| """Generate CEA results string for AI context""" | |
| if st.session_state.cea_results is None: | |
| return "No cost-effectiveness analysis has been run yet. Please complete the analysis in Tab 4 or Tab 5 first." | |
| cea = st.session_state.cea_results | |
| result_text = f""" | |
| COST-EFFECTIVENESS ANALYSIS RESULTS: | |
| ═══════════════════════════════════════════════ | |
| Analysis Type: {cea.get('analysis_type', 'Not specified')} | |
| Intervention: {cea.get('intervention_name', 'Not specified')} | |
| Time Horizon: {cea.get('time_horizon', 'N/A')} years | |
| INCREMENTAL RESULTS: | |
| - Incremental Cost: ${cea.get('cost_diff', 0):.2f} | |
| - Incremental QALYs: {cea.get('qaly_diff', 0):.4f} | |
| - ICER: ${cea.get('icer', 0):.2f} per QALY | |
| INTERPRETATION: | |
| - WTP Threshold: ${cea.get('wtp_threshold', 50000):,}/QALY | |
| - Decision: {cea.get('decision', 'Not determined')} | |
| """ | |
| # Add PSA results if available | |
| if cea.get('analysis_type') == 'Probabilistic' and 'psa_mean_cost' in cea: | |
| result_text += f""" | |
| PROBABILISTIC SENSITIVITY ANALYSIS (PSA): | |
| - Mean Incremental Cost: ${cea.get('psa_mean_cost', 0):.2f} (95% CI: ${cea.get('psa_cost_ci_low', 0):.2f} to ${cea.get('psa_cost_ci_high', 0):.2f}) | |
| - Mean Incremental QALYs: {cea.get('psa_mean_qaly', 0):.4f} (95% CI: {cea.get('psa_qaly_ci_low', 0):.4f} to {cea.get('psa_qaly_ci_high', 0):.4f}) | |
| - Mean ICER: ${cea.get('psa_mean_icer', 0):.2f} per QALY | |
| - Probability Cost-Effective at ${cea.get('wtp_threshold', 50000):,}/QALY: {cea.get('prob_ce', 0)*100:.1f}% | |
| """ | |
| result_text += """ | |
| ═══════════════════════════════════════════════ | |
| """ | |
| return result_text | |
| # Create tabs | |
| tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([ | |
| "🤖 AI Assistant", | |
| "📊 Hypertension Progression", | |
| "📈 Intervention Comparison", | |
| "💰 CEA (Deterministic)", | |
| "🎲 CEA (Probabilistic)", | |
| "💬 personalized CEA chatbot" | |
| ]) | |
| # Tab 1: AI Assistant | |
| with tab1: | |
| st.subheader("AI Assistant - Ask Questions About This Tool") | |
| if not openai_api_key: | |
| st.warning("⚠️ Please enter your OpenAI API key in the top right corner to use the AI Assistant.") | |
| else: | |
| # Display chat messages | |
| for message in st.session_state.assistant_messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Chat input | |
| if prompt := st.chat_input("Ask me anything about hypertension analysis..."): | |
| # Add user message | |
| st.session_state.assistant_messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Get AI response | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| try: | |
| llm = get_llm() | |
| if llm: | |
| history_text = "" | |
| for msg in st.session_state.assistant_messages[-10:]: | |
| role = "User" if msg["role"] == "user" else "Assistant" | |
| history_text += f"{role}: {msg['content']}\n\n" | |
| full_prompt = f"""You are an expert AI assistant for a Hypertension Cost-Effectiveness Analysis tool. | |
| Your role is to help users understand: | |
| - How to use this analysis tool | |
| - Markov model methodology (4 states: Normal → Prehypertension → Stage 1 HTN → Stage 2 HTN) | |
| - Cost-effectiveness metrics (ICER, QALY, CEAC) | |
| - Risk factor interpretation | |
| - Available interventions | |
| Be clear, concise, and educational. Use examples when helpful. | |
| IMPORTANT: | |
| - Use plain text formatting only (no LaTeX, no \\text{{}} or \\frac{{}}{{}} syntax) | |
| - Write mathematical formulas in plain text like: ICER = (Cost_B - Cost_A) / (QALY_B - QALY_A) | |
| - Use simple markdown formatting (**, -, numbers) for emphasis | |
| - Avoid special characters that may not render correctly | |
| Conversation History: | |
| {history_text} | |
| User Question: {prompt} | |
| Your Response:""" | |
| response = llm.invoke(full_prompt).content | |
| response = response.replace("\\text{", "").replace("}", "") | |
| response = response.replace("\\frac{", "(").replace("}{", ")/(") | |
| st.markdown(response, unsafe_allow_html=False) | |
| st.session_state.assistant_messages.append({"role": "assistant", "content": response}) | |
| else: | |
| st.error("Failed to initialize AI. Please check your API key.") | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # Tab 2: Hypertension Progression | |
| with tab2: | |
| st.subheader("Baseline Hypertension Progression Risk") | |
| # Create columns for different visualizations | |
| prog_col1, prog_col2 = st.columns([1, 1]) | |
| # Prediction settings | |
| with prog_col1: | |
| st.write("Progression Projection Settings") | |
| projection_years = st.slider("Projection Horizon (Years)", min_value=1, max_value=20, value=10) | |
| # Calculate the transition rates | |
| sex_code = "M" if sex == "Male" else "F" | |
| l1, l2, l3, r = hazards_from_beta(sex_code, patient_features, | |
| lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05, randomize=True) | |
| # Display the annual transition rates | |
| st.write("Annual Transition Rates:") | |
| rates_df = pd.DataFrame({ | |
| "Transition": ["Normal → Prehypertension", "Prehypertension → Stage 1", | |
| "Stage 1 → Stage 2", "Prehypertension → Normal"], | |
| "Annual Rate (%)": [l1 * 100, l2 * 100, l3 * 100, r * 100] | |
| }) | |
| st.dataframe(rates_df) | |
| # Calculate lifetime risk | |
| Q = Q_matrix(l1, l2, l3, r) | |
| P = discrete_P(Q, 1.0) | |
| # Project the state distribution | |
| states = ["Normal", "Prehypertension", "Stage 1", "Stage 2"] | |
| s = np.array([1, 0, 0, 0], float) # Start in Normal state | |
| # Project over time | |
| projections = [s.copy()] | |
| for _ in range(projection_years): | |
| s = s @ P | |
| projections.append(s.copy()) | |
| proj_df = pd.DataFrame(projections, columns=states) | |
| proj_df.index.name = "Year" | |
| # Calculate lifetime risk of progressing to Stage 2 | |
| lifetime_risk_s2 = proj_df["Stage 2"].iloc[-1] * 100 | |
| # Visualization of progression over time | |
| with prog_col2: | |
| # Plot state distribution over time | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| for i, state in enumerate(states): | |
| ax.plot(range(projection_years + 1), proj_df[state], label=state) | |
| ax.set_xlabel("Year") | |
| ax.set_ylabel("Proportion") | |
| ax.set_title("Projected Hypertension State Distribution Over Time") | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| # Summary metrics | |
| st.subheader("Summary Risk Metrics") | |
| risk_cols = st.columns(5) | |
| # 5-year risks | |
| risk_5yr_p = projections[5][1] | |
| risk_5yr_htn = projections[5][2] + projections[5][3] | |
| risk_5yr_s2 = projections[5][3] | |
| # 10-year risks | |
| year_10_idx = min(10, projection_years) | |
| risk_10yr_p = projections[year_10_idx][1] | |
| risk_10yr_htn = projections[year_10_idx][2] + projections[year_10_idx][3] | |
| risk_10yr_s2 = projections[year_10_idx][3] | |
| # Display risk metrics | |
| with risk_cols[0]: | |
| st.metric("5-Year Prehypertension Risk", f"{risk_5yr_p * 100:.1f}%") | |
| with risk_cols[1]: | |
| st.metric("5-Year Any Hypertension Risk", f"{risk_5yr_htn * 100:.1f}%") | |
| with risk_cols[2]: | |
| st.metric("5-Year Stage 2 Risk", f"{risk_5yr_s2 * 100:.1f}%") | |
| with risk_cols[3]: | |
| st.metric("10-Year Any Hypertension Risk", f"{risk_10yr_htn * 100:.1f}%") | |
| with risk_cols[4]: | |
| st.metric("Lifetime Stage 2 Risk", f"{lifetime_risk_s2:.1f}%") | |
| # Comparison to population averages | |
| st.write("#### Risk Comparison to Population Average") | |
| avg_5yr_htn = 0.08 | |
| avg_10yr_htn = 0.18 | |
| risk_ratio_5yr = risk_5yr_htn / avg_5yr_htn if avg_5yr_htn > 0 else 1.0 | |
| risk_ratio_10yr = risk_10yr_htn / avg_10yr_htn if avg_10yr_htn > 0 else 1.0 | |
| st.write(f"This patient's 5-year risk of hypertension is **{risk_ratio_5yr:.1f}x** the population average.") | |
| st.write(f"This patient's 10-year risk of hypertension is **{risk_ratio_10yr:.1f}x** the population average.") | |
| # Risk factors explanation | |
| st.write("#### Key Risk Factors") | |
| # Identify high risk factors | |
| risk_factors = [] | |
| if patient_features.get("BMI_ge25", 0) == 1: | |
| risk_factors.append(f"BMI ≥ 25 kg/m² (current: {bmi:.1f})") | |
| if (sex == "Male" and patient_features.get("Waist_ge90", 0) == 1) or ( | |
| sex == "Female" and patient_features.get("Waist_ge80", 0) == 1): | |
| risk_factors.append(f"High waist circumference (current: {waist} cm)") | |
| if patient_features.get("Smoking_current", 0) == 1: | |
| risk_factors.append("Current smoker") | |
| if patient_features.get("Alcohol_current", 0) == 1: | |
| risk_factors.append("Regular alcohol consumption") | |
| if sex == "Male" and patient_features.get("Betel_current", 0) == 1: | |
| risk_factors.append("Betel nut chewing") | |
| if patient_features.get("Exercise_freq", 0) == 0: | |
| risk_factors.append("Infrequent exercise") | |
| if patient_features.get("FHx_yes", 0) == 1: | |
| risk_factors.append("Family history of hypertension") | |
| if patient_features.get("UA_high", 0) == 1: | |
| risk_factors.append(f"High uric acid (current: {uric_acid:.1f} mg/dL)") | |
| if patient_features.get("Fasting_glu_high", 0) == 1: | |
| risk_factors.append(f"High fasting glucose (current: {fasting_glucose} mg/dL)") | |
| if patient_features.get("TC_ge200", 0) == 1: | |
| risk_factors.append(f"High total cholesterol (current: {total_cholesterol} mg/dL)") | |
| # Display risk factors | |
| if risk_factors: | |
| st.write("This patient has the following risk factors:") | |
| for factor in risk_factors: | |
| st.write(f"- {factor}") | |
| else: | |
| st.write("This patient has no major modifiable risk factors.") | |
| # Tab 3: Intervention Comparison | |
| with tab3: | |
| st.subheader("Compare Intervention Effects") | |
| # Choose interventions to compare | |
| available_interventions = [] | |
| if patient_features.get("BMI_ge25", 0) == 1: | |
| available_interventions.append(("BMI_ge25", "Weight Loss (BMI < 25 kg/m²)")) | |
| if (sex == "Male" and patient_features.get("Waist_ge90", 0) == 1) or ( | |
| sex == "Female" and patient_features.get("Waist_ge80", 0) == 1): | |
| waist_feature = "Waist_ge90" if sex == "Male" else "Waist_ge80" | |
| available_interventions.append((waist_feature, "Waist Circumference Reduction")) | |
| if patient_features.get("Smoking_current", 0) == 1: | |
| available_interventions.append(("Smoking_current", "Smoking Cessation")) | |
| if patient_features.get("Alcohol_current", 0) == 1: | |
| available_interventions.append(("Alcohol_current", "Alcohol Reduction")) | |
| if sex == "Male" and patient_features.get("Betel_current", 0) == 1: | |
| available_interventions.append(("Betel_current", "Betel Nut Cessation")) | |
| if patient_features.get("Exercise_freq", 0) == 0: | |
| available_interventions.append(("Exercise_freq", "Regular Exercise")) | |
| if patient_features.get("UA_high", 0) == 1: | |
| available_interventions.append(("UA_high", "Uric Acid Reduction")) | |
| if patient_features.get("TC_ge200", 0) == 1: | |
| available_interventions.append(("TC_ge200", "Cholesterol Reduction")) | |
| if patient_features.get("Fasting_glu_high", 0) == 1: | |
| available_interventions.append(("Fasting_glu_high", "Glucose Control")) | |
| if not available_interventions: | |
| st.write("No modifiable risk factors available for intervention.") | |
| else: | |
| selected_intervention_names = st.multiselect( | |
| "Select interventions to compare:", | |
| [name for _, name in available_interventions], | |
| max_selections=3 | |
| ) | |
| selected_interventions = [ | |
| feature for feature, name in available_interventions | |
| if name in selected_intervention_names | |
| ] | |
| comp_years = st.slider("Comparison Projection (Years)", min_value=1, max_value=20, value=10, key="comp_years") | |
| if selected_interventions: | |
| sex_code = "M" if sex == "Male" else "F" | |
| l1_base, l2_base, l3_base, r_base = hazards_from_beta( | |
| sex_code, patient_features, | |
| lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05 | |
| ) | |
| Q_base = Q_matrix(l1_base, l2_base, l3_base, r_base) | |
| P_base = discrete_P(Q_base, 1.0) | |
| states = ["Normal", "Prehypertension", "Stage 1", "Stage 2"] | |
| s_base = np.array([1, 0, 0, 0], float) | |
| proj_base = [s_base.copy()] | |
| for _ in range(comp_years): | |
| s_base = s_base @ P_base | |
| proj_base.append(s_base.copy()) | |
| proj_base_df = pd.DataFrame(proj_base, columns=states) | |
| proj_base_df.index.name = "Year" | |
| intervention_data = [] | |
| for feature in selected_interventions: | |
| int_features = patient_features.copy() | |
| if feature == "Exercise_freq": | |
| int_features[feature] = 1 | |
| else: | |
| int_features[feature] = 0 | |
| l1_int, l2_int, l3_int, r_int = hazards_from_beta( | |
| sex_code, int_features, | |
| lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05 | |
| ) | |
| Q_int = Q_matrix(l1_int, l2_int, l3_int, r_int) | |
| P_int = discrete_P(Q_int, 1.0) | |
| s_int = np.array([1, 0, 0, 0], float) | |
| proj_int = [s_int.copy()] | |
| for _ in range(comp_years): | |
| s_int = s_int @ P_int | |
| proj_int.append(s_int.copy()) | |
| proj_int_df = pd.DataFrame(proj_int, columns=states) | |
| baseline_htn_risk = proj_base_df["Stage 1"].iloc[-1] + proj_base_df["Stage 2"].iloc[-1] | |
| int_htn_risk = proj_int_df["Stage 1"].iloc[-1] + proj_int_df["Stage 2"].iloc[-1] | |
| absolute_risk_reduction = baseline_htn_risk - int_htn_risk | |
| relative_risk_reduction = absolute_risk_reduction / baseline_htn_risk if baseline_htn_risk > 0 else 0 | |
| nnt = 1 / absolute_risk_reduction if absolute_risk_reduction > 0 else float('inf') | |
| int_name = next(name for feat, name in available_interventions if feat == feature) | |
| intervention_data.append({ | |
| "feature": feature, | |
| "name": int_name, | |
| "projection": proj_int_df, | |
| "risk_reduction_abs": absolute_risk_reduction, | |
| "risk_reduction_rel": relative_risk_reduction, | |
| "nnt": nnt | |
| }) | |
| st.write("#### Hypertension Risk Comparison") | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| baseline_htn_risk = [ | |
| proj_base_df["Stage 1"].iloc[i] + proj_base_df["Stage 2"].iloc[i] | |
| for i in range(len(proj_base_df)) | |
| ] | |
| ax.plot(range(comp_years + 1), baseline_htn_risk, 'k-', linewidth=2, label="No Intervention") | |
| colors = ['b', 'g', 'r', 'c', 'm', 'y'] | |
| for i, int_data in enumerate(intervention_data): | |
| int_htn_risk = [ | |
| int_data["projection"]["Stage 1"].iloc[j] + int_data["projection"]["Stage 2"].iloc[j] | |
| for j in range(len(int_data["projection"])) | |
| ] | |
| ax.plot( | |
| range(comp_years + 1), | |
| int_htn_risk, | |
| f"{colors[i % len(colors)]}-", | |
| linewidth=2, | |
| label=int_data["name"] | |
| ) | |
| ax.set_xlabel("Year") | |
| ax.set_ylabel("Probability of Hypertension (Stage 1 or 2)") | |
| ax.set_title(f"Effect of Interventions on {comp_years}-Year Hypertension Risk") | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| st.write("#### Effectiveness Comparison") | |
| metrics_data = { | |
| "Intervention": ["No Intervention"] + [int_data["name"] for int_data in intervention_data], | |
| f"{comp_years}-Year HTN Risk": [ | |
| baseline_htn_risk[-1] * 100 | |
| ] + [ | |
| (baseline_htn_risk[-1] - int_data["risk_reduction_abs"]) * 100 | |
| for int_data in intervention_data | |
| ], | |
| "Absolute Risk Reduction (%)": [ | |
| 0 | |
| ] + [ | |
| int_data["risk_reduction_abs"] * 100 | |
| for int_data in intervention_data | |
| ], | |
| "Relative Risk Reduction (%)": [ | |
| 0 | |
| ] + [ | |
| int_data["risk_reduction_rel"] * 100 | |
| for int_data in intervention_data | |
| ], | |
| "Number Needed to Treat": [ | |
| "N/A" | |
| ] + [ | |
| f"{int_data['nnt']:.1f}" if int_data["nnt"] < 100 else "100+" | |
| for int_data in intervention_data | |
| ] | |
| } | |
| metrics_df = pd.DataFrame(metrics_data) | |
| st.table(metrics_df.set_index("Intervention")) | |
| if intervention_data: | |
| most_effective = max(intervention_data, key=lambda x: x["risk_reduction_abs"]) | |
| st.info( | |
| f"**Recommendation**: Based on this analysis, " | |
| f"**{most_effective['name']}** provides the greatest reduction in " | |
| f"{comp_years}-year hypertension risk " | |
| f"({most_effective['risk_reduction_abs'] * 100:.1f}% absolute reduction)." | |
| ) | |
| else: | |
| st.write("Please select at least one intervention to compare.") | |
| # Tab 4: Cost-Effectiveness Analysis (Deterministic) | |
| # Tab 4: Cost-Effectiveness Analysis (Deterministic) | |
| with tab4: | |
| st.subheader("Cost-Effectiveness Analysis (Deterministic)") | |
| st.info("📌 This analysis uses **point estimates** (single values) for all parameters") | |
| st.write("### Analysis Settings") | |
| if available_interventions: | |
| cea_intervention = st.selectbox( | |
| "Select intervention to analyze:", | |
| [name for _, name in available_interventions], | |
| index=0 | |
| ) | |
| cea_feature = next( | |
| feature for feature, name in available_interventions | |
| if name == cea_intervention | |
| ) | |
| param_col1, param_col2, param_col3 = st.columns(3) | |
| with param_col1: | |
| cea_cycles = st.slider("Time Horizon (Years)", min_value=5, max_value=30, value=10) | |
| discount_rate = st.slider("Discount Rate (%)", min_value=0, max_value=10, value=3) / 100 | |
| with param_col2: | |
| st.write("#### Cost Settings ($ per year)") | |
| cost_normal = st.number_input("Cost - Normal BP", min_value=0, max_value=5000, value=200) | |
| cost_pre = st.number_input("Cost - Prehypertension", min_value=0, max_value=5000, value=600) | |
| cost_s1 = st.number_input("Cost - Stage 1 HTN", min_value=0, max_value=5000, value=1200) | |
| cost_s2 = st.number_input("Cost - Stage 2 HTN", min_value=0, max_value=5000, value=2200) | |
| with param_col3: | |
| st.write("#### Utility Settings (QOL 0-1)") | |
| util_normal = st.slider("Utility - Normal BP", min_value=0.0, max_value=1.0, value=1.0, step=0.05) | |
| util_pre = st.slider("Utility - Prehypertension", min_value=0.0, max_value=1.0, value=0.9, step=0.05) | |
| util_s1 = st.slider("Utility - Stage 1 HTN", min_value=0.0, max_value=1.0, value=0.7, step=0.05) | |
| util_s2 = st.slider("Utility - Stage 2 HTN", min_value=0.0, max_value=1.0, value=0.5, step=0.05) | |
| st.write("#### Intervention Settings") | |
| int_cost_increase = st.number_input( | |
| "Additional Intervention Cost ($/year)", | |
| min_value=0, | |
| max_value=2000, | |
| value=500 | |
| ) | |
| st.markdown("---") | |
| C_A = np.array([cost_normal, cost_pre, cost_s1, cost_s2]) | |
| C_B = C_A.copy() | |
| C_B[0] += int_cost_increase | |
| U = np.array([util_normal, util_pre, util_s1, util_s2]) | |
| sex_code = "M" if sex == "Male" else "F" | |
| try: | |
| # Calculate baseline scenario | |
| l1_base, l2_base, l3_base, r_base = hazards_from_beta( | |
| sex_code, patient_features, | |
| lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05 | |
| ) | |
| Q_base = Q_matrix(l1_base, l2_base, l3_base, r_base) | |
| P_base = discrete_P(Q_base, 1.0) | |
| start_dist = np.array([1, 0, 0, 0], float) | |
| cost_A, qaly_A, _ = run_markov(P_base, C_A, U, start_dist, cea_cycles, discount_rate) | |
| # Calculate intervention scenario | |
| int_features = patient_features.copy() | |
| if cea_feature == "Exercise_freq": | |
| int_features[cea_feature] = 1 | |
| else: | |
| int_features[cea_feature] = 0 | |
| l1_int, l2_int, l3_int, r_int = hazards_from_beta( | |
| sex_code, int_features, | |
| lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05 | |
| ) | |
| Q_int = Q_matrix(l1_int, l2_int, l3_int, r_int) | |
| P_int = discrete_P(Q_int, 1.0) | |
| cost_B, qaly_B, _ = run_markov(P_int, C_B, U, start_dist, cea_cycles, discount_rate) | |
| # Calculate incremental values | |
| cost_diff = cost_B - cost_A | |
| qaly_diff = qaly_B - qaly_A | |
| icer_val, _, _ = icer(cost_A, qaly_A, cost_B, qaly_B) # Unpack tuple (ICER, deltaC, deltaQ) | |
| # ✅ 新增:成本效益平面圖 (單點) | |
| st.write("### Cost-Effectiveness Plane") | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| # Plot the single point | |
| ax.scatter(qaly_diff, cost_diff, s=300, c='red', marker='*', | |
| edgecolors='black', linewidths=2, label='Intervention vs Baseline', zorder=5) | |
| # Add quadrant lines | |
| ax.axhline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1) | |
| ax.axvline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1) | |
| # Add WTP threshold line | |
| wtp_threshold = 50000 | |
| xlim = ax.get_xlim() | |
| ylim = ax.get_ylim() | |
| # Extend the line across the plot | |
| x_range = np.linspace(min(xlim[0], -0.01), max(xlim[1], 0.01), 100) | |
| ax.plot(x_range, x_range * wtp_threshold, 'k--', alpha=0.5, | |
| linewidth=2, label=f'WTP ${wtp_threshold:,}/QALY') | |
| # Add quadrant labels | |
| ax.text(0.95, 0.95, 'More Costly\nMore Effective', | |
| transform=ax.transAxes, ha='right', va='top', fontsize=10, alpha=0.5) | |
| ax.text(0.05, 0.95, 'More Costly\nLess Effective', | |
| transform=ax.transAxes, ha='left', va='top', fontsize=10, alpha=0.5) | |
| ax.text(0.95, 0.05, 'Less Costly\nMore Effective', | |
| transform=ax.transAxes, ha='right', va='bottom', fontsize=10, alpha=0.5) | |
| ax.text(0.05, 0.05, 'Less Costly\nLess Effective', | |
| transform=ax.transAxes, ha='left', va='bottom', fontsize=10, alpha=0.5) | |
| ax.set_xlabel("Incremental QALYs", fontsize=12, fontweight='bold') | |
| ax.set_ylabel("Incremental Cost ($)", fontsize=12, fontweight='bold') | |
| ax.set_title("Cost-Effectiveness Plane (Deterministic Analysis)", | |
| fontsize=14, fontweight='bold') | |
| ax.legend(fontsize=11) | |
| ax.grid(alpha=0.3) | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| st.write("### Summary Metrics") | |
| metrics_cols = st.columns(3) | |
| with metrics_cols[0]: | |
| st.metric("Incremental Cost", f"${cost_diff:.2f}") | |
| with metrics_cols[1]: | |
| st.metric("Incremental QALYs", f"{qaly_diff:.3f}") | |
| with metrics_cols[2]: | |
| st.metric("ICER ($/QALY)", f"${icer_val:.2f}") | |
| # ✅ 修改:決策建議(不使用機率) | |
| st.write("### Cost-Effectiveness Decision") | |
| wtp_threshold = 50000 | |
| # Determine decision text | |
| if qaly_diff > 0: | |
| if cost_diff <= 0: | |
| decision_text = "DOMINANT: Saves money and improves health" | |
| st.success("✅ **DOMINANT**: This intervention saves money and improves health.") | |
| st.info( | |
| f"💡 The intervention provides {qaly_diff:.3f} additional QALYs while saving ${-cost_diff:.2f}.") | |
| elif icer_val < wtp_threshold: | |
| decision_text = f"COST-EFFECTIVE: ICER below ${wtp_threshold:,}/QALY threshold" | |
| st.success( | |
| f"✅ **COST-EFFECTIVE**: ICER = ${icer_val:.2f}/QALY is below the ${wtp_threshold:,}/QALY threshold.") | |
| st.info( | |
| f"💡 For every additional QALY gained, it costs ${icer_val:.2f}, which is considered acceptable.") | |
| else: | |
| decision_text = f"NOT COST-EFFECTIVE: ICER exceeds ${wtp_threshold:,}/QALY threshold" | |
| st.warning( | |
| f"⚠️ **NOT COST-EFFECTIVE**: ICER = ${icer_val:.2f}/QALY exceeds the ${wtp_threshold:,}/QALY threshold.") | |
| st.info( | |
| f"💡 The intervention would need to cost ${qaly_diff * wtp_threshold:.2f} or less to be cost-effective at this threshold.") | |
| else: | |
| if cost_diff >= 0: | |
| decision_text = "DOMINATED: Costs more and worsens health" | |
| st.error("❌ **DOMINATED**: This intervention costs more and worsens health outcomes.") | |
| else: | |
| decision_text = "TRADE-OFF: Saves money but reduces QALYs" | |
| st.warning("⚠️ **TRADE-OFF**: This intervention saves money but reduces QALYs.") | |
| st.info(f"💡 It saves ${-cost_diff:.2f} but loses {-qaly_diff:.3f} QALYs.") | |
| # Store CEA results in session state | |
| st.session_state.cea_results = { | |
| 'analysis_type': 'Deterministic', | |
| 'intervention_name': cea_intervention, | |
| 'time_horizon': cea_cycles, | |
| 'cost_diff': cost_diff, | |
| 'qaly_diff': qaly_diff, | |
| 'icer': icer_val, | |
| 'wtp_threshold': wtp_threshold, | |
| 'decision': decision_text | |
| } | |
| # ✅ 新增:敏感度分析表格 | |
| st.write("### Sensitivity to WTP Threshold") | |
| st.caption( | |
| "This shows whether the intervention would be considered cost-effective at different willingness-to-pay thresholds.") | |
| wtp_thresholds = [25000, 50000, 75000, 100000, 150000] | |
| decision_data = [] | |
| for wtp in wtp_thresholds: | |
| if qaly_diff > 0: | |
| if cost_diff <= 0: | |
| decision = "✅ Dominant (Cost-Effective)" | |
| elif icer_val < wtp: | |
| decision = "✅ Cost-Effective" | |
| else: | |
| decision = "❌ Not Cost-Effective" | |
| else: | |
| if cost_diff >= 0: | |
| decision = "❌ Dominated" | |
| else: | |
| decision = "⚠️ Saves Money, Loses QALYs" | |
| decision_data.append({ | |
| "WTP Threshold": f"${wtp:,}/QALY", | |
| "Decision": decision | |
| }) | |
| decision_df = pd.DataFrame(decision_data) | |
| st.table(decision_df.set_index("WTP Threshold")) | |
| except Exception as e: | |
| st.error(f"An error occurred while running the cost-effectiveness analysis: {str(e)}") | |
| st.error("Try different parameters or a different intervention.") | |
| else: | |
| st.write("No modifiable risk factors available for intervention.") | |
| # Tab 5: Cost-Effectiveness Analysis (Probabilistic) | |
| with tab5: | |
| st.subheader("Cost-Effectiveness Analysis (Probabilistic)") | |
| st.info( | |
| "📌 This analysis uses **distributions** (with Standard Errors) for all parameters to account for uncertainty") | |
| st.write("### Analysis Settings") | |
| if available_interventions: | |
| psa_intervention = st.selectbox( | |
| "Select intervention to analyze:", | |
| [name for _, name in available_interventions], | |
| index=0, | |
| key="psa_intervention" | |
| ) | |
| psa_feature = next( | |
| feature for feature, name in available_interventions | |
| if name == psa_intervention | |
| ) | |
| psa_col1, psa_col2, psa_col3 = st.columns(3) | |
| with psa_col1: | |
| st.write("#### Simulation Settings") | |
| psa_cycles = st.slider("Time Horizon (Years)", min_value=5, max_value=30, value=10, key="psa_cycles") | |
| psa_discount_rate = st.slider("Discount Rate (%)", min_value=0, max_value=10, value=3, | |
| key="psa_discount") / 100 | |
| n_simulations = st.slider("Number of Simulations", min_value=100, max_value=10000, value=1000, step=100) | |
| with psa_col2: | |
| st.write("#### Cost Parameters (Mean ± SE)") | |
| cost_normal_mean = st.number_input("Cost - Normal BP (Mean)", min_value=0, max_value=5000, value=200, | |
| key="psa_cn_mean") | |
| cost_normal_se = st.number_input("Cost - Normal BP (SE)", min_value=0, max_value=500, value=20, | |
| key="psa_cn_se") | |
| cost_pre_mean = st.number_input("Cost - Prehypertension (Mean)", min_value=0, max_value=5000, value=600, | |
| key="psa_cp_mean") | |
| cost_pre_se = st.number_input("Cost - Prehypertension (SE)", min_value=0, max_value=500, value=60, | |
| key="psa_cp_se") | |
| cost_s1_mean = st.number_input("Cost - Stage 1 HTN (Mean)", min_value=0, max_value=5000, value=1200, | |
| key="psa_cs1_mean") | |
| cost_s1_se = st.number_input("Cost - Stage 1 HTN (SE)", min_value=0, max_value=500, value=120, | |
| key="psa_cs1_se") | |
| cost_s2_mean = st.number_input("Cost - Stage 2 HTN (Mean)", min_value=0, max_value=5000, value=2200, | |
| key="psa_cs2_mean") | |
| cost_s2_se = st.number_input("Cost - Stage 2 HTN (SE)", min_value=0, max_value=500, value=220, | |
| key="psa_cs2_se") | |
| with psa_col3: | |
| st.write("#### Utility Parameters (Mean ± SE)") | |
| util_normal_mean = st.slider("Utility - Normal BP (Mean)", 0.0, 1.0, 1.0, 0.01, key="psa_un_mean") | |
| util_normal_se = st.slider("Utility - Normal BP (SE)", 0.0, 0.1, 0.01, 0.001, key="psa_un_se") | |
| util_pre_mean = st.slider("Utility - Prehypertension (Mean)", 0.0, 1.0, 0.9, 0.01, key="psa_up_mean") | |
| util_pre_se = st.slider("Utility - Prehypertension (SE)", 0.0, 0.1, 0.02, 0.001, key="psa_up_se") | |
| util_s1_mean = st.slider("Utility - Stage 1 HTN (Mean)", 0.0, 1.0, 0.7, 0.01, key="psa_us1_mean") | |
| util_s1_se = st.slider("Utility - Stage 1 HTN (SE)", 0.0, 0.1, 0.03, 0.001, key="psa_us1_se") | |
| util_s2_mean = st.slider("Utility - Stage 2 HTN (Mean)", 0.0, 1.0, 0.5, 0.01, key="psa_us2_mean") | |
| util_s2_se = st.slider("Utility - Stage 2 HTN (SE)", 0.0, 0.1, 0.05, 0.001, key="psa_us2_se") | |
| st.write("#### Intervention Settings") | |
| psa_int_cost_mean = st.number_input("Additional Intervention Cost (Mean)", 0, 2000, 500, key="psa_int_cost_mean") | |
| psa_int_cost_se = st.number_input("Additional Intervention Cost (SE)", 0, 200, 50, key="psa_int_cost_se") | |
| st.markdown("---") | |
| if st.button("🎲 Run Probabilistic Analysis", type="primary"): | |
| with st.spinner(f"Running {n_simulations} Monte Carlo simulations..."): | |
| try: | |
| # Storage for simulation results | |
| results_cost_A = [] | |
| results_cost_B = [] | |
| results_qaly_A = [] | |
| results_qaly_B = [] | |
| results_icer = [] | |
| results_delta_cost = [] | |
| results_delta_qaly = [] | |
| progress_bar = st.progress(0) | |
| for sim in range(n_simulations): | |
| # Sample from distributions (using normal distribution with SE) | |
| # For costs - use gamma distribution (non-negative) | |
| # For utilities - use beta distribution (bounded 0-1) | |
| # Sample costs (using gamma approximation) | |
| def sample_cost(mean, se): | |
| if se == 0: | |
| return mean | |
| shape = (mean / se) ** 2 | |
| scale = se ** 2 / mean | |
| return np.random.gamma(shape, scale) | |
| # Sample utilities (using beta approximation) | |
| def sample_utility(mean, se): | |
| if se == 0 or mean == 0 or mean == 1: | |
| return np.clip(mean, 0, 1) | |
| # Beta distribution parameters | |
| alpha = mean * ((mean * (1 - mean) / (se ** 2)) - 1) | |
| beta = (1 - mean) * ((mean * (1 - mean) / (se ** 2)) - 1) | |
| if alpha > 0 and beta > 0: | |
| return np.random.beta(alpha, beta) | |
| else: | |
| return np.clip(np.random.normal(mean, se), 0, 1) | |
| # Sample parameters for this iteration | |
| C_A_sim = np.array([ | |
| sample_cost(cost_normal_mean, cost_normal_se), | |
| sample_cost(cost_pre_mean, cost_pre_se), | |
| sample_cost(cost_s1_mean, cost_s1_se), | |
| sample_cost(cost_s2_mean, cost_s2_se) | |
| ]) | |
| int_cost_add = sample_cost(psa_int_cost_mean, psa_int_cost_se) | |
| C_B_sim = C_A_sim.copy() | |
| C_B_sim[0] += int_cost_add | |
| U_sim = np.array([ | |
| sample_utility(util_normal_mean, util_normal_se), | |
| sample_utility(util_pre_mean, util_pre_se), | |
| sample_utility(util_s1_mean, util_s1_se), | |
| sample_utility(util_s2_mean, util_s2_se) | |
| ]) | |
| # Run analysis with sampled parameters | |
| sex_code = "M" if sex == "Male" else "F" | |
| # Get transition matrices (using point estimates for transition probabilities) | |
| l1, l2, l3, r = hazards_from_beta(sex_code, patient_features, | |
| lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05) | |
| # Baseline | |
| Q_A = Q_matrix(l1, l2, l3, r) | |
| P_A = discrete_P(Q_A, 1.0) | |
| # Intervention | |
| int_features = patient_features.copy() | |
| if psa_feature == "Exercise_freq": | |
| int_features[psa_feature] = 1 | |
| else: | |
| int_features[psa_feature] = 0 | |
| l1_b, l2_b, l3_b, r_b = hazards_from_beta(sex_code, int_features, | |
| lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05) | |
| Q_B = Q_matrix(l1_b, l2_b, l3_b, r_b) | |
| P_B = discrete_P(Q_B, 1.0) | |
| # Run Markov models | |
| start_dist = np.array([1, 0, 0, 0], float) | |
| cost_A_sim, qaly_A_sim, _ = run_markov(P_A, C_A_sim, U_sim, start_dist, | |
| psa_cycles, psa_discount_rate) | |
| cost_B_sim, qaly_B_sim, _ = run_markov(P_B, C_B_sim, U_sim, start_dist, | |
| psa_cycles, psa_discount_rate) | |
| # Calculate incremental values | |
| delta_cost = cost_B_sim - cost_A_sim | |
| delta_qaly = qaly_B_sim - qaly_A_sim | |
| # Calculate ICER | |
| if abs(delta_qaly) > 1e-9: | |
| icer_sim = delta_cost / delta_qaly | |
| else: | |
| icer_sim = np.inf if delta_cost > 0 else -np.inf | |
| # Store results | |
| results_cost_A.append(cost_A_sim) | |
| results_cost_B.append(cost_B_sim) | |
| results_qaly_A.append(qaly_A_sim) | |
| results_qaly_B.append(qaly_B_sim) | |
| results_delta_cost.append(delta_cost) | |
| results_delta_qaly.append(delta_qaly) | |
| results_icer.append(icer_sim) | |
| # Update progress | |
| progress_bar.progress((sim + 1) / n_simulations) | |
| progress_bar.empty() | |
| # Convert to arrays | |
| results_delta_cost = np.array(results_delta_cost) | |
| results_delta_qaly = np.array(results_delta_qaly) | |
| results_icer = np.array(results_icer) | |
| # Filter out infinite ICERs for display | |
| results_icer_finite = results_icer[np.isfinite(results_icer)] | |
| # Display results | |
| st.success(f"✅ Completed {n_simulations} simulations!") | |
| st.write("### Probabilistic Results Summary") | |
| summary_cols = st.columns(4) | |
| with summary_cols[0]: | |
| st.metric("Mean ΔCost", f"${np.mean(results_delta_cost):.2f}") | |
| st.caption( | |
| f"95% CI: [{np.percentile(results_delta_cost, 2.5):.2f}, {np.percentile(results_delta_cost, 97.5):.2f}]") | |
| with summary_cols[1]: | |
| st.metric("Mean ΔQALY", f"{np.mean(results_delta_qaly):.4f}") | |
| st.caption( | |
| f"95% CI: [{np.percentile(results_delta_qaly, 2.5):.4f}, {np.percentile(results_delta_qaly, 97.5):.4f}]") | |
| with summary_cols[2]: | |
| st.metric("Mean ICER", f"${np.mean(results_icer_finite):.2f}/QALY") | |
| st.caption( | |
| f"95% CI: [{np.percentile(results_icer_finite, 2.5):.2f}, {np.percentile(results_icer_finite, 97.5):.2f}]") | |
| with summary_cols[3]: | |
| # Calculate probability cost-effective at $50k threshold | |
| wtp_threshold = 50000 | |
| prob_ce = np.mean((results_delta_cost / results_delta_qaly) < wtp_threshold) | |
| st.metric("Prob. Cost-Effective", f"{prob_ce * 100:.1f}%") | |
| st.caption(f"at ${wtp_threshold:,}/QALY") | |
| # Determine decision based on probability | |
| if prob_ce >= 0.95: | |
| psa_decision = f"HIGHLY LIKELY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)" | |
| elif prob_ce >= 0.75: | |
| psa_decision = f"LIKELY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)" | |
| elif prob_ce >= 0.50: | |
| psa_decision = f"POSSIBLY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)" | |
| else: | |
| psa_decision = f"UNLIKELY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)" | |
| # Store PSA CEA results in session state | |
| st.session_state.cea_results = { | |
| 'analysis_type': 'Probabilistic', | |
| 'intervention_name': psa_intervention, | |
| 'time_horizon': psa_cycles, | |
| 'cost_diff': np.mean(results_delta_cost), | |
| 'qaly_diff': np.mean(results_delta_qaly), | |
| 'icer': np.mean(results_icer_finite), | |
| 'wtp_threshold': wtp_threshold, | |
| 'decision': psa_decision, | |
| 'psa_mean_cost': np.mean(results_delta_cost), | |
| 'psa_cost_ci_low': np.percentile(results_delta_cost, 2.5), | |
| 'psa_cost_ci_high': np.percentile(results_delta_cost, 97.5), | |
| 'psa_mean_qaly': np.mean(results_delta_qaly), | |
| 'psa_qaly_ci_low': np.percentile(results_delta_qaly, 2.5), | |
| 'psa_qaly_ci_high': np.percentile(results_delta_qaly, 97.5), | |
| 'psa_mean_icer': np.mean(results_icer_finite), | |
| 'prob_ce': prob_ce | |
| } | |
| # Cost-Effectiveness Plane | |
| st.write("### Cost-Effectiveness Plane (Scatter Plot)") | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| # Plot scatter points | |
| ax.scatter(results_delta_qaly, results_delta_cost, alpha=0.3, s=20, c='blue') | |
| # Plot mean point | |
| ax.scatter(np.mean(results_delta_qaly), np.mean(results_delta_cost), | |
| color='red', s=200, marker='*', edgecolors='black', linewidths=2, | |
| label='Mean', zorder=5) | |
| # Add quadrant lines | |
| ax.axhline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1) | |
| ax.axvline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1) | |
| # Add WTP threshold line | |
| xlim = ax.get_xlim() | |
| x_range = np.linspace(xlim[0], xlim[1], 100) | |
| ax.plot(x_range, x_range * wtp_threshold, 'k--', alpha=0.5, | |
| linewidth=2, label=f'WTP ${wtp_threshold:,}/QALY') | |
| ax.set_xlabel("Incremental QALYs", fontsize=12, fontweight='bold') | |
| ax.set_ylabel("Incremental Cost ($)", fontsize=12, fontweight='bold') | |
| ax.set_title("Cost-Effectiveness Plane (Probabilistic Sensitivity Analysis)", | |
| fontsize=14, fontweight='bold') | |
| ax.legend(fontsize=11) | |
| ax.grid(alpha=0.3) | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| # CEAC Curve | |
| st.write("### Cost-Effectiveness Acceptability Curve (CEAC)") | |
| wtp_range = np.linspace(0, 150000, 100) | |
| prob_ce_array = [] | |
| for wtp in wtp_range: | |
| # Count simulations where intervention is cost-effective | |
| ce_count = 0 | |
| for i in range(n_simulations): | |
| if results_delta_qaly[i] > 0: | |
| if results_delta_cost[i] < 0: # Dominant | |
| ce_count += 1 | |
| elif (results_delta_cost[i] / results_delta_qaly[i]) < wtp: | |
| ce_count += 1 | |
| elif results_delta_qaly[i] < 0 and results_delta_cost[i] < 0: | |
| # Trade-off: saves money but loses QALYs | |
| # Not typically considered cost-effective | |
| pass | |
| prob_ce_array.append(ce_count / n_simulations) | |
| fig2, ax2 = plt.subplots(figsize=(10, 6)) | |
| ax2.plot(wtp_range, prob_ce_array, 'b-', linewidth=2) | |
| ax2.axhline(0.5, color='gray', linestyle='--', alpha=0.5) | |
| ax2.axvline(50000, color='red', linestyle='--', alpha=0.5, label='$50,000/QALY') | |
| ax2.set_xlabel("Willingness-to-Pay Threshold ($/QALY)", fontsize=12, fontweight='bold') | |
| ax2.set_ylabel("Probability Cost-Effective", fontsize=12, fontweight='bold') | |
| ax2.set_title("Cost-Effectiveness Acceptability Curve", fontsize=14, fontweight='bold') | |
| ax2.set_ylim(0, 1) | |
| ax2.grid(alpha=0.3) | |
| ax2.legend() | |
| st.pyplot(fig2) | |
| plt.close(fig2) | |
| # Distribution histograms | |
| st.write("### Distribution of Results") | |
| hist_col1, hist_col2 = st.columns(2) | |
| with hist_col1: | |
| fig3, ax3 = plt.subplots(figsize=(8, 5)) | |
| ax3.hist(results_delta_cost, bins=50, alpha=0.7, color='blue', edgecolor='black') | |
| ax3.axvline(np.mean(results_delta_cost), color='red', linestyle='--', | |
| linewidth=2, label=f'Mean: ${np.mean(results_delta_cost):.2f}') | |
| ax3.set_xlabel("Incremental Cost ($)") | |
| ax3.set_ylabel("Frequency") | |
| ax3.set_title("Distribution of Incremental Costs") | |
| ax3.legend() | |
| ax3.grid(alpha=0.3) | |
| st.pyplot(fig3) | |
| plt.close(fig3) | |
| with hist_col2: | |
| fig4, ax4 = plt.subplots(figsize=(8, 5)) | |
| ax4.hist(results_delta_qaly, bins=50, alpha=0.7, color='green', edgecolor='black') | |
| ax4.axvline(np.mean(results_delta_qaly), color='red', linestyle='--', | |
| linewidth=2, label=f'Mean: {np.mean(results_delta_qaly):.4f}') | |
| ax4.set_xlabel("Incremental QALYs") | |
| ax4.set_ylabel("Frequency") | |
| ax4.set_title("Distribution of Incremental QALYs") | |
| ax4.legend() | |
| ax4.grid(alpha=0.3) | |
| st.pyplot(fig4) | |
| plt.close(fig4) | |
| # ICER distribution | |
| st.write("### ICER Distribution") | |
| fig5, ax5 = plt.subplots(figsize=(10, 5)) | |
| ax5.hist(results_icer_finite, bins=50, alpha=0.7, color='purple', edgecolor='black') | |
| ax5.axvline(np.mean(results_icer_finite), color='red', linestyle='--', | |
| linewidth=2, label=f'Mean: ${np.mean(results_icer_finite):.2f}/QALY') | |
| ax5.axvline(50000, color='orange', linestyle='--', linewidth=2, | |
| label='$50,000/QALY threshold') | |
| ax5.set_xlabel("ICER ($/QALY)") | |
| ax5.set_ylabel("Frequency") | |
| ax5.set_title("Distribution of ICER Values") | |
| ax5.legend() | |
| ax5.grid(alpha=0.3) | |
| st.pyplot(fig5) | |
| plt.close(fig5) | |
| except Exception as e: | |
| st.error(f"Error running probabilistic analysis: {str(e)}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| else: | |
| st.write("No modifiable risk factors available for intervention.") | |
| # Tab 6: Personalized Chat | |
| with tab6: | |
| st.subheader("💬 Personalized Health Recommendations") | |
| if not openai_api_key: | |
| st.warning("⚠️ Please enter your OpenAI API key in the top right corner to use Personalized Chat.") | |
| else: | |
| if not st.session_state.summary_generated: | |
| if st.button("✨ Generate Personalized Health Summary", type="primary"): | |
| with st.spinner("Analyzing your health profile..."): | |
| try: | |
| llm = get_llm() | |
| if llm: | |
| patient_info = get_patient_info_string() | |
| cea_results = get_cea_results_string() | |
| summary_prompt = f"""Generate a comprehensive health assessment summary for this patient: | |
| {patient_info} | |
| {cea_results} | |
| Include: | |
| 1. Overall risk level assessment | |
| 2. Key modifiable risk factors | |
| 3. Top 3 priority recommendations based on both health profile AND cost-effectiveness analysis results | |
| 4. Expected health benefits and economic value of interventions | |
| IMPORTANT: | |
| - Use plain text formatting only (no LaTeX, no \\text{{}} or \\frac{{}}{{}} syntax) | |
| - Write any formulas in plain text | |
| - Use simple markdown formatting (**, -, numbers) for emphasis | |
| - Avoid special characters that may not render correctly | |
| Format the response with clear sections and bullet points.""" | |
| response = llm.invoke(summary_prompt).content | |
| response = response.replace("\\text{", "").replace("}", "") | |
| response = response.replace("\\frac{", "(").replace("}{", ")/(") | |
| st.session_state.recommendation_messages.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| st.session_state.summary_generated = True | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error generating summary: {str(e)}") | |
| for message in st.session_state.recommendation_messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if st.session_state.summary_generated: | |
| if prompt := st.chat_input("Ask about your personalized recommendations..."): | |
| st.session_state.recommendation_messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| try: | |
| llm = get_llm() | |
| if llm: | |
| patient_info = get_patient_info_string() | |
| cea_results = get_cea_results_string() | |
| history_text = "" | |
| for msg in st.session_state.recommendation_messages[-10:]: | |
| role = "Patient" if msg["role"] == "user" else "Health Coach" | |
| history_text += f"{role}: {msg['content']}\n\n" | |
| full_prompt = f"""You are a personalized health coach and health economist specializing in hypertension management. | |
| PATIENT PROFILE: | |
| {patient_info} | |
| {cea_results} | |
| Provide evidence-based, actionable recommendations for: | |
| - Weight management and DASH diet | |
| - Exercise prescriptions | |
| - Smoking cessation strategies | |
| - Medication adherence | |
| - Lifestyle modifications | |
| - Stress management | |
| - Cost-effectiveness interpretation and value-based decision making | |
| Answer questions about: | |
| - Patient health risks and progression | |
| - Cost-effectiveness analysis results (ICER, QALYs, etc.) | |
| - Which interventions provide the best value for money | |
| - How to interpret deterministic vs probabilistic results | |
| - Personalized recommendations based on both clinical and economic evidence | |
| Be empathetic, practical, and motivating. Cite specific guidelines when relevant. | |
| IMPORTANT: | |
| - Use plain text formatting only (no LaTeX, no \\text{{}} or \\frac{{}}{{}} syntax) | |
| - Write formulas in plain text | |
| - Use simple markdown formatting (**, -, numbers) for emphasis | |
| - Avoid special characters that may not render correctly | |
| Conversation History: | |
| {history_text} | |
| Patient Question: {prompt} | |
| Your Personalized Advice:""" | |
| response = llm.invoke(full_prompt).content | |
| response = response.replace("\\text{", "").replace("}", "") | |
| response = response.replace("\\frac{", "(").replace("}{", ")/(") | |
| st.markdown(response, unsafe_allow_html=False) | |
| st.session_state.recommendation_messages.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| else: | |
| st.error("Failed to initialize AI.") | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='text-align: center; color: gray; font-size: 0.8em;'> | |
| Powered by LangChain & OpenAI | Hypertension CEA Tool v2.0 | |
| </div> | |
| """, unsafe_allow_html=True) |