import streamlit as st
from PIL import Image
import base64
import time
import streamlit.components.v1 as components
import json
from typing import Dict, List, Optional
from pathlib import Path
from streamlit_autorefresh import st_autorefresh
import pandas as pd
import re
import streamlit_image_select as sis
from fpdf import FPDF
import io
import markdown
import os
from openai import OpenAI
from http import HTTPStatus
import dashscope
# =============================================================================
# 配置和常量
# =============================================================================
PAGE_CONFIG = {
"page_title": "Knee OA Demo",
"layout": "wide"
}
CHAT_CONFIG = {
"update_interval": 2,
"height": 640,
"max_width": "80%"
}
IMAGE_PATHS = {
"logo": "images/logo.png",
"framework": "images/framework.png",
"status_framework": "images/status_framework.png",
"predicting_framework": "images/predicting_framework.png",
"recommendation_framework": "images/Recommendation_framework.png"
}
CASES_FILE = "cases.json"
PARAMS_FILE = "predict_params.json"
# =============================================================================
# 工具函数
# =============================================================================
@st.cache_data
def get_base64_image(image_path: str) -> Optional[str]:
"""安全获取图片base64编码"""
try:
if Path(image_path).exists():
with open(image_path, "rb") as f:
data = f.read()
return base64.b64encode(data).decode()
except Exception as e:
st.error(f"无法加载图片 {image_path}: {e}")
return None
@st.cache_data
def load_initial_chat_history(file_path: str = "initial_chat.json") -> List[Dict]:
"""从 JSON 文件加载初始聊天历史"""
try:
if Path(file_path).exists():
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
else:
st.warning(f"找不到初始对话文件:{file_path}")
except Exception as e:
st.error(f"加载初始聊天对话失败: {e}")
return []
@st.cache_data
def load_analysis_report(json_file="assess_result.json") -> Dict:
try:
with open(json_file, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
st.error(f"Unable to load the analysis report file: {e}")
return {}
@st.cache_data
def load_case_data() -> Dict:
"""加载病例数据"""
try:
if Path(CASES_FILE).exists():
with open(CASES_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
st.error(f"无法加载病例数据: {e}")
return {}
def load_plan(agent_type: str):
path = f"{agent_type}_plan.json"
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def clean_text_for_pdf(text: str) -> str:
replacements = {
"–": "-", # 长破折号
"—": "-", # 全角破折号
"“": "\"", # 中文引号
"”": "\"",
"’": "'",
"•": "-", # 项目符号
"→": "->",
"…": "...",
"©": "(c)",
}
for old, new in replacements.items():
text = text.replace(old, new)
return text.encode("latin-1", "ignore").decode("latin-1") # 丢弃不可编码字符
def strip_non_latin1(text: str) -> str:
"""去除无法被 Latin-1 编码的字符(如 emoji、中文)"""
return text.encode("latin-1", errors="ignore").decode("latin-1")
def generate_pdf(text: str) -> bytes:
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12) # 保留默认字体
# 🔧 逐行添加,过滤掉无法编码的字符
for line in text.split("\n"):
clean_line = strip_non_latin1(line)
pdf.multi_cell(0, 10, txt=clean_line)
return pdf.output(dest="S").encode("latin1")
def generate_report_text_from_prediction(params: dict) -> str:
lines = ["Prediction Report\n", "=" * 30 + "\n"]
for k, v in params.items():
lines.append(f"{k}: {json.dumps(v, ensure_ascii=False) if isinstance(v, (dict, list)) else v}")
return "\n".join(lines)
def safe_image_display(image_path: str, caption: str = "", **kwargs):
"""显示图片"""
try:
if Path(image_path).exists():
st.image(image_path, caption=caption, **kwargs)
else:
st.warning(f"⚠️ 图片未找到: {image_path}")
except Exception as e:
st.error(f"显示图片时出错: {e}")
def generate_report_text_from_prediction(params: dict) -> str:
lines = []
# 1. Symptom Trajectory Forecast
lines.append("📊 Symptom Trajectory Forecast (KOOS, 0–100)")
symptom_rows = [
("Right Knee Pain", "symptom_trajectory.right_knee.pain"),
("Right Knee Symptoms", "symptom_trajectory.right_knee.symptoms"),
("Sport/Recreation Function", "symptom_trajectory.right_knee.sport_recreation_function"),
("Quality of Life", "symptom_trajectory.right_knee.quality_of_life"),
("Left Knee Pain", "symptom_trajectory.left_knee.pain")
]
for label, base_key in symptom_rows:
v00 = params.get(f"{base_key}.v00", "N/A")
v01 = params.get(f"{base_key}.v01", "N/A")
v04 = params.get(f"{base_key}.v04", "N/A")
lines.append(f"- {label}: Current={v00}, Year 2={v01}, Year 4={v04}")
lines.append("")
# 2. Imaging Trajectory Forecast
lines.append("🦴 Imaging Trajectory (KL grade, 0–4)")
for side in ["right", "left"]:
v00 = params.get(f"imaging_trajectory.{side}_knee.pain.v00", "N/A")
v01 = params.get(f"imaging_trajectory.{side}_knee.pain.v01", "N/A")
v04 = params.get(f"imaging_trajectory.{side}_knee.pain.v04", "N/A")
lines.append(f"- {side.capitalize()} Knee: Current={v00}, Year 2={v01}, Year 4={v04}")
lines.append("")
# 3. SHAP Key Factors
lines.append("💡 Key Contributing Factors (SHAP)")
shap_data = params.get("key_factors.right_knee_symptoms_year2", [])
for item in shap_data:
feature = item.get("feature", "Unknown")
impact = item.get("impact", "N/A")
effect = item.get("effect", "")
lines.append(f"- {feature}: {impact} ({effect})")
return "\n".join(lines)
# =============================================================================
# 样式定义
# =============================================================================
def get_navigation_styles(logo_base64: str) -> str:
"""获取导航栏样式"""
return f"""
Knee Osteoarthritis Management Platform
膝关节骨关节炎人工智能平台
"""
def get_chat_styles() -> str:
"""返回优化后的聊天气泡样式"""
return """
"""
# =============================================================================
# 聊天功能
# =============================================================================
class ChatManager:
def __init__(self, initial_chat_file: str = "assess_chat.json"):
self.initial_history = load_initial_chat_history(initial_chat_file)
# def initialize_state(self):
# """初始化聊天状态"""
# if "chat_history" not in st.session_state:
# st.session_state.chat_history = self.initial_history.copy()
# st.session_state.chat_step = 1
# st.session_state.last_update_time = time.time()
def initialize_state(self):
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# 遍历 initial_history,保留所有用户消息,assistant 用 None 占位
for msg in self.initial_history:
if msg["role"] == "user":
st.session_state.chat_history.append(msg)
else:
st.session_state.chat_history.append({"role": "assistant", "content": None})
st.session_state.chat_step = 1 # 指向第一条 assistant 消息
st.session_state.last_update_time = time.time()
# def update_progress(self):
# """更新聊天进度(非阻塞)"""
# current_time = time.time()
# if (st.session_state.chat_step < len(st.session_state.chat_history) and
# current_time - st.session_state.last_update_time > CHAT_CONFIG["update_interval"]):
# st.session_state.chat_step += 1
# st.session_state.last_update_time = current_time
def update_progress(self):
now = time.time()
if now - st.session_state.last_update_time < CHAT_CONFIG["update_interval"]:
return
step = st.session_state.chat_step
initial = self.initial_history
history = st.session_state.chat_history
if step >= len(initial):
return # 全部处理完
next_msg = initial[step]
if next_msg["role"] == "user":
# 避免重复插入
if len(history) <= step or history[step]["content"] != next_msg["content"]:
# user 消息直接写入对应位置(如果存在替换,否则append)
if len(history) > step:
history[step] = next_msg
else:
history.append(next_msg)
st.session_state.chat_step += 1
st.session_state.last_update_time = now
# elif next_msg["role"] == "assistant":
# # 调用API前,确认上一条是用户消息,且内容不为空
# if step == 0 or history[step - 1]["role"] != "user" or not history[step - 1]["content"]:
# st.warning("⚠️ 无法生成 AI 回答:找不到上一条用户消息")
# return
# user_prompt = history[step - 1]["content"]
# app_id = "c968f91131ac432787f5ef81f51922ba"
# api_key = os.getenv("DASHSCOPE_API_KEY")
# reply = self.generate_response(user_prompt, app_id, api_key)
# # 写入 assistant 回复
# if len(history) > step:
# history[step]["content"] = reply
# else:
# history.append({"role": "assistant", "content": reply})
# st.session_state.chat_step += 1
# st.session_state.last_update_time = now
elif next_msg["role"] == "assistant": #multi
if step == 0 or history[step - 1]["role"] != "user" or not history[step - 1]["content"]:
st.warning("⚠️ 无法生成 AI 回答:找不到上一条用户消息")
return
# 取当前上下文(含用户与助手历史)
messages = history[:step] # 所有已发生的轮次
messages.append(history[step - 1]) # 确保当前用户输入在最后
app_id = st.secrets["DASHSCOPE_AGENT_ID"]
api_key = st.secrets["DASHSCOPE_API_KEY"]
reply = self.generate_response(messages, app_id, api_key)
if not reply:
st.warning("⚠️ AI 未返回内容,请稍后重试")
return
# 写入 assistant 回复
if len(history) > step:
history[step]["content"] = reply
else:
history.append({"role": "assistant", "content": reply})
st.session_state.chat_step += 1
st.session_state.last_update_time = now
def render_message(self, role: str, content: str) -> str:
"""渲染单条消息(AI左侧,用户右侧)"""
if role == "user":
return f"""
"""
else:
return f"""
"""
def render_chat_interface(self):
"""自动逐条渲染聊天界面"""
chat_html = "".join([
self.render_message(msg["role"], msg["content"])
for msg in st.session_state.chat_history[:st.session_state.chat_step]
])
height = CHAT_CONFIG["height"]
components.html(f"""
""", height=height)
def handle_user_input(self):
user_input = st.chat_input("Please enter your symptoms, medical history or problems...")
if user_input:
st.session_state.chat_history.append({"role": "user", "content": user_input})
api_key = st.secrets["DASHSCOPE_API_KEY"]
app_id = st.secrets["DASHSCOPE_AGENT_ID"]
# response = self.generate_response(user_input, app_id, api_key)
messages = [{"role": "user", "content": user_input}]
response = self.generate_response(messages, app_id, api_key)
st.session_state.chat_history.append({"role": "assistant", "content": response})
st.session_state.chat_step = len(st.session_state.chat_history)
st.rerun()
# def generate_response(self, user_input: str) -> str:
# """生成助手回复"""
# responses = {
# "Pain": "Please describe in detail the nature, frequency and triggering factors of the pain.",
# "Swelling": "When does swelling usually occur? Is there any accompanying fever?",
# "Stiffness": "How long does morning stiffness last? Was there any improvement after the activity?",
# "Cracking": "Is joint cracking accompanied by pain?"
# }
# for keyword, response in responses.items():
# if keyword in user_input:
# return response
# return "Thank you for your feedback. I will conduct an analysis based on this information. Please continue to describe your symptoms."
def generate_response(self, messages: list, app_id: str, api_key: str) -> str:
api_key = st.secrets["DASHSCOPE_API_KEY"]
app_id = st.secrets["DASHSCOPE_AGENT_ID"]
dashscope.api_key = api_key
dashscope.base_http_api_url = 'https://dashscope.aliyuncs.com/api/v1'
# 拼接历史对话为 prompt(只支持单轮,不建议传 messages 列表)
prompt = "\n".join([
f"{msg['role'].capitalize()}: {msg['content']}" for msg in messages if msg["content"]
])
try:
response = dashscope.Application.call(
app_id=app_id,
prompt=prompt
)
except Exception as e:
return f"⚠️ 智能体调用异常: {e}"
if response.status_code != HTTPStatus.OK:
return f"⚠️ 智能体调用失败: {response.message}"
return response.output.get("text", "⚠️ 没有返回内容")
# =============================================================================
# 页面渲染函数
# =============================================================================
def render_navigation():
"""渲染导航栏"""
logo_base64 = get_base64_image(IMAGE_PATHS["logo"])
if logo_base64:
st.markdown(get_navigation_styles(logo_base64), unsafe_allow_html=True)
else:
st.title("Knee Osteoarthritis Management Platform")
def inject_agent_styles():
"""注入各智能体颜色样式"""
st.markdown("""
""", unsafe_allow_html=True)
def render_home_page():
"""渲染首页"""
abstract_col, figure_col = st.columns([0.9, 1.1])
with abstract_col:
st.markdown('About
', unsafe_allow_html=True)
st.markdown("""
平台使用流程等
""")
with figure_col:
st.markdown('General Framework
', unsafe_allow_html=True)
safe_image_display(IMAGE_PATHS["framework"], "Framework Overview", use_container_width=True)
st.markdown("---")
st.markdown(
"⚠️This website is at an early stage of development and intended for research purposes only. For collaboration or to report bugs, please contact us at ……. Thank you! 本网页仅用于研究用途")
# 每秒刷新一次(根据需要调整频率),最多刷新 N 次
def generate_report_text_from_json(json_path: str = "structured_report_template.json") -> str:
with open(json_path, "r", encoding="utf-8") as f:
report_dict = json.load(f)
lines = []
for section, contents in report_dict.items():
lines.append(section)
for line in contents:
lines.append(line)
lines.append("") # 每个 section 之间空一行
return "\n".join(lines)
def render_centered_image_full(image_path, width=300):
import base64
with open(image_path, "rb") as img_file:
img_bytes = img_file.read()
encoded = base64.b64encode(img_bytes).decode("utf-8")
html = f'''
'''
st.markdown(html, unsafe_allow_html=True)
def spacer(height_px=24):
st.markdown(f"", unsafe_allow_html=True)
def render_assessment_page():
"""渲染评估页面"""
st.markdown("""
""", unsafe_allow_html=True)
chat_manager = ChatManager()
chat_manager.initialize_state()
chat_manager.update_progress()
# 自动刷新(如果还有消息未显示)
if st.session_state.chat_step < len(st.session_state.chat_history):
st_autorefresh(interval=1500, key="chat_autorefresh")
for key, default in {
"show_sidebar": False,
"selected_image_path": None,
"selected_image_label": None,
"typing_index": 0,
"chat_history": None,
"chat_step": 1,
"last_update_time": time.time(),
}.items():
if key not in st.session_state:
st.session_state[key] = default
if st.session_state.show_sidebar:
with st.sidebar:
st.markdown("### 📂 Select a sample knee image")
PREDEFINED_IMAGES = {
"Knee Image A": "images/knee_sample_1.png",
}
for idx, (label, path) in enumerate(PREDEFINED_IMAGES.items()):
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.image(path, width=150, caption=label)
if st.button("✅ Select", key=f"select_{idx}"):
st.session_state.selected_image_path = path
st.session_state.selected_image_label = label
st.session_state.show_sidebar = False
st.rerun()
col1, col2 = st.columns([1.2, 0.8])
with col1:
chat_manager.render_chat_interface()
chat_manager.update_progress()
chat_manager.handle_user_input()
st.divider()
if st.button("📷 Upload Image"):
st.session_state.show_sidebar = True
if st.session_state.selected_image_path and st.session_state.selected_image_label:
st.success(f"✅ Selected: {st.session_state.selected_image_label}")
render_centered_image_full(st.session_state.selected_image_path, width=450)
spacer(24)
analysis_data = load_analysis_report()
report = analysis_data.get("default")
if report:
with st.expander("📝 View Structured Analysis Report"):
for knee, sections in report.items():
st.markdown(f"### 🦵 {knee}")
for section_title, items in sections.items():
st.markdown(f"**{section_title}**")
for item in items:
st.markdown(f"- {item}")
report_text = generate_report_text_from_json()
report_text = clean_text_for_pdf(report_text)
pdf_bytes = generate_pdf(report_text)
with open("custom_patient_report.json", "r", encoding="utf-8") as f:
custom_json_data = json.load(f)
json_bytes = io.BytesIO(json.dumps(custom_json_data, indent=2).encode('utf-8'))
# left_col, right_col = st.columns([4, 1])
# with left_col:
st.download_button(
label="📄 Download Structured Analysis Report as PDF",
data=pdf_bytes,
file_name="knee_report.pdf",
mime="application/pdf"
)
# with right_col:
st.download_button(
label="📄 Download the JSON file",
data=json_bytes,
file_name="knee_report.json",
mime="application/json"
)
else:
st.warning("⚠️ The structured analysis report failed to load")
with col2:
safe_image_display(IMAGE_PATHS["status_framework"], "Status framework", use_container_width=True)
def render_chat(role, message=None, table_df=None):
avatar = {
"User": "🧑",
"AI": "🩺"
}.get(role, "💬")
bg_color = {
"User": "#f1f8e9",
"AI": "#F2F2F2"
}.get(role, "#eeeeee")
content_html = ""
if message:
content_html += f"{message}
"
if table_df is not None:
table_html = table_df.to_html(index=False)
table_html = f"""
{table_html}
"""
content_html += f"{table_html}
"
st.markdown(f"""
""", unsafe_allow_html=True)
def render_centered_table(df):
html_table = df.to_html(index=False)
centered_html = f"""
"""
st.markdown(centered_html, unsafe_allow_html=True)
def render_prediction_report(params):
st.markdown("📝 Comprehensive Prediction Report
", unsafe_allow_html=True)
render_chat("AI", """
Excellent!
I’ve received your case report—thanks for submitting it!
With this complete dataset, I can now provide you with a comprehensive forecast of how your knee condition may evolve over time. Here’s what the model predicts:
""")
st.markdown("📊 Symptom Trajectory Forecast (KOOS, 0–100)
", unsafe_allow_html=True)
symptom_table = {
"Metric": ["Right Knee Pain", "Right Knee Symptoms", "Sport/Recreation Function", "Quality of Life",
"Left Knee Pain"],
"Current (V00)": [
params["symptom_trajectory.right_knee.pain.v00"],
params["symptom_trajectory.right_knee.symptoms.v00"],
params["symptom_trajectory.right_knee.sport_recreation_function.v00"],
params["symptom_trajectory.right_knee.quality_of_life.v00"],
params["symptom_trajectory.left_knee.pain.v00"]
],
"Year 2 (V01)": [
params["symptom_trajectory.right_knee.pain.v01"],
params["symptom_trajectory.right_knee.symptoms.v01"],
params["symptom_trajectory.right_knee.sport_recreation_function.v01"],
params["symptom_trajectory.right_knee.quality_of_life.v01"],
params["symptom_trajectory.left_knee.pain.v01"]
],
"Year 4 (V04)": [
params["symptom_trajectory.right_knee.pain.v04"],
params["symptom_trajectory.right_knee.symptoms.v04"],
params["symptom_trajectory.right_knee.sport_recreation_function.v04"],
params["symptom_trajectory.right_knee.quality_of_life.v04"],
params["symptom_trajectory.left_knee.pain.v04"]
]
}
render_chat("AI", "Here is the forecast of your knee-related symptoms over the coming years:",
pd.DataFrame(symptom_table))
# 影像预测(KL)
st.markdown("🦴 Imaging Trajectory (KL grade, 0–4)
", unsafe_allow_html=True)
imaging_table = {
"Knee": ["Right", "Left"],
"Current": [
params["imaging_trajectory.right_knee.pain.v00"],
params["imaging_trajectory.left_knee.pain.v00"]
],
"Year 2": [
params["imaging_trajectory.right_knee.pain.v01"],
params["imaging_trajectory.left_knee.pain.v01"]
],
"Year 4": [
params["imaging_trajectory.right_knee.pain.v04"],
params["imaging_trajectory.left_knee.pain.v04"]
]
}
render_chat("AI", "Here’s how your knee structure may change over time, based on imaging predictions:",
pd.DataFrame(imaging_table))
# SHAP 解释
st.markdown("💡 Key Contributing Factors (SHAP)
", unsafe_allow_html=True)
shap_data = params["key_factors.right_knee_symptoms_year2"]
shap_table = {
"Feature": [item["feature"] for item in shap_data],
"Impact on KOOS Symptoms": [f"{item['impact']} ({item['effect']})" for item in shap_data]
}
render_chat("AI", "These are the most impactful factors influencing your right knee symptoms at Year 2:",
pd.DataFrame(shap_table))
def load_default_params():
with open(PARAMS_FILE, "r") as f:
return json.load(f)
def render_prediction_page():
"""渲染预测页面"""
col1, col2 = st.columns([1.2, 0.8])
with col1:
params = load_default_params()
# 仅显示 key,如果 value 是基础类型,则展示值
param_display_list = []
display_to_key = {}
for k, v in params.items():
if isinstance(v, (int, float, str)):
label = f"{k} ({v})"
else:
label = f"{k} (complex)"
param_display_list.append(label)
display_to_key[label] = k
st.markdown("**Parameter Mode: `fixed parameter from Assess Agent`**")
with st.expander("Click to view parameters(Extracted from Assess Agent)"):
selected_display = st.radio(" ", param_display_list, index=0)
model = display_to_key[selected_display]
threshold = params[model]
st.markdown("**Selected value:**")
if isinstance(threshold, dict):
st.json(threshold)
elif isinstance(threshold, list):
for i, item in enumerate(threshold):
st.markdown(f"**Item {i + 1}:**")
st.json(item)
else:
st.write(threshold)
# 初始化 session state(首次运行时)
if "prediction_done" not in st.session_state:
st.session_state["prediction_done"] = False
# 点击预测按钮,修改状态
if st.button("Starting prediction", type="primary"):
with st.spinner("Analysing"):
time.sleep(3)
st.success("Prediction completed!")
st.session_state["prediction_done"] = True
# ✅ 只在 prediction_done 为 True 时渲染报告和导出按钮
if st.session_state["prediction_done"]:
render_prediction_report(params)
spacer(16)
export_col1, export_col2 = st.columns([1, 1])
with export_col1:
pdf_text = generate_report_text_from_prediction(params)
pdf_bytes = generate_pdf(pdf_text)
st.download_button(
label="📄 Download Prediction Report as PDF",
data=pdf_bytes,
file_name="prediction_report.pdf",
mime="application/pdf",
use_container_width=True
)
with export_col2:
with open(PARAMS_FILE, "rb") as f:
json_bytes = f.read()
st.download_button(
label="Download Prediction Report JSON",
data=json_bytes,
file_name="predict_params.json",
mime="application/json",
use_container_width=True
)
with col2:
safe_image_display(IMAGE_PATHS["predicting_framework"], "Framework for predicting progress risks",
use_container_width=True)
def render_agent_message_return_html(role: str, action_html: str, style_class: str) -> str:
return f"""
"""
def render_agent_message(role: str, action: str, style_class: str) -> str:
action_html = markdown.markdown(action, extensions=["extra", "nl2br"]) # 转换 Markdown 为 HTML
return f"""
"""
def extract_week_number(phase_name: str) -> int:
"""从 'Week 1–4'(含 en dash)中提取排序基准数字"""
# 替换 en dash(–)和 em dash(—)为 ASCII dash(-)
normalized = phase_name.replace("–", "-").replace("—", "-")
match = re.search(r"Week (\d+)", normalized)
return int(match.group(1)) if match else 0
def render_exercise_plan_return_html(plan: Dict) -> str:
sorted_phases = sorted(plan.items(), key=lambda x: extract_week_number(x[0]))
html_blocks = []
for i, (phase, content) in enumerate(sorted_phases, start=1):
goal = content.get("Goal", "")
prescriptions = content.get("Prescription", [])
markdown_text = f"Phase {i}: {phase}
"
markdown_text += f"GOAL: {goal}
"
for item in prescriptions:
category = item.get("Category", "Training")
description = item.get("Description", "")
markdown_text += f"{category} Training:
"
for part in description.split(", "):
markdown_text += f"- {part.strip()}
"
markdown_text += "
"
html = render_agent_message_return_html(
role="A. Exercise Prescriptionist Agent",
action_html=markdown_text,
style_class="exercise"
)
html_blocks.append(html)
return "\n".join(html_blocks)
def render_surgical_pharma_plan_return_html(plan_data: Dict) -> List[str]:
"""
返回 Surgical & Pharmacological Specialist Agent 的多个 HTML 块列表,
每个块可被单独传入 st.markdown(..., unsafe_allow_html=True) 渲染。
"""
html_blocks = []
# Step 1: 渲染 Guideline Summary
guideline_markdown = "#### Clinical Guideline Analysis\n\n### Matched Guidelines Summary\n\n"
title_map = {
"564": "Severe Functional Limitation",
"225": "Moderate Functional Limitation with Mechanical Symptoms",
"482": "Younger Patient with Single-Compartment Disease"
}
for item in plan_data.get("matched_guidelines", []):
guideline_text = item.get("guideline", "")
match = re.search(r"Scenario (\d+):", guideline_text)
gid = match.group(1) if match else "Unknown"
title = title_map.get(gid, "Clinical Scenario")
guideline_markdown += f"**Guideline {gid}: {title}**\n"
def extract_section(text, start_kw, end_kw=None):
try:
start = text.index(start_kw)
end = text.index(end_kw, start) if end_kw else None
return text[start + len(start_kw):end].strip()
except ValueError:
return ""
clinical = extract_section(guideline_text, 'The patient reports', 'Demonstrates') or extract_section(
guideline_text, 'Experiences', 'has limited') or ""
physical = extract_section(guideline_text, 'Demonstrates', 'Shows') or extract_section(guideline_text,
'has limited',
'shows') or ""
radio = extract_section(guideline_text, 'Shows', 'Total') or extract_section(guideline_text, 'exhibits',
'Total') or ""
guideline_markdown += f"- **Clinical Presentation:** {clinical}\n"
guideline_markdown += f"- **Physical Findings:** {physical}\n"
guideline_markdown += f"- **Radiographic Features:** {radio}\n"
guideline_markdown += f"**Recommendations:**\n"
recos = re.findall(
r"(Total knee arthroplasty|Unicompartmental knee arthroplasty.*?|Realignment Osteotomy.*?)\s*(Appropriate|May Be Appropriate|Rarely Appropriate)\s*(\d)",
guideline_text)
for rec in recos:
guideline_markdown += f"- {rec[0]}: {rec[1]} ({rec[2]}/9)\n"
guideline_markdown += "\n"
guideline_html = render_agent_message_return_html(
role="B. Surgical & Pharmacological Specialist Agent",
action_html=markdown.markdown(guideline_markdown, extensions=["extra", "nl2br"]),
style_class="surgical"
)
html_blocks.append(guideline_html)
# Step 2: 药物推荐表格
meds = plan_data.get("medication_plan", [])
med_table_md = "#### Pharmacological Management Plan\n\n"
med_table_md += "| Medication | Dosage | Administration Schedule | Notes |\n"
med_table_md += "|------------|--------|-------------------------|-------|\n"
for med in meds:
name = med.get("name", "")
dosage = med.get("dosage", "")
freq = med.get("frequency", "")
notes = ""
if "Ibuprofen" in name:
notes = "Monitor for GI effects; take with food"
elif "Acetaminophen" in name:
notes = "Not to exceed 3000 mg daily"
elif "Corticosteroids" in name:
notes = "Consider after failed oral analgesics"
med_table_md += f"| {name} | {dosage} | {freq} | {notes} |\n"
med_table_md += "\nNote: Medication regimen should be tailored based on patient comorbidities, concomitant medications, and individual response to therapy.\n"
med_table_html = markdown.markdown(med_table_md, extensions=["extra", "nl2br"])
wrapped_html = f"{med_table_html}
"
pharma_html = render_agent_message_return_html(
role="B. Surgical & Pharmacological Specialist Agent",
action_html=wrapped_html,
style_class="pharma"
)
html_blocks.append(pharma_html)
return html_blocks
def render_nutrition_psychology_plan_return_html(plan_data: Dict) -> List[str]:
"""返回 Nutritional & Psychological Specialist Agent 的多个 HTML 气泡块"""
html_blocks = []
# ----------- Nutrition 部分 -----------
nutrition = plan_data.get("nutrition", {})
n_goal = nutrition.get("goal", "")
n_duration = nutrition.get("duration", "")
n_content = nutrition.get("content", [])
nutrition_md = "#### Nutritional Intervention Plan\n"
nutrition_md += f"**Goal:** {n_goal}\n\n"
nutrition_md += f"**Delivery Method:** Personalized one-on-one counseling supplemented with mobile application reminders\n"
nutrition_md += f"**Program Structure:**\n"
nutrition_md += f"- **Initial Phase:** Weekly consultations (first 6 weeks)\n"
nutrition_md += f"- **Maintenance Phase:** Bi-weekly check-ins\n"
nutrition_md += f"- **Total Duration:** {n_duration} comprehensive program\n"
# 分类策略
strategies = {
"Anti-inflammatory": [],
"Macronutrient": [],
"Weight": []
}
for item in n_content:
if "Anti-inflammatory" in item or "Adequacy" in item:
strategies["Anti-inflammatory"].append(item)
elif "macronutrient" in item or "Balance" in item:
strategies["Macronutrient"].append(item)
elif "calorie" in item.lower() or "Calorie control" in item:
strategies["Weight"].append(item)
nutrition_md += "**Key Nutritional Strategies:**\n"
if strategies["Anti-inflammatory"]:
nutrition_md += "1. **Anti-inflammatory Focus**\n"
nutrition_md += " - Incorporate omega-3 rich foods (fatty fish, walnuts, flaxseeds)\n"
nutrition_md += " - Increase consumption of antioxidant-rich leafy greens\n"
nutrition_md += " - Integrate nuts and seeds for micronutrient support\n"
nutrition_md += " - Purpose: Reduce joint inflammation and support tissue repair\n"
if strategies["Macronutrient"]:
nutrition_md += "2. **Macronutrient Optimization**\n"
nutrition_md += " - Ensure adequate protein intake to support muscle maintenance\n"
nutrition_md += " - Balance complex carbohydrates for sustained energy\n"
nutrition_md += " - Include healthy fats to support joint lubrication\n"
nutrition_md += " - Purpose: Enhance musculoskeletal strength and joint function\n"
if strategies["Weight"]:
nutrition_md += "3. **Weight Management**\n"
nutrition_md += " - Implement portion awareness techniques\n"
nutrition_md += " - Monitor caloric balance through guided food journaling\n"
nutrition_md += " - Adjust intake based on activity levels and rehabilitation phases\n"
nutrition_md += " - Purpose: Reduce mechanical stress on knee joints\n"
nutrition_html = render_agent_message_return_html(
role="C. Nutritional & Psychological Specialist Agent",
action_html=markdown.markdown(nutrition_md, extensions=["extra", "nl2br"]),
style_class="nutrition"
)
html_blocks.append(nutrition_html)
# ----------- Psychology 部分 -----------
psych = plan_data.get("psychology", {})
p_goal = psych.get("goal", "")
p_duration = psych.get("duration", "")
p_content = psych.get("content", [])
psychology_md = "#### Psychological Support\n"
psychology_md += f"**Goal:** {p_goal}\n\n"
psychology_md += f"**Delivery Method:** Tele-health Cognitive Behavioral Therapy with structured daily practice components\n"
psychology_md += f"**Program Structure:**\n"
psychology_md += f"- **Intensive Phase:** Weekly sessions (first 8 weeks)\n"
psychology_md += f"- **Consolidation Phase:** Bi-weekly sessions\n"
psychology_md += f"- **Total Duration:** {p_duration} comprehensive program\n"
psychology_md += "**Evidence-Based Psychological Approaches:**\n"
for idx, item in enumerate(p_content, start=1):
if "Motivational" in item:
psychology_md += f"{idx}. **Motivational Interviewing**\n"
psychology_md += " - Explore personal values related to mobility and function\n"
psychology_md += " - Resolve ambivalence about rehabilitation commitment\n"
psychology_md += " - Develop intrinsic motivation for consistent exercise adherence\n"
psychology_md += " - Purpose: Strengthen commitment to rehabilitation protocols\n"
elif "CBT" in item:
psychology_md += f"{idx}. **Cognitive Restructuring**\n"
psychology_md += " - Identify and challenge maladaptive thoughts about pain and recovery\n"
psychology_md += " - Transform catastrophizing patterns into realistic perspectives\n"
psychology_md += " - Develop confidence in functional improvement\n"
psychology_md += " - Purpose: Reduce pain-related fear and enhance rehabilitation engagement\n"
elif "mindfulness" in item.lower():
psychology_md += f"{idx}. **Digital Mindfulness Integration**\n"
psychology_md += " - Implement scheduled mindfulness practice through mobile notifications\n"
psychology_md += " - Provide guided pain-specific meditation recordings\n"
psychology_md += " - Track stress levels in relation to symptom fluctuations\n"
psychology_md += " - Purpose: Enhance stress management and improve pain tolerance\n"
psychology_md += "*Note: Both nutritional and psychological interventions will be coordinated with physical rehabilitation to ensure comprehensive care integration.*"
psychology_html = render_agent_message_return_html(
role="C. Nutritional & Psychological Specialist Agent",
action_html=markdown.markdown(psychology_md, extensions=["extra", "nl2br"]),
style_class="psychology"
)
html_blocks.append(psychology_html)
return html_blocks
def render_clinical_decision_agent_return_html(plan_data: Dict) -> str:
primary_goal = plan_data.get("Goals", {}).get("Primary", "")
secondary_goal = plan_data.get("Goals", {}).get("Secondary", "")
plan = plan_data.get("InterventionPlan", {})
action_md = "#### Integrated Multimodal Intervention Plan\n\n"
# Medication
med_summary = plan.get("Medication", {}).get("Summary", "")
action_md += "**🩺 Medication Strategy**\n"
action_md += f"- {med_summary}\n\n"
# Nutrition
nutrition_desc = plan.get("NutritionPlan", {}).get("Description", "")
framework = plan.get("NutritionPlan", {}).get("Framework", "")
action_md += "**🥗 Nutrition Plan**\n"
action_md += f"- **Framework:** {framework}\n"
action_md += f"- {nutrition_desc}\n\n"
# Exercise
exercise = plan.get("ExercisePlan", {})
framework = exercise.get("Framework", "")
phases = exercise.get("Phases", {})
action_md += "**🏃 Exercise Plan**\n"
action_md += f"- **Framework:** {framework}\n"
for week_range, content in phases.items():
goal = content.get("Goal", "")
prescription = content.get("Prescription", "")
action_md += f" - **{week_range}:** {goal}\n"
action_md += f" - {prescription}\n"
action_md += "\n"
# Psychology
psych_summary = plan.get("PsychologicalSupport", {}).get("Summary", "")
action_md += "**🧠 Psychological Support**\n"
action_md += f"- {psych_summary}\n\n"
# Surgical
surgical_summary = plan.get("SurgicalOrInjectionConsiderations", {}).get("Summary", "")
action_md += "**🛠️ Surgical or Injection Considerations**\n"
action_md += f"- {surgical_summary}\n\n"
# Safety
safety_summary = plan.get("SafetyMonitoring", {}).get("Summary", "")
action_md += "**🔍 Safety Monitoring Plan**\n"
action_md += f"- {safety_summary}\n\n"
# Personalization
accessibility = plan_data.get("AccessibilityFeasibility", "")
rationale = plan_data.get("PersonalizationRationale", "")
evidence = plan_data.get("EvidenceCompliance", "")
action_md += "#### Personalized Treatment Context\n"
action_md += f"- **Accessibility & Feasibility:** {accessibility}\n"
action_md += f"- **Personalization Rationale:** {rationale}\n"
action_md += f"- **Evidence Compliance:** {evidence}\n"
html = render_agent_message(
role="🧩 Clinical Decision-Making Agent",
action=action_md,
style_class="decision"
)
return html
def render_progress_bar(step: int, total: int):
'''进度条函数'''
progress = step / total
st.markdown(f"""
""", unsafe_allow_html=True)
st.markdown(f"Progress: {int(progress * 100)}%", unsafe_allow_html=True)
def render_all_agents_auto():
total_agents = 4
progress_placeholder = st.empty()
# ✅ Agent A: Exercise
progress_placeholder.markdown(render_progress_bar_html(1, total_agents), unsafe_allow_html=True)
exercise_plan = load_plan("exercise")
with st.expander("A. Exercise Prescriptionist Agent", expanded=False):
html = render_exercise_plan_return_html(exercise_plan)
st.markdown(html, unsafe_allow_html=True)
# ✅ Agent B: Surgical & Pharma
progress_placeholder.markdown(render_progress_bar_html(2, total_agents), unsafe_allow_html=True)
surgical_plan = load_plan("surgical_pharma")
with st.expander("B. Surgical & Pharmacological Specialist Agent", expanded=False):
# html = render_surgical_pharma_plan_return_html(surgical_plan)
# st.markdown(html, unsafe_allow_html=True)
html_blocks = render_surgical_pharma_plan_return_html(surgical_plan)
for html in html_blocks:
st.markdown(html, unsafe_allow_html=True)
# ✅ Agent C: Nutrition & Psychology
progress_placeholder.markdown(render_progress_bar_html(3, total_agents), unsafe_allow_html=True)
nutrition_plan = load_plan("nutrition_psychology")
with st.expander("C. Nutritional & Psychological Specialist Agent", expanded=False):
# html = render_nutrition_psychology_plan_return_html(nutrition_plan)
# st.markdown(html, unsafe_allow_html=True)
html_blocks = render_nutrition_psychology_plan_return_html(nutrition_plan)
for html in html_blocks:
st.markdown(html, unsafe_allow_html=True)
with st.spinner("Clinical Decision-Making Agent reasoning..."):
time.sleep(3)
progress_placeholder.markdown(render_progress_bar_html(4, total_agents), unsafe_allow_html=True)
decision_plan = load_plan("clinical_integration")
with st.expander("D. Clinical Decision-Making Agent", expanded=False):
html = render_clinical_decision_agent_return_html(decision_plan)
st.markdown(html, unsafe_allow_html=True)
# 进度条 HTML 渲染拆出来方便复用
def render_progress_bar_html(step: int, total: int) -> str:
progress = step / total
return f"""
Progress: {int(progress * 100)}%
"""
def render_therapy_page():
"""渲染治疗推荐页面"""
inject_agent_styles()
col1, col2 = st.columns([1.5, 1])
with col2:
safe_image_display(IMAGE_PATHS["recommendation_framework"], "Framework for personalizing treatment",
use_container_width=True)
with col1:
st.subheader("🩺 Quick Therapy Demo")
# 加载病例数据
case_data = load_case_data()
if not case_data:
st.warning("Case data cannot be loaded.")
return
case_names = list(case_data.keys())
case_options = [""] + case_names
selected_case = st.selectbox("Select a sample case:", case_options)
st.markdown(get_chat_styles(), unsafe_allow_html=True)
if selected_case != "":
case = case_data[selected_case]
if "start_clicked" not in st.session_state:
st.session_state.start_clicked = False
if not st.session_state.get("start_clicked", False):
st.success(f"Selected:{selected_case}")
st.markdown("### 📑 Case Reports")
for report_name in case.get("reports", []):
st.markdown(f"📄 {report_name}")
# if st.button("▶️ Start Multi-Agent Reasoning"):
# st.session_state.start_clicked = True
# st.query_params.update({"start": "1"})
button_placeholder = st.empty() # 创建按钮容器
spacer_placeholder = st.empty() # 创建一个“空白占位”容器用于替代按钮
if "start_clicked" not in st.session_state:
st.session_state.start_clicked = False
if not st.session_state.start_clicked:
if button_placeholder.button("▶️ Start Multi-Agent Reasoning"):
st.session_state.start_clicked = True
st.query_params.update({"start": "1"})
# 清空按钮,并填入等高的空白 div 避免布局闪烁
button_placeholder.empty()
spacer_placeholder.markdown("", unsafe_allow_html=True)
else:
button_placeholder.empty()
spacer_placeholder.empty()
render_all_agents_auto()
else:
# 展示多智能体对话部分
render_all_agents_auto() # 👈 自动延迟依次渲染各个Agent对话 + 进度条