| """ |
| 对比 Agent - 纵向时序分析 |
| """ |
| import json |
| from typing import Dict, Any, List, Optional |
| from datetime import datetime |
|
|
| from langchain_openai import ChatOpenAI |
| from langchain_core.messages import HumanMessage, SystemMessage |
| from langgraph.graph import StateGraph, END |
|
|
| from app.core.config import settings |
| from app.core.logging import logger |
| from app.schemas.analysis import ComparisonResult, RECISTResponse |
| from .prompts import COMPARATIVE_SYSTEM_PROMPT, COMPARATIVE_USER_TEMPLATE |
|
|
|
|
| class ComparativeAgentState(dict): |
| """对比 Agent 状态""" |
| patient_id: str |
| baseline_scan_id: str |
| baseline_date: str |
| followup_scan_id: str |
| followup_date: str |
| days_between: int |
| comparisons: List[ComparisonResult] |
| heatmap_path: str |
| analysis: str |
| recist_evaluation: str |
|
|
|
|
| class ComparativeAgent: |
| """对比专家 Agent""" |
| |
| def __init__(self, llm: ChatOpenAI = None): |
| """ |
| 初始化对比 Agent |
| |
| Args: |
| llm: LLM 实例 |
| """ |
| if llm is None: |
| self.llm = ChatOpenAI( |
| model=settings.LLM_MODEL, |
| base_url=settings.LLM_BASE_URL, |
| api_key=settings.LLM_API_KEY, |
| temperature=settings.LLM_TEMPERATURE, |
| max_tokens=settings.LLM_MAX_TOKENS |
| ) |
| else: |
| self.llm = llm |
| |
| |
| self.workflow = self._build_workflow() |
| |
| def _build_workflow(self) -> StateGraph: |
| """构建 LangGraph 工作流""" |
| workflow = StateGraph(ComparativeAgentState) |
| |
| |
| workflow.add_node("prepare_comparison", self._prepare_comparison) |
| workflow.add_node("evaluate_recist", self._evaluate_recist) |
| workflow.add_node("generate_analysis", self._generate_analysis) |
| workflow.add_node("format_report", self._format_report) |
| |
| |
| workflow.set_entry_point("prepare_comparison") |
| workflow.add_edge("prepare_comparison", "evaluate_recist") |
| workflow.add_edge("evaluate_recist", "generate_analysis") |
| workflow.add_edge("generate_analysis", "format_report") |
| workflow.add_edge("format_report", END) |
| |
| return workflow.compile() |
| |
| def _prepare_comparison(self, state: ComparativeAgentState) -> ComparativeAgentState: |
| """准备对比数据""" |
| logger.info(f"准备对比数据: {state.get('baseline_scan_id')} vs {state.get('followup_scan_id')}") |
| return state |
| |
| def _evaluate_recist(self, state: ComparativeAgentState) -> ComparativeAgentState: |
| """RECIST 评估""" |
| logger.info("执行 RECIST 评估...") |
| |
| comparisons = state.get("comparisons", []) |
| recist_results = [] |
| |
| for comp in comparisons: |
| if isinstance(comp, dict): |
| diameter_change = comp.get("diameter_change_percent", 0) |
| else: |
| diameter_change = comp.diameter_change_percent |
| |
| if diameter_change <= -30: |
| recist = "PR (部分缓解)" |
| elif diameter_change >= 20: |
| recist = "PD (疾病进展)" |
| else: |
| recist = "SD (疾病稳定)" |
| |
| recist_results.append(recist) |
| |
| state["recist_evaluation"] = ", ".join(recist_results) if recist_results else "无法评估" |
| return state |
| |
| def _generate_analysis(self, state: ComparativeAgentState) -> ComparativeAgentState: |
| """生成对比分析""" |
| logger.info("生成对比分析...") |
| |
| |
| comparisons = state.get("comparisons", []) |
| comparison_data = [] |
| |
| for comp in comparisons: |
| if isinstance(comp, dict): |
| comparison_data.append(comp) |
| elif hasattr(comp, 'model_dump'): |
| comparison_data.append(comp.model_dump()) |
| else: |
| comparison_data.append(str(comp)) |
| |
| comparison_json = json.dumps(comparison_data, indent=2, ensure_ascii=False, default=str) |
| |
| |
| user_prompt = COMPARATIVE_USER_TEMPLATE.format( |
| patient_id=state.get("patient_id", "Unknown"), |
| baseline_date=state.get("baseline_date", "Unknown"), |
| baseline_scan_id=state.get("baseline_scan_id", "Unknown"), |
| followup_date=state.get("followup_date", "Unknown"), |
| followup_scan_id=state.get("followup_scan_id", "Unknown"), |
| days_between=state.get("days_between", 0), |
| comparison_json=comparison_json, |
| heatmap_path=state.get("heatmap_path", "未生成") |
| ) |
| |
| |
| messages = [ |
| SystemMessage(content=COMPARATIVE_SYSTEM_PROMPT), |
| HumanMessage(content=user_prompt) |
| ] |
| |
| try: |
| response = self.llm.invoke(messages) |
| state["analysis"] = response.content |
| except Exception as e: |
| logger.error(f"LLM 调用失败: {e}") |
| state["analysis"] = f"分析生成失败: {str(e)}" |
| |
| return state |
| |
| def _format_report(self, state: ComparativeAgentState) -> ComparativeAgentState: |
| """格式化报告""" |
| logger.info("格式化对比报告...") |
| return state |
| |
| def compare( |
| self, |
| patient_id: str, |
| baseline_scan_id: str, |
| baseline_date: str, |
| followup_scan_id: str, |
| followup_date: str, |
| comparisons: List[Dict[str, Any]], |
| heatmap_path: str = None |
| ) -> Dict[str, Any]: |
| """ |
| 执行对比分析 |
| |
| Args: |
| patient_id: 患者 ID |
| baseline_scan_id: 基线扫描 ID |
| baseline_date: 基线日期 |
| followup_scan_id: 随访扫描 ID |
| followup_date: 随访日期 |
| comparisons: 对比数据列表 |
| heatmap_path: 热力图路径 |
| |
| Returns: |
| 对比分析结果 |
| """ |
| |
| try: |
| from datetime import datetime |
| d1 = datetime.fromisoformat(baseline_date.replace("Z", "+00:00")) |
| d2 = datetime.fromisoformat(followup_date.replace("Z", "+00:00")) |
| days_between = (d2 - d1).days |
| except: |
| days_between = 0 |
| |
| initial_state = { |
| "patient_id": patient_id, |
| "baseline_scan_id": baseline_scan_id, |
| "baseline_date": baseline_date, |
| "followup_scan_id": followup_scan_id, |
| "followup_date": followup_date, |
| "days_between": days_between, |
| "comparisons": comparisons, |
| "heatmap_path": heatmap_path or "未生成", |
| "analysis": "", |
| "recist_evaluation": "" |
| } |
| |
| |
| result = self.workflow.invoke(initial_state) |
| |
| return { |
| "patient_id": patient_id, |
| "baseline_scan_id": baseline_scan_id, |
| "followup_scan_id": followup_scan_id, |
| "days_between": days_between, |
| "recist_evaluation": result.get("recist_evaluation", ""), |
| "analysis": result.get("analysis", ""), |
| "heatmap_path": heatmap_path, |
| "generated_at": datetime.now().isoformat() |
| } |
| |
| def compare_simple( |
| self, |
| comparison_data: Dict[str, Any] |
| ) -> str: |
| """ |
| 简单对比分析 (不使用工作流) |
| |
| Args: |
| comparison_data: 对比数据字典 |
| |
| Returns: |
| 分析文本 |
| """ |
| comparison_json = json.dumps(comparison_data, indent=2, ensure_ascii=False, default=str) |
| |
| user_prompt = f"""请分析以下 CT 对比数据: |
| |
| ```json |
| {comparison_json} |
| ``` |
| |
| 请根据 RECIST 1.1 标准生成对比分析报告。 |
| """ |
| |
| messages = [ |
| SystemMessage(content=COMPARATIVE_SYSTEM_PROMPT), |
| HumanMessage(content=user_prompt) |
| ] |
| |
| try: |
| response = self.llm.invoke(messages) |
| return response.content |
| except Exception as e: |
| logger.error(f"对比分析失败: {e}") |
| return f"分析生成失败: {str(e)}" |
|
|
|
|
|
|