import os import re import gradio as gr import numpy as np import faiss from pypdf import PdfReader from openai import OpenAI # ===================== Config ===================== OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() if not OPENAI_API_KEY: raise RuntimeError("OPENAI_API_KEY 未設定。請在 Hugging Face Space 的 Secrets 設定 OPENAI_API_KEY。") client = OpenAI(api_key=OPENAI_API_KEY) EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small").strip() LLM_MODEL = os.getenv("LLM_MODEL", "gpt-5.2").strip() # 你要用 gpt-5.2 / gpt-5.1 都可 # Chunking MIN_LEN_NCCN = int(os.getenv("MIN_LEN_NCCN", "60")) # NCCN 常碎,放寬 MIN_LEN_LOCAL = int(os.getenv("MIN_LEN_LOCAL", "120")) MAX_PARA_LEN = int(os.getenv("MAX_PARA_LEN", "1800")) SPLIT_STEP = int(os.getenv("SPLIT_STEP", "1200")) # Retrieval K_VEC = int(os.getenv("K_VEC", "40")) K_FINAL_PER_FACET = int(os.getenv("K_FINAL_PER_FACET", "6")) K_MAX_CONTEXT = int(os.getenv("K_MAX_CONTEXT", "18")) # ===================== State ===================== chunks = [] # [{id, source, text}] index = None # faiss index allowed_citations = set() # {"NCCN|NCCN-001", "院內|LOCAL-001", ...} # ===================== Helpers ===================== BOILERPLATE_PATTERNS = [ r"Copyright\s*©", r"All Rights Reserved", r"NCCN Guidelines Version", r"Printed by", r"\b\d{1,2}/\d{1,2}/\d{4}\b", r"\b\d{1,2}:\d{2}:\d{2}\s*(AM|PM)\b", r"www\.springer\.com", r"Springer International Publishing", ] BOILERPLATE_RE = re.compile("|".join(BOILERPLATE_PATTERNS), re.IGNORECASE) CIT_RE = re.compile(r"〔([^|〔〕]+)\|([A-Z]+-\d{3})〕") def normalize_text(text: str) -> str: text = (text or "").replace("\r", "\n") text = re.sub(r"[ \t]+", " ", text) text = re.sub(r"\n{3,}", "\n\n", text) return text.strip() def is_boilerplate(p: str) -> bool: if not p: return True if BOILERPLATE_RE.search(p): return True if len(p) < 90 and ("©" in p or "NCCN" in p): return True return False def load_pdf_text(file) -> str: reader = PdfReader(file.name) out = [] for page in reader.pages: t = page.extract_text() if t: out.append(t) return "\n".join(out).strip() def split_into_paragraphs(text: str, source: str): text = normalize_text(text) if not text: return [] paras = [p.strip() for p in text.split("\n\n") if p.strip()] refined = [] for p in paras: if is_boilerplate(p): continue # long paragraph slicing if len(p) > MAX_PARA_LEN: for i in range(0, len(p), SPLIT_STEP): seg = p[i:i + SPLIT_STEP].strip() if seg and not is_boilerplate(seg): refined.append(seg) continue refined.append(p) min_len = MIN_LEN_NCCN if source == "NCCN" else MIN_LEN_LOCAL refined = [p for p in refined if len(p) >= min_len and not is_boilerplate(p)] return refined def build_faiss_index(all_chunks): global index, allowed_citations texts = [c["text"] for c in all_chunks] # embeddings batch emb = client.embeddings.create(model=EMBED_MODEL, input=texts) vecs = np.array([e.embedding for e in emb.data], dtype="float32") dim = vecs.shape[1] index = faiss.IndexFlatL2(dim) index.add(vecs) allowed_citations = set(f"{c['source']}|{c['id']}" for c in all_chunks) def keyword_score(text: str, keywords): t = text.lower() return sum(1 for kw in keywords if kw in t) def retrieve(query: str, keywords, k_vec=K_VEC, k_final=K_FINAL_PER_FACET): if index is None: return [] q_emb = client.embeddings.create(model=EMBED_MODEL, input=query) q_vec = np.array([q_emb.data[0].embedding], dtype="float32") _, idxs = index.search(q_vec, k_vec) cands = [chunks[i] for i in idxs[0]] scored = [] for c in cands: if is_boilerplate(c["text"]): continue s = keyword_score(c["text"], keywords) scored.append((s, len(c["text"]), c)) scored.sort(key=lambda x: (x[0], x[1]), reverse=True) return [x[2] for x in scored[:k_final]] def uniq_by_key(items): seen = set() out = [] for c in items: key = f"{c['source']}|{c['id']}" if key not in seen: seen.add(key) out.append(c) return out def make_evidence_preview(selected_chunks): lines = [] for c in selected_chunks: preview = c["text"][:220].replace("\n", " ") lines.append(f"- [{c['source']} | {c['id']}] {preview} ...") return "\n".join(lines) def normalize_headers(text: str) -> str: lines = [] for raw in text.splitlines(): line = raw.rstrip() m = re.match(r"^(【MDT摘要\s*[123]/3】)\s*(.+)$", line) if m: lines.append(m.group(1)) tail = m.group(2).strip() if tail: lines.append(tail) else: lines.append(line) return "\n".join(lines).strip() def sanitize_answer(answer: str) -> str: """ 規則: - 非標題行必須含 citation,且 citation 必須在 allowed_citations;否則整行 NA - NA 不重複堆疊 """ answer = normalize_headers(answer) out_lines = [] for raw_line in answer.splitlines(): line = raw_line.strip() if not line: out_lines.append("") continue if line.startswith("【MDT摘要") or line.startswith("【治療路徑流程圖】"): out_lines.append(line) continue cites = CIT_RE.findall(line) if not cites: out_lines.append("此版檢索未涵蓋,故不推測〔NA〕") continue ok = True for src, cid in cites: key = f"{src}|{cid}" if key not in allowed_citations: ok = False break if not ok: out_lines.append("此版檢索未涵蓋,故不推測〔NA〕") else: line = re.sub( r"(此版檢索未涵蓋,故不推測〔NA〕\s*){2,}", "此版檢索未涵蓋,故不推測〔NA〕 ", line ).strip() out_lines.append(line) return "\n".join(out_lines).strip() def call_openai_llm(system_prompt: str, user_prompt: str) -> str: resp = client.chat.completions.create( model=LLM_MODEL, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.1 ) return resp.choices[0].message.content.strip() # ===================== Agent ===================== def agent_answer(t, n, m, er, pr, her2, question): if index is None: return "⚠️ 請先建立知識庫", "" patient = f"TNM={t}{n}{m};ER={er};PR={pr};HER2={her2}" user_q = (question or "").strip() q_lower = user_q.lower() mention_neoadj = any(x in q_lower for x in ["術前", "新輔助", "neoadjuvant", "preoperative"]) kw_stage = ["stage", "ajcc", "prognostic", "pathological", "tnm", "table"] kw_loco = ["locoregional", "bcs", "mastect", "surgery", "axillary", "slnb", "alnd", "margin", "rt", "radiation"] kw_sys = ["systemic", "adjuvant", "chemotherapy", "endocrine", "tamoxifen", "aromatase", "ai", "her2", "trastuz", "pertuz", "t-dm1", "taxane", "anthracycline", "sequence", "sequential"] kw_fu = ["follow", "surveillance", "history", "physical", "mammogram", "ultrasound", "mri", "cea", "ca15", "ca153", "ct", "pet", "bone", "x-ray", "lft", "alp", "cbc"] q_stage = f"breast cancer staging table prognostic stage {patient}" q_loco = f"breast cancer locoregional treatment axillary staging radiation {patient}" q_sys = f"breast cancer adjuvant systemic therapy endocrine chemotherapy HER2 therapy sequencing {patient}" q_fu = f"breast cancer post-therapy surveillance follow-up schedule tests imaging labs {patient}" top_stage = retrieve(q_stage, kw_stage) top_loco = retrieve(q_loco, kw_loco) top_sys = retrieve(q_sys, kw_sys) top_fu = retrieve(q_fu, kw_fu) top_neoadj = [] if mention_neoadj: kw_neoadj = ["neoadjuvant", "preoperative", "marker", "restaged", "tad", "dual tracers", "cn(+)", "fna", "cnb"] q_neoadj = f"breast cancer neoadjuvant preoperative systemic therapy axillary marker TAD SLNB {patient}" top_neoadj = retrieve(q_neoadj, kw_neoadj, k_vec=K_VEC, k_final=5) selected = uniq_by_key(top_stage + top_loco + top_sys + top_fu + top_neoadj)[:K_MAX_CONTEXT] context = "\n\n---\n\n".join([f"[{c['source']}|{c['id']}]\n{c['text']}" for c in selected]) evidence_preview = make_evidence_preview(selected) system_prompt = """ 你是「乳癌指引整合智能體(NCCN + 院內)」。 你只做「文件查詢整理」,用於支援 MDT 會議快速核對指引;你不是臨床決策系統。 輸出格式(必須嚴格遵守): 【MDT摘要 1/3】 - ...(最多 8 行) 【MDT摘要 2/3】 - ...(最多 8 行) 【MDT摘要 3/3】 - ...(最多 8 行) 【治療路徑流程圖】 - 使用純文字流程(→ │ ├─ └─),限制在 10–18 行 核心規則(必須遵守): 1) 輔助導向:不要把回答寫成「建議做某治療」,改寫成「文件列出/文件指出/文件條列」的可核對要點。 2) 每一行末尾必須附 1 個引用:〔來源|ID〕(來源只能是 NCCN 或 院內;ID 必須來自提供的 [來源|ID]) 3) 不允許常識補充:任何治療/追蹤/適應症/頻率/流程,必須由文件段落支持。 4) 文件沒寫到就輸出:此版檢索未涵蓋,故不推測〔NA〕 5) 不要要求使用者補文件、不要寫背景長文、不要說「截圖」。 6) 流程圖每個分支終點行必須附 citation;若無法引用則用 NA,不可補內容。 """.strip() user_prompt = f""" 【病患資料】 {patient} 【MDT問題】 {user_q} 【檢索到的文件段落(只能用這些)】 {context} """.strip() try: raw = call_openai_llm(system_prompt, user_prompt) answer = sanitize_answer(raw) return answer, evidence_preview except Exception as e: return f"⚠️ LLM 呼叫失敗:{e}", evidence_preview # ===================== Build KB ===================== def build_all(nccn_file, local_file): global chunks, index if nccn_file is None or local_file is None: return "⚠️ 請同時上傳 NCCN 與院內指引 PDF" try: nccn_text = load_pdf_text(nccn_file) local_text = load_pdf_text(local_file) except Exception as e: index = None chunks = [] return f"⚠️ PDF 讀取失敗:{e}" nccn_paras = split_into_paragraphs(nccn_text, "NCCN") local_paras = split_into_paragraphs(local_text, "院內") new_chunks = [] for i, p in enumerate(nccn_paras, start=1): new_chunks.append({"id": f"NCCN-{i:03d}", "source": "NCCN", "text": p}) for i, p in enumerate(local_paras, start=1): new_chunks.append({"id": f"LOCAL-{i:03d}", "source": "院內", "text": p}) chunks = new_chunks try: if len(chunks) == 0: index = None return "⚠️ 無可用文字段落(兩份 PDF 皆未抽到有效文字)" build_faiss_index(chunks) except Exception as e: index = None return f"⚠️ 索引建立失敗:{e}" # 明確提示 NCCN=0 的原因(不影響可用性,但讓你可驗收) if len(nccn_paras) == 0: extra = "(NCCN 未抽到可用文字:PDF 可能為掃描/受保護層,extract_text 為空或全為頁眉版權)" else: extra = "" return f"索引建立完成:NCCN {len(nccn_paras)} 段、院內 {len(local_paras)} 段,共 {len(chunks)} 段{extra}" # ===================== UI ===================== with gr.Blocks() as demo: gr.Markdown("## 乳癌指引整合智能體(NCCN + 院內)") gr.Markdown("### 上傳最新版 NCCN 與院內指引 PDF(外部上傳;不內建)") nccn_file = gr.File(label="上傳 NCCN PDF") local_file = gr.File(label="上傳 院內指引 PDF") build_btn = gr.Button("建立知識庫") status = gr.Textbox(label="狀態") build_btn.click(build_all, [nccn_file, local_file], status) gr.Markdown("### 臨床查詢") with gr.Row(): t = gr.Dropdown(["T0", "T1", "T2", "T3", "T4"], label="T", value="T1") n = gr.Dropdown(["N0", "N1", "N2", "N3"], label="N", value="N0") m = gr.Dropdown(["M0", "M1"], label="M", value="M0") with gr.Row(): er = gr.Dropdown(["Positive", "Negative"], label="ER", value="Positive") pr = gr.Dropdown(["Positive", "Negative"], label="PR", value="Positive") her2 = gr.Dropdown(["Positive", "Negative"], label="HER2", value="Negative") question = gr.Textbox( label="臨床問題(MDT 用)", lines=3, placeholder="例如:初次確診的治療路徑怎麼走?請依指引整理 MDT 會議核對重點。" ) ask_btn = gr.Button("智能體分析") answer = gr.Markdown() with gr.Accordion("證據預覽(檢索到的段落)", open=False): evidence = gr.Markdown() ask_btn.click(agent_answer, [t, n, m, er, pr, her2, question], [answer, evidence]) demo.launch()