CanReg / app.py
wensjheng's picture
Update app.py
d5b41f4 verified
import os
import re
import gradio as gr
import numpy as np
import faiss
from pypdf import PdfReader
from openai import OpenAI
# ===================== Config =====================
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip()
if not OPENAI_API_KEY:
raise RuntimeError("OPENAI_API_KEY 未設定。請在 Hugging Face Space 的 Secrets 設定 OPENAI_API_KEY。")
client = OpenAI(api_key=OPENAI_API_KEY)
EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small").strip()
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-5.2").strip() # 你要用 gpt-5.2 / gpt-5.1 都可
# Chunking
MIN_LEN_NCCN = int(os.getenv("MIN_LEN_NCCN", "60")) # NCCN 常碎,放寬
MIN_LEN_LOCAL = int(os.getenv("MIN_LEN_LOCAL", "120"))
MAX_PARA_LEN = int(os.getenv("MAX_PARA_LEN", "1800"))
SPLIT_STEP = int(os.getenv("SPLIT_STEP", "1200"))
# Retrieval
K_VEC = int(os.getenv("K_VEC", "40"))
K_FINAL_PER_FACET = int(os.getenv("K_FINAL_PER_FACET", "6"))
K_MAX_CONTEXT = int(os.getenv("K_MAX_CONTEXT", "18"))
# ===================== State =====================
chunks = [] # [{id, source, text}]
index = None # faiss index
allowed_citations = set() # {"NCCN|NCCN-001", "院內|LOCAL-001", ...}
# ===================== Helpers =====================
BOILERPLATE_PATTERNS = [
r"Copyright\s*©",
r"All Rights Reserved",
r"NCCN Guidelines Version",
r"Printed by",
r"\b\d{1,2}/\d{1,2}/\d{4}\b",
r"\b\d{1,2}:\d{2}:\d{2}\s*(AM|PM)\b",
r"www\.springer\.com",
r"Springer International Publishing",
]
BOILERPLATE_RE = re.compile("|".join(BOILERPLATE_PATTERNS), re.IGNORECASE)
CIT_RE = re.compile(r"〔([^|〔〕]+)\|([A-Z]+-\d{3})〕")
def normalize_text(text: str) -> str:
text = (text or "").replace("\r", "\n")
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
def is_boilerplate(p: str) -> bool:
if not p:
return True
if BOILERPLATE_RE.search(p):
return True
if len(p) < 90 and ("©" in p or "NCCN" in p):
return True
return False
def load_pdf_text(file) -> str:
reader = PdfReader(file.name)
out = []
for page in reader.pages:
t = page.extract_text()
if t:
out.append(t)
return "\n".join(out).strip()
def split_into_paragraphs(text: str, source: str):
text = normalize_text(text)
if not text:
return []
paras = [p.strip() for p in text.split("\n\n") if p.strip()]
refined = []
for p in paras:
if is_boilerplate(p):
continue
# long paragraph slicing
if len(p) > MAX_PARA_LEN:
for i in range(0, len(p), SPLIT_STEP):
seg = p[i:i + SPLIT_STEP].strip()
if seg and not is_boilerplate(seg):
refined.append(seg)
continue
refined.append(p)
min_len = MIN_LEN_NCCN if source == "NCCN" else MIN_LEN_LOCAL
refined = [p for p in refined if len(p) >= min_len and not is_boilerplate(p)]
return refined
def build_faiss_index(all_chunks):
global index, allowed_citations
texts = [c["text"] for c in all_chunks]
# embeddings batch
emb = client.embeddings.create(model=EMBED_MODEL, input=texts)
vecs = np.array([e.embedding for e in emb.data], dtype="float32")
dim = vecs.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(vecs)
allowed_citations = set(f"{c['source']}|{c['id']}" for c in all_chunks)
def keyword_score(text: str, keywords):
t = text.lower()
return sum(1 for kw in keywords if kw in t)
def retrieve(query: str, keywords, k_vec=K_VEC, k_final=K_FINAL_PER_FACET):
if index is None:
return []
q_emb = client.embeddings.create(model=EMBED_MODEL, input=query)
q_vec = np.array([q_emb.data[0].embedding], dtype="float32")
_, idxs = index.search(q_vec, k_vec)
cands = [chunks[i] for i in idxs[0]]
scored = []
for c in cands:
if is_boilerplate(c["text"]):
continue
s = keyword_score(c["text"], keywords)
scored.append((s, len(c["text"]), c))
scored.sort(key=lambda x: (x[0], x[1]), reverse=True)
return [x[2] for x in scored[:k_final]]
def uniq_by_key(items):
seen = set()
out = []
for c in items:
key = f"{c['source']}|{c['id']}"
if key not in seen:
seen.add(key)
out.append(c)
return out
def make_evidence_preview(selected_chunks):
lines = []
for c in selected_chunks:
preview = c["text"][:220].replace("\n", " ")
lines.append(f"- [{c['source']} | {c['id']}] {preview} ...")
return "\n".join(lines)
def normalize_headers(text: str) -> str:
lines = []
for raw in text.splitlines():
line = raw.rstrip()
m = re.match(r"^(【MDT摘要\s*[123]/3】)\s*(.+)$", line)
if m:
lines.append(m.group(1))
tail = m.group(2).strip()
if tail:
lines.append(tail)
else:
lines.append(line)
return "\n".join(lines).strip()
def sanitize_answer(answer: str) -> str:
"""
規則:
- 非標題行必須含 citation,且 citation 必須在 allowed_citations;否則整行 NA
- NA 不重複堆疊
"""
answer = normalize_headers(answer)
out_lines = []
for raw_line in answer.splitlines():
line = raw_line.strip()
if not line:
out_lines.append("")
continue
if line.startswith("【MDT摘要") or line.startswith("【治療路徑流程圖】"):
out_lines.append(line)
continue
cites = CIT_RE.findall(line)
if not cites:
out_lines.append("此版檢索未涵蓋,故不推測〔NA〕")
continue
ok = True
for src, cid in cites:
key = f"{src}|{cid}"
if key not in allowed_citations:
ok = False
break
if not ok:
out_lines.append("此版檢索未涵蓋,故不推測〔NA〕")
else:
line = re.sub(
r"(此版檢索未涵蓋,故不推測〔NA〕\s*){2,}",
"此版檢索未涵蓋,故不推測〔NA〕 ",
line
).strip()
out_lines.append(line)
return "\n".join(out_lines).strip()
def call_openai_llm(system_prompt: str, user_prompt: str) -> str:
resp = client.chat.completions.create(
model=LLM_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.1
)
return resp.choices[0].message.content.strip()
# ===================== Agent =====================
def agent_answer(t, n, m, er, pr, her2, question):
if index is None:
return "⚠️ 請先建立知識庫", ""
patient = f"TNM={t}{n}{m};ER={er};PR={pr};HER2={her2}"
user_q = (question or "").strip()
q_lower = user_q.lower()
mention_neoadj = any(x in q_lower for x in ["術前", "新輔助", "neoadjuvant", "preoperative"])
kw_stage = ["stage", "ajcc", "prognostic", "pathological", "tnm", "table"]
kw_loco = ["locoregional", "bcs", "mastect", "surgery", "axillary", "slnb", "alnd", "margin", "rt", "radiation"]
kw_sys = ["systemic", "adjuvant", "chemotherapy", "endocrine", "tamoxifen", "aromatase", "ai", "her2",
"trastuz", "pertuz", "t-dm1", "taxane", "anthracycline", "sequence", "sequential"]
kw_fu = ["follow", "surveillance", "history", "physical", "mammogram", "ultrasound", "mri", "cea", "ca15",
"ca153", "ct", "pet", "bone", "x-ray", "lft", "alp", "cbc"]
q_stage = f"breast cancer staging table prognostic stage {patient}"
q_loco = f"breast cancer locoregional treatment axillary staging radiation {patient}"
q_sys = f"breast cancer adjuvant systemic therapy endocrine chemotherapy HER2 therapy sequencing {patient}"
q_fu = f"breast cancer post-therapy surveillance follow-up schedule tests imaging labs {patient}"
top_stage = retrieve(q_stage, kw_stage)
top_loco = retrieve(q_loco, kw_loco)
top_sys = retrieve(q_sys, kw_sys)
top_fu = retrieve(q_fu, kw_fu)
top_neoadj = []
if mention_neoadj:
kw_neoadj = ["neoadjuvant", "preoperative", "marker", "restaged", "tad", "dual tracers", "cn(+)", "fna", "cnb"]
q_neoadj = f"breast cancer neoadjuvant preoperative systemic therapy axillary marker TAD SLNB {patient}"
top_neoadj = retrieve(q_neoadj, kw_neoadj, k_vec=K_VEC, k_final=5)
selected = uniq_by_key(top_stage + top_loco + top_sys + top_fu + top_neoadj)[:K_MAX_CONTEXT]
context = "\n\n---\n\n".join([f"[{c['source']}|{c['id']}]\n{c['text']}" for c in selected])
evidence_preview = make_evidence_preview(selected)
system_prompt = """
你是「乳癌指引整合智能體(NCCN + 院內)」。
你只做「文件查詢整理」,用於支援 MDT 會議快速核對指引;你不是臨床決策系統。
輸出格式(必須嚴格遵守):
【MDT摘要 1/3】
- ...(最多 8 行)
【MDT摘要 2/3】
- ...(最多 8 行)
【MDT摘要 3/3】
- ...(最多 8 行)
【治療路徑流程圖】
- 使用純文字流程(→ │ ├─ └─),限制在 10–18 行
核心規則(必須遵守):
1) 輔助導向:不要把回答寫成「建議做某治療」,改寫成「文件列出/文件指出/文件條列」的可核對要點。
2) 每一行末尾必須附 1 個引用:〔來源|ID〕(來源只能是 NCCN 或 院內;ID 必須來自提供的 [來源|ID])
3) 不允許常識補充:任何治療/追蹤/適應症/頻率/流程,必須由文件段落支持。
4) 文件沒寫到就輸出:此版檢索未涵蓋,故不推測〔NA〕
5) 不要要求使用者補文件、不要寫背景長文、不要說「截圖」。
6) 流程圖每個分支終點行必須附 citation;若無法引用則用 NA,不可補內容。
""".strip()
user_prompt = f"""
【病患資料】
{patient}
【MDT問題】
{user_q}
【檢索到的文件段落(只能用這些)】
{context}
""".strip()
try:
raw = call_openai_llm(system_prompt, user_prompt)
answer = sanitize_answer(raw)
return answer, evidence_preview
except Exception as e:
return f"⚠️ LLM 呼叫失敗:{e}", evidence_preview
# ===================== Build KB =====================
def build_all(nccn_file, local_file):
global chunks, index
if nccn_file is None or local_file is None:
return "⚠️ 請同時上傳 NCCN 與院內指引 PDF"
try:
nccn_text = load_pdf_text(nccn_file)
local_text = load_pdf_text(local_file)
except Exception as e:
index = None
chunks = []
return f"⚠️ PDF 讀取失敗:{e}"
nccn_paras = split_into_paragraphs(nccn_text, "NCCN")
local_paras = split_into_paragraphs(local_text, "院內")
new_chunks = []
for i, p in enumerate(nccn_paras, start=1):
new_chunks.append({"id": f"NCCN-{i:03d}", "source": "NCCN", "text": p})
for i, p in enumerate(local_paras, start=1):
new_chunks.append({"id": f"LOCAL-{i:03d}", "source": "院內", "text": p})
chunks = new_chunks
try:
if len(chunks) == 0:
index = None
return "⚠️ 無可用文字段落(兩份 PDF 皆未抽到有效文字)"
build_faiss_index(chunks)
except Exception as e:
index = None
return f"⚠️ 索引建立失敗:{e}"
# 明確提示 NCCN=0 的原因(不影響可用性,但讓你可驗收)
if len(nccn_paras) == 0:
extra = "(NCCN 未抽到可用文字:PDF 可能為掃描/受保護層,extract_text 為空或全為頁眉版權)"
else:
extra = ""
return f"索引建立完成:NCCN {len(nccn_paras)} 段、院內 {len(local_paras)} 段,共 {len(chunks)}{extra}"
# ===================== UI =====================
with gr.Blocks() as demo:
gr.Markdown("## 乳癌指引整合智能體(NCCN + 院內)")
gr.Markdown("### 上傳最新版 NCCN 與院內指引 PDF(外部上傳;不內建)")
nccn_file = gr.File(label="上傳 NCCN PDF")
local_file = gr.File(label="上傳 院內指引 PDF")
build_btn = gr.Button("建立知識庫")
status = gr.Textbox(label="狀態")
build_btn.click(build_all, [nccn_file, local_file], status)
gr.Markdown("### 臨床查詢")
with gr.Row():
t = gr.Dropdown(["T0", "T1", "T2", "T3", "T4"], label="T", value="T1")
n = gr.Dropdown(["N0", "N1", "N2", "N3"], label="N", value="N0")
m = gr.Dropdown(["M0", "M1"], label="M", value="M0")
with gr.Row():
er = gr.Dropdown(["Positive", "Negative"], label="ER", value="Positive")
pr = gr.Dropdown(["Positive", "Negative"], label="PR", value="Positive")
her2 = gr.Dropdown(["Positive", "Negative"], label="HER2", value="Negative")
question = gr.Textbox(
label="臨床問題(MDT 用)",
lines=3,
placeholder="例如:初次確診的治療路徑怎麼走?請依指引整理 MDT 會議核對重點。"
)
ask_btn = gr.Button("智能體分析")
answer = gr.Markdown()
with gr.Accordion("證據預覽(檢索到的段落)", open=False):
evidence = gr.Markdown()
ask_btn.click(agent_answer, [t, n, m, er, pr, her2, question], [answer, evidence])
demo.launch()