personalized_ht / personalized_ht4.py
Donlagon007's picture
Upload personalized_ht4.py
348087f verified
"""
Streamlit App for Hypertension Cost-Effectiveness Analysis with LangChain
Enhanced with AI Assistant and Personalized Chat using OpenAI
Added: CSV upload and patient ID retrieval feature
"""
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, List, Any
import uuid
import os
import shutil
from datetime import datetime, timedelta
import hashlib
# LangChain imports
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
# Import the hypertension model functions
from hypertension_model_fixed2 import (
run_analysis, beta_men, beta_women,
hazards_from_beta, Q_matrix, discrete_P,
run_markov, icer
)
# Set page configuration
st.set_page_config(
page_title="Hypertension CEA Tool with AI",
page_icon="❤️",
layout="wide",
initial_sidebar_state="expanded",
)
# ===== 用戶隔離系統初始化 =====
# 為每個用戶生成唯一 ID
if 'user_id' not in st.session_state:
st.session_state.user_id = str(uuid.uuid4())
st.session_state.session_start = datetime.now()
# Session 過期檢查(2小時後過期)
if 'session_start' in st.session_state:
session_duration = datetime.now() - st.session_state.session_start
if session_duration > timedelta(hours=2):
st.warning("⏱️ Your session has expired (2 hours). Please refresh the page.")
st.stop()
# Initialize session state for chat histories
if 'assistant_messages' not in st.session_state:
st.session_state.assistant_messages = []
if 'recommendation_messages' not in st.session_state:
st.session_state.recommendation_messages = []
if 'summary_generated' not in st.session_state:
st.session_state.summary_generated = False
if 'patients_df' not in st.session_state:
st.session_state.patients_df = None
if 'vectorstore' not in st.session_state:
st.session_state.vectorstore = None
if 'cea_results' not in st.session_state:
st.session_state.cea_results = None
if 'uploaded_files' not in st.session_state:
st.session_state.uploaded_files = [] # ✅ 新增
if 'api_call_count' not in st.session_state:
st.session_state.api_call_count = 0 # ✅ 新增
# ✅ 啟動時清理超過 24 小時的舊資料
def cleanup_old_data():
"""Delete vector stores older than 24 hours"""
data_dir = "./data/chroma"
if not os.path.exists(data_dir):
return
cutoff_time = datetime.now() - timedelta(hours=24)
try:
for user_folder in os.listdir(data_dir):
folder_path = os.path.join(data_dir, user_folder)
if os.path.isdir(folder_path):
mod_time = datetime.fromtimestamp(os.path.getmtime(folder_path))
if mod_time < cutoff_time:
shutil.rmtree(folder_path)
except Exception as e:
pass # 靜默失敗,不影響用戶
cleanup_old_data()
# Header with OpenAI API Key input
col1, col2 = st.columns([3, 1])
with col1:
st.title("❤️ Personalized LLM for Planning Cost Effectiveness Analysis")
st.markdown("*Multistate ML Analysis Showcase Hypertension*")
with col2:
openai_api_key = st.text_input(
"🔑 OpenAI API Key",
type="password",
placeholder="sk-...",
help="Enter your OpenAI API key to enable AI features"
)
if openai_api_key:
st.success("✓ API Key set")
else:
st.warning("⚠️ Enter API key")
# ali Check if API key is provided
def get_llm():
"""Initialize LangChain LLM with OpenAI - with rate limiting"""
if not openai_api_key:
return None
# ✅ API 調用限制(每 session 最多 100 次)
if st.session_state.api_call_count >= 100:
st.error("⚠️ API call limit reached (100 calls per session). Please start a new session.")
st.stop()
try:
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0.7,
openai_api_key=openai_api_key
)
# ✅ 記錄 API 調用
st.session_state.api_call_count += 1
return llm
except Exception as e:
st.error(f"Error initializing OpenAI: {str(e)}")
return None
# Create vector store from patient data
def create_patient_vectorstore(patients_df: pd.DataFrame):
"""Create vector store from patient dataframe for RAG retrieval - USER ISOLATED"""
if not openai_api_key:
return None
try:
# ✅ 建立用戶專屬目錄
user_id = st.session_state.user_id
persist_dir = f"./data/chroma/{user_id}"
os.makedirs(persist_dir, exist_ok=True)
documents = []
for idx, row in patients_df.iterrows():
patient_text = f"""Patient ID: {row['patient_id']}
Sex: {row['sex']}, Age: {row['age']}, Education: {row['education']}
BMI: {row['bmi']} kg/m², Waist: {row['waist']} cm
Fasting Glucose: {row['fasting_glucose']} mg/dL
Total Cholesterol: {row['total_cholesterol']} mg/dL
Uric Acid: {row['uric_acid']} mg/dL
Smoking: {row['smoking']}, Alcohol: {row['alcohol']}, Exercise: {row['exercise']}
Betel: {row.get('betel', 'No')}, Family History: {row['family_history']}"""
doc = Document(
page_content=patient_text,
metadata={
"patient_id": row['patient_id'],
"user_id": user_id # ✅ 加入 user_id 標記
}
)
documents.append(doc)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
# ✅ 使用用戶專屬的 collection 和目錄
vectorstore = Chroma(
collection_name=f"user_{user_id}_patients",
embedding_function=embeddings,
persist_directory=persist_dir
)
# 清空舊資料(如果有)
try:
vectorstore.delete_collection()
vectorstore = Chroma(
collection_name=f"user_{user_id}_patients",
embedding_function=embeddings,
persist_directory=persist_dir
)
except:
pass
# 加入新文件
vectorstore.add_documents(documents)
return vectorstore
except Exception as e:
st.error(f"Error creating vector store: {str(e)}")
return None
# Retrieve patient by ID
def retrieve_patient_by_id(patient_id: str):
"""Retrieve patient from dataframe by ID"""
if st.session_state.patients_df is None:
return None
patient_row = st.session_state.patients_df[
st.session_state.patients_df['patient_id'] == patient_id
]
if patient_row.empty:
return None
return patient_row.iloc[0].to_dict()
# Sidebar for patient information
st.sidebar.header("👤 Patient Information")
# ali ===== NEW: Patient Data Upload Section =====
with st.sidebar.expander("📂 Upload Patient Database (Optional)", expanded=False):
uploaded_file = st.file_uploader(
"Upload CSV/Excel with patient data",
type=['csv', 'xlsx'],
help="Upload a file with multiple patients (Max 10MB)"
)
if uploaded_file is not None:
try:
# ✅ 檔案大小限制
file_size_mb = uploaded_file.size / (1024 * 1024)
if file_size_mb > 10:
st.error(f"❌ File too large ({file_size_mb:.1f}MB). Max 10MB.")
st.stop()
if uploaded_file.name.endswith('.csv'):
df = pd.read_csv(uploaded_file)
else:
df = pd.read_excel(uploaded_file)
st.session_state.patients_df = df
# ✅ 記錄上傳的檔案
file_hash = hashlib.md5(uploaded_file.getvalue()).hexdigest()
if file_hash not in st.session_state.uploaded_files:
st.session_state.uploaded_files.append(file_hash)
st.success(f"✅ Loaded {len(df)} patients ({file_size_mb:.1f}MB)")
# Optionally create vector store
if openai_api_key and st.button("🔄 Create Vector Store for Smart Search"):
with st.spinner("Creating isolated vector store..."):
vectorstore = create_patient_vectorstore(df)
if vectorstore:
st.session_state.vectorstore = vectorstore
st.success("✅ Vector store created! (Isolated to your session)")
except Exception as e:
st.error(f"Error loading file: {str(e)}")
# Patient ID retrieval
if st.session_state.patients_df is not None:
st.markdown("---")
patient_id_input = st.text_input("🔍 Enter Patient ID", placeholder="P001")
if st.button("📥 Load Patient Data"):
if patient_id_input:
patient_data = retrieve_patient_by_id(patient_id_input)
if patient_data:
st.session_state.loaded_patient = patient_data
st.success(f"✅ Loaded {patient_id_input}")
st.rerun()
else:
st.error(f"Patient {patient_id_input} not found")
st.sidebar.markdown("---")
# Check if we have loaded patient data
if 'loaded_patient' in st.session_state:
# Use loaded patient data
p = st.session_state.loaded_patient
st.sidebar.info(f"📋 Loaded: {p['patient_id']}")
sex = st.sidebar.radio("Biological Sex", ["Male", "Female"],
index=0 if p['sex'] == 'Male' else 1, key="sex_loaded")
st.sidebar.markdown("### Demographics")
age = st.sidebar.number_input("Age", min_value=18, max_value=100, value=int(p['age']), key="age_loaded")
education_idx = 0 if p['education'] in ['High', 'High (College or above)'] else 1
education = st.sidebar.selectbox("Education Level",
["High (College or above)", "Low (Below college)"],
index=education_idx, key="edu_loaded")
st.sidebar.markdown("### Anthropometrics")
bmi = st.sidebar.number_input("BMI (kg/m²)", min_value=15.0, max_value=50.0,
value=float(p['bmi']), format="%.1f", key="bmi_loaded")
waist = st.sidebar.number_input("Waist Circumference (cm)", min_value=50, max_value=150,
value=int(p['waist']), key="waist_loaded")
st.sidebar.markdown("### Laboratory Values")
fasting_glucose = st.sidebar.number_input("Fasting Glucose (mg/dL)", min_value=70, max_value=300,
value=int(p['fasting_glucose']), key="glucose_loaded")
total_cholesterol = st.sidebar.number_input("Total Cholesterol (mg/dL)", min_value=100, max_value=400,
value=int(p['total_cholesterol']), key="chol_loaded")
uric_acid = st.sidebar.number_input("Uric Acid (mg/dL)", min_value=2.0, max_value=15.0,
value=float(p['uric_acid']), format="%.1f", key="ua_loaded")
st.sidebar.markdown("### Lifestyle Factors")
smoking_idx = 0 if p['smoking'] in ['No', 'Non-smoker'] else 1
smoking = st.sidebar.selectbox("Smoking Status", ["Non-smoker", "Current smoker"],
index=smoking_idx, key="smoke_loaded")
alcohol_idx = 0 if p['alcohol'] in ['No', 'None/Occasional'] else 1
alcohol = st.sidebar.selectbox("Alcohol Consumption", ["None/Occasional", "Regular drinker"],
index=alcohol_idx, key="alcohol_loaded")
exercise_idx = 0 if p['exercise'] in ['No', 'Infrequent'] else 1
exercise = st.sidebar.selectbox("Exercise Frequency", ["Infrequent", "Regular (≥3 times/week)"],
index=exercise_idx, key="exercise_loaded")
if sex == "Male":
betel_idx = 0 if p.get('betel', 'No') == 'No' else 1
betel = st.sidebar.selectbox("Betel Nut Chewing", ["No", "Yes"],
index=betel_idx, key="betel_loaded")
family_idx = 0 if p['family_history'] == 'No' else 1
family_history = st.sidebar.selectbox("Family History of Hypertension", ["No", "Yes"],
index=family_idx, key="fh_loaded")
else:
# Manual input (original functionality)
sex = st.sidebar.radio("Biological Sex", ["Male", "Female"])
st.sidebar.markdown("### Demographics")
age = st.sidebar.number_input("Age", min_value=18, max_value=100, value=45)
education = st.sidebar.selectbox("Education Level", ["High (College or above)", "Low (Below college)"])
st.sidebar.markdown("### Anthropometrics")
bmi = st.sidebar.number_input("BMI (kg/m²)", min_value=15.0, max_value=50.0, value=27.0, format="%.1f")
if sex == "Male":
waist = st.sidebar.number_input("Waist Circumference (cm)", min_value=60, max_value=150, value=88)
else:
waist = st.sidebar.number_input("Waist Circumference (cm)", min_value=50, max_value=150, value=78)
st.sidebar.markdown("### Laboratory Values")
fasting_glucose = st.sidebar.number_input("Fasting Glucose (mg/dL)", min_value=70, max_value=300, value=100)
total_cholesterol = st.sidebar.number_input("Total Cholesterol (mg/dL)", min_value=100, max_value=400, value=190)
if sex == "Male":
uric_acid = st.sidebar.number_input("Uric Acid (mg/dL)", min_value=2.0, max_value=15.0, value=6.5,
format="%.1f")
else:
uric_acid = st.sidebar.number_input("Uric Acid (mg/dL)", min_value=2.0, max_value=15.0, value=5.5,
format="%.1f")
st.sidebar.markdown("### Lifestyle Factors")
smoking = st.sidebar.selectbox("Smoking Status", ["Non-smoker", "Current smoker"])
alcohol = st.sidebar.selectbox("Alcohol Consumption", ["None/Occasional", "Regular drinker"])
exercise = st.sidebar.selectbox("Exercise Frequency", ["Infrequent", "Regular (≥3 times/week)"])
if sex == "Male":
betel = st.sidebar.selectbox("Betel Nut Chewing", ["No", "Yes"])
family_history = st.sidebar.selectbox("Family History of Hypertension", ["No", "Yes"])
# Convert inputs to feature dictionary
def create_feature_dict():
features = {}
if sex == "Male":
features["Education_high"] = 1 if education == "High (College or above)" else 0
features["BMI_ge25"] = 1 if bmi >= 25 else 0
features["Waist_ge90"] = 1 if waist >= 90 else 0
features["Fasting_glu_high"] = 1 if fasting_glucose >= 110 else 0
features["TC_ge200"] = 1 if total_cholesterol >= 200 else 0
features["UA_high"] = 1 if uric_acid >= 7 else 0
features["Smoking_current"] = 1 if smoking == "Current smoker" else 0
features["Betel_current"] = 1 if betel == "Yes" else 0
features["Alcohol_current"] = 1 if alcohol == "Regular drinker" else 0
features["Exercise_freq"] = 1 if exercise == "Regular (≥3 times/week)" else 0
features["FHx_yes"] = 1 if family_history == "Yes" else 0
else: # Female
features["Education_high"] = 1 if education == "High (College or above)" else 0
features["BMI_ge25"] = 1 if bmi >= 25 else 0
features["Waist_ge80"] = 1 if waist >= 80 else 0
features["Fasting_glu_high"] = 1 if fasting_glucose >= 110 else 0
features["TC_ge200"] = 1 if total_cholesterol >= 200 else 0
features["UA_high"] = 1 if uric_acid >= 6 else 0
features["Smoking_current"] = 1 if smoking == "Current smoker" else 0
features["Alcohol_current"] = 1 if alcohol == "Regular drinker" else 0
features["Exercise_freq"] = 1 if exercise == "Regular (≥3 times/week)" else 0
features["FHx_yes"] = 1 if family_history == "Yes" else 0
return features
# Get patient features and info string
patient_features = create_feature_dict()
def get_patient_info_string():
"""Generate patient info string for AI context"""
risk_factors = []
if patient_features.get("BMI_ge25", 0) == 1:
risk_factors.append(f"BMI {bmi:.1f} kg/m² (≥25)")
if (sex == "Male" and patient_features.get("Waist_ge90", 0) == 1) or \
(sex == "Female" and patient_features.get("Waist_ge80", 0) == 1):
risk_factors.append(f"Waist circumference {waist} cm (high)")
if patient_features.get("Smoking_current", 0) == 1:
risk_factors.append("Current smoker")
if patient_features.get("Alcohol_current", 0) == 1:
risk_factors.append("Regular alcohol consumption")
if sex == "Male" and patient_features.get("Betel_current", 0) == 1:
risk_factors.append("Betel nut chewing")
if patient_features.get("Exercise_freq", 0) == 0:
risk_factors.append("Insufficient exercise")
if patient_features.get("FHx_yes", 0) == 1:
risk_factors.append("Family history of hypertension")
if patient_features.get("UA_high", 0) == 1:
risk_factors.append(f"High uric acid ({uric_acid:.1f} mg/dL)")
if patient_features.get("Fasting_glu_high", 0) == 1:
risk_factors.append(f"High fasting glucose ({fasting_glucose} mg/dL)")
if patient_features.get("TC_ge200", 0) == 1:
risk_factors.append(f"High total cholesterol ({total_cholesterol} mg/dL)")
patient_id_str = ""
if 'loaded_patient' in st.session_state:
patient_id_str = f"Patient ID: {st.session_state.loaded_patient['patient_id']}\n"
info = f"""{patient_id_str}{sex}, Age {age}
BMI: {bmi:.1f} kg/m²
Waist: {waist} cm
Exercise: {exercise}
Smoking: {smoking}
Alcohol: {alcohol}
Risk Factors Identified:
"""
if risk_factors:
info += "\n".join(f"- {rf}" for rf in risk_factors)
else:
info += "- No major modifiable risk factors detected"
return info
def get_cea_results_string():
"""Generate CEA results string for AI context"""
if st.session_state.cea_results is None:
return "No cost-effectiveness analysis has been run yet. Please complete the analysis in Tab 4 or Tab 5 first."
cea = st.session_state.cea_results
result_text = f"""
COST-EFFECTIVENESS ANALYSIS RESULTS:
═══════════════════════════════════════════════
Analysis Type: {cea.get('analysis_type', 'Not specified')}
Intervention: {cea.get('intervention_name', 'Not specified')}
Time Horizon: {cea.get('time_horizon', 'N/A')} years
INCREMENTAL RESULTS:
- Incremental Cost: ${cea.get('cost_diff', 0):.2f}
- Incremental QALYs: {cea.get('qaly_diff', 0):.4f}
- ICER: ${cea.get('icer', 0):.2f} per QALY
INTERPRETATION:
- WTP Threshold: ${cea.get('wtp_threshold', 50000):,}/QALY
- Decision: {cea.get('decision', 'Not determined')}
"""
# Add PSA results if available
if cea.get('analysis_type') == 'Probabilistic' and 'psa_mean_cost' in cea:
result_text += f"""
PROBABILISTIC SENSITIVITY ANALYSIS (PSA):
- Mean Incremental Cost: ${cea.get('psa_mean_cost', 0):.2f} (95% CI: ${cea.get('psa_cost_ci_low', 0):.2f} to ${cea.get('psa_cost_ci_high', 0):.2f})
- Mean Incremental QALYs: {cea.get('psa_mean_qaly', 0):.4f} (95% CI: {cea.get('psa_qaly_ci_low', 0):.4f} to {cea.get('psa_qaly_ci_high', 0):.4f})
- Mean ICER: ${cea.get('psa_mean_icer', 0):.2f} per QALY
- Probability Cost-Effective at ${cea.get('wtp_threshold', 50000):,}/QALY: {cea.get('prob_ce', 0)*100:.1f}%
"""
result_text += """
═══════════════════════════════════════════════
"""
return result_text
# Create tabs
tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
"🤖 AI Assistant",
"📊 Hypertension Progression",
"📈 Intervention Comparison",
"💰 CEA (Deterministic)",
"🎲 CEA (Probabilistic)",
"💬 personalized CEA chatbot"
])
# Tab 1: AI Assistant
with tab1:
st.subheader("AI Assistant - Ask Questions About This Tool")
if not openai_api_key:
st.warning("⚠️ Please enter your OpenAI API key in the top right corner to use the AI Assistant.")
else:
# Display chat messages
for message in st.session_state.assistant_messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Ask me anything about hypertension analysis..."):
# Add user message
st.session_state.assistant_messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Get AI response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
llm = get_llm()
if llm:
history_text = ""
for msg in st.session_state.assistant_messages[-10:]:
role = "User" if msg["role"] == "user" else "Assistant"
history_text += f"{role}: {msg['content']}\n\n"
full_prompt = f"""You are an expert AI assistant for a Hypertension Cost-Effectiveness Analysis tool.
Your role is to help users understand:
- How to use this analysis tool
- Markov model methodology (4 states: Normal → Prehypertension → Stage 1 HTN → Stage 2 HTN)
- Cost-effectiveness metrics (ICER, QALY, CEAC)
- Risk factor interpretation
- Available interventions
Be clear, concise, and educational. Use examples when helpful.
IMPORTANT:
- Use plain text formatting only (no LaTeX, no \\text{{}} or \\frac{{}}{{}} syntax)
- Write mathematical formulas in plain text like: ICER = (Cost_B - Cost_A) / (QALY_B - QALY_A)
- Use simple markdown formatting (**, -, numbers) for emphasis
- Avoid special characters that may not render correctly
Conversation History:
{history_text}
User Question: {prompt}
Your Response:"""
response = llm.invoke(full_prompt).content
response = response.replace("\\text{", "").replace("}", "")
response = response.replace("\\frac{", "(").replace("}{", ")/(")
st.markdown(response, unsafe_allow_html=False)
st.session_state.assistant_messages.append({"role": "assistant", "content": response})
else:
st.error("Failed to initialize AI. Please check your API key.")
except Exception as e:
st.error(f"Error: {str(e)}")
# Tab 2: Hypertension Progression
with tab2:
st.subheader("Baseline Hypertension Progression Risk")
# Create columns for different visualizations
prog_col1, prog_col2 = st.columns([1, 1])
# Prediction settings
with prog_col1:
st.write("Progression Projection Settings")
projection_years = st.slider("Projection Horizon (Years)", min_value=1, max_value=20, value=10)
# Calculate the transition rates
sex_code = "M" if sex == "Male" else "F"
l1, l2, l3, r = hazards_from_beta(sex_code, patient_features,
lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05, randomize=True)
# Display the annual transition rates
st.write("Annual Transition Rates:")
rates_df = pd.DataFrame({
"Transition": ["Normal → Prehypertension", "Prehypertension → Stage 1",
"Stage 1 → Stage 2", "Prehypertension → Normal"],
"Annual Rate (%)": [l1 * 100, l2 * 100, l3 * 100, r * 100]
})
st.dataframe(rates_df)
# Calculate lifetime risk
Q = Q_matrix(l1, l2, l3, r)
P = discrete_P(Q, 1.0)
# Project the state distribution
states = ["Normal", "Prehypertension", "Stage 1", "Stage 2"]
s = np.array([1, 0, 0, 0], float) # Start in Normal state
# Project over time
projections = [s.copy()]
for _ in range(projection_years):
s = s @ P
projections.append(s.copy())
proj_df = pd.DataFrame(projections, columns=states)
proj_df.index.name = "Year"
# Calculate lifetime risk of progressing to Stage 2
lifetime_risk_s2 = proj_df["Stage 2"].iloc[-1] * 100
# Visualization of progression over time
with prog_col2:
# Plot state distribution over time
fig, ax = plt.subplots(figsize=(8, 5))
for i, state in enumerate(states):
ax.plot(range(projection_years + 1), proj_df[state], label=state)
ax.set_xlabel("Year")
ax.set_ylabel("Proportion")
ax.set_title("Projected Hypertension State Distribution Over Time")
ax.legend()
ax.grid(alpha=0.3)
st.pyplot(fig)
plt.close(fig)
# Summary metrics
st.subheader("Summary Risk Metrics")
risk_cols = st.columns(5)
# 5-year risks
risk_5yr_p = projections[5][1]
risk_5yr_htn = projections[5][2] + projections[5][3]
risk_5yr_s2 = projections[5][3]
# 10-year risks
year_10_idx = min(10, projection_years)
risk_10yr_p = projections[year_10_idx][1]
risk_10yr_htn = projections[year_10_idx][2] + projections[year_10_idx][3]
risk_10yr_s2 = projections[year_10_idx][3]
# Display risk metrics
with risk_cols[0]:
st.metric("5-Year Prehypertension Risk", f"{risk_5yr_p * 100:.1f}%")
with risk_cols[1]:
st.metric("5-Year Any Hypertension Risk", f"{risk_5yr_htn * 100:.1f}%")
with risk_cols[2]:
st.metric("5-Year Stage 2 Risk", f"{risk_5yr_s2 * 100:.1f}%")
with risk_cols[3]:
st.metric("10-Year Any Hypertension Risk", f"{risk_10yr_htn * 100:.1f}%")
with risk_cols[4]:
st.metric("Lifetime Stage 2 Risk", f"{lifetime_risk_s2:.1f}%")
# Comparison to population averages
st.write("#### Risk Comparison to Population Average")
avg_5yr_htn = 0.08
avg_10yr_htn = 0.18
risk_ratio_5yr = risk_5yr_htn / avg_5yr_htn if avg_5yr_htn > 0 else 1.0
risk_ratio_10yr = risk_10yr_htn / avg_10yr_htn if avg_10yr_htn > 0 else 1.0
st.write(f"This patient's 5-year risk of hypertension is **{risk_ratio_5yr:.1f}x** the population average.")
st.write(f"This patient's 10-year risk of hypertension is **{risk_ratio_10yr:.1f}x** the population average.")
# Risk factors explanation
st.write("#### Key Risk Factors")
# Identify high risk factors
risk_factors = []
if patient_features.get("BMI_ge25", 0) == 1:
risk_factors.append(f"BMI ≥ 25 kg/m² (current: {bmi:.1f})")
if (sex == "Male" and patient_features.get("Waist_ge90", 0) == 1) or (
sex == "Female" and patient_features.get("Waist_ge80", 0) == 1):
risk_factors.append(f"High waist circumference (current: {waist} cm)")
if patient_features.get("Smoking_current", 0) == 1:
risk_factors.append("Current smoker")
if patient_features.get("Alcohol_current", 0) == 1:
risk_factors.append("Regular alcohol consumption")
if sex == "Male" and patient_features.get("Betel_current", 0) == 1:
risk_factors.append("Betel nut chewing")
if patient_features.get("Exercise_freq", 0) == 0:
risk_factors.append("Infrequent exercise")
if patient_features.get("FHx_yes", 0) == 1:
risk_factors.append("Family history of hypertension")
if patient_features.get("UA_high", 0) == 1:
risk_factors.append(f"High uric acid (current: {uric_acid:.1f} mg/dL)")
if patient_features.get("Fasting_glu_high", 0) == 1:
risk_factors.append(f"High fasting glucose (current: {fasting_glucose} mg/dL)")
if patient_features.get("TC_ge200", 0) == 1:
risk_factors.append(f"High total cholesterol (current: {total_cholesterol} mg/dL)")
# Display risk factors
if risk_factors:
st.write("This patient has the following risk factors:")
for factor in risk_factors:
st.write(f"- {factor}")
else:
st.write("This patient has no major modifiable risk factors.")
# Tab 3: Intervention Comparison
with tab3:
st.subheader("Compare Intervention Effects")
# Choose interventions to compare
available_interventions = []
if patient_features.get("BMI_ge25", 0) == 1:
available_interventions.append(("BMI_ge25", "Weight Loss (BMI < 25 kg/m²)"))
if (sex == "Male" and patient_features.get("Waist_ge90", 0) == 1) or (
sex == "Female" and patient_features.get("Waist_ge80", 0) == 1):
waist_feature = "Waist_ge90" if sex == "Male" else "Waist_ge80"
available_interventions.append((waist_feature, "Waist Circumference Reduction"))
if patient_features.get("Smoking_current", 0) == 1:
available_interventions.append(("Smoking_current", "Smoking Cessation"))
if patient_features.get("Alcohol_current", 0) == 1:
available_interventions.append(("Alcohol_current", "Alcohol Reduction"))
if sex == "Male" and patient_features.get("Betel_current", 0) == 1:
available_interventions.append(("Betel_current", "Betel Nut Cessation"))
if patient_features.get("Exercise_freq", 0) == 0:
available_interventions.append(("Exercise_freq", "Regular Exercise"))
if patient_features.get("UA_high", 0) == 1:
available_interventions.append(("UA_high", "Uric Acid Reduction"))
if patient_features.get("TC_ge200", 0) == 1:
available_interventions.append(("TC_ge200", "Cholesterol Reduction"))
if patient_features.get("Fasting_glu_high", 0) == 1:
available_interventions.append(("Fasting_glu_high", "Glucose Control"))
if not available_interventions:
st.write("No modifiable risk factors available for intervention.")
else:
selected_intervention_names = st.multiselect(
"Select interventions to compare:",
[name for _, name in available_interventions],
max_selections=3
)
selected_interventions = [
feature for feature, name in available_interventions
if name in selected_intervention_names
]
comp_years = st.slider("Comparison Projection (Years)", min_value=1, max_value=20, value=10, key="comp_years")
if selected_interventions:
sex_code = "M" if sex == "Male" else "F"
l1_base, l2_base, l3_base, r_base = hazards_from_beta(
sex_code, patient_features,
lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05
)
Q_base = Q_matrix(l1_base, l2_base, l3_base, r_base)
P_base = discrete_P(Q_base, 1.0)
states = ["Normal", "Prehypertension", "Stage 1", "Stage 2"]
s_base = np.array([1, 0, 0, 0], float)
proj_base = [s_base.copy()]
for _ in range(comp_years):
s_base = s_base @ P_base
proj_base.append(s_base.copy())
proj_base_df = pd.DataFrame(proj_base, columns=states)
proj_base_df.index.name = "Year"
intervention_data = []
for feature in selected_interventions:
int_features = patient_features.copy()
if feature == "Exercise_freq":
int_features[feature] = 1
else:
int_features[feature] = 0
l1_int, l2_int, l3_int, r_int = hazards_from_beta(
sex_code, int_features,
lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05
)
Q_int = Q_matrix(l1_int, l2_int, l3_int, r_int)
P_int = discrete_P(Q_int, 1.0)
s_int = np.array([1, 0, 0, 0], float)
proj_int = [s_int.copy()]
for _ in range(comp_years):
s_int = s_int @ P_int
proj_int.append(s_int.copy())
proj_int_df = pd.DataFrame(proj_int, columns=states)
baseline_htn_risk = proj_base_df["Stage 1"].iloc[-1] + proj_base_df["Stage 2"].iloc[-1]
int_htn_risk = proj_int_df["Stage 1"].iloc[-1] + proj_int_df["Stage 2"].iloc[-1]
absolute_risk_reduction = baseline_htn_risk - int_htn_risk
relative_risk_reduction = absolute_risk_reduction / baseline_htn_risk if baseline_htn_risk > 0 else 0
nnt = 1 / absolute_risk_reduction if absolute_risk_reduction > 0 else float('inf')
int_name = next(name for feat, name in available_interventions if feat == feature)
intervention_data.append({
"feature": feature,
"name": int_name,
"projection": proj_int_df,
"risk_reduction_abs": absolute_risk_reduction,
"risk_reduction_rel": relative_risk_reduction,
"nnt": nnt
})
st.write("#### Hypertension Risk Comparison")
fig, ax = plt.subplots(figsize=(10, 6))
baseline_htn_risk = [
proj_base_df["Stage 1"].iloc[i] + proj_base_df["Stage 2"].iloc[i]
for i in range(len(proj_base_df))
]
ax.plot(range(comp_years + 1), baseline_htn_risk, 'k-', linewidth=2, label="No Intervention")
colors = ['b', 'g', 'r', 'c', 'm', 'y']
for i, int_data in enumerate(intervention_data):
int_htn_risk = [
int_data["projection"]["Stage 1"].iloc[j] + int_data["projection"]["Stage 2"].iloc[j]
for j in range(len(int_data["projection"]))
]
ax.plot(
range(comp_years + 1),
int_htn_risk,
f"{colors[i % len(colors)]}-",
linewidth=2,
label=int_data["name"]
)
ax.set_xlabel("Year")
ax.set_ylabel("Probability of Hypertension (Stage 1 or 2)")
ax.set_title(f"Effect of Interventions on {comp_years}-Year Hypertension Risk")
ax.legend()
ax.grid(alpha=0.3)
st.pyplot(fig)
plt.close(fig)
st.write("#### Effectiveness Comparison")
metrics_data = {
"Intervention": ["No Intervention"] + [int_data["name"] for int_data in intervention_data],
f"{comp_years}-Year HTN Risk": [
baseline_htn_risk[-1] * 100
] + [
(baseline_htn_risk[-1] - int_data["risk_reduction_abs"]) * 100
for int_data in intervention_data
],
"Absolute Risk Reduction (%)": [
0
] + [
int_data["risk_reduction_abs"] * 100
for int_data in intervention_data
],
"Relative Risk Reduction (%)": [
0
] + [
int_data["risk_reduction_rel"] * 100
for int_data in intervention_data
],
"Number Needed to Treat": [
"N/A"
] + [
f"{int_data['nnt']:.1f}" if int_data["nnt"] < 100 else "100+"
for int_data in intervention_data
]
}
metrics_df = pd.DataFrame(metrics_data)
st.table(metrics_df.set_index("Intervention"))
if intervention_data:
most_effective = max(intervention_data, key=lambda x: x["risk_reduction_abs"])
st.info(
f"**Recommendation**: Based on this analysis, "
f"**{most_effective['name']}** provides the greatest reduction in "
f"{comp_years}-year hypertension risk "
f"({most_effective['risk_reduction_abs'] * 100:.1f}% absolute reduction)."
)
else:
st.write("Please select at least one intervention to compare.")
# Tab 4: Cost-Effectiveness Analysis (Deterministic)
# Tab 4: Cost-Effectiveness Analysis (Deterministic)
with tab4:
st.subheader("Cost-Effectiveness Analysis (Deterministic)")
st.info("📌 This analysis uses **point estimates** (single values) for all parameters")
st.write("### Analysis Settings")
if available_interventions:
cea_intervention = st.selectbox(
"Select intervention to analyze:",
[name for _, name in available_interventions],
index=0
)
cea_feature = next(
feature for feature, name in available_interventions
if name == cea_intervention
)
param_col1, param_col2, param_col3 = st.columns(3)
with param_col1:
cea_cycles = st.slider("Time Horizon (Years)", min_value=5, max_value=30, value=10)
discount_rate = st.slider("Discount Rate (%)", min_value=0, max_value=10, value=3) / 100
with param_col2:
st.write("#### Cost Settings ($ per year)")
cost_normal = st.number_input("Cost - Normal BP", min_value=0, max_value=5000, value=200)
cost_pre = st.number_input("Cost - Prehypertension", min_value=0, max_value=5000, value=600)
cost_s1 = st.number_input("Cost - Stage 1 HTN", min_value=0, max_value=5000, value=1200)
cost_s2 = st.number_input("Cost - Stage 2 HTN", min_value=0, max_value=5000, value=2200)
with param_col3:
st.write("#### Utility Settings (QOL 0-1)")
util_normal = st.slider("Utility - Normal BP", min_value=0.0, max_value=1.0, value=1.0, step=0.05)
util_pre = st.slider("Utility - Prehypertension", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
util_s1 = st.slider("Utility - Stage 1 HTN", min_value=0.0, max_value=1.0, value=0.7, step=0.05)
util_s2 = st.slider("Utility - Stage 2 HTN", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
st.write("#### Intervention Settings")
int_cost_increase = st.number_input(
"Additional Intervention Cost ($/year)",
min_value=0,
max_value=2000,
value=500
)
st.markdown("---")
C_A = np.array([cost_normal, cost_pre, cost_s1, cost_s2])
C_B = C_A.copy()
C_B[0] += int_cost_increase
U = np.array([util_normal, util_pre, util_s1, util_s2])
sex_code = "M" if sex == "Male" else "F"
try:
# Calculate baseline scenario
l1_base, l2_base, l3_base, r_base = hazards_from_beta(
sex_code, patient_features,
lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05
)
Q_base = Q_matrix(l1_base, l2_base, l3_base, r_base)
P_base = discrete_P(Q_base, 1.0)
start_dist = np.array([1, 0, 0, 0], float)
cost_A, qaly_A, _ = run_markov(P_base, C_A, U, start_dist, cea_cycles, discount_rate)
# Calculate intervention scenario
int_features = patient_features.copy()
if cea_feature == "Exercise_freq":
int_features[cea_feature] = 1
else:
int_features[cea_feature] = 0
l1_int, l2_int, l3_int, r_int = hazards_from_beta(
sex_code, int_features,
lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05
)
Q_int = Q_matrix(l1_int, l2_int, l3_int, r_int)
P_int = discrete_P(Q_int, 1.0)
cost_B, qaly_B, _ = run_markov(P_int, C_B, U, start_dist, cea_cycles, discount_rate)
# Calculate incremental values
cost_diff = cost_B - cost_A
qaly_diff = qaly_B - qaly_A
icer_val, _, _ = icer(cost_A, qaly_A, cost_B, qaly_B) # Unpack tuple (ICER, deltaC, deltaQ)
# ✅ 新增:成本效益平面圖 (單點)
st.write("### Cost-Effectiveness Plane")
fig, ax = plt.subplots(figsize=(10, 8))
# Plot the single point
ax.scatter(qaly_diff, cost_diff, s=300, c='red', marker='*',
edgecolors='black', linewidths=2, label='Intervention vs Baseline', zorder=5)
# Add quadrant lines
ax.axhline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax.axvline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
# Add WTP threshold line
wtp_threshold = 50000
xlim = ax.get_xlim()
ylim = ax.get_ylim()
# Extend the line across the plot
x_range = np.linspace(min(xlim[0], -0.01), max(xlim[1], 0.01), 100)
ax.plot(x_range, x_range * wtp_threshold, 'k--', alpha=0.5,
linewidth=2, label=f'WTP ${wtp_threshold:,}/QALY')
# Add quadrant labels
ax.text(0.95, 0.95, 'More Costly\nMore Effective',
transform=ax.transAxes, ha='right', va='top', fontsize=10, alpha=0.5)
ax.text(0.05, 0.95, 'More Costly\nLess Effective',
transform=ax.transAxes, ha='left', va='top', fontsize=10, alpha=0.5)
ax.text(0.95, 0.05, 'Less Costly\nMore Effective',
transform=ax.transAxes, ha='right', va='bottom', fontsize=10, alpha=0.5)
ax.text(0.05, 0.05, 'Less Costly\nLess Effective',
transform=ax.transAxes, ha='left', va='bottom', fontsize=10, alpha=0.5)
ax.set_xlabel("Incremental QALYs", fontsize=12, fontweight='bold')
ax.set_ylabel("Incremental Cost ($)", fontsize=12, fontweight='bold')
ax.set_title("Cost-Effectiveness Plane (Deterministic Analysis)",
fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)
st.pyplot(fig)
plt.close(fig)
st.write("### Summary Metrics")
metrics_cols = st.columns(3)
with metrics_cols[0]:
st.metric("Incremental Cost", f"${cost_diff:.2f}")
with metrics_cols[1]:
st.metric("Incremental QALYs", f"{qaly_diff:.3f}")
with metrics_cols[2]:
st.metric("ICER ($/QALY)", f"${icer_val:.2f}")
# ✅ 修改:決策建議(不使用機率)
st.write("### Cost-Effectiveness Decision")
wtp_threshold = 50000
# Determine decision text
if qaly_diff > 0:
if cost_diff <= 0:
decision_text = "DOMINANT: Saves money and improves health"
st.success("✅ **DOMINANT**: This intervention saves money and improves health.")
st.info(
f"💡 The intervention provides {qaly_diff:.3f} additional QALYs while saving ${-cost_diff:.2f}.")
elif icer_val < wtp_threshold:
decision_text = f"COST-EFFECTIVE: ICER below ${wtp_threshold:,}/QALY threshold"
st.success(
f"✅ **COST-EFFECTIVE**: ICER = ${icer_val:.2f}/QALY is below the ${wtp_threshold:,}/QALY threshold.")
st.info(
f"💡 For every additional QALY gained, it costs ${icer_val:.2f}, which is considered acceptable.")
else:
decision_text = f"NOT COST-EFFECTIVE: ICER exceeds ${wtp_threshold:,}/QALY threshold"
st.warning(
f"⚠️ **NOT COST-EFFECTIVE**: ICER = ${icer_val:.2f}/QALY exceeds the ${wtp_threshold:,}/QALY threshold.")
st.info(
f"💡 The intervention would need to cost ${qaly_diff * wtp_threshold:.2f} or less to be cost-effective at this threshold.")
else:
if cost_diff >= 0:
decision_text = "DOMINATED: Costs more and worsens health"
st.error("❌ **DOMINATED**: This intervention costs more and worsens health outcomes.")
else:
decision_text = "TRADE-OFF: Saves money but reduces QALYs"
st.warning("⚠️ **TRADE-OFF**: This intervention saves money but reduces QALYs.")
st.info(f"💡 It saves ${-cost_diff:.2f} but loses {-qaly_diff:.3f} QALYs.")
# Store CEA results in session state
st.session_state.cea_results = {
'analysis_type': 'Deterministic',
'intervention_name': cea_intervention,
'time_horizon': cea_cycles,
'cost_diff': cost_diff,
'qaly_diff': qaly_diff,
'icer': icer_val,
'wtp_threshold': wtp_threshold,
'decision': decision_text
}
# ✅ 新增:敏感度分析表格
st.write("### Sensitivity to WTP Threshold")
st.caption(
"This shows whether the intervention would be considered cost-effective at different willingness-to-pay thresholds.")
wtp_thresholds = [25000, 50000, 75000, 100000, 150000]
decision_data = []
for wtp in wtp_thresholds:
if qaly_diff > 0:
if cost_diff <= 0:
decision = "✅ Dominant (Cost-Effective)"
elif icer_val < wtp:
decision = "✅ Cost-Effective"
else:
decision = "❌ Not Cost-Effective"
else:
if cost_diff >= 0:
decision = "❌ Dominated"
else:
decision = "⚠️ Saves Money, Loses QALYs"
decision_data.append({
"WTP Threshold": f"${wtp:,}/QALY",
"Decision": decision
})
decision_df = pd.DataFrame(decision_data)
st.table(decision_df.set_index("WTP Threshold"))
except Exception as e:
st.error(f"An error occurred while running the cost-effectiveness analysis: {str(e)}")
st.error("Try different parameters or a different intervention.")
else:
st.write("No modifiable risk factors available for intervention.")
# Tab 5: Cost-Effectiveness Analysis (Probabilistic)
with tab5:
st.subheader("Cost-Effectiveness Analysis (Probabilistic)")
st.info(
"📌 This analysis uses **distributions** (with Standard Errors) for all parameters to account for uncertainty")
st.write("### Analysis Settings")
if available_interventions:
psa_intervention = st.selectbox(
"Select intervention to analyze:",
[name for _, name in available_interventions],
index=0,
key="psa_intervention"
)
psa_feature = next(
feature for feature, name in available_interventions
if name == psa_intervention
)
psa_col1, psa_col2, psa_col3 = st.columns(3)
with psa_col1:
st.write("#### Simulation Settings")
psa_cycles = st.slider("Time Horizon (Years)", min_value=5, max_value=30, value=10, key="psa_cycles")
psa_discount_rate = st.slider("Discount Rate (%)", min_value=0, max_value=10, value=3,
key="psa_discount") / 100
n_simulations = st.slider("Number of Simulations", min_value=100, max_value=10000, value=1000, step=100)
with psa_col2:
st.write("#### Cost Parameters (Mean ± SE)")
cost_normal_mean = st.number_input("Cost - Normal BP (Mean)", min_value=0, max_value=5000, value=200,
key="psa_cn_mean")
cost_normal_se = st.number_input("Cost - Normal BP (SE)", min_value=0, max_value=500, value=20,
key="psa_cn_se")
cost_pre_mean = st.number_input("Cost - Prehypertension (Mean)", min_value=0, max_value=5000, value=600,
key="psa_cp_mean")
cost_pre_se = st.number_input("Cost - Prehypertension (SE)", min_value=0, max_value=500, value=60,
key="psa_cp_se")
cost_s1_mean = st.number_input("Cost - Stage 1 HTN (Mean)", min_value=0, max_value=5000, value=1200,
key="psa_cs1_mean")
cost_s1_se = st.number_input("Cost - Stage 1 HTN (SE)", min_value=0, max_value=500, value=120,
key="psa_cs1_se")
cost_s2_mean = st.number_input("Cost - Stage 2 HTN (Mean)", min_value=0, max_value=5000, value=2200,
key="psa_cs2_mean")
cost_s2_se = st.number_input("Cost - Stage 2 HTN (SE)", min_value=0, max_value=500, value=220,
key="psa_cs2_se")
with psa_col3:
st.write("#### Utility Parameters (Mean ± SE)")
util_normal_mean = st.slider("Utility - Normal BP (Mean)", 0.0, 1.0, 1.0, 0.01, key="psa_un_mean")
util_normal_se = st.slider("Utility - Normal BP (SE)", 0.0, 0.1, 0.01, 0.001, key="psa_un_se")
util_pre_mean = st.slider("Utility - Prehypertension (Mean)", 0.0, 1.0, 0.9, 0.01, key="psa_up_mean")
util_pre_se = st.slider("Utility - Prehypertension (SE)", 0.0, 0.1, 0.02, 0.001, key="psa_up_se")
util_s1_mean = st.slider("Utility - Stage 1 HTN (Mean)", 0.0, 1.0, 0.7, 0.01, key="psa_us1_mean")
util_s1_se = st.slider("Utility - Stage 1 HTN (SE)", 0.0, 0.1, 0.03, 0.001, key="psa_us1_se")
util_s2_mean = st.slider("Utility - Stage 2 HTN (Mean)", 0.0, 1.0, 0.5, 0.01, key="psa_us2_mean")
util_s2_se = st.slider("Utility - Stage 2 HTN (SE)", 0.0, 0.1, 0.05, 0.001, key="psa_us2_se")
st.write("#### Intervention Settings")
psa_int_cost_mean = st.number_input("Additional Intervention Cost (Mean)", 0, 2000, 500, key="psa_int_cost_mean")
psa_int_cost_se = st.number_input("Additional Intervention Cost (SE)", 0, 200, 50, key="psa_int_cost_se")
st.markdown("---")
if st.button("🎲 Run Probabilistic Analysis", type="primary"):
with st.spinner(f"Running {n_simulations} Monte Carlo simulations..."):
try:
# Storage for simulation results
results_cost_A = []
results_cost_B = []
results_qaly_A = []
results_qaly_B = []
results_icer = []
results_delta_cost = []
results_delta_qaly = []
progress_bar = st.progress(0)
for sim in range(n_simulations):
# Sample from distributions (using normal distribution with SE)
# For costs - use gamma distribution (non-negative)
# For utilities - use beta distribution (bounded 0-1)
# Sample costs (using gamma approximation)
def sample_cost(mean, se):
if se == 0:
return mean
shape = (mean / se) ** 2
scale = se ** 2 / mean
return np.random.gamma(shape, scale)
# Sample utilities (using beta approximation)
def sample_utility(mean, se):
if se == 0 or mean == 0 or mean == 1:
return np.clip(mean, 0, 1)
# Beta distribution parameters
alpha = mean * ((mean * (1 - mean) / (se ** 2)) - 1)
beta = (1 - mean) * ((mean * (1 - mean) / (se ** 2)) - 1)
if alpha > 0 and beta > 0:
return np.random.beta(alpha, beta)
else:
return np.clip(np.random.normal(mean, se), 0, 1)
# Sample parameters for this iteration
C_A_sim = np.array([
sample_cost(cost_normal_mean, cost_normal_se),
sample_cost(cost_pre_mean, cost_pre_se),
sample_cost(cost_s1_mean, cost_s1_se),
sample_cost(cost_s2_mean, cost_s2_se)
])
int_cost_add = sample_cost(psa_int_cost_mean, psa_int_cost_se)
C_B_sim = C_A_sim.copy()
C_B_sim[0] += int_cost_add
U_sim = np.array([
sample_utility(util_normal_mean, util_normal_se),
sample_utility(util_pre_mean, util_pre_se),
sample_utility(util_s1_mean, util_s1_se),
sample_utility(util_s2_mean, util_s2_se)
])
# Run analysis with sampled parameters
sex_code = "M" if sex == "Male" else "F"
# Get transition matrices (using point estimates for transition probabilities)
l1, l2, l3, r = hazards_from_beta(sex_code, patient_features,
lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05)
# Baseline
Q_A = Q_matrix(l1, l2, l3, r)
P_A = discrete_P(Q_A, 1.0)
# Intervention
int_features = patient_features.copy()
if psa_feature == "Exercise_freq":
int_features[psa_feature] = 1
else:
int_features[psa_feature] = 0
l1_b, l2_b, l3_b, r_b = hazards_from_beta(sex_code, int_features,
lam10=0.08, lam20=0.10, lam30=0.12, rho0=0.05)
Q_B = Q_matrix(l1_b, l2_b, l3_b, r_b)
P_B = discrete_P(Q_B, 1.0)
# Run Markov models
start_dist = np.array([1, 0, 0, 0], float)
cost_A_sim, qaly_A_sim, _ = run_markov(P_A, C_A_sim, U_sim, start_dist,
psa_cycles, psa_discount_rate)
cost_B_sim, qaly_B_sim, _ = run_markov(P_B, C_B_sim, U_sim, start_dist,
psa_cycles, psa_discount_rate)
# Calculate incremental values
delta_cost = cost_B_sim - cost_A_sim
delta_qaly = qaly_B_sim - qaly_A_sim
# Calculate ICER
if abs(delta_qaly) > 1e-9:
icer_sim = delta_cost / delta_qaly
else:
icer_sim = np.inf if delta_cost > 0 else -np.inf
# Store results
results_cost_A.append(cost_A_sim)
results_cost_B.append(cost_B_sim)
results_qaly_A.append(qaly_A_sim)
results_qaly_B.append(qaly_B_sim)
results_delta_cost.append(delta_cost)
results_delta_qaly.append(delta_qaly)
results_icer.append(icer_sim)
# Update progress
progress_bar.progress((sim + 1) / n_simulations)
progress_bar.empty()
# Convert to arrays
results_delta_cost = np.array(results_delta_cost)
results_delta_qaly = np.array(results_delta_qaly)
results_icer = np.array(results_icer)
# Filter out infinite ICERs for display
results_icer_finite = results_icer[np.isfinite(results_icer)]
# Display results
st.success(f"✅ Completed {n_simulations} simulations!")
st.write("### Probabilistic Results Summary")
summary_cols = st.columns(4)
with summary_cols[0]:
st.metric("Mean ΔCost", f"${np.mean(results_delta_cost):.2f}")
st.caption(
f"95% CI: [{np.percentile(results_delta_cost, 2.5):.2f}, {np.percentile(results_delta_cost, 97.5):.2f}]")
with summary_cols[1]:
st.metric("Mean ΔQALY", f"{np.mean(results_delta_qaly):.4f}")
st.caption(
f"95% CI: [{np.percentile(results_delta_qaly, 2.5):.4f}, {np.percentile(results_delta_qaly, 97.5):.4f}]")
with summary_cols[2]:
st.metric("Mean ICER", f"${np.mean(results_icer_finite):.2f}/QALY")
st.caption(
f"95% CI: [{np.percentile(results_icer_finite, 2.5):.2f}, {np.percentile(results_icer_finite, 97.5):.2f}]")
with summary_cols[3]:
# Calculate probability cost-effective at $50k threshold
wtp_threshold = 50000
prob_ce = np.mean((results_delta_cost / results_delta_qaly) < wtp_threshold)
st.metric("Prob. Cost-Effective", f"{prob_ce * 100:.1f}%")
st.caption(f"at ${wtp_threshold:,}/QALY")
# Determine decision based on probability
if prob_ce >= 0.95:
psa_decision = f"HIGHLY LIKELY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)"
elif prob_ce >= 0.75:
psa_decision = f"LIKELY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)"
elif prob_ce >= 0.50:
psa_decision = f"POSSIBLY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)"
else:
psa_decision = f"UNLIKELY COST-EFFECTIVE ({prob_ce*100:.1f}% probability at ${wtp_threshold:,}/QALY)"
# Store PSA CEA results in session state
st.session_state.cea_results = {
'analysis_type': 'Probabilistic',
'intervention_name': psa_intervention,
'time_horizon': psa_cycles,
'cost_diff': np.mean(results_delta_cost),
'qaly_diff': np.mean(results_delta_qaly),
'icer': np.mean(results_icer_finite),
'wtp_threshold': wtp_threshold,
'decision': psa_decision,
'psa_mean_cost': np.mean(results_delta_cost),
'psa_cost_ci_low': np.percentile(results_delta_cost, 2.5),
'psa_cost_ci_high': np.percentile(results_delta_cost, 97.5),
'psa_mean_qaly': np.mean(results_delta_qaly),
'psa_qaly_ci_low': np.percentile(results_delta_qaly, 2.5),
'psa_qaly_ci_high': np.percentile(results_delta_qaly, 97.5),
'psa_mean_icer': np.mean(results_icer_finite),
'prob_ce': prob_ce
}
# Cost-Effectiveness Plane
st.write("### Cost-Effectiveness Plane (Scatter Plot)")
fig, ax = plt.subplots(figsize=(10, 8))
# Plot scatter points
ax.scatter(results_delta_qaly, results_delta_cost, alpha=0.3, s=20, c='blue')
# Plot mean point
ax.scatter(np.mean(results_delta_qaly), np.mean(results_delta_cost),
color='red', s=200, marker='*', edgecolors='black', linewidths=2,
label='Mean', zorder=5)
# Add quadrant lines
ax.axhline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax.axvline(0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
# Add WTP threshold line
xlim = ax.get_xlim()
x_range = np.linspace(xlim[0], xlim[1], 100)
ax.plot(x_range, x_range * wtp_threshold, 'k--', alpha=0.5,
linewidth=2, label=f'WTP ${wtp_threshold:,}/QALY')
ax.set_xlabel("Incremental QALYs", fontsize=12, fontweight='bold')
ax.set_ylabel("Incremental Cost ($)", fontsize=12, fontweight='bold')
ax.set_title("Cost-Effectiveness Plane (Probabilistic Sensitivity Analysis)",
fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)
st.pyplot(fig)
plt.close(fig)
# CEAC Curve
st.write("### Cost-Effectiveness Acceptability Curve (CEAC)")
wtp_range = np.linspace(0, 150000, 100)
prob_ce_array = []
for wtp in wtp_range:
# Count simulations where intervention is cost-effective
ce_count = 0
for i in range(n_simulations):
if results_delta_qaly[i] > 0:
if results_delta_cost[i] < 0: # Dominant
ce_count += 1
elif (results_delta_cost[i] / results_delta_qaly[i]) < wtp:
ce_count += 1
elif results_delta_qaly[i] < 0 and results_delta_cost[i] < 0:
# Trade-off: saves money but loses QALYs
# Not typically considered cost-effective
pass
prob_ce_array.append(ce_count / n_simulations)
fig2, ax2 = plt.subplots(figsize=(10, 6))
ax2.plot(wtp_range, prob_ce_array, 'b-', linewidth=2)
ax2.axhline(0.5, color='gray', linestyle='--', alpha=0.5)
ax2.axvline(50000, color='red', linestyle='--', alpha=0.5, label='$50,000/QALY')
ax2.set_xlabel("Willingness-to-Pay Threshold ($/QALY)", fontsize=12, fontweight='bold')
ax2.set_ylabel("Probability Cost-Effective", fontsize=12, fontweight='bold')
ax2.set_title("Cost-Effectiveness Acceptability Curve", fontsize=14, fontweight='bold')
ax2.set_ylim(0, 1)
ax2.grid(alpha=0.3)
ax2.legend()
st.pyplot(fig2)
plt.close(fig2)
# Distribution histograms
st.write("### Distribution of Results")
hist_col1, hist_col2 = st.columns(2)
with hist_col1:
fig3, ax3 = plt.subplots(figsize=(8, 5))
ax3.hist(results_delta_cost, bins=50, alpha=0.7, color='blue', edgecolor='black')
ax3.axvline(np.mean(results_delta_cost), color='red', linestyle='--',
linewidth=2, label=f'Mean: ${np.mean(results_delta_cost):.2f}')
ax3.set_xlabel("Incremental Cost ($)")
ax3.set_ylabel("Frequency")
ax3.set_title("Distribution of Incremental Costs")
ax3.legend()
ax3.grid(alpha=0.3)
st.pyplot(fig3)
plt.close(fig3)
with hist_col2:
fig4, ax4 = plt.subplots(figsize=(8, 5))
ax4.hist(results_delta_qaly, bins=50, alpha=0.7, color='green', edgecolor='black')
ax4.axvline(np.mean(results_delta_qaly), color='red', linestyle='--',
linewidth=2, label=f'Mean: {np.mean(results_delta_qaly):.4f}')
ax4.set_xlabel("Incremental QALYs")
ax4.set_ylabel("Frequency")
ax4.set_title("Distribution of Incremental QALYs")
ax4.legend()
ax4.grid(alpha=0.3)
st.pyplot(fig4)
plt.close(fig4)
# ICER distribution
st.write("### ICER Distribution")
fig5, ax5 = plt.subplots(figsize=(10, 5))
ax5.hist(results_icer_finite, bins=50, alpha=0.7, color='purple', edgecolor='black')
ax5.axvline(np.mean(results_icer_finite), color='red', linestyle='--',
linewidth=2, label=f'Mean: ${np.mean(results_icer_finite):.2f}/QALY')
ax5.axvline(50000, color='orange', linestyle='--', linewidth=2,
label='$50,000/QALY threshold')
ax5.set_xlabel("ICER ($/QALY)")
ax5.set_ylabel("Frequency")
ax5.set_title("Distribution of ICER Values")
ax5.legend()
ax5.grid(alpha=0.3)
st.pyplot(fig5)
plt.close(fig5)
except Exception as e:
st.error(f"Error running probabilistic analysis: {str(e)}")
import traceback
st.code(traceback.format_exc())
else:
st.write("No modifiable risk factors available for intervention.")
# Tab 6: Personalized Chat
with tab6:
st.subheader("💬 Personalized Health Recommendations")
if not openai_api_key:
st.warning("⚠️ Please enter your OpenAI API key in the top right corner to use Personalized Chat.")
else:
if not st.session_state.summary_generated:
if st.button("✨ Generate Personalized Health Summary", type="primary"):
with st.spinner("Analyzing your health profile..."):
try:
llm = get_llm()
if llm:
patient_info = get_patient_info_string()
cea_results = get_cea_results_string()
summary_prompt = f"""Generate a comprehensive health assessment summary for this patient:
{patient_info}
{cea_results}
Include:
1. Overall risk level assessment
2. Key modifiable risk factors
3. Top 3 priority recommendations based on both health profile AND cost-effectiveness analysis results
4. Expected health benefits and economic value of interventions
IMPORTANT:
- Use plain text formatting only (no LaTeX, no \\text{{}} or \\frac{{}}{{}} syntax)
- Write any formulas in plain text
- Use simple markdown formatting (**, -, numbers) for emphasis
- Avoid special characters that may not render correctly
Format the response with clear sections and bullet points."""
response = llm.invoke(summary_prompt).content
response = response.replace("\\text{", "").replace("}", "")
response = response.replace("\\frac{", "(").replace("}{", ")/(")
st.session_state.recommendation_messages.append({
"role": "assistant",
"content": response
})
st.session_state.summary_generated = True
st.rerun()
except Exception as e:
st.error(f"Error generating summary: {str(e)}")
for message in st.session_state.recommendation_messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if st.session_state.summary_generated:
if prompt := st.chat_input("Ask about your personalized recommendations..."):
st.session_state.recommendation_messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
llm = get_llm()
if llm:
patient_info = get_patient_info_string()
cea_results = get_cea_results_string()
history_text = ""
for msg in st.session_state.recommendation_messages[-10:]:
role = "Patient" if msg["role"] == "user" else "Health Coach"
history_text += f"{role}: {msg['content']}\n\n"
full_prompt = f"""You are a personalized health coach and health economist specializing in hypertension management.
PATIENT PROFILE:
{patient_info}
{cea_results}
Provide evidence-based, actionable recommendations for:
- Weight management and DASH diet
- Exercise prescriptions
- Smoking cessation strategies
- Medication adherence
- Lifestyle modifications
- Stress management
- Cost-effectiveness interpretation and value-based decision making
Answer questions about:
- Patient health risks and progression
- Cost-effectiveness analysis results (ICER, QALYs, etc.)
- Which interventions provide the best value for money
- How to interpret deterministic vs probabilistic results
- Personalized recommendations based on both clinical and economic evidence
Be empathetic, practical, and motivating. Cite specific guidelines when relevant.
IMPORTANT:
- Use plain text formatting only (no LaTeX, no \\text{{}} or \\frac{{}}{{}} syntax)
- Write formulas in plain text
- Use simple markdown formatting (**, -, numbers) for emphasis
- Avoid special characters that may not render correctly
Conversation History:
{history_text}
Patient Question: {prompt}
Your Personalized Advice:"""
response = llm.invoke(full_prompt).content
response = response.replace("\\text{", "").replace("}", "")
response = response.replace("\\frac{", "(").replace("}{", ")/(")
st.markdown(response, unsafe_allow_html=False)
st.session_state.recommendation_messages.append({
"role": "assistant",
"content": response
})
else:
st.error("Failed to initialize AI.")
except Exception as e:
st.error(f"Error: {str(e)}")
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align: center; color: gray; font-size: 0.8em;'>
Powered by LangChain & OpenAI | Hypertension CEA Tool v2.0
</div>
""", unsafe_allow_html=True)