File size: 13,504 Bytes
a7ebe56
c57f4fd
a755a5b
 
 
 
2f9cd9e
 
2d8c06a
 
 
 
2f9cd9e
2d8c06a
a7ebe56
2d8c06a
 
 
 
 
 
2f9cd9e
2d8c06a
a7ebe56
2d8c06a
2f9cd9e
 
 
a7ebe56
2d8c06a
 
2f9cd9e
2d8c06a
2f9cd9e
2d8c06a
9904700
 
 
 
 
d2c058e
 
9904700
 
 
 
 
 
a7ebe56
 
e7f6c8a
2f9cd9e
e7f6c8a
 
 
a755a5b
 
9904700
2f9cd9e
 
9904700
 
2d8c06a
9904700
 
 
 
2d8c06a
2f9cd9e
 
 
 
 
 
2d8c06a
2f9cd9e
 
 
e7f6c8a
2f9cd9e
 
a755a5b
2f9cd9e
e7f6c8a
2f9cd9e
e7f6c8a
9904700
 
 
2d8c06a
2f9cd9e
 
 
9904700
 
2f9cd9e
a755a5b
2f9cd9e
 
 
 
e7f6c8a
5141d0b
a755a5b
e7f6c8a
2d8c06a
a755a5b
2f9cd9e
2d8c06a
2f9cd9e
2d8c06a
a755a5b
e7f6c8a
 
 
a755a5b
2d8c06a
d2f5ccc
a755a5b
e7f6c8a
 
2f9cd9e
a7ebe56
a755a5b
2f9cd9e
2d8c06a
e7f6c8a
 
 
2d8c06a
e7f6c8a
2f9cd9e
d2c058e
d2f5ccc
e7f6c8a
d2c058e
9904700
 
e7f6c8a
 
 
 
9904700
e7f6c8a
 
d2f5ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d8c06a
d2f5ccc
 
 
9904700
 
 
 
 
 
 
 
 
 
 
 
 
 
d2f5ccc
 
2f9cd9e
2d8c06a
 
 
2f9cd9e
9904700
 
d2f5ccc
 
 
 
 
 
 
2f9cd9e
d2f5ccc
 
 
 
 
 
 
 
 
 
 
2d8c06a
d2f5ccc
 
 
2f9cd9e
 
 
2d8c06a
 
 
 
 
2f9cd9e
d2f5ccc
 
 
 
2f9cd9e
 
2d8c06a
 
 
 
 
2f9cd9e
 
 
 
 
2d8c06a
c57f4fd
 
 
a755a5b
d2f5ccc
2f9cd9e
 
2d8c06a
d2c058e
 
 
2f9cd9e
 
2d8c06a
2f9cd9e
 
d2c058e
 
2d8c06a
2f9cd9e
 
d2c058e
2f9cd9e
 
 
 
9904700
d2c058e
 
2f9cd9e
2d8c06a
2f9cd9e
9904700
2f9cd9e
9904700
 
d2f5ccc
a755a5b
f389f3f
2d8c06a
 
e7f6c8a
2d8c06a
e7f6c8a
 
 
 
 
 
2f9cd9e
2d8c06a
f389f3f
2d8c06a
 
d2c058e
 
 
2d8c06a
2f9cd9e
e7f6c8a
c57f4fd
e7f6c8a
 
 
c57f4fd
e7f6c8a
 
 
d2f5ccc
e7f6c8a
 
c57f4fd
2f9cd9e
2d8c06a
 
 
2f9cd9e
 
a7ebe56
e7f6c8a
2d8c06a
e7f6c8a
 
c57f4fd
e7f6c8a
 
c57f4fd
2f9cd9e
 
 
 
2d8c06a
 
2f9cd9e
a7ebe56
2f9cd9e
 
a7ebe56
e7f6c8a
 
 
 
 
 
 
2f9cd9e
 
2d8c06a
 
 
 
2f9cd9e
 
 
 
e7f6c8a
2d8c06a
 
 
 
 
 
 
e7f6c8a
 
2d8c06a
f389f3f
 
e7f6c8a
d2c058e
a755a5b
 
 
 
 
 
f389f3f
a755a5b
2f9cd9e
 
 
a755a5b
 
2f9cd9e
 
 
d2f5ccc
 
 
 
2d8c06a
d2f5ccc
a755a5b
e7f6c8a
2f9cd9e
 
e7f6c8a
c57f4fd
a755a5b
c57f4fd
a7ebe56
d5b41f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import os
import re
import gradio as gr
import numpy as np
import faiss
from pypdf import PdfReader
from openai import OpenAI

# ===================== Config =====================
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip()
if not OPENAI_API_KEY:
    raise RuntimeError("OPENAI_API_KEY 未設定。請在 Hugging Face Space 的 Secrets 設定 OPENAI_API_KEY。")

client = OpenAI(api_key=OPENAI_API_KEY)

EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small").strip()
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-5.2").strip()  # 你要用 gpt-5.2 / gpt-5.1 都可

# Chunking
MIN_LEN_NCCN = int(os.getenv("MIN_LEN_NCCN", "60"))      # NCCN 常碎,放寬
MIN_LEN_LOCAL = int(os.getenv("MIN_LEN_LOCAL", "120"))
MAX_PARA_LEN = int(os.getenv("MAX_PARA_LEN", "1800"))
SPLIT_STEP = int(os.getenv("SPLIT_STEP", "1200"))

# Retrieval
K_VEC = int(os.getenv("K_VEC", "40"))
K_FINAL_PER_FACET = int(os.getenv("K_FINAL_PER_FACET", "6"))
K_MAX_CONTEXT = int(os.getenv("K_MAX_CONTEXT", "18"))

# ===================== State =====================
chunks = []          # [{id, source, text}]
index = None         # faiss index
allowed_citations = set()  # {"NCCN|NCCN-001", "院內|LOCAL-001", ...}

# ===================== Helpers =====================
BOILERPLATE_PATTERNS = [
    r"Copyright\s*©",
    r"All Rights Reserved",
    r"NCCN Guidelines Version",
    r"Printed by",
    r"\b\d{1,2}/\d{1,2}/\d{4}\b",
    r"\b\d{1,2}:\d{2}:\d{2}\s*(AM|PM)\b",
    r"www\.springer\.com",
    r"Springer International Publishing",
]
BOILERPLATE_RE = re.compile("|".join(BOILERPLATE_PATTERNS), re.IGNORECASE)

CIT_RE = re.compile(r"〔([^|〔〕]+)\|([A-Z]+-\d{3})〕")


def normalize_text(text: str) -> str:
    text = (text or "").replace("\r", "\n")
    text = re.sub(r"[ \t]+", " ", text)
    text = re.sub(r"\n{3,}", "\n\n", text)
    return text.strip()


def is_boilerplate(p: str) -> bool:
    if not p:
        return True
    if BOILERPLATE_RE.search(p):
        return True
    if len(p) < 90 and ("©" in p or "NCCN" in p):
        return True
    return False


def load_pdf_text(file) -> str:
    reader = PdfReader(file.name)
    out = []
    for page in reader.pages:
        t = page.extract_text()
        if t:
            out.append(t)
    return "\n".join(out).strip()


def split_into_paragraphs(text: str, source: str):
    text = normalize_text(text)
    if not text:
        return []

    paras = [p.strip() for p in text.split("\n\n") if p.strip()]
    refined = []

    for p in paras:
        if is_boilerplate(p):
            continue

        # long paragraph slicing
        if len(p) > MAX_PARA_LEN:
            for i in range(0, len(p), SPLIT_STEP):
                seg = p[i:i + SPLIT_STEP].strip()
                if seg and not is_boilerplate(seg):
                    refined.append(seg)
            continue

        refined.append(p)

    min_len = MIN_LEN_NCCN if source == "NCCN" else MIN_LEN_LOCAL
    refined = [p for p in refined if len(p) >= min_len and not is_boilerplate(p)]
    return refined


def build_faiss_index(all_chunks):
    global index, allowed_citations

    texts = [c["text"] for c in all_chunks]
    # embeddings batch
    emb = client.embeddings.create(model=EMBED_MODEL, input=texts)
    vecs = np.array([e.embedding for e in emb.data], dtype="float32")

    dim = vecs.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(vecs)

    allowed_citations = set(f"{c['source']}|{c['id']}" for c in all_chunks)


def keyword_score(text: str, keywords):
    t = text.lower()
    return sum(1 for kw in keywords if kw in t)


def retrieve(query: str, keywords, k_vec=K_VEC, k_final=K_FINAL_PER_FACET):
    if index is None:
        return []

    q_emb = client.embeddings.create(model=EMBED_MODEL, input=query)
    q_vec = np.array([q_emb.data[0].embedding], dtype="float32")

    _, idxs = index.search(q_vec, k_vec)
    cands = [chunks[i] for i in idxs[0]]

    scored = []
    for c in cands:
        if is_boilerplate(c["text"]):
            continue
        s = keyword_score(c["text"], keywords)
        scored.append((s, len(c["text"]), c))

    scored.sort(key=lambda x: (x[0], x[1]), reverse=True)
    return [x[2] for x in scored[:k_final]]


def uniq_by_key(items):
    seen = set()
    out = []
    for c in items:
        key = f"{c['source']}|{c['id']}"
        if key not in seen:
            seen.add(key)
            out.append(c)
    return out


def make_evidence_preview(selected_chunks):
    lines = []
    for c in selected_chunks:
        preview = c["text"][:220].replace("\n", " ")
        lines.append(f"- [{c['source']} | {c['id']}] {preview} ...")
    return "\n".join(lines)


def normalize_headers(text: str) -> str:
    lines = []
    for raw in text.splitlines():
        line = raw.rstrip()
        m = re.match(r"^(【MDT摘要\s*[123]/3】)\s*(.+)$", line)
        if m:
            lines.append(m.group(1))
            tail = m.group(2).strip()
            if tail:
                lines.append(tail)
        else:
            lines.append(line)
    return "\n".join(lines).strip()


def sanitize_answer(answer: str) -> str:
    """
    規則:
    - 非標題行必須含 citation,且 citation 必須在 allowed_citations;否則整行 NA
    - NA 不重複堆疊
    """
    answer = normalize_headers(answer)

    out_lines = []
    for raw_line in answer.splitlines():
        line = raw_line.strip()
        if not line:
            out_lines.append("")
            continue

        if line.startswith("【MDT摘要") or line.startswith("【治療路徑流程圖】"):
            out_lines.append(line)
            continue

        cites = CIT_RE.findall(line)
        if not cites:
            out_lines.append("此版檢索未涵蓋,故不推測〔NA〕")
            continue

        ok = True
        for src, cid in cites:
            key = f"{src}|{cid}"
            if key not in allowed_citations:
                ok = False
                break

        if not ok:
            out_lines.append("此版檢索未涵蓋,故不推測〔NA〕")
        else:
            line = re.sub(
                r"(此版檢索未涵蓋,故不推測〔NA〕\s*){2,}",
                "此版檢索未涵蓋,故不推測〔NA〕 ",
                line
            ).strip()
            out_lines.append(line)

    return "\n".join(out_lines).strip()


def call_openai_llm(system_prompt: str, user_prompt: str) -> str:
    resp = client.chat.completions.create(
        model=LLM_MODEL,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        temperature=0.1
    )
    return resp.choices[0].message.content.strip()


# ===================== Agent =====================
def agent_answer(t, n, m, er, pr, her2, question):
    if index is None:
        return "⚠️ 請先建立知識庫", ""

    patient = f"TNM={t}{n}{m};ER={er};PR={pr};HER2={her2}"
    user_q = (question or "").strip()

    q_lower = user_q.lower()
    mention_neoadj = any(x in q_lower for x in ["術前", "新輔助", "neoadjuvant", "preoperative"])

    kw_stage = ["stage", "ajcc", "prognostic", "pathological", "tnm", "table"]
    kw_loco = ["locoregional", "bcs", "mastect", "surgery", "axillary", "slnb", "alnd", "margin", "rt", "radiation"]
    kw_sys = ["systemic", "adjuvant", "chemotherapy", "endocrine", "tamoxifen", "aromatase", "ai", "her2",
              "trastuz", "pertuz", "t-dm1", "taxane", "anthracycline", "sequence", "sequential"]
    kw_fu = ["follow", "surveillance", "history", "physical", "mammogram", "ultrasound", "mri", "cea", "ca15",
             "ca153", "ct", "pet", "bone", "x-ray", "lft", "alp", "cbc"]

    q_stage = f"breast cancer staging table prognostic stage {patient}"
    q_loco = f"breast cancer locoregional treatment axillary staging radiation {patient}"
    q_sys = f"breast cancer adjuvant systemic therapy endocrine chemotherapy HER2 therapy sequencing {patient}"
    q_fu = f"breast cancer post-therapy surveillance follow-up schedule tests imaging labs {patient}"

    top_stage = retrieve(q_stage, kw_stage)
    top_loco = retrieve(q_loco, kw_loco)
    top_sys = retrieve(q_sys, kw_sys)
    top_fu = retrieve(q_fu, kw_fu)

    top_neoadj = []
    if mention_neoadj:
        kw_neoadj = ["neoadjuvant", "preoperative", "marker", "restaged", "tad", "dual tracers", "cn(+)", "fna", "cnb"]
        q_neoadj = f"breast cancer neoadjuvant preoperative systemic therapy axillary marker TAD SLNB {patient}"
        top_neoadj = retrieve(q_neoadj, kw_neoadj, k_vec=K_VEC, k_final=5)

    selected = uniq_by_key(top_stage + top_loco + top_sys + top_fu + top_neoadj)[:K_MAX_CONTEXT]

    context = "\n\n---\n\n".join([f"[{c['source']}|{c['id']}]\n{c['text']}" for c in selected])
    evidence_preview = make_evidence_preview(selected)

    system_prompt = """
你是「乳癌指引整合智能體(NCCN + 院內)」。
你只做「文件查詢整理」,用於支援 MDT 會議快速核對指引;你不是臨床決策系統。

輸出格式(必須嚴格遵守):
【MDT摘要 1/3】
- ...(最多 8 行)
【MDT摘要 2/3】
- ...(最多 8 行)
【MDT摘要 3/3】
- ...(最多 8 行)
【治療路徑流程圖】
- 使用純文字流程(→ │ ├─ └─),限制在 10–18 行

核心規則(必須遵守):
1) 輔助導向:不要把回答寫成「建議做某治療」,改寫成「文件列出/文件指出/文件條列」的可核對要點。
2) 每一行末尾必須附 1 個引用:〔來源|ID〕(來源只能是 NCCN 或 院內;ID 必須來自提供的 [來源|ID])
3) 不允許常識補充:任何治療/追蹤/適應症/頻率/流程,必須由文件段落支持。
4) 文件沒寫到就輸出:此版檢索未涵蓋,故不推測〔NA〕
5) 不要要求使用者補文件、不要寫背景長文、不要說「截圖」。
6) 流程圖每個分支終點行必須附 citation;若無法引用則用 NA,不可補內容。
""".strip()

    user_prompt = f"""
【病患資料】
{patient}

【MDT問題】
{user_q}

【檢索到的文件段落(只能用這些)】
{context}
""".strip()

    try:
        raw = call_openai_llm(system_prompt, user_prompt)
        answer = sanitize_answer(raw)
        return answer, evidence_preview
    except Exception as e:
        return f"⚠️ LLM 呼叫失敗:{e}", evidence_preview


# ===================== Build KB =====================
def build_all(nccn_file, local_file):
    global chunks, index

    if nccn_file is None or local_file is None:
        return "⚠️ 請同時上傳 NCCN 與院內指引 PDF"

    try:
        nccn_text = load_pdf_text(nccn_file)
        local_text = load_pdf_text(local_file)
    except Exception as e:
        index = None
        chunks = []
        return f"⚠️ PDF 讀取失敗:{e}"

    nccn_paras = split_into_paragraphs(nccn_text, "NCCN")
    local_paras = split_into_paragraphs(local_text, "院內")

    new_chunks = []
    for i, p in enumerate(nccn_paras, start=1):
        new_chunks.append({"id": f"NCCN-{i:03d}", "source": "NCCN", "text": p})
    for i, p in enumerate(local_paras, start=1):
        new_chunks.append({"id": f"LOCAL-{i:03d}", "source": "院內", "text": p})

    chunks = new_chunks

    try:
        if len(chunks) == 0:
            index = None
            return "⚠️ 無可用文字段落(兩份 PDF 皆未抽到有效文字)"

        build_faiss_index(chunks)
    except Exception as e:
        index = None
        return f"⚠️ 索引建立失敗:{e}"

    # 明確提示 NCCN=0 的原因(不影響可用性,但讓你可驗收)
    if len(nccn_paras) == 0:
        extra = "(NCCN 未抽到可用文字:PDF 可能為掃描/受保護層,extract_text 為空或全為頁眉版權)"
    else:
        extra = ""

    return f"索引建立完成:NCCN {len(nccn_paras)} 段、院內 {len(local_paras)} 段,共 {len(chunks)}{extra}"


# ===================== UI =====================
with gr.Blocks() as demo:
    gr.Markdown("## 乳癌指引整合智能體(NCCN + 院內)")
    gr.Markdown("### 上傳最新版 NCCN 與院內指引 PDF(外部上傳;不內建)")

    nccn_file = gr.File(label="上傳 NCCN PDF")
    local_file = gr.File(label="上傳 院內指引 PDF")
    build_btn = gr.Button("建立知識庫")
    status = gr.Textbox(label="狀態")
    build_btn.click(build_all, [nccn_file, local_file], status)

    gr.Markdown("### 臨床查詢")
    with gr.Row():
        t = gr.Dropdown(["T0", "T1", "T2", "T3", "T4"], label="T", value="T1")
        n = gr.Dropdown(["N0", "N1", "N2", "N3"], label="N", value="N0")
        m = gr.Dropdown(["M0", "M1"], label="M", value="M0")

    with gr.Row():
        er = gr.Dropdown(["Positive", "Negative"], label="ER", value="Positive")
        pr = gr.Dropdown(["Positive", "Negative"], label="PR", value="Positive")
        her2 = gr.Dropdown(["Positive", "Negative"], label="HER2", value="Negative")

    question = gr.Textbox(
        label="臨床問題(MDT 用)",
        lines=3,
        placeholder="例如:初次確診的治療路徑怎麼走?請依指引整理 MDT 會議核對重點。"
    )

    ask_btn = gr.Button("智能體分析")
    answer = gr.Markdown()

    with gr.Accordion("證據預覽(檢索到的段落)", open=False):
        evidence = gr.Markdown()

    ask_btn.click(agent_answer, [t, n, m, er, pr, her2, question], [answer, evidence])

demo.launch()