# TrialPath 数据与评估管线 TDD 实现指南 > 基于 DeepWiki、TREC 官方文档、ir-measures/ir_datasets 库深度研究产出 --- ## 1. 管线架构概览 ### 1.1 数据流图 ``` ┌─────────────────────────────────────────────────────────────────┐ │ Data & Evaluation Pipeline │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ │ │ Synthea │───▶│ FHIR Bundle │───▶│ PatientProfile │ │ │ │ (Java CLI) │ │ (JSON) │ │ (JSON Schema) │ │ │ └──────────────┘ └──────────────┘ └────────┬─────────┘ │ │ │ │ │ ┌──────────────┐ ┌──────────────┐ ▼ │ │ │ LLM Letter │───▶│ ReportLab │───▶ Noisy Clinical PDFs │ │ │ Generator │ │ + Augraphy │ (Letters/Labs/Path) │ │ └──────────────┘ └──────────────┘ │ │ │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ │ │ MedGemma │───▶│ Extracted │───▶│ F1 Evaluator │ │ │ │ Extractor │ │ Profile │ │ (scikit-learn) │ │ │ └──────────────┘ └──────────────┘ └──────────────────┘ │ │ │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ │ │ TREC Topics │───▶│ TrialPath │───▶│ TREC Evaluator │ │ │ │ (ir_datasets)│ │ Matching │ │ (ir-measures) │ │ │ └──────────────┘ └──────────────┘ └──────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────────┘ ``` ### 1.2 模块关系 | 模块 | 输入 | 输出 | 依赖 | |------|------|------|------| | `data/generate_synthetic_patients.py` | Synthea FHIR Bundles | `PatientProfile` JSON + Ground Truth | Synthea CLI, FHIR R4 | | `data/generate_noisy_pdfs.py` | `PatientProfile` JSON | Clinical PDFs (带噪声) | ReportLab, Augraphy | | `evaluation/run_trec_benchmark.py` | TREC Topics + TrialPath Run | Recall@50, NDCG@10, P@10 | ir_datasets, ir-measures | | `evaluation/extraction_eval.py` | Extracted vs Ground Truth Profiles | Field-level F1 | scikit-learn | | `evaluation/criterion_eval.py` | EligibilityLedger vs Gold Standard | Criterion Accuracy | scikit-learn | | `evaluation/latency_cost_tracker.py` | API call logs | Latency/Cost reports | time, logging | ### 1.3 目录结构 ``` data/ ├── generate_synthetic_patients.py # Synthea FHIR → PatientProfile ├── generate_noisy_pdfs.py # PatientProfile → Clinical PDFs ├── synthea_config/ │ ├── synthea.properties # Synthea 配置 │ └── modules/ │ └── lung_cancer_extended.json # 扩展 NSCLC 模块 (含 biomarkers) ├── templates/ │ ├── clinical_letter.py # 临床信件模板 │ ├── pathology_report.py # 病理报告模板 │ ├── lab_report.py # 实验室报告模板 │ └── imaging_report.py # 影像报告模板 ├── noise/ │ └── noise_injector.py # 噪声注入引擎 └── output/ ├── fhir/ # Synthea 原始 FHIR 输出 ├── profiles/ # 转换后的 PatientProfile JSON ├── pdfs/ # 生成的临床 PDF └── ground_truth/ # 标注数据 evaluation/ ├── run_trec_benchmark.py # TREC 检索评估 ├── extraction_eval.py # MedGemma 提取 F1 ├── criterion_eval.py # Criterion Decision Accuracy ├── latency_cost_tracker.py # 延迟与成本追踪 ├── trec_data/ │ ├── topics2021.xml # TREC 2021 topics │ ├── qrels2021.txt # TREC 2021 relevance judgments │ └── topics2022.xml # TREC 2022 topics └── reports/ # 评估报告输出 tests/ ├── test_synthea_data.py # Synthea 数据验证 ├── test_pdf_generation.py # PDF 生成正确性 ├── test_noise_injection.py # 噪声注入效果 ├── test_trec_evaluation.py # TREC 评估计算 ├── test_extraction_f1.py # F1 计算测试 ├── test_latency_cost.py # 延迟成本测试 └── test_e2e_pipeline.py # 端到端管线测试 ``` --- ## 2. Synthea 合成患者生成指南 ### 2.1 Synthea 概述 Synthea 是 MITRE 开发的开源合成患者模拟器,基于 Java 实现。它通过 JSON 状态机模块模拟疾病轨迹,输出标准 FHIR R4 Bundle。 **关键特性(来源:DeepWiki synthetichealth/synthea):** - 基于模块的疾病模拟:每种疾病定义为 JSON 状态机 - 支持 FHIR R4/STU3/DSTU2 导出 - 内置 `lung_cancer.json` 模块,85% NSCLC / 15% SCLC 分布 - 支持 Stage I-IV 分期和化疗/放疗治疗路径 - **不含 NSCLC 特异性 biomarkers(EGFR, ALK, PD-L1, KRAS, ROS1)—— 需要自定义扩展** ### 2.2 安装和配置 **系统要求:** - Java JDK 11 或更高版本(推荐 LTS 11 或 17) **安装方式 A:直接使用 JAR(推荐用于数据生成)** ```bash # 下载最新 release JAR # 从 https://github.com/synthetichealth/synthea/releases 获取 wget https://github.com/synthetichealth/synthea/releases/download/master-branch-latest/synthea-with-dependencies.jar # 验证安装 java -jar synthea-with-dependencies.jar --help ``` **安装方式 B:从源码构建(需要自定义模块时使用)** ```bash git clone https://github.com/synthetichealth/synthea.git cd synthea ./gradlew build check test ``` ### 2.3 NSCLC 模块配置 #### 2.3.1 现有 lung_cancer 模块分析 来源:DeepWiki 对 `synthetichealth/synthea` 的 `lung_cancer.json` 模块分析: - **入口条件**:45-65 岁人群,基于概率计算 - **诊断流程**:症状(咳嗽、咯血、气短) → 胸部 X 光 → 胸部 CT → 活检/细胞学 - **分型**:85% NSCLC,15% SCLC - **分期**:Stage I-IV,基于 `lung_cancer_nondiagnosis_counter` - **治疗**:NSCLC 使用 Cisplatin + Paclitaxel → 放疗 #### 2.3.2 自定义 NSCLC Biomarker 扩展模块 由于原生模块不含 EGFR/ALK/PD-L1 等 biomarkers,需要创建扩展子模块。 **文件:`data/synthea_config/modules/lung_cancer_biomarkers.json`** 基于 DeepWiki 研究的 Synthea 模块状态类型,可用的状态类型包括: - `Initial` — 模块入口 - `Terminal` — 模块出口 - `Observation` — 记录临床观察值(用于 biomarkers) - `SetAttribute` — 设置患者属性 - `Guard` — 条件门控 - `Simple` — 简单转换状态 - `Encounter` — 就诊状态 Biomarker 观察状态示例结构: ```json { "name": "NSCLC Biomarker Panel", "states": { "Initial": { "type": "Initial", "conditional_transition": [ { "condition": { "condition_type": "Attribute", "attribute": "Lung Cancer Type", "operator": "==", "value": "NSCLC" }, "transition": "EGFR_Test_Encounter" }, { "transition": "Terminal" } ] }, "EGFR_Test_Encounter": { "type": "Encounter", "encounter_class": "ambulatory", "codes": [ { "system": "SNOMED-CT", "code": "185349003", "display": "Encounter for check up" } ], "direct_transition": "EGFR_Mutation_Status" }, "EGFR_Mutation_Status": { "type": "Observation", "category": "laboratory", "codes": [ { "system": "LOINC", "code": "41103-3", "display": "EGFR gene mutations found" } ], "distributed_transition": [ { "distribution": 0.15, "transition": "EGFR_Positive" }, { "distribution": 0.85, "transition": "EGFR_Negative" } ] }, "EGFR_Positive": { "type": "SetAttribute", "attribute": "egfr_status", "value": "positive", "direct_transition": "ALK_Rearrangement_Status" }, "EGFR_Negative": { "type": "SetAttribute", "attribute": "egfr_status", "value": "negative", "direct_transition": "ALK_Rearrangement_Status" }, "ALK_Rearrangement_Status": { "type": "Observation", "category": "laboratory", "codes": [ { "system": "LOINC", "code": "46264-8", "display": "ALK gene rearrangement" } ], "distributed_transition": [ { "distribution": 0.05, "transition": "ALK_Positive" }, { "distribution": 0.95, "transition": "ALK_Negative" } ] }, "ALK_Positive": { "type": "SetAttribute", "attribute": "alk_status", "value": "positive", "direct_transition": "PDL1_Expression" }, "ALK_Negative": { "type": "SetAttribute", "attribute": "alk_status", "value": "negative", "direct_transition": "PDL1_Expression" }, "PDL1_Expression": { "type": "Observation", "category": "laboratory", "codes": [ { "system": "LOINC", "code": "85147-0", "display": "PD-L1 by immune stain" } ], "distributed_transition": [ { "distribution": 0.30, "transition": "PDL1_High" }, { "distribution": 0.35, "transition": "PDL1_Low" }, { "distribution": 0.35, "transition": "PDL1_Negative" } ] }, "PDL1_High": { "type": "SetAttribute", "attribute": "pdl1_tps", "value": ">=50%", "direct_transition": "KRAS_Mutation_Status" }, "PDL1_Low": { "type": "SetAttribute", "attribute": "pdl1_tps", "value": "1-49%", "direct_transition": "KRAS_Mutation_Status" }, "PDL1_Negative": { "type": "SetAttribute", "attribute": "pdl1_tps", "value": "<1%", "direct_transition": "KRAS_Mutation_Status" }, "KRAS_Mutation_Status": { "type": "Observation", "category": "laboratory", "codes": [ { "system": "LOINC", "code": "21717-3", "display": "KRAS gene mutations found" } ], "distributed_transition": [ { "distribution": 0.25, "transition": "KRAS_Positive" }, { "distribution": 0.75, "transition": "KRAS_Negative" } ] }, "KRAS_Positive": { "type": "SetAttribute", "attribute": "kras_status", "value": "positive", "direct_transition": "Terminal" }, "KRAS_Negative": { "type": "SetAttribute", "attribute": "kras_status", "value": "negative", "direct_transition": "Terminal" }, "Terminal": { "type": "Terminal" } } } ``` **Biomarker 流行率分布(基于 NSCLC 文献):** | Biomarker | 阳性率 | LOINC Code | 说明 | |-----------|--------|------------|------| | EGFR mutation | ~15% | 41103-3 | 非吸烟亚裔女性更高 | | ALK rearrangement | ~5% | 46264-8 | 年轻非吸烟者更常见 | | PD-L1 TPS>=50% | ~30% | 85147-0 | 免疫治疗适用标准 | | KRAS G12C | ~13% | 21717-3 | Sotorasib 靶向 | | ROS1 fusion | ~1-2% | 46265-5 | Crizotinib 靶向 | ### 2.4 批量生成命令 ```bash # 生成 500 个 NSCLC 患者,使用种子确保可重现 java -jar synthea-with-dependencies.jar \ -p 500 \ -s 42 \ -m lung_cancer \ --exporter.fhir.export=true \ --exporter.fhir_stu3.export=false \ --exporter.fhir_dstu2.export=false \ --exporter.ccda.export=false \ --exporter.csv.export=false \ --exporter.hospital.fhir.export=false \ --exporter.practitioner.fhir.export=false \ --exporter.pretty_print=true \ Massachusetts # 参数说明: # -p 500 : 生成 500 个患者 # -s 42 : 随机种子 (可重现) # -m lung_cancer : 仅运行 lung_cancer 模块 # --exporter.fhir.export=true : 启用 FHIR R4 导出 # Massachusetts : 生成地区 ``` **输出位置:** `./output/fhir/` 目录下,每个患者一个 JSON 文件。 ### 2.5 FHIR Bundle 输出格式 来源:DeepWiki `synthetichealth/synthea` 关于 FHIR 导出系统的分析。 **顶层结构:** ```json { "resourceType": "Bundle", "type": "transaction", "entry": [ { "fullUrl": "urn:uuid:patient-uuid-here", "resource": { "resourceType": "Patient", ... }, "request": { "method": "POST", "url": "Patient" } }, { "fullUrl": "urn:uuid:condition-uuid-here", "resource": { "resourceType": "Condition", ... }, "request": { "method": "POST", "url": "Condition" } } ] } ``` **Synthea 生成的 FHIR Resource 类型(DeepWiki 确认):** - `Patient` — 患者基本信息 - `Condition` — 诊断(如 NSCLC) - `Observation` — 实验室检查和生命体征 - `MedicationRequest` — 用药处方 - `Procedure` — 手术和操作 - `DiagnosticReport` — 诊断报告 - `DocumentReference` — 临床文档(需 US Core IG 启用) - `Encounter` — 就诊记录 - `AllergyIntolerance` — 过敏史 - `Immunization` — 免疫接种 - `CarePlan` — 护理计划 - `ImagingStudy` — 影像检查 ### 2.6 FHIR Resource 到 PatientProfile 的映射 ```python # data/generate_synthetic_patients.py 中的映射逻辑 FHIR_TO_PATIENT_PROFILE_MAP = { # Patient Resource → demographics "Patient.name": "demographics.name", "Patient.gender": "demographics.sex", "Patient.birthDate": "demographics.date_of_birth", "Patient.address.state": "demographics.state", # Condition Resource → diagnosis "Condition[code=SNOMED:254637007]": "diagnosis.primary", # NSCLC "Condition.stage.summary": "diagnosis.stage", "Condition.bodySite": "diagnosis.histology", # Observation Resources → biomarkers "Observation[code=LOINC:41103-3]": "biomarkers.egfr", "Observation[code=LOINC:46264-8]": "biomarkers.alk", "Observation[code=LOINC:85147-0]": "biomarkers.pdl1_tps", "Observation[code=LOINC:21717-3]": "biomarkers.kras", # Observation Resources → labs "Observation[category=laboratory]": "labs[]", # MedicationRequest → prior_treatments "MedicationRequest.medicationCodeableConcept": "treatments[].medication", # Procedure → prior_treatments "Procedure.code": "treatments[].procedure", } ``` **转换函数模式:** ```python import json from pathlib import Path from dataclasses import dataclass, field, asdict from typing import Optional @dataclass class Demographics: name: str = "" sex: str = "" date_of_birth: str = "" age: int = 0 state: str = "" @dataclass class Diagnosis: primary: str = "" stage: str = "" histology: str = "" diagnosis_date: str = "" @dataclass class Biomarkers: egfr: Optional[str] = None alk: Optional[str] = None pdl1_tps: Optional[str] = None kras: Optional[str] = None ros1: Optional[str] = None @dataclass class LabResult: name: str = "" value: float = 0.0 unit: str = "" date: str = "" loinc_code: str = "" @dataclass class Treatment: name: str = "" type: str = "" # "medication" | "procedure" | "radiation" start_date: str = "" end_date: Optional[str] = None @dataclass class PatientProfile: patient_id: str = "" demographics: Demographics = field(default_factory=Demographics) diagnosis: Diagnosis = field(default_factory=Diagnosis) biomarkers: Biomarkers = field(default_factory=Biomarkers) labs: list[LabResult] = field(default_factory=list) treatments: list[Treatment] = field(default_factory=list) unknowns: list[str] = field(default_factory=list) evidence_spans: list[dict] = field(default_factory=list) def parse_fhir_bundle(fhir_path: Path) -> PatientProfile: """Parse a Synthea FHIR Bundle JSON into PatientProfile.""" with open(fhir_path) as f: bundle = json.load(f) profile = PatientProfile() entries = bundle.get("entry", []) for entry in entries: resource = entry.get("resource", {}) resource_type = resource.get("resourceType") if resource_type == "Patient": _parse_patient(resource, profile) elif resource_type == "Condition": _parse_condition(resource, profile) elif resource_type == "Observation": _parse_observation(resource, profile) elif resource_type == "MedicationRequest": _parse_medication(resource, profile) elif resource_type == "Procedure": _parse_procedure(resource, profile) return profile def _parse_patient(resource: dict, profile: PatientProfile): """Extract demographics from Patient resource.""" names = resource.get("name", [{}]) if names: given = " ".join(names[0].get("given", [])) family = names[0].get("family", "") profile.demographics.name = f"{given} {family}".strip() profile.demographics.sex = resource.get("gender", "") profile.demographics.date_of_birth = resource.get("birthDate", "") profile.patient_id = resource.get("id", "") addresses = resource.get("address", [{}]) if addresses: profile.demographics.state = addresses[0].get("state", "") def _parse_condition(resource: dict, profile: PatientProfile): """Extract diagnosis from Condition resource.""" code = resource.get("code", {}) codings = code.get("coding", []) for coding in codings: # SNOMED codes for lung cancer if coding.get("code") in ["254637007", "254632001"]: profile.diagnosis.primary = coding.get("display", "") onset = resource.get("onsetDateTime", "") profile.diagnosis.diagnosis_date = onset # Extract stage if available stage_info = resource.get("stage", []) if stage_info: summary = stage_info[0].get("summary", {}) stage_codings = summary.get("coding", []) if stage_codings: profile.diagnosis.stage = stage_codings[0].get("display", "") def _parse_observation(resource: dict, profile: PatientProfile): """Extract labs and biomarkers from Observation resource.""" code = resource.get("code", {}) codings = code.get("coding", []) category_list = resource.get("category", []) is_lab = any( cat_coding.get("code") == "laboratory" for cat in category_list for cat_coding in cat.get("coding", []) ) for coding in codings: loinc = coding.get("code", "") display = coding.get("display", "") # Biomarker mappings biomarker_map = { "41103-3": "egfr", "46264-8": "alk", "85147-0": "pdl1_tps", "21717-3": "kras", "46265-5": "ros1", } if loinc in biomarker_map: value_cc = resource.get("valueCodeableConcept", {}) value_codings = value_cc.get("coding", []) value_str = value_codings[0].get("display", "") if value_codings else "" setattr(profile.biomarkers, biomarker_map[loinc], value_str) elif is_lab: value_qty = resource.get("valueQuantity", {}) lab = LabResult( name=display, value=value_qty.get("value", 0.0), unit=value_qty.get("unit", ""), date=resource.get("effectiveDateTime", ""), loinc_code=loinc, ) profile.labs.append(lab) ``` --- ## 3. 合成 PDF 生成管线 ### 3.1 概述 目标:将 `PatientProfile` 转换为逼真的临床文档 PDF,并注入受控噪声以模拟真实世界 OCR 场景。 **技术栈:** - **ReportLab** (`pip install reportlab`) — PDF 生成引擎,支持 `SimpleDocTemplate`、`Table`、`Paragraph` 等 Platypus 流式组件 - **Augraphy** (`pip install augraphy`) — 文档图像退化管线,模拟打印、传真、扫描噪声 - **Pillow** (`pip install Pillow`) — 图像处理 - **pdf2image** (`pip install pdf2image`) — PDF 转图像(用于噪声注入后转回 PDF) ### 3.2 临床信件模板 ```python # data/templates/clinical_letter.py from reportlab.lib.pagesizes import letter from reportlab.lib.units import inch from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.platypus import ( SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle ) from reportlab.lib import colors def generate_clinical_letter(profile: dict, output_path: str): """Generate a clinical letter PDF from PatientProfile.""" doc = SimpleDocTemplate(output_path, pagesize=letter, topMargin=1*inch, bottomMargin=1*inch) styles = getSampleStyleSheet() story = [] # Header header_style = ParagraphStyle( 'Header', parent=styles['Heading1'], fontSize=14, spaceAfter=6 ) story.append(Paragraph("Clinical Summary Letter", header_style)) story.append(Spacer(1, 12)) # Patient Info info_data = [ ["Patient Name:", profile["demographics"]["name"]], ["Date of Birth:", profile["demographics"]["date_of_birth"]], ["Sex:", profile["demographics"]["sex"]], ["MRN:", profile["patient_id"]], ] info_table = Table(info_data, colWidths=[2*inch, 4*inch]) info_table.setStyle(TableStyle([ ('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'), ('FONTNAME', (1, 0), (1, -1), 'Helvetica'), ('FONTSIZE', (0, 0), (-1, -1), 10), ('VALIGN', (0, 0), (-1, -1), 'TOP'), ])) story.append(info_table) story.append(Spacer(1, 18)) # Diagnosis Section story.append(Paragraph("Diagnosis", styles['Heading2'])) dx = profile.get("diagnosis", {}) dx_text = ( f"Primary: {dx.get('primary', 'Unknown')}. " f"Stage: {dx.get('stage', 'Unknown')}. " f"Histology: {dx.get('histology', 'Unknown')}. " f"Diagnosed: {dx.get('diagnosis_date', 'Unknown')}." ) story.append(Paragraph(dx_text, styles['Normal'])) story.append(Spacer(1, 12)) # Biomarkers Section story.append(Paragraph("Molecular Testing", styles['Heading2'])) bm = profile.get("biomarkers", {}) bm_data = [["Biomarker", "Result"]] for marker, value in bm.items(): if value is not None: bm_data.append([marker.upper(), str(value)]) if len(bm_data) > 1: bm_table = Table(bm_data, colWidths=[2.5*inch, 3.5*inch]) bm_table.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey), ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), ('FONTSIZE', (0, 0), (-1, -1), 10), ])) story.append(bm_table) story.append(Spacer(1, 12)) # Treatment History story.append(Paragraph("Treatment History", styles['Heading2'])) treatments = profile.get("treatments", []) for tx in treatments: tx_text = f"- {tx['name']} ({tx['type']}): {tx.get('start_date', '')}" story.append(Paragraph(tx_text, styles['Normal'])) doc.build(story) ``` ### 3.3 病理报告模板 ```python # data/templates/pathology_report.py def generate_pathology_report(profile: dict, output_path: str): """Generate a pathology report PDF.""" doc = SimpleDocTemplate(output_path, pagesize=letter) styles = getSampleStyleSheet() story = [] story.append(Paragraph("SURGICAL PATHOLOGY REPORT", styles['Title'])) story.append(Spacer(1, 12)) # Specimen Info spec_data = [ ["Specimen:", "Right lung, upper lobe, wedge resection"], ["Procedure:", "CT-guided needle biopsy"], ["Date:", profile["diagnosis"]["diagnosis_date"]], ] spec_table = Table(spec_data, colWidths=[2*inch, 4*inch]) story.append(spec_table) story.append(Spacer(1, 12)) # Final Diagnosis story.append(Paragraph("FINAL DIAGNOSIS", styles['Heading2'])) story.append(Paragraph( f"Non-small cell lung carcinoma, {profile['diagnosis'].get('histology', 'adenocarcinoma')}, " f"{profile['diagnosis'].get('stage', 'Stage IIIA')}", styles['Normal'] )) # Biomarker Results story.append(Spacer(1, 12)) story.append(Paragraph("MOLECULAR/IMMUNOHISTOCHEMISTRY", styles['Heading2'])) bm = profile.get("biomarkers", {}) results = [] if bm.get("egfr"): results.append(f"EGFR mutation analysis: {bm['egfr']}") if bm.get("alk"): results.append(f"ALK rearrangement (FISH): {bm['alk']}") if bm.get("pdl1_tps"): results.append(f"PD-L1 (22C3, TPS): {bm['pdl1_tps']}") if bm.get("kras"): results.append(f"KRAS mutation analysis: {bm['kras']}") for r in results: story.append(Paragraph(r, styles['Normal'])) doc.build(story) ``` ### 3.4 实验室报告模板 ```python # data/templates/lab_report.py def generate_lab_report(profile: dict, output_path: str): """Generate a laboratory report PDF with CBC, CMP, etc.""" doc = SimpleDocTemplate(output_path, pagesize=letter) styles = getSampleStyleSheet() story = [] story.append(Paragraph("LABORATORY REPORT", styles['Title'])) story.append(Spacer(1, 12)) # Lab Results Table lab_data = [["Test", "Result", "Unit", "Reference Range", "Date"]] for lab in profile.get("labs", []): lab_data.append([ lab["name"], str(lab["value"]), lab["unit"], "", # Reference range (can be added) lab["date"][:10] if lab["date"] else "" ]) if len(lab_data) > 1: lab_table = Table(lab_data, colWidths=[2*inch, 1*inch, 0.8*inch, 1.2*inch, 1*inch]) lab_table.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#003366')), ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), ('FONTSIZE', (0, 0), (-1, -1), 9), ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#f0f0f0')]), ])) story.append(lab_table) doc.build(story) ``` ### 3.5 噪声注入策略 ```python # data/noise/noise_injector.py import random import re from pathlib import Path from PIL import Image # Augraphy 管线配置 try: from augraphy import ( AugraphyPipeline, InkBleed, Letterpress, LowInkPeriodicLines, DirtyDrum, SubtleNoise, Jpeg, Brightness, BleedThrough ) AUGRAPHY_AVAILABLE = True except ImportError: AUGRAPHY_AVAILABLE = False class NoiseInjector: """受控噪声注入引擎,模拟真实世界文档退化。""" # OCR 常见错误映射 OCR_ERROR_MAP = { "0": ["O", "o", "Q"], "1": ["l", "I", "|"], "5": ["S", "s"], "8": ["B"], "O": ["0", "Q"], "l": ["1", "I", "|"], "rn": ["m"], "cl": ["d"], "vv": ["w"], } # 医学缩写替换 ABBREVIATION_MAP = { "non-small cell lung cancer": ["NSCLC", "non-small cell ca", "NSCC"], "adenocarcinoma": ["adeno", "adenoca", "adeno ca"], "squamous cell carcinoma": ["SCC", "squamous ca", "sq cell ca"], "Eastern Cooperative Oncology Group": ["ECOG"], "performance status": ["PS", "perf status"], "milligrams per deciliter": ["mg/dL", "mg/dl"], "computed tomography": ["CT", "cat scan"], } # 噪声级别配置 NOISE_LEVELS = { "clean": {"ocr_rate": 0.0, "abbrev_rate": 0.0, "missing_rate": 0.0}, "mild": {"ocr_rate": 0.02, "abbrev_rate": 0.1, "missing_rate": 0.05}, "moderate": {"ocr_rate": 0.05, "abbrev_rate": 0.2, "missing_rate": 0.1}, "severe": {"ocr_rate": 0.10, "abbrev_rate": 0.3, "missing_rate": 0.2}, } def __init__(self, noise_level: str = "mild", seed: int = 42): self.config = self.NOISE_LEVELS[noise_level] self.rng = random.Random(seed) def inject_text_noise(self, text: str) -> tuple[str, list[dict]]: """Inject OCR errors and abbreviations into text. Returns (noisy_text, list_of_injected_noise_records). """ noise_records = [] chars = list(text) # OCR character substitutions i = 0 while i < len(chars): if self.rng.random() < self.config["ocr_rate"]: original = chars[i] if original in self.OCR_ERROR_MAP: replacement = self.rng.choice(self.OCR_ERROR_MAP[original]) chars[i] = replacement noise_records.append({ "type": "ocr_error", "position": i, "original": original, "replacement": replacement, }) i += 1 noisy_text = "".join(chars) # Abbreviation substitutions for full_form, abbreviations in self.ABBREVIATION_MAP.items(): if full_form in noisy_text.lower() and self.rng.random() < self.config["abbrev_rate"]: abbrev = self.rng.choice(abbreviations) noisy_text = re.sub( re.escape(full_form), abbrev, noisy_text, count=1, flags=re.IGNORECASE ) noise_records.append({ "type": "abbreviation", "original": full_form, "replacement": abbrev, }) return noisy_text, noise_records def inject_missing_values(self, profile: dict) -> tuple[dict, list[str]]: """Randomly remove fields from profile to simulate missing data. Returns (modified_profile, list_of_removed_fields). """ removed = [] removable_fields = [ ("biomarkers", "egfr"), ("biomarkers", "alk"), ("biomarkers", "pdl1_tps"), ("biomarkers", "kras"), ("biomarkers", "ros1"), ("diagnosis", "stage"), ("diagnosis", "histology"), ] for section, field_name in removable_fields: if self.rng.random() < self.config["missing_rate"]: if section in profile and field_name in profile[section]: profile[section][field_name] = None removed.append(f"{section}.{field_name}") return profile, removed def degrade_image(self, image: Image.Image) -> Image.Image: """Apply Augraphy degradation pipeline to document image.""" if not AUGRAPHY_AVAILABLE: return image import numpy as np img_array = np.array(image) pipeline = AugraphyPipeline( ink_phase=[ InkBleed(p=0.5), Letterpress(p=0.3), LowInkPeriodicLines(p=0.3), ], paper_phase=[ SubtleNoise(p=0.5), ], post_phase=[ DirtyDrum(p=0.3), Brightness(p=0.5), Jpeg(p=0.5), ], ) degraded = pipeline(img_array) return Image.fromarray(degraded) ``` --- ## 4. TREC 基准评估指南 ### 4.1 数据集概述 **TREC Clinical Trials Track 2021:** - 来源:NIST 文本检索会议 - Topics(查询):75 个合成患者描述(5-10 句入院记录) - 文档集:376,000+ 临床试验(ClinicalTrials.gov 2021 年 4 月快照) - Qrels:35,832 条相关性判断 - 相关性标签:0=不相关,1=排除,2=合格 **TREC Clinical Trials Track 2022:** - Topics:50 个合成患者描述 - 使用相同的文档集快照 ### 4.2 数据格式 #### Topics XML 格式 ```xml A 62-year-old male presents with a 3-month history of progressive dyspnea and a 20-pound weight loss. He has a 40 pack-year smoking history. CT chest reveals a 4.5cm right upper lobe mass with mediastinal lymphadenopathy. Biopsy confirms non-small cell lung cancer, adenocarcinoma. EGFR mutation testing is positive for exon 19 deletion. PD-L1 TPS is 60%. ECOG performance status is 1. ... ``` #### Qrels 格式(制表符分隔) ``` topic_id 0 doc_id relevance 1 0 NCT00760162 2 1 0 NCT01234567 1 1 0 NCT09876543 0 ``` - 列 1:Topic 编号 - 列 2:固定值 0(迭代次数) - 列 3:NCT 文档 ID - 列 4:相关性(0=不相关,1=排除,2=合格) #### Run 提交格式 ``` TOPIC_NO Q0 NCT_ID RANK SCORE RUN_NAME 1 Q0 NCT00760162 1 0.9999 trialpath-v1 1 Q0 NCT01234567 2 0.9998 trialpath-v1 ``` ### 4.3 使用 ir_datasets 加载数据 ```python # evaluation/run_trec_benchmark.py import ir_datasets def load_trec_2021(): """Load TREC CT 2021 topics and qrels via ir_datasets.""" dataset = ir_datasets.load("clinicaltrials/2021/trec-ct-2021") # 加载 topics (GenericQuery: query_id, text) topics = {} for query in dataset.queries_iter(): topics[query.query_id] = query.text # 加载 qrels (TrecQrel: query_id, doc_id, relevance, iteration) qrels = {} for qrel in dataset.qrels_iter(): if qrel.query_id not in qrels: qrels[qrel.query_id] = {} qrels[qrel.query_id][qrel.doc_id] = qrel.relevance return topics, qrels def load_trec_2022(): """Load TREC CT 2022 topics and qrels.""" dataset = ir_datasets.load("clinicaltrials/2021/trec-ct-2022") topics = {q.query_id: q.text for q in dataset.queries_iter()} qrels = {} for qrel in dataset.qrels_iter(): if qrel.query_id not in qrels: qrels[qrel.query_id] = {} qrels[qrel.query_id][qrel.doc_id] = qrel.relevance return topics, qrels def load_trial_documents(): """Load the clinical trial documents from ir_datasets.""" dataset = ir_datasets.load("clinicaltrials/2021") # ClinicalTrialsDoc: doc_id, title, condition, summary, # detailed_description, eligibility docs = {} for doc in dataset.docs_iter(): docs[doc.doc_id] = { "title": doc.title, "condition": doc.condition, "summary": doc.summary, "detailed_description": doc.detailed_description, "eligibility": doc.eligibility, } return docs ``` ### 4.4 TrialPath 输出到 TREC 格式的映射 ```python def convert_trialpath_to_trec_run( results: dict[str, list[dict]], run_name: str = "trialpath-v1" ) -> str: """Convert TrialPath matching results to TREC run format. Args: results: {topic_id: [{"nct_id": str, "score": float}, ...]} run_name: Run identifier Returns: TREC-format run string """ lines = [] for topic_id, candidates in results.items(): sorted_candidates = sorted(candidates, key=lambda x: x["score"], reverse=True) for rank, candidate in enumerate(sorted_candidates[:1000], 1): lines.append( f"{topic_id} Q0 {candidate['nct_id']} {rank} " f"{candidate['score']:.6f} {run_name}" ) return "\n".join(lines) def save_trec_run(run_str: str, output_path: str): """Save TREC run to file.""" with open(output_path, 'w') as f: f.write(run_str) ``` ### 4.5 使用 ir-measures 计算评估指标 ```python # evaluation/run_trec_benchmark.py (续) import ir_measures from ir_measures import nDCG, P, Recall, AP, RR, SetP, SetR, SetF def evaluate_trec_run( qrels_path: str, run_path: str, ) -> dict: """Evaluate a TREC run using ir-measures. Target metrics: - Recall@50 >= 0.75 - NDCG@10 >= 0.60 - P@10 (informational) """ qrels = list(ir_measures.read_trec_qrels(qrels_path)) run = list(ir_measures.read_trec_run(run_path)) # 定义目标指标 measures = [ nDCG@10, # Target >= 0.60 Recall@50, # Target >= 0.75 P@10, # Precision at 10 AP, # Mean Average Precision RR, # Reciprocal Rank nDCG@20, # Additional depth Recall@100, # Extended recall ] # 计算聚合指标 aggregate = ir_measures.calc_aggregate(measures, qrels, run) # 计算逐查询指标 per_query = {} for metric in ir_measures.iter_calc(measures, qrels, run): qid = metric.query_id if qid not in per_query: per_query[qid] = {} per_query[qid][str(metric.measure)] = metric.value return { "aggregate": {str(k): v for k, v in aggregate.items()}, "per_query": per_query, "pass_fail": { "ndcg@10": aggregate.get(nDCG@10, 0) >= 0.60, "recall@50": aggregate.get(Recall@50, 0) >= 0.75, } } def evaluate_with_eligibility_levels( qrels_path: str, run_path: str, ) -> dict: """Evaluate with TREC CT graded relevance (0=NR, 1=Excluded, 2=Eligible). Uses rel=2 for strict eligible-only evaluation. """ qrels = list(ir_measures.read_trec_qrels(qrels_path)) run = list(ir_measures.read_trec_run(run_path)) # Standard evaluation (relevance >= 1) standard_measures = [nDCG@10, Recall@50, P@10] standard = ir_measures.calc_aggregate(standard_measures, qrels, run) # Strict evaluation (only eligible = relevance 2) strict_measures = [ AP(rel=2), P(rel=2)@10, Recall(rel=2)@50, ] strict = ir_measures.calc_aggregate(strict_measures, qrels, run) return { "standard": {str(k): v for k, v in standard.items()}, "strict_eligible_only": {str(k): v for k, v in strict.items()}, } ``` ### 4.6 使用 ir_datasets 的替代 qrels/run 格式 ```python def evaluate_from_dicts( qrels_dict: dict[str, dict[str, int]], run_dict: dict[str, list[tuple[str, float]]], ) -> dict: """Evaluate using Python dict format (no files needed). Args: qrels_dict: {query_id: {doc_id: relevance}} run_dict: {query_id: [(doc_id, score), ...]} """ # Convert to ir-measures format qrels = [ ir_measures.Qrel(qid, did, rel) for qid, docs in qrels_dict.items() for did, rel in docs.items() ] run = [ ir_measures.ScoredDoc(qid, did, score) for qid, docs in run_dict.items() for did, score in docs ] measures = [nDCG@10, Recall@50, P@10, AP] aggregate = ir_measures.calc_aggregate(measures, qrels, run) return {str(k): v for k, v in aggregate.items()} ``` --- ## 5. MedGemma 提取评估 ### 5.1 标注数据集设计 ```python # evaluation/extraction_eval.py from dataclasses import dataclass from typing import Optional @dataclass class AnnotatedField: """A single annotated field with ground truth and extraction result.""" field_name: str # e.g., "biomarkers.egfr" ground_truth: Optional[str] # From Synthea profile (gold standard) extracted: Optional[str] # From MedGemma extraction evidence_span: Optional[str] # Text span in source document source_page: Optional[int] # Page number in PDF @dataclass class ExtractionAnnotation: """Complete annotation for one patient's extraction.""" patient_id: str fields: list[AnnotatedField] noise_level: str # "clean", "mild", "moderate", "severe" document_type: str # "clinical_letter", "pathology_report", etc. ``` **标注数据集结构:** ```json { "patient_id": "synth-001", "noise_level": "mild", "document_type": "clinical_letter", "fields": [ { "field_name": "demographics.name", "ground_truth": "John Smith", "extracted": "John Smith", "correct": true }, { "field_name": "diagnosis.stage", "ground_truth": "Stage IIIA", "extracted": "Stage 3A", "correct": true, "note": "Equivalent representation" }, { "field_name": "biomarkers.egfr", "ground_truth": "Exon 19 deletion", "extracted": "EGFR positive", "correct": false, "note": "Partial extraction - missing specific mutation" } ] } ``` ### 5.2 字段级 F1 计算 ```python # evaluation/extraction_eval.py from sklearn.metrics import ( f1_score, precision_score, recall_score, classification_report, confusion_matrix ) import numpy as np # 定义所有可提取字段 EXTRACTION_FIELDS = [ "demographics.name", "demographics.sex", "demographics.date_of_birth", "demographics.age", "diagnosis.primary", "diagnosis.stage", "diagnosis.histology", "biomarkers.egfr", "biomarkers.alk", "biomarkers.pdl1_tps", "biomarkers.kras", "biomarkers.ros1", "labs.wbc", "labs.hemoglobin", "labs.platelets", "labs.creatinine", "labs.alt", "labs.ast", "treatments.current_regimen", "performance_status.ecog", ] def compute_field_level_f1( annotations: list[dict], ) -> dict: """Compute field-level F1, precision, recall. For each field: - TP: ground_truth exists AND extracted matches - FP: extracted exists BUT ground_truth is None or mismatch - FN: ground_truth exists BUT extracted is None or mismatch Args: annotations: List of patient annotation dicts Returns: Per-field and aggregate metrics """ field_metrics = {} for field_name in EXTRACTION_FIELDS: y_true = [] # 1 if field has ground truth value y_pred = [] # 1 if field was correctly extracted for ann in annotations: fields = {f["field_name"]: f for f in ann["fields"]} if field_name in fields: f = fields[field_name] has_gt = f["ground_truth"] is not None is_correct = f.get("correct", False) y_true.append(1 if has_gt else 0) y_pred.append(1 if is_correct else 0) if len(y_true) > 0: precision = precision_score(y_true, y_pred, zero_division=0) recall = recall_score(y_true, y_pred, zero_division=0) f1 = f1_score(y_true, y_pred, zero_division=0) field_metrics[field_name] = { "precision": round(precision, 4), "recall": round(recall, 4), "f1": round(f1, 4), "support": sum(y_true), } # Aggregate metrics all_y_true = [] all_y_pred = [] for ann in annotations: for f in ann["fields"]: has_gt = f["ground_truth"] is not None is_correct = f.get("correct", False) all_y_true.append(1 if has_gt else 0) all_y_pred.append(1 if is_correct else 0) micro_f1 = f1_score(all_y_true, all_y_pred, zero_division=0) macro_f1 = np.mean([m["f1"] for m in field_metrics.values()]) return { "per_field": field_metrics, "micro_f1": round(micro_f1, 4), "macro_f1": round(macro_f1, 4), "total_fields": len(all_y_true), "pass": micro_f1 >= 0.85, # Target: F1 >= 0.85 } def compute_extraction_report(annotations: list[dict]) -> str: """Generate a scikit-learn classification_report style output.""" all_y_true = [] all_y_pred = [] labels = [] for field_name in EXTRACTION_FIELDS: for ann in annotations: fields = {f["field_name"]: f for f in ann["fields"]} if field_name in fields: f = fields[field_name] has_gt = f["ground_truth"] is not None is_correct = f.get("correct", False) all_y_true.append(1 if has_gt else 0) all_y_pred.append(1 if is_correct else 0) return classification_report( all_y_true, all_y_pred, target_names=["absent", "present/correct"], digits=4, ) def compare_with_baseline( medgemma_annotations: list[dict], gemini_only_annotations: list[dict], ) -> dict: """Compare MedGemma extraction vs Gemini-only baseline.""" medgemma_metrics = compute_field_level_f1(medgemma_annotations) gemini_metrics = compute_field_level_f1(gemini_only_annotations) comparison = {} for field_name in EXTRACTION_FIELDS: mg = medgemma_metrics["per_field"].get(field_name, {}) gm = gemini_metrics["per_field"].get(field_name, {}) comparison[field_name] = { "medgemma_f1": mg.get("f1", 0), "gemini_f1": gm.get("f1", 0), "delta": round(mg.get("f1", 0) - gm.get("f1", 0), 4), } return { "per_field_comparison": comparison, "medgemma_overall_f1": medgemma_metrics["micro_f1"], "gemini_overall_f1": gemini_metrics["micro_f1"], "improvement": round( medgemma_metrics["micro_f1"] - gemini_metrics["micro_f1"], 4 ), } ``` ### 5.3 噪声级别对提取性能的影响分析 ```python def analyze_noise_impact(annotations: list[dict]) -> dict: """Analyze how noise level affects extraction F1.""" by_noise = {} for ann in annotations: level = ann["noise_level"] if level not in by_noise: by_noise[level] = [] by_noise[level].append(ann) results = {} for level, level_anns in by_noise.items(): metrics = compute_field_level_f1(level_anns) results[level] = { "micro_f1": metrics["micro_f1"], "macro_f1": metrics["macro_f1"], "n_patients": len(level_anns), } return results ``` --- ## 6. 端到端评估管线 ### 6.1 Criterion Decision Accuracy ```python # evaluation/criterion_eval.py def compute_criterion_accuracy( predictions: list[dict], ground_truth: list[dict], ) -> dict: """Compute criterion-level decision accuracy. Each prediction/ground_truth entry: { "patient_id": str, "trial_id": str, "criteria": [ {"criterion_id": str, "decision": "met"|"not_met"|"unknown", "evidence": str} ] } Target: >= 0.85 """ total = 0 correct = 0 by_decision_type = {"met": {"tp": 0, "total": 0}, "not_met": {"tp": 0, "total": 0}, "unknown": {"tp": 0, "total": 0}} for pred, gt in zip(predictions, ground_truth): assert pred["patient_id"] == gt["patient_id"] assert pred["trial_id"] == gt["trial_id"] gt_map = {c["criterion_id"]: c["decision"] for c in gt["criteria"]} for criterion in pred["criteria"]: cid = criterion["criterion_id"] if cid in gt_map: total += 1 gt_decision = gt_map[cid] pred_decision = criterion["decision"] by_decision_type[gt_decision]["total"] += 1 if pred_decision == gt_decision: correct += 1 by_decision_type[gt_decision]["tp"] += 1 accuracy = correct / total if total > 0 else 0.0 return { "overall_accuracy": round(accuracy, 4), "total_criteria": total, "correct": correct, "pass": accuracy >= 0.85, "by_decision_type": { k: { "accuracy": round(v["tp"] / v["total"], 4) if v["total"] > 0 else 0, "support": v["total"], } for k, v in by_decision_type.items() }, } ``` ### 6.2 延迟基准测试 ```python # evaluation/latency_cost_tracker.py import time import json from dataclasses import dataclass, field, asdict from typing import Optional from contextlib import contextmanager @dataclass class APICallRecord: """Record of a single API call.""" service: str # "medgemma", "gemini", "clinicaltrials_mcp" operation: str # "extract", "search", "evaluate_criterion" latency_ms: float input_tokens: int = 0 output_tokens: int = 0 cost_usd: float = 0.0 timestamp: str = "" @dataclass class SessionMetrics: """Aggregate metrics for a patient matching session.""" patient_id: str total_latency_ms: float = 0.0 total_cost_usd: float = 0.0 api_calls: list[APICallRecord] = field(default_factory=list) @property def total_latency_s(self) -> float: return self.total_latency_ms / 1000.0 @property def pass_latency(self) -> bool: """Target: < 15s per session.""" return self.total_latency_s < 15.0 @property def pass_cost(self) -> bool: """Target: < $0.50 per session.""" return self.total_cost_usd < 0.50 class LatencyCostTracker: """Track latency and cost across API calls.""" # Pricing per 1M tokens (approximate) PRICING = { "medgemma": {"input": 0.0, "output": 0.0}, # Self-hosted "gemini": {"input": 1.25, "output": 5.00}, # Gemini Pro "clinicaltrials_mcp": {"input": 0.0, "output": 0.0}, # Free API } def __init__(self): self.sessions: list[SessionMetrics] = [] self._current_session: Optional[SessionMetrics] = None def start_session(self, patient_id: str): self._current_session = SessionMetrics(patient_id=patient_id) def end_session(self) -> SessionMetrics: session = self._current_session if session: session.total_latency_ms = sum(c.latency_ms for c in session.api_calls) session.total_cost_usd = sum(c.cost_usd for c in session.api_calls) self.sessions.append(session) self._current_session = None return session @contextmanager def track_call(self, service: str, operation: str): """Context manager to track an API call.""" start = time.monotonic() record = APICallRecord(service=service, operation=operation, latency_ms=0) try: yield record finally: record.latency_ms = (time.monotonic() - start) * 1000 # Compute cost pricing = self.PRICING.get(service, {"input": 0, "output": 0}) record.cost_usd = ( record.input_tokens * pricing["input"] / 1_000_000 + record.output_tokens * pricing["output"] / 1_000_000 ) if self._current_session: self._current_session.api_calls.append(record) def summary(self) -> dict: """Generate aggregate summary across all sessions.""" if not self.sessions: return {} latencies = [s.total_latency_s for s in self.sessions] costs = [s.total_cost_usd for s in self.sessions] return { "n_sessions": len(self.sessions), "latency": { "mean_s": round(sum(latencies) / len(latencies), 2), "p50_s": round(sorted(latencies)[len(latencies) // 2], 2), "p95_s": round(sorted(latencies)[int(len(latencies) * 0.95)], 2), "max_s": round(max(latencies), 2), "pass_rate": round( sum(1 for s in self.sessions if s.pass_latency) / len(self.sessions), 4 ), }, "cost": { "mean_usd": round(sum(costs) / len(costs), 4), "total_usd": round(sum(costs), 4), "max_usd": round(max(costs), 4), "pass_rate": round( sum(1 for s in self.sessions if s.pass_cost) / len(self.sessions), 4 ), }, "targets": { "latency_pass": all(s.pass_latency for s in self.sessions), "cost_pass": all(s.pass_cost for s in self.sessions), }, } ``` --- ## 7. TDD 测试用例 ### 7.1 Synthea 数据验证测试 ```python # tests/test_synthea_data.py import pytest import json from pathlib import Path # 预期的 FHIR Resource 类型 REQUIRED_RESOURCE_TYPES = {"Patient", "Condition", "Observation", "Encounter"} class TestSyntheaDataValidation: """Validate Synthea FHIR output for TrialPath requirements.""" def test_fhir_bundle_is_valid_json(self, fhir_file): """Bundle must be valid JSON.""" with open(fhir_file) as f: data = json.load(f) assert data["resourceType"] == "Bundle" assert "entry" in data def test_bundle_contains_required_resources(self, fhir_file): """Bundle must contain Patient, Condition, Observation, Encounter.""" with open(fhir_file) as f: bundle = json.load(f) resource_types = { e["resource"]["resourceType"] for e in bundle["entry"] } for rt in REQUIRED_RESOURCE_TYPES: assert rt in resource_types, f"Missing {rt} resource" def test_patient_has_demographics(self, fhir_file): """Patient resource must have name, gender, birthDate.""" with open(fhir_file) as f: bundle = json.load(f) patients = [ e["resource"] for e in bundle["entry"] if e["resource"]["resourceType"] == "Patient" ] assert len(patients) == 1 patient = patients[0] assert "name" in patient assert "gender" in patient assert "birthDate" in patient def test_lung_cancer_condition_present(self, fhir_file): """At least one Condition must be NSCLC or lung cancer.""" with open(fhir_file) as f: bundle = json.load(f) conditions = [ e["resource"] for e in bundle["entry"] if e["resource"]["resourceType"] == "Condition" ] lung_cancer_codes = {"254637007", "254632001", "162573006"} has_lung_cancer = False for cond in conditions: codings = cond.get("code", {}).get("coding", []) for c in codings: if c.get("code") in lung_cancer_codes: has_lung_cancer = True assert has_lung_cancer, "No lung cancer Condition found" def test_patient_profile_conversion(self, fhir_file): """FHIR Bundle must convert to valid PatientProfile.""" profile = parse_fhir_bundle(Path(fhir_file)) assert profile.patient_id != "" assert profile.demographics.name != "" assert profile.demographics.sex in ("male", "female") assert profile.diagnosis.primary != "" def test_batch_generation_produces_500_patients(self, output_dir): """Batch generation must produce at least 500 FHIR files.""" fhir_files = list(Path(output_dir).glob("*.json")) assert len(fhir_files) >= 500 def test_nsclc_ratio(self, all_profiles): """~85% of lung cancer patients should be NSCLC.""" nsclc_count = sum( 1 for p in all_profiles if "non-small cell" in p.diagnosis.primary.lower() or "nsclc" in p.diagnosis.primary.lower() ) ratio = nsclc_count / len(all_profiles) assert 0.70 <= ratio <= 0.95, f"NSCLC ratio {ratio} outside expected range" ``` ### 7.2 PDF 生成正确性测试 ```python # tests/test_pdf_generation.py import pytest from pathlib import Path from data.templates.clinical_letter import generate_clinical_letter from data.templates.pathology_report import generate_pathology_report from data.templates.lab_report import generate_lab_report class TestPDFGeneration: """Test that PDF generation produces valid documents.""" SAMPLE_PROFILE = { "patient_id": "test-001", "demographics": { "name": "Jane Doe", "sex": "female", "date_of_birth": "1960-05-15", }, "diagnosis": { "primary": "Non-small cell lung cancer, adenocarcinoma", "stage": "Stage IIIA", "histology": "adenocarcinoma", "diagnosis_date": "2024-01-15", }, "biomarkers": { "egfr": "Exon 19 deletion", "alk": "Negative", "pdl1_tps": "60%", "kras": None, }, "labs": [ {"name": "WBC", "value": 7.2, "unit": "10*3/uL", "date": "2024-01-10", "loinc_code": "6690-2"}, {"name": "Hemoglobin", "value": 12.5, "unit": "g/dL", "date": "2024-01-10", "loinc_code": "718-7"}, ], "treatments": [ {"name": "Cisplatin", "type": "medication", "start_date": "2024-02-01"}, ], } def test_clinical_letter_generates_pdf(self, tmp_path): """Clinical letter must generate a non-empty PDF file.""" output = tmp_path / "letter.pdf" generate_clinical_letter(self.SAMPLE_PROFILE, str(output)) assert output.exists() assert output.stat().st_size > 0 def test_pathology_report_generates_pdf(self, tmp_path): """Pathology report must generate a non-empty PDF file.""" output = tmp_path / "pathology.pdf" generate_pathology_report(self.SAMPLE_PROFILE, str(output)) assert output.exists() assert output.stat().st_size > 0 def test_lab_report_generates_pdf(self, tmp_path): """Lab report must generate a non-empty PDF file.""" output = tmp_path / "lab.pdf" generate_lab_report(self.SAMPLE_PROFILE, str(output)) assert output.exists() assert output.stat().st_size > 0 def test_pdf_contains_patient_name(self, tmp_path): """Generated PDF must contain patient name (OCR-verifiable).""" output = tmp_path / "letter.pdf" generate_clinical_letter(self.SAMPLE_PROFILE, str(output)) # Read PDF text (using pdfplumber or PyPDF2) import pdfplumber with pdfplumber.open(str(output)) as pdf: text = "" for page in pdf.pages: text += page.extract_text() or "" assert "Jane Doe" in text def test_pdf_contains_biomarkers(self, tmp_path): """Generated PDF must contain biomarker results.""" output = tmp_path / "pathology.pdf" generate_pathology_report(self.SAMPLE_PROFILE, str(output)) import pdfplumber with pdfplumber.open(str(output)) as pdf: text = "" for page in pdf.pages: text += page.extract_text() or "" assert "EGFR" in text assert "Exon 19" in text or "positive" in text.lower() def test_missing_biomarker_handled_gracefully(self, tmp_path): """PDF generation should not crash when biomarkers are None.""" profile = self.SAMPLE_PROFILE.copy() profile["biomarkers"] = { "egfr": None, "alk": None, "pdl1_tps": None, "kras": None } output = tmp_path / "letter.pdf" generate_clinical_letter(profile, str(output)) assert output.exists() ``` ### 7.3 噪声注入效果验证测试 ```python # tests/test_noise_injection.py import pytest from data.noise.noise_injector import NoiseInjector class TestNoiseInjection: """Test noise injection produces expected results.""" def test_clean_noise_no_changes(self): """Clean level should produce no changes.""" injector = NoiseInjector(noise_level="clean", seed=42) text = "Patient has EGFR mutation positive" noisy, records = injector.inject_text_noise(text) assert noisy == text assert len(records) == 0 def test_mild_noise_produces_some_changes(self): """Mild noise should produce some but limited changes.""" injector = NoiseInjector(noise_level="mild", seed=42) # Use longer text to increase chance of noise text = "The patient is a 65 year old male with stage IIIA " * 10 noisy, records = injector.inject_text_noise(text) # Should have some changes but not too many assert len(records) >= 0 # May or may not have changes depending on seed def test_severe_noise_produces_many_changes(self): """Severe noise should produce noticeable changes.""" injector = NoiseInjector(noise_level="severe", seed=42) text = "The 50 year old patient has stage 1 NSCLC " * 20 noisy, records = injector.inject_text_noise(text) assert noisy != text # Should differ from original assert len(records) > 0 def test_ocr_error_types_are_valid(self): """OCR errors should only substitute known character pairs.""" injector = NoiseInjector(noise_level="severe", seed=42) text = "0123456789 OIBS" * 10 _, records = injector.inject_text_noise(text) for r in records: if r["type"] == "ocr_error": assert r["original"] in NoiseInjector.OCR_ERROR_MAP assert r["replacement"] in NoiseInjector.OCR_ERROR_MAP[r["original"]] def test_missing_value_injection(self): """Missing value injection should remove some fields.""" injector = NoiseInjector(noise_level="moderate", seed=42) profile = { "biomarkers": {"egfr": "positive", "alk": "negative", "pdl1_tps": "60%", "kras": "negative", "ros1": "negative"}, "diagnosis": {"stage": "IIIA", "histology": "adenocarcinoma"}, } modified, removed = injector.inject_missing_values(profile) # At 10% rate with 7 fields, expect 0-3 removals assert len(removed) <= 7 for field_path in removed: section, field_name = field_path.split(".") assert modified[section][field_name] is None def test_noise_is_deterministic_with_seed(self): """Same seed should produce identical results.""" text = "Patient has stage IIIA non-small cell lung cancer" inj1 = NoiseInjector(noise_level="moderate", seed=123) inj2 = NoiseInjector(noise_level="moderate", seed=123) noisy1, _ = inj1.inject_text_noise(text) noisy2, _ = inj2.inject_text_noise(text) assert noisy1 == noisy2 def test_different_seeds_produce_different_results(self): """Different seeds should generally produce different noise.""" text = "The 50 year old patient has 10 biomarker tests 0 1 5 8" * 20 inj1 = NoiseInjector(noise_level="severe", seed=1) inj2 = NoiseInjector(noise_level="severe", seed=999) noisy1, _ = inj1.inject_text_noise(text) noisy2, _ = inj2.inject_text_noise(text) # With severe noise on long text, different seeds should differ assert noisy1 != noisy2 ``` ### 7.4 TREC 评估计算测试 ```python # tests/test_trec_evaluation.py import pytest import ir_measures from ir_measures import nDCG, Recall, P, AP class TestTRECEvaluation: """Test TREC evaluation metric computation.""" @pytest.fixture def sample_qrels(self): """Sample qrels with known ground truth.""" return [ ir_measures.Qrel("q1", "d1", 2), # eligible ir_measures.Qrel("q1", "d2", 1), # excluded ir_measures.Qrel("q1", "d3", 0), # not relevant ir_measures.Qrel("q1", "d4", 2), # eligible ir_measures.Qrel("q1", "d5", 0), # not relevant ] @pytest.fixture def perfect_run(self): """Run that ranks all relevant docs at top.""" return [ ir_measures.ScoredDoc("q1", "d1", 1.0), ir_measures.ScoredDoc("q1", "d4", 0.9), ir_measures.ScoredDoc("q1", "d2", 0.8), ir_measures.ScoredDoc("q1", "d3", 0.1), ir_measures.ScoredDoc("q1", "d5", 0.05), ] @pytest.fixture def worst_run(self): """Run that ranks relevant docs at bottom.""" return [ ir_measures.ScoredDoc("q1", "d3", 1.0), ir_measures.ScoredDoc("q1", "d5", 0.9), ir_measures.ScoredDoc("q1", "d2", 0.5), ir_measures.ScoredDoc("q1", "d4", 0.2), ir_measures.ScoredDoc("q1", "d1", 0.1), ] def test_perfect_ndcg_at_10(self, sample_qrels, perfect_run): """Perfect ranking should yield NDCG@10 = 1.0.""" result = ir_measures.calc_aggregate([nDCG@10], sample_qrels, perfect_run) assert result[nDCG@10] == pytest.approx(1.0, abs=0.01) def test_worst_ndcg_lower(self, sample_qrels, perfect_run, worst_run): """Worst ranking should yield lower NDCG than perfect.""" perfect = ir_measures.calc_aggregate([nDCG@10], sample_qrels, perfect_run) worst = ir_measures.calc_aggregate([nDCG@10], sample_qrels, worst_run) assert worst[nDCG@10] < perfect[nDCG@10] def test_recall_at_50_perfect(self, sample_qrels, perfect_run): """Perfect run should retrieve all relevant docs.""" result = ir_measures.calc_aggregate([Recall@50], sample_qrels, perfect_run) assert result[Recall@50] == pytest.approx(1.0, abs=0.01) def test_empty_run_yields_zero(self, sample_qrels): """Empty run should yield 0 for all metrics.""" empty_run = [] result = ir_measures.calc_aggregate( [nDCG@10, Recall@50, P@10], sample_qrels, empty_run ) assert result[nDCG@10] == 0.0 assert result[Recall@50] == 0.0 assert result[P@10] == 0.0 def test_per_query_results(self, sample_qrels, perfect_run): """Per-query results should return one entry per query.""" results = list(ir_measures.iter_calc( [nDCG@10], sample_qrels, perfect_run )) assert len(results) == 1 # Only q1 assert results[0].query_id == "q1" def test_trec_run_format_conversion(self): """Test TrialPath results to TREC format conversion.""" results = { "1": [ {"nct_id": "NCT001", "score": 0.95}, {"nct_id": "NCT002", "score": 0.80}, ] } run_str = convert_trialpath_to_trec_run(results, "test-run") lines = run_str.strip().split("\n") assert len(lines) == 2 assert "NCT001" in lines[0] assert "1" == lines[0].split()[3] # rank 1 assert "2" == lines[1].split()[3] # rank 2 def test_graded_relevance_evaluation(self, sample_qrels, perfect_run): """Test strict eligible-only evaluation (rel=2).""" strict = ir_measures.calc_aggregate( [AP(rel=2)], sample_qrels, perfect_run ) assert strict[AP(rel=2)] > 0.0 def test_qrels_dict_format(self): """Test evaluation from dict format.""" qrels = {"q1": {"d1": 2, "d2": 1, "d3": 0}} run = [ ir_measures.ScoredDoc("q1", "d1", 1.0), ir_measures.ScoredDoc("q1", "d2", 0.5), ir_measures.ScoredDoc("q1", "d3", 0.1), ] result = ir_measures.calc_aggregate([nDCG@10], qrels, run) assert nDCG@10 in result ``` ### 7.5 F1 计算测试 ```python # tests/test_extraction_f1.py import pytest from evaluation.extraction_eval import compute_field_level_f1 class TestExtractionF1: """Test F1 computation for field-level extraction.""" def test_perfect_extraction(self): """All fields correctly extracted should yield F1=1.0.""" annotations = [{ "patient_id": "p1", "noise_level": "clean", "document_type": "clinical_letter", "fields": [ {"field_name": "demographics.name", "ground_truth": "John", "extracted": "John", "correct": True}, {"field_name": "demographics.sex", "ground_truth": "male", "extracted": "male", "correct": True}, {"field_name": "diagnosis.primary", "ground_truth": "NSCLC", "extracted": "NSCLC", "correct": True}, {"field_name": "biomarkers.egfr", "ground_truth": "positive", "extracted": "positive", "correct": True}, ] }] result = compute_field_level_f1(annotations) assert result["micro_f1"] == 1.0 assert result["pass"] is True def test_zero_extraction(self): """No correct extractions should yield F1=0.""" annotations = [{ "patient_id": "p1", "noise_level": "clean", "document_type": "clinical_letter", "fields": [ {"field_name": "demographics.name", "ground_truth": "John", "extracted": "Jane", "correct": False}, {"field_name": "diagnosis.primary", "ground_truth": "NSCLC", "extracted": None, "correct": False}, ] }] result = compute_field_level_f1(annotations) assert result["micro_f1"] == 0.0 assert result["pass"] is False def test_partial_extraction(self): """Partial extraction should yield 0 < F1 < 1.""" annotations = [{ "patient_id": "p1", "noise_level": "mild", "document_type": "clinical_letter", "fields": [ {"field_name": "demographics.name", "ground_truth": "John", "extracted": "John", "correct": True}, {"field_name": "diagnosis.primary", "ground_truth": "NSCLC", "extracted": "lung ca", "correct": False}, {"field_name": "biomarkers.egfr", "ground_truth": "positive", "extracted": "positive", "correct": True}, {"field_name": "biomarkers.alk", "ground_truth": "negative", "extracted": None, "correct": False}, ] }] result = compute_field_level_f1(annotations) assert 0.0 < result["micro_f1"] < 1.0 def test_f1_threshold_boundary(self): """F1 exactly at 0.85 should pass.""" # Create annotations that produce exactly 0.85 F1 fields = [] for i in range(85): fields.append({"field_name": f"field_{i}", "ground_truth": "val", "extracted": "val", "correct": True}) for i in range(15): fields.append({"field_name": f"field_miss_{i}", "ground_truth": "val", "extracted": None, "correct": False}) annotations = [{"patient_id": "p1", "noise_level": "clean", "document_type": "test", "fields": fields}] result = compute_field_level_f1(annotations) # With 85/100 correct, F1 should be ~0.85 assert result["pass"] is True def test_empty_annotations(self): """Empty annotations should not crash.""" result = compute_field_level_f1([]) assert result["micro_f1"] == 0.0 def test_none_ground_truth_not_counted(self): """Fields with None ground truth should be handled.""" annotations = [{ "patient_id": "p1", "noise_level": "clean", "document_type": "test", "fields": [ {"field_name": "biomarkers.ros1", "ground_truth": None, "extracted": None, "correct": False}, ] }] result = compute_field_level_f1(annotations) # Should not crash, though metrics may be 0 assert "micro_f1" in result ``` ### 7.6 端到端管线测试 ```python # tests/test_e2e_pipeline.py import pytest from pathlib import Path class TestE2EPipeline: """End-to-end tests for the complete data & evaluation pipeline.""" def test_fhir_to_profile_to_pdf_roundtrip(self, sample_fhir_file, tmp_path): """FHIR → PatientProfile → PDF should complete without error.""" from data.generate_synthetic_patients import parse_fhir_bundle from data.templates.clinical_letter import generate_clinical_letter from dataclasses import asdict # Step 1: Parse FHIR profile = parse_fhir_bundle(Path(sample_fhir_file)) assert profile.patient_id != "" # Step 2: Generate PDF pdf_path = tmp_path / "test_roundtrip.pdf" generate_clinical_letter(asdict(profile), str(pdf_path)) assert pdf_path.exists() assert pdf_path.stat().st_size > 1000 # Reasonable PDF size def test_noisy_pdf_pipeline(self, sample_profile, tmp_path): """Profile → Noisy PDF should inject noise and produce valid PDF.""" from data.templates.clinical_letter import generate_clinical_letter from data.noise.noise_injector import NoiseInjector injector = NoiseInjector(noise_level="moderate", seed=42) # Inject text noise into profile fields for PDF rendering profile = sample_profile.copy() dx_text = profile["diagnosis"]["primary"] noisy_dx, records = injector.inject_text_noise(dx_text) profile["diagnosis"]["primary"] = noisy_dx pdf_path = tmp_path / "noisy.pdf" generate_clinical_letter(profile, str(pdf_path)) assert pdf_path.exists() def test_trec_evaluation_pipeline(self, tmp_path): """Complete TREC evaluation from dicts should produce metrics.""" import ir_measures from ir_measures import nDCG, Recall, P qrels = [ ir_measures.Qrel("1", "NCT001", 2), ir_measures.Qrel("1", "NCT002", 1), ir_measures.Qrel("1", "NCT003", 0), ] run = [ ir_measures.ScoredDoc("1", "NCT001", 0.9), ir_measures.ScoredDoc("1", "NCT002", 0.5), ir_measures.ScoredDoc("1", "NCT003", 0.1), ] result = ir_measures.calc_aggregate( [nDCG@10, Recall@50, P@10], qrels, run ) assert nDCG@10 in result assert Recall@50 in result assert result[nDCG@10] > 0 def test_latency_tracker_integration(self): """Latency tracker should record and summarize calls.""" import time from evaluation.latency_cost_tracker import LatencyCostTracker tracker = LatencyCostTracker() tracker.start_session("test-patient") with tracker.track_call("gemini", "search_anchors") as record: time.sleep(0.01) # Simulate API call record.input_tokens = 500 record.output_tokens = 200 session = tracker.end_session() assert session.total_latency_ms > 0 assert len(session.api_calls) == 1 summary = tracker.summary() assert summary["n_sessions"] == 1 assert summary["latency"]["mean_s"] > 0 ``` --- ## 8. 附录 ### 8.1 数据格式规范 #### PatientProfile v1 JSON Schema ```json { "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "required": ["patient_id", "demographics", "diagnosis"], "properties": { "patient_id": {"type": "string"}, "demographics": { "type": "object", "properties": { "name": {"type": "string"}, "sex": {"type": "string", "enum": ["male", "female"]}, "date_of_birth": {"type": "string", "format": "date"}, "age": {"type": "integer"}, "state": {"type": "string"} } }, "diagnosis": { "type": "object", "properties": { "primary": {"type": "string"}, "stage": {"type": ["string", "null"]}, "histology": {"type": ["string", "null"]}, "diagnosis_date": {"type": "string", "format": "date"} } }, "biomarkers": { "type": "object", "properties": { "egfr": {"type": ["string", "null"]}, "alk": {"type": ["string", "null"]}, "pdl1_tps": {"type": ["string", "null"]}, "kras": {"type": ["string", "null"]}, "ros1": {"type": ["string", "null"]} } }, "labs": { "type": "array", "items": { "type": "object", "properties": { "name": {"type": "string"}, "value": {"type": "number"}, "unit": {"type": "string"}, "date": {"type": "string"}, "loinc_code": {"type": "string"} } } }, "treatments": { "type": "array", "items": { "type": "object", "properties": { "name": {"type": "string"}, "type": {"type": "string", "enum": ["medication", "procedure", "radiation"]}, "start_date": {"type": "string"}, "end_date": {"type": ["string", "null"]} } } }, "unknowns": {"type": "array", "items": {"type": "string"}}, "evidence_spans": {"type": "array"} } } ``` ### 8.2 工具 API 参考 #### ir_datasets | API | 说明 | 返回类型 | |-----|------|----------| | `ir_datasets.load("clinicaltrials/2021/trec-ct-2021")` | 加载 TREC CT 2021 数据集 | Dataset | | `dataset.queries_iter()` | 遍历 topics | GenericQuery(query_id, text) | | `dataset.qrels_iter()` | 遍历 qrels | TrecQrel(query_id, doc_id, relevance, iteration) | | `dataset.docs_iter()` | 遍历文档 | ClinicalTrialsDoc(doc_id, title, condition, summary, detailed_description, eligibility) | **数据集 ID:** - `clinicaltrials/2021/trec-ct-2021` — 75 queries, 35,832 qrels - `clinicaltrials/2021/trec-ct-2022` — 50 queries - `clinicaltrials/2021` — 376K 文档(基础集) #### ir-measures | API | 说明 | |-----|------| | `ir_measures.calc_aggregate(measures, qrels, run)` | 计算聚合指标 | | `ir_measures.iter_calc(measures, qrels, run)` | 逐查询指标迭代 | | `ir_measures.read_trec_qrels(path)` | 读取 TREC qrels 文件 | | `ir_measures.read_trec_run(path)` | 读取 TREC run 文件 | | `ir_measures.Qrel(qid, did, rel)` | 创建 qrel 记录 | | `ir_measures.ScoredDoc(qid, did, score)` | 创建评分文档记录 | **指标对象:** - `nDCG@10` — Normalized DCG at cutoff 10 - `Recall@50` — Recall at cutoff 50 - `P@10` — Precision at cutoff 10 - `AP` — Average Precision - `AP(rel=2)` — AP with minimum relevance 2 - `RR` — Reciprocal Rank #### scikit-learn 评估 | API | 说明 | |-----|------| | `f1_score(y_true, y_pred, average=None)` | 逐类别 F1 | | `f1_score(y_true, y_pred, average='micro')` | 全局 micro F1 | | `f1_score(y_true, y_pred, average='macro')` | 逐类别平均 F1 | | `precision_score(y_true, y_pred)` | Precision | | `recall_score(y_true, y_pred)` | Recall | | `classification_report(y_true, y_pred)` | 完整分类报告 | | `confusion_matrix(y_true, y_pred)` | 混淆矩阵 | #### Synthea CLI | 参数 | 说明 | 示例 | |------|------|------| | `-p N` | 生成 N 个患者 | `-p 500` | | `-s SEED` | 随机种子 | `-s 42` | | `-m MODULE` | 指定疾病模块 | `-m lung_cancer` | | `STATE` | 指定州 | `Massachusetts` | | `--exporter.fhir.export` | 启用 FHIR R4 导出 | `=true` | | `--exporter.pretty_print` | 美化 JSON 输出 | `=true` | #### ReportLab 核心 API | 组件 | 说明 | |------|------| | `SimpleDocTemplate(path, pagesize=letter)` | 创建文档模板 | | `Paragraph(text, style)` | 段落流式组件 | | `Table(data, colWidths)` | 表格流式组件 | | `TableStyle(commands)` | 表格样式 | | `Spacer(width, height)` | 间距组件 | | `getSampleStyleSheet()` | 获取默认样式表 | #### Augraphy 降质管线 | 组件 | 说明 | |------|------| | `AugraphyPipeline(ink_phase, paper_phase, post_phase)` | 完整降质管线 | | `InkBleed(p=0.5)` | 墨水渗透效果 | | `Letterpress(p=0.3)` | 活版印刷效果 | | `LowInkPeriodicLines(p=0.3)` | 低墨水周期性线条 | | `DirtyDrum(p=0.3)` | 脏鼓效果 | | `SubtleNoise(p=0.5)` | 微噪声 | | `Jpeg(p=0.5)` | JPEG 压缩伪影 | | `Brightness(p=0.5)` | 亮度变化 | ### 8.3 Python 依赖清单 ``` # requirements-data-eval.txt ir-datasets>=0.5.6 ir-measures>=0.3.1 reportlab>=4.0 augraphy>=8.0 Pillow>=10.0 pdfplumber>=0.10 scikit-learn>=1.3 numpy>=1.24 pandas>=2.0 pdf2image>=1.16 ``` ### 8.4 成功指标速查表 | 指标 | 目标值 | 评估工具 | 数据源 | |------|--------|----------|--------| | MedGemma Extraction F1 | >= 0.85 | scikit-learn `f1_score` | 合成患者 + Ground Truth | | Trial Retrieval Recall@50 | >= 0.75 | ir-measures `Recall@50` | TREC CT 2021/2022 | | Trial Ranking NDCG@10 | >= 0.60 | ir-measures `nDCG@10` | TREC CT 2021/2022 | | Criterion Decision Accuracy | >= 0.85 | Custom accuracy | 标注 EligibilityLedger | | Latency | < 15s | `LatencyCostTracker` | API call timing | | Cost | < $0.50/session | `LatencyCostTracker` | Token counting |