import streamlit as st from PIL import Image import base64 import time import streamlit.components.v1 as components import json from typing import Dict, List, Optional from pathlib import Path from streamlit_autorefresh import st_autorefresh import pandas as pd import re import streamlit_image_select as sis from fpdf import FPDF import io import markdown from utils.qwen_agent import call_qwen_agent import os import math import html # ============================================================================= # 配置和常量 # ============================================================================= PAGE_CONFIG = { "page_title": "Knee OA Demo", "layout": "wide" } CHAT_CONFIG = { "update_interval": 2, "height": 400, "max_width": "80%" } IMAGE_PATHS = { "logo": "images/logo.png", "framework": "images/framework.png", "status_framework": "images/status_framework.png", "predicting_framework": "images/predicting_framework.png", "recommendation_framework": "images/Recommendation_framework.png" } CASES_FILE = "cases.json" PARAMS_FILE = "predict_params.json" PREDICT_FILE = "predict_params_ori.json" # ============================================================================= # 工具函数 # ============================================================================= @st.cache_data def get_base64_image(image_path: str) -> Optional[str]: """安全获取图片base64编码""" try: if Path(image_path).exists(): with open(image_path, "rb") as f: data = f.read() return base64.b64encode(data).decode() except Exception as e: st.error(f"无法加载图片 {image_path}: {e}") return None @st.cache_data def load_initial_chat_history(file_path: str = "initial_chat.json") -> List[Dict]: """从 JSON 文件加载初始聊天历史""" try: if Path(file_path).exists(): with open(file_path, "r", encoding="utf-8") as f: return json.load(f) else: st.warning(f"找不到初始对话文件:{file_path}") except Exception as e: st.error(f"加载初始聊天对话失败: {e}") return [] @st.cache_data def load_analysis_report(json_file="assess_result.json") -> Dict: try: with open(json_file, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: st.error(f"Unable to load the analysis report file: {e}") return {} @st.cache_data def load_case_data() -> Dict: """加载病例数据""" try: if Path(CASES_FILE).exists(): with open(CASES_FILE, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: st.error(f"无法加载病例数据: {e}") return {} def load_plan(agent_type: str): path = f"{agent_type}_plan.json" with open(path, "r", encoding="utf-8") as f: return json.load(f) def clean_text_for_pdf(text: str) -> str: replacements = { "–": "-", # 长破折号 "—": "-", # 全角破折号 "“": "\"", # 中文引号 "”": "\"", "’": "'", "•": "-", # 项目符号 "→": "->", "…": "...", "©": "(c)", } for old, new in replacements.items(): text = text.replace(old, new) return text.encode("latin-1", "ignore").decode("latin-1") # 丢弃不可编码字符 def strip_non_latin1(text: str) -> str: """去除无法被 Latin-1 编码的字符(如 emoji、中文)""" return text.encode("latin-1", errors="ignore").decode("latin-1") def generate_pdf(text: str) -> bytes: pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) # 保留默认字体 # 🔧 逐行添加,过滤掉无法编码的字符 for line in text.split("\n"): clean_line = strip_non_latin1(line) pdf.multi_cell(0, 10, txt=clean_line) return pdf.output(dest="S").encode("latin1") def safe_image_display(image_path: str, caption: str = "", **kwargs): """显示图片""" try: if Path(image_path).exists(): st.image(image_path, caption=caption, **kwargs) else: st.warning(f"⚠️ 图片未找到: {image_path}") except Exception as e: st.error(f"显示图片时出错: {e}") def generate_report_text_from_prediction(params: dict) -> str: lines = [] # 1. Symptom Trajectory Forecast lines.append("📊 Symptom Trajectory Forecast (KOOS, 0–100)") symptom_rows = [ ("Right Knee Pain", "symptom_trajectory.right_knee.pain"), ("Right Knee Symptoms", "symptom_trajectory.right_knee.symptoms"), ("Left Knee Pain", "symptom_trajectory.left_knee.pain"), ("Left Knee Symptoms", "symptom_trajectory.left_knee.symptoms"), ("Sport/Recreation Function", "symptom_trajectory.right_knee.sport_recreation_function"), ("Quality of Life", "symptom_trajectory.right_knee.quality_of_life") ] for label, base_key in symptom_rows: v00 = params.get(f"{base_key}.v00", "N/A") v01 = params.get(f"{base_key}.v01", "N/A") v04 = params.get(f"{base_key}.v04", "N/A") lines.append(f"- {label}: Current={v00}, Year 2={v01}, Year 4={v04}") lines.append("") # 2. Imaging Trajectory Forecast lines.append("🦴 Imaging Trajectory (KL grade, 0–4)") for side in ["right", "left"]: v00 = params.get(f"imaging_trajectory.{side}_knee.pain.v00", "N/A") v01 = params.get(f"imaging_trajectory.{side}_knee.pain.v01", "N/A") v04 = params.get(f"imaging_trajectory.{side}_knee.pain.v04", "N/A") lines.append(f"- {side.capitalize()} Knee: Current={v00}, Year 2={v01}, Year 4={v04}") lines.append("") # 3. SHAP Key Factors lines.append("💡 Key Contributing Factors (SHAP)") shap_data = params.get("key_factors.right_knee_symptoms_year2", []) for item in shap_data: feature = item.get("feature", "Unknown") impact = item.get("impact", "N/A") effect = item.get("effect", "") lines.append(f"- {feature}: {impact} ({effect})") return "\n".join(lines) # ============================================================================= # 样式定义 # ============================================================================= def get_navigation_styles(logo_base64: str) -> str: """获取导航栏样式""" return f""" """ def get_chat_styles() -> str: """返回优化后的聊天气泡样式""" return """ """ # ============================================================================= # 聊天功能 # ============================================================================= class ChatManager: def __init__(self, initial_chat_file: str = "assess_chat.json"): self.initial_history = load_initial_chat_history(initial_chat_file) def initialize_state(self): """初始化聊天状态""" if "chat_history" not in st.session_state: st.session_state.chat_history = self.initial_history.copy() st.session_state.chat_step = 1 st.session_state.last_update_time = time.time() def update_progress(self): """更新聊天进度(非阻塞)""" current_time = time.time() if (st.session_state.chat_step < len(st.session_state.chat_history) and current_time - st.session_state.last_update_time > CHAT_CONFIG["update_interval"]): st.session_state.chat_step += 1 st.session_state.last_update_time = current_time # def update_progress(self): # current_time = time.time() # step = st.session_state.chat_step # history = st.session_state.chat_history # # 如果未显示完全部对话,并且刷新间隔已达 # if step < len(self.initial_history) and current_time - st.session_state.last_update_time > CHAT_CONFIG["update_interval"]: # next_msg = self.initial_history[step] # # 👤 如果下一条是用户内容,直接添加 # if next_msg["role"] == "user": # history.append(next_msg) # # 🤖 如果下一条是 AI 回答:动态生成(用上一条 user 消息作为 prompt) # elif next_msg["role"] == "assistant": # # ⛔ 安全校验:若 history 为空或上一条不是 user,跳过 # if len(history) == 0 or history[-1]["role"] != "user": # st.warning("⚠️ 无法生成 AI 回答:找不到上一条用户消息") # return # user_prompt = history[-1]["content"] # app_id = "c968f91131ac432787f5ef81f51922ba" # api_key = os.getenv("DASHSCOPE_API_KEY") # ai_reply = self.generate_response(user_prompt, app_id, api_key) # history.append({"role": "assistant", "content": ai_reply}) # # ✅ 每推进一条,step +1,更新时间 # st.session_state.chat_step += 1 # st.session_state.last_update_time = current_time def render_message(self, role: str, content: str) -> str: """渲染单条消息(AI左侧,用户右侧)""" if role == "user": return f"""
🧍 {content}
""" else: return f"""
👨‍⚕️ {content}
""" def render_chat_interface(self): for msg in st.session_state.chat_history[:st.session_state.chat_step]: print("原始 content 内容:", repr(msg["content"])) # # def process_content(content): # # # 将换行符转换为 HTML 换行标签 # # return content.replace("\n", "
") # def process_content(content): # print("替换前:", content[:50]) # 打印前50字符 # content = content.replace("\n", "
") # content = content.replace(" ", "   ") # print("替换后:", content[:50]) # 打印替换后的前50字符 # return content # chat_html = "".join([ # self.render_message(msg["role"], process_content(msg["content"])) # 应用换行处理 # for msg in st.session_state.chat_history[:st.session_state.chat_step] # ]) chat_html = "".join([ self.render_message( msg["role"], # 关键修改:先处理字面意义的 \\n(\和n组成的字符),再转
msg["content"] .replace("\\n", "\n") # 第一步:将字面的 \n 转为真正的换行符 .replace("\n", "
") # 第二步:将真正的换行符转为 HTML 换行 .replace(" ", "   ") # 保留缩进 ) for msg in st.session_state.chat_history[:st.session_state.chat_step] ]) height = CHAT_CONFIG["height"] components.html(f"""
{chat_html}
""", height=height) def handle_user_input(self): user_input = st.chat_input("Please enter your symptoms, medical history or problems...") if user_input: st.session_state.chat_history.append({"role": "user", "content": user_input}) app_id = "c968f91131ac432787f5ef81f51922ba" api_key = os.getenv("DASHSCOPE_API_KEY") response = self.generate_response(user_input, app_id, api_key) st.session_state.chat_history.append({"role": "assistant", "content": response}) st.session_state.chat_step = len(st.session_state.chat_history) st.rerun() # def generate_response(self, user_input: str) -> str: # """生成助手回复""" # responses = { # "Pain": "Please describe in detail the nature, frequency and triggering factors of the pain.", # "Swelling": "When does swelling usually occur? Is there any accompanying fever?", # "Stiffness": "How long does morning stiffness last? Was there any improvement after the activity?", # "Cracking": "Is joint cracking accompanied by pain?" # } # for keyword, response in responses.items(): # if keyword in user_input: # return response # return "Thank you for your feedback. I will conduct an analysis based on this information. Please continue to describe your symptoms." def generate_response(self, user_input: str, app_id: str, api_key: str) -> str: if not api_key: return "❌ 请在 Hugging Face 的 Secrets 中配置 DASHSCOPE_API_KEY。" try: st.write(f"🚀 正在调用 Qwen,输入:{user_input}") response = call_qwen_agent(user_input, app_id, api_key) st.write(f"✅ Qwen 返回前200字:{response[:200]}") return response except Exception as e: st.write(f"❌ Qwen 调用失败:{e}") return "调用 Qwen API 出错,请稍后再试。" # ============================================================================= # 页面渲染函数 # ============================================================================= def render_navigation(): """渲染导航栏""" logo_base64 = get_base64_image(IMAGE_PATHS["logo"]) if logo_base64: st.markdown(get_navigation_styles(logo_base64), unsafe_allow_html=True) else: st.title("Knee Osteoarthritis Management Platform") def inject_agent_styles(): """注入各智能体颜色样式""" st.markdown(""" """, unsafe_allow_html=True) def render_home_page(): """渲染首页""" abstract_col, figure_col = st.columns([0.9, 1.1]) # with abstract_col: # st.markdown('

About

', unsafe_allow_html=True) # st.markdown(""" # 平台使用流程等 # """) with abstract_col: st.markdown('

About

', unsafe_allow_html=True) st.markdown(""" KOM: Knee Osteoarthritis Chronic Disease Management System From the Sports Medicine Center, West China Hospital, Sichuan University KOM is an intelligent, multi-agent (Multi-Agent) AI system that supports the full KOA care pathway—assessment → risk prediction → individualized therapy—to enable precise, standardized, and scalable chronic disease management for knee osteoarthritis. Quick Start: How to Use the Web App In the top-right corner of the page, you’ll see three buttons, each mapping to a core module. Click from left to right to experience the end-to-end AI-assisted diagnosis and prescription flow. Interaction tips In Assessment, start a guided dialogue to capture medical history. You can also upload bilateral AP knee X-rays via the button below and complete structured data entry with on-screen prompts. Go to Risk to automatically pull prior inputs and generate 2-year / 4-year predictions for symptoms (KOOS) and radiographic outcomes, with patient-specific risk factor explanations (via SHAP). In Therapy, launch a multidisciplinary, multi-agent (MDT) discussion to produce an evidence-based, individualized, and actionable management plan (covering exercise, surgical/pharmacologic, nutrition, and psychological sub-prescriptions). Each module can also be used independently with manual data entry—handy for different clinical settings. What KOM Is KOM (Knee Osteoarthritis Manager) is the first end-to-end multi-agent AI system purpose-built for KOA, developed by the Sports Medicine Center, West China Hospital, with a cross-disciplinary team. It integrates LLMs, ResNet-based imaging, classical machine learning, and MDT-style multi-agent collaboration to cover: Disease Assessment: structured dialogue + automated X-ray analysis to generate a standardized case report. Progression Prediction: 2-/4-year forecasts for KOOS subscales and KL grades, plus individualized etiology/risk explanations. Individualized Therapy: multi-agent simulation of MDT to output evidence-based, executable plans. Modules & Capabilities 1) Assessment Agent Structured dialogue intake (LLM with optimized prompts): auto-completes missing fields, explains medical terms, and guides KOOS collection. Intelligent X-ray analysis: a deep-learning pipeline trained on the OAI dataset for knee localization, KOA grading, medial/lateral joint space narrowing, osteophytes, and subchondral sclerosis. Output: one-click case evaluation report in clinical style. 2) Progression Prediction Agent (Risk Agent) Functional outcomes: regression for KOOS subscales at 2 and 4 years. Radiographic outcomes: classification for KL grades at 2 and 4 years (ensemble of algorithms). Explainability: SHAP shows each patient’s risk contributions (e.g., osteophytes, pain scores, muscle strength) to guide interventions. 3) Treatment Multi-Agent Cluster (Therapy Agent) MDT via multi-agent collaboration: Exercise/Rehab, Orthopedics (surgery/pharmacology), Psycho-Nutrition, and a Clinical Integration agent. Evidence bases: structured entries curated from guidelines and peer-reviewed literature (Exercise 975; Surgery 1549; Rehab 934; Psychology 210; Nutrition 349).Output: individualized plans aligned with FITT-VP (exercise) and ABCMV (nutrition), emphasizing safety and actionability. Workflow at a Glance Data Intake: upload X-rays + structured interview Auto Analysis: radiographic grading & key signs → evaluation report Risk Prediction: 2-/4-year functional & imaging outcomes + personalized risk explanation MDT Therapy: multi-agent discussion → evidence-based individualized plan Review & Export: clinicians can revise at any step and export standardized documents Code & Live Demo GitHub (Open Source): https://github.com/jacobliuweizhi/KOM Live Demo (Hugging Face Spaces): https://huggingface.co/spaces/Miemie123/Streamlit?page=Tailored+Therapy+Recommendation&start=1 License: GNU AGPL v3.0. RAG references and example code are included in the repo. Ongoing Research & Productization Large-scale RCT: Evaluating KOM-assisted care vs. routine clinical workflows, focusing on real-world effectiveness and safety. Sports-Med LLM: In parallel development to enhance cross-task generalization and on-device, real-time assessment—especially for non-radiographic scenarios. Team & Contact (Corresponding Authors) Prof. Yong Nie (Department of Orthopedic Surgery, West China Hospital) | nieyong1983@wchscu.cn Prof. Kang Li (Sichuan University / Shanghai AI Lab) | likang@wchscu.cn Prof. Jian Li (Sports Medicine Center, West China Hospital) | lijian_sportsmed@163.com """) with figure_col: st.markdown('

General Framework

', unsafe_allow_html=True) safe_image_display(IMAGE_PATHS["framework"], "Framework Overview", use_container_width=True) st.markdown("---") st.markdown("⚠️This website is at an early stage of development and intended for research purposes only. Thank you! 本网页仅用于研究用途") # 每秒刷新一次(根据需要调整频率),最多刷新 N 次 def generate_report_text_from_json(json_path: str = "structured_report_template.json") -> str: with open(json_path, "r", encoding="utf-8") as f: report_dict = json.load(f) lines = [] for section, contents in report_dict.items(): lines.append(section) for line in contents: lines.append(line) lines.append("") # 每个 section 之间空一行 return "\n".join(lines) def render_centered_image_full(image_path, width=300): import base64 with open(image_path, "rb") as img_file: img_bytes = img_file.read() encoded = base64.b64encode(img_bytes).decode("utf-8") html = f'''
''' st.markdown(html, unsafe_allow_html=True) def spacer(height_px=24): st.markdown(f"
", unsafe_allow_html=True) def render_assessment_page(): """渲染评估页面""" st.markdown(""" """, unsafe_allow_html=True) chat_manager = ChatManager() chat_manager.initialize_state() chat_manager.update_progress() if st.session_state.chat_step < len(st.session_state.chat_history): st_autorefresh(interval=1500, key="chat_autorefresh") for key, default in { "show_sidebar": False, "selected_image_path": None, "selected_image_label": None, "typing_index": 0, "chat_history": None, "chat_step": 1, "last_update_time": time.time(), }.items(): if key not in st.session_state: st.session_state[key] = default if st.session_state.show_sidebar: with st.sidebar: st.markdown("### 📂 Select a sample image") PREDEFINED_IMAGES = { "Knee Image A": "images/knee_sample_1.png", } for idx, (label, path) in enumerate(PREDEFINED_IMAGES.items()): col1, col2, col3 = st.columns([1, 2, 1]) with col2: st.image(path, width=150, caption=label) if st.button("✅ Select", key=f"select_{idx}"): st.session_state.selected_image_path = path st.session_state.selected_image_label = label st.session_state.show_sidebar = False st.rerun() st.markdown('

Demo chat interface (display only).

', unsafe_allow_html=True) col1, col2 = st.columns([1.2, 0.8]) with col1: chat_manager.render_chat_interface() chat_manager.update_progress() chat_manager.handle_user_input() st.divider() if st.button("📷 Upload Image", key="upload_image_btn", use_container_width=True): st.session_state.show_sidebar = True if st.session_state.selected_image_path and st.session_state.selected_image_label: st.success(f"✅ Selected: {st.session_state.selected_image_label}") render_centered_image_full(st.session_state.selected_image_path, width=450) spacer(24) analysis_data = load_analysis_report() report = analysis_data.get("default") if report: with st.expander("📝 View Structured Analysis Report"): for knee, sections in report.items(): st.markdown(f"### 🦵 {knee}") for section_title, items in sections.items(): st.markdown(f"**{section_title}**") for item in items: st.markdown(f"- {item}") report_text = generate_report_text_from_json() report_text = clean_text_for_pdf(report_text) pdf_bytes = generate_pdf(report_text) with open("custom_patient_report_ori.json", "r", encoding="utf-8") as f: custom_json_data = json.load(f) json_bytes = io.BytesIO(json.dumps(custom_json_data, indent=2).encode('utf-8')) # left_col, right_col = st.columns([4, 1]) # with left_col: st.download_button( label="📄 Download Structured Analysis Report as PDF", data=pdf_bytes, file_name="knee_report.pdf", mime="application/pdf" ) # with right_col: st.download_button( label="📄 Download the JSON file", data=json_bytes, file_name="knee_report.json", mime="application/json" ) else: st.warning("⚠️ The structured analysis report failed to load") with col2: safe_image_display(IMAGE_PATHS["status_framework"], "Status framework", use_container_width=True) def render_chat(role, message=None, table_df=None): avatar = { "User": "🧑", "AI": "🩺" }.get(role, "💬") bg_color = { "User": "#f1f8e9", "AI": "#F2F2F2" }.get(role, "#eeeeee") content_html = "" if message: content_html += f"
{message}
" if table_df is not None: table_html = table_df.to_html(index=False) table_html = f""" {table_html} """ content_html += f"
{table_html}
" st.markdown(f"""
{avatar}
{content_html}
""", unsafe_allow_html=True) def render_centered_table(df): html_table = df.to_html(index=False) centered_html = f"""
{html_table}
""" st.markdown(centered_html, unsafe_allow_html=True) def render_prediction_report(params): st.markdown("

📝 Comprehensive Prediction Report

", unsafe_allow_html=True) render_chat("AI", """ Excellent! I’ve received your case report—thanks for submitting it! With this complete dataset, I can now provide you with a comprehensive forecast of how your knee condition may evolve over time. Here’s what the model predicts: """) st.markdown("
📊 Symptom Trajectory Forecast (KOOS, 0–100)
", unsafe_allow_html=True) symptom_table = { "Metric": ["Right Knee Pain", "Right Knee Symptoms", "Left Knee Pain", "Left Knee Symptoms","Sport/Recreation Function", "Quality of Life"], "Current (V00)": [ params["KOOSPain_R"], params["KOOSSym_R"], params["LKPain_V00"], params["LKSym_V00"], params["KOOSSport"], params["KQOL_V00"] ], "Year 2 (V01)": [ 97, 93, 89, 73, 58, 31 ], "Year 4 (V04)": [ 97, 91, 84, 66, 75, 50 ] } render_chat("AI", "Here is the forecast of your knee-related symptoms over the coming years:", pd.DataFrame(symptom_table)) # 影像预测(KL) st.markdown("
🦴 Imaging Trajectory (KL grade, 0–4)
", unsafe_allow_html=True) imaging_table = { "Knee": ["Right", "Left"], "Current": [ params["RKImg_V00"], params["LKImg_V00"] ], "Year 2": [ "Severe", "Mild" ], "Year 4": [ "Severe", "Mild" ] } render_chat("AI", "Here’s how your knee structure may change over time, based on imaging predictions:", pd.DataFrame(imaging_table)) # SHAP 解释 st.markdown("
💡 Key Contributing Factors (SHAP)
", unsafe_allow_html=True) shap_data = params["key_factors.right_knee_symptoms_year2"] shap_table = { "Feature": [item["feature"] for item in shap_data], "Impact on KOOS Symptoms": [f"{item['impact']} ({item['effect']})" for item in shap_data] } render_chat("AI", "These are the most impactful factors influencing your right knee symptoms at Year 2:", pd.DataFrame(shap_table)) # def load_default_params(): # with open(PARAMS_FILE, "r") as f: # return json.load(f) def load_default_params(file_path: str) -> dict: try: with open(file_path, "r") as f: return json.load(f) except FileNotFoundError: raise FileNotFoundError(f"参数文件不存在: {file_path}") except json.JSONDecodeError: raise ValueError(f"参数文件格式错误(非有效的JSON): {file_path}") except Exception as e: raise Exception(f"加载参数文件时出错: {str(e)}") def multi_column_radio(label, options, cols=7, index=0): key = f"multi_col_radio_{label}" if key not in st.session_state: st.session_state[key] = options[index] if options else None items_per_col = math.ceil(len(options) / cols) if options else 0 columns = st.columns(cols) for col_idx in range(cols): start_idx = col_idx * items_per_col end_idx = start_idx + items_per_col column_options = options[start_idx:end_idx] with columns[col_idx]: for option in column_options: # 为每个选项创建单选按钮 is_selected = (st.session_state[key] == option) # 使用唯一key,但不直接修改其他选项的状态 if st.checkbox(option, value=is_selected, key=f"{key}_{col_idx}_{option}"): if not is_selected: st.session_state[key] = option # 触发重新渲染以更新所有选项状态 st.rerun() return st.session_state[key] def render_prediction_page(): """渲染预测页面""" col1, col2 = st.columns([1.2, 0.8]) with col1: params = load_default_params(PARAMS_FILE) predict = load_default_params(PREDICT_FILE) param_display_list = [] display_to_key = {} exclude_key = "key_factors.right_knee_symptoms_year2" for k, v in params.items(): # if isinstance(v, (int, float, str)): if k == exclude_key: continue label = f"{k}" # else: # label = f"{k} (complex)" param_display_list.append(label) display_to_key[label] = k st.markdown("**Parameter Mode: `fixed parameter from Assessment Agent`**") with st.expander("Click to view parameters"): st.markdown("**The patient parameters are listed below, Click the box to view the values.**") selected_display = multi_column_radio(" ", param_display_list, cols=7) model = display_to_key[selected_display] threshold = params[model] st.markdown("**Selected value:**") if isinstance(threshold, dict): st.json(threshold) elif isinstance(threshold, list): for i, item in enumerate(threshold): st.markdown(f"**Item {i+1}:**") st.json(item) else: st.write(threshold) with st.expander("View abbreviation explanations"): col1_exp, col2_exp = st.columns(2) with col1_exp: st.markdown("**X-ray parameters (Knee):**") x_ray_params = { "XRKL_L": "Kellgren–Lawrence grade, left knee", "XRKL_R": "Kellgren–Lawrence grade, right knee", "XRJSL_L": "Joint space narrowing, lateral, left knee", "XRJSM_L": "Joint space narrowing, medial, left knee", "XROSFL_L": "Osteophytes, femur lateral, left knee", "XROSFM_L": "Osteophytes, femur medial, left knee", "XROSTL_L": "Osteophytes, tibia lateral, left knee", "XROSTM_L": "Osteophytes, tibia medial, left knee", "XRJSL_R": "Joint space narrowing, lateral, right knee", "XRJSM_R": "Joint space narrowing, medial, right knee", "XROSFL_R": "Osteophytes, femur lateral, right knee", "XROSFM_R": "Osteophytes, femur medial, right knee", } for abbr, full_name in x_ray_params.items(): st.markdown(f"- **{abbr}**: {full_name}") st.markdown("**Demographics / Basics:**") demo_params = { "AGE": "Age at baseline", "BMI": "Body Mass Index", "WEIGHT": "Body weight (kg)", } for abbr, full_name in demo_params.items(): st.markdown(f"- **{abbr}**: {full_name}") with col2_exp: st.markdown("**X-ray parameters (cont.):**") x_ray_params_cont = { "XROSTL_R": "Osteophytes, tibia lateral, right knee", "XROSTM_R": "Osteophytes, tibia medial, right knee", "XRSCFL_R": "Subchondral cyst, femur lateral, right knee", "RKImg_V00": "Radiographic grade, right knee, baseline", "LKImg_V00": "Radiographic grade, left knee, baseline", } for abbr, full_name in x_ray_params_cont.items(): st.markdown(f"- **{abbr}**: {full_name}") st.markdown("**Biomechanics (Force):**") bio_params = { "RFmaxF": "Right foot maximum forward force", "REmaxF": "Right foot maximum eversion force", "LFmaxF": "Left foot maximum forward force", "LEmaxF": "Left foot maximum eversion force", "RFmaxF_BMI": "Right foot max forward force normalized by BMI", "REmaxF_BMI": "Right foot max eversion force normalized by BMI", "LFmaxF_BMI": "Left foot max forward force normalized by BMI", "LEmaxF_BMI": "Left foot max eversion force normalized by BMI", } for abbr, full_name in bio_params.items(): st.markdown(f"- **{abbr}**: {full_name}") st.markdown("**KOOS Questionnaire:**") koos_params = { "KOOSPain_R": "KOOS pain score, right knee, baseline", "KOOSSym_R": "KOOS symptoms score, right knee, baseline", "KOOSPain_L": "KOOS pain score, left knee, baseline", "KOOSSym_L": "KOOS symptoms score, left knee, baseline", "KOOSSport": "KOOS sport/recreation score, baseline", "KOOSQOL": "KOOS quality of life score, baseline" } for abbr, full_name in koos_params.items(): st.markdown(f"- **{abbr}**: {full_name}") if "prediction_done" not in st.session_state: st.session_state["prediction_done"] = False if st.button("Starting prediction", type="primary"): with st.spinner("Analysing"): time.sleep(3) st.success("Prediction completed!") st.session_state["prediction_done"] = True if st.session_state["prediction_done"]: render_prediction_report(params) spacer(16) export_col1, export_col2 = st.columns([1, 1]) with export_col1: pdf_text = generate_report_text_from_prediction(predict) pdf_bytes = generate_pdf(pdf_text) st.download_button( label="📄 Download Prediction Report as PDF", data=pdf_bytes, file_name="prediction_report.pdf", mime="application/pdf", use_container_width=True ) with export_col2: with open(PARAMS_FILE, "rb") as f: json_bytes = f.read() st.download_button( label="Download Prediction Report JSON", data=json_bytes, file_name="predict_params.json", mime="application/json", use_container_width=True ) with col2: safe_image_display(IMAGE_PATHS["predicting_framework"], "Framework for predicting progress risks", use_container_width=True) def render_agent_message_return_html(role: str, action_html: str, style_class: str) -> str: return f"""
{role}
{action_html}
""" def render_agent_message(role: str, action: str, style_class: str) -> str: action_html = markdown.markdown(action, extensions=["extra", "nl2br"]) # 转换 Markdown 为 HTML return f"""
{role}
{action_html}
""" def extract_week_number(phase_name: str) -> int: """从 'Week 1–4'(含 en dash)中提取排序基准数字""" # 替换 en dash(–)和 em dash(—)为 ASCII dash(-) normalized = phase_name.replace("–", "-").replace("—", "-") match = re.search(r"Week (\d+)", normalized) return int(match.group(1)) if match else 0 def render_exercise_plan_return_html(plan: Dict) -> str: sorted_phases = sorted(plan.items(), key=lambda x: extract_week_number(x[0])) html_blocks = [] for i, (phase, content) in enumerate(sorted_phases, start=1): goal = content.get("Goal", "") prescriptions = content.get("Prescription", []) markdown_text = f"

Phase {i}: {phase}

" markdown_text += f"GOAL: {goal}

" for item in prescriptions: category = item.get("Category", "Training") description = item.get("Description", "") markdown_text += f"{category} Training:
" for part in description.split(", "): markdown_text += f"- {part.strip()}
" markdown_text += "
" html = render_agent_message_return_html( role="A. Exercise Prescriptionist Agent", action_html=markdown_text, style_class="exercise" ) html_blocks.append(html) return "\n".join(html_blocks) def render_surgical_pharma_plan_return_html(plan_data: Dict) -> List[str]: """ 返回 Surgical & Pharmacological Specialist Agent 的多个 HTML 块列表, 每个块单独传入 st.markdown(..., unsafe_allow_html=True) 渲染。 """ html_blocks = [] # Step 1: 渲染 Guideline Summary guideline_markdown = "#### Clinical Guideline Analysis\n\n### Matched Guidelines Summary\n\n" title_map = { "564": "Severe Functional Limitation", "225": "Moderate Functional Limitation with Mechanical Symptoms", "482": "Younger Patient with Single-Compartment Disease" } for item in plan_data.get("matched_guidelines", []): guideline_text = item.get("guideline", "") match = re.search(r"Scenario (\d+):", guideline_text) gid = match.group(1) if match else "Unknown" title = title_map.get(gid, "Clinical Scenario") guideline_markdown += f"**Guideline {gid}: {title}**\n" def extract_section(text, start_kw, end_kw=None): try: start = text.index(start_kw) end = text.index(end_kw, start) if end_kw else None return text[start + len(start_kw):end].strip() except ValueError: return "" clinical = extract_section(guideline_text, 'The patient reports', 'Demonstrates') or extract_section(guideline_text, 'Experiences', 'has limited') or "" physical = extract_section(guideline_text, 'Demonstrates', 'Shows') or extract_section(guideline_text, 'has limited', 'shows') or "" radio = extract_section(guideline_text, 'Shows', 'Total') or extract_section(guideline_text, 'exhibits', 'Total') or "" guideline_markdown += f"- **Clinical Presentation:** {clinical}\n" guideline_markdown += f"- **Physical Findings:** {physical}\n" guideline_markdown += f"- **Radiographic Features:** {radio}\n" guideline_markdown += f"**Recommendations:**\n" recos = re.findall(r"(Total knee arthroplasty|Unicompartmental knee arthroplasty.*?|Realignment Osteotomy.*?)\s*(Appropriate|May Be Appropriate|Rarely Appropriate)\s*(\d)", guideline_text) for rec in recos: guideline_markdown += f"- {rec[0]}: {rec[1]} ({rec[2]}/9)\n" guideline_markdown += "\n" guideline_html = render_agent_message_return_html( role="B. Surgical & Pharmacological Specialist Agent", action_html=markdown.markdown(guideline_markdown,extensions=["extra", "nl2br"]), style_class="surgical" ) html_blocks.append(guideline_html) # Step 2: 药物推荐表格 meds = plan_data.get("medication_plan", []) med_table_md = "#### Pharmacological Management Plan\n\n" med_table_md += "| Medication | Dosage | Administration Schedule | Notes |\n" med_table_md += "|------------|--------|-------------------------|-------|\n" for med in meds: name = med.get("name", "") dosage = med.get("dosage", "") freq = med.get("frequency", "") notes = "" if "Ibuprofen" in name: notes = "Monitor for GI effects; take with food" elif "Acetaminophen" in name: notes = "Not to exceed 3000 mg daily" elif "Corticosteroids" in name: notes = "Consider after failed oral analgesics" med_table_md += f"| {name} | {dosage} | {freq} | {notes} |\n" med_table_md += "\nNote: Medication regimen should be tailored based on patient comorbidities, concomitant medications, and individual response to therapy.\n" med_table_html = markdown.markdown(med_table_md, extensions=["extra", "nl2br"]) wrapped_html = f"
{med_table_html}
" pharma_html = render_agent_message_return_html( role="B. Surgical & Pharmacological Specialist Agent", action_html=wrapped_html, style_class="pharma" ) html_blocks.append(pharma_html) return html_blocks def render_nutrition_psychology_plan_return_html(plan_data: Dict) -> List[str]: """返回 Nutritional & Psychological Specialist Agent 的多个 HTML 气泡块""" html_blocks = [] # ----------- Nutrition 部分 ----------- nutrition = plan_data.get("nutrition", {}) n_goal = nutrition.get("goal", "") n_duration = nutrition.get("duration", "") n_content = nutrition.get("content", []) nutrition_md = "#### Nutritional Intervention Plan\n" nutrition_md += f"**Goal:** {n_goal}\n\n" nutrition_md += f"**Delivery Method:** Personalized one-on-one counseling supplemented with mobile application reminders\n" nutrition_md += f"**Program Structure:**\n" nutrition_md += f"- **Initial Phase:** Weekly consultations (first 6 weeks)\n" nutrition_md += f"- **Maintenance Phase:** Bi-weekly check-ins\n" nutrition_md += f"- **Total Duration:** {n_duration} comprehensive program\n" # 分类策略 strategies = { "Anti-inflammatory": [], "Macronutrient": [], "Weight": [] } for item in n_content: if "Anti-inflammatory" in item or "Adequacy" in item: strategies["Anti-inflammatory"].append(item) elif "macronutrient" in item or "Balance" in item: strategies["Macronutrient"].append(item) elif "calorie" in item.lower() or "Calorie control" in item: strategies["Weight"].append(item) nutrition_md += "**Key Nutritional Strategies:**\n" if strategies["Anti-inflammatory"]: nutrition_md += "1. **Anti-inflammatory Focus**\n" nutrition_md += " - Incorporate omega-3 rich foods (fatty fish, walnuts, flaxseeds)\n" nutrition_md += " - Increase consumption of antioxidant-rich leafy greens\n" nutrition_md += " - Integrate nuts and seeds for micronutrient support\n" nutrition_md += " - Purpose: Reduce joint inflammation and support tissue repair\n" if strategies["Macronutrient"]: nutrition_md += "2. **Macronutrient Optimization**\n" nutrition_md += " - Ensure adequate protein intake to support muscle maintenance\n" nutrition_md += " - Balance complex carbohydrates for sustained energy\n" nutrition_md += " - Include healthy fats to support joint lubrication\n" nutrition_md += " - Purpose: Enhance musculoskeletal strength and joint function\n" if strategies["Weight"]: nutrition_md += "3. **Weight Management**\n" nutrition_md += " - Implement portion awareness techniques\n" nutrition_md += " - Monitor caloric balance through guided food journaling\n" nutrition_md += " - Adjust intake based on activity levels and rehabilitation phases\n" nutrition_md += " - Purpose: Reduce mechanical stress on knee joints\n" nutrition_html = render_agent_message_return_html( role="C. Nutritional & Psychological Specialist Agent", action_html=markdown.markdown(nutrition_md,extensions=["extra", "nl2br"]), style_class="nutrition" ) html_blocks.append(nutrition_html) # ----------- Psychology 部分 ----------- psych = plan_data.get("psychology", {}) p_goal = psych.get("goal", "") p_duration = psych.get("duration", "") p_content = psych.get("content", []) psychology_md = "#### Psychological Support\n" psychology_md += f"**Goal:** {p_goal}\n\n" psychology_md += f"**Delivery Method:** Tele-health Cognitive Behavioral Therapy with structured daily practice components\n" psychology_md += f"**Program Structure:**\n" psychology_md += f"- **Intensive Phase:** Weekly sessions (first 8 weeks)\n" psychology_md += f"- **Consolidation Phase:** Bi-weekly sessions\n" psychology_md += f"- **Total Duration:** {p_duration} comprehensive program\n" psychology_md += "**Evidence-Based Psychological Approaches:**\n" for idx, item in enumerate(p_content, start=1): if "Motivational" in item: psychology_md += f"{idx}. **Motivational Interviewing**\n" psychology_md += " - Explore personal values related to mobility and function\n" psychology_md += " - Resolve ambivalence about rehabilitation commitment\n" psychology_md += " - Develop intrinsic motivation for consistent exercise adherence\n" psychology_md += " - Purpose: Strengthen commitment to rehabilitation protocols\n" elif "CBT" in item: psychology_md += f"{idx}. **Cognitive Restructuring**\n" psychology_md += " - Identify and challenge maladaptive thoughts about pain and recovery\n" psychology_md += " - Transform catastrophizing patterns into realistic perspectives\n" psychology_md += " - Develop confidence in functional improvement\n" psychology_md += " - Purpose: Reduce pain-related fear and enhance rehabilitation engagement\n" elif "mindfulness" in item.lower(): psychology_md += f"{idx}. **Digital Mindfulness Integration**\n" psychology_md += " - Implement scheduled mindfulness practice through mobile notifications\n" psychology_md += " - Provide guided pain-specific meditation recordings\n" psychology_md += " - Track stress levels in relation to symptom fluctuations\n" psychology_md += " - Purpose: Enhance stress management and improve pain tolerance\n" psychology_md += "*Note: Both nutritional and psychological interventions will be coordinated with physical rehabilitation to ensure comprehensive care integration.*" psychology_html = render_agent_message_return_html( role="C. Nutritional & Psychological Specialist Agent", action_html=markdown.markdown(psychology_md, extensions=["extra", "nl2br"]), style_class="psychology" ) html_blocks.append(psychology_html) return html_blocks def render_clinical_decision_agent_return_html(plan_data: Dict) -> str: primary_goal = plan_data.get("Goals", {}).get("Primary", "") secondary_goal = plan_data.get("Goals", {}).get("Secondary", "") plan = plan_data.get("InterventionPlan", {}) action_md = "#### Integrated Multimodal Intervention Plan\n\n" # Medication med_summary = plan.get("Medication", {}).get("Summary", "") action_md += "**🩺 Medication Strategy**\n" action_md += f"- {med_summary}\n\n" # Nutrition nutrition_desc = plan.get("NutritionPlan", {}).get("Description", "") framework = plan.get("NutritionPlan", {}).get("Framework", "") action_md += "**🥗 Nutrition Plan**\n" action_md += f"- **Framework:** {framework}\n" action_md += f"- {nutrition_desc}\n\n" # Exercise exercise = plan.get("ExercisePlan", {}) framework = exercise.get("Framework", "") phases = exercise.get("Phases", {}) action_md += "**🏃 Exercise Plan**\n" action_md += f"- **Framework:** {framework}\n" for week_range, content in phases.items(): goal = content.get("Goal", "") prescription = content.get("Prescription", "") action_md += f" - **{week_range}:** {goal}\n" action_md += f" - {prescription}\n" action_md += "\n" # Psychology psych_summary = plan.get("PsychologicalSupport", {}).get("Summary", "") action_md += "**🧠 Psychological Support**\n" action_md += f"- {psych_summary}\n\n" # Surgical surgical_summary = plan.get("SurgicalOrInjectionConsiderations", {}).get("Summary", "") action_md += "**🛠️ Surgical or Injection Considerations**\n" action_md += f"- {surgical_summary}\n\n" # Safety safety_summary = plan.get("SafetyMonitoring", {}).get("Summary", "") action_md += "**🔍 Safety Monitoring Plan**\n" action_md += f"- {safety_summary}\n\n" # Personalization accessibility = plan_data.get("AccessibilityFeasibility", "") rationale = plan_data.get("PersonalizationRationale", "") evidence = plan_data.get("EvidenceCompliance", "") action_md += "#### Personalized Treatment Context\n" action_md += f"- **Accessibility & Feasibility:** {accessibility}\n" action_md += f"- **Personalization Rationale:** {rationale}\n" action_md += f"- **Evidence Compliance:** {evidence}\n" html = render_agent_message( role="🧩 Clinical Decision-Making Agent", action=action_md, style_class="decision" ) return html def render_progress_bar(step: int, total: int): '''进度条函数''' progress = step / total st.markdown(f"""
""", unsafe_allow_html=True) st.markdown(f"Progress: {int(progress * 100)}%", unsafe_allow_html=True) def render_all_agents_auto(): total_agents = 4 progress_placeholder = st.empty() # ✅ Agent A: Exercise progress_placeholder.markdown(render_progress_bar_html(1, total_agents), unsafe_allow_html=True) exercise_plan = load_plan("exercise") with st.expander("A. Exercise Prescriptionist Agent", expanded=False): html = render_exercise_plan_return_html(exercise_plan) st.markdown(html, unsafe_allow_html=True) # ✅ Agent B: Surgical & Pharma progress_placeholder.markdown(render_progress_bar_html(2, total_agents), unsafe_allow_html=True) surgical_plan = load_plan("surgical_pharma") with st.expander("B. Surgical & Pharmacological Specialist Agent", expanded=False): # html = render_surgical_pharma_plan_return_html(surgical_plan) # st.markdown(html, unsafe_allow_html=True) html_blocks = render_surgical_pharma_plan_return_html(surgical_plan) for html in html_blocks: st.markdown(html, unsafe_allow_html=True) # ✅ Agent C: Nutrition & Psychology progress_placeholder.markdown(render_progress_bar_html(3, total_agents), unsafe_allow_html=True) nutrition_plan = load_plan("nutrition_psychology") with st.expander("C. Nutritional & Psychological Specialist Agent", expanded=False): # html = render_nutrition_psychology_plan_return_html(nutrition_plan) # st.markdown(html, unsafe_allow_html=True) html_blocks = render_nutrition_psychology_plan_return_html(nutrition_plan) for html in html_blocks: st.markdown(html, unsafe_allow_html=True) with st.spinner("Clinical Decision-Making Agent reasoning..."): time.sleep(3) progress_placeholder.markdown(render_progress_bar_html(4, total_agents), unsafe_allow_html=True) decision_plan = load_plan("clinical_integration") with st.expander("D. Clinical Decision-Making Agent", expanded=False): html = render_clinical_decision_agent_return_html(decision_plan) st.markdown(html, unsafe_allow_html=True) # 进度条 HTML 渲染拆出来方便复用 def render_progress_bar_html(step: int, total: int) -> str: progress = step / total return f"""
Progress: {int(progress * 100)}% """ def render_therapy_page(): """渲染治疗推荐页面""" inject_agent_styles() col1, col2 = st.columns([1.5, 1]) with col2: safe_image_display(IMAGE_PATHS["recommendation_framework"], "Framework for personalizing treatment", use_container_width=True) with col1: st.subheader("🩺 Quick Therapy Demo") case_data = load_case_data() if not case_data: st.warning("Case data cannot be loaded.") return case_names = list(case_data.keys()) case_options = [""] + case_names selected_case = st.selectbox("Select a sample case:", case_options) st.markdown(get_chat_styles(), unsafe_allow_html=True) if selected_case != "": case = case_data[selected_case] if "start_clicked" not in st.session_state: st.session_state.start_clicked = False if not st.session_state.get("start_clicked", False): st.success(f"Selected:{selected_case}") st.markdown("### 📑 Case Reports") for report_name in case.get("reports", []): st.markdown(f"📄 {report_name}") # if st.button("▶️ Start Multi-Agent Reasoning"): # st.session_state.start_clicked = True # st.query_params.update({"start": "1"}) button_placeholder = st.empty() spacer_placeholder = st.empty() if "start_clicked" not in st.session_state: st.session_state.start_clicked = False if not st.session_state.start_clicked: if button_placeholder.button("▶️ Start Multi-Agent Reasoning"): st.session_state.start_clicked = True st.query_params.update({"start": "1"}) button_placeholder.empty() spacer_placeholder.markdown("
", unsafe_allow_html=True) else: button_placeholder.empty() spacer_placeholder.empty() render_all_agents_auto() else: render_all_agents_auto() # ============================================================================= # 主程序 # ============================================================================= def main(): # 页面配置 st.set_page_config(**PAGE_CONFIG) # 渲染导航 render_navigation() # 获取当前页面 page = st.query_params.get("page", "Home") # 路由到对应页面 page_routes = { "Home": render_home_page, "Assessing Current Status": render_assessment_page, "Predicting Progression Risk": render_prediction_page, "Tailored Therapy Recommendation": render_therapy_page } render_func = page_routes.get(page) if render_func: render_func() else: st.error(f"未知页面: {page}") if __name__ == "__main__": main()