| import os |
| import re |
| import gradio as gr |
| import numpy as np |
| import faiss |
| from pypdf import PdfReader |
| from openai import OpenAI |
|
|
| |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() |
| if not OPENAI_API_KEY: |
| raise RuntimeError("OPENAI_API_KEY 未設定。請在 Hugging Face Space 的 Secrets 設定 OPENAI_API_KEY。") |
|
|
| client = OpenAI(api_key=OPENAI_API_KEY) |
|
|
| EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small").strip() |
| LLM_MODEL = os.getenv("LLM_MODEL", "gpt-5.2").strip() |
|
|
| |
| MIN_LEN_NCCN = int(os.getenv("MIN_LEN_NCCN", "60")) |
| MIN_LEN_LOCAL = int(os.getenv("MIN_LEN_LOCAL", "120")) |
| MAX_PARA_LEN = int(os.getenv("MAX_PARA_LEN", "1800")) |
| SPLIT_STEP = int(os.getenv("SPLIT_STEP", "1200")) |
|
|
| |
| K_VEC = int(os.getenv("K_VEC", "40")) |
| K_FINAL_PER_FACET = int(os.getenv("K_FINAL_PER_FACET", "6")) |
| K_MAX_CONTEXT = int(os.getenv("K_MAX_CONTEXT", "18")) |
|
|
| |
| chunks = [] |
| index = None |
| allowed_citations = set() |
|
|
| |
| BOILERPLATE_PATTERNS = [ |
| r"Copyright\s*©", |
| r"All Rights Reserved", |
| r"NCCN Guidelines Version", |
| r"Printed by", |
| r"\b\d{1,2}/\d{1,2}/\d{4}\b", |
| r"\b\d{1,2}:\d{2}:\d{2}\s*(AM|PM)\b", |
| r"www\.springer\.com", |
| r"Springer International Publishing", |
| ] |
| BOILERPLATE_RE = re.compile("|".join(BOILERPLATE_PATTERNS), re.IGNORECASE) |
|
|
| CIT_RE = re.compile(r"〔([^|〔〕]+)\|([A-Z]+-\d{3})〕") |
|
|
|
|
| def normalize_text(text: str) -> str: |
| text = (text or "").replace("\r", "\n") |
| text = re.sub(r"[ \t]+", " ", text) |
| text = re.sub(r"\n{3,}", "\n\n", text) |
| return text.strip() |
|
|
|
|
| def is_boilerplate(p: str) -> bool: |
| if not p: |
| return True |
| if BOILERPLATE_RE.search(p): |
| return True |
| if len(p) < 90 and ("©" in p or "NCCN" in p): |
| return True |
| return False |
|
|
|
|
| def load_pdf_text(file) -> str: |
| reader = PdfReader(file.name) |
| out = [] |
| for page in reader.pages: |
| t = page.extract_text() |
| if t: |
| out.append(t) |
| return "\n".join(out).strip() |
|
|
|
|
| def split_into_paragraphs(text: str, source: str): |
| text = normalize_text(text) |
| if not text: |
| return [] |
|
|
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] |
| refined = [] |
|
|
| for p in paras: |
| if is_boilerplate(p): |
| continue |
|
|
| |
| if len(p) > MAX_PARA_LEN: |
| for i in range(0, len(p), SPLIT_STEP): |
| seg = p[i:i + SPLIT_STEP].strip() |
| if seg and not is_boilerplate(seg): |
| refined.append(seg) |
| continue |
|
|
| refined.append(p) |
|
|
| min_len = MIN_LEN_NCCN if source == "NCCN" else MIN_LEN_LOCAL |
| refined = [p for p in refined if len(p) >= min_len and not is_boilerplate(p)] |
| return refined |
|
|
|
|
| def build_faiss_index(all_chunks): |
| global index, allowed_citations |
|
|
| texts = [c["text"] for c in all_chunks] |
| |
| emb = client.embeddings.create(model=EMBED_MODEL, input=texts) |
| vecs = np.array([e.embedding for e in emb.data], dtype="float32") |
|
|
| dim = vecs.shape[1] |
| index = faiss.IndexFlatL2(dim) |
| index.add(vecs) |
|
|
| allowed_citations = set(f"{c['source']}|{c['id']}" for c in all_chunks) |
|
|
|
|
| def keyword_score(text: str, keywords): |
| t = text.lower() |
| return sum(1 for kw in keywords if kw in t) |
|
|
|
|
| def retrieve(query: str, keywords, k_vec=K_VEC, k_final=K_FINAL_PER_FACET): |
| if index is None: |
| return [] |
|
|
| q_emb = client.embeddings.create(model=EMBED_MODEL, input=query) |
| q_vec = np.array([q_emb.data[0].embedding], dtype="float32") |
|
|
| _, idxs = index.search(q_vec, k_vec) |
| cands = [chunks[i] for i in idxs[0]] |
|
|
| scored = [] |
| for c in cands: |
| if is_boilerplate(c["text"]): |
| continue |
| s = keyword_score(c["text"], keywords) |
| scored.append((s, len(c["text"]), c)) |
|
|
| scored.sort(key=lambda x: (x[0], x[1]), reverse=True) |
| return [x[2] for x in scored[:k_final]] |
|
|
|
|
| def uniq_by_key(items): |
| seen = set() |
| out = [] |
| for c in items: |
| key = f"{c['source']}|{c['id']}" |
| if key not in seen: |
| seen.add(key) |
| out.append(c) |
| return out |
|
|
|
|
| def make_evidence_preview(selected_chunks): |
| lines = [] |
| for c in selected_chunks: |
| preview = c["text"][:220].replace("\n", " ") |
| lines.append(f"- [{c['source']} | {c['id']}] {preview} ...") |
| return "\n".join(lines) |
|
|
|
|
| def normalize_headers(text: str) -> str: |
| lines = [] |
| for raw in text.splitlines(): |
| line = raw.rstrip() |
| m = re.match(r"^(【MDT摘要\s*[123]/3】)\s*(.+)$", line) |
| if m: |
| lines.append(m.group(1)) |
| tail = m.group(2).strip() |
| if tail: |
| lines.append(tail) |
| else: |
| lines.append(line) |
| return "\n".join(lines).strip() |
|
|
|
|
| def sanitize_answer(answer: str) -> str: |
| """ |
| 規則: |
| - 非標題行必須含 citation,且 citation 必須在 allowed_citations;否則整行 NA |
| - NA 不重複堆疊 |
| """ |
| answer = normalize_headers(answer) |
|
|
| out_lines = [] |
| for raw_line in answer.splitlines(): |
| line = raw_line.strip() |
| if not line: |
| out_lines.append("") |
| continue |
|
|
| if line.startswith("【MDT摘要") or line.startswith("【治療路徑流程圖】"): |
| out_lines.append(line) |
| continue |
|
|
| cites = CIT_RE.findall(line) |
| if not cites: |
| out_lines.append("此版檢索未涵蓋,故不推測〔NA〕") |
| continue |
|
|
| ok = True |
| for src, cid in cites: |
| key = f"{src}|{cid}" |
| if key not in allowed_citations: |
| ok = False |
| break |
|
|
| if not ok: |
| out_lines.append("此版檢索未涵蓋,故不推測〔NA〕") |
| else: |
| line = re.sub( |
| r"(此版檢索未涵蓋,故不推測〔NA〕\s*){2,}", |
| "此版檢索未涵蓋,故不推測〔NA〕 ", |
| line |
| ).strip() |
| out_lines.append(line) |
|
|
| return "\n".join(out_lines).strip() |
|
|
|
|
| def call_openai_llm(system_prompt: str, user_prompt: str) -> str: |
| resp = client.chat.completions.create( |
| model=LLM_MODEL, |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| temperature=0.1 |
| ) |
| return resp.choices[0].message.content.strip() |
|
|
|
|
| |
| def agent_answer(t, n, m, er, pr, her2, question): |
| if index is None: |
| return "⚠️ 請先建立知識庫", "" |
|
|
| patient = f"TNM={t}{n}{m};ER={er};PR={pr};HER2={her2}" |
| user_q = (question or "").strip() |
|
|
| q_lower = user_q.lower() |
| mention_neoadj = any(x in q_lower for x in ["術前", "新輔助", "neoadjuvant", "preoperative"]) |
|
|
| kw_stage = ["stage", "ajcc", "prognostic", "pathological", "tnm", "table"] |
| kw_loco = ["locoregional", "bcs", "mastect", "surgery", "axillary", "slnb", "alnd", "margin", "rt", "radiation"] |
| kw_sys = ["systemic", "adjuvant", "chemotherapy", "endocrine", "tamoxifen", "aromatase", "ai", "her2", |
| "trastuz", "pertuz", "t-dm1", "taxane", "anthracycline", "sequence", "sequential"] |
| kw_fu = ["follow", "surveillance", "history", "physical", "mammogram", "ultrasound", "mri", "cea", "ca15", |
| "ca153", "ct", "pet", "bone", "x-ray", "lft", "alp", "cbc"] |
|
|
| q_stage = f"breast cancer staging table prognostic stage {patient}" |
| q_loco = f"breast cancer locoregional treatment axillary staging radiation {patient}" |
| q_sys = f"breast cancer adjuvant systemic therapy endocrine chemotherapy HER2 therapy sequencing {patient}" |
| q_fu = f"breast cancer post-therapy surveillance follow-up schedule tests imaging labs {patient}" |
|
|
| top_stage = retrieve(q_stage, kw_stage) |
| top_loco = retrieve(q_loco, kw_loco) |
| top_sys = retrieve(q_sys, kw_sys) |
| top_fu = retrieve(q_fu, kw_fu) |
|
|
| top_neoadj = [] |
| if mention_neoadj: |
| kw_neoadj = ["neoadjuvant", "preoperative", "marker", "restaged", "tad", "dual tracers", "cn(+)", "fna", "cnb"] |
| q_neoadj = f"breast cancer neoadjuvant preoperative systemic therapy axillary marker TAD SLNB {patient}" |
| top_neoadj = retrieve(q_neoadj, kw_neoadj, k_vec=K_VEC, k_final=5) |
|
|
| selected = uniq_by_key(top_stage + top_loco + top_sys + top_fu + top_neoadj)[:K_MAX_CONTEXT] |
|
|
| context = "\n\n---\n\n".join([f"[{c['source']}|{c['id']}]\n{c['text']}" for c in selected]) |
| evidence_preview = make_evidence_preview(selected) |
|
|
| system_prompt = """ |
| 你是「乳癌指引整合智能體(NCCN + 院內)」。 |
| 你只做「文件查詢整理」,用於支援 MDT 會議快速核對指引;你不是臨床決策系統。 |
| |
| 輸出格式(必須嚴格遵守): |
| 【MDT摘要 1/3】 |
| - ...(最多 8 行) |
| 【MDT摘要 2/3】 |
| - ...(最多 8 行) |
| 【MDT摘要 3/3】 |
| - ...(最多 8 行) |
| 【治療路徑流程圖】 |
| - 使用純文字流程(→ │ ├─ └─),限制在 10–18 行 |
| |
| 核心規則(必須遵守): |
| 1) 輔助導向:不要把回答寫成「建議做某治療」,改寫成「文件列出/文件指出/文件條列」的可核對要點。 |
| 2) 每一行末尾必須附 1 個引用:〔來源|ID〕(來源只能是 NCCN 或 院內;ID 必須來自提供的 [來源|ID]) |
| 3) 不允許常識補充:任何治療/追蹤/適應症/頻率/流程,必須由文件段落支持。 |
| 4) 文件沒寫到就輸出:此版檢索未涵蓋,故不推測〔NA〕 |
| 5) 不要要求使用者補文件、不要寫背景長文、不要說「截圖」。 |
| 6) 流程圖每個分支終點行必須附 citation;若無法引用則用 NA,不可補內容。 |
| """.strip() |
|
|
| user_prompt = f""" |
| 【病患資料】 |
| {patient} |
| |
| 【MDT問題】 |
| {user_q} |
| |
| 【檢索到的文件段落(只能用這些)】 |
| {context} |
| """.strip() |
|
|
| try: |
| raw = call_openai_llm(system_prompt, user_prompt) |
| answer = sanitize_answer(raw) |
| return answer, evidence_preview |
| except Exception as e: |
| return f"⚠️ LLM 呼叫失敗:{e}", evidence_preview |
|
|
|
|
| |
| def build_all(nccn_file, local_file): |
| global chunks, index |
|
|
| if nccn_file is None or local_file is None: |
| return "⚠️ 請同時上傳 NCCN 與院內指引 PDF" |
|
|
| try: |
| nccn_text = load_pdf_text(nccn_file) |
| local_text = load_pdf_text(local_file) |
| except Exception as e: |
| index = None |
| chunks = [] |
| return f"⚠️ PDF 讀取失敗:{e}" |
|
|
| nccn_paras = split_into_paragraphs(nccn_text, "NCCN") |
| local_paras = split_into_paragraphs(local_text, "院內") |
|
|
| new_chunks = [] |
| for i, p in enumerate(nccn_paras, start=1): |
| new_chunks.append({"id": f"NCCN-{i:03d}", "source": "NCCN", "text": p}) |
| for i, p in enumerate(local_paras, start=1): |
| new_chunks.append({"id": f"LOCAL-{i:03d}", "source": "院內", "text": p}) |
|
|
| chunks = new_chunks |
|
|
| try: |
| if len(chunks) == 0: |
| index = None |
| return "⚠️ 無可用文字段落(兩份 PDF 皆未抽到有效文字)" |
|
|
| build_faiss_index(chunks) |
| except Exception as e: |
| index = None |
| return f"⚠️ 索引建立失敗:{e}" |
|
|
| |
| if len(nccn_paras) == 0: |
| extra = "(NCCN 未抽到可用文字:PDF 可能為掃描/受保護層,extract_text 為空或全為頁眉版權)" |
| else: |
| extra = "" |
|
|
| return f"索引建立完成:NCCN {len(nccn_paras)} 段、院內 {len(local_paras)} 段,共 {len(chunks)} 段{extra}" |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## 乳癌指引整合智能體(NCCN + 院內)") |
| gr.Markdown("### 上傳最新版 NCCN 與院內指引 PDF(外部上傳;不內建)") |
|
|
| nccn_file = gr.File(label="上傳 NCCN PDF") |
| local_file = gr.File(label="上傳 院內指引 PDF") |
| build_btn = gr.Button("建立知識庫") |
| status = gr.Textbox(label="狀態") |
| build_btn.click(build_all, [nccn_file, local_file], status) |
|
|
| gr.Markdown("### 臨床查詢") |
| with gr.Row(): |
| t = gr.Dropdown(["T0", "T1", "T2", "T3", "T4"], label="T", value="T1") |
| n = gr.Dropdown(["N0", "N1", "N2", "N3"], label="N", value="N0") |
| m = gr.Dropdown(["M0", "M1"], label="M", value="M0") |
|
|
| with gr.Row(): |
| er = gr.Dropdown(["Positive", "Negative"], label="ER", value="Positive") |
| pr = gr.Dropdown(["Positive", "Negative"], label="PR", value="Positive") |
| her2 = gr.Dropdown(["Positive", "Negative"], label="HER2", value="Negative") |
|
|
| question = gr.Textbox( |
| label="臨床問題(MDT 用)", |
| lines=3, |
| placeholder="例如:初次確診的治療路徑怎麼走?請依指引整理 MDT 會議核對重點。" |
| ) |
|
|
| ask_btn = gr.Button("智能體分析") |
| answer = gr.Markdown() |
|
|
| with gr.Accordion("證據預覽(檢索到的段落)", open=False): |
| evidence = gr.Markdown() |
|
|
| ask_btn.click(agent_answer, [t, n, m, er, pr, her2, question], [answer, evidence]) |
|
|
| demo.launch() |