Spaces:
Sleeping
Sleeping
Upload personalized_ht4.py
Browse files- personalized_ht4.py +62 -79
personalized_ht4.py
CHANGED
|
@@ -98,40 +98,49 @@ with col1:
|
|
| 98 |
st.markdown("*Multistate ML Analysis Showcase Hypertension*")
|
| 99 |
|
| 100 |
with col2:
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"🔑 OpenAI API Key",
|
| 103 |
type="password",
|
| 104 |
placeholder="sk-...",
|
| 105 |
-
help="Enter your OpenAI API key to enable AI features"
|
|
|
|
| 106 |
)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if openai_api_key:
|
| 109 |
st.success("✓ API Key set")
|
| 110 |
else:
|
| 111 |
-
st.warning("⚠️ Enter API key")
|
| 112 |
-
|
| 113 |
|
| 114 |
# ali Check if API key is provided
|
| 115 |
-
def get_llm():
|
| 116 |
"""Initialize LangChain LLM with OpenAI - with rate limiting"""
|
| 117 |
-
if not
|
|
|
|
|
|
|
|
|
|
| 118 |
return None
|
| 119 |
|
| 120 |
-
# ✅ API 調用限制(每 session 最多 100 次)
|
| 121 |
if st.session_state.api_call_count >= 100:
|
| 122 |
-
st.error("⚠️ API call limit reached (100 calls per session).
|
| 123 |
st.stop()
|
| 124 |
|
| 125 |
try:
|
| 126 |
llm = ChatOpenAI(
|
| 127 |
model="gpt-4o-mini",
|
| 128 |
temperature=0.7,
|
| 129 |
-
openai_api_key=
|
| 130 |
)
|
| 131 |
-
|
| 132 |
-
# ✅ 記錄 API 調用
|
| 133 |
st.session_state.api_call_count += 1
|
| 134 |
-
|
| 135 |
return llm
|
| 136 |
except Exception as e:
|
| 137 |
st.error(f"Error initializing OpenAI: {str(e)}")
|
|
@@ -139,16 +148,17 @@ def get_llm():
|
|
| 139 |
|
| 140 |
|
| 141 |
# Create vector store from patient data
|
| 142 |
-
def create_patient_vectorstore(patients_df: pd.DataFrame):
|
| 143 |
-
"""Create vector store from patient dataframe for RAG retrieval -
|
| 144 |
-
if not
|
|
|
|
|
|
|
|
|
|
| 145 |
return None
|
| 146 |
|
| 147 |
try:
|
| 148 |
-
|
| 149 |
user_id = st.session_state.user_id
|
| 150 |
-
persist_dir = f"./data/chroma/{user_id}"
|
| 151 |
-
os.makedirs(persist_dir, exist_ok=True)
|
| 152 |
|
| 153 |
documents = []
|
| 154 |
for idx, row in patients_df.iterrows():
|
|
@@ -163,61 +173,42 @@ Betel: {row.get('betel', 'No')}, Family History: {row['family_history']}"""
|
|
| 163 |
|
| 164 |
doc = Document(
|
| 165 |
page_content=patient_text,
|
| 166 |
-
metadata={
|
| 167 |
-
"patient_id": row['patient_id'],
|
| 168 |
-
"user_id": user_id # ✅ 加入 user_id 標記
|
| 169 |
-
}
|
| 170 |
)
|
| 171 |
documents.append(doc)
|
| 172 |
|
| 173 |
-
embeddings = OpenAIEmbeddings(openai_api_key=
|
| 174 |
|
| 175 |
-
# ✅
|
| 176 |
vectorstore = Chroma(
|
| 177 |
-
collection_name=f"
|
| 178 |
-
embedding_function=embeddings
|
| 179 |
-
persist_directory=
|
| 180 |
)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
try:
|
| 184 |
-
vectorstore.delete_collection()
|
| 185 |
-
vectorstore = Chroma(
|
| 186 |
-
collection_name=f"user_{user_id}_patients",
|
| 187 |
-
embedding_function=embeddings,
|
| 188 |
-
persist_directory=persist_dir
|
| 189 |
-
)
|
| 190 |
-
except:
|
| 191 |
-
pass
|
| 192 |
-
|
| 193 |
-
# 加入新文件
|
| 194 |
-
vectorstore.add_documents(documents)
|
| 195 |
-
|
| 196 |
return vectorstore
|
| 197 |
|
| 198 |
except Exception as e:
|
| 199 |
st.error(f"Error creating vector store: {str(e)}")
|
| 200 |
return None
|
| 201 |
-
|
| 202 |
-
|
| 203 |
# Retrieve patient by ID
|
| 204 |
def retrieve_patient_by_id(patient_id: str):
|
| 205 |
-
"""Retrieve patient from
|
| 206 |
-
|
| 207 |
-
session_key = f"patients_df_{user_id}"
|
| 208 |
-
|
| 209 |
-
# ✅ 從用戶專屬的 DataFrame 檢索
|
| 210 |
-
if session_key not in st.session_state or st.session_state[session_key] is None:
|
| 211 |
return None
|
| 212 |
|
| 213 |
-
patients_df = st.session_state
|
| 214 |
patient_row = patients_df[patients_df['patient_id'] == patient_id]
|
| 215 |
|
| 216 |
if patient_row.empty:
|
| 217 |
return None
|
| 218 |
|
| 219 |
-
return patient_row.iloc[0].to_dict()
|
| 220 |
-
|
|
|
|
| 221 |
# Sidebar for patient information
|
| 222 |
st.sidebar.header("👤 Patient Information")
|
| 223 |
|
|
@@ -244,10 +235,8 @@ with st.sidebar.expander("📂 Upload Patient Database (Optional)", expanded=Fal
|
|
| 244 |
else:
|
| 245 |
df = pd.read_excel(uploaded_file)
|
| 246 |
|
| 247 |
-
# ✅ 修改:使用用戶專屬的 key
|
| 248 |
-
|
| 249 |
-
session_key = f"patients_df_{user_id}"
|
| 250 |
-
st.session_state[session_key] = df
|
| 251 |
|
| 252 |
# ✅ 記錄檔案 hash
|
| 253 |
file_hash = hashlib.md5(uploaded_file.getvalue()).hexdigest()
|
|
@@ -255,51 +244,45 @@ with st.sidebar.expander("📂 Upload Patient Database (Optional)", expanded=Fal
|
|
| 255 |
st.session_state.uploaded_files.append(file_hash)
|
| 256 |
|
| 257 |
st.success(f"✅ Loaded {len(df)} patients ({file_size_mb:.1f}MB)")
|
| 258 |
-
|
| 259 |
-
# Vector store 建立
|
| 260 |
if openai_api_key and st.button("🔄 Create Vector Store for Smart Search"):
|
| 261 |
with st.spinner("Creating isolated vector store..."):
|
| 262 |
-
vectorstore = create_patient_vectorstore(df)
|
| 263 |
if vectorstore:
|
| 264 |
st.session_state.vectorstore = vectorstore
|
| 265 |
-
st.success("✅ Vector store created!
|
| 266 |
|
| 267 |
except Exception as e:
|
| 268 |
-
st.error(f"Error loading file: {str(e)}")
|
|
|
|
| 269 |
|
| 270 |
# Patient ID retrieval
|
| 271 |
# ✅ 修正:使用用戶專屬的 DataFrame key
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
# 顯示已上傳的病患數量(可選)
|
| 276 |
-
if session_key in st.session_state and st.session_state[session_key] is not None:
|
| 277 |
-
df = st.session_state[session_key]
|
| 278 |
st.caption(f"📊 {len(df)} patients loaded")
|
| 279 |
|
| 280 |
-
if
|
| 281 |
st.markdown("---")
|
| 282 |
patient_id_input = st.text_input("🔍 Enter Patient ID", placeholder="P001")
|
| 283 |
|
| 284 |
-
if st.button("
|
| 285 |
if patient_id_input:
|
| 286 |
patient_data = retrieve_patient_by_id(patient_id_input)
|
| 287 |
if patient_data:
|
| 288 |
-
#
|
| 289 |
-
user_id = st.session_state.user_id
|
| 290 |
-
loaded_key = f"loaded_patient_{user_id}"
|
| 291 |
-
st.session_state[loaded_key] = patient_data
|
| 292 |
st.session_state.loaded_patient = patient_data
|
| 293 |
|
| 294 |
-
#
|
| 295 |
st.session_state.summary_generated = False
|
| 296 |
st.session_state.recommendation_messages = []
|
| 297 |
-
|
| 298 |
-
# ✅ 清除 CEA 結果(可選,因為新病人的分析應該重新做)
|
| 299 |
st.session_state.cea_results = None
|
| 300 |
|
| 301 |
st.success(f"✅ Loaded {patient_id_input}.")
|
| 302 |
st.rerun()
|
|
|
|
|
|
|
| 303 |
|
| 304 |
st.sidebar.markdown("---")
|
| 305 |
|
|
@@ -550,7 +533,7 @@ with tab1:
|
|
| 550 |
with st.chat_message("assistant"):
|
| 551 |
with st.spinner("Thinking..."):
|
| 552 |
try:
|
| 553 |
-
llm = get_llm()
|
| 554 |
if llm:
|
| 555 |
history_text = ""
|
| 556 |
for msg in st.session_state.assistant_messages[-10:]:
|
|
@@ -1592,7 +1575,7 @@ with tab6:
|
|
| 1592 |
if st.button("✨ Generate Personalized Health Summary", type="primary"):
|
| 1593 |
with st.spinner("Analyzing your health profile..."):
|
| 1594 |
try:
|
| 1595 |
-
llm = get_llm()
|
| 1596 |
if llm:
|
| 1597 |
patient_info = get_patient_info_string()
|
| 1598 |
cea_results = get_cea_results_string()
|
|
@@ -1643,7 +1626,7 @@ Format the response with clear sections and bullet points."""
|
|
| 1643 |
with st.chat_message("assistant"):
|
| 1644 |
with st.spinner("Thinking..."):
|
| 1645 |
try:
|
| 1646 |
-
llm = get_llm()
|
| 1647 |
if llm:
|
| 1648 |
patient_info = get_patient_info_string()
|
| 1649 |
cea_results = get_cea_results_string()
|
|
|
|
| 98 |
st.markdown("*Multistate ML Analysis Showcase Hypertension*")
|
| 99 |
|
| 100 |
with col2:
|
| 101 |
+
|
| 102 |
+
if 'openai_api_key' not in st.session_state:
|
| 103 |
+
st.session_state.openai_api_key = ""
|
| 104 |
+
|
| 105 |
+
openai_api_key_input = st.text_input(
|
| 106 |
"🔑 OpenAI API Key",
|
| 107 |
type="password",
|
| 108 |
placeholder="sk-...",
|
| 109 |
+
help="Enter your OpenAI API key to enable AI features",
|
| 110 |
+
value=st.session_state.openai_api_key
|
| 111 |
)
|
| 112 |
|
| 113 |
+
if openai_api_key_input != st.session_state.openai_api_key:
|
| 114 |
+
st.session_state.openai_api_key = openai_api_key_input
|
| 115 |
+
|
| 116 |
+
openai_api_key = st.session_state.openai_api_key
|
| 117 |
+
|
| 118 |
if openai_api_key:
|
| 119 |
st.success("✓ API Key set")
|
| 120 |
else:
|
| 121 |
+
st.warning("⚠️ Enter API key")
|
| 122 |
+
|
| 123 |
|
| 124 |
# ali Check if API key is provided
|
| 125 |
+
def get_llm(api_key=None):
|
| 126 |
"""Initialize LangChain LLM with OpenAI - with rate limiting"""
|
| 127 |
+
if not api_key:
|
| 128 |
+
api_key = st.session_state.get('openai_api_key', None)
|
| 129 |
+
|
| 130 |
+
if not api_key:
|
| 131 |
return None
|
| 132 |
|
|
|
|
| 133 |
if st.session_state.api_call_count >= 100:
|
| 134 |
+
st.error("⚠️ API call limit reached (100 calls per session).")
|
| 135 |
st.stop()
|
| 136 |
|
| 137 |
try:
|
| 138 |
llm = ChatOpenAI(
|
| 139 |
model="gpt-4o-mini",
|
| 140 |
temperature=0.7,
|
| 141 |
+
openai_api_key=api_key
|
| 142 |
)
|
|
|
|
|
|
|
| 143 |
st.session_state.api_call_count += 1
|
|
|
|
| 144 |
return llm
|
| 145 |
except Exception as e:
|
| 146 |
st.error(f"Error initializing OpenAI: {str(e)}")
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
# Create vector store from patient data
|
| 151 |
+
def create_patient_vectorstore(patients_df: pd.DataFrame, api_key: str = None):
|
| 152 |
+
"""Create vector store from patient dataframe for RAG retrieval (memory-only)"""
|
| 153 |
+
if not api_key:
|
| 154 |
+
api_key = st.session_state.get('openai_api_key', None)
|
| 155 |
+
|
| 156 |
+
if not api_key:
|
| 157 |
return None
|
| 158 |
|
| 159 |
try:
|
| 160 |
+
import time
|
| 161 |
user_id = st.session_state.user_id
|
|
|
|
|
|
|
| 162 |
|
| 163 |
documents = []
|
| 164 |
for idx, row in patients_df.iterrows():
|
|
|
|
| 173 |
|
| 174 |
doc = Document(
|
| 175 |
page_content=patient_text,
|
| 176 |
+
metadata={"patient_id": row['patient_id']}
|
|
|
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
documents.append(doc)
|
| 179 |
|
| 180 |
+
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
| 181 |
|
| 182 |
+
# ✅ 純記憶體模式,避免文件鎖問題
|
| 183 |
vectorstore = Chroma(
|
| 184 |
+
collection_name=f"patients_{user_id}_{int(time.time())}",
|
| 185 |
+
embedding_function=embeddings
|
| 186 |
+
# 不設置 persist_directory = 純記憶體
|
| 187 |
)
|
| 188 |
|
| 189 |
+
vectorstore.add_documents(documents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
return vectorstore
|
| 191 |
|
| 192 |
except Exception as e:
|
| 193 |
st.error(f"Error creating vector store: {str(e)}")
|
| 194 |
return None
|
| 195 |
+
|
| 196 |
+
|
| 197 |
# Retrieve patient by ID
|
| 198 |
def retrieve_patient_by_id(patient_id: str):
|
| 199 |
+
"""Retrieve patient from dataframe by ID"""
|
| 200 |
+
if st.session_state.patients_df is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
return None
|
| 202 |
|
| 203 |
+
patients_df = st.session_state.patients_df
|
| 204 |
patient_row = patients_df[patients_df['patient_id'] == patient_id]
|
| 205 |
|
| 206 |
if patient_row.empty:
|
| 207 |
return None
|
| 208 |
|
| 209 |
+
return patient_row.iloc[0].to_dict()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
# Sidebar for patient information
|
| 213 |
st.sidebar.header("👤 Patient Information")
|
| 214 |
|
|
|
|
| 235 |
else:
|
| 236 |
df = pd.read_excel(uploaded_file)
|
| 237 |
|
| 238 |
+
# ✅ 修改:使用用戶專屬的 key
|
| 239 |
+
st.session_state.patients_df = df
|
|
|
|
|
|
|
| 240 |
|
| 241 |
# ✅ 記錄檔案 hash
|
| 242 |
file_hash = hashlib.md5(uploaded_file.getvalue()).hexdigest()
|
|
|
|
| 244 |
st.session_state.uploaded_files.append(file_hash)
|
| 245 |
|
| 246 |
st.success(f"✅ Loaded {len(df)} patients ({file_size_mb:.1f}MB)")
|
| 247 |
+
|
| 248 |
+
# Vector store 建立
|
| 249 |
if openai_api_key and st.button("🔄 Create Vector Store for Smart Search"):
|
| 250 |
with st.spinner("Creating isolated vector store..."):
|
| 251 |
+
vectorstore = create_patient_vectorstore(df, st.session_state.openai_api_key)
|
| 252 |
if vectorstore:
|
| 253 |
st.session_state.vectorstore = vectorstore
|
| 254 |
+
st.success("✅ Vector store created!")
|
| 255 |
|
| 256 |
except Exception as e:
|
| 257 |
+
st.error(f"Error loading file: {str(e)}")
|
| 258 |
+
|
| 259 |
|
| 260 |
# Patient ID retrieval
|
| 261 |
# ✅ 修正:使用用戶專屬的 DataFrame key
|
| 262 |
+
if st.session_state.patients_df is not None:
|
| 263 |
+
df = st.session_state.patients_df
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
st.caption(f"📊 {len(df)} patients loaded")
|
| 265 |
|
| 266 |
+
if st.session_state.patients_df is not None:
|
| 267 |
st.markdown("---")
|
| 268 |
patient_id_input = st.text_input("🔍 Enter Patient ID", placeholder="P001")
|
| 269 |
|
| 270 |
+
if st.button("📥 Load Patient Data"):
|
| 271 |
if patient_id_input:
|
| 272 |
patient_data = retrieve_patient_by_id(patient_id_input)
|
| 273 |
if patient_data:
|
| 274 |
+
# ✅ 只用簡單命名
|
|
|
|
|
|
|
|
|
|
| 275 |
st.session_state.loaded_patient = patient_data
|
| 276 |
|
| 277 |
+
# 清除 Tab 6 的狀態
|
| 278 |
st.session_state.summary_generated = False
|
| 279 |
st.session_state.recommendation_messages = []
|
|
|
|
|
|
|
| 280 |
st.session_state.cea_results = None
|
| 281 |
|
| 282 |
st.success(f"✅ Loaded {patient_id_input}.")
|
| 283 |
st.rerun()
|
| 284 |
+
else:
|
| 285 |
+
st.error(f"❌ Patient ID '{patient_id_input}' not found.")
|
| 286 |
|
| 287 |
st.sidebar.markdown("---")
|
| 288 |
|
|
|
|
| 533 |
with st.chat_message("assistant"):
|
| 534 |
with st.spinner("Thinking..."):
|
| 535 |
try:
|
| 536 |
+
llm = get_llm(st.session_state.openai_api_key)
|
| 537 |
if llm:
|
| 538 |
history_text = ""
|
| 539 |
for msg in st.session_state.assistant_messages[-10:]:
|
|
|
|
| 1575 |
if st.button("✨ Generate Personalized Health Summary", type="primary"):
|
| 1576 |
with st.spinner("Analyzing your health profile..."):
|
| 1577 |
try:
|
| 1578 |
+
llm = get_llm(st.session_state.openai_api_key)
|
| 1579 |
if llm:
|
| 1580 |
patient_info = get_patient_info_string()
|
| 1581 |
cea_results = get_cea_results_string()
|
|
|
|
| 1626 |
with st.chat_message("assistant"):
|
| 1627 |
with st.spinner("Thinking..."):
|
| 1628 |
try:
|
| 1629 |
+
llm = get_llm(st.session_state.openai_api_key)
|
| 1630 |
if llm:
|
| 1631 |
patient_info = get_patient_info_string()
|
| 1632 |
cea_results = get_cea_results_string()
|